all.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. package main
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/RoaringBitmap/roaring"
  12. "go.uber.org/zap"
  13. "gorm.io/gorm"
  14. util "jygit.jydev.jianyu360.cn/data_processing/common_utils"
  15. jlog "jygit.jydev.jianyu360.cn/data_processing/common_utils/log"
  16. )
  17. // dealAllFromCompanyBase 从company_base 处理惬意数据存量
  18. func dealAllFromCompanyBase() {
  19. jlog.Info("dealAllFromCompanyBase", zap.String("开始处理", "-------企业库存量数据"))
  20. defer util.Catch()
  21. sess := MgoQY.GetMgoConn()
  22. defer MgoQY.DestoryMongoConn(sess)
  23. where := map[string]interface{}{
  24. "company_type": map[string]interface{}{
  25. "$ne": "个体工商户",
  26. },
  27. }
  28. count := 0
  29. batchSize := 100
  30. ents := make([]EntInfo, 0, batchSize)
  31. it := sess.DB(GF.MongoQy.DB).C("company_base").Find(where).Select(nil).Iter()
  32. for tmp := make(map[string]interface{}); it.Next(&tmp); count++ {
  33. if count%1000 == 0 {
  34. jlog.Info("dealAllFromCompanyBase", zap.Any("current:", count), zap.Any("company_name", tmp["company_name"]))
  35. }
  36. company_status := util.ObjToString(tmp["company_status"])
  37. if strings.Contains(company_status, "注销") || strings.Contains(company_status, "吊销") {
  38. continue
  39. }
  40. if util.IntAll(tmp["use_flag"]) > 0 {
  41. continue
  42. }
  43. if util.ObjToString(tmp["company_type"]) == "事业单位" {
  44. continue
  45. }
  46. var ent EntInfo
  47. ent.CompanyID = util.ObjToString(tmp["company_id"])
  48. ent.CompanyName = util.ObjToString(tmp["company_name"])
  49. ent.CompanyCode = util.ObjToString(tmp["company_code"])
  50. ent.CreditNo = util.ObjToString(tmp["credit_no"])
  51. ent.OrgCode = util.ObjToString(tmp["org_code"])
  52. ent.TaxCode = util.ObjToString(tmp["tax_code"])
  53. ent.EstablishDate = util.ObjToString(tmp["establish_date"])
  54. ent.LegalPerson = util.ObjToString(tmp["legal_person"])
  55. ent.LegalPersonCaption = util.ObjToString(tmp["legal_person_caption"])
  56. ent.CompanyStatus = util.ObjToString(tmp["company_status"])
  57. ent.CompanyType = util.ObjToString(tmp["company_type"])
  58. ent.Authority = util.ObjToString(tmp["authority"])
  59. ent.IssueDate = util.ObjToString(tmp["issue_date"])
  60. ent.OperationStartDate = util.ObjToString(tmp["operation_startdate"])
  61. ent.OperationEndDate = util.ObjToString(tmp["operation_enddate"])
  62. ent.Capital = util.ObjToString(tmp["capital"])
  63. ent.CompanyAddress = util.ObjToString(tmp["company_address"])
  64. ent.BusinessScope = util.ObjToString(tmp["business_scope"])
  65. ent.ComeInTime = time.Now().Unix()
  66. ent.UpdateTime = time.Now().Unix()
  67. ent.LegalPersonType = int8(util.IntAll(tmp["legal_person_type"]))
  68. ent.RealCapital = util.ObjToString(tmp["real_capital"])
  69. ent.EnName = util.ObjToString(tmp["en_name"])
  70. ent.ListCode = util.ObjToString(tmp["list_code"])
  71. //annual_reports
  72. std := getQyxyStd(util.ObjToString(tmp["company_name"]))
  73. if std != nil && len(std) > 0 {
  74. // 取出 annual_reports 字段
  75. reports, ok := std["annual_reports"].([]interface{})
  76. if ok {
  77. var maxYear float64
  78. var employeeNo string
  79. // 遍历 annual_reports 数组
  80. for i, r := range reports {
  81. if reportMap, ok := r.(map[string]interface{}); ok {
  82. year := util.Float64All(reportMap["report_year"])
  83. emp := util.ObjToString(reportMap["employee_no"])
  84. if i == 0 || year > maxYear {
  85. maxYear = year
  86. employeeNo = emp
  87. }
  88. }
  89. }
  90. if maxYear > 0 {
  91. ent.EmployeeNo = util.IntAll(employeeNo)
  92. }
  93. }
  94. }
  95. //
  96. ent.Website = util.ObjToString(tmp["website_url"])
  97. ent.CompanyPhone = util.ObjToString(tmp["company_phone"])
  98. ent.CompanyEmail = util.ObjToString(tmp["company_email"])
  99. //company_industry_tags
  100. whereIndustry := map[string]interface{}{
  101. "company_id": util.ObjToString(tmp["company_id"]),
  102. }
  103. indus, _ := MgoQY.FindOne("company_industry", whereIndustry)
  104. ent.CompanyIndustryTags = "{}" // 先给个默认值
  105. if indus != nil && len(*indus) > 0 {
  106. name_path := make([]string, 0)
  107. name_code := make([]string, 0)
  108. name_path = append(name_path, util.ObjToString((*indus)["industry_l1_name"]))
  109. name_path = append(name_path, util.ObjToString((*indus)["industry_l2_name"]))
  110. name_path = append(name_path, util.ObjToString((*indus)["industry_l3_name"]))
  111. name_path = append(name_path, util.ObjToString((*indus)["industry_l4_name"]))
  112. //
  113. name_code = append(name_code, util.ObjToString((*indus)["industry_l1_code"]))
  114. name_code = append(name_code, util.ObjToString((*indus)["industry_l2_code"]))
  115. name_code = append(name_code, util.ObjToString((*indus)["industry_l3_code"]))
  116. name_code = append(name_code, util.ObjToString((*indus)["industry_l4_code"]))
  117. industry := map[string]interface{}{
  118. "name_path": name_path,
  119. "code_path": name_code,
  120. }
  121. // map 转 JSON
  122. jsonBytes, _ := json.Marshal(industry)
  123. ent.CompanyIndustryTags = string(jsonBytes)
  124. }
  125. //
  126. area, city, district := util.ObjToString((std)["company_area"]), util.ObjToString((std)["company_city"]), util.ObjToString((std)["company_district"])
  127. area_code, city_code, district_code := CalculateRegionCode(area, city, district)
  128. ent.JYAreaCode = area_code
  129. ent.JYCityCode = city_code
  130. ent.JYDistrictCode = district_code
  131. //
  132. query := `
  133. SELECT bitmapToArray(company_label)
  134. FROM ent_info
  135. WHERE company_id = ?
  136. `
  137. var oldLabels = make([]uint64, 0)
  138. row := ClickHouseConn.QueryRow(context.Background(), query, ent.CompanyID)
  139. err := row.Scan(&oldLabels)
  140. if err != nil {
  141. if errors.Is(err, sql.ErrNoRows) {
  142. //jlog.Info("dealIncEntInfo: 没查到数据", zap.String("company_id", ent.CompanyID))
  143. } else {
  144. jlog.Info("dealIncEntInfo: 查询出错", zap.Error(err))
  145. }
  146. }
  147. // 转 RoaringBitmap
  148. rbm := roaring.NewBitmap()
  149. for _, v := range oldLabels {
  150. rbm.Add(uint32(v))
  151. }
  152. bin, _ := rbm.ToBytes()
  153. ent.JYCompanyLabel = bin
  154. ent.JYOrgTopType = "企业"
  155. company_type := util.ObjToString(tmp["company_type"])
  156. if info, ok := nameNorm[company_type]; ok {
  157. ent.JYCompanyTypeOriginCode = info.Code
  158. ent.JYCompanyTypeIsLeaf = 1
  159. ent.JYCompanyTypeLeafCode = info.Code
  160. ent.JYCompanyTypeLeafName = info.Name
  161. ent.JYCompanyTypeLeafTag = info.Tag
  162. ent.JYOrgPropertyOneTag = "工商"
  163. ent.JYOrgPropertyTwoTag = "企业"
  164. }
  165. //保存tidb
  166. //if err := MysqlDB.Create(&ent).Error; err != nil {
  167. // jlog.Info("insert failed: %v", zap.Error(err))
  168. //}
  169. ents = append(ents, ent)
  170. if len(ents) >= batchSize {
  171. if err := MysqlDB.CreateInBatches(ents, batchSize).Error; err != nil {
  172. jlog.Error("批量插入失败", zap.Error(err))
  173. }
  174. ents = ents[:0] // 清空 slice
  175. }
  176. }
  177. // 循环结束后如果还有数据
  178. if len(ents) > 0 {
  179. if err := MysqlDB.CreateInBatches(ents, batchSize).Error; err != nil {
  180. jlog.Error("批量插入失败", zap.Error(err))
  181. }
  182. }
  183. }
  184. // dealAllFromCompanyBase2 多协程批量数据
  185. func dealAllFromCompanyBase2() {
  186. jlog.Info("dealAllFromCompanyBase", zap.String("开始处理", "-------企业库存量数据"))
  187. defer util.Catch()
  188. sess := MgoQY.GetMgoConn()
  189. defer MgoQY.DestoryMongoConn(sess)
  190. where := map[string]interface{}{
  191. "company_type": map[string]interface{}{
  192. "$ne": "个体工商户",
  193. },
  194. "_id": map[string]interface{}{
  195. //"$lte": 964729447,
  196. "$gt": GF.Env.Startid,
  197. "$lte": GF.Env.Endid,
  198. },
  199. }
  200. if GF.Env.Startid >= GF.Env.Endid || GF.Env.Endid <= 0 {
  201. jlog.Error("dealAllFromCompanyBase2", zap.Any("where", where), zap.Any("查询条件错误", "开始结束ID错误"))
  202. }
  203. // channel 作为队列
  204. jobCh := make(chan map[string]interface{}, 1000) // 缓冲队列
  205. entCh := make(chan EntInfo, 1000) // 结果队列
  206. ctx, cancel := context.WithCancel(context.Background())
  207. defer cancel()
  208. jlog.Error("dealAllFromCompanyBase2", zap.Any("where", where))
  209. // 启动 worker 处理数据
  210. workerNum := 10 // 并发度可调
  211. var wg sync.WaitGroup
  212. for i := 0; i < workerNum; i++ {
  213. wg.Add(1)
  214. go func() {
  215. defer wg.Done()
  216. for tmp := range jobCh {
  217. ent, ok := processCompany(tmp)
  218. if ok {
  219. entCh <- ent
  220. }
  221. }
  222. }()
  223. }
  224. // 启动一个写入 goroutine,专门负责批量写 DB
  225. go func() {
  226. batchSize := 100
  227. ents := make([]EntInfo, 0, batchSize)
  228. for ent := range entCh {
  229. ents = append(ents, ent)
  230. if len(ents) >= batchSize {
  231. if err := MysqlDB.CreateInBatches(ents, batchSize).Error; err != nil {
  232. jlog.Error("批量插入失败", zap.Error(err))
  233. }
  234. ents = ents[:0]
  235. }
  236. }
  237. // flush
  238. if len(ents) > 0 {
  239. if err := MysqlDB.CreateInBatches(ents, batchSize).Error; err != nil {
  240. jlog.Error("批量插入失败", zap.Error(err))
  241. }
  242. }
  243. }()
  244. // 主协程负责读 Mongo
  245. it := sess.DB(GF.MongoQy.DB).C("company_base").Find(where).Sort("_id").Iter()
  246. count := 0
  247. for tmp := make(map[string]interface{}); it.Next(&tmp); count++ {
  248. if count%1000 == 0 {
  249. jlog.Info("dealAllFromCompanyBase", zap.Any("current:", count), zap.Any("company_name", tmp["company_name"]), zap.Any("id", tmp["_id"]))
  250. }
  251. select {
  252. case jobCh <- tmp:
  253. case <-ctx.Done():
  254. break
  255. }
  256. }
  257. close(jobCh) // 生产完毕
  258. wg.Wait() // 等所有 worker 结束
  259. close(entCh) // 再关掉结果通道,通知写入 goroutine flush 完成
  260. }
  261. // processCompany 处理单条 company_base 数据,生成 EntInfo
  262. func processCompany(tmp map[string]interface{}) (EntInfo, bool) {
  263. // 过滤条件
  264. company_status := util.ObjToString(tmp["company_status"])
  265. if strings.Contains(company_status, "注销") || strings.Contains(company_status, "吊销") {
  266. return EntInfo{}, false
  267. }
  268. if util.IntAll(tmp["use_flag"]) > 0 {
  269. return EntInfo{}, false
  270. }
  271. if util.ObjToString(tmp["company_type"]) == "事业单位" {
  272. return EntInfo{}, false
  273. }
  274. var ent EntInfo
  275. ent.CompanyID = util.ObjToString(tmp["company_id"])
  276. ent.CompanyName = util.ObjToString(tmp["company_name"])
  277. ent.CompanyCode = util.ObjToString(tmp["company_code"])
  278. ent.CreditNo = util.ObjToString(tmp["credit_no"])
  279. ent.OrgCode = util.ObjToString(tmp["org_code"])
  280. ent.TaxCode = util.ObjToString(tmp["tax_code"])
  281. ent.EstablishDate = util.ObjToString(tmp["establish_date"])
  282. ent.LegalPerson = util.ObjToString(tmp["legal_person"])
  283. ent.LegalPersonCaption = util.ObjToString(tmp["legal_person_caption"])
  284. ent.CompanyStatus = company_status
  285. ent.CompanyType = util.ObjToString(tmp["company_type"])
  286. ent.Authority = util.ObjToString(tmp["authority"])
  287. ent.IssueDate = util.ObjToString(tmp["issue_date"])
  288. ent.OperationStartDate = util.ObjToString(tmp["operation_startdate"])
  289. ent.OperationEndDate = util.ObjToString(tmp["operation_enddate"])
  290. ent.Capital = util.ObjToString(tmp["capital"])
  291. ent.CompanyAddress = util.ObjToString(tmp["company_address"])
  292. ent.BusinessScope = util.ObjToString(tmp["business_scope"])
  293. ent.ComeInTime = time.Now().Unix()
  294. ent.UpdateTime = time.Now().Unix()
  295. ent.LegalPersonType = int8(util.IntAll(tmp["legal_person_type"]))
  296. ent.RealCapital = util.ObjToString(tmp["real_capital"])
  297. ent.EnName = util.ObjToString(tmp["en_name"])
  298. ent.ListCode = util.ObjToString(tmp["list_code"])
  299. // annual_reports
  300. std := getQyxyStd(util.ObjToString(tmp["company_name"]))
  301. if std != nil && len(std) > 0 {
  302. reports, ok := std["annual_reports"].([]interface{})
  303. if ok {
  304. var maxYear float64
  305. var employeeNo string
  306. for i, r := range reports {
  307. if reportMap, ok := r.(map[string]interface{}); ok {
  308. year := util.Float64All(reportMap["report_year"])
  309. emp := util.ObjToString(reportMap["employee_no"])
  310. if i == 0 || year > maxYear {
  311. maxYear = year
  312. employeeNo = emp
  313. }
  314. }
  315. }
  316. if maxYear > 0 {
  317. ent.EmployeeNo = util.IntAll(employeeNo)
  318. }
  319. }
  320. }
  321. ent.Website = util.ObjToString(tmp["website_url"])
  322. ent.CompanyPhone = util.ObjToString(tmp["company_phone"])
  323. ent.CompanyEmail = util.ObjToString(tmp["company_email"])
  324. // company_industry_tags
  325. whereIndustry := map[string]interface{}{
  326. "company_id": util.ObjToString(tmp["company_id"]),
  327. }
  328. indus, _ := MgoQY.FindOne("company_industry", whereIndustry)
  329. ent.CompanyIndustryTags = "{}"
  330. if indus != nil && len(*indus) > 0 {
  331. name_path := []string{
  332. util.ObjToString((*indus)["industry_l1_name"]),
  333. util.ObjToString((*indus)["industry_l2_name"]),
  334. util.ObjToString((*indus)["industry_l3_name"]),
  335. util.ObjToString((*indus)["industry_l4_name"]),
  336. }
  337. name_code := []string{
  338. util.ObjToString((*indus)["industry_l1_code"]),
  339. util.ObjToString((*indus)["industry_l2_code"]),
  340. util.ObjToString((*indus)["industry_l3_code"]),
  341. util.ObjToString((*indus)["industry_l4_code"]),
  342. }
  343. industry := map[string]interface{}{
  344. "name_path": name_path,
  345. "code_path": name_code,
  346. }
  347. jsonBytes, _ := json.Marshal(industry)
  348. ent.CompanyIndustryTags = string(jsonBytes)
  349. }
  350. // 行政区划编码
  351. area, city, district := util.ObjToString((std)["company_area"]), util.ObjToString((std)["company_city"]), util.ObjToString((std)["company_district"])
  352. area_code, city_code, district_code := CalculateRegionCode(area, city, district)
  353. ent.JYAreaCode = area_code
  354. ent.JYCityCode = city_code
  355. ent.JYDistrictCode = district_code
  356. // ClickHouse 查询历史标签
  357. query := `SELECT bitmapToArray(company_label) FROM ent_info WHERE company_id = ?`
  358. var oldLabels = make([]uint64, 0)
  359. row := ClickHouseConn.QueryRow(context.Background(), query, ent.CompanyID)
  360. err := row.Scan(&oldLabels)
  361. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  362. jlog.Info("dealIncEntInfo: 查询出错", zap.Error(err))
  363. }
  364. rbm := roaring.NewBitmap()
  365. for _, v := range oldLabels {
  366. rbm.Add(uint32(v))
  367. }
  368. bin, _ := rbm.ToBytes()
  369. ent.JYCompanyLabel = bin
  370. ent.JYOrgTopType = "企业"
  371. company_type := util.ObjToString(tmp["company_type"])
  372. if info, ok := nameNorm[company_type]; ok {
  373. ent.JYCompanyTypeOriginCode = info.Code
  374. ent.JYCompanyTypeIsLeaf = 1
  375. ent.JYCompanyTypeLeafCode = info.Code
  376. ent.JYCompanyTypeLeafName = info.Name
  377. ent.JYCompanyTypeLeafTag = info.Tag
  378. ent.JYOrgPropertyOneTag = "工商"
  379. ent.JYOrgPropertyTwoTag = "企业"
  380. ent.JYOrgPropertyThreeTag = info.Tag2
  381. }
  382. return ent, true
  383. }
  384. // dealLeaf 处理存量非叶子节点的企业数据标签
  385. func dealLeaf() {
  386. const batchSize = 50
  387. lastID := uint64(0)
  388. for {
  389. var companies []EntInfo
  390. // 分批查询
  391. if err := MysqlDB.Model(&EntInfo{}).
  392. Select("id, company_name, credit_no, company_type, jy_company_type_is_leaf").
  393. Where("jy_company_type_is_leaf = ?", 0).
  394. Order("id ASC").
  395. Limit(batchSize).
  396. Find(&companies).Error; err != nil {
  397. panic(err)
  398. }
  399. if len(companies) == 0 {
  400. fmt.Println("处理完成 ✅")
  401. break
  402. }
  403. if lastID%1000 == 0 {
  404. jlog.Info("dealLeaf", zap.Any("lastID", lastID), zap.Any("id", companies[0].ID))
  405. }
  406. // 只存储有变化的公司
  407. updates := make(map[uint64]map[string]interface{})
  408. for i := range companies {
  409. if companies[i].JYCompanyTypeIsLeaf == 1 {
  410. continue
  411. }
  412. company_name := util.ObjToString(companies[i].CompanyName)
  413. top_names := getTopNames(company_name)
  414. for _, top_name := range top_names {
  415. topwhere := map[string]interface{}{
  416. "use_flag": 0,
  417. "company_status": map[string]interface{}{
  418. "$nin": []string{"注销", "吊销", "吊销,已注销"},
  419. },
  420. "company_name": top_name,
  421. }
  422. top_bases, _ := MgoQY.FindOne("company_base", topwhere)
  423. if top_bases != nil && len(*top_bases) > 0 {
  424. //获取上级企业类型
  425. top_company_type := util.ObjToString((*top_bases)["company_type"])
  426. if norm_info, ok := nameNorm[top_company_type]; ok {
  427. // 这里判断:如果已有字段不一样,才算变更
  428. if companies[i].JYCompanyTypeLeafCode != norm_info.Code ||
  429. companies[i].JYCompanyTypeLeafName != norm_info.Name ||
  430. companies[i].JYCompanyTypeLeafTag != norm_info.Tag ||
  431. companies[i].JYOrgPropertyThreeTag != norm_info.Tag2 {
  432. updates[companies[i].ID] = map[string]interface{}{
  433. "jy_company_type_leaf_code": norm_info.Code,
  434. "jy_company_type_leaf_name": norm_info.Name,
  435. "jy_company_type_leaf_tag": norm_info.Tag,
  436. "jy_org_property_three_tag": norm_info.Tag2,
  437. }
  438. }
  439. break
  440. }
  441. } else {
  442. // 去其他库查
  443. where2 := map[string]interface{}{"company_name": top_name}
  444. enterprise, _ := MgoQY.FindOne("special_enterprise", where2)
  445. if enterprise != nil && len(*enterprise) > 0 {
  446. if companies[i].JYOrgPropertyThreeTag != "国企" {
  447. updates[companies[i].ID] = map[string]interface{}{
  448. "jy_org_property_three_tag": "国企",
  449. }
  450. }
  451. break
  452. } else {
  453. gov, _ := MgoQY.FindOne("special_gov_unit", where2)
  454. if gov != nil && len(*gov) > 0 {
  455. if companies[i].JYOrgPropertyThreeTag != "国企" {
  456. updates[companies[i].ID] = map[string]interface{}{
  457. "jy_org_property_three_tag": "国企",
  458. }
  459. }
  460. break
  461. }
  462. }
  463. }
  464. }
  465. }
  466. // 批量更新 (只更新有变化的)
  467. if len(updates) > 0 {
  468. if err := batchUpdateFields(MysqlDB, (EntInfo{}).TableName(), updates); err != nil {
  469. panic(err)
  470. }
  471. }
  472. // 更新游标
  473. lastID = companies[len(companies)-1].ID
  474. }
  475. }
  476. // 允许更新的字段白名单(非常重要,防注入)
  477. var allowedColumns = map[string]struct{}{
  478. "jy_company_type_leaf_code": {},
  479. "jy_company_type_leaf_name": {},
  480. "jy_company_type_leaf_tag": {},
  481. "jy_org_property_three_tag": {},
  482. // 需要的话继续补充其它允许批量更新的字段
  483. }
  484. // batchUpdateFields 批量更新
  485. func batchUpdateFields(db *gorm.DB, tableName string, updates map[uint64]map[string]interface{}) error {
  486. if len(updates) == 0 {
  487. return nil
  488. }
  489. // 1) 收集字段(按白名单过滤)
  490. fieldSet := make(map[string]struct{})
  491. for _, m := range updates {
  492. for col := range m {
  493. if _, ok := allowedColumns[col]; ok {
  494. fieldSet[col] = struct{}{}
  495. }
  496. }
  497. }
  498. if len(fieldSet) == 0 {
  499. return nil
  500. }
  501. fields := make([]string, 0, len(fieldSet))
  502. for col := range fieldSet {
  503. fields = append(fields, col)
  504. }
  505. // 2) 构造 CASE 语句和参数
  506. cases := make([]string, 0, len(fields))
  507. args := make([]interface{}, 0, len(updates)*len(fields)*2)
  508. idSet := make(map[uint64]struct{}, len(updates))
  509. for _, field := range fields {
  510. var sb strings.Builder
  511. sb.WriteString(field)
  512. sb.WriteString(" = CASE id ")
  513. hasWhen := false
  514. for id, m := range updates {
  515. if val, ok := m[field]; ok {
  516. sb.WriteString("WHEN ? THEN ? ")
  517. args = append(args, id, val)
  518. idSet[id] = struct{}{}
  519. hasWhen = true
  520. }
  521. }
  522. // 如果这个字段对所有 id 都没有需要更新的值,就跳过它
  523. if !hasWhen {
  524. continue
  525. }
  526. sb.WriteString("ELSE ")
  527. sb.WriteString(field)
  528. sb.WriteString(" END")
  529. cases = append(cases, sb.String())
  530. }
  531. // 如果所有字段都被跳过了(比如全被白名单过滤),直接返回
  532. if len(cases) == 0 {
  533. return nil
  534. }
  535. // 3) WHERE IN 使用占位符
  536. ids := make([]uint64, 0, len(idSet))
  537. for id := range idSet {
  538. ids = append(ids, id)
  539. }
  540. placeholders := make([]string, 0, len(ids))
  541. for range ids {
  542. placeholders = append(placeholders, "?")
  543. }
  544. for _, id := range ids {
  545. args = append(args, id)
  546. }
  547. // 4) 组装最终 SQL
  548. sql := fmt.Sprintf(
  549. "UPDATE %s SET %s WHERE id IN (%s)",
  550. tableName,
  551. strings.Join(cases, ", "),
  552. strings.Join(placeholders, ","),
  553. )
  554. // 5) 建议放在事务里执行
  555. return db.Transaction(func(tx *gorm.DB) error {
  556. return tx.Exec(sql, args...).Error
  557. })
  558. }
  559. // get 通过companyID 获取法人库数据
  560. func get() {
  561. // 2. 查询一条数据
  562. var ent EntInfo
  563. if err := MysqlDB.Where("company_id = ?", "001c2e9882ae982abf6e1e9ed06e2654").First(&ent).Error; err != nil {
  564. panic(err)
  565. }
  566. // 3. 反序列化 RoaringBitmap
  567. rbm := roaring.NewBitmap()
  568. if len(ent.JYCompanyLabel) > 0 {
  569. if err := rbm.UnmarshalBinary(ent.JYCompanyLabel); err != nil {
  570. panic(err)
  571. }
  572. }
  573. // 4. 转成 []uint64
  574. ids := make([]uint64, 0, rbm.GetCardinality())
  575. it := rbm.Iterator()
  576. for it.HasNext() {
  577. ids = append(ids, uint64(it.Next()))
  578. }
  579. fmt.Println("CompanyID:", ent.CompanyID)
  580. fmt.Println("标签ID集合:", ids)
  581. }