mysql.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. package util
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "fmt"
  6. "log"
  7. "reflect"
  8. "strings"
  9. "time"
  10. _ "github.com/go-sql-driver/mysql"
  11. )
  12. type Mysql struct {
  13. Address string //数据库地址:端口
  14. UserName string //用户名
  15. PassWord string //密码
  16. DBName string //数据库名
  17. DB *sql.DB //数据库连接池对象
  18. MaxOpenConns int //用于设置最大打开的连接数,默认值为0表示不限制。
  19. MaxIdleConns int //用于设置闲置的连接数。
  20. }
  21. func (m *Mysql) Init() {
  22. if m.MaxOpenConns <= 0 {
  23. m.MaxOpenConns = 20
  24. }
  25. if m.MaxIdleConns <= 0 {
  26. m.MaxIdleConns = 20
  27. }
  28. var err error //utf8mb4
  29. m.DB, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", m.UserName, m.PassWord, m.Address, m.DBName))
  30. log.Println(fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", m.UserName, m.PassWord, m.Address, m.DBName))
  31. if err != nil {
  32. log.Println(err)
  33. return
  34. }
  35. m.DB.SetMaxOpenConns(m.MaxOpenConns)
  36. m.DB.SetMaxIdleConns(m.MaxIdleConns)
  37. m.DB.SetConnMaxLifetime(time.Minute * 3)
  38. err = m.DB.Ping()
  39. if err != nil {
  40. log.Println(err)
  41. }
  42. }
  43. // 新增
  44. func (m *Mysql) Insert(tableName string, data map[string]interface{}) int64 {
  45. return m.InsertByTx(nil, tableName, data)
  46. }
  47. // 带有事务的新增
  48. func (m *Mysql) InsertByTx(tx *sql.Tx, tableName string, data map[string]interface{}) int64 {
  49. fields := []string{}
  50. values := []interface{}{}
  51. placeholders := []string{}
  52. if tableName == "dataexport_order" {
  53. if _, ok := data["user_nickname"]; ok {
  54. data["user_nickname"] = ""
  55. }
  56. }
  57. for k, v := range data {
  58. fields = append(fields, k)
  59. values = append(values, v)
  60. placeholders = append(placeholders, "?")
  61. }
  62. q := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(fields, ","), strings.Join(placeholders, ","))
  63. //log.Println("mysql", q, values)
  64. return m.InsertBySqlByTx(tx, q, values...)
  65. }
  66. // sql语句新增
  67. func (m *Mysql) InsertBySql(q string, args ...interface{}) int64 {
  68. return m.InsertBySqlByTx(nil, q, args...)
  69. }
  70. // 带有事务的sql语句新增
  71. func (m *Mysql) InsertBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  72. result, _ := m.ExecBySqlByTx(tx, q, args...)
  73. if result == nil {
  74. return -1
  75. }
  76. id, err := result.LastInsertId()
  77. if err != nil {
  78. log.Println(err)
  79. return -1
  80. }
  81. return id
  82. }
  83. // 批量新增
  84. func (m *Mysql) InsertIgnoreBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  85. return m.InsertIgnoreBatchByTx(nil, tableName, fields, values)
  86. }
  87. // 带事务的批量新增
  88. func (m *Mysql) InsertIgnoreBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  89. return m.insertOrReplaceBatchByTx(tx, "INSERT", "IGNORE", tableName, fields, values)
  90. }
  91. // 批量新增
  92. func (m *Mysql) InsertBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  93. return m.InsertBatchByTx(nil, tableName, fields, values)
  94. }
  95. // 带事务的批量新增
  96. func (m *Mysql) InsertBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  97. return m.insertOrReplaceBatchByTx(tx, "INSERT", "", tableName, fields, values)
  98. }
  99. // 批量更新
  100. func (m *Mysql) ReplaceBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  101. return m.ReplaceBatchByTx(nil, tableName, fields, values)
  102. }
  103. // 带事务的批量更新
  104. func (m *Mysql) ReplaceBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  105. return m.insertOrReplaceBatchByTx(tx, "REPLACE", "", tableName, fields, values)
  106. }
  107. func (m *Mysql) insertOrReplaceBatchByTx(tx *sql.Tx, tp string, afterInsert, tableName string, fields []string, values []interface{}) (int64, int64) {
  108. placeholders := []string{}
  109. for range fields {
  110. placeholders = append(placeholders, "?")
  111. }
  112. placeholder := strings.Join(placeholders, ",")
  113. array := []string{}
  114. for i := 0; i < len(values)/len(fields); i++ {
  115. array = append(array, fmt.Sprintf("(%s)", placeholder))
  116. }
  117. q := fmt.Sprintf("%s %s INTO %s (%s) VALUES %s", tp, afterInsert, tableName, strings.Join(fields, ","), strings.Join(array, ","))
  118. result, _ := m.ExecBySqlByTx(tx, q, values...)
  119. if result == nil {
  120. return -1, -1
  121. }
  122. v1, e1 := result.RowsAffected()
  123. if e1 != nil {
  124. log.Println(e1)
  125. return -1, -1
  126. }
  127. v2, e2 := result.LastInsertId()
  128. if e2 != nil {
  129. log.Println(e2)
  130. return -1, -1
  131. }
  132. return v1, v2
  133. }
  134. // sql语句执行
  135. func (m *Mysql) ExecBySql(q string, args ...interface{}) (sql.Result, error) {
  136. return m.ExecBySqlByTx(nil, q, args...)
  137. }
  138. // sql语句执行,带有事务
  139. func (m *Mysql) ExecBySqlByTx(tx *sql.Tx, q string, args ...interface{}) (sql.Result, error) {
  140. var stmtIns *sql.Stmt
  141. var err error
  142. if tx == nil {
  143. stmtIns, err = m.DB.Prepare(q)
  144. } else {
  145. stmtIns, err = tx.Prepare(q)
  146. }
  147. if err != nil {
  148. log.Println(err)
  149. return nil, err
  150. }
  151. defer stmtIns.Close()
  152. result, err := stmtIns.Exec(args...)
  153. if err != nil {
  154. //log.Println(args, err)
  155. log.Println(err)
  156. return nil, err
  157. }
  158. return result, nil
  159. }
  160. /*不等于 map[string]string{"ne":"1"}
  161. *不等于多个 map[string]string{"notin":[]interface{}{1,2}}
  162. *字段为空 map[string]string{"name":"$isNull"}
  163. *字段不为空 map[string]string{"name":"$isNotNull"}
  164. */
  165. func (m *Mysql) Find(tableName string, query map[string]interface{}, fields, order string, start, pageSize int) *[]map[string]interface{} {
  166. fs := []string{}
  167. vs := []interface{}{}
  168. for k, v := range query {
  169. rt := reflect.TypeOf(v)
  170. rv := reflect.ValueOf(v)
  171. if rt.Kind() == reflect.Map {
  172. for _, rv_k := range rv.MapKeys() {
  173. if rv_k.String() == "ne" {
  174. fs = append(fs, fmt.Sprintf("%s!=?", k))
  175. vs = append(vs, rv.MapIndex(rv_k).Interface())
  176. }
  177. if rv_k.String() == "notin" {
  178. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  179. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  180. fs = append(fs, fmt.Sprintf("%s!=?", k))
  181. vs = append(vs, v)
  182. }
  183. }
  184. }
  185. if rv_k.String() == "gte" {
  186. fs = append(fs, fmt.Sprintf("%s>=?", k))
  187. vs = append(vs, rv.MapIndex(rv_k).Interface())
  188. }
  189. if rv_k.String() == "in" {
  190. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  191. _fs := fmt.Sprintf("%s in (?", k)
  192. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  193. if k > 0 {
  194. _fs += ",?"
  195. }
  196. vs = append(vs, v)
  197. }
  198. _fs += ")"
  199. fs = append(fs, _fs)
  200. }
  201. }
  202. }
  203. } else {
  204. if v == "$isNull" {
  205. fs = append(fs, fmt.Sprintf("%s is null", k))
  206. } else if v == "$isNotNull" {
  207. fs = append(fs, fmt.Sprintf("%s is not null", k))
  208. } else {
  209. fs = append(fs, fmt.Sprintf("%s=?", k))
  210. vs = append(vs, v)
  211. }
  212. }
  213. }
  214. var buffer bytes.Buffer
  215. buffer.WriteString("select ")
  216. if fields == "" {
  217. buffer.WriteString("*")
  218. } else {
  219. buffer.WriteString(fields)
  220. }
  221. buffer.WriteString(" from ")
  222. buffer.WriteString(tableName)
  223. if len(fs) > 0 {
  224. buffer.WriteString(" where ")
  225. buffer.WriteString(strings.Join(fs, " and "))
  226. }
  227. if order != "" {
  228. buffer.WriteString(" order by ")
  229. buffer.WriteString(order)
  230. }
  231. if start > -1 && pageSize > 0 {
  232. buffer.WriteString(" limit ")
  233. buffer.WriteString(fmt.Sprint(start))
  234. buffer.WriteString(",")
  235. buffer.WriteString(fmt.Sprint(pageSize))
  236. }
  237. q := buffer.String()
  238. //log.Println(q, vs)
  239. return m.SelectBySql(q, vs...)
  240. }
  241. // sql语句查询
  242. func (m *Mysql) SelectBySql(q string, args ...interface{}) *[]map[string]interface{} {
  243. return m.SelectBySqlByTx(nil, q, args...)
  244. }
  245. func (m *Mysql) SelectBySqlByTx(tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  246. return m.Select(0, nil, tx, q, args...)
  247. }
  248. func (m *Mysql) Select(bath int, f func(l *[]map[string]interface{}), tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  249. var stmtOut *sql.Stmt
  250. var err error
  251. if tx == nil {
  252. stmtOut, err = m.DB.Prepare(q)
  253. } else {
  254. stmtOut, err = tx.Prepare(q)
  255. }
  256. if err != nil {
  257. log.Println(err)
  258. return nil
  259. }
  260. defer stmtOut.Close()
  261. rows, err := stmtOut.Query(args...)
  262. if err != nil {
  263. log.Println(err)
  264. return nil
  265. }
  266. if rows != nil {
  267. defer rows.Close()
  268. }
  269. columns, err := rows.Columns()
  270. if err != nil {
  271. log.Println(err)
  272. return nil
  273. }
  274. list := []map[string]interface{}{}
  275. for rows.Next() {
  276. scanArgs := make([]interface{}, len(columns))
  277. values := make([]interface{}, len(columns))
  278. ret := make(map[string]interface{})
  279. for k, _ := range values {
  280. scanArgs[k] = &values[k]
  281. }
  282. err = rows.Scan(scanArgs...)
  283. if err != nil {
  284. log.Println(err)
  285. break
  286. }
  287. for i, col := range values {
  288. if v, ok := col.([]uint8); ok {
  289. ret[columns[i]] = string(v)
  290. } else {
  291. ret[columns[i]] = col
  292. }
  293. }
  294. list = append(list, ret)
  295. if bath > 0 && len(list) == bath {
  296. f(&list)
  297. list = []map[string]interface{}{}
  298. }
  299. }
  300. if bath > 0 && len(list) > 0 {
  301. f(&list)
  302. list = []map[string]interface{}{}
  303. }
  304. return &list
  305. }
  306. func (m *Mysql) SelectByBath(bath int, f func(l *[]map[string]interface{}), q string, args ...interface{}) {
  307. m.SelectByBathByTx(bath, f, nil, q, args...)
  308. }
  309. func (m *Mysql) SelectByBathByTx(bath int, f func(l *[]map[string]interface{}), tx *sql.Tx, q string, args ...interface{}) {
  310. m.Select(bath, f, tx, q, args...)
  311. }
  312. func (m *Mysql) FindOne(tableName string, query map[string]interface{}, fields, order string) *map[string]interface{} {
  313. list := m.Find(tableName, query, fields, order, 0, 1)
  314. if list != nil && len(*list) == 1 {
  315. temp := (*list)[0]
  316. return &temp
  317. }
  318. return nil
  319. }
  320. // 修改
  321. func (m *Mysql) Update(tableName string, query, update map[string]interface{}) bool {
  322. return m.UpdateByTx(nil, tableName, query, update)
  323. }
  324. // 带事务的修改
  325. func (m *Mysql) UpdateByTx(tx *sql.Tx, tableName string, query, update map[string]interface{}) bool {
  326. q_fs := []string{}
  327. u_fs := []string{}
  328. values := []interface{}{}
  329. for k, v := range update {
  330. q_fs = append(q_fs, fmt.Sprintf("%s=?", k))
  331. values = append(values, v)
  332. }
  333. for k, v := range query {
  334. u_fs = append(u_fs, fmt.Sprintf("%s=?", k))
  335. values = append(values, v)
  336. }
  337. q := fmt.Sprintf("update %s set %s where %s", tableName, strings.Join(q_fs, ","), strings.Join(u_fs, " and "))
  338. //log.Println(q, values)
  339. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) >= 0
  340. }
  341. // 删除
  342. func (m *Mysql) Delete(tableName string, query map[string]interface{}) bool {
  343. return m.DeleteByTx(nil, tableName, query)
  344. }
  345. func (m *Mysql) DeleteByTx(tx *sql.Tx, tableName string, query map[string]interface{}) bool {
  346. fields := []string{}
  347. values := []interface{}{}
  348. for k, v := range query {
  349. fields = append(fields, fmt.Sprintf("%s=?", k))
  350. values = append(values, v)
  351. }
  352. q := fmt.Sprintf("delete from %s where %s", tableName, strings.Join(fields, " and "))
  353. //log.Println(q, values)
  354. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) > 0
  355. }
  356. // 修改或删除
  357. func (m *Mysql) UpdateOrDeleteBySql(q string, args ...interface{}) int64 {
  358. return m.UpdateOrDeleteBySqlByTx(nil, q, args...)
  359. }
  360. // 带事务的修改或删除
  361. func (m *Mysql) UpdateOrDeleteBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  362. result, err := m.ExecBySqlByTx(tx, q, args...)
  363. if err != nil {
  364. log.Println(err)
  365. return -1
  366. }
  367. count, err := result.RowsAffected()
  368. if err != nil {
  369. log.Println(err)
  370. return -1
  371. }
  372. return count
  373. }
  374. // 总数
  375. func (m *Mysql) Count(tableName string, query map[string]interface{}) int64 {
  376. fields := []string{}
  377. values := []interface{}{}
  378. for k, v := range query {
  379. rt := reflect.TypeOf(v)
  380. rv := reflect.ValueOf(v)
  381. if rt.Kind() == reflect.Map {
  382. for _, rv_k := range rv.MapKeys() {
  383. if rv_k.String() == "ne" {
  384. fields = append(fields, fmt.Sprintf("%s!=?", k))
  385. values = append(values, rv.MapIndex(rv_k).Interface())
  386. }
  387. if rv_k.String() == "notin" {
  388. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  389. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  390. fields = append(fields, fmt.Sprintf("%s!=?", k))
  391. values = append(values, v)
  392. }
  393. }
  394. }
  395. if rv_k.String() == "in" {
  396. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  397. _fs := fmt.Sprintf("%s in (?", k)
  398. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  399. if k > 0 {
  400. _fs += ",?"
  401. }
  402. values = append(values, v)
  403. }
  404. _fs += ")"
  405. fields = append(fields, _fs)
  406. }
  407. }
  408. }
  409. } else if v == "$isNull" {
  410. fields = append(fields, fmt.Sprintf("%s is null", k))
  411. } else if v == "$isNotNull" {
  412. fields = append(fields, fmt.Sprintf("%s is not null", k))
  413. } else {
  414. fields = append(fields, fmt.Sprintf("%s=?", k))
  415. values = append(values, v)
  416. }
  417. }
  418. q := fmt.Sprintf("select count(1) as count from %s", tableName)
  419. if len(query) > 0 {
  420. q += fmt.Sprintf(" where %s", strings.Join(fields, " and "))
  421. }
  422. log.Println(q, values)
  423. return m.CountBySql(q, values...)
  424. }
  425. func (m *Mysql) CountBySql(q string, args ...interface{}) int64 {
  426. stmtIns, err := m.DB.Prepare(q)
  427. if err != nil {
  428. log.Println(err)
  429. return -1
  430. }
  431. defer stmtIns.Close()
  432. rows, err := stmtIns.Query(args...)
  433. if err != nil {
  434. log.Println(err)
  435. return -1
  436. }
  437. if rows != nil {
  438. defer rows.Close()
  439. }
  440. var count int64 = -1
  441. if rows.Next() {
  442. err = rows.Scan(&count)
  443. if err != nil {
  444. log.Println(err)
  445. }
  446. }
  447. return count
  448. }
  449. // 执行事务
  450. func (m *Mysql) ExecTx(msg string, f func(tx *sql.Tx) bool) bool {
  451. tx, err := m.DB.Begin()
  452. if err != nil {
  453. log.Println(msg, "获取事务错误", err)
  454. } else {
  455. if f(tx) {
  456. if err := tx.Commit(); err != nil {
  457. log.Println(msg, "提交事务错误", err)
  458. } else {
  459. return true
  460. }
  461. } else {
  462. if err := tx.Rollback(); err != nil {
  463. log.Println(msg, "事务回滚错误", err)
  464. }
  465. }
  466. }
  467. return false
  468. }
  469. /*************方法命名不规范,上面有替代方法*************/
  470. func (m *Mysql) Query(query string, args ...interface{}) *[]map[string]interface{} {
  471. return m.SelectBySql(query, args...)
  472. }
  473. func (m *Mysql) QueryCount(query string, args ...interface{}) (count int) {
  474. count = -1
  475. if !strings.Contains(strings.ToLower(query), "count(*)") {
  476. fmt.Println("QueryCount need query like < select count(*) from ..... >")
  477. return
  478. }
  479. count = int(m.CountBySql(query, args...))
  480. return
  481. }