package main import ( "context" "encoding/json" "fmt" "log" "math/big" "runtime" "strconv" "strings" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type MgoSess struct { Db string Coll string Query interface{} Sorts []string fields interface{} limit int64 skip int64 M *MongodbSim } type MgoIter struct { Cursor *mongo.Cursor } func (mt *MgoIter) Next(result interface{}) bool { if mt.Cursor != nil { if mt.Cursor.Next(nil) { err := mt.Cursor.Decode(result) if err != nil { log.Println("mgo cur err", err.Error()) mt.Cursor.Close(nil) return false } return true } else { mt.Cursor.Close(nil) return false } } else { return false } } func (ms *MgoSess) DB(name string) *MgoSess { ms.Db = name return ms } func (ms *MgoSess) C(name string) *MgoSess { ms.Coll = name return ms } func (ms *MgoSess) Find(q interface{}) *MgoSess { ms.Query = q return ms } func (ms *MgoSess) Select(fields interface{}) *MgoSess { ms.fields = fields return ms } func (ms *MgoSess) Limit(limit int64) *MgoSess { ms.limit = limit return ms } func (ms *MgoSess) Skip(skip int64) *MgoSess { ms.skip = skip return ms } func (ms *MgoSess) Sort(sorts ...string) *MgoSess { ms.Sorts = sorts return ms } func (ms *MgoSess) Iter() *MgoIter { it := &MgoIter{} find := options.Find() if ms.skip > 0 { find.SetSkip(ms.skip) } if ms.limit > 0 { find.SetLimit(ms.limit) } find.SetBatchSize(100) if len(ms.Sorts) > 0 { sort := bson.M{} for _, k := range ms.Sorts { switch k[:1] { case "-": sort[k[1:]] = -1 case "+": sort[k[1:]] = 1 default: sort[k] = 1 } } find.SetSort(sort) } if ms.fields != nil { find.SetProjection(ms.fields) } cur, err := ms.M.C.Database(ms.Db).Collection(ms.Coll).Find(ms.M.Ctx, ms.Query, find) if err != nil { log.Println("mgo find err", err.Error()) } else { it.Cursor = cur } return it } type MongodbSim struct { MongodbAddr string Size int // MinSize int DbName string C *mongo.Client Ctx context.Context ShortCtx context.Context pool chan bool UserName string Password string } func (m *MongodbSim) GetMgoConn() *MgoSess { //m.Open() ms := &MgoSess{} ms.M = m return ms } func (m *MongodbSim) DestoryMongoConn(ms *MgoSess) { //m.Close() ms.M = nil ms = nil } func (m *MongodbSim) InitPool() { opts := options.Client() opts.SetConnectTimeout(3 * time.Second) opts.ApplyURI("mongodb://" + m.MongodbAddr) opts.SetMaxPoolSize(uint64(m.Size)) m.pool = make(chan bool, m.Size) if m.UserName != "" && m.Password != "" { cre := options.Credential{ Username: m.UserName, Password: m.Password, } opts.SetAuth(cre) } opts.SetMaxConnIdleTime(2 * time.Hour) m.Ctx, _ = context.WithTimeout(context.Background(), 99999*time.Hour) m.ShortCtx, _ = context.WithTimeout(context.Background(), 1*time.Minute) client, err := mongo.Connect(m.ShortCtx, opts) if err != nil { log.Println("mgo init error:", err.Error()) } else { m.C = client log.Println("init success") } } func (m *MongodbSim) Open() { m.pool <- true } func (m *MongodbSim) Close() { <-m.pool } // 批量插入 func (m *MongodbSim) UpSertBulk(c string, doc ...[]map[string]interface{}) (map[int64]interface{}, bool) { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) var writes []mongo.WriteModel for _, d := range doc { write := mongo.NewUpdateOneModel() write.SetFilter(d[0]) write.SetUpdate(d[1]) write.SetUpsert(true) writes = append(writes, write) } r, e := coll.BulkWrite(m.Ctx, writes) if e != nil { log.Println("mgo upsert error:", e.Error()) return nil, false } // else { // if r.UpsertedCount != int64(len(doc)) { // log.Println("mgo upsert uncomplete:uc/dc", r.UpsertedCount, len(doc)) // } // return true // } return r.UpsertedIDs, true } // 批量插入 func (m *MongodbSim) SaveBulk(c string, doc ...map[string]interface{}) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) var writes []mongo.WriteModel for _, d := range doc { write := mongo.NewInsertOneModel() write.SetDocument(d) writes = append(writes, write) } _, e := coll.BulkWrite(m.Ctx, writes) if e != nil { log.Println("mgo savebulk error:", e.Error()) return false } return true } // 保存 func (m *MongodbSim) Save(c string, doc map[string]interface{}) interface{} { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r, err := coll.InsertOne(m.Ctx, doc) if err != nil { return nil } return r.InsertedID } // 按条件更新 func (m *MongodbSim) Update(c string, q, u interface{}, upsert bool, multi bool) bool { defer catch() m.Open() defer m.Close() ct := options.Update() if upsert { ct.SetUpsert(true) } coll := m.C.Database(m.DbName).Collection(c) var err error if multi { _, err = coll.UpdateMany(m.Ctx, ObjToM(q), ObjToM(u), ct) } else { _, err = coll.UpdateOne(m.Ctx, ObjToM(q), ObjToM(u), ct) } if err != nil { log.Println("删除错误", err.Error()) return false } return true } // 更新by Id func (m *MongodbSim) UpdateById(c, id string, doc map[string]interface{}) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) _, err := coll.UpdateOne(m.Ctx, map[string]interface{}{"_id": StringTOBsonId(id)}, doc) if err != nil { log.Println(err) return false } return true } // 删除by id func (m *MongodbSim) DeleteById(c, id string) int64 { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r, err := coll.DeleteOne(m.Ctx, map[string]interface{}{"_id": StringTOBsonId(id)}) if err != nil { return 0 } return r.DeletedCount } // 通过条件删除 func (m *MongodbSim) Delete(c string, query map[string]interface{}) int64 { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r, err := coll.DeleteMany(m.Ctx, query) if err != nil { return 0 } return r.DeletedCount } // findbyid func (m *MongodbSim) FindById(c, id string) map[string]interface{} { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r := coll.FindOne(m.Ctx, map[string]interface{}{"_id": StringTOBsonId(id)}) v := map[string]interface{}{} r.Decode(&v) return v } // findone func (m *MongodbSim) FindOne(c string, query map[string]interface{}) map[string]interface{} { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r := coll.FindOne(m.Ctx, query) v := map[string]interface{}{} r.Decode(&v) return v } // find func (m *MongodbSim) Find(c string, query map[string]interface{}, sort, fields interface{}) ([]map[string]interface{}, error) { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) op := options.Find() r, err := coll.Find(m.Ctx, query, op.SetSort(sort), op.SetProjection(fields)) if err != nil { log.Fatal(err) return nil, err } var results []map[string]interface{} if err = r.All(m.Ctx, &results); err != nil { log.Fatal(err) return nil, err } return results, nil } func ObjToOth(query interface{}) *bson.M { return ObjToMQ(query, false) } func ObjToM(query interface{}) *bson.M { return ObjToMQ(query, true) } // obj(string,M)转M,查询用到 func ObjToMQ(query interface{}, isQuery bool) *bson.M { data := make(bson.M) defer catch() if s2, ok2 := query.(*map[string]interface{}); ok2 { data = bson.M(*s2) } else if s3, ok3 := query.(*bson.M); ok3 { return s3 } else if s3, ok3 := query.(*primitive.M); ok3 { return s3 } else if s, ok := query.(string); ok { json.Unmarshal([]byte(strings.Replace(s, "'", "\"", -1)), &data) if ss, oks := data["_id"]; oks && isQuery { switch ss.(type) { case string: data["_id"], _ = primitive.ObjectIDFromHex(ss.(string)) case map[string]interface{}: tmp := ss.(map[string]interface{}) for k, v := range tmp { tmp[k], _ = primitive.ObjectIDFromHex(v.(string)) } data["_id"] = tmp } } } else if s1, ok1 := query.(map[string]interface{}); ok1 { data = s1 } else if s4, ok4 := query.(bson.M); ok4 { data = s4 } else if s4, ok4 := query.(primitive.M); ok4 { data = s4 } else { data = nil } return &data } func intAllDef(num interface{}, defaultNum int) int { if i, ok := num.(int); ok { return int(i) } else if i0, ok0 := num.(int32); ok0 { return int(i0) } else if i1, ok1 := num.(float64); ok1 { return int(i1) } else if i2, ok2 := num.(int64); ok2 { return int(i2) } else if i3, ok3 := num.(float32); ok3 { return int(i3) } else if i4, ok4 := num.(string); ok4 { in, _ := strconv.Atoi(i4) return int(in) } else if i5, ok5 := num.(int16); ok5 { return int(i5) } else if i6, ok6 := num.(int8); ok6 { return int(i6) } else if i7, ok7 := num.(*big.Int); ok7 { in, _ := strconv.Atoi(fmt.Sprint(i7)) return int(in) } else if i8, ok8 := num.(*big.Float); ok8 { in, _ := strconv.Atoi(fmt.Sprint(i8)) return int(in) } else { return defaultNum } } // 创建_id func NewObjectId() primitive.ObjectID { return primitive.NewObjectID() } func StringTOBsonId(id string) primitive.ObjectID { objectId, _ := primitive.ObjectIDFromHex(id) return objectId } func BsonTOStringId(id interface{}) string { return id.(primitive.ObjectID).Hex() } // 出错拦截 func catch() { if r := recover(); r != nil { log.Println(r) for skip := 0; ; skip++ { _, file, line, ok := runtime.Caller(skip) if !ok { break } go log.Printf("%v,%v\n", file, line) } } }