wangchuanjin 5 gadi atpakaļ
vecāks
revīzija
6535cb9c08
2 mainītis faili ar 229 papildinājumiem un 113 dzēšanām
  1. 132 113
      src/qfw/util/mongodb/mongodb.go
  2. 97 0
      src/qfw/util/mongodb/mongodb_test.go

+ 132 - 113
src/qfw/util/mongodb/mongodb.go

@@ -18,13 +18,15 @@ import (
 )
 
 type MgoSess struct {
-	Db     string
-	Coll   string
-	Query  interface{}
-	Sorts  []string
+	db     string
+	coll   string
+	query  interface{}
+	sorts  []string
 	fields interface{}
 	limit  int64
 	skip   int64
+	pipe   []map[string]interface{}
+	all    interface{}
 	M      *MongodbSim
 }
 
@@ -59,21 +61,20 @@ func (mt *MgoIter) Next(result interface{}) bool {
 	} else {
 		return false
 	}
-
 }
 
 func (ms *MgoSess) DB(name string) *MgoSess {
-	ms.Db = name
+	ms.db = name
 	return ms
 }
 
 func (ms *MgoSess) C(name string) *MgoSess {
-	ms.Coll = name
+	ms.coll = name
 	return ms
 }
 
 func (ms *MgoSess) Find(q interface{}) *MgoSess {
-	ms.Query = q
+	ms.query = q
 	return ms
 }
 
@@ -92,10 +93,19 @@ func (ms *MgoSess) Skip(skip int64) *MgoSess {
 }
 
 func (ms *MgoSess) Sort(sorts ...string) *MgoSess {
-	ms.Sorts = sorts
+	ms.sorts = sorts
 	return ms
 }
