all.go 12 KB


  1. package main
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "strings"
  9. "time"
  10. "github.com/RoaringBitmap/roaring"
  11. "go.uber.org/zap"
  12. "gorm.io/gorm"
  13. util "jygit.jydev.jianyu360.cn/data_processing/common_utils"
  14. jlog "jygit.jydev.jianyu360.cn/data_processing/common_utils/log"
  15. )
  16. // dealAllFromCompanyBase 从company_base 处理惬意数据存量
  17. func dealAllFromCompanyBase() {
  18. jlog.Info("dealAllFromCompanyBase", zap.String("开始处理", "-------企业库存量数据"))
  19. defer util.Catch()
  20. sess := MgoQY.GetMgoConn()
  21. defer MgoQY.DestoryMongoConn(sess)
  22. where := map[string]interface{}{
  23. "company_type": map[string]interface{}{
  24. "$ne": "个体工商户",
  25. },
  26. }
  27. count := 0
  28. batchSize := 100
  29. ents := make([]EntInfo, 0, batchSize)
  30. it := sess.DB(GF.MongoQy.DB).C("company_base").Find(where).Select(nil).Iter()
  31. for tmp := make(map[string]interface{}); it.Next(&tmp); count++ {
  32. if count%1000 == 0 {
  33. jlog.Info("dealAllFromCompanyBase", zap.Any("current:", count), zap.Any("company_name", tmp["company_name"]))
  34. }
  35. company_status := util.ObjToString(tmp["company_status"])
  36. if strings.Contains(company_status, "注销") || strings.Contains(company_status, "吊销") {
  37. continue
  38. }
  39. if util.IntAll(tmp["use_flag"]) > 0 {
  40. continue
  41. }
  42. var ent EntInfo
  43. ent.CompanyID = util.ObjToString(tmp["company_id"])
  44. ent.CompanyName = util.ObjToString(tmp["company_name"])
  45. ent.CompanyCode = util.ObjToString(tmp["company_code"])
  46. ent.CreditNo = util.ObjToString(tmp["credit_no"])
  47. ent.OrgCode = util.ObjToString(tmp["org_code"])
  48. ent.TaxCode = util.ObjToString(tmp["tax_code"])
  49. ent.EstablishDate = util.ObjToString(tmp["establish_date"])
  50. ent.LegalPerson = util.ObjToString(tmp["legal_person"])
  51. ent.LegalPersonCaption = util.ObjToString(tmp["legal_person_caption"])
  52. ent.CompanyStatus = util.ObjToString(tmp["company_status"])
  53. ent.CompanyType = util.ObjToString(tmp["company_type"])
  54. ent.Authority = util.ObjToString(tmp["authority"])
  55. ent.IssueDate = util.ObjToString(tmp["issue_date"])
  56. ent.OperationStartDate = util.ObjToString(tmp["operation_startdate"])
  57. ent.OperationEndDate = util.ObjToString(tmp["operation_enddate"])
  58. ent.Capital = util.ObjToString(tmp["capital"])
  59. ent.CompanyAddress = util.ObjToString(tmp["company_address"])
  60. ent.BusinessScope = util.ObjToString(tmp["business_scope"])
  61. ent.ComeInTime = time.Now().Unix()
  62. ent.UpdateTime = time.Now().Unix()
  63. ent.LegalPersonType = int8(util.IntAll(tmp["legal_person_type"]))
  64. ent.RealCapital = util.ObjToString(tmp["real_capital"])
  65. ent.EnName = util.ObjToString(tmp["en_name"])
  66. ent.ListCode = util.ObjToString(tmp["list_code"])
  67. //annual_reports
  68. std := getQyxyStd(util.ObjToString(tmp["company_name"]))
  69. if std != nil && len(std) > 0 {
  70. // 取出 annual_reports 字段
  71. reports, ok := std["annual_reports"].([]interface{})
  72. if ok {
  73. var maxYear float64
  74. var employeeNo string
  75. // 遍历 annual_reports 数组
  76. for i, r := range reports {
  77. if reportMap, ok := r.(map[string]interface{}); ok {
  78. year := util.Float64All(reportMap["report_year"])
  79. emp := util.ObjToString(reportMap["employee_no"])
  80. if i == 0 || year > maxYear {
  81. maxYear = year
  82. employeeNo = emp
  83. }
  84. }
  85. }
  86. if maxYear > 0 {
  87. ent.EmployeeNo = util.IntAll(employeeNo)
  88. }
  89. }
  90. }
  91. //
  92. ent.Website = util.ObjToString(tmp["website_url"])
  93. ent.CompanyPhone = util.ObjToString(tmp["company_phone"])
  94. ent.CompanyEmail = util.ObjToString(tmp["company_email"])
  95. //company_industry_tags
  96. whereIndustry := map[string]interface{}{
  97. "company_id": util.ObjToString(tmp["company_id"]),
  98. }
  99. indus, _ := MgoQY.FindOne("company_industry", whereIndustry)
  100. ent.CompanyIndustryTags = "{}" // 先给个默认值
  101. if indus != nil && len(*indus) > 0 {
  102. name_path := make([]string, 0)
  103. name_code := make([]string, 0)
  104. name_path = append(name_path, util.ObjToString((*indus)["industry_l1_name"]))
  105. name_path = append(name_path, util.ObjToString((*indus)["industry_l2_name"]))
  106. name_path = append(name_path, util.ObjToString((*indus)["industry_l3_name"]))
  107. name_path = append(name_path, util.ObjToString((*indus)["industry_l4_name"]))
  108. //
  109. name_code = append(name_code, util.ObjToString((*indus)["industry_l1_code"]))
  110. name_code = append(name_code, util.ObjToString((*indus)["industry_l2_code"]))
  111. name_code = append(name_code, util.ObjToString((*indus)["industry_l3_code"]))
  112. name_code = append(name_code, util.ObjToString((*indus)["industry_l4_code"]))
  113. industry := map[string]interface{}{
  114. "name_path": name_path,
  115. "code_path": name_code,
  116. }
  117. // map 转 JSON
  118. jsonBytes, _ := json.Marshal(industry)
  119. ent.CompanyIndustryTags = string(jsonBytes)
  120. }
  121. //
  122. area, city, district := util.ObjToString((std)["company_area"]), util.ObjToString((std)["company_city"]), util.ObjToString((std)["company_district"])
  123. area_code, city_code, district_code := CalculateRegionCode(area, city, district)
  124. ent.JYAreaCode = area_code
  125. ent.JYCityCode = city_code
  126. ent.JYDistrictCode = district_code
  127. //
  128. query := `
  129. SELECT bitmapToArray(company_label)
  130. FROM ent_info
  131. WHERE company_id = ?
  132. `
  133. var oldLabels = make([]uint64, 0)
  134. row := ClickHouseConn.QueryRow(context.Background(), query, ent.CompanyID)
  135. err := row.Scan(&oldLabels)
  136. if err != nil {
  137. if errors.Is(err, sql.ErrNoRows) {
  138. //jlog.Info("dealIncEntInfo: 没查到数据", zap.String("company_id", ent.CompanyID))
  139. } else {
  140. jlog.Info("dealIncEntInfo: 查询出错", zap.Error(err))
  141. }
  142. }
  143. // 转 RoaringBitmap
  144. rbm := roaring.NewBitmap()
  145. for _, v := range oldLabels {
  146. rbm.Add(uint32(v))
  147. }
  148. bin, _ := rbm.ToBytes()
  149. ent.JYCompanyLabel = bin
  150. ent.JYOrgTopType = "企业"
  151. company_type := util.ObjToString(tmp["company_type"])
  152. if info, ok := nameNorm[company_type]; ok {
  153. ent.JYCompanyTypeOriginCode = info.Code
  154. ent.JYCompanyTypeIsLeaf = 1
  155. ent.JYCompanyTypeLeafCode = info.Code
  156. ent.JYCompanyTypeLeafName = info.Name
  157. ent.JYCompanyTypeLeafTag = info.Tag
  158. ent.JYOrgPropertyOneTag = "工商"
  159. ent.JYOrgPropertyTwoTag = "企业"
  160. }
  161. //保存tidb
  162. //if err := MysqlDB.Create(&ent).Error; err != nil {
  163. // jlog.Info("insert failed: %v", zap.Error(err))
  164. //}
  165. ents = append(ents, ent)
  166. if len(ents) >= batchSize {
  167. if err := MysqlDB.CreateInBatches(ents, batchSize).Error; err != nil {
  168. jlog.Error("批量插入失败", zap.Error(err))
  169. }
  170. ents = ents[:0] // 清空 slice
  171. }
  172. }
  173. // 循环结束后如果还有数据
  174. if len(ents) > 0 {
  175. if err := MysqlDB.CreateInBatches(ents, batchSize).Error; err != nil {
  176. jlog.Error("批量插入失败", zap.Error(err))
  177. }
  178. }
  179. }
  180. // dealLeaf 处理存量非叶子节点的企业数据标签
  181. func dealLeaf() {
  182. const batchSize = 50
  183. lastID := uint64(0)
  184. for {
  185. var companies []EntInfo
  186. // 分批查询
  187. if err := MysqlDB.Model(&EntInfo{}).
  188. Select("id, company_name, credit_no, company_type, jy_company_type_is_leaf").
  189. Where("jy_company_type_is_leaf = ?", 0).
  190. Order("id ASC").
  191. Limit(batchSize).
  192. Find(&companies).Error; err != nil {
  193. panic(err)
  194. }
  195. if len(companies) == 0 {
  196. fmt.Println("处理完成 ✅")
  197. break
  198. }
  199. if lastID%1000 == 0 {
  200. jlog.Info("dealLeaf", zap.Any("lastID", lastID), zap.Any("id", companies[0].ID))
  201. }
  202. // 只存储有变化的公司
  203. updates := make(map[uint64]map[string]interface{})
  204. for i := range companies {
  205. if companies[i].JYCompanyTypeIsLeaf == 1 {
  206. continue
  207. }
  208. company_name := util.ObjToString(companies[i].CompanyName)
  209. top_names := getTopNames(company_name)
  210. for _, top_name := range top_names {
  211. topwhere := map[string]interface{}{
  212. "use_flag": 0,
  213. "company_status": map[string]interface{}{
  214. "$nin": []string{"注销", "吊销", "吊销,已注销"},
  215. },
  216. "company_name": top_name,
  217. }
  218. top_bases, _ := MgoQY.FindOne("company_base", topwhere)
  219. if top_bases != nil && len(*top_bases) > 0 {
  220. //获取上级企业类型
  221. top_company_type := util.ObjToString((*top_bases)["company_type"])
  222. if norm_info, ok := nameNorm[top_company_type]; ok {
  223. // 这里判断:如果已有字段不一样,才算变更
  224. if companies[i].JYCompanyTypeLeafCode != norm_info.Code ||
  225. companies[i].JYCompanyTypeLeafName != norm_info.Name ||
  226. companies[i].JYCompanyTypeLeafTag != norm_info.Tag ||
  227. companies[i].JYOrgPropertyThreeTag != norm_info.Tag2 {
  228. updates[companies[i].ID] = map[string]interface{}{
  229. "jy_company_type_leaf_code": norm_info.Code,
  230. "jy_company_type_leaf_name": norm_info.Name,
  231. "jy_company_type_leaf_tag": norm_info.Tag,
  232. "jy_org_property_three_tag": norm_info.Tag2,
  233. }
  234. }
  235. break
  236. }
  237. } else {
  238. // 去其他库查
  239. where2 := map[string]interface{}{"company_name": top_name}
  240. enterprise, _ := MgoQY.FindOne("special_enterprise", where2)
  241. if enterprise != nil && len(*enterprise) > 0 {
  242. if companies[i].JYOrgPropertyThreeTag != "国企" {
  243. updates[companies[i].ID] = map[string]interface{}{
  244. "jy_org_property_three_tag": "国企",
  245. }
  246. }
  247. break
  248. } else {
  249. gov, _ := MgoQY.FindOne("special_gov_unit", where2)
  250. if gov != nil && len(*gov) > 0 {
  251. if companies[i].JYOrgPropertyThreeTag != "国企" {
  252. updates[companies[i].ID] = map[string]interface{}{
  253. "jy_org_property_three_tag": "国企",
  254. }
  255. }
  256. break
  257. }
  258. }
  259. }
  260. }
  261. }
  262. // 批量更新 (只更新有变化的)
  263. if len(updates) > 0 {
  264. if err := batchUpdateFields(MysqlDB, (EntInfo{}).TableName(), updates); err != nil {
  265. panic(err)
  266. }
  267. }
  268. // 更新游标
  269. lastID = companies[len(companies)-1].ID
  270. }
  271. }
  272. // 允许更新的字段白名单(非常重要,防注入)
  273. var allowedColumns = map[string]struct{}{
  274. "jy_company_type_leaf_code": {},
  275. "jy_company_type_leaf_name": {},
  276. "jy_company_type_leaf_tag": {},
  277. "jy_org_property_three_tag": {},
  278. // 需要的话继续补充其它允许批量更新的字段
  279. }
  280. // batchUpdateFields 批量更新
  281. func batchUpdateFields(db *gorm.DB, tableName string, updates map[uint64]map[string]interface{}) error {
  282. if len(updates) == 0 {
  283. return nil
  284. }
  285. // 1) 收集字段(按白名单过滤)
  286. fieldSet := make(map[string]struct{})
  287. for _, m := range updates {
  288. for col := range m {
  289. if _, ok := allowedColumns[col]; ok {
  290. fieldSet[col] = struct{}{}
  291. }
  292. }
  293. }
  294. if len(fieldSet) == 0 {
  295. return nil
  296. }
  297. fields := make([]string, 0, len(fieldSet))
  298. for col := range fieldSet {
  299. fields = append(fields, col)
  300. }
  301. // 2) 构造 CASE 语句和参数
  302. cases := make([]string, 0, len(fields))
  303. args := make([]interface{}, 0, len(updates)*len(fields)*2)
  304. idSet := make(map[uint64]struct{}, len(updates))
  305. for _, field := range fields {
  306. var sb strings.Builder
  307. sb.WriteString(field)
  308. sb.WriteString(" = CASE id ")
  309. hasWhen := false
  310. for id, m := range updates {
  311. if val, ok := m[field]; ok {
  312. sb.WriteString("WHEN ? THEN ? ")
  313. args = append(args, id, val)
  314. idSet[id] = struct{}{}
  315. hasWhen = true
  316. }
  317. }
  318. // 如果这个字段对所有 id 都没有需要更新的值,就跳过它
  319. if !hasWhen {
  320. continue
  321. }
  322. sb.WriteString("ELSE ")
  323. sb.WriteString(field)
  324. sb.WriteString(" END")
  325. cases = append(cases, sb.String())
  326. }
  327. // 如果所有字段都被跳过了(比如全被白名单过滤),直接返回
  328. if len(cases) == 0 {
  329. return nil
  330. }
  331. // 3) WHERE IN 使用占位符
  332. ids := make([]uint64, 0, len(idSet))
  333. for id := range idSet {
  334. ids = append(ids, id)
  335. }
  336. placeholders := make([]string, 0, len(ids))
  337. for range ids {
  338. placeholders = append(placeholders, "?")
  339. }
  340. for _, id := range ids {
  341. args = append(args, id)
  342. }
  343. // 4) 组装最终 SQL
  344. sql := fmt.Sprintf(
  345. "UPDATE %s SET %s WHERE id IN (%s)",
  346. tableName,
  347. strings.Join(cases, ", "),
  348. strings.Join(placeholders, ","),
  349. )
  350. // 5) 建议放在事务里执行
  351. return db.Transaction(func(tx *gorm.DB) error {
  352. return tx.Exec(sql, args...).Error
  353. })
  354. }
  355. // get 通过companyID 获取法人库数据
  356. func get() {
  357. // 2. 查询一条数据
  358. var ent EntInfo
  359. if err := MysqlDB.Where("company_id = ?", "001c2e9882ae982abf6e1e9ed06e2654").First(&ent).Error; err != nil {
  360. panic(err)
  361. }
  362. // 3. 反序列化 RoaringBitmap
  363. rbm := roaring.NewBitmap()
  364. if len(ent.JYCompanyLabel) > 0 {
  365. if err := rbm.UnmarshalBinary(ent.JYCompanyLabel); err != nil {
  366. panic(err)
  367. }
  368. }
  369. // 4. 转成 []uint64
  370. ids := make([]uint64, 0, rbm.GetCardinality())
  371. it := rbm.Iterator()
  372. for it.HasNext() {
  373. ids = append(ids, uint64(it.Next()))
  374. }
  375. fmt.Println("CompanyID:", ent.CompanyID)
  376. fmt.Println("标签ID集合:", ids)
  377. }