mysql.go 15 KB


  1. package mysql
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "fmt"
  6. "log"
  7. "reflect"
  8. "strings"
  9. "time"
  10. _ "github.com/go-sql-driver/mysql"
  11. )
  12. type Mysql struct {
  13. Address string //数据库地址:端口
  14. UserName string //用户名
  15. PassWord string //密码
  16. DBName string //数据库名
  17. DB *sql.DB //数据库连接池对象
  18. MaxOpenConns int //用于设置最大打开的连接数,默认值为0表示不限制。
  19. MaxIdleConns int //用于设置闲置的连接数。
  20. }
  21. func (m *Mysql) Init() {
  22. if m.MaxOpenConns <= 0 {
  23. m.MaxOpenConns = 30
  24. }
  25. if m.MaxIdleConns <= 0 {
  26. m.MaxIdleConns = 6
  27. }
  28. var err error
  29. m.DB, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", m.UserName, m.PassWord, m.Address, m.DBName))
  30. if err != nil {
  31. log.Println(err)
  32. return
  33. }
  34. m.DB.SetMaxOpenConns(m.MaxOpenConns)
  35. m.DB.SetMaxIdleConns(m.MaxIdleConns)
  36. m.DB.SetConnMaxLifetime(14400 * time.Second)
  37. err = m.DB.Ping()
  38. if err != nil {
  39. log.Println(err)
  40. }
  41. }
  42. //新增
  43. func (m *Mysql) Insert(tableName string, data map[string]interface{}) int64 {
  44. return m.InsertByTx(nil, tableName, data)
  45. }
  46. //带有事务的新增
  47. func (m *Mysql) InsertByTx(tx *sql.Tx, tableName string, data map[string]interface{}) int64 {
  48. fields := []string{}
  49. values := []interface{}{}
  50. placeholders := []string{}
  51. if tableName == "dataexport_order" {
  52. if _, ok := data["user_nickname"]; ok {
  53. data["user_nickname"] = ""
  54. }
  55. }
  56. for k, v := range data {
  57. fields = append(fields, k)
  58. values = append(values, v)
  59. placeholders = append(placeholders, "?")
  60. }
  61. q := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(fields, ","), strings.Join(placeholders, ","))
  62. log.Println("mysql", q, values)
  63. return m.InsertBySqlByTx(tx, q, values...)
  64. }
  65. //sql语句新增
  66. func (m *Mysql) InsertBySql(q string, args ...interface{}) int64 {
  67. return m.InsertBySqlByTx(nil, q, args...)
  68. }
  69. //带有事务的sql语句新增
  70. func (m *Mysql) InsertBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  71. result, _ := m.ExecBySqlByTx(tx, q, args...)
  72. if result == nil {
  73. return -1
  74. }
  75. id, err := result.LastInsertId()
  76. if err != nil {
  77. log.Println(err)
  78. return -1
  79. }
  80. return id
  81. }
  82. //批量新增
  83. func (m *Mysql) InsertIgnoreBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  84. return m.InsertIgnoreBatchByTx(nil, tableName, fields, values)
  85. }
  86. //带事务的批量新增
  87. func (m *Mysql) InsertIgnoreBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  88. return m.insertOrReplaceBatchByTx(tx, "INSERT", "IGNORE", tableName, fields, values)
  89. }
  90. //批量新增
  91. func (m *Mysql) InsertBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  92. return m.InsertBatchByTx(nil, tableName, fields, values)
  93. }
  94. //带事务的批量新增
  95. func (m *Mysql) InsertBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  96. return m.insertOrReplaceBatchByTx(tx, "INSERT", "", tableName, fields, values)
  97. }
  98. //批量更新
  99. func (m *Mysql) ReplaceBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  100. return m.ReplaceBatchByTx(nil, tableName, fields, values)
  101. }
  102. //带事务的批量更新
  103. func (m *Mysql) ReplaceBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  104. return m.insertOrReplaceBatchByTx(tx, "REPLACE", "", tableName, fields, values)
  105. }
  106. //批量更新
  107. func (m *Mysql) ReplaceIgnoreBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  108. return m.ReplaceIgnoreBatchByTx(nil, tableName, fields, values)
  109. }
  110. //带事务的批量更新
  111. func (m *Mysql) ReplaceIgnoreBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  112. return m.insertOrReplaceBatchByTx(tx, "REPLACE", "IGNORE", tableName, fields, values)
  113. }
  114. func (m *Mysql) insertOrReplaceBatchByTx(tx *sql.Tx, tp string, afterInsert, tableName string, fields []string, values []interface{}) (int64, int64) {
  115. placeholders := []string{}
  116. for range fields {
  117. placeholders = append(placeholders, "?")
  118. }
  119. placeholder := strings.Join(placeholders, ",")
  120. array := []string{}
  121. for i := 0; i < len(values)/len(fields); i++ {
  122. array = append(array, fmt.Sprintf("(%s)", placeholder))
  123. }
  124. q := fmt.Sprintf("%s %s INTO %s (%s) VALUES %s", tp, afterInsert, tableName, strings.Join(fields, ","), strings.Join(array, ","))
  125. result, _ := m.ExecBySqlByTx(tx, q, values...)
  126. if result == nil {
  127. return -1, -1
  128. }
  129. v1, e1 := result.RowsAffected()
  130. if e1 != nil {
  131. log.Println(e1)
  132. return -1, -1
  133. }
  134. v2, e2 := result.LastInsertId()
  135. if e2 != nil {
  136. log.Println(e2)
  137. return -1, -1
  138. }
  139. return v1, v2
  140. }
  141. //sql语句执行
  142. func (m *Mysql) ExecBySql(q string, args ...interface{}) (sql.Result, error) {
  143. return m.ExecBySqlByTx(nil, q, args...)
  144. }
  145. //sql语句执行,带有事务
  146. func (m *Mysql) ExecBySqlByTx(tx *sql.Tx, q string, args ...interface{}) (sql.Result, error) {
  147. var stmtIns *sql.Stmt
  148. var err error
  149. if tx == nil {
  150. stmtIns, err = m.DB.Prepare(q)
  151. } else {
  152. stmtIns, err = tx.Prepare(q)
  153. }
  154. if err != nil {
  155. log.Println(err)
  156. return nil, err
  157. }
  158. defer stmtIns.Close()
  159. result, err := stmtIns.Exec(args...)
  160. if err != nil {
  161. log.Println(args, err)
  162. return nil, err
  163. }
  164. return result, nil
  165. }
  166. /*不等于 map[string]string{"ne":"1"}
  167. *不等于多个 map[string]string{"notin":[]interface{}{1,2}}
  168. *字段为空 map[string]string{"name":"$isNull"}
  169. *字段不为空 map[string]string{"name":"$isNotNull"}
  170. */
  171. func (m *Mysql) Find(tableName string, query map[string]interface{}, fields, order string, start, pageSize int) *[]map[string]interface{} {
  172. fs := []string{}
  173. vs := []interface{}{}
  174. for k, v := range query {
  175. rt := reflect.TypeOf(v)
  176. rv := reflect.ValueOf(v)
  177. if rt.Kind() == reflect.Map {
  178. for _, rv_k := range rv.MapKeys() {
  179. if rv_k.String() == "ne" {
  180. fs = append(fs, fmt.Sprintf("%s!=?", k))
  181. vs = append(vs, rv.MapIndex(rv_k).Interface())
  182. }
  183. if rv_k.String() == "notin" {
  184. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  185. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  186. fs = append(fs, fmt.Sprintf("%s!=?", k))
  187. vs = append(vs, v)
  188. }
  189. }
  190. }
  191. if rv_k.String() == "in" {
  192. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  193. _fs := fmt.Sprintf("%s in (?", k)
  194. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  195. if k > 0 {
  196. _fs += ",?"
  197. }
  198. vs = append(vs, v)
  199. }
  200. _fs += ")"
  201. fs = append(fs, _fs)
  202. }
  203. }
  204. }
  205. } else {
  206. if v == "$isNull" {
  207. fs = append(fs, fmt.Sprintf("%s is null", k))
  208. } else if v == "$isNotNull" {
  209. fs = append(fs, fmt.Sprintf("%s is not null", k))
  210. } else {
  211. fs = append(fs, fmt.Sprintf("%s=?", k))
  212. vs = append(vs, v)
  213. }
  214. }
  215. }
  216. var buffer bytes.Buffer
  217. buffer.WriteString("select ")
  218. if fields == "" {
  219. buffer.WriteString("*")
  220. } else {
  221. buffer.WriteString(fields)
  222. }
  223. buffer.WriteString(" from ")
  224. buffer.WriteString(tableName)
  225. if len(fs) > 0 {
  226. buffer.WriteString(" where ")
  227. buffer.WriteString(strings.Join(fs, " and "))
  228. }
  229. if order != "" {
  230. buffer.WriteString(" order by ")
  231. buffer.WriteString(order)
  232. }
  233. if start > -1 && pageSize > 0 {
  234. buffer.WriteString(" limit ")
  235. buffer.WriteString(fmt.Sprint(start))
  236. buffer.WriteString(",")
  237. buffer.WriteString(fmt.Sprint(pageSize))
  238. }
  239. q := buffer.String()
  240. log.Println(q, vs)
  241. return m.SelectBySql(q, vs...)
  242. }
  243. //sql语句查询
  244. func (m *Mysql) SelectBySql(q string, args ...interface{}) *[]map[string]interface{} {
  245. return m.SelectBySqlByTx(nil, q, args...)
  246. }
  247. func (m *Mysql) SelectBySqlByTx(tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  248. return m.Select(0, nil, tx, q, args...)
  249. }
  250. func (m *Mysql) Select(bath int, f func(l *[]map[string]interface{}), tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  251. var stmtOut *sql.Stmt
  252. var err error
  253. if tx == nil {
  254. stmtOut, err = m.DB.Prepare(q)
  255. } else {
  256. stmtOut, err = tx.Prepare(q)
  257. }
  258. if err != nil {
  259. log.Println(err)
  260. return nil
  261. }
  262. defer stmtOut.Close()
  263. rows, err := stmtOut.Query(args...)
  264. if err != nil {
  265. log.Println(err)
  266. return nil
  267. }
  268. if rows != nil {
  269. defer rows.Close()
  270. }
  271. columns, err := rows.Columns()
  272. if err != nil {
  273. log.Println(err)
  274. return nil
  275. }
  276. list := []map[string]interface{}{}
  277. for rows.Next() {
  278. scanArgs := make([]interface{}, len(columns))
  279. values := make([]interface{}, len(columns))
  280. ret := make(map[string]interface{})
  281. for k, _ := range values {
  282. scanArgs[k] = &values[k]
  283. }
  284. err = rows.Scan(scanArgs...)
  285. if err != nil {
  286. log.Println(err)
  287. break
  288. }
  289. for i, col := range values {
  290. if v, ok := col.([]uint8); ok {
  291. ret[columns[i]] = string(v)
  292. } else {
  293. ret[columns[i]] = col
  294. }
  295. }
  296. list = append(list, ret)
  297. if bath > 0 && len(list) == bath {
  298. f(&list)
  299. list = []map[string]interface{}{}
  300. }
  301. }
  302. if bath > 0 && len(list) > 0 {
  303. f(&list)
  304. list = []map[string]interface{}{}
  305. }
  306. return &list
  307. }
  308. func (m *Mysql) SelectByBath(bath int, f func(l *[]map[string]interface{}), q string, args ...interface{}) {
  309. m.SelectByBathByTx(bath, f, nil, q, args...)
  310. }
  311. func (m *Mysql) SelectByBathByTx(bath int, f func(l *[]map[string]interface{}), tx *sql.Tx, q string, args ...interface{}) {
  312. m.Select(bath, f, tx, q, args...)
  313. }
  314. func (m *Mysql) FindOne(tableName string, query map[string]interface{}, fields, order string) *map[string]interface{} {
  315. list := m.Find(tableName, query, fields, order, 0, 1)
  316. if list != nil && len(*list) == 1 {
  317. temp := (*list)[0]
  318. return &temp
  319. }
  320. return nil
  321. }
  322. //修改
  323. func (m *Mysql) Update(tableName string, query, update map[string]interface{}) bool {
  324. return m.UpdateByTx(nil, tableName, query, update)
  325. }
  326. //带事务的修改
  327. func (m *Mysql) UpdateByTx(tx *sql.Tx, tableName string, query, update map[string]interface{}) bool {
  328. q_fs := []string{}
  329. u_fs := []string{}
  330. values := []interface{}{}
  331. for k, v := range update {
  332. q_fs = append(q_fs, fmt.Sprintf("%s=?", k))
  333. values = append(values, v)
  334. }
  335. for k, v := range query {
  336. u_fs = append(u_fs, fmt.Sprintf("%s=?", k))
  337. values = append(values, v)
  338. }
  339. q := fmt.Sprintf("update %s set %s where %s", tableName, strings.Join(q_fs, ","), strings.Join(u_fs, " and "))
  340. log.Println(q, values)
  341. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) >= 0
  342. }
  343. //批量更新
  344. func (m *Mysql) UpdateBath(tableName string, fields []string, array [][]interface{}) {
  345. ws := []string{}
  346. args := []interface{}{}
  347. ids := []interface{}{}
  348. casethen := []string{}
  349. for n := 0; n < len(array[0]); n++ {
  350. for _, v := range array {
  351. if n == 0 {
  352. ws = append(ws, "?")
  353. ids = append(ids, v[0])
  354. casethen = append(casethen, "when ? then ?")
  355. } else {
  356. args = append(args, v[0], v[n])
  357. }
  358. }
  359. }
  360. ct := strings.Join(casethen, " ")
  361. sql_appends := []string{}
  362. for k, v := range fields {
  363. if k == 0 {
  364. continue
  365. }
  366. sql_appends = append(sql_appends, fmt.Sprintf(`%s=case %s %s end`, v, fields[0], ct))
  367. }
  368. args = append(args, ids...)
  369. sql := fmt.Sprintf(`update %s set %s where %s in (%s)`, tableName, strings.Join(sql_appends, ","), fields[0], strings.Join(ws, ","))
  370. m.UpdateOrDeleteBySql(sql, args...)
  371. }
  372. //删除
  373. func (m *Mysql) Delete(tableName string, query map[string]interface{}) bool {
  374. return m.DeleteByTx(nil, tableName, query)
  375. }
  376. func (m *Mysql) DeleteByTx(tx *sql.Tx, tableName string, query map[string]interface{}) bool {
  377. fields := []string{}
  378. values := []interface{}{}
  379. for k, v := range query {
  380. fields = append(fields, fmt.Sprintf("%s=?", k))
  381. values = append(values, v)
  382. }
  383. q := fmt.Sprintf("delete from %s where %s", tableName, strings.Join(fields, " and "))
  384. log.Println(q, values)
  385. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) > 0
  386. }
  387. //修改或删除
  388. func (m *Mysql) UpdateOrDeleteBySql(q string, args ...interface{}) int64 {
  389. return m.UpdateOrDeleteBySqlByTx(nil, q, args...)
  390. }
  391. //带事务的修改或删除
  392. func (m *Mysql) UpdateOrDeleteBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  393. result, err := m.ExecBySqlByTx(tx, q, args...)
  394. if err != nil {
  395. log.Println(err)
  396. return -1
  397. }
  398. count, err := result.RowsAffected()
  399. if err != nil {
  400. log.Println(err)
  401. return -1
  402. }
  403. return count
  404. }
  405. //总数
  406. func (m *Mysql) Count(tableName string, query map[string]interface{}) int64 {
  407. fields := []string{}
  408. values := []interface{}{}
  409. for k, v := range query {
  410. rt := reflect.TypeOf(v)
  411. rv := reflect.ValueOf(v)
  412. if rt.Kind() == reflect.Map {
  413. for _, rv_k := range rv.MapKeys() {
  414. if rv_k.String() == "ne" {
  415. fields = append(fields, fmt.Sprintf("%s!=?", k))
  416. values = append(values, rv.MapIndex(rv_k).Interface())
  417. }
  418. if rv_k.String() == "notin" {
  419. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  420. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  421. fields = append(fields, fmt.Sprintf("%s!=?", k))
  422. values = append(values, v)
  423. }
  424. }
  425. }
  426. if rv_k.String() == "in" {
  427. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  428. _fs := fmt.Sprintf("%s in (?", k)
  429. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  430. if k > 0 {
  431. _fs += ",?"
  432. }
  433. values = append(values, v)
  434. }
  435. _fs += ")"
  436. fields = append(fields, _fs)
  437. }
  438. }
  439. }
  440. } else if v == "$isNull" {
  441. fields = append(fields, fmt.Sprintf("%s is null", k))
  442. } else if v == "$isNotNull" {
  443. fields = append(fields, fmt.Sprintf("%s is not null", k))
  444. } else {
  445. fields = append(fields, fmt.Sprintf("%s=?", k))
  446. values = append(values, v)
  447. }
  448. }
  449. q := fmt.Sprintf("select count(1) as count from %s", tableName)
  450. if len(query) > 0 {
  451. q += fmt.Sprintf(" where %s", strings.Join(fields, " and "))
  452. }
  453. log.Println(q, values)
  454. return m.CountBySql(q, values...)
  455. }
  456. func (m *Mysql) CountBySql(q string, args ...interface{}) int64 {
  457. stmtIns, err := m.DB.Prepare(q)
  458. if err != nil {
  459. log.Println(err)
  460. return -1
  461. }
  462. defer stmtIns.Close()
  463. rows, err := stmtIns.Query(args...)
  464. if err != nil {
  465. log.Println(err)
  466. return -1
  467. }
  468. if rows != nil {
  469. defer rows.Close()
  470. }
  471. var count int64 = -1
  472. if rows.Next() {
  473. err = rows.Scan(&count)
  474. if err != nil {
  475. log.Println(err)
  476. }
  477. }
  478. return count
  479. }
  480. //执行事务
  481. func (m *Mysql) ExecTx(msg string, f func(tx *sql.Tx) bool) bool {
  482. tx, err := m.DB.Begin()
  483. if err != nil {
  484. log.Println(msg, "获取事务错误", err)
  485. } else {
  486. if f(tx) {
  487. if err := tx.Commit(); err != nil {
  488. log.Println(msg, "提交事务错误", err)
  489. } else {
  490. return true
  491. }
  492. } else {
  493. if err := tx.Rollback(); err != nil {
  494. log.Println(msg, "事务回滚错误", err)
  495. }
  496. }
  497. }
  498. return false
  499. }
  500. /*************方法命名不规范,上面有替代方法*************/
  501. func (m *Mysql) Query(query string, args ...interface{}) *[]map[string]interface{} {
  502. return m.SelectBySql(query, args...)
  503. }
  504. func (m *Mysql) QueryCount(query string, args ...interface{}) (count int) {
  505. count = -1
  506. if !strings.Contains(strings.ToLower(query), "count(*)") {
  507. fmt.Println("QueryCount need query like < select count(*) from ..... >")
  508. return
  509. }
  510. count = int(m.CountBySql(query, args...))
  511. return
  512. }