-
+func (ms *MgoSess) Pipe(p []map[string]interface{}) *MgoSess {
+	ms.pipe = p
+	return ms
+}
+func (ms *MgoSess) All(v *[]map[string]interface{}) {
+	cur, err := ms.M.C.Database(ms.db).Collection(ms.coll).Aggregate(ms.M.Ctx, ms.pipe)
+	if err == nil && cur.Err() == nil {
+		cur.All(ms.M.Ctx, v)
+	}
+}
 func (ms *MgoSess) Iter() *MgoIter {
 	it := &MgoIter{}
 	find := options.Find()
@@ -106,9 +116,9 @@ func (ms *MgoSess) Iter() *MgoIter {
 		find.SetLimit(ms.limit)
 	}
 	find.SetBatchSize(100)
-	if len(ms.Sorts) > 0 {
+	if len(ms.sorts) > 0 {
 		sort := bson.M{}
-		for _, k := range ms.Sorts {
+		for _, k := range ms.sorts {
 			switch k[:1] {
 			case "-":
 				sort[k[1:]] = -1
@@ -123,7 +133,7 @@ func (ms *MgoSess) Iter() *MgoIter {
 	if ms.fields != nil {
 		find.SetProjection(ms.fields)
 	}
-	cur, err := ms.M.C.Database(ms.Db).Collection(ms.Coll).Find(ms.M.Ctx, ms.Query, find)
+	cur, err := ms.M.C.Database(ms.db).Collection(ms.coll).Find(ms.M.Ctx, ms.query, find)
 	if err != nil {
 		log.Println("mgo find err", err.Error())
 	} else {
@@ -141,6 +151,9 @@ type MongodbSim struct {
 	Ctx      context.Context
 	ShortCtx context.Context
 	pool     chan bool
+	UserName string
+	Password string
+	ReplSet  string
 }
 
 func (m *MongodbSim) GetMgoConn() *MgoSess {
@@ -165,8 +178,24 @@ func (m *MongodbSim) Destroy() {
 func (m *MongodbSim) InitPool() {
 	opts := options.Client()
 	opts.SetConnectTimeout(3 * time.Second)
-	opts.ApplyURI("mongodb://" + m.MongodbAddr)
+	opts.SetHosts(strings.Split(m.MongodbAddr, ","))
+	//opts.ApplyURI("mongodb://" + m.MongodbAddr)
 	opts.SetMaxPoolSize(uint64(m.Size))
+	if m.UserName != "" && m.Password != "" {
+		cre := options.Credential{
+			Username: m.UserName,
+			Password: m.Password,
+		}
+		opts.SetAuth(cre)
+	}
+	ms := strings.Split(m.MongodbAddr, ",")
+	if m.ReplSet == "" && len(ms) > 1 {
+		m.ReplSet = "qfws"
+	}
+	if m.ReplSet != "" {
+		opts.SetReplicaSet(m.ReplSet)
+		opts.SetDirect(false)
+	}
 	m.pool = make(chan bool, m.Size)
 	opts.SetMaxConnIdleTime(2 * time.Hour)
 	m.Ctx, _ = context.WithTimeout(context.Background(), 99999*time.Hour)
@@ -187,7 +216,7 @@ func (m *MongodbSim) Close() {
 }
 
 func (m *MongodbSim) Save(c string, doc interface{}) string {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	coll := m.C.Database(m.DbName).Collection(c)
@@ -204,7 +233,7 @@ func (m *MongodbSim) Save(c string, doc interface{}) string {
 
 //原_id不变
 func (m *MongodbSim) SaveByOriID(c string, doc interface{}) bool {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	coll := m.C.Database(m.DbName).Collection(c)
@@ -218,7 +247,7 @@ func (m *MongodbSim) SaveByOriID(c string, doc interface{}) bool {
 
 //批量插入
 func (m *MongodbSim) SaveBulk(c string, doc ...map[string]interface{}) bool {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	coll := m.C.Database(m.DbName).Collection(c)
@@ -242,7 +271,7 @@ func (m *MongodbSim) SaveBulk(c string, doc ...map[string]interface{}) bool {
 
 //批量插入
 func (m *MongodbSim) SaveBulkInterface(c string, doc ...interface{}) bool {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	coll := m.C.Database(m.DbName).Collection(c)
@@ -264,44 +293,6 @@ func (m *MongodbSim) SaveBulkInterface(c string, doc ...interface{}) bool {
 	return true
 }
 
-//批量插入
-func (m *MongodbSim) UpSertBulk(c string, doc ...[]map[string]interface{}) bool {
-	return m.upSertBulk(m.DbName, c, true, doc...)
-}
-
-//批量插入
-func (m *MongodbSim) upSertBulk(db, c string, upsert bool, doc ...[]map[string]interface{}) bool {
-	defer Catch()
-	m.Open()
-	defer m.Close()
-	coll := m.C.Database(db).Collection(c)
-	var writes []mongo.WriteModel
-	for _, d := range doc {
-		write := mongo.NewUpdateOneModel()
-		write.SetFilter(d[0])
-		write.SetUpdate(d[1])
-		write.SetUpsert(upsert)
-		writes = append(writes, write)
-	}
-	br, e := coll.BulkWrite(m.Ctx, writes)
-	if e != nil {
-		log.Println("mgo upsert error:", e.Error())
-		return br == nil || br.UpsertedCount == 0
-	}
-	//	else {
-	//		if r.UpsertedCount != int64(len(doc)) {
-	//			log.Println("mgo upsert uncomplete:uc/dc", r.UpsertedCount, len(doc))
-	//		}
-	//		return true
-	//	}
-	return true
-}
-
-func StringTOBsonId(id string) primitive.ObjectID {
-	objectId, _ := primitive.ObjectIDFromHex(id)
-	return objectId
-}
-
 //按条件统计
 func (m *MongodbSim) Count(c string, q interface{}) int {
 	r, _ := m.CountByErr(c, q)
@@ -310,7 +301,7 @@ func (m *MongodbSim) Count(c string, q interface{}) int {
 
 //统计
 func (m *MongodbSim) CountByErr(c string, q interface{}) (int, error) {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	res, err := m.C.Database(m.DbName).Collection(c).CountDocuments(m.Ctx, ObjToM(q))
@@ -324,7 +315,7 @@ func (m *MongodbSim) CountByErr(c string, q interface{}) (int, error) {
 
 //按条件删除
 func (m *MongodbSim) Delete(c string, q interface{}) int64 {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	res, err := m.C.Database(m.DbName).Collection(c).DeleteMany(m.Ctx, ObjToM(q))
@@ -336,19 +327,20 @@ func (m *MongodbSim) Delete(c string, q interface{}) int64 {
 
 //删除对象
 func (m *MongodbSim) Del(c string, q interface{}) bool {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
-	res, err := m.C.Database(m.DbName).Collection(c).DeleteMany(m.Ctx, ObjToM(q))
-	if err != nil && res == nil {
+	_, err := m.C.Database(m.DbName).Collection(c).DeleteMany(m.Ctx, ObjToM(q))
+	if err != nil {
 		log.Println("删除错误", err.Error())
+		return false
 	}
-	return res.DeletedCount >= 0
+	return true
 }
 
 //按条件更新
 func (m *MongodbSim) Update(c string, q, u interface{}, upsert bool, multi bool) bool {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	ct := options.Update()
@@ -369,7 +361,7 @@ func (m *MongodbSim) Update(c string, q, u interface{}, upsert bool, multi bool)
 	return true
 }
 func (m *MongodbSim) UpdateById(c string, id interface{}, set interface{}) bool {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	q := make(map[string]interface{})
@@ -395,6 +387,39 @@ func (m *MongodbSim) UpdateBulk(c string, doc ...[]map[string]interface{}) bool
 	return m.UpdateBulkAll(m.DbName, c, doc...)
 }
 
+//批量插入
+func (m *MongodbSim) UpSertBulk(c string, doc ...[]map[string]interface{}) bool {
+	return m.upSertBulk(m.DbName, c, true, doc...)
+}
+
+//批量插入
+func (m *MongodbSim) upSertBulk(db, c string, upsert bool, doc ...[]map[string]interface{}) bool {
+	defer catch()
+	m.Open()
+	defer m.Close()
+	coll := m.C.Database(db).Collection(c)
+	var writes []mongo.WriteModel
+	for _, d := range doc {
+		write := mongo.NewUpdateOneModel()
+		write.SetFilter(d[0])
+		write.SetUpdate(d[1])
+		write.SetUpsert(upsert)
+		writes = append(writes, write)
+	}
+	br, e := coll.BulkWrite(m.Ctx, writes)
+	if e != nil {
+		log.Println("mgo upsert error:", e.Error())
+		return br == nil || br.UpsertedCount == 0
+	}
+	//	else {
+	//		if r.UpsertedCount != int64(len(doc)) {
+	//			log.Println("mgo upsert uncomplete:uc/dc", r.UpsertedCount, len(doc))
+	//		}
+	//		return true
+	//	}
+	return true
+}
+
 //查询单条对象
 func (m *MongodbSim) FindOne(c string, query interface{}) (*map[string]interface{}, bool) {
 	return m.FindOneByField(c, query, nil)
@@ -402,7 +427,7 @@ func (m *MongodbSim) FindOne(c string, query interface{}) (*map[string]interface
 
 //查询单条对象
 func (m *MongodbSim) FindOneByField(c string, query interface{}, fields interface{}) (*map[string]interface{}, bool) {
-	defer Catch()
+	defer catch()
 	res, ok := m.Find(c, query, nil, fields, true, -1, -1)
 	if nil != res && len(*res) > 0 {
 		return &((*res)[0]), ok
@@ -412,7 +437,7 @@ func (m *MongodbSim) FindOneByField(c string, query interface{}, fields interfac
 
 //查询单条对象
 func (m *MongodbSim) FindById(c string, query string, fields interface{}) (*map[string]interface{}, bool) {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	of := options.FindOne()
@@ -432,7 +457,7 @@ func (m *MongodbSim) FindById(c string, query string, fields interface{}) (*map[
 
 //底层查询方法
 func (m *MongodbSim) Find(c string, query interface{}, order interface{}, fields interface{}, single bool, start int, limit int) (*[]map[string]interface{}, bool) {
-	defer Catch()
+	defer catch()
 	m.Open()
 	defer m.Close()
 	res := make([]map[string]interface{}, 1)
@@ -460,55 +485,23 @@ func (m *MongodbSim) Find(c string, query interface{}, order interface{}, fields
 	return &res, true
 }
 
-/////////////////////////////////////////////////////
-/*func (m *MongodbSim) InitPool() {
-	defer util.Catch()
-	ms := strings.Split(m.MongodbAddr, ",")
-	if m.ReplSet == "" && len(ms) > 1 {
-		m.ReplSet = "qfws"
-	}
-	m.pool = make(chan *mgo.Session, m.Size)
-	for i := 0; i < m.Size; i++ {
-		sess, err := m.createConn()
-		if err == nil && sess.Ping() == nil {
-			m.pool <- sess
-		}
-	}
-}
-
-//取session链接
-func (m *MongodbSim) createConn() (sess *mgo.Session, err error) {
-	//增加对集群的支持
-	if m.ReplSet != "" {
-		info := mgo.DialInfo{
-			Addrs:          strings.Split(m.MongodbAddr, ","),
-			Timeout:        TIMEOUT,
-			ReplicaSetName: m.ReplSet,
-			Direct:         false,
-		}
-		return mgo.DialWithInfo(&info)
-	}
-	return mgo.Dial(m.MongodbAddr)
-}
-
-*/
-
-/////////////////////////////////////////////
-func ObjToOth(query interface{}) *map[string]interface{} {
+func ObjToOth(query interface{}) *bson.M {
 	return ObjToMQ(query, false)
 }
-func ObjToM(query interface{}) *map[string]interface{} {
+func ObjToM(query interface{}) *bson.M {
 	return ObjToMQ(query, true)
 }
 
 //obj(string,M)转M,查询用到
-func ObjToMQ(query interface{}, isQuery bool) *map[string]interface{} {
-	defer Catch()
-	data := make(map[string]interface{})
+func ObjToMQ(query interface{}, isQuery bool) *bson.M {
+	data := make(bson.M)
+	defer catch()
 	if s2, ok2 := query.(*map[string]interface{}); ok2 {
-		data = *s2
-	} else if s1, ok1 := query.(map[string]interface{}); ok1 {
-		data = s1
+		data = bson.M(*s2)
+	} else if s3, ok3 := query.(*bson.M); ok3 {
+		return s3
+	} else if s3, ok3 := query.(*primitive.M); ok3 {
+		return s3
 	} else if s, ok := query.(string); ok {
 		json.Unmarshal([]byte(strings.Replace(s, "'", "\"", -1)), &data)
 		if ss, oks := data["_id"]; oks && isQuery {
@@ -524,13 +517,18 @@ func ObjToMQ(query interface{}, isQuery bool) *map[string]interface{} {
 			}
 
 		}
+	} else if s1, ok1 := query.(map[string]interface{}); ok1 {
+		data = s1
+	} else if s4, ok4 := query.(bson.M); ok4 {
+		data = s4
+	} else if s4, ok4 := query.(primitive.M); ok4 {
+		data = s4
 	} else {
-		b, _ := json.Marshal(query)
-		json.Unmarshal(b, &data)
+		data = nil
 	}
 	return &data
 }
-func IntAllDef(num interface{}, defaultNum int) int {
+func intAllDef(num interface{}, defaultNum int) int {
 	if i, ok := num.(int); ok {
 		return int(i)
 	} else if i0, ok0 := num.(int32); ok0 {
@@ -560,7 +558,7 @@ func IntAllDef(num interface{}, defaultNum int) int {
 }
 
 //出错拦截
-func Catch() {
+func catch() {
 	if r := recover(); r != nil {
 		log.Println(r)
 		for skip := 0; ; skip++ {
@@ -572,3 +570,24 @@ func Catch() {
 		}
 	}
 }
+
+//根据bsonID转string
+func BsonIdToSId(uid interface{}) string {
+	if uid == nil {
+		return ""
+	} else if u, ok := uid.(string); ok {
+		return u
+	} else if u, ok := uid.(primitive.ObjectID); ok {
+		return u.Hex()
+	} else {
+		return ""
+	}
+}
+
+func StringTOBsonId(id string) (bid primitive.ObjectID) {
+	defer catch()
+	if id != "" {
+		bid, _ = primitive.ObjectIDFromHex(id)
+	}
+	return
+}

+ 97 - 0
src/qfw/util/mongodb/mongodb_test.go

@@ -0,0 +1,97 @@
+package mongodb
+
+import (
+	"log"
+	"testing"
+
+	"go.mongodb.org/mongo-driver/bson"
+	//	"go.mongodb.org/mongo-driver/bson"
+	//	"go.mongodb.org/mongo-driver/bson"
+)
+
+func Test_add(t *testing.T) {
+	m := &MongodbSim{
+		MongodbAddr: "192.168.3.128:27080",
+		Size:        5,
+		DbName:      "wcj",
+	}
+	m.InitPool()
+	// log.Println(m.Save("test", map[string]interface{}{
+	// 	"name": "张三",
+	// 	"age":  12,
+	// }))
+	// log.Println(m.SaveByOriID("test", map[string]interface{}{
+	// 	"name": "张三",
+	// 	"age":  25,
+	// }))
+	log.Println(m.SaveBulkInterface("test", []interface{}{
+		map[string]interface{}{
+			"name": "张三1",
+			"age":  1,
+		},
+		map[string]interface{}{
+			"name": "张三2",
+			"age":  2,
+		},
+	}...))
+}
+func Test_find(t *testing.T) {
+	m := &MongodbSim{
+		MongodbAddr: "192.168.3.128:27080",
+		Size:        5,
+		DbName:      "wcj",
+	}
+	m.InitPool()
+	list, _ := m.Find("test", map[string]interface{}{
+		//"name": "张三",
+		//"_id": _id,
+	}, map[string]interface{}{"age": -1}, map[string]interface{}{"age": 1, "name": 1, "_id": 0}, false, -1, -1)
+	for _, v := range *list {
+		log.Println(v)
+	}
+	// one, _ := m.FindById("test", "5f10204cf54cfedfc09b0d76", nil)
+	// log.Println(one)
+	// log.Println(BsonIdToSId((*one)["_id"]))
+}
+func Test_update(t *testing.T) {
+	m := &MongodbSim{
+		MongodbAddr: "192.168.3.128:27080",
+		Size:        5,
+		DbName:      "wcj",
+	}
+	m.InitPool()
+	s := [][]map[string]interface{}{
+		[]map[string]interface{}{
+			map[string]interface{}{"name": "李四111"},
+			map[string]interface{}{"$set": map[string]interface{}{"type": 1}},
+		},
+		[]map[string]interface{}{
+			map[string]interface{}{"name": "张三111"},
+			map[string]interface{}{"$set": map[string]interface{}{"type": 2}},
+		},
+	}
+	one := m.upSertBulk("wcj", "test", true, s...)
+	log.Println(one)
+}
+func Test_count(t *testing.T) {
+	m := &MongodbSim{
+		MongodbAddr: "192.168.3.128:27080",
+		Size:        5,
+		DbName:      "wcj",
+	}
+	m.InitPool()
+	one := m.Count("test", bson.M{
+		"name": "张三",
+	})
+	log.Println(one)
+}
+func Test_del(t *testing.T) {
+	m := &MongodbSim{
+		MongodbAddr: "192.168.3.128:27080",
+		Size:        5,
+		DbName:      "wcj",
+	}
+	m.InitPool()
+	one := m.Del("test", nil)
+	log.Println(one)
+}