瀏覽代碼

fix:ip段增加

duxin 1 年之前
父節點
當前提交
d6bb788b12
共有 3 個文件被更改,包括 44 次插入4 次删除
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 41 4
      ipmatch/ipmatch.go

+ 1 - 0
go.mod

@@ -14,6 +14,7 @@ require (
 	github.com/gomodule/redigo v1.8.9
 	github.com/howeyc/fsnotify v0.9.0
 	github.com/olivere/elastic/v7 v7.0.22
+	github.com/yl2chen/cidranger v1.0.2
 	github.com/zeromicro/go-zero v1.3.5
 	go.etcd.io/etcd/client/v3 v3.5.4
 	go.mongodb.org/mongo-driver v1.9.1

+ 2 - 0
go.sum

@@ -438,6 +438,8 @@ github.com/xdg-go/scram v1.0.2 h1:akYIkZ28e6A96dkWNJQu3nmCzH3YfwMPQExUYDaRv7w=
 github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs=
 github.com/xdg-go/stringprep v1.0.2 h1:6iq84/ryjjeRmMJwxutI51F2GIPlP5BfTvXHeYjyhBc=
 github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM=
+github.com/yl2chen/cidranger v1.0.2 h1:lbOWZVCG1tCRX4u24kuM1Tb4nHqWkDxwLdoS+SevawU=
+github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g=
 github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA=
 github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
 github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=

+ 41 - 4
ipmatch/ipmatch.go

@@ -4,12 +4,16 @@ import (
 	_ "embed"
 	"errors"
 	"github.com/RoaringBitmap/roaring"
+	"github.com/yl2chen/cidranger"
+	"log"
+	"net"
 	"strconv"
 	"strings"
 )
 
 type WhiteIp struct {
-	Rb *roaring.Bitmap
+	Rb     *roaring.Bitmap
+	Ranger cidranger.Ranger
 }
 
 // ip2Uint32
@@ -33,23 +37,56 @@ func ip2Uint32(rawIp string) (uint32, error) {
 func NewRb(ip string) *WhiteIp {
 	newIp := new(WhiteIp)
 	rb := roaring.NewBitmap()
+	ranger := cidranger.NewPCTrieRanger()
 	for _, v := range strings.Split(ip, "\n") {
 		if len(v) < 8 {
 			continue
 		}
-		if ipUint, err := ip2Uint32(v); err == nil {
-			rb.Add(ipUint)
+		if isIP(v) { //精准ip添加
+			if ipUint, err := ip2Uint32(v); err == nil {
+				rb.Add(ipUint)
+			}
+		} else if ipNet, isIpv4 := isIPv4Segment(v); isIpv4 && ipNet != nil { //ip段添加
+			// 初始化 IP 段
+			_ = ranger.Insert(cidranger.NewBasicRangerEntry(*ipNet))
 		}
 	}
+	newIp.Ranger = ranger
 	newIp.Rb = rb
 	return newIp
 }
 
+// 验证 IP 是否在 IP 段内
+func (ip *WhiteIp) isIPInRange(ips string) bool {
+	checkIP := net.ParseIP(ips)
+	contains, err := ip.Ranger.Contains(checkIP)
+	if err != nil {
+		log.Println("Failed to check IP range:", err)
+		return false
+	}
+	return contains
+}
+
+// 判断字符串是否为有效的 IP 地址
+func isIP(s string) bool {
+	ip := net.ParseIP(s)
+	return ip != nil
+}
+
+// 判断字符串是否为有效的 IPv4 段
+func isIPv4Segment(s string) (*net.IPNet, bool) {
+	_, ipNet, err := net.ParseCIDR(s)
+	return ipNet, err == nil
+}
+
 // Match
 func (ip *WhiteIp) Match(rawIp string) bool {
 	if ipUint, err := ip2Uint32(rawIp); err != nil {
 		return false
 	} else {
-		return ip.Rb.Contains(ipUint)
+		if ip.Rb.Contains(ipUint) {
+			return true
+		}
+		return ip.isIPInRange(rawIp)
 	}
 }