wcc 3 місяців тому
батько
коміт
8d256b831d
4 змінених файлів з 429 додано та 12 видалено
  1. 18 2
      graph/graph_test.go
  2. 104 1
      graph/init.go
  3. 48 9
      graph/main.go
  4. 259 0
      graph/utils.go

+ 18 - 2
graph/graph_test.go

@@ -12,15 +12,31 @@ func TestCheckInvestRelation(t *testing.T) {
 	}
 	defer pool.Close()
 	defer session.Release()
-	names := []string{"华芳创业投资有限公司", "北京剑鱼信息技术有限公司", "河南折扣牛哟有限公司", "上海元藩投资有限公司"}
+	names := []string{"新疆拓普丰联网络信息技术有限公司",
+		"上海元藩投资有限公司",
+		"海南清鹤鸣企业管理咨询合伙企业(有限合伙)",
+		"华芳创业投资有限公司",
+		"北京剑鱼信息技术有限公司"}
 	//res, err := CheckLegalRelationsGraph(session, names, 3)
-	has, res, err := CheckLegalRelationships4(session, names, 3, 1)
+	has, res, err := CheckLegalRelationships5(session, names, 3, 1)
 	if err != nil {
 		log.Println(res, err, has)
 	}
 	log.Println(has, res)
 }
 
