package main import ( "context" "log" "strings" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type MgoSess struct { Db string Coll string Query interface{} Sorts []string fields interface{} limit int64 skip int64 M *MongodbSim } type MgoIter struct { Cursor *mongo.Cursor } func NewMgo(addr, db string, size int) *MongodbSim { mgo := &MongodbSim{ MongodbAddr: addr, Size: size, DbName: db, } mgo.InitPool() return mgo } func (mt *MgoIter) Next(result interface{}) bool { if mt.Cursor != nil { if mt.Cursor.Next(nil) { err := mt.Cursor.Decode(result) if err != nil { log.Println("mgo cur err", err.Error()) mt.Cursor.Close(nil) return false } return true } else { mt.Cursor.Close(nil) return false } } else { return false } } func (ms *MgoSess) DB(name string) *MgoSess { ms.Db = name return ms } func (ms *MgoSess) C(name string) *MgoSess { ms.Coll = name return ms } func (ms *MgoSess) Find(q interface{}) *MgoSess { ms.Query = q return ms } func (ms *MgoSess) Select(fields interface{}) *MgoSess { ms.fields = fields return ms } func (ms *MgoSess) Limit(limit int64) *MgoSess { ms.limit = limit return ms } func (ms *MgoSess) Skip(skip int64) *MgoSess { ms.skip = skip return ms } func (ms *MgoSess) Sort(sorts ...string) *MgoSess { ms.Sorts = sorts return ms } func (ms *MgoSess) Iter() *MgoIter { it := &MgoIter{} find := options.Find() if ms.skip > 0 { find.SetSkip(ms.skip) } if ms.limit > 0 { find.SetLimit(ms.limit) } find.SetBatchSize(100) if len(ms.Sorts) > 0 { sort := bson.M{} for _, k := range ms.Sorts { switch k[:1] { case "-": sort[k[1:]] = -1 case "+": sort[k[1:]] = 1 default: sort[k] = 1 } } find.SetSort(sort) } if ms.fields != nil { find.SetProjection(ms.fields) } cur, err := ms.M.C.Database(ms.Db).Collection(ms.Coll).Find(ms.M.Ctx, ms.Query, find) if err != nil { log.Println("mgo find err", err.Error()) } else { it.Cursor = cur } return it } type MongodbSim struct { MongodbAddr string Size int // MinSize int DbName string C *mongo.Client Ctx context.Context ShortCtx context.Context pool chan bool } func (m *MongodbSim) GetMgoConn() *MgoSess { //m.Open() ms := &MgoSess{} ms.M = m return ms } func (m *MongodbSim) DestoryMongoConn(ms *MgoSess) { //m.Close() ms.M = nil ms = nil } func (m *MongodbSim) Destroy() { //m.Close() m.C.Disconnect(nil) m.C = nil } func (m *MongodbSim) InitPool() { opts := options.Client() opts.SetConnectTimeout(3 * time.Second) opts.ApplyURI("mongodb://" + m.MongodbAddr) opts.SetMaxPoolSize(uint64(m.Size)) m.pool = make(chan bool, m.Size) opts.SetMaxConnIdleTime(2 * time.Hour) m.Ctx, _ = context.WithTimeout(context.Background(), 99999*time.Hour) m.ShortCtx, _ = context.WithTimeout(context.Background(), 1*time.Minute) client, err := mongo.Connect(m.ShortCtx, opts) if err != nil { log.Println("mgo init error:", err.Error()) } else { m.C = client } } func (m *MongodbSim) Open() { m.pool <- true } func (m *MongodbSim) Close() { <-m.pool } //批量插入 func (m *MongodbSim) UpSertBulk(c string, doc ...[]map[string]interface{}) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) var writes []mongo.WriteModel for _, d := range doc { write := mongo.NewUpdateOneModel() write.SetFilter(d[0]) write.SetUpdate(d[1]) write.SetUpsert(true) writes = append(writes, write) } br, e := coll.BulkWrite(m.Ctx, writes) if e != nil { log.Println("mgo upsert error:", e.Error()) return br == nil || br.UpsertedCount == 0 } // else { // if r.UpsertedCount != int64(len(doc)) { // log.Println("mgo upsert uncomplete:uc/dc", r.UpsertedCount, len(doc)) // } // return true // } return true } //批量插入 func (m *MongodbSim) SaveBulk(c string, doc ...map[string]interface{}) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) var writes []mongo.WriteModel for _, d := range doc { write := mongo.NewInsertOneModel() write.SetDocument(d) writes = append(writes, write) } br, e := coll.BulkWrite(m.Ctx, writes) if e != nil { b := strings.Index(e.Error(), "duplicate") > -1 log.Println("mgo savebulk error:", e.Error()) if br != nil { log.Println("mgo savebulk size:", br.InsertedCount) } return b } return true } func StringTOBsonId(id string) primitive.ObjectID { objectId, _ := primitive.ObjectIDFromHex(id) return objectId } //按条件统计 func (m *MongodbSim) Count(c string, q interface{}) int64 { m.Open() defer m.Close() ct := options.Count() ct.SetMaxTime(180 * time.Second) res, err := m.C.Database(m.DbName).Collection(c).CountDocuments(m.Ctx, q, ct) if err != nil { log.Println("统计错误", err.Error()) } return res } //按条件删除 func (m *MongodbSim) Delete(c string, q interface{}) int64 { m.Open() defer m.Close() ct := options.Delete() res, err := m.C.Database(m.DbName).Collection(c).DeleteMany(m.Ctx, q, ct) if err != nil && res == nil { log.Println("删除错误", err.Error()) } return res.DeletedCount } //按条件更新 func (m *MongodbSim) Update(c string, q, u interface{}) int64 { m.Open() defer m.Close() ct := options.Update() res, err := m.C.Database(m.DbName).Collection(c).UpdateMany(m.Ctx, q, u, ct) if err != nil && res == nil { log.Println("删除错误", err.Error()) } return res.ModifiedCount } //查找一条数据 func (m *MongodbSim) FindOne(c string, q, fields interface{}, sorts []string) (res map[string]interface{}) { m.Open() defer m.Close() of := options.FindOne() if fields != nil { of.SetProjection(fields) } if len(sorts) > 0 { sort := bson.M{} for _, k := range sorts { switch k[:1] { case "-": sort[k[1:]] = -1 case "+": sort[k[1:]] = 1 default: sort[k] = 1 } } of.SetSort(sort) } sr := m.C.Database(m.DbName).Collection(c).FindOne(m.Ctx, q, of) sr.Decode(&res) return }