utils.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. package main
  2. import (
  3. "fmt"
  4. nebula "github.com/vesoft-inc/nebula-go/v3"
  5. "log"
  6. )
  7. // 表示企业关系结果
  8. //type AllRelationResult struct {
  9. // RelatedCompanies []string // 有关系的企业列表
  10. // Paths []string // 对应的路径
  11. //}
  12. //
  13. //// 批量获取企业的 VID
  14. //func getVidsByName(session *nebula.Session, names []string) (map[string]string, error) {
  15. // if len(names) == 0 {
  16. // return nil, nil
  17. // }
  18. // conditions := ""
  19. // for i, name := range names {
  20. // if i > 0 {
  21. // conditions += " OR "
  22. // }
  23. // conditions += fmt.Sprintf("Legal.name == \"%s\"", name)
  24. // }
  25. // query := fmt.Sprintf(`
  26. //USE %s;
  27. //LOOKUP ON Legal WHERE %s YIELD id(vertex) AS vid, properties(vertex).name AS name
  28. //`, Table_Space, conditions)
  29. // resp, err := session.Execute(query)
  30. // if err != nil {
  31. // return nil, err
  32. // }
  33. // nameToVid := make(map[string]string)
  34. // for _, row := range resp.GetRows() {
  35. // if len(row.Values) >= 2 {
  36. // if row.Values[0].SVal != nil && row.Values[1].SVal != nil {
  37. // nameToVid[string(row.Values[1].SVal)] = string(row.Values[0].SVal)
  38. // }
  39. // }
  40. // }
  41. // return nameToVid, nil
  42. //}
  43. //
  44. //// 获取 VID 对应的名称
  45. //func getVidName(session *nebula.Session, vid string) (string, error) {
  46. // query := fmt.Sprintf(`
  47. //USE %s;
  48. //FETCH PROP ON Legal "%s" YIELD properties(vertex).name AS name
  49. //`, Table_Space, vid)
  50. // resp, err := session.Execute(query)
  51. // if err != nil {
  52. // return "", err
  53. // }
  54. // names, err := getFirstColumnStrings(resp)
  55. // if err != nil || len(names) == 0 {
  56. // return "", fmt.Errorf("未找到 VID %s 的名称", vid)
  57. // }
  58. // return names[0], nil
  59. //}
  60. //
  61. //// 查找路径
  62. //func findPath(session *nebula.Session, fromVid, toVid string, maxStep int, pathCache map[string][]string) ([]string, error) {
  63. // key := fmt.Sprintf("%s->%s:%d", fromVid, toVid, maxStep)
  64. // if cachedPath, ok := pathCache[key]; ok {
  65. // return cachedPath, nil
  66. // }
  67. // query := fmt.Sprintf(`FIND ALL PATH FROM "%s" TO "%s" OVER Invest UPTO %d STEPS YIELD path as p`, fromVid, toVid, maxStep)
  68. // resp, err := session.Execute(query)
  69. // if err != nil {
  70. // return nil, err
  71. // }
  72. // path, err := getFirstColumnStrings(resp)
  73. // if err != nil {
  74. // return nil, err
  75. // }
  76. // pathCache[key] = path
  77. // return path, nil
  78. //}
  79. //
  80. //// 检查共同祖先
  81. //func checkCommonAncestor(session *nebula.Session, aVid, bVid string, deep int, pathCache map[string][]string) (bool, []string, string) {
  82. // key := fmt.Sprintf("%s&%s:%d", aVid, bVid, deep)
  83. // if cachedPath, ok := pathCache[key]; ok {
  84. // if len(cachedPath) > 0 {
  85. // return true, cachedPath, cachedPath[1]
  86. // }
  87. // return false, nil, ""
  88. // }
  89. // query := fmt.Sprintf(`
  90. // (
  91. // GO 1 TO %d STEPS FROM "%s" OVER Invest REVERSELY YIELD dst(edge) AS ancestor
  92. // )
  93. // INTERSECT
  94. // (
  95. // GO 1 TO %d STEPS FROM "%s" OVER Invest REVERSELY YIELD dst(edge) AS ancestor
  96. // );
  97. // `, deep, aVid, deep, bVid)
  98. // resp, err := session.Execute(query)
  99. // if err != nil {
  100. // return false, nil, ""
  101. // }
  102. // ancestors, err := getFirstColumnStrings(resp)
  103. // if err != nil || len(ancestors) == 0 {
  104. // pathCache[key] = nil
  105. // return false, nil, ""
  106. // }
  107. // pathA, _ := findPath(session, aVid, ancestors[0], deep, pathCache)
  108. // pathB, _ := findPath(session, bVid, ancestors[0], deep, pathCache)
  109. // var path []string
  110. // if len(pathB) > 1 {
  111. // path = append(pathA, pathB[1:]...)
  112. // } else {
  113. // path = append(pathA, pathB...)
  114. // }
  115. // pathCache[key] = path
  116. // return true, path, ancestors[0]
  117. //}
  118. //
  119. //// 将 VID 路径转换为名称路径
  120. //func convertVidPathToNamePath(session *nebula.Session, vidPath []string) (string, error) {
  121. // namePath := ""
  122. // for i, vid := range vidPath {
  123. // name, err := getVidName(session, vid)
  124. // if err != nil {
  125. // return "", err
  126. // }
  127. // if i > 0 {
  128. // namePath += "->"
  129. // }
  130. // namePath += name
  131. // }
  132. // return namePath, nil
  133. //}
  134. //
  135. //// 检查企业关系
  136. //func CheckLegalRelations(session *nebula.Session, names []string, deep int) (AllRelationResult, error) {
  137. // result := AllRelationResult{}
  138. // checked := make(map[string]bool)
  139. // nameToVid, err := getVidsByName(session, names)
  140. // if err != nil {
  141. // return result, err
  142. // }
  143. // pathCache := make(map[string][]string)
  144. // relatedCompaniesSet := make(map[string]bool)
  145. // var paths []string
  146. //
  147. // for i := 0; i < len(names); i++ {
  148. // for j := i + 1; j < len(names); j++ {
  149. // a, b := names[i], names[j]
  150. // vidA, okA := nameToVid[a]
  151. // vidB, okB := nameToVid[b]
  152. // if !okA || !okB {
  153. // continue
  154. // }
  155. //
  156. // key := vidA + "|" + vidB
  157. // if checked[key] {
  158. // continue
  159. // }
  160. // checked[key] = true
  161. //
  162. // // 1. a -> b
  163. // pathAB, err := findPath(session, vidA, vidB, deep, pathCache)
  164. // if err != nil {
  165. // log.Printf("查找 %s 到 %s 的路径失败: %v", a, b, err)
  166. // continue
  167. // }
  168. // if len(pathAB) > 0 {
  169. // pathStr, err := convertVidPathToNamePath(session, pathAB)
  170. // if err != nil {
  171. // log.Printf("转换 %s 到 %s 的路径失败: %v", a, b, err)
  172. // continue
  173. // }
  174. // paths = append(paths, pathStr)
  175. // for _, vid := range pathAB {
  176. // name, err := getVidName(session, vid)
  177. // if err != nil {
  178. // log.Printf("获取 VID %s 对应的名称失败: %v", vid, err)
  179. // continue
  180. // }
  181. // relatedCompaniesSet[name] = true
  182. // }
  183. // continue
  184. // }
  185. //
  186. // // 2. b -> a
  187. // pathBA, err := findPath(session, vidB, vidA, deep, pathCache)
  188. // if err != nil {
  189. // log.Printf("查找 %s 到 %s 的路径失败: %v", b, a, err)
  190. // continue
  191. // }
  192. // if len(pathBA) > 0 {
  193. // pathStr, err := convertVidPathToNamePath(session, pathBA)
  194. // if err != nil {
  195. // log.Printf("转换 %s 到 %s 的路径失败: %v", b, a, err)
  196. // continue
  197. // }
  198. // paths = append(paths, pathStr)
  199. // for _, vid := range pathBA {
  200. // name, err := getVidName(session, vid)
  201. // if err != nil {
  202. // log.Printf("获取 VID %s 对应的名称失败: %v", vid, err)
  203. // continue
  204. // }
  205. // relatedCompaniesSet[name] = true
  206. // }
  207. // continue
  208. // }
  209. //
  210. // // 3. common ancestor
  211. // common, path, _ := checkCommonAncestor(session, vidA, vidB, deep, pathCache)
  212. // if common {
  213. // pathStr, err := convertVidPathToNamePath(session, path)
  214. // if err != nil {
  215. // log.Printf("转换 %s 和 %s 到共同祖先的路径失败: %v", a, b, err)
  216. // continue
  217. // }
  218. // paths = append(paths, pathStr)
  219. // for _, vid := range path {
  220. // name, err := getVidName(session, vid)
  221. // if err != nil {
  222. // log.Printf("获取 VID %s 对应的名称失败: %v", vid, err)
  223. // continue
  224. // }
  225. // relatedCompaniesSet[name] = true
  226. // }
  227. // }
  228. // }
  229. // }
  230. //
  231. // for company := range relatedCompaniesSet {
  232. // result.RelatedCompanies = append(result.RelatedCompanies, company)
  233. // }
  234. // result.Paths = paths
  235. //
  236. // return result, nil
  237. //}
  238. //
  239. //// getFirstColumnStrings 适配 nebula-go v3 取出字符串类型列
  240. //func getFirstColumnStrings(resp *nebula.ResultSet) ([]string, error) {
  241. // if resp == nil {
  242. // return nil, fmt.Errorf("result set is nil")
  243. // }
  244. //
  245. // var values []string
  246. // for _, row := range resp.GetRows() {
  247. // if len(row.Values) == 0 {
  248. // continue
  249. // }
  250. // val := row.Values[0]
  251. // switch {
  252. // case val.SVal != nil:
  253. // values = append(values, string(val.SVal))
  254. // case val.IVal != nil:
  255. // values = append(values, fmt.Sprintf("%d", *val.IVal))
  256. // case val.BVal != nil:
  257. // values = append(values, fmt.Sprintf("%v", *val.BVal))
  258. // default:
  259. // log.Printf("未知类型值: %+v", val)
  260. // }
  261. // }
  262. // return values, nil
  263. //}
  264. func CheckLegalRelations(session *nebula.Session, names []string, deep int) ([]RelationResult, error) {
  265. results := []RelationResult{}
  266. checked := make(map[string]bool)
  267. nameToVid, err := getAllVids(session, names)
  268. if err != nil {
  269. return nil, err
  270. }
  271. vidToName := reverseMap(nameToVid)
  272. for i := 0; i < len(names); i++ {
  273. for j := i + 1; j < len(names); j++ {
  274. a, b := names[i], names[j]
  275. vidA, okA := nameToVid[a]
  276. vidB, okB := nameToVid[b]
  277. if !okA || !okB {
  278. continue
  279. }
  280. key := vidA + "|" + vidB
  281. if checked[key] {
  282. continue
  283. }
  284. checked[key] = true
  285. // 1. a -> b
  286. pathAB, err := findPath(session, vidA, vidB, deep)
  287. if err != nil {
  288. return nil, err
  289. }
  290. if len(pathAB) > 0 {
  291. readablePath := convertPathToNames(pathAB, vidToName)
  292. results = append(results, RelationResult{A: a, B: b, RelationType: "direct_or_indirect", Path: readablePath})
  293. continue
  294. }
  295. // 2. b -> a
  296. pathBA, err := findPath(session, vidB, vidA, deep)
  297. if err != nil {
  298. return nil, err
  299. }
  300. if len(pathBA) > 0 {
  301. readablePath := convertPathToNames(pathBA, vidToName)
  302. results = append(results, RelationResult{A: b, B: a, RelationType: "direct_or_indirect", Path: readablePath})
  303. continue
  304. }
  305. // 3. common ancestor
  306. common, ancestorVid, err := checkCommonAncestor(session, vidA, vidB, deep)
  307. if err != nil {
  308. return nil, err
  309. }
  310. if common {
  311. ancestorName := getAncestorName(session, ancestorVid, vidToName)
  312. aName := vidToName[vidA]
  313. bName := vidToName[vidB]
  314. if ancestorName != "" && aName != "" && bName != "" {
  315. readablePath := []string{
  316. fmt.Sprintf("%s -> %s", aName, ancestorName),
  317. fmt.Sprintf("%s -> %s", bName, ancestorName),
  318. }
  319. results = append(results, RelationResult{A: a, B: b, RelationType: "common_ancestor", Path: readablePath})
  320. }
  321. }
  322. }
  323. }
  324. return results, nil
  325. }
  326. func getAllVids(session *nebula.Session, names []string) (map[string]string, error) {
  327. nameToVid := make(map[string]string)
  328. for _, name := range names {
  329. vid, err := getVidByName(session, name)
  330. if err != nil {
  331. log.Printf("获取 %s 的 VID 失败: %v", name, err)
  332. continue
  333. }
  334. nameToVid[name] = vid
  335. }
  336. return nameToVid, nil
  337. }
  338. func checkCommonAncestor(session *nebula.Session, aVid, bVid string, deep int) (bool, string, error) {
  339. query := fmt.Sprintf(`
  340. (
  341. GO 1 TO %d STEPS FROM "%s" OVER Invest REVERSELY YIELD dst(edge) AS ancestor
  342. )
  343. INTERSECT
  344. (
  345. GO 1 TO %d STEPS FROM "%s" OVER Invest REVERSELY YIELD dst(edge) AS ancestor
  346. );
  347. `, deep, aVid, deep, bVid)
  348. resp, err := session.Execute(query)
  349. if err != nil {
  350. return false, "", err
  351. }
  352. ancestors, err := getFirstColumnStrings(resp)
  353. if err != nil || len(ancestors) == 0 {
  354. return false, "", nil
  355. }
  356. return true, ancestors[0], nil
  357. }
  358. func findPath(session *nebula.Session, fromVid, toVid string, maxStep int) ([]string, error) {
  359. query := fmt.Sprintf(`FIND ALL PATH FROM "%s" TO "%s" OVER Invest UPTO %d STEPS YIELD path as p`, fromVid, toVid, maxStep)
  360. resp, err := session.Execute(query)
  361. if err != nil {
  362. return nil, err
  363. }
  364. return getFirstColumnStrings(resp)
  365. }
  366. func getVidByName(session *nebula.Session, name string) (string, error) {
  367. query := fmt.Sprintf(`
  368. USE `+Table_Space+`;
  369. LOOKUP ON Legal WHERE Legal.name == "%s" YIELD id(vertex)`, name)
  370. resp, err := session.Execute(query)
  371. if err != nil {
  372. return "", err
  373. }
  374. values, err := getFirstColumnStrings(resp)
  375. if err != nil || len(values) == 0 {
  376. return "", fmt.Errorf("未找到公司: %s", name)
  377. }
  378. return values[0], nil
  379. }
  380. type RelationResult struct {
  381. A, B string // 公司名
  382. RelationType string // 关系类型:"direct_or_indirect", "common_ancestor"
  383. Path []string // 路径中的公司名称,以更直观的形式展示
  384. }
  385. // getFirstColumnStrings 适配 nebula-go v3 取出字符串类型列
  386. func getFirstColumnStrings(resp *nebula.ResultSet) ([]string, error) {
  387. if resp == nil {
  388. return nil, fmt.Errorf("result set is nil")
  389. }
  390. var values []string
  391. for _, row := range resp.GetRows() {
  392. if len(row.Values) == 0 {
  393. continue
  394. }
  395. val := row.Values[0]
  396. switch {
  397. case val.SVal != nil:
  398. values = append(values, string(val.SVal))
  399. case val.IVal != nil:
  400. values = append(values, fmt.Sprintf("%d", *val.IVal))
  401. case val.BVal != nil:
  402. values = append(values, fmt.Sprintf("%v", *val.BVal))
  403. case val.PVal != nil:
  404. // 处理点类型
  405. //src := val.PVal.GetSrc()
  406. //if src.GetId != nil {
  407. // values = append(values, string(*src.SVal))
  408. //} else if src.IVal != nil {
  409. // values = append(values, fmt.Sprintf("%d", *src.IVal))
  410. //} else {
  411. // log.Printf("未知的点源 ID 类型: %+v", src)
  412. //}
  413. default:
  414. log.Printf("未知类型值: %+v", val)
  415. }
  416. }
  417. return values, nil
  418. }
  419. func reverseMap(m map[string]string) map[string]string {
  420. result := make(map[string]string)
  421. for k, v := range m {
  422. result[v] = k
  423. }
  424. return result
  425. }
  426. func convertPathToNames(path []string, vidToName map[string]string) []string {
  427. readablePath := make([]string, 0, len(path))
  428. for i := 0; i < len(path)-1; i++ {
  429. fromName, okFrom := vidToName[path[i]]
  430. toName, okTo := vidToName[path[i+1]]
  431. if okFrom && okTo && fromName != toName {
  432. readablePath = append(readablePath, fmt.Sprintf("%s -> %s", fromName, toName))
  433. }
  434. }
  435. return readablePath
  436. }
  437. func getAncestorName(session *nebula.Session, ancestorVid string, vidToName map[string]string) string {
  438. if name, ok := vidToName[ancestorVid]; ok {
  439. return name
  440. }
  441. query := fmt.Sprintf(`
  442. USE `+Table_Space+`;
  443. FETCH PROP ON Legal "%s" YIELD Legal.name;
  444. `, ancestorVid)
  445. resp, err := session.Execute(query)
  446. if err != nil {
  447. log.Printf("获取祖先公司名称失败: %v", err)
  448. return ""
  449. }
  450. names, err := getFirstColumnStrings(resp)
  451. if err != nil || len(names) == 0 {
  452. return ""
  453. }
  454. return names[0]
  455. }