mysql.go 14 KB


  1. package mysqldb
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "fmt"
  6. "log"
  7. "reflect"
  8. "strings"
  9. "time"
  10. _ "github.com/go-sql-driver/mysql"
  11. )
  12. type Mysql struct {
  13. Address string //数据库地址:端口
  14. UserName string //用户名
  15. PassWord string //密码
  16. DBName string //数据库名
  17. DB *sql.DB //数据库连接池对象
  18. MaxOpenConns int //用于设置最大打开的连接数,默认值为0表示不限制。
  19. MaxIdleConns int //用于设置闲置的连接数。
  20. }
  21. func (m *Mysql) Init() {
  22. if m.MaxOpenConns <= 0 {
  23. m.MaxOpenConns = 20
  24. }
  25. if m.MaxIdleConns <= 0 {
  26. m.MaxIdleConns = 20
  27. }
  28. var err error
  29. m.DB, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", m.UserName, m.PassWord, m.Address, m.DBName))
  30. if err != nil {
  31. log.Println(err)
  32. return
  33. }
  34. m.DB.SetMaxOpenConns(m.MaxOpenConns)
  35. m.DB.SetMaxIdleConns(m.MaxIdleConns)
  36. m.DB.SetConnMaxLifetime(time.Minute * 3)
  37. err = m.DB.Ping()
  38. if err != nil {
  39. log.Println(err)
  40. }
  41. }
  42. // 新增
  43. func (m *Mysql) Insert(tableName string, data map[string]interface{}) int64 {
  44. return m.InsertByTx(nil, tableName, data)
  45. }
  46. // 带有事务的新增
  47. func (m *Mysql) InsertByTx(tx *sql.Tx, tableName string, data map[string]interface{}) int64 {
  48. fields := []string{}
  49. values := []interface{}{}
  50. placeholders := []string{}
  51. for k, v := range data {
  52. fields = append(fields, k)
  53. values = append(values, v)
  54. placeholders = append(placeholders, "?")
  55. }
  56. q := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(fields, ","), strings.Join(placeholders, ","))
  57. //log.Println("mysql", q, values)
  58. return m.InsertBySqlByTx(tx, q, values...)
  59. }
  60. // sql语句新增
  61. func (m *Mysql) InsertBySql(q string, args ...interface{}) int64 {
  62. return m.InsertBySqlByTx(nil, q, args...)
  63. }
  64. // 带有事务的sql语句新增
  65. func (m *Mysql) InsertBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  66. result, _ := m.ExecBySqlByTx(tx, q, args...)
  67. if result == nil {
  68. return -1
  69. }
  70. id, err := result.LastInsertId()
  71. if err != nil {
  72. log.Println(err)
  73. return -1
  74. }
  75. return id
  76. }
  77. // 批量新增
  78. func (m *Mysql) InsertIgnoreBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  79. return m.InsertIgnoreBatchByTx(nil, tableName, fields, values)
  80. }
  81. // 带事务的批量新增
  82. func (m *Mysql) InsertIgnoreBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  83. return m.insertOrReplaceBatchByTx(tx, "INSERT", "IGNORE", tableName, fields, values)
  84. }
  85. // 批量新增
  86. func (m *Mysql) InsertBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  87. return m.InsertBatchByTx(nil, tableName, fields, values)
  88. }
  89. // 带事务的批量新增
  90. func (m *Mysql) InsertBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  91. return m.insertOrReplaceBatchByTx(tx, "INSERT", "", tableName, fields, values)
  92. }
  93. // 批量更新
  94. func (m *Mysql) ReplaceBatch(tableName string, fields []string, values []interface{}) (int64, int64) {
  95. return m.ReplaceBatchByTx(nil, tableName, fields, values)
  96. }
  97. // 带事务的批量更新
  98. func (m *Mysql) ReplaceBatchByTx(tx *sql.Tx, tableName string, fields []string, values []interface{}) (int64, int64) {
  99. return m.insertOrReplaceBatchByTx(tx, "REPLACE", "", tableName, fields, values)
  100. }
  101. func (m *Mysql) insertOrReplaceBatchByTx(tx *sql.Tx, tp string, afterInsert, tableName string, fields []string, values []interface{}) (int64, int64) {
  102. placeholders := []string{}
  103. for range fields {
  104. placeholders = append(placeholders, "?")
  105. }
  106. placeholder := strings.Join(placeholders, ",")
  107. array := []string{}
  108. for i := 0; i < len(values)/len(fields); i++ {
  109. array = append(array, fmt.Sprintf("(%s)", placeholder))
  110. }
  111. q := fmt.Sprintf("%s %s INTO %s (%s) VALUES %s", tp, afterInsert, tableName, strings.Join(fields, ","), strings.Join(array, ","))
  112. result, _ := m.ExecBySqlByTx(tx, q, values...)
  113. if result == nil {
  114. return -1, -1
  115. }
  116. v1, e1 := result.RowsAffected()
  117. if e1 != nil {
  118. log.Println(e1)
  119. return -1, -1
  120. }
  121. v2, e2 := result.LastInsertId()
  122. if e2 != nil {
  123. log.Println(e2)
  124. return -1, -1
  125. }
  126. return v1, v2
  127. }
  128. func (m *Mysql) InsertBulk(tableName string, fields []string, doc ...map[string]interface{}) (int64, int64) {
  129. placeholders := []string{}
  130. for range fields {
  131. placeholders = append(placeholders, "?")
  132. }
  133. placeholder := strings.Join(placeholders, ",")
  134. array := []string{}
  135. for i := 0; i < len(doc); i++ {
  136. array = append(array, fmt.Sprintf("(%s)", placeholder))
  137. }
  138. var values []interface{}
  139. for _, d := range doc {
  140. for _, f := range fields {
  141. if d[f] != nil {
  142. values = append(values, d[f])
  143. } else {
  144. values = append(values, nil)
  145. }
  146. }
  147. }
  148. q := fmt.Sprintf("%s %s INTO %s (%s) VALUES %s", "REPLACE", "", tableName, strings.Join(fields, ","), strings.Join(array, ","))
  149. result, _ := m.ExecBySqlByTx(nil, q, values...)
  150. if result == nil {
  151. return -1, -1
  152. }
  153. v1, e1 := result.RowsAffected()
  154. if e1 != nil {
  155. log.Println(e1)
  156. return -1, -1
  157. }
  158. v2, e2 := result.LastInsertId()
  159. if e2 != nil {
  160. log.Println(e2)
  161. return -1, -1
  162. }
  163. return v1, v2
  164. }
  165. // sql语句执行
  166. func (m *Mysql) ExecBySql(q string, args ...interface{}) (sql.Result, error) {
  167. return m.ExecBySqlByTx(nil, q, args...)
  168. }
  169. // sql语句执行,带有事务
  170. func (m *Mysql) ExecBySqlByTx(tx *sql.Tx, q string, args ...interface{}) (sql.Result, error) {
  171. var stmtIns *sql.Stmt
  172. var err error
  173. if tx == nil {
  174. stmtIns, err = m.DB.Prepare(q)
  175. } else {
  176. stmtIns, err = tx.Prepare(q)
  177. }
  178. if err != nil {
  179. log.Println(err)
  180. return nil, err
  181. }
  182. defer stmtIns.Close()
  183. result, err := stmtIns.Exec(args...)
  184. if err != nil {
  185. log.Println(args, err)
  186. return nil, err
  187. }
  188. return result, nil
  189. }
  190. /*不等于 map[string]string{"ne":"1"}
  191. *不等于多个 map[string]string{"notin":[]interface{}{1,2}}
  192. *字段为空 map[string]string{"name":"$isNull"}
  193. *字段不为空 map[string]string{"name":"$isNotNull"}
  194. */
  195. func (m *Mysql) Find(tableName string, query map[string]interface{}, fields, order string, start, pageSize int) *[]map[string]interface{} {
  196. fs := []string{}
  197. vs := []interface{}{}
  198. for k, v := range query {
  199. rt := reflect.TypeOf(v)
  200. rv := reflect.ValueOf(v)
  201. if rt.Kind() == reflect.Map {
  202. for _, rv_k := range rv.MapKeys() {
  203. if rv_k.String() == "ne" {
  204. fs = append(fs, fmt.Sprintf("%s!=?", k))
  205. vs = append(vs, rv.MapIndex(rv_k).Interface())
  206. }
  207. if rv_k.String() == "notin" {
  208. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  209. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  210. fs = append(fs, fmt.Sprintf("%s!=?", k))
  211. vs = append(vs, v)
  212. }
  213. }
  214. }
  215. if rv_k.String() == "in" {
  216. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  217. _fs := fmt.Sprintf("%s in (?", k)
  218. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  219. if k > 0 {
  220. _fs += ",?"
  221. }
  222. vs = append(vs, v)
  223. }
  224. _fs += ")"
  225. fs = append(fs, _fs)
  226. }
  227. }
  228. }
  229. } else {
  230. if v == "$isNull" {
  231. fs = append(fs, fmt.Sprintf("%s is null", k))
  232. } else if v == "$isNotNull" {
  233. fs = append(fs, fmt.Sprintf("%s is not null", k))
  234. } else {
  235. fs = append(fs, fmt.Sprintf("%s=?", k))
  236. vs = append(vs, v)
  237. }
  238. }
  239. }
  240. var buffer bytes.Buffer
  241. buffer.WriteString("select ")
  242. if fields == "" {
  243. buffer.WriteString("*")
  244. } else {
  245. buffer.WriteString(fields)
  246. }
  247. buffer.WriteString(" from ")
  248. buffer.WriteString(tableName)
  249. if len(fs) > 0 {
  250. buffer.WriteString(" where ")
  251. buffer.WriteString(strings.Join(fs, " and "))
  252. }
  253. if order != "" {
  254. buffer.WriteString(" order by ")
  255. buffer.WriteString(order)
  256. }
  257. if start > -1 && pageSize > 0 {
  258. buffer.WriteString(" limit ")
  259. buffer.WriteString(fmt.Sprint(start))
  260. buffer.WriteString(",")
  261. buffer.WriteString(fmt.Sprint(pageSize))
  262. }
  263. q := buffer.String()
  264. //log.Println(q, vs)
  265. return m.SelectBySql(q, vs...)
  266. }
  267. // sql语句查询
  268. func (m *Mysql) SelectBySql(q string, args ...interface{}) *[]map[string]interface{} {
  269. return m.SelectBySqlByTx(nil, q, args...)
  270. }
  271. func (m *Mysql) SelectBySqlByTx(tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  272. return m.Select(0, nil, tx, q, args...)
  273. }
  274. func (m *Mysql) Select(bath int, f func(l *[]map[string]interface{}), tx *sql.Tx, q string, args ...interface{}) *[]map[string]interface{} {
  275. var stmtOut *sql.Stmt
  276. var err error
  277. if tx == nil {
  278. stmtOut, err = m.DB.Prepare(q)
  279. } else {
  280. stmtOut, err = tx.Prepare(q)
  281. }
  282. if err != nil {
  283. log.Println(err)
  284. return nil
  285. }
  286. defer stmtOut.Close()
  287. rows, err := stmtOut.Query(args...)
  288. if err != nil {
  289. log.Println(err)
  290. return nil
  291. }
  292. if rows != nil {
  293. defer rows.Close()
  294. }
  295. columns, err := rows.Columns()
  296. if err != nil {
  297. log.Println(err)
  298. return nil
  299. }
  300. list := []map[string]interface{}{}
  301. for rows.Next() {
  302. scanArgs := make([]interface{}, len(columns))
  303. values := make([]interface{}, len(columns))
  304. ret := make(map[string]interface{})
  305. for k, _ := range values {
  306. scanArgs[k] = &values[k]
  307. }
  308. err = rows.Scan(scanArgs...)
  309. if err != nil {
  310. log.Println(err)
  311. break
  312. }
  313. for i, col := range values {
  314. if v, ok := col.([]uint8); ok {
  315. ret[columns[i]] = string(v)
  316. } else {
  317. ret[columns[i]] = col
  318. }
  319. }
  320. list = append(list, ret)
  321. if bath > 0 && len(list) == bath {
  322. f(&list)
  323. list = []map[string]interface{}{}
  324. }
  325. }
  326. if bath > 0 && len(list) > 0 {
  327. f(&list)
  328. list = []map[string]interface{}{}
  329. }
  330. return &list
  331. }
  332. func (m *Mysql) SelectByBath(bath int, f func(l *[]map[string]interface{}), q string, args ...interface{}) {
  333. m.SelectByBathByTx(bath, f, nil, q, args...)
  334. }
  335. func (m *Mysql) SelectByBathByTx(bath int, f func(l *[]map[string]interface{}), tx *sql.Tx, q string, args ...interface{}) {
  336. m.Select(bath, f, tx, q, args...)
  337. }
  338. func (m *Mysql) FindOne(tableName string, query map[string]interface{}, fields, order string) *map[string]interface{} {
  339. list := m.Find(tableName, query, fields, order, 0, 1)
  340. if list != nil && len(*list) == 1 {
  341. temp := (*list)[0]
  342. return &temp
  343. }
  344. return nil
  345. }
  346. // 修改
  347. func (m *Mysql) Update(tableName string, query, update map[string]interface{}) bool {
  348. return m.UpdateByTx(nil, tableName, query, update)
  349. }
  350. // 带事务的修改
  351. func (m *Mysql) UpdateByTx(tx *sql.Tx, tableName string, query, update map[string]interface{}) bool {
  352. q_fs := []string{}
  353. u_fs := []string{}
  354. values := []interface{}{}
  355. for k, v := range update {
  356. q_fs = append(q_fs, fmt.Sprintf("%s=?", k))
  357. values = append(values, v)
  358. }
  359. for k, v := range query {
  360. u_fs = append(u_fs, fmt.Sprintf("%s=?", k))
  361. values = append(values, v)
  362. }
  363. q := fmt.Sprintf("update %s set %s where %s", tableName, strings.Join(q_fs, ","), strings.Join(u_fs, " and "))
  364. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) >= 0
  365. }
  366. // 删除
  367. func (m *Mysql) Delete(tableName string, query map[string]interface{}) bool {
  368. return m.DeleteByTx(nil, tableName, query)
  369. }
  370. func (m *Mysql) DeleteByTx(tx *sql.Tx, tableName string, query map[string]interface{}) bool {
  371. fields := []string{}
  372. values := []interface{}{}
  373. for k, v := range query {
  374. fields = append(fields, fmt.Sprintf("%s=?", k))
  375. values = append(values, v)
  376. }
  377. q := fmt.Sprintf("delete from %s where %s", tableName, strings.Join(fields, " and "))
  378. log.Println(q, values)
  379. return m.UpdateOrDeleteBySqlByTx(tx, q, values...) > 0
  380. }
  381. // 修改或删除
  382. func (m *Mysql) UpdateOrDeleteBySql(q string, args ...interface{}) int64 {
  383. return m.UpdateOrDeleteBySqlByTx(nil, q, args...)
  384. }
  385. // 带事务的修改或删除
  386. func (m *Mysql) UpdateOrDeleteBySqlByTx(tx *sql.Tx, q string, args ...interface{}) int64 {
  387. result, err := m.ExecBySqlByTx(tx, q, args...)
  388. if err != nil {
  389. log.Println(err)
  390. return -1
  391. }
  392. count, err := result.RowsAffected()
  393. if err != nil {
  394. log.Println(err)
  395. return -1
  396. }
  397. return count
  398. }
  399. // 总数
  400. func (m *Mysql) Count(tableName string, query map[string]interface{}) int64 {
  401. fields := []string{}
  402. values := []interface{}{}
  403. for k, v := range query {
  404. rt := reflect.TypeOf(v)
  405. rv := reflect.ValueOf(v)
  406. if rt.Kind() == reflect.Map {
  407. for _, rv_k := range rv.MapKeys() {
  408. if rv_k.String() == "ne" {
  409. fields = append(fields, fmt.Sprintf("%s!=?", k))
  410. values = append(values, rv.MapIndex(rv_k).Interface())
  411. }
  412. if rv_k.String() == "notin" {
  413. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  414. for _, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  415. fields = append(fields, fmt.Sprintf("%s!=?", k))
  416. values = append(values, v)
  417. }
  418. }
  419. }
  420. if rv_k.String() == "in" {
  421. if len(rv.MapIndex(rv_k).Interface().([]interface{})) > 0 {
  422. _fs := fmt.Sprintf("%s in (?", k)
  423. for k, v := range rv.MapIndex(rv_k).Interface().([]interface{}) {
  424. if k > 0 {
  425. _fs += ",?"
  426. }
  427. values = append(values, v)
  428. }
  429. _fs += ")"
  430. fields = append(fields, _fs)
  431. }
  432. }
  433. }
  434. } else if v == "$isNull" {
  435. fields = append(fields, fmt.Sprintf("%s is null", k))
  436. } else if v == "$isNotNull" {
  437. fields = append(fields, fmt.Sprintf("%s is not null", k))
  438. } else {
  439. fields = append(fields, fmt.Sprintf("%s=?", k))
  440. values = append(values, v)
  441. }
  442. }
  443. q := fmt.Sprintf("select count(1) as count from %s", tableName)
  444. if len(query) > 0 {
  445. q += fmt.Sprintf(" where %s", strings.Join(fields, " and "))
  446. }
  447. return m.CountBySql(q, values...)
  448. }
  449. func (m *Mysql) CountBySql(q string, args ...interface{}) int64 {
  450. stmtIns, err := m.DB.Prepare(q)
  451. if err != nil {
  452. log.Println(err)
  453. return -1
  454. }
  455. defer stmtIns.Close()
  456. rows, err := stmtIns.Query(args...)
  457. if err != nil {
  458. log.Println(err)
  459. return -1
  460. }
  461. if rows != nil {
  462. defer rows.Close()
  463. }
  464. var count int64 = -1
  465. if rows.Next() {
  466. err = rows.Scan(&count)
  467. if err != nil {
  468. log.Println(err)
  469. }
  470. }
  471. return count
  472. }
  473. // 执行事务
  474. func (m *Mysql) ExecTx(msg string, f func(tx *sql.Tx) bool) bool {
  475. tx, err := m.DB.Begin()
  476. if err != nil {
  477. log.Println(msg, "获取事务错误", err)
  478. } else {
  479. if f(tx) {
  480. if err := tx.Commit(); err != nil {
  481. log.Println(msg, "提交事务错误", err)
  482. } else {
  483. return true
  484. }
  485. } else {
  486. if err := tx.Rollback(); err != nil {
  487. log.Println(msg, "事务回滚错误", err)
  488. }
  489. }
  490. }
  491. return false
  492. }
  493. /*************方法命名不规范,上面有替代方法*************/
  494. func (m *Mysql) Query(query string, args ...interface{}) *[]map[string]interface{} {
  495. return m.SelectBySql(query, args...)
  496. }
  497. func (m *Mysql) QueryCount(query string, args ...interface{}) (count int) {
  498. count = -1
  499. if !strings.Contains(strings.ToLower(query), "count(*)") {
  500. fmt.Println("QueryCount need query like < select count(*) from ..... >")
  501. return
  502. }
  503. count = int(m.CountBySql(query, args...))
  504. return
  505. }