mysql.go 16 KB


  1. package mysql
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "fmt"
  6. "log"
  7. "reflect"
  8. "strings"
  9. "time"
  10. _ "github.com/go-sql-driver/mysql"
  11. )
  12. const (
  13. CLICKHOUSE = "clickhouse"
  14. )
  15. type Mysql struct {
  16. Address string //数据库地址:端口
  17. UserName string //用户名
  18. PassWord string //密码
  19. DBName string //数据库名
  20. DB *sql.DB `json:",optional"` //数据库连接池对象
  21. MaxOpenConns int //用于设置最大打开的连接数,默认值为0表示不限制。
  22. MaxIdleConns int //用于设置闲置的连接数。
  23. driverName string
  24. }
  25. //
  26. func NewInit(driverName, dataSourceName string, maxOpenConns, maxIdleConns int) *Mysql {
  27. if maxOpenConns <= 0 {
  28. maxOpenConns = 30
  29. }
  30. if maxIdleConns <= 0 {
  31. maxIdleConns = 6
  32. }
  33. db, err := sql.Open(driverName, dataSourceName)
  34. if err != nil {
  35. log.Println("open error", err)
  36. return nil
  37. }
  38. db.SetMaxOpenConns(maxOpenConns)
  39. db.SetMaxIdleConns(maxIdleConns)
  40. db.SetConnMaxLifetime(14400 * time.Second)
  41. if err := db.Ping(); err != nil {
  42. log.Println("ping error", err)
  43. return nil
  44. }
  45. return &Mysql{DB: db, driverName: driverName}
  46. }
  47. //
  48. func (m *Mysql) Init() {
  49. m.driverName = "mysql"
  50. nm := NewInit(m.driverName, fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", m.UserName, m.PassWord, m.Address, m.DBName), m.MaxOpenConns, m.MaxIdleConns)
  51. if nm == nil {
  52. return
  53. }
  54. m.DB = nm.DB
  55. }
  56. //新增
  57. func (m *Mysql) Insert(tableName string, data map[string]interface{}) int64 {
  58. return m.InsertByTx(nil, tableName, data)
  59. }
  60. //带有事务的新增
  61. func (m *Mysql) InsertByTx(tx *sql.Tx, tableName string, data map[string]interface{}) int64 {
  62. fields := []string{}
  63. values := []interface{}{}
  64. placeholders := []string{}
  65. if tableName == "dataexport_order" {
  66. if _, ok := data["user_nickname"]; ok {
  67. data["user_nickname"] = ""
  68. }
  69. }
  70. for k, v := range data {
  71. fields = append(fields, k)
  72. values = append(values, v)
  73. placeholders = append(placeholders, "?")
  74. }
  75. q := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(fields, ","), strings.Join(placeholders, ","))
  76. log.Println(q, values)
  77. return m.InsertBySqlByTx(tx, q, values...)
  78. }
  79. //sql语句新增
  80. func (m *Mysql) InsertBySql(q string, args ...interface{}) int64 {
  81. return m.InsertBySqlByTx(nil, q, args...)
  82. }
  83. //带有事务的sql语句新增
  84. func (m *Mysql) InsertBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  85. result, _ := m.ExecBySqlByTx(tx, q, args...)
  86. if result == nil {
  87. return -1
  88. }
  89. if m.driverName == CLICKHOUSE {
  90. return 0
  91. }
  92. id, err := result.LastInsertId()
  93. if err != nil {
  94. log.Println(err)
  95. return -1
  96. }
  97. return id
  98. }
  99. //批量新增
  100. func (m *Mysql) InsertIgnoreBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  101. return m.InsertIgnoreBatchByTx(nil, tableName, fields, values)
  102. }
  103. //带事务的批量新增
  104. func (m *Mysql) InsertIgnoreBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  105. return m.insertOrReplaceBatchByTx(tx, "INSERT", "IGNORE", tableName, fields, values)
  106. }
  107. //批量新增
  108. func (m *Mysql) InsertBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  109. return m.InsertBatchByTx(nil, tableName, fields, values)
  110. }
  111. //带事务的批量新增
  112. func (m *Mysql) InsertBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  113. return m.insertOrReplaceBatchByTx(tx, "INSERT", "", tableName, fields, values)
  114. }
  115. //批量更新
  116. func (m *Mysql) ReplaceBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  117. return m.ReplaceBatchByTx(nil, tableName, fields, values)
  118. }
  119. //带事务的批量更新
  120. func (m *Mysql) ReplaceBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  121. return m.insertOrReplaceBatchByTx(tx, "REPLACE", "", tableName, fields, values)
  122. }
  123. func (m *Mysql) insertOrReplaceBatchByTx(tx *sql.Tx, tp string, afterInsert, tableName string, fields []string, values []interface{}) (int64, int64) {
  124. placeholders := []string{}
  125. for range fields {
  126. placeholders = append(placeholders, "?")
  127. }
  128. placeholder := strings.Join(placeholders, ",")
  129. array := []string{}
  130. for i := 0; i < len(values)/len(fields); i++ {
  131. array = append(array, fmt.Sprintf("(%s)", placeholder))
  132. }
  133. q := fmt.Sprintf("%s %s INTO %s (%s) VALUES %s", tp, afterInsert, tableName, strings.Join(fields, ","), strings.Join(array, ","))
  134. result, _ := m.ExecBySqlByTx(tx, q, values...)
  135. if result == nil {
  136. return -1, -1
  137. }
  138. if m.driverName == CLICKHOUSE {
  139. return 0, 0
  140. }
  141. v1, e1 := result.RowsAffected()
  142. if e1 != nil {
  143. log.Println(e1)
  144. return -1, -1
  145. }
  146. v2, e2 := result.LastInsertId()
  147. if e2 != nil {
  148. log.Println(e2)
  149. return -1, -1
  150. }
  151. return v1, v2
  152. }
  153. //sql语句执行
  154. func (m *Mysql) ExecBySql(q string, args ...interface{}) (sql.Result, error) {
  155. return m.ExecBySqlByTx(nil, q, args...)
  156. }
  157. //sql语句执行,带有事务
  158. func (m *Mysql) ExecBySqlByTx(tx *sql.Tx, q string, args ...interface{}) (sql.Result, error) {
  159. var result sql.Result
  160. var err error
  161. if m.driverName == CLICKHOUSE {
  162. result, err = m.DB.Exec(q, args...)
  163. } else {
  164. var stmtIns *sql.Stmt
  165. if tx == nil {
  166. stmtIns, err = m.DB.Prepare(q)
  167. } else {
  168. stmtIns, err = tx.Prepare(q)
  169. }
  170. if err != nil {
  171. log.Println(err)
  172. return nil, err
  173. }
  174. defer stmtIns.Close()
  175. result, err = stmtIns.Exec(args...)
  176. }
  177. if err != nil {
  178. log.Println(args, err)
  179. return nil, err
  180. }
  181. return result, nil
  182. }
  183. /*不等于 map[string]string{"ne":"1"}
  184. *不等于多个 map[string]string{"notin":[]interface{}{1,2}}
  185. *字段为空 map[string]string{"name":"$isNull"}
  186. *字段不为空 map[string]string{"name":"$isNotNull"}
  187. */
  188. func (m *Mysql) Find(tableName string, query map[string]interface{}, fields, order string, start, pageSize int) *[]map[string]interface{} {
  189. fs := []string{}
  190. vs := []interface{}{}
  191. for k, v := range query {
  192. rt := reflect.TypeOf(v)
  193. rv := reflect.ValueOf(v)
  194. if rt.Kind() == reflect.Map {
  195. for _, rv_k := range rv.MapKeys() {
  196. if rv_k.String() == "ne" {
  197. fs = append(fs, fmt.Sprintf("%s!=?", k))
  198. vs = append(vs, rv.MapIndex(rv_k).Interface())
  199. }
  200. if rv_k.String() == "notin" {
  201. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  202. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  203. fs = append(fs, fmt.Sprintf("%s!=?", k))
  204. vs = append(vs, v)
  205. }
  206. }
  207. }
  208. if rv_k.String() == "in" {
  209. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  210. _fs := fmt.Sprintf("%s in (?", k)
  211. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  212. if k > 0 {
  213. _fs += ",?"
  214. }
  215. vs = append(vs, v)
  216. }
  217. _fs += ")"
  218. fs = append(fs, _fs)
  219. }
  220. }
  221. }
  222. } else {
  223. if v == "$isNull" {
  224. fs = append(fs, fmt.Sprintf("%s is null", k))
  225. } else if v == "$isNotNull" {
  226. fs = append(fs, fmt.Sprintf("%s is not null", k))
  227. } else {
  228. fs = append(fs, fmt.Sprintf("%s=?", k))
  229. vs = append(vs, v)
  230. }
  231. }
  232. }
  233. var buffer bytes.Buffer
  234. buffer.WriteString("select ")
  235. if fields == "" {
  236. buffer.WriteString("*")
  237. } else {
  238. buffer.WriteString(fields)
  239. }
  240. buffer.WriteString(" from ")
  241. buffer.WriteString(tableName)
  242. if len(fs) > 0 {
  243. buffer.WriteString(" where ")
  244. buffer.WriteString(strings.Join(fs, " and "))
  245. }
  246. if order != "" {
  247. buffer.WriteString(" order by ")
  248. buffer.WriteString(order)
  249. }
  250. if start > -1 && pageSize > 0 {
  251. buffer.WriteString(" limit ")
  252. buffer.WriteString(fmt.Sprint(start))
  253. buffer.WriteString(",")
  254. buffer.WriteString(fmt.Sprint(pageSize))
  255. }
  256. q := buffer.String()
  257. log.Println(q, vs)
  258. return m.SelectBySql(q, vs...)
  259. }
  260. //sql语句查询
  261. func (m *Mysql) SelectBySql(q string, args ...interface{}) *[]map[string]interface{} {
  262. return m.SelectBySqlByTx(nil, q, args...)
  263. }
  264. func (m *Mysql) SelectBySqlByTx(tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  265. return m.Select(0, nil, tx, q, args...)
  266. }
  267. func (m *Mysql) Select(bath int, f func(l *[]map[string]interface{}) bool, tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  268. var rows *sql.Rows
  269. var err error
  270. if m.driverName == CLICKHOUSE {
  271. rows, err = m.DB.Query(q, args...)
  272. } else {
  273. var stmtOut *sql.Stmt
  274. if tx == nil {
  275. stmtOut, err = m.DB.Prepare(q)
  276. } else {
  277. stmtOut, err = tx.Prepare(q)
  278. }
  279. if err != nil {
  280. log.Println(err)
  281. return nil
  282. }
  283. defer stmtOut.Close()
  284. rows, err = stmtOut.Query(args...)
  285. }
  286. if err != nil {
  287. log.Println(err)
  288. return nil
  289. }
  290. if rows != nil {
  291. defer rows.Close()
  292. }
  293. columns, err := rows.Columns()
  294. if err != nil {
  295. log.Println(err)
  296. return nil
  297. }
  298. list := []map[string]interface{}{}
  299. for rows.Next() {
  300. scanArgs := make([]interface{}, len(columns))
  301. values := make([]interface{}, len(columns))
  302. ret := make(map[string]interface{})
  303. for k, _ := range values {
  304. scanArgs[k] = &values[k]
  305. }
  306. err = rows.Scan(scanArgs...)
  307. if err != nil {
  308. log.Println(err)
  309. break
  310. }
  311. for i, col := range values {
  312. if v, ok := col.([]uint8); ok {
  313. ret[columns[i]] = string(v)
  314. } else {
  315. ret[columns[i]] = col
  316. }
  317. }
  318. list = append(list, ret)
  319. if bath > 0 && len(list) == bath {
  320. if !f(&list) {
  321. list = []map[string]interface{}{}
  322. break
  323. }
  324. list = []map[string]interface{}{}
  325. }
  326. }
  327. if bath > 0 && len(list) > 0 {
  328. f(&list)
  329. list = []map[string]interface{}{}
  330. }
  331. return &list
  332. }
  333. func (m *Mysql) SelectByBath(bath int, f func(l *[]map[string]interface{}) bool, q string, args ...interface{}) {
  334. m.SelectByBathByTx(bath, f, nil, q, args...)
  335. }
  336. func (m *Mysql) SelectByBathByTx(bath int, f func(l *[]map[string]interface{}) bool, tx *sql.Tx, q string, args ...interface{}) {
  337. m.Select(bath, f, tx, q, args...)
  338. }
  339. func (m *Mysql) FindOne(tableName string, query map[string]interface{}, fields, order string) *map[string]interface{} {
  340. list := m.Find(tableName, query, fields, order, 0, 1)
  341. if list != nil && len(*list) == 1 {
  342. temp := (*list)[0]
  343. return &temp
  344. }
  345. return nil
  346. }
  347. //修改
  348. func (m *Mysql) Update(tableName string, query, update map[string]interface{}) bool {
  349. return m.UpdateByTx(nil, tableName, query, update)
  350. }
  351. //带事务的修改
  352. func (m *Mysql) UpdateByTx(tx *sql.Tx, tableName string, query, update map[string]interface{}) bool {
  353. q_fs := []string{}
  354. u_fs := []string{}
  355. values := []interface{}{}
  356. for k, v := range update {
  357. q_fs = append(q_fs, fmt.Sprintf("%s=?", k))
  358. values = append(values, v)
  359. }
  360. for k, v := range query {
  361. u_fs = append(u_fs, fmt.Sprintf("%s=?", k))
  362. values = append(values, v)
  363. }
  364. q := fmt.Sprintf("update %s set %s where %s", tableName, strings.Join(q_fs, ","), strings.Join(u_fs, " and "))
  365. log.Println(q, values)
  366. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) >= 0
  367. }
  368. //批量更新
  369. func (m *Mysql) UpdateBath(tableName string, fields []string, array [][]interface{}) {
  370. m.UpdateBathByTx(nil, tableName, fields, array)
  371. }
  372. //带事务的批量更新
  373. func (m *Mysql) UpdateBathByTx(tx *sql.Tx, tableName string, fields []string, array [][]interface{}) int64 {
  374. ws := []string{}
  375. args := []interface{}{}
  376. ids := []interface{}{}
  377. casethen := []string{}
  378. for n := 0; n < len(array[0]); n++ {
  379. for _, v := range array {
  380. if n == 0 {
  381. ws = append(ws, "?")
  382. ids = append(ids, v[0])
  383. casethen = append(casethen, "when ? then ?")
  384. } else {
  385. args = append(args, v[0], v[n])
  386. }
  387. }
  388. }
  389. ct := strings.Join(casethen, " ")
  390. sql_appends := []string{}
  391. for k, v := range fields {
  392. if k == 0 {
  393. continue
  394. }
  395. sql_appends = append(sql_appends, fmt.Sprintf(`%s=case %s %s end`, v, fields[0], ct))
  396. }
  397. args = append(args, ids...)
  398. sql := fmt.Sprintf(`update %s set %s where %s in (%s)`, tableName, strings.Join(sql_appends, ","), fields[0], strings.Join(ws, ","))
  399. return m.UpdateOrDeleteBySqlByTx(tx, sql, args...)
  400. }
  401. //删除
  402. func (m *Mysql) Delete(tableName string, query map[string]interface{}) bool {
  403. return m.DeleteByTx(nil, tableName, query)
  404. }
  405. func (m *Mysql) DeleteByTx(tx *sql.Tx, tableName string, query map[string]interface{}) bool {
  406. fields := []string{}
  407. values := []interface{}{}
  408. for k, v := range query {
  409. fields = append(fields, fmt.Sprintf("%s=?", k))
  410. values = append(values, v)
  411. }
  412. q := fmt.Sprintf("delete from %s where %s", tableName, strings.Join(fields, " and "))
  413. log.Println(q, values)
  414. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) > 0
  415. }
  416. //修改或删除
  417. func (m *Mysql) UpdateOrDeleteBySql(q string, args ...interface{}) int64 {
  418. return m.UpdateOrDeleteBySqlByTx(nil, q, args...)
  419. }
  420. //带事务的修改或删除
  421. func (m *Mysql) UpdateOrDeleteBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  422. result, _ := m.ExecBySqlByTx(tx, q, args...)
  423. if result == nil {
  424. return -1
  425. }
  426. if m.driverName == CLICKHOUSE {
  427. return 0
  428. }
  429. count, err := result.RowsAffected()
  430. if err != nil {
  431. log.Println(err)
  432. return -1
  433. }
  434. return count
  435. }
  436. //总数
  437. func (m *Mysql) Count(tableName string, query map[string]interface{}) int64 {
  438. fields := []string{}
  439. values := []interface{}{}
  440. for k, v := range query {
  441. rt := reflect.TypeOf(v)
  442. rv := reflect.ValueOf(v)
  443. if rt.Kind() == reflect.Map {
  444. for _, rv_k := range rv.MapKeys() {
  445. if rv_k.String() == "ne" {
  446. fields = append(fields, fmt.Sprintf("%s!=?", k))
  447. values = append(values, rv.MapIndex(rv_k).Interface())
  448. }
  449. if rv_k.String() == "notin" {
  450. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  451. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  452. fields = append(fields, fmt.Sprintf("%s!=?", k))
  453. values = append(values, v)
  454. }
  455. }
  456. }
  457. if rv_k.String() == "in" {
  458. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  459. _fs := fmt.Sprintf("%s in (?", k)
  460. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  461. if k > 0 {
  462. _fs += ",?"
  463. }
  464. values = append(values, v)
  465. }
  466. _fs += ")"
  467. fields = append(fields, _fs)
  468. }
  469. }
  470. }
  471. } else if v == "$isNull" {
  472. fields = append(fields, fmt.Sprintf("%s is null", k))
  473. } else if v == "$isNotNull" {
  474. fields = append(fields, fmt.Sprintf("%s is not null", k))
  475. } else {
  476. fields = append(fields, fmt.Sprintf("%s=?", k))
  477. values = append(values, v)
  478. }
  479. }
  480. q := fmt.Sprintf("select count(1) as count from %s", tableName)
  481. if len(query) > 0 {
  482. q += fmt.Sprintf(" where %s", strings.Join(fields, " and "))
  483. }
  484. log.Println(q, values)
  485. return m.CountBySql(q, values...)
  486. }
  487. func (m *Mysql) CountBySql(q string, args ...interface{}) int64 {
  488. return m.CountBySqlByTx(nil, q, args...)
  489. }
  490. func (m *Mysql) CountBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  491. var rows *sql.Rows
  492. var err error
  493. if m.driverName == CLICKHOUSE {
  494. rows, err = m.DB.Query(q, args...)
  495. } else {
  496. var stmtOut *sql.Stmt
  497. if tx == nil {
  498. stmtOut, err = m.DB.Prepare(q)
  499. } else {
  500. stmtOut, err = tx.Prepare(q)
  501. }
  502. if err != nil {
  503. log.Println(err)
  504. return -1
  505. }
  506. defer stmtOut.Close()
  507. rows, err = stmtOut.Query(args...)
  508. }
  509. if err != nil {
  510. log.Println(err)
  511. return -1
  512. }
  513. if rows != nil {
  514. defer rows.Close()
  515. }
  516. var count int64 = -1
  517. if rows.Next() {
  518. err = rows.Scan(&count)
  519. if err != nil {
  520. log.Println(err)
  521. }
  522. }
  523. return count
  524. }
  525. //执行事务
  526. func (m *Mysql) ExecTx(msg string, f func(tx *sql.Tx) bool) bool {
  527. tx, err := m.DB.Begin()
  528. if err != nil {
  529. log.Println(msg, "获取事务错误", err)
  530. } else {
  531. if f(tx) {
  532. if err := tx.Commit(); err != nil {
  533. log.Println(msg, "提交事务错误", err)
  534. } else {
  535. return true
  536. }
  537. } else {
  538. if err := tx.Rollback(); err != nil {
  539. log.Println(msg, "事务回滚错误", err)
  540. }
  541. }
  542. }
  543. return false
  544. }
  545. /*************方法命名不规范,上面有替代方法*************/
  546. func (m *Mysql) Query(query string, args ...interface{}) *[]map[string]interface{} {
  547. return m.SelectBySql(query, args...)
  548. }
  549. func (m *Mysql) QueryCount(query string, args ...interface{}) (count int) {
  550. count = -1
  551. if !strings.Contains(strings.ToLower(query), "count(*)") {
  552. fmt.Println("QueryCount need query like < select count(*) from ..... >")
  553. return
  554. }
  555. count = int(m.CountBySql(query, args...))
  556. return
  557. }