mssql.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package mssql
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. "sync"
  7. "time"
  8. log "github.com/sirupsen/logrus"
  9. "github.com/xyproto/algernon/lua/convert"
  10. lua "github.com/xyproto/gopher-lua"
  11. // Using the MSSQL database engine
  12. _ "github.com/denisenkom/go-mssqldb"
  13. )
  14. const (
  15. defaultQuery = "SELECT @@VERSION"
  16. defaultConnectionString = "server=localhost;user=sa;password=Password123,port=1433"
  17. )
  18. var (
  19. // global map from connection string to database connection, to reuse connections, protected by a mutex
  20. reuseDB = make(map[string]*sql.DB)
  21. reuseMut = &sync.RWMutex{}
  22. )
  23. // LValueWrapper decorates lua.LValue to help retieve values from the database.
  24. type LValueWrapper struct {
  25. LValue lua.LValue
  26. }
  27. // Scan implements the sql.Scanner interface for database deserialization.
  28. func (w *LValueWrapper) Scan(value any) error {
  29. if value == nil {
  30. *w = LValueWrapper{lua.LNil}
  31. return nil
  32. }
  33. switch v := value.(type) {
  34. case float32:
  35. *w = LValueWrapper{lua.LNumber(float64(v))}
  36. case float64:
  37. *w = LValueWrapper{lua.LNumber(v)}
  38. case int64:
  39. *w = LValueWrapper{lua.LNumber(float64(v))}
  40. case string:
  41. *w = LValueWrapper{lua.LString(v)}
  42. case []byte:
  43. *w = LValueWrapper{lua.LString(string(v))}
  44. case time.Time:
  45. *w = LValueWrapper{lua.LNumber(float64(v.Unix()))}
  46. default:
  47. return fmt.Errorf("unable to scan type %T into lua value wrapper", value)
  48. }
  49. return nil
  50. }
  51. // LValueWrappers is a convenience type to easily map to a slice of lua.LValue
  52. type LValueWrappers []LValueWrapper
  53. // Unwrap produces a slice of lua.LValue from the given LValueWrappers
  54. func (w LValueWrappers) Unwrap() (s []lua.LValue) {
  55. s = make([]lua.LValue, len(w))
  56. for i, v := range w {
  57. s[i] = v.LValue
  58. }
  59. return
  60. }
  61. // Interfaces returns a slice of any values from the given LValueWrappers
  62. func (w LValueWrappers) Interfaces() (s []any) {
  63. s = make([]any, len(w))
  64. for i := range w {
  65. s[i] = &w[i]
  66. }
  67. return
  68. }
  69. // Load makes functions related to building a library of Lua code available
  70. func Load(L *lua.LState) {
  71. // Register the MSSQL function
  72. L.SetGlobal("MSSQL", L.NewFunction(func(L *lua.LState) int {
  73. // Check if the optional argument is given
  74. query := defaultQuery
  75. if L.GetTop() >= 1 {
  76. query = L.ToString(1)
  77. if query == "" {
  78. query = defaultQuery
  79. }
  80. }
  81. connectionString := defaultConnectionString
  82. if L.GetTop() >= 2 {
  83. connectionString = L.ToString(2)
  84. }
  85. // Get arguments
  86. var queryArgs []any
  87. if L.GetTop() >= 3 {
  88. args := L.ToTable(3)
  89. args.ForEach(func(k lua.LValue, v lua.LValue) {
  90. switch k.Type() {
  91. case lua.LTNumber:
  92. queryArgs = append(queryArgs, v.String())
  93. case lua.LTString:
  94. queryArgs = append(queryArgs, sql.Named(k.String(), v.String()))
  95. }
  96. })
  97. }
  98. // Check if there is a connection that can be reused
  99. var db *sql.DB
  100. reuseMut.RLock()
  101. conn, ok := reuseDB[connectionString]
  102. reuseMut.RUnlock()
  103. if ok {
  104. // It exists, but is it still alive?
  105. err := conn.Ping()
  106. if err != nil {
  107. // no
  108. // log.Info("did not reuse the connection")
  109. reuseMut.Lock()
  110. delete(reuseDB, connectionString)
  111. reuseMut.Unlock()
  112. } else {
  113. // yes
  114. // log.Info("reused the connection")
  115. db = conn
  116. }
  117. }
  118. // Create a new connection, if needed
  119. var err error
  120. if db == nil {
  121. db, err = sql.Open("sqlserver", connectionString)
  122. if err != nil {
  123. log.Error("Could not connect to database using " + connectionString + ": " + err.Error())
  124. return 0 // No results
  125. }
  126. // Save the connection for later
  127. reuseMut.Lock()
  128. reuseDB[connectionString] = db
  129. reuseMut.Unlock()
  130. }
  131. // log.Info(fmt.Sprintf("MSSQL database: %v (%T)\n", db, db))
  132. reuseMut.Lock()
  133. rows, err := db.Query(query, queryArgs...)
  134. reuseMut.Unlock()
  135. if err != nil {
  136. errMsg := err.Error()
  137. if strings.Contains(errMsg, ": connect: connection refused") {
  138. log.Info("MSSQL connection string: " + connectionString)
  139. log.Info("MSSQL query: " + query)
  140. log.Error("Could not connect to database: " + errMsg)
  141. } else if strings.Contains(errMsg, "missing") && strings.Contains(errMsg, "in connection info string") {
  142. log.Info("MSSQL connection string: " + connectionString)
  143. log.Info("MSSQL query: " + query)
  144. log.Error(errMsg)
  145. } else {
  146. log.Info("MSSQL query: " + query)
  147. log.Error("Query failed: " + errMsg)
  148. }
  149. return 0 // No results
  150. }
  151. if rows == nil {
  152. // Return an empty table
  153. L.Push(L.NewTable())
  154. return 1 // number of results
  155. }
  156. cols, err := rows.Columns()
  157. if err != nil {
  158. log.Error("Failed to get columns: " + err.Error())
  159. return 0
  160. }
  161. // Return the rows as a 2-dimensional table
  162. // Outer table is an array of rows
  163. // Inner tables are maps of values with column names as keys
  164. var (
  165. m map[string]lua.LValue
  166. maps []map[string]lua.LValue
  167. values LValueWrappers
  168. cname string
  169. )
  170. for rows.Next() {
  171. values = make(LValueWrappers, len(cols))
  172. err = rows.Scan(values.Interfaces()...)
  173. if err != nil {
  174. log.Error("Failed to scan data: " + err.Error())
  175. break
  176. }
  177. m = make(map[string]lua.LValue, len(cols))
  178. for i, v := range values.Unwrap() {
  179. cname = cols[i]
  180. m[cname] = v
  181. }
  182. maps = append(maps, m)
  183. }
  184. // Convert the strings to a Lua table
  185. table := convert.LValueMaps2table(L, maps)
  186. // Return the table
  187. L.Push(table)
  188. return 1 // number of results
  189. }))
  190. }