+func TestCheckLegalRelationships(t *testing.T) {
+	client, err := NewNebulaClient(HostList, UserName, PassWord)
+	if err != nil {
+		log.Fatal("连接失败:", err)
+	}
+	defer client.Close()
+	names := []string{"华芳创业投资有限公司", "北京剑鱼信息技术有限公司", "河南折扣牛哟有限公司", "上海元藩投资有限公司"}
+
+	has, res, err := client.CheckLegalRelationships(names, 3, 0)
+	log.Println(has, res, err)
+}
+
 func TestFetchLegalByVid(t *testing.T) {
 	session, pool, err := ConnectToNebula(HostList, UserName, PassWord)
 	if err != nil {

+ 104 - 1
graph/init.go

@@ -1,6 +1,13 @@
 package main
 
-import "jygit.jydev.jianyu360.cn/data_processing/common_utils/mongodb"
+import (
+	"fmt"
+	nebula "github.com/vesoft-inc/nebula-go/v3"
+	"jygit.jydev.jianyu360.cn/data_processing/common_utils/mongodb"
+	"log"
+	"strings"
+	"sync"
+)
 
 func InitMgo() {
 	//181 凭安库
@@ -15,3 +22,99 @@ func InitMgo() {
 	}
 	Mgo181.InitPool()
 }
+
+type NebulaClient struct {
+	hosts    []nebula.HostAddress
+	username string
+	password string
+
+	pool    *nebula.ConnectionPool
+	session *nebula.Session
+	mu      sync.Mutex
+}
+
+// NewNebulaClient 初始化客户端-NebulaGraph
+func NewNebulaClient(hosts []nebula.HostAddress, username, password string) (*NebulaClient, error) {
+	client := &NebulaClient{
+		hosts:    hosts,
+		username: username,
+		password: password,
+	}
+	err := client.connect()
+	if err != nil {
+		return nil, err
+	}
+	return client, nil
+}
+
+func (c *NebulaClient) connect() error {
+	config := nebula.GetDefaultConf()
+	config.UseHTTP2 = false
+	config.HandshakeKey = ""
+
+	pool, err := nebula.NewConnectionPool(c.hosts, config, nebula.DefaultLogger{})
+	if err != nil {
+		return fmt.Errorf("初始化连接池失败: %w", err)
+	}
+
+	session, err := pool.GetSession(c.username, c.password)
+	if err != nil {
+		pool.Close()
+		return fmt.Errorf("获取 session 失败: %w", err)
+	}
+
+	c.pool = pool
+	c.session = session
+	return nil
+}
+
+// 自动重连逻辑:当返回错误里包含 session 不存在时触发重连
+func (c *NebulaClient) ExecuteWithReconnect(query string) (*nebula.ResultSet, error) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	// 如果 session 为 nil,则重新连接
+	if c.session == nil {
+		log.Println("session 为 nil,重连中...")
+		if err := c.connect(); err != nil {
+			return nil, err
+		}
+	}
+
+	resp, err := c.session.Execute(query)
+	if err != nil || !resp.IsSucceed() {
+		// 检查是否是 session 已失效
+		if strings.Contains(resp.GetErrorMsg(), "Session") ||
+			(err != nil && strings.Contains(err.Error(), "Session")) {
+
+			log.Println("session 可能失效,正在重连...")
+			if c.pool != nil {
+				c.pool.Close()
+			}
+			if err := c.connect(); err != nil {
+				return nil, fmt.Errorf("重连失败: %w", err)
+			}
+			// 重试一次查询
+			resp, err = c.session.Execute(query)
+			if err != nil {
+				return nil, err
+			}
+			if !resp.IsSucceed() {
+				return nil, fmt.Errorf("重连后查询失败: %s", resp.GetErrorMsg())
+			}
+		} else {
+			return nil, fmt.Errorf("查询失败: %s", resp.GetErrorMsg())
+		}
+	}
+
+	return resp, nil
+}
+
+func (c *NebulaClient) Close() {
+	if c.session != nil {
+		c.session.Release()
+	}
+	if c.pool != nil {
+		c.pool.Close()
+	}
+}

+ 48 - 9
graph/main.go

@@ -114,15 +114,55 @@ func main() {
 
 	//log.Println("数据处理完毕!!!!!!!")
 	//return
-	//封装对外提供的HTTP
-	session, pool, err := ConnectToNebula(HostList, UserName, PassWord)
-	if err != nil {
-		log.Fatalf("Failed to connect to Nebula Graph: %v", err)
-	}
-	defer pool.Close()
-	defer session.Release()
+	//2、封装对外提供的HTTP
+	//session, pool, err := ConnectToNebula(HostList, UserName, PassWord)
+	//if err != nil {
+	//	log.Fatalf("Failed to connect to Nebula Graph: %v", err)
+	//}
+	//defer pool.Close()
+	//defer session.Release()
+	//// 初始化 Gin 路由
+	//r := gin.Default()
+	//// 注册 POST 接口
+	//r.POST("/check-relations", func(c *gin.Context) {
+	//	var req CheckRequest
+	//	if err := c.ShouldBindJSON(&req); err != nil {
+	//		c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效"})
+	//		return
+	//	}
+	//
+	//	has, results, err := CheckLegalRelationships4(session, req.Names, req.Deep, req.Stype)
+	//	if err != nil {
+	//		res := CheckResponse{
+	//			Code: -1,
+	//			Data: results,
+	//			Msg:  "请求失败",
+	//		}
+	//		c.JSON(http.StatusInternalServerError, res)
+	//		return
+	//	}
+	//
+	//	res := CheckResponse{
+	//		Code: 200,
+	//		Data: results,
+	//	}
+	//	if has {
+	//		res.Msg = "存在投资关系"
+	//	} else {
+	//		res.Msg = "不存在投资关系"
+	//	}
+	//
+	//	c.JSON(http.StatusOK, res)
+	//})
+
+	//3、改造方法,使用连接池,避免session过去//
 	// 初始化 Gin 路由
 	r := gin.Default()
+	client, err := NewNebulaClient(HostList, UserName, PassWord)
+	if err != nil {
+		log.Fatal("连接失败:", err)
+	}
+	defer client.Close()
 	// 注册 POST 接口
 	r.POST("/check-relations", func(c *gin.Context) {
 		var req CheckRequest
@@ -131,7 +171,7 @@ func main() {
 			return
 		}
 
-		has, results, err := CheckLegalRelationships4(session, req.Names, req.Deep, req.Stype)
+		has, results, err := client.CheckLegalRelationships(req.Names, req.Deep, req.Stype)
 		if err != nil {
 			res := CheckResponse{
 				Code: -1,
@@ -154,7 +194,6 @@ func main() {
 
 		c.JSON(http.StatusOK, res)
 	})
-
 	// 启动服务
 	r.Run(":8080")
 }

+ 259 - 0
graph/utils.go

@@ -923,3 +923,262 @@ RETURN p LIMIT 1
 
 	return false, nil, nil
 }
+
+func CheckLegalRelationships5(session *nebula.Session, names []string, deep, stype int) (bool, []string, error) {
+	if len(names) < 2 {
+		return false, nil, fmt.Errorf("企业数量不足,至少需要两个")
+	}
+
+	var rawPaths []string
+	var rawNodeLists [][]string
+
+	for i := 0; i < len(names); i++ {
+		start := names[i]
+
+		targets := []string{}
+		for j := 0; j < len(names); j++ {
+			if i != j {
+				targets = append(targets, fmt.Sprintf(`"%s"`, names[j]))
+			}
+		}
+		targetList := strings.Join(targets, ", ")
+
+		query := fmt.Sprintf(`
+USE %s;
+MATCH p=(a:Legal{name:"%s"})-[*1..%d]-(b:Legal)
+WHERE b.Legal.name IN [%s]
+RETURN p LIMIT 1
+`, Table_Space, start, deep, targetList)
+
+		resp, err := session.Execute(query)
+		if err != nil {
+			return false, nil, fmt.Errorf("查询失败: %w", err)
+		}
+		if !resp.IsSucceed() {
+			return false, nil, fmt.Errorf("查询执行失败: %s", resp.GetErrorMsg())
+		}
+
+		if resp.GetRowSize() > 0 {
+			for _, row := range resp.GetRows() {
+				if len(row.Values) == 0 {
+					continue
+				}
+				val := row.Values[0]
+				if !val.IsSetPVal() {
+					continue
+				}
+
+				path := val.GetPVal()
+				var builder strings.Builder
+				var nodeNames []string
+
+				curName := ""
+				srcVertex := path.Src
+				if srcVertex != nil && srcVertex.Vid != nil && srcVertex.Vid.IsSetSVal() {
+					vid := string(srcVertex.Vid.GetSVal())
+					lea, err := getLegalByVid(session, vid)
+					if err != nil {
+						log.Println("getLegalByVid err:", err, vid)
+					} else {
+						curName = lea.Name
+					}
+				}
+				builder.WriteString(curName)
+				nodeNames = append(nodeNames, curName)
+
+				for _, step := range path.Steps {
+					dstName := ""
+					if step.Dst != nil && step.Dst.Vid != nil && step.Dst.Vid.IsSetSVal() {
+						vid := string(step.Dst.Vid.GetSVal())
+						lea, err := getLegalByVid(session, vid)
+						if err != nil {
+							log.Println("getLegalByVid err:", err, vid)
+						} else {
+							if lea != nil && lea.Name != "" {
+								dstName = lea.Name
+							}
+						}
+					}
+
+					if step.Type > 0 {
+						builder.WriteString(" → ")
+					} else if step.Type < 0 {
+						builder.WriteString(" ← ")
+					} else {
+						builder.WriteString(" - ")
+					}
+					builder.WriteString(dstName)
+					nodeNames = append(nodeNames, dstName)
+				}
+
+				rawPaths = append(rawPaths, builder.String())
+				rawNodeLists = append(rawNodeLists, nodeNames)
+
+				if stype == 0 {
+					return true, []string{builder.String()}, nil
+				}
+			}
+		}
+	}
+
+	// 去重 + 保留最长路径
+	uniqueMap := map[string]string{}
+	for i, nodes := range rawNodeLists {
+		pathStr := rawPaths[i]
+		key := generatePathKey(nodes)
+
+		shouldAdd := true
+		for k, _ := range uniqueMap {
+			existingNodes := strings.Split(k, "|")
+			if isSubPath(nodes, existingNodes) || isSubPath(reverseSlice(nodes), existingNodes) {
+				shouldAdd = false
+				break
+			}
+			if isSubPath(existingNodes, nodes) {
+				delete(uniqueMap, k)
+			}
+		}
+
+		if shouldAdd {
+			uniqueMap[key] = pathStr
+		}
+	}
+
+	var finalPaths []string
+	for _, v := range uniqueMap {
+		finalPaths = append(finalPaths, v)
+	}
+
+	if len(finalPaths) > 0 {
+		return true, finalPaths, nil
+	}
+	return false, nil, nil
+}
+
+// CheckLegalRelationships CheckLegalRelationships
+func (c *NebulaClient) CheckLegalRelationships(names []string, deep, stype int) (bool, []string, error) {
+	if len(names) < 2 {
+		return false, nil, fmt.Errorf("企业数量不足,至少需要两个")
+	}
+
+	var rawPaths []string
+	var rawNodeLists [][]string
+
+	for i := 0; i < len(names); i++ {
+		start := names[i]
+
+		targets := []string{}
+		for j := 0; j < len(names); j++ {
+			if i != j {
+				targets = append(targets, fmt.Sprintf(`"%s"`, names[j]))
+			}
+		}
+		targetList := strings.Join(targets, ", ")
+
+		query := fmt.Sprintf(`
+USE %s;
+MATCH p=(a:Legal{name:"%s"})-[*1..%d]-(b:Legal)
+WHERE b.Legal.name IN [%s]
+RETURN p LIMIT 1
+`, Table_Space, start, deep, targetList)
+
+		resp, err := c.ExecuteWithReconnect(query)
+		if err != nil {
+			return false, nil, fmt.Errorf("查询失败: %w", err)
+		}
+		if !resp.IsSucceed() {
+			return false, nil, fmt.Errorf("执行失败: %s", resp.GetErrorMsg())
+		}
+
+		if resp.GetRowSize() > 0 {
+			for _, row := range resp.GetRows() {
+				if len(row.Values) == 0 || !row.Values[0].IsSetPVal() {
+					continue
+				}
+
+				path := row.Values[0].GetPVal()
+				var builder strings.Builder
+				var nodeNames []string
+
+				// 起点
+				src := path.Src
+				curName := ""
+				if src != nil && src.Vid != nil && src.Vid.IsSetSVal() {
+					vid := string(src.Vid.GetSVal())
+					lea, err := getLegalByVid(c.session, vid)
+					if err != nil {
+						log.Println("getLegalByVid err:", err, vid)
+					} else if lea != nil && lea.Name != "" {
+						curName = lea.Name
+					}
+				}
+				builder.WriteString(curName)
+				nodeNames = append(nodeNames, curName)
+
+				// 步长处理
+				for _, step := range path.Steps {
+					dstName := ""
+					if step.Dst != nil && step.Dst.Vid != nil && step.Dst.Vid.IsSetSVal() {
+						vid := string(step.Dst.Vid.GetSVal())
+						lea, err := getLegalByVid(c.session, vid)
+						if err != nil {
+							log.Println("getLegalByVid err:", err, vid)
+						} else if lea != nil && lea.Name != "" {
+							dstName = lea.Name
+						}
+					}
+
+					if step.Type > 0 {
+						builder.WriteString(" → ")
+					} else if step.Type < 0 {
+						builder.WriteString(" ← ")
+					} else {
+						builder.WriteString(" - ")
+					}
+					builder.WriteString(dstName)
+					nodeNames = append(nodeNames, dstName)
+				}
+
+				rawPaths = append(rawPaths, builder.String())
+				rawNodeLists = append(rawNodeLists, nodeNames)
+
+				if stype == 0 {
+					return true, []string{builder.String()}, nil
+				}
+			}
+		}
+	}
+
+	// 去重 + 最长路径保留
+	uniqueMap := map[string]string{}
+	for i, nodes := range rawNodeLists {
+		pathStr := rawPaths[i]
+		key := generatePathKey(nodes)
+
+		shouldAdd := true
+		for k := range uniqueMap {
+			existingNodes := strings.Split(k, "|")
+			if isSubPath(nodes, existingNodes) || isSubPath(reverseSlice(nodes), existingNodes) {
+				shouldAdd = false
+				break
+			}
+			if isSubPath(existingNodes, nodes) {
+				delete(uniqueMap, k)
+			}
+		}
+
+		if shouldAdd {
+			uniqueMap[key] = pathStr
+		}
+	}
+
+	var finalPaths []string
+	for _, v := range uniqueMap {
+		finalPaths = append(finalPaths, v)
+	}
+
+	if len(finalPaths) > 0 {
+		return true, finalPaths, nil
+	}
+	return false, nil, nil
+}