|
@@ -5,7 +5,6 @@ import (
|
|
|
"errors"
|
|
|
"github.com/RoaringBitmap/roaring"
|
|
|
"github.com/yl2chen/cidranger"
|
|
|
- "log"
|
|
|
"net"
|
|
|
"strconv"
|
|
|
"strings"
|
|
@@ -16,6 +15,12 @@ type WhiteIp struct {
|
|
|
Ranger cidranger.Ranger
|
|
|
}
|
|
|
|
|
|
+type IpParameter struct {
|
|
|
+ Ip string
|
|
|
+ IpType int
|
|
|
+ IsWhite int
|
|
|
+}
|
|
|
+
|
|
|
// ip2Uint32
|
|
|
func ip2Uint32(rawIp string) (uint32, error) {
|
|
|
w := strings.Split(rawIp, ".")
|
|
@@ -34,37 +39,50 @@ func ip2Uint32(rawIp string) (uint32, error) {
|
|
|
}
|
|
|
|
|
|
// NewRb init
|
|
|
-func NewRb(ip string) *WhiteIp {
|
|
|
+func NewRb(ips []IpParameter) *WhiteIp {
|
|
|
newIp := new(WhiteIp)
|
|
|
rb := roaring.NewBitmap()
|
|
|
- ranger := cidranger.NewPCTrieRanger()
|
|
|
- for _, v := range strings.Split(ip, "\n") {
|
|
|
- if len(v) < 8 {
|
|
|
+ for _, ipData := range ips {
|
|
|
+ if len(ipData.Ip) < 8 {
|
|
|
continue
|
|
|
}
|
|
|
- if isIP(v) { //精准ip添加
|
|
|
- if ipUint, err := ip2Uint32(v); err == nil {
|
|
|
+ if ipData.IpType == 1 && isIP(ipData.Ip) { //精准ip添加
|
|
|
+ if ipUint, err := ip2Uint32(ipData.Ip); err == nil {
|
|
|
rb.Add(ipUint)
|
|
|
}
|
|
|
- } else if ipNet, isIpv4 := isIPv4Segment(v); isIpv4 && ipNet != nil { //ip段添加
|
|
|
- // 初始化 IP 段
|
|
|
- _ = ranger.Insert(cidranger.NewBasicRangerEntry(*ipNet))
|
|
|
+ } else if ipData.IpType == 1 && isIPv4Segment(ipData.Ip) { //ip段添加
|
|
|
+ for _, v := range cidrToIPList(ipData.Ip) {
|
|
|
+ if ipUint, err1 := ip2Uint32(v); err1 == nil {
|
|
|
+ rb.Add(ipUint)
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
- newIp.Ranger = ranger
|
|
|
newIp.Rb = rb
|
|
|
return newIp
|
|
|
}
|
|
|
|
|
|
-// 验证 IP 是否在 IP 段内
|
|
|
-func (ip *WhiteIp) isIPInRange(ips string) bool {
|
|
|
- checkIP := net.ParseIP(ips)
|
|
|
- contains, err := ip.Ranger.Contains(checkIP)
|
|
|
+func cidrToIPList(cidr string) []string {
|
|
|
+ ip, ipNet, err := net.ParseCIDR(cidr)
|
|
|
if err != nil {
|
|
|
- log.Println("Failed to check IP range:", err)
|
|
|
- return false
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ var ipList []string
|
|
|
+ for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
|
|
|
+ ipList = append(ipList, ip.String())
|
|
|
+ }
|
|
|
+
|
|
|
+ return ipList
|
|
|
+}
|
|
|
+
|
|
|
+func incIP(ip net.IP) {
|
|
|
+ for j := len(ip) - 1; j >= 0; j-- {
|
|
|
+ ip[j]++
|
|
|
+ if ip[j] > 0 {
|
|
|
+ break
|
|
|
+ }
|
|
|
}
|
|
|
- return contains
|
|
|
}
|
|
|
|
|
|
// 判断字符串是否为有效的 IP 地址
|
|
@@ -74,9 +92,9 @@ func isIP(s string) bool {
|
|
|
}
|
|
|
|
|
|
// 判断字符串是否为有效的 IPv4 段
|
|
|
-func isIPv4Segment(s string) (*net.IPNet, bool) {
|
|
|
- _, ipNet, err := net.ParseCIDR(s)
|
|
|
- return ipNet, err == nil
|
|
|
+func isIPv4Segment(s string) bool {
|
|
|
+ _, _, err := net.ParseCIDR(s)
|
|
|
+ return err == nil
|
|
|
}
|
|
|
|
|
|
// Match
|
|
@@ -84,9 +102,6 @@ func (ip *WhiteIp) Match(rawIp string) bool {
|
|
|
if ipUint, err := ip2Uint32(rawIp); err != nil {
|
|
|
return false
|
|
|
} else {
|
|
|
- if ip.Rb.Contains(ipUint) {
|
|
|
- return true
|
|
|
- }
|
|
|
- return ip.isIPInRange(rawIp)
|
|
|
+ return ip.Rb.Contains(ipUint)
|
|
|
}
|
|
|
}
|