mysql.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. package mysql
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "fmt"
  6. "log"
  7. "reflect"
  8. "strings"
  9. "time"
  10. _ "github.com/go-sql-driver/mysql"
  11. )
  12. type Mysql struct {
  13. Address string //数据库地址:端口
  14. UserName string //用户名
  15. PassWord string //密码
  16. DBName string //数据库名
  17. DB *sql.DB `json:",optional"` //数据库连接池对象
  18. MaxOpenConns int //用于设置最大打开的连接数,默认值为0表示不限制。
  19. MaxIdleConns int //用于设置闲置的连接数。
  20. }
  21. func (m *Mysql) Init() {
  22. if m.MaxOpenConns <= 0 {
  23. m.MaxOpenConns = 30
  24. }
  25. if m.MaxIdleConns <= 0 {
  26. m.MaxIdleConns = 6
  27. }
  28. var err error
  29. m.DB, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", m.UserName, m.PassWord, m.Address, m.DBName))
  30. if err != nil {
  31. log.Println(err)
  32. return
  33. }
  34. m.DB.SetMaxOpenConns(m.MaxOpenConns)
  35. m.DB.SetMaxIdleConns(m.MaxIdleConns)
  36. m.DB.SetConnMaxLifetime(14400 * time.Second)
  37. err = m.DB.Ping()
  38. if err != nil {
  39. log.Println(err)
  40. }
  41. }
  42. //新增
  43. func (m *Mysql) Insert(tableName string, data map[string]interface{}) int64 {
  44. return m.InsertByTx(nil, tableName, data)
  45. }
  46. //带有事务的新增
  47. func (m *Mysql) InsertByTx(tx *sql.Tx, tableName string, data map[string]interface{}) int64 {
  48. fields := []string{}
  49. values := []interface{}{}
  50. placeholders := []string{}
  51. if tableName == "dataexport_order" {
  52. if _, ok := data["user_nickname"]; ok {
  53. data["user_nickname"] = ""
  54. }
  55. }
  56. for k, v := range data {
  57. fields = append(fields, k)
  58. values = append(values, v)
  59. placeholders = append(placeholders, "?")
  60. }
  61. q := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(fields, ","), strings.Join(placeholders, ","))
  62. log.Println("mysql", q, values)
  63. return m.InsertBySqlByTx(tx, q, values...)
  64. }
  65. //sql语句新增
  66. func (m *Mysql) InsertBySql(q string, args ...interface{}) int64 {
  67. return m.InsertBySqlByTx(nil, q, args...)
  68. }
  69. //带有事务的sql语句新增
  70. func (m *Mysql) InsertBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  71. result, _ := m.ExecBySqlByTx(tx, q, args...)
  72. if result == nil {
  73. return -1
  74. }
  75. id, err := result.LastInsertId()
  76. if err != nil {
  77. log.Println(err)
  78. return -1
  79. }
  80. return id
  81. }
  82. //批量新增
  83. func (m *Mysql) InsertIgnoreBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  84. return m.InsertIgnoreBatchByTx(nil, tableName, fields, values)
  85. }
  86. //带事务的批量新增
  87. func (m *Mysql) InsertIgnoreBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  88. return m.insertOrReplaceBatchByTx(tx, "INSERT", "IGNORE", tableName, fields, values)
  89. }
  90. //批量新增
  91. func (m *Mysql) InsertBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  92. return m.InsertBatchByTx(nil, tableName, fields, values)
  93. }
  94. //带事务的批量新增
  95. func (m *Mysql) InsertBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  96. return m.insertOrReplaceBatchByTx(tx, "INSERT", "", tableName, fields, values)
  97. }
  98. //批量更新
  99. func (m *Mysql) ReplaceBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  100. return m.ReplaceBatchByTx(nil, tableName, fields, values)
  101. }
  102. //带事务的批量更新
  103. func (m *Mysql) ReplaceBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  104. return m.insertOrReplaceBatchByTx(tx, "REPLACE", "", tableName, fields, values)
  105. }
  106. func (m *Mysql) insertOrReplaceBatchByTx(tx *sql.Tx, tp string, afterInsert, tableName string, fields []string, values []interface{}) (int64, int64) {
  107. placeholders := []string{}
  108. for range fields {
  109. placeholders = append(placeholders, "?")
  110. }
  111. placeholder := strings.Join(placeholders, ",")
  112. array := []string{}
  113. for i := 0; i < len(values)/len(fields); i++ {
  114. array = append(array, fmt.Sprintf("(%s)", placeholder))
  115. }
  116. q := fmt.Sprintf("%s %s INTO %s (%s) VALUES %s", tp, afterInsert, tableName, strings.Join(fields, ","), strings.Join(array, ","))
  117. result, _ := m.ExecBySqlByTx(tx, q, values...)
  118. if result == nil {
  119. return -1, -1
  120. }
  121. v1, e1 := result.RowsAffected()
  122. if e1 != nil {
  123. log.Println(e1)
  124. return -1, -1
  125. }
  126. v2, e2 := result.LastInsertId()
  127. if e2 != nil {
  128. log.Println(e2)
  129. return -1, -1
  130. }
  131. return v1, v2
  132. }
  133. //sql语句执行
  134. func (m *Mysql) ExecBySql(q string, args ...interface{}) (sql.Result, error) {
  135. return m.ExecBySqlByTx(nil, q, args...)
  136. }
  137. //sql语句执行,带有事务
  138. func (m *Mysql) ExecBySqlByTx(tx *sql.Tx, q string, args ...interface{}) (sql.Result, error) {
  139. var stmtIns *sql.Stmt
  140. var err error
  141. if tx == nil {
  142. stmtIns, err = m.DB.Prepare(q)
  143. } else {
  144. stmtIns, err = tx.Prepare(q)
  145. }
  146. if err != nil {
  147. log.Println(err)
  148. return nil, err
  149. }
  150. defer stmtIns.Close()
  151. result, err := stmtIns.Exec(args...)
  152. if err != nil {
  153. log.Println(args, err)
  154. return nil, err
  155. }
  156. return result, nil
  157. }
  158. /*不等于 map[string]string{"ne":"1"}
  159. *不等于多个 map[string]string{"notin":[]interface{}{1,2}}
  160. *字段为空 map[string]string{"name":"$isNull"}
  161. *字段不为空 map[string]string{"name":"$isNotNull"}
  162. */
  163. func (m *Mysql) Find(tableName string, query map[string]interface{}, fields, order string, start, pageSize int) *[]map[string]interface{} {
  164. fs := []string{}
  165. vs := []interface{}{}
  166. for k, v := range query {
  167. rt := reflect.TypeOf(v)
  168. rv := reflect.ValueOf(v)
  169. if rt.Kind() == reflect.Map {
  170. for _, rv_k := range rv.MapKeys() {
  171. if rv_k.String() == "ne" {
  172. fs = append(fs, fmt.Sprintf("%s!=?", k))
  173. vs = append(vs, rv.MapIndex(rv_k).Interface())
  174. }
  175. if rv_k.String() == "notin" {
  176. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  177. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  178. fs = append(fs, fmt.Sprintf("%s!=?", k))
  179. vs = append(vs, v)
  180. }
  181. }
  182. }
  183. if rv_k.String() == "in" {
  184. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  185. _fs := fmt.Sprintf("%s in (?", k)
  186. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  187. if k > 0 {
  188. _fs += ",?"
  189. }
  190. vs = append(vs, v)
  191. }
  192. _fs += ")"
  193. fs = append(fs, _fs)
  194. }
  195. }
  196. }
  197. } else {
  198. if v == "$isNull" {
  199. fs = append(fs, fmt.Sprintf("%s is null", k))
  200. } else if v == "$isNotNull" {
  201. fs = append(fs, fmt.Sprintf("%s is not null", k))
  202. } else {
  203. fs = append(fs, fmt.Sprintf("%s=?", k))
  204. vs = append(vs, v)
  205. }
  206. }
  207. }
  208. var buffer bytes.Buffer
  209. buffer.WriteString("select ")
  210. if fields == "" {
  211. buffer.WriteString("*")
  212. } else {
  213. buffer.WriteString(fields)
  214. }
  215. buffer.WriteString(" from ")
  216. buffer.WriteString(tableName)
  217. if len(fs) > 0 {
  218. buffer.WriteString(" where ")
  219. buffer.WriteString(strings.Join(fs, " and "))
  220. }
  221. if order != "" {
  222. buffer.WriteString(" order by ")
  223. buffer.WriteString(order)
  224. }
  225. if start > -1 && pageSize > 0 {
  226. buffer.WriteString(" limit ")
  227. buffer.WriteString(fmt.Sprint(start))
  228. buffer.WriteString(",")
  229. buffer.WriteString(fmt.Sprint(pageSize))
  230. }
  231. q := buffer.String()
  232. log.Println(q, vs)
  233. return m.SelectBySql(q, vs...)
  234. }
  235. //sql语句查询
  236. func (m *Mysql) SelectBySql(q string, args ...interface{}) *[]map[string]interface{} {
  237. return m.SelectBySqlByTx(nil, q, args...)
  238. }
  239. func (m *Mysql) SelectBySqlByTx(tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  240. return m.Select(0, nil, tx, q, args...)
  241. }
  242. func (m *Mysql) Select(bath int, f func(l *[]map[string]interface{}) bool, tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  243. var stmtOut *sql.Stmt
  244. var err error
  245. if tx == nil {
  246. stmtOut, err = m.DB.Prepare(q)
  247. } else {
  248. stmtOut, err = tx.Prepare(q)
  249. }
  250. if err != nil {
  251. log.Println(err)
  252. return nil
  253. }
  254. defer stmtOut.Close()
  255. rows, err := stmtOut.Query(args...)
  256. if err != nil {
  257. log.Println(err)
  258. return nil
  259. }
  260. if rows != nil {
  261. defer rows.Close()
  262. }
  263. columns, err := rows.Columns()
  264. if err != nil {
  265. log.Println(err)
  266. return nil
  267. }
  268. list := []map[string]interface{}{}
  269. for rows.Next() {
  270. scanArgs := make([]interface{}, len(columns))
  271. values := make([]interface{}, len(columns))
  272. ret := make(map[string]interface{})
  273. for k, _ := range values {
  274. scanArgs[k] = &values[k]
  275. }
  276. err = rows.Scan(scanArgs...)
  277. if err != nil {
  278. log.Println(err)
  279. break
  280. }
  281. for i, col := range values {
  282. if v, ok := col.([]uint8); ok {
  283. ret[columns[i]] = string(v)
  284. } else {
  285. ret[columns[i]] = col
  286. }
  287. }
  288. list = append(list, ret)
  289. if bath > 0 && len(list) == bath {
  290. if !f(&list) {
  291. list = []map[string]interface{}{}
  292. break
  293. }
  294. list = []map[string]interface{}{}
  295. }
  296. }
  297. if bath > 0 && len(list) > 0 {
  298. f(&list)
  299. list = []map[string]interface{}{}
  300. }
  301. return &list
  302. }
  303. func (m *Mysql) SelectByBath(bath int, f func(l *[]map[string]interface{}) bool, q string, args ...interface{}) {
  304. m.SelectByBathByTx(bath, f, nil, q, args...)
  305. }
  306. func (m *Mysql) SelectByBathByTx(bath int, f func(l *[]map[string]interface{}) bool, tx *sql.Tx, q string, args ...interface{}) {
  307. m.Select(bath, f, tx, q, args...)
  308. }
  309. func (m *Mysql) FindOne(tableName string, query map[string]interface{}, fields, order string) *map[string]interface{} {
  310. list := m.Find(tableName, query, fields, order, 0, 1)
  311. if list != nil && len(*list) == 1 {
  312. temp := (*list)[0]
  313. return &temp
  314. }
  315. return nil
  316. }
  317. //修改
  318. func (m *Mysql) Update(tableName string, query, update map[string]interface{}) bool {
  319. return m.UpdateByTx(nil, tableName, query, update)
  320. }
  321. //带事务的修改
  322. func (m *Mysql) UpdateByTx(tx *sql.Tx, tableName string, query, update map[string]interface{}) bool {
  323. q_fs := []string{}
  324. u_fs := []string{}
  325. values := []interface{}{}
  326. for k, v := range update {
  327. q_fs = append(q_fs, fmt.Sprintf("%s=?", k))
  328. values = append(values, v)
  329. }
  330. for k, v := range query {
  331. u_fs = append(u_fs, fmt.Sprintf("%s=?", k))
  332. values = append(values, v)
  333. }
  334. q := fmt.Sprintf("update %s set %s where %s", tableName, strings.Join(q_fs, ","), strings.Join(u_fs, " and "))
  335. log.Println(q, values)
  336. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) >= 0
  337. }
  338. //批量更新
  339. func (m *Mysql) UpdateBath(tableName string, fields []string, array [][]interface{}) {
  340. m.UpdateBathByTx(nil, tableName, fields, array)
  341. }
  342. //带事务的批量更新
  343. func (m *Mysql) UpdateBathByTx(tx *sql.Tx, tableName string, fields []string, array [][]interface{}) int64 {
  344. ws := []string{}
  345. args := []interface{}{}
  346. ids := []interface{}{}
  347. casethen := []string{}
  348. for n := 0; n < len(array[0]); n++ {
  349. for _, v := range array {
  350. if n == 0 {
  351. ws = append(ws, "?")
  352. ids = append(ids, v[0])
  353. casethen = append(casethen, "when ? then ?")
  354. } else {
  355. args = append(args, v[0], v[n])
  356. }
  357. }
  358. }
  359. ct := strings.Join(casethen, " ")
  360. sql_appends := []string{}
  361. for k, v := range fields {
  362. if k == 0 {
  363. continue
  364. }
  365. sql_appends = append(sql_appends, fmt.Sprintf(`%s=case %s %s end`, v, fields[0], ct))
  366. }
  367. args = append(args, ids...)
  368. sql := fmt.Sprintf(`update %s set %s where %s in (%s)`, tableName, strings.Join(sql_appends, ","), fields[0], strings.Join(ws, ","))
  369. return m.UpdateOrDeleteBySqlByTx(tx, sql, args...)
  370. }
  371. //删除
  372. func (m *Mysql) Delete(tableName string, query map[string]interface{}) bool {
  373. return m.DeleteByTx(nil, tableName, query)
  374. }
  375. func (m *Mysql) DeleteByTx(tx *sql.Tx, tableName string, query map[string]interface{}) bool {
  376. fields := []string{}
  377. values := []interface{}{}
  378. for k, v := range query {
  379. fields = append(fields, fmt.Sprintf("%s=?", k))
  380. values = append(values, v)
  381. }
  382. q := fmt.Sprintf("delete from %s where %s", tableName, strings.Join(fields, " and "))
  383. log.Println(q, values)
  384. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) > 0
  385. }
  386. //修改或删除
  387. func (m *Mysql) UpdateOrDeleteBySql(q string, args ...interface{}) int64 {
  388. return m.UpdateOrDeleteBySqlByTx(nil, q, args...)
  389. }
  390. //带事务的修改或删除
  391. func (m *Mysql) UpdateOrDeleteBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  392. result, err := m.ExecBySqlByTx(tx, q, args...)
  393. if err != nil {
  394. log.Println(err)
  395. return -1
  396. }
  397. count, err := result.RowsAffected()
  398. if err != nil {
  399. log.Println(err)
  400. return -1
  401. }
  402. return count
  403. }
  404. //总数
  405. func (m *Mysql) Count(tableName string, query map[string]interface{}) int64 {
  406. fields := []string{}
  407. values := []interface{}{}
  408. for k, v := range query {
  409. rt := reflect.TypeOf(v)
  410. rv := reflect.ValueOf(v)
  411. if rt.Kind() == reflect.Map {
  412. for _, rv_k := range rv.MapKeys() {
  413. if rv_k.String() == "ne" {
  414. fields = append(fields, fmt.Sprintf("%s!=?", k))
  415. values = append(values, rv.MapIndex(rv_k).Interface())
  416. }
  417. if rv_k.String() == "notin" {
  418. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  419. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  420. fields = append(fields, fmt.Sprintf("%s!=?", k))
  421. values = append(values, v)
  422. }
  423. }
  424. }
  425. if rv_k.String() == "in" {
  426. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  427. _fs := fmt.Sprintf("%s in (?", k)
  428. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  429. if k > 0 {
  430. _fs += ",?"
  431. }
  432. values = append(values, v)
  433. }
  434. _fs += ")"
  435. fields = append(fields, _fs)
  436. }
  437. }
  438. }
  439. } else if v == "$isNull" {
  440. fields = append(fields, fmt.Sprintf("%s is null", k))
  441. } else if v == "$isNotNull" {
  442. fields = append(fields, fmt.Sprintf("%s is not null", k))
  443. } else {
  444. fields = append(fields, fmt.Sprintf("%s=?", k))
  445. values = append(values, v)
  446. }
  447. }
  448. q := fmt.Sprintf("select count(1) as count from %s", tableName)
  449. if len(query) > 0 {
  450. q += fmt.Sprintf(" where %s", strings.Join(fields, " and "))
  451. }
  452. log.Println(q, values)
  453. return m.CountBySql(q, values...)
  454. }
  455. func (m *Mysql) CountBySql(q string, args ...interface{}) int64 {
  456. stmtIns, err := m.DB.Prepare(q)
  457. if err != nil {
  458. log.Println(err)
  459. return -1
  460. }
  461. defer stmtIns.Close()
  462. rows, err := stmtIns.Query(args...)
  463. if err != nil {
  464. log.Println(err)
  465. return -1
  466. }
  467. if rows != nil {
  468. defer rows.Close()
  469. }
  470. var count int64 = -1
  471. if rows.Next() {
  472. err = rows.Scan(&count)
  473. if err != nil {
  474. log.Println(err)
  475. }
  476. }
  477. return count
  478. }
  479. //执行事务
  480. func (m *Mysql) ExecTx(msg string, f func(tx *sql.Tx) bool) bool {
  481. tx, err := m.DB.Begin()
  482. if err != nil {
  483. log.Println(msg, "获取事务错误", err)
  484. } else {
  485. if f(tx) {
  486. if err := tx.Commit(); err != nil {
  487. log.Println(msg, "提交事务错误", err)
  488. } else {
  489. return true
  490. }
  491. } else {
  492. if err := tx.Rollback(); err != nil {
  493. log.Println(msg, "事务回滚错误", err)
  494. }
  495. }
  496. }
  497. return false
  498. }
  499. /*************方法命名不规范,上面有替代方法*************/
  500. func (m *Mysql) Query(query string, args ...interface{}) *[]map[string]interface{} {
  501. return m.SelectBySql(query, args...)
  502. }
  503. func (m *Mysql) QueryCount(query string, args ...interface{}) (count int) {
  504. count = -1
  505. if !strings.Contains(strings.ToLower(query), "count(*)") {
  506. fmt.Println("QueryCount need query like < select count(*) from ..... >")
  507. return
  508. }
  509. count = int(m.CountBySql(query, args...))
  510. return
  511. }