Quellcode durchsuchen

fix:ip段修改

duxin vor 1 Jahr
Ursprung
Commit
db0d938799
1 geänderte Dateien mit 40 neuen und 25 gelöschten Zeilen
  1. 40 25
      ipmatch/ipmatch.go

+ 40 - 25
ipmatch/ipmatch.go

@@ -5,7 +5,6 @@ import (
 	"errors"
 	"github.com/RoaringBitmap/roaring"
 	"github.com/yl2chen/cidranger"
-	"log"
 	"net"
 	"strconv"
 	"strings"
@@ -16,6 +15,12 @@ type WhiteIp struct {
 	Ranger cidranger.Ranger
 }
 
+type IpParameter struct {
+	Ip      string
+	IpType  int
+	IsWhite int
+}
+
 // ip2Uint32
 func ip2Uint32(rawIp string) (uint32, error) {
 	w := strings.Split(rawIp, ".")
@@ -34,37 +39,50 @@ func ip2Uint32(rawIp string) (uint32, error) {
 }
 
 // NewRb init
-func NewRb(ip string) *WhiteIp {
+func NewRb(ips []IpParameter) *WhiteIp {
 	newIp := new(WhiteIp)
 	rb := roaring.NewBitmap()
-	ranger := cidranger.NewPCTrieRanger()
-	for _, v := range strings.Split(ip, "\n") {
-		if len(v) < 8 {
+	for _, ipData := range ips {
+		if len(ipData.Ip) < 8 {
 			continue
 		}
-		if isIP(v) { //精准ip添加
-			if ipUint, err := ip2Uint32(v); err == nil {
+		if ipData.IpType == 1 && isIP(ipData.Ip) { //精准ip添加
+			if ipUint, err := ip2Uint32(ipData.Ip); err == nil {
 				rb.Add(ipUint)
 			}
-		} else if ipNet, isIpv4 := isIPv4Segment(v); isIpv4 && ipNet != nil { //ip段添加
-			// 初始化 IP 段
-			_ = ranger.Insert(cidranger.NewBasicRangerEntry(*ipNet))
+		} else if ipData.IpType == 1 && isIPv4Segment(ipData.Ip) { //ip段添加
+			for _, v := range cidrToIPList(ipData.Ip) {
+				if ipUint, err1 := ip2Uint32(v); err1 == nil {
+					rb.Add(ipUint)
+				}
+			}
 		}
 	}
-	newIp.Ranger = ranger
 	newIp.Rb = rb
 	return newIp
 }
 
-// 验证 IP 是否在 IP 段内
-func (ip *WhiteIp) isIPInRange(ips string) bool {
-	checkIP := net.ParseIP(ips)
-	contains, err := ip.Ranger.Contains(checkIP)
+func cidrToIPList(cidr string) []string {
+	ip, ipNet, err := net.ParseCIDR(cidr)
 	if err != nil {
-		log.Println("Failed to check IP range:", err)
-		return false
+		return nil
+	}
+
+	var ipList []string
+	for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
+		ipList = append(ipList, ip.String())
+	}
+
+	return ipList
+}
+
+func incIP(ip net.IP) {
+	for j := len(ip) - 1; j >= 0; j-- {
+		ip[j]++
+		if ip[j] > 0 {
+			break
+		}
 	}
-	return contains
 }
 
 // 判断字符串是否为有效的 IP 地址
@@ -74,9 +92,9 @@ func isIP(s string) bool {
 }
 
 // 判断字符串是否为有效的 IPv4 段
-func isIPv4Segment(s string) (*net.IPNet, bool) {
-	_, ipNet, err := net.ParseCIDR(s)
-	return ipNet, err == nil
+func isIPv4Segment(s string) bool {
+	_, _, err := net.ParseCIDR(s)
+	return err == nil
 }
 
 // Match
@@ -84,9 +102,6 @@ func (ip *WhiteIp) Match(rawIp string) bool {
 	if ipUint, err := ip2Uint32(rawIp); err != nil {
 		return false
 	} else {
-		if ip.Rb.Contains(ipUint) {
-			return true
-		}
-		return ip.isIPInRange(rawIp)
+		return ip.Rb.Contains(ipUint)
 	}
 }