123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- package mssql
- import (
- "database/sql"
- "fmt"
- "strings"
- "sync"
- "time"
- log "github.com/sirupsen/logrus"
- "github.com/xyproto/algernon/lua/convert"
- lua "github.com/xyproto/gopher-lua"
- // Using the MSSQL database engine
- _ "github.com/denisenkom/go-mssqldb"
- )
- const (
- defaultQuery = "SELECT @@VERSION"
- defaultConnectionString = "server=localhost;user=sa;password=Password123,port=1433"
- )
- var (
- // global map from connection string to database connection, to reuse connections, protected by a mutex
- reuseDB = make(map[string]*sql.DB)
- reuseMut = &sync.RWMutex{}
- )
- // LValueWrapper decorates lua.LValue to help retieve values from the database.
- type LValueWrapper struct {
- LValue lua.LValue
- }
- // Scan implements the sql.Scanner interface for database deserialization.
- func (w *LValueWrapper) Scan(value any) error {
- if value == nil {
- *w = LValueWrapper{lua.LNil}
- return nil
- }
- switch v := value.(type) {
- case float32:
- *w = LValueWrapper{lua.LNumber(float64(v))}
- case float64:
- *w = LValueWrapper{lua.LNumber(v)}
- case int64:
- *w = LValueWrapper{lua.LNumber(float64(v))}
- case string:
- *w = LValueWrapper{lua.LString(v)}
- case []byte:
- *w = LValueWrapper{lua.LString(string(v))}
- case time.Time:
- *w = LValueWrapper{lua.LNumber(float64(v.Unix()))}
- default:
- return fmt.Errorf("unable to scan type %T into lua value wrapper", value)
- }
- return nil
- }
- // LValueWrappers is a convenience type to easily map to a slice of lua.LValue
- type LValueWrappers []LValueWrapper
- // Unwrap produces a slice of lua.LValue from the given LValueWrappers
- func (w LValueWrappers) Unwrap() (s []lua.LValue) {
- s = make([]lua.LValue, len(w))
- for i, v := range w {
- s[i] = v.LValue
- }
- return
- }
- // Interfaces returns a slice of any values from the given LValueWrappers
- func (w LValueWrappers) Interfaces() (s []any) {
- s = make([]any, len(w))
- for i := range w {
- s[i] = &w[i]
- }
- return
- }
- // Load makes functions related to building a library of Lua code available
- func Load(L *lua.LState) {
- // Register the MSSQL function
- L.SetGlobal("MSSQL", L.NewFunction(func(L *lua.LState) int {
- // Check if the optional argument is given
- query := defaultQuery
- if L.GetTop() >= 1 {
- query = L.ToString(1)
- if query == "" {
- query = defaultQuery
- }
- }
- connectionString := defaultConnectionString
- if L.GetTop() >= 2 {
- connectionString = L.ToString(2)
- }
- // Get arguments
- var queryArgs []any
- if L.GetTop() >= 3 {
- args := L.ToTable(3)
- args.ForEach(func(k lua.LValue, v lua.LValue) {
- switch k.Type() {
- case lua.LTNumber:
- queryArgs = append(queryArgs, v.String())
- case lua.LTString:
- queryArgs = append(queryArgs, sql.Named(k.String(), v.String()))
- }
- })
- }
- // Check if there is a connection that can be reused
- var db *sql.DB
- reuseMut.RLock()
- conn, ok := reuseDB[connectionString]
- reuseMut.RUnlock()
- if ok {
- // It exists, but is it still alive?
- err := conn.Ping()
- if err != nil {
- // no
- // log.Info("did not reuse the connection")
- reuseMut.Lock()
- delete(reuseDB, connectionString)
- reuseMut.Unlock()
- } else {
- // yes
- // log.Info("reused the connection")
- db = conn
- }
- }
- // Create a new connection, if needed
- var err error
- if db == nil {
- db, err = sql.Open("sqlserver", connectionString)
- if err != nil {
- log.Error("Could not connect to database using " + connectionString + ": " + err.Error())
- return 0 // No results
- }
- // Save the connection for later
- reuseMut.Lock()
- reuseDB[connectionString] = db
- reuseMut.Unlock()
- }
- // log.Info(fmt.Sprintf("MSSQL database: %v (%T)\n", db, db))
- reuseMut.Lock()
- rows, err := db.Query(query, queryArgs...)
- reuseMut.Unlock()
- if err != nil {
- errMsg := err.Error()
- if strings.Contains(errMsg, ": connect: connection refused") {
- log.Info("MSSQL connection string: " + connectionString)
- log.Info("MSSQL query: " + query)
- log.Error("Could not connect to database: " + errMsg)
- } else if strings.Contains(errMsg, "missing") && strings.Contains(errMsg, "in connection info string") {
- log.Info("MSSQL connection string: " + connectionString)
- log.Info("MSSQL query: " + query)
- log.Error(errMsg)
- } else {
- log.Info("MSSQL query: " + query)
- log.Error("Query failed: " + errMsg)
- }
- return 0 // No results
- }
- if rows == nil {
- // Return an empty table
- L.Push(L.NewTable())
- return 1 // number of results
- }
- cols, err := rows.Columns()
- if err != nil {
- log.Error("Failed to get columns: " + err.Error())
- return 0
- }
- // Return the rows as a 2-dimensional table
- // Outer table is an array of rows
- // Inner tables are maps of values with column names as keys
- var (
- m map[string]lua.LValue
- maps []map[string]lua.LValue
- values LValueWrappers
- cname string
- )
- for rows.Next() {
- values = make(LValueWrappers, len(cols))
- err = rows.Scan(values.Interfaces()...)
- if err != nil {
- log.Error("Failed to scan data: " + err.Error())
- break
- }
- m = make(map[string]lua.LValue, len(cols))
- for i, v := range values.Unwrap() {
- cname = cols[i]
- m[cname] = v
- }
- maps = append(maps, m)
- }
- // Convert the strings to a Lua table
- table := convert.LValueMaps2table(L, maps)
- // Return the table
- L.Push(table)
- return 1 // number of results
- }))
- }
|