package main import ( "context" "time" "log" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type MgoSess struct { Db string Coll string Query interface{} Sorts []string fields interface{} limit int64 skip int64 M *MongodbSim } type MgoIter struct { Cursor *mongo.Cursor } func (mt *MgoIter) Next(result interface{}) bool { if mt.Cursor != nil { if mt.Cursor.Next(nil) { err := mt.Cursor.Decode(result) if err != nil { log.Println("mgo cur err", err.Error()) mt.Cursor.Close(nil) return false } return true } else { mt.Cursor.Close(nil) return false } } else { return false } } func (ms *MgoSess) DB(name string) *MgoSess { ms.Db = name return ms } func (ms *MgoSess) C(name string) *MgoSess { ms.Coll = name return ms } func (ms *MgoSess) Find(q interface{}) *MgoSess { ms.Query = q return ms } func (ms *MgoSess) Select(fields interface{}) *MgoSess { ms.fields = fields return ms } func (ms *MgoSess) Limit(limit int64) *MgoSess { ms.limit = limit return ms } func (ms *MgoSess) Skip(skip int64) *MgoSess { ms.skip = skip return ms } func (ms *MgoSess) Sort(sorts ...string) *MgoSess { ms.Sorts = sorts return ms } func (ms *MgoSess) Iter() *MgoIter { it := &MgoIter{} find := options.Find() if ms.skip > 0 { find.SetSkip(ms.skip) } if ms.limit > 0 { find.SetLimit(ms.limit) } find.SetBatchSize(100) if len(ms.Sorts) > 0 { sort := bson.M{} for _, k := range ms.Sorts { switch k[:1] { case "-": sort[k[1:]] = -1 case "+": sort[k[1:]] = 1 default: sort[k] = 1 } } find.SetSort(sort) } if ms.fields != nil { find.SetProjection(ms.fields) } cur, err := ms.M.C.Database(ms.Db).Collection(ms.Coll).Find(ms.M.Ctx, ms.Query, find) if err != nil { log.Println("mgo find err", err.Error()) } else { it.Cursor = cur } return it } type MongodbSim struct { MongodbAddr string Size int // MinSize int DbName string C *mongo.Client Ctx context.Context ShortCtx context.Context pool chan bool UserName string Password string } func (m *MongodbSim) GetMgoConn() *MgoSess { //m.Open() ms := &MgoSess{} ms.M = m return ms } func (m *MongodbSim) DestoryMongoConn(ms *MgoSess) { //m.Close() ms.M = nil ms = nil } func (m *MongodbSim) InitPoolDirect() { opts := options.Client() opts.SetConnectTimeout(3 * time.Second) opts.ApplyURI("mongodb://" + m.MongodbAddr) opts.SetMaxPoolSize(uint64(m.Size)) opts.SetDirect(true) m.pool = make(chan bool, m.Size) if m.UserName != "" && m.Password != "" { cre := options.Credential{ Username: m.UserName, Password: m.Password, AuthSource: "admin", } opts.SetAuth(cre) } opts.SetMaxConnIdleTime(2 * time.Hour) m.Ctx, _ = context.WithTimeout(context.Background(), 99999*time.Hour) m.ShortCtx, _ = context.WithTimeout(context.Background(), 1*time.Minute) client, err := mongo.Connect(m.ShortCtx, opts) if err != nil { log.Println("mgo init error:", err.Error()) } else { m.C = client log.Println("init success") } } func (m *MongodbSim) InitPool() { opts := options.Client() opts.SetConnectTimeout(3 * time.Second) opts.ApplyURI("mongodb://" + m.MongodbAddr) opts.SetMaxPoolSize(uint64(m.Size)) //opts.SetDirect(true) m.pool = make(chan bool, m.Size) if m.UserName != "" && m.Password != "" { cre := options.Credential{ Username: m.UserName, Password: m.Password, AuthSource: "admin", } opts.SetAuth(cre) } opts.SetMaxConnIdleTime(2 * time.Hour) m.Ctx, _ = context.WithTimeout(context.Background(), 99999*time.Hour) m.ShortCtx, _ = context.WithTimeout(context.Background(), 1*time.Minute) client, err := mongo.Connect(m.ShortCtx, opts) if err != nil { log.Println("mgo init error:", err.Error()) } else { m.C = client log.Println("init success") } } func (m *MongodbSim) Open() { m.pool <- true } func (m *MongodbSim) Close() { <-m.pool } // 新建表并生成索引 func (m *MongodbSim) CreateIndex(c string, models []mongo.IndexModel) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) names, err := coll.Indexes().CreateMany(m.Ctx, models) if err == nil && len(names) > 0 { return true } else { log.Println("CreateIndex Error:", err) return false } } // 批量插入 func (m *MongodbSim) UpSertBulk(c string, doc ...[]map[string]interface{}) (map[int64]interface{}, bool) { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) var writes []mongo.WriteModel for _, d := range doc { write := mongo.NewUpdateOneModel() write.SetFilter(d[0]) write.SetUpdate(d[1]) write.SetUpsert(true) writes = append(writes, write) } r, e := coll.BulkWrite(m.Ctx, writes) if e != nil { log.Println("mgo upsert error:", e.Error()) return nil, false } // else { // if r.UpsertedCount != int64(len(doc)) { // log.Println("mgo upsert uncomplete:uc/dc", r.UpsertedCount, len(doc)) // } // return true // } return r.UpsertedIDs, true } // 批量插入 func (m *MongodbSim) SaveBulk(c string, doc ...map[string]interface{}) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) var writes []mongo.WriteModel for _, d := range doc { write := mongo.NewInsertOneModel() write.SetDocument(d) writes = append(writes, write) } _, e := coll.BulkWrite(m.Ctx, writes) if e != nil { log.Println("mgo savebulk error:", e.Error()) return false } return true } // 保存 func (m *MongodbSim) Save(c string, doc map[string]interface{}) interface{} { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r, err := coll.InsertOne(m.Ctx, doc) if err != nil { return nil } return r.InsertedID } // 更新by Id func (m *MongodbSim) UpdateById(c, id string, doc map[string]interface{}) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) _, err := coll.UpdateOne(m.Ctx, map[string]interface{}{"_id": StringTOBsonId(id)}, doc) if err != nil { return false } return true } func (m *MongodbSim) UpdateStrId(c, id string, doc map[string]interface{}) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) _, err := coll.UpdateOne(m.Ctx, map[string]interface{}{"_id": id}, doc) if err != nil { return false } return true } func (m *MongodbSim) UpdateQueryData(c string, query map[string]interface{}, doc map[string]interface{}) bool { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) _, err := coll.UpdateOne(m.Ctx, query, doc) if err != nil { return false } return true } // 删除by id func (m *MongodbSim) DeleteById(c, id string) int64 { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r, err := coll.DeleteOne(m.Ctx, map[string]interface{}{"_id": StringTOBsonId(id)}) if err != nil { return 0 } return r.DeletedCount } // 通过条件删除 func (m *MongodbSim) Delete(c string, query map[string]interface{}) int64 { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r, err := coll.DeleteMany(m.Ctx, query) if err != nil { return 0 } return r.DeletedCount } // findbyid func (m *MongodbSim) FindById(c, id string) map[string]interface{} { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r := coll.FindOne(m.Ctx, map[string]interface{}{"_id": StringTOBsonId(id)}) v := map[string]interface{}{} r.Decode(&v) return v } // findone func (m *MongodbSim) FindOne(c string, query map[string]interface{}) map[string]interface{} { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) r := coll.FindOne(m.Ctx, query) v := map[string]interface{}{} r.Decode(&v) return v } // find func (m *MongodbSim) Find(c string, query map[string]interface{}, sort, fields interface{}) ([]map[string]interface{}, error) { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) op := options.Find() r, err := coll.Find(m.Ctx, query, op.SetSort(sort), op.SetProjection(fields)) if err != nil { log.Fatal(err) return nil, err } var results []map[string]interface{} if err = r.All(m.Ctx, &results); err != nil { log.Fatal(err) return nil, err } return results, nil } // find func (m *MongodbSim) FindLimit(c string, query map[string]interface{}, sort, fields interface{}, limit int64) ([]map[string]interface{}, error) { m.Open() defer m.Close() coll := m.C.Database(m.DbName).Collection(c) op := options.Find() r, err := coll.Find(m.Ctx, query, op.SetSort(sort), op.SetProjection(fields), op.SetLimit(limit)) if err != nil { log.Fatal(err) return nil, err } var results []map[string]interface{} if err = r.All(m.Ctx, &results); err != nil { log.Fatal(err) return nil, err } return results, nil } // 创建_id func NewObjectId() primitive.ObjectID { return primitive.NewObjectID() } func StringTOBsonId(id string) primitive.ObjectID { objectId, _ := primitive.ObjectIDFromHex(id) return objectId } func BsonTOStringId(id interface{}) string { return id.(primitive.ObjectID).Hex() }