Эх сурвалжийг харах

auth token过滤器开发

wanghuidong 4 жил өмнө
parent
commit
db90cb4b02

+ 9 - 5
lock/lock.go

@@ -3,13 +3,17 @@ package lock
 import "sync"
 
 var (
-	MainLock = new(sync.Map)
+	mainLock = new(sync.Map)
 )
 
-func InitUserLock() {
-
+func InitUserLock(appID string) {
+	mainLock.Store(appID, &sync.Mutex{})
 }
 
-func GetUserLock(appID string) {
-
+func GetUserLock(appID string) *sync.Mutex {
+	if userLock, ok := mainLock.Load(appID); ok {
+		_userLock := userLock.(*sync.Mutex)
+		return _userLock
+	}
+	return nil
 }

+ 16 - 0
main.go

@@ -1,12 +1,15 @@
 package main
 
 import (
+	"go.uber.org/zap"
 	"log"
 	"sfbase/core"
 	"sfbase/global"
 	"sfbase/redis"
 	"sfis/db"
+	"sfis/model"
 	"sfis/router"
+	"sfis/utils"
 )
 
 func main() {
@@ -21,7 +24,20 @@ func main() {
 	db.InitDB()
 	if db.GetSFISDB() != nil {
 		//todo other caches service or init operation
+		users := make([]*model.User, 0)
+		db.GetSFISDB().Find(&users)
+		for _, user := range users {
+			utils.UserCaches.Map.Store(user.AppID, user)
+		}
+		global.Logger.Info("初始化用户缓存信息,", zap.Any("用户数量:", len(users)))
 
+		apis := make([]*model.Product, 0)
+		db.GetSFISDB().Find(&apis)
+		for _, api := range apis {
+			utils.ProductCaches.Map.Store(api.ID, api)
+			utils.ApiUrlCache.Store(api.Path, api)
+		}
+		global.Logger.Info("初始化产品缓存信息,", zap.Any("产品数量:", len(users)))
 	}
 
 	//全局redis的使用?

+ 131 - 0
middleware/auth.go

@@ -1,4 +1,135 @@
 package middleware
 
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"go.uber.org/zap"
+	"sfbase/global"
+	sutils "sfbase/utils"
+	"sfis/db"
+	"sfis/lock"
+	"sfis/model"
+	"sfis/model/response"
+	"sfis/utils"
+	"strconv"
+	"strings"
+	"time"
+)
 
+const TimestampExpireTime = 600 //单位秒,header里的时间戳超时时间 10分钟
 
+func TokenAuth() gin.HandlerFunc {
+	return func(context *gin.Context) {
+		var (
+			requestUrl string
+			token      string
+			timestamp  string
+			appID      string
+			productID  int
+		)
+		requestUrl = context.Request.URL.String()
+		requestUrl = strings.Split(requestUrl, "?")[0]
+		requestUrl = strings.Split(requestUrl, "/")[3]
+		if p, ok := utils.ApiUrlCache.Load(requestUrl); ok {
+			productID = p.(*model.Product).ID
+		} else {
+			response.FailWithDetailed(response.ParamError, nil, "url错误", context)
+			context.Abort()
+			return
+		}
+		token = context.Request.Header.Get("token")
+		timestamp = context.Request.Header.Get("timestamp")
+		appID = context.PostForm("app_id")
+
+		if appID == "" || token == "" || timestamp == "" {
+			response.FailWithDetailed(response.ParamEmpty, nil, "参数缺失或为空", context)
+			context.Abort()
+			return
+		}
+		_timestamp, err := strconv.ParseInt(timestamp, 10, 64)
+		if err != nil {
+			response.FailWithDetailed(response.ParamError, nil, "参数异常", context)
+			context.Abort()
+			return
+		}
+		now := time.Now().Unix()
+		if now-_timestamp > TimestampExpireTime {
+			//token时间验证 十分钟
+			response.FailWithDetailed(response.TokenExpired, nil, "签名过期", context)
+			context.Abort()
+			return
+		}
+
+		user := utils.GetUserByAppID(appID)
+		secretKey := user.SecretKey
+		ipWhiteList := user.IpWhiteList
+		userName := user.Name
+		global.Logger.Info("用户:", zap.Any("userName:", userName), zap.Any("appID:", appID), zap.Any("secretKey:", secretKey), zap.Any("ipWhiteList:", ipWhiteList))
+		/**
+		第一步:ip白名单校验
+		*/
+		if ipWhiteList != "*" {
+			requestIp := utils.GetIp(context.Request)
+			if strings.Index(ipWhiteList, requestIp) < 0 {
+				response.FailWithDetailed(response.IpInvalid, nil, "ip不在白名单", context)
+				context.Abort()
+				return
+			}
+		}
+		/**
+		第二步:MD5签名校验
+		*/
+		signToken := sutils.MD5(fmt.Sprintf("%s%s%s", appID, timestamp, user.SecretKey))
+		if token != signToken {
+			response.FailWithDetailed(response.TokenInvalid, nil, "身份验证失败", context)
+			context.Abort()
+			return
+		}
+
+		context.Set("appID", appID)
+		context.Set("productID", productID)
+		if userLock := lock.GetUserLock(appID); userLock != nil {
+			/**
+			第二步:用户接口产品校验-加锁处理
+			*/
+			//2.1 取用户接口状态校验
+			userLock.Lock()
+			userProduct := &model.UserProduct{}
+			db.GetSFISDB().First(userProduct, &model.UserProduct{AppID: appID, ProductID: productID})
+			userLock.Unlock()
+			if userProduct.InterfaceStatus != 0 {
+				response.FailWithDetailed(response.InterfaceDeleted, nil, "该用户接口已停用", context)
+				context.Abort()
+				return
+			}
+
+			//2.2 取用户余量|账户余额 校验
+			costModel := userProduct.CostModel
+			product := utils.GetProductByID(productID)
+			productType := product.ProductType
+			userLock.Lock()
+			switch costModel {
+			case 0:
+				//按剩余量扣费
+				if productType == 0 {
+					//按次扣费-每调一次 剩余量-1
+					userProduct.LeftNum = userProduct.LeftNum - 1
+				} else if productType == 1 {
+					//按条扣费-每调一次剩余量-len(getDataList	)
+				}
+			case 1:
+				//按账户钱包余额扣费
+				if productType == 0 {
+					//按次扣费-每调一次
+					//todo 账户余额表user_account的余额 减去 product单价*1
+				} else if productType == 1 {
+					//按条扣费-每调一次
+					//todo 账户余额表user_account的余额 减去 product单价*len(getDataList)
+				}
+			case 2:
+				//优先扣剩余量,剩余量为0,扣钱包余额
+			}
+		}
+
+	}
+}

+ 13 - 0
model/baseModel.go

@@ -0,0 +1,13 @@
+package model
+
+import (
+	"gorm.io/gorm"
+	"time"
+)
+
+type BaseModel struct {
+	ID        int            `json:"id" form:"id" gorm:"primaryKey"`
+	CreateAt  time.Time      `json:"-" gorm:"autoCreateTime"` //标签autoCreateTime设置如果字段名字不为CreatAt时候自动插入当前时间
+	UpdateAt  time.Time      `json:"-" gorm:"autoUpdateTime"`
+	DeletedAt gorm.DeletedAt `json:"-" `
+}

+ 33 - 0
model/product.go

@@ -0,0 +1,33 @@
+package model
+
+import "time"
+
+type Product struct {
+	BaseModel
+	Name        string `json:"name"`
+	Path        string `json:"url"`
+	UnitPrice   int    `json:"unit_price"`   //单价
+	MinUnit     int    `json:"min_unit"`     //最小单位
+	ProductType int    `json:"product_type"` //产品类型 按次-0,按条-1
+	TestNum     int    `json:"test_num"`     //试用量
+}
+
+func (p *Product) TableName() string {
+	return "product"
+}
+
+type UserProduct struct {
+	ID              int       `json:"id" gorm:"primaryKey"`
+	AppID           string    `json:"app_id"`
+	ProductID       int       `json:"product_id"`
+	CreateAt        time.Time `json:"-" gorm:"autoCreateTime"` //标签autoCreateTime设置如果字段名字不为CreatAt时候自动插入当前时间
+	EndAt           time.Time `json:"end_at"`
+	LeftNum         int       `json:"left_num"`         //剩余量  加锁处理
+	CostModel       int       `json:"cost_model"`       //扣费模式(0-按剩余量扣,1-按账户余额扣,2-优先扣剩余量,量为0扣余额)
+	InterfaceStatus int       `json:"interface_status"` //接口状态(0开启|-1停用|-2异常|-3维护)
+	CallLimitDay    int       `json:"call_limit_day"`   //每天(调用次数|取走数据量)上限
+}
+
+func (p *UserProduct) TableName() string {
+	return "user_product"
+}

+ 14 - 10
model/response/response.go

@@ -11,24 +11,28 @@ const (
 	//服务级错误码
 	EmptyResult                  int = 201 //查询无结果
 	ParamError                   int = 202 //参数错误
-	ParamLenInValid              int = 203 //参数长度小于4
-	Waiting                      int = 204 //等待处理中
-	MoreThanQueryDataNumberLimit int = 205 //请求数据的条数超过上限
+	ParamEmpty                       = 203 //参数为空
+	ParamLenInValid              int = 204 //参数长度小于4
+	Waiting                      int = 205 //等待处理中
+	MoreThanQueryDataNumberLimit int = 206 //请求数据的条数超过上限
+	LeftNumEmpty                     = 207 //余额不足
 	QueryError                   int = 299 //系统查询异常,请联系客服
 
 	//系统级错误码
 	InValidKey    = 101 //当前KEY无效
 	RemainingLack = 102 //当前KEY余额|余量不足
 	DeleteKey     = 103 //当前Key被暂停使用,请联系管理员
-	TokenInvalid  = 104 //身份验证错误或者已过期
+	TokenInvalid  = 104 //身份验证错误
+	TokenExpired  = 107 //身份验证已过期
+
 	//105非法请求过多,请联系管理员
-	IpInvalid                       = 106 //被禁止的IP
-	MoreThanEveryDayQueryTimesLimit = 107 //请求超过每日系统限制
+	IpInvalid                       = 108 //被禁止的IP
+	MoreThanEveryDayQueryTimesLimit = 109 //请求超过每日系统限制
 	//108当前相同查询连续出错,请等2小时后重试
-	InterfaceRightInvalid           = 109 //接口权限未开通
-	InterfaceExpired                = 110 //您的账号剩余使用量已过期
-	InterfaceDeleted                = 111 //接口已停用,请联系管理员
-	MoreThanEveryDayDataNumberLimit = 112 //请求超过每日调用总量限制
+	InterfaceRightInvalid           = 110 //接口权限未开通
+	InterfaceExpired                = 111 //您的账号剩余使用量已过期
+	InterfaceDeleted                = 112 //接口已停用,请联系管理员
+	MoreThanEveryDayDataNumberLimit = 113 //请求超过每日调用总量限制
 	OtherError                      = 199 //系统未知错误,请联系技术客服
 )
 

+ 18 - 0
model/user.go

@@ -0,0 +1,18 @@
+package model
+
+type User struct {
+	BaseModel
+	Name        string `json:"name"`
+	Phone       string `json:"phone"`
+	AppID       string `json:"app_id"`
+	SecretKey   string `json:"secret_key"`
+	IpWhiteList string `json:"ip_white_list"`
+}
+
+
+
+func (user *User) TableName() string {
+	return "user"
+}
+
+

+ 135 - 0
sword_base/utils/simple_encrypt.go

@@ -0,0 +1,135 @@
+package utils
+
+import (
+	"bytes"
+	"encoding/base64"
+	"encoding/hex"
+)
+
+/**
+加解密
+	数据结构
+	密文+sha32校验
+*/
+//
+type SimpleEncrypt struct {
+	Key string //加解密用到的key(加密key索引)+
+}
+
+//计算检验和
+func (s *SimpleEncrypt) calaCheckCode(src []byte) []byte {
+	check := 0
+	for i := 0; i < len(src); i++ {
+		check += int(src[i])
+	}
+	return []byte{byte((check >> 8) & 0xff), byte(check & 0xff)}
+}
+
+//验证数据有效性
+func (s *SimpleEncrypt) verify(src []byte) bool {
+	v := s.calaCheckCode(src[:len(src)-2])
+	return bytes.Equal(v, src[len(src)-2:])
+}
+
+//加密String
+func (s *SimpleEncrypt) EncodeString(str string) string {
+	data := []byte(str)
+	s.doEncode(data)
+	return base64.StdEncoding.EncodeToString(data)
+}
+
+//加密String,ByCheck
+func (s *SimpleEncrypt) EncodeStringByCheck(str string) string {
+	data := []byte(str)
+	s.doEncode(data)
+	v := s.calaCheckCode(data)
+	data = append(data, v...)
+	return base64.StdEncoding.EncodeToString(data)
+}
+
+//
+func (s *SimpleEncrypt) Encode2Hex(str string) string {
+	data := []byte(str)
+	s.doEncode(data)
+	return hex.EncodeToString(data)
+}
+func (s *SimpleEncrypt) Encode2HexByCheck(str string) string {
+	data := []byte(str)
+	s.doEncode(data)
+	v := s.calaCheckCode(data)
+	data = append(data, v...)
+	return hex.EncodeToString(data)
+}
+
+//解密String
+func (s *SimpleEncrypt) DecodeString(str string) string {
+	data, _ := base64.StdEncoding.DecodeString(str)
+	s.doEncode(data)
+	return string(data)
+}
+
+//解密String,ByCheck
+func (s *SimpleEncrypt) DecodeStringByCheck(str string) string {
+	data, _ := base64.StdEncoding.DecodeString(str)
+	if len(data) < 2 || !s.verify(data) {
+		return ""
+	}
+	data = data[:len(data)-2]
+	s.doEncode(data)
+	return string(data)
+}
+
+//
+func (s *SimpleEncrypt) Decode4Hex(str string) string {
+	data, _ := hex.DecodeString(str)
+	s.doEncode(data)
+	return string(data)
+}
+func (s *SimpleEncrypt) Decode4HexByCheck(str string) string {
+	data, _ := hex.DecodeString(str)
+	if len(data) < 2 || !s.verify(data) {
+		return ""
+	}
+	data = data[:len(data)-2]
+	s.doEncode(data)
+	return string(data)
+}
+
+//加密
+func (s *SimpleEncrypt) Encode(data []byte) {
+	s.doEncode(data)
+
+}
+
+func (s *SimpleEncrypt) EncodeByCheck(data []byte) {
+	s.doEncode(data)
+	v := s.calaCheckCode(data)
+	data = append(data, v...)
+}
+
+//解密
+func (s *SimpleEncrypt) Decode(data []byte) {
+	s.doEncode(data)
+}
+
+//解密
+func (s *SimpleEncrypt) DecodeByCheck(data []byte) {
+	if len(data) < 2 || !s.verify(data) {
+		data = []byte{}
+		return
+	}
+	s.doEncode(data)
+}
+
+func (s *SimpleEncrypt) doEncode(bs []byte) {
+	tmp := []byte(s.Key)
+THEFOR:
+	for i := 0; i < len(bs); {
+		for j := 0; j < len(tmp); j, i = j+1, i+1 {
+			if i >= len(bs) {
+				break THEFOR
+			}
+			bs[i] = bs[i] ^ tmp[j]
+		}
+	}
+}

+ 73 - 0
sword_base/utils/stringutil.go

@@ -2,9 +2,11 @@ package utils
 
 import (
 	"crypto/md5"
+	cryptoRand "crypto/rand"
 	"encoding/hex"
 	"fmt"
 	"github.com/dchest/captcha"
+	"io"
 	"math/rand"
 	"strings"
 	"time"
@@ -57,3 +59,74 @@ func GenerateSimpleToken() string {
 	return hex.EncodeToString(h.Sum(nil))
 }
 //var pool = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz@#$"
+
+
+
+//获取复杂的随机数
+func GetLetterRandom(length int, flag ...bool) string {
+	var idChars = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")
+	var mod byte = 52
+	if len(flag) > 0 && flag[0] {
+		idChars = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
+		mod = 26
+	}
+	b := make([]byte, length)
+	maxrb := byte(256 - (256 % int(mod)))
+	i := 0
+EXIT:
+	for {
+		r := make([]byte, length+(length/4))
+		if _, err := io.ReadFull(cryptoRand.Reader, r); err != nil {
+			panic("GetLetterRandom: error reading random source: " + err.Error())
+		}
+		for _, c := range r {
+			if c > maxrb {
+				continue
+			}
+			b[i] = c % mod
+			i++
+			if i == length {
+				break EXIT
+			}
+		}
+	}
+	for i, c := range b {
+		b[i] = idChars[c]
+	}
+	return string(b)
+}
+
+/*获取复杂的随机数,数字和字母的组合
+ * c > 2 数字的个数和字母的个数随机分配
+ * n 数字的个数
+ * l 字母的个数
+ */
+func GetComplexRandom(c, n, l int) string {
+	if c < 2 && (n < 1 || l < 1) {
+		return "--"
+	}
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
+	myCommonMethod := func(flag bool) int {
+		if flag {
+			return r.Intn(c-1) + 1
+		} else {
+			return r.Intn(c)
+		}
+	}
+	if c >= 2 {
+		n = myCommonMethod(true)
+		l = c - n
+	} else {
+		c = l + n
+	}
+	value := MakeSimpleCaptcha(n) + GetLetterRandom(l)
+	var array = strings.Split(value, "")
+	for i := 0; i < c/2; i++ {
+		r1 := myCommonMethod(false)
+		r2 := myCommonMethod(false)
+		o := array[r1]
+		array[r1] = array[r2]
+		array[r2] = o
+	}
+	return strings.Join(array, "")
+}

+ 8 - 2
test/manage/user_test.go

@@ -1,13 +1,19 @@
 package manage
 
 import (
-	"log"
+	"fmt"
+	sutil "sfbase/utils"
+	"sfis/utils"
 	"testing"
+	"time"
 )
 
 func init() {
 	//todo init connection db operation
 }
 func Test_CreateUser(t *testing.T) {
-	log.Println("devUserCreate testing......")
+	//log.Println("devUserCreate testing......")
+	appID := utils.GetAppID(time.Now().Unix())
+	secretKey := sutil.GetComplexRandom(8, 3, 5)
+	fmt.Printf("appID:[%s],secretKey:[%s]", appID, secretKey)
 }

+ 42 - 0
utils/caches.go

@@ -0,0 +1,42 @@
+package utils
+
+import (
+	"sfis/db"
+	"sfis/model"
+	"sync"
+)
+
+var (
+	UserCaches    = new(ResourceCache)
+	ProductCaches = new(ResourceCache)
+	ApiUrlCache   = new(sync.Map)
+)
+
+type ResourceCache struct {
+	Data []interface{}
+	Map  sync.Map
+}
+
+func FLushUserCaches() {
+	users := make([]*model.User, 0)
+	db.GetSFISDB().Find(&users)
+	UserCaches.Map = sync.Map{}
+	for _, user := range users {
+		UserCaches.Map.Store(user.AppID, user)
+	}
+}
+func GetUserByAppID(appID string) *model.User {
+	if m, ok := UserCaches.Map.Load(appID); ok {
+		_m := m.(*model.User)
+		return _m
+	}
+	return nil
+}
+
+func GetProductByID(productID int) *model.Product {
+	if m, ok := ProductCaches.Map.Load(productID); ok {
+		_m := m.(*model.Product)
+		return _m
+	}
+	return nil
+}

+ 49 - 0
utils/user_util.go

@@ -0,0 +1,49 @@
+package utils
+
+import (
+	"fmt"
+	"net"
+	"net/http"
+	"regexp"
+	"sfbase/utils"
+	"strings"
+)
+
+var (
+	userAppIDEncrypt = &utils.SimpleEncrypt{"sfis20212120"}
+	strReg           = regexp.MustCompile("^[0-9a-zA-Z]+$")
+)
+
+func GetAppID(tn int64) (appID string) {
+	for {
+		randomstr := utils.GetLetterRandom(5)
+		str := fmt.Sprintf("%s%d%s", randomstr[:2], tn, randomstr[2:])
+		appID = userAppIDEncrypt.EncodeString(str)
+		if strReg.MatchString(appID) {
+			break
+		}
+	}
+	appID = "sf" + appID
+	return
+}
+
+func GetIp(req *http.Request) string {
+	if req == nil {
+		return ""
+	}
+	ip_for := req.Header.Get("x-forwarded-for")
+	ip_client := req.Header.Get("http_client_ip")
+	ip_addr := req.Header.Get("Remote_addr")
+	un := "unknown"
+	if (ip_for != un) && (len(strings.TrimSpace(ip_for)) > 0) {
+		return ip_for
+	}
+	if (ip_client != un) && (len(strings.TrimSpace(ip_client)) > 0) {
+		return ip_client
+	}
+	if (ip_addr != un) && (len(strings.TrimSpace(ip_addr)) > 0) {
+		return ip_addr
+	}
+	ip, _, _ := net.SplitHostPort(req.RemoteAddr)
+	return ip
+}