package main import ( "fmt" nebula "github.com/vesoft-inc/nebula-go/v3" "log" ) // 表示企业关系结果 //type AllRelationResult struct { // RelatedCompanies []string // 有关系的企业列表 // Paths []string // 对应的路径 //} // //// 批量获取企业的 VID //func getVidsByName(session *nebula.Session, names []string) (map[string]string, error) { // if len(names) == 0 { // return nil, nil // } // conditions := "" // for i, name := range names { // if i > 0 { // conditions += " OR " // } // conditions += fmt.Sprintf("Legal.name == \"%s\"", name) // } // query := fmt.Sprintf(` //USE %s; //LOOKUP ON Legal WHERE %s YIELD id(vertex) AS vid, properties(vertex).name AS name //`, Table_Space, conditions) // resp, err := session.Execute(query) // if err != nil { // return nil, err // } // nameToVid := make(map[string]string) // for _, row := range resp.GetRows() { // if len(row.Values) >= 2 { // if row.Values[0].SVal != nil && row.Values[1].SVal != nil { // nameToVid[string(row.Values[1].SVal)] = string(row.Values[0].SVal) // } // } // } // return nameToVid, nil //} // //// 获取 VID 对应的名称 //func getVidName(session *nebula.Session, vid string) (string, error) { // query := fmt.Sprintf(` //USE %s; //FETCH PROP ON Legal "%s" YIELD properties(vertex).name AS name //`, Table_Space, vid) // resp, err := session.Execute(query) // if err != nil { // return "", err // } // names, err := getFirstColumnStrings(resp) // if err != nil || len(names) == 0 { // return "", fmt.Errorf("未找到 VID %s 的名称", vid) // } // return names[0], nil //} // //// 查找路径 //func findPath(session *nebula.Session, fromVid, toVid string, maxStep int, pathCache map[string][]string) ([]string, error) { // key := fmt.Sprintf("%s->%s:%d", fromVid, toVid, maxStep) // if cachedPath, ok := pathCache[key]; ok { // return cachedPath, nil // } // query := fmt.Sprintf(`FIND ALL PATH FROM "%s" TO "%s" OVER Invest UPTO %d STEPS YIELD path as p`, fromVid, toVid, maxStep) // resp, err := session.Execute(query) // if err != nil { // return nil, err // } // path, err := getFirstColumnStrings(resp) // if err != nil { // return nil, err // } // pathCache[key] = path // return path, nil //} // //// 检查共同祖先 //func checkCommonAncestor(session *nebula.Session, aVid, bVid string, deep int, pathCache map[string][]string) (bool, []string, string) { // key := fmt.Sprintf("%s&%s:%d", aVid, bVid, deep) // if cachedPath, ok := pathCache[key]; ok { // if len(cachedPath) > 0 { // return true, cachedPath, cachedPath[1] // } // return false, nil, "" // } // query := fmt.Sprintf(` // ( // GO 1 TO %d STEPS FROM "%s" OVER Invest REVERSELY YIELD dst(edge) AS ancestor // ) // INTERSECT // ( // GO 1 TO %d STEPS FROM "%s" OVER Invest REVERSELY YIELD dst(edge) AS ancestor // ); // `, deep, aVid, deep, bVid) // resp, err := session.Execute(query) // if err != nil { // return false, nil, "" // } // ancestors, err := getFirstColumnStrings(resp) // if err != nil || len(ancestors) == 0 { // pathCache[key] = nil // return false, nil, "" // } // pathA, _ := findPath(session, aVid, ancestors[0], deep, pathCache) // pathB, _ := findPath(session, bVid, ancestors[0], deep, pathCache) // var path []string // if len(pathB) > 1 { // path = append(pathA, pathB[1:]...) // } else { // path = append(pathA, pathB...) // } // pathCache[key] = path // return true, path, ancestors[0] //} // //// 将 VID 路径转换为名称路径 //func convertVidPathToNamePath(session *nebula.Session, vidPath []string) (string, error) { // namePath := "" // for i, vid := range vidPath { // name, err := getVidName(session, vid) // if err != nil { // return "", err // } // if i > 0 { // namePath += "->" // } // namePath += name // } // return namePath, nil //} // //// 检查企业关系 //func CheckLegalRelations(session *nebula.Session, names []string, deep int) (AllRelationResult, error) { // result := AllRelationResult{} // checked := make(map[string]bool) // nameToVid, err := getVidsByName(session, names) // if err != nil { // return result, err // } // pathCache := make(map[string][]string) // relatedCompaniesSet := make(map[string]bool) // var paths []string // // for i := 0; i < len(names); i++ { // for j := i + 1; j < len(names); j++ { // a, b := names[i], names[j] // vidA, okA := nameToVid[a] // vidB, okB := nameToVid[b] // if !okA || !okB { // continue // } // // key := vidA + "|" + vidB // if checked[key] { // continue // } // checked[key] = true // // // 1. a -> b // pathAB, err := findPath(session, vidA, vidB, deep, pathCache) // if err != nil { // log.Printf("查找 %s 到 %s 的路径失败: %v", a, b, err) // continue // } // if len(pathAB) > 0 { // pathStr, err := convertVidPathToNamePath(session, pathAB) // if err != nil { // log.Printf("转换 %s 到 %s 的路径失败: %v", a, b, err) // continue // } // paths = append(paths, pathStr) // for _, vid := range pathAB { // name, err := getVidName(session, vid) // if err != nil { // log.Printf("获取 VID %s 对应的名称失败: %v", vid, err) // continue // } // relatedCompaniesSet[name] = true // } // continue // } // // // 2. b -> a // pathBA, err := findPath(session, vidB, vidA, deep, pathCache) // if err != nil { // log.Printf("查找 %s 到 %s 的路径失败: %v", b, a, err) // continue // } // if len(pathBA) > 0 { // pathStr, err := convertVidPathToNamePath(session, pathBA) // if err != nil { // log.Printf("转换 %s 到 %s 的路径失败: %v", b, a, err) // continue // } // paths = append(paths, pathStr) // for _, vid := range pathBA { // name, err := getVidName(session, vid) // if err != nil { // log.Printf("获取 VID %s 对应的名称失败: %v", vid, err) // continue // } // relatedCompaniesSet[name] = true // } // continue // } // // // 3. common ancestor // common, path, _ := checkCommonAncestor(session, vidA, vidB, deep, pathCache) // if common { // pathStr, err := convertVidPathToNamePath(session, path) // if err != nil { // log.Printf("转换 %s 和 %s 到共同祖先的路径失败: %v", a, b, err) // continue // } // paths = append(paths, pathStr) // for _, vid := range path { // name, err := getVidName(session, vid) // if err != nil { // log.Printf("获取 VID %s 对应的名称失败: %v", vid, err) // continue // } // relatedCompaniesSet[name] = true // } // } // } // } // // for company := range relatedCompaniesSet { // result.RelatedCompanies = append(result.RelatedCompanies, company) // } // result.Paths = paths // // return result, nil //} // //// getFirstColumnStrings 适配 nebula-go v3 取出字符串类型列 //func getFirstColumnStrings(resp *nebula.ResultSet) ([]string, error) { // if resp == nil { // return nil, fmt.Errorf("result set is nil") // } // // var values []string // for _, row := range resp.GetRows() { // if len(row.Values) == 0 { // continue // } // val := row.Values[0] // switch { // case val.SVal != nil: // values = append(values, string(val.SVal)) // case val.IVal != nil: // values = append(values, fmt.Sprintf("%d", *val.IVal)) // case val.BVal != nil: // values = append(values, fmt.Sprintf("%v", *val.BVal)) // default: // log.Printf("未知类型值: %+v", val) // } // } // return values, nil //} func CheckLegalRelations(session *nebula.Session, names []string, deep int) ([]RelationResult, error) { results := []RelationResult{} checked := make(map[string]bool) nameToVid, err := getAllVids(session, names) if err != nil { return nil, err } vidToName := reverseMap(nameToVid) for i := 0; i < len(names); i++ { for j := i + 1; j < len(names); j++ { a, b := names[i], names[j] vidA, okA := nameToVid[a] vidB, okB := nameToVid[b] if !okA || !okB { continue } key := vidA + "|" + vidB if checked[key] { continue } checked[key] = true // 1. a -> b pathAB, err := findPath(session, vidA, vidB, deep) if err != nil { return nil, err } if len(pathAB) > 0 { readablePath := convertPathToNames(pathAB, vidToName) results = append(results, RelationResult{A: a, B: b, RelationType: "direct_or_indirect", Path: readablePath}) continue } // 2. b -> a pathBA, err := findPath(session, vidB, vidA, deep) if err != nil { return nil, err } if len(pathBA) > 0 { readablePath := convertPathToNames(pathBA, vidToName) results = append(results, RelationResult{A: b, B: a, RelationType: "direct_or_indirect", Path: readablePath}) continue } // 3. common ancestor common, ancestorVid, err := checkCommonAncestor(session, vidA, vidB, deep) if err != nil { return nil, err } if common { ancestorName := getAncestorName(session, ancestorVid, vidToName) aName := vidToName[vidA] bName := vidToName[vidB] if ancestorName != "" && aName != "" && bName != "" { readablePath := []string{ fmt.Sprintf("%s -> %s", aName, ancestorName), fmt.Sprintf("%s -> %s", bName, ancestorName), } results = append(results, RelationResult{A: a, B: b, RelationType: "common_ancestor", Path: readablePath}) } } } } return results, nil } func getAllVids(session *nebula.Session, names []string) (map[string]string, error) { nameToVid := make(map[string]string) for _, name := range names { vid, err := getVidByName(session, name) if err != nil { log.Printf("获取 %s 的 VID 失败: %v", name, err) continue } nameToVid[name] = vid } return nameToVid, nil } func checkCommonAncestor(session *nebula.Session, aVid, bVid string, deep int) (bool, string, error) { query := fmt.Sprintf(` ( GO 1 TO %d STEPS FROM "%s" OVER Invest REVERSELY YIELD dst(edge) AS ancestor ) INTERSECT ( GO 1 TO %d STEPS FROM "%s" OVER Invest REVERSELY YIELD dst(edge) AS ancestor ); `, deep, aVid, deep, bVid) resp, err := session.Execute(query) if err != nil { return false, "", err } ancestors, err := getFirstColumnStrings(resp) if err != nil || len(ancestors) == 0 { return false, "", nil } return true, ancestors[0], nil } func findPath(session *nebula.Session, fromVid, toVid string, maxStep int) ([]string, error) { query := fmt.Sprintf(`FIND ALL PATH FROM "%s" TO "%s" OVER Invest UPTO %d STEPS YIELD path as p`, fromVid, toVid, maxStep) resp, err := session.Execute(query) if err != nil { return nil, err } return getFirstColumnStrings(resp) } func getVidByName(session *nebula.Session, name string) (string, error) { query := fmt.Sprintf(` USE `+Table_Space+`; LOOKUP ON Legal WHERE Legal.name == "%s" YIELD id(vertex)`, name) resp, err := session.Execute(query) if err != nil { return "", err } values, err := getFirstColumnStrings(resp) if err != nil || len(values) == 0 { return "", fmt.Errorf("未找到公司: %s", name) } return values[0], nil } type RelationResult struct { A, B string // 公司名 RelationType string // 关系类型:"direct_or_indirect", "common_ancestor" Path []string // 路径中的公司名称,以更直观的形式展示 } // getFirstColumnStrings 适配 nebula-go v3 取出字符串类型列 func getFirstColumnStrings(resp *nebula.ResultSet) ([]string, error) { if resp == nil { return nil, fmt.Errorf("result set is nil") } var values []string for _, row := range resp.GetRows() { if len(row.Values) == 0 { continue } val := row.Values[0] switch { case val.SVal != nil: values = append(values, string(val.SVal)) case val.IVal != nil: values = append(values, fmt.Sprintf("%d", *val.IVal)) case val.BVal != nil: values = append(values, fmt.Sprintf("%v", *val.BVal)) case val.PVal != nil: // 处理点类型 //src := val.PVal.GetSrc() //if src.GetId != nil { // values = append(values, string(*src.SVal)) //} else if src.IVal != nil { // values = append(values, fmt.Sprintf("%d", *src.IVal)) //} else { // log.Printf("未知的点源 ID 类型: %+v", src) //} default: log.Printf("未知类型值: %+v", val) } } return values, nil } func reverseMap(m map[string]string) map[string]string { result := make(map[string]string) for k, v := range m { result[v] = k } return result } func convertPathToNames(path []string, vidToName map[string]string) []string { readablePath := make([]string, 0, len(path)) for i := 0; i < len(path)-1; i++ { fromName, okFrom := vidToName[path[i]] toName, okTo := vidToName[path[i+1]] if okFrom && okTo && fromName != toName { readablePath = append(readablePath, fmt.Sprintf("%s -> %s", fromName, toName)) } } return readablePath } func getAncestorName(session *nebula.Session, ancestorVid string, vidToName map[string]string) string { if name, ok := vidToName[ancestorVid]; ok { return name } query := fmt.Sprintf(` USE `+Table_Space+`; FETCH PROP ON Legal "%s" YIELD Legal.name; `, ancestorVid) resp, err := session.Execute(query) if err != nil { log.Printf("获取祖先公司名称失败: %v", err) return "" } names, err := getFirstColumnStrings(resp) if err != nil || len(names) == 0 { return "" } return names[0] }