瀏覽代碼

feat:修改

wangchuanjin 3 年之前
父節點
當前提交
b1ef6874e7
共有 1 個文件被更改,包括 69 次插入42 次删除
  1. 69 42
      src/mongodb/mongodb.go

+ 69 - 42
src/mongodb/mongodb.go

@@ -147,7 +147,6 @@ type MgoSess struct {
 	fields interface{}
 	limit  int64
 	skip   int64
-	pipe   []map[string]interface{}
 	all    interface{}
 	M      *MongodbSim
 }
@@ -190,9 +189,12 @@ func (ms *MgoSess) Sort(sorts ...string) *MgoSess {
 	ms.sorts = sorts
 	return ms
 }
-func (ms *MgoSess) Pipe(p []map[string]interface{}) *MgoSess {
-	ms.pipe = p
-	return ms
+func (ms *MgoSess) Pipe(p []map[string]interface{}) *pipe {
+	pe := &pipe{
+		ms:       ms,
+		pipeline: p,
+	}
+	return pe
 }
 func (ms *MgoSess) Insert(doc interface{}) error {
 	_, err := ms.M.C.Database(ms.db).Collection(ms.coll).InsertOne(ms.M.Ctx, doc)
@@ -247,7 +249,16 @@ func (ms *MgoSess) One(v *map[string]interface{}) {
 	}
 }
 func (ms *MgoSess) All(v *[]map[string]interface{}) {
-	cur, err := ms.M.C.Database(ms.db).Collection(ms.coll).Aggregate(ms.M.Ctx, ms.pipe)
+	of := options.Find()
+	of.SetProjection(ms.fields)
+	of.SetSort(ms.sorts)
+	if ms.skip > 0 {
+		of.SetSkip(ms.skip)
+	}
+	if ms.limit > 0 {
+		of.SetLimit(ms.limit)
+	}
+	cur, err := ms.M.C.Database(ms.db).Collection(ms.coll).Find(ms.M.Ctx, ms.query, of)
 	if err == nil && cur.Err() == nil {
 		cur.All(ms.M.Ctx, v)
 	}
@@ -255,45 +266,34 @@ func (ms *MgoSess) All(v *[]map[string]interface{}) {
 func (ms *MgoSess) Iter() *MgoIter {
 	it := &MgoIter{}
 	coll := ms.M.C.Database(ms.db).Collection(ms.coll)
-	var cur *mongo.Cursor
-	var err error
-	if ms.query != nil {
-		find := options.Find()
-		if ms.skip > 0 {
-			find.SetSkip(ms.skip)
-		}
-		if ms.limit > 0 {
-			find.SetLimit(ms.limit)
-		}
-		find.SetBatchSize(100)
-		if len(ms.sorts) > 0 {
-			sort := bson.D{}
-			for _, k := range ms.sorts {
-				switch k[:1] {
-				case "-":
-					sort = append(sort, bson.E{k[1:], -1})
-				case "+":
-					sort = append(sort, bson.E{k[1:], 1})
-				default:
-					sort = append(sort, bson.E{k, 1})
-				}
+	find := options.Find()
+	if ms.skip > 0 {
+		find.SetSkip(ms.skip)
+	}
+	if ms.limit > 0 {
+		find.SetLimit(ms.limit)
+	}
+	find.SetBatchSize(100)
+	if len(ms.sorts) > 0 {
+		sort := bson.D{}
+		for _, k := range ms.sorts {
+			switch k[:1] {
+			case "-":
+				sort = append(sort, bson.E{k[1:], -1})
+			case "+":
+				sort = append(sort, bson.E{k[1:], 1})
+			default:
+				sort = append(sort, bson.E{k, 1})
 			}
-			find.SetSort(sort)
-		}
-		if ms.fields != nil {
-			find.SetProjection(ms.fields)
-		}
-		cur, err = coll.Find(ms.M.Ctx, ms.query, find)
-		if err != nil {
-			log.Println("mgo find err", err.Error())
-		}
-	} else if ms.pipe != nil {
-		aggregate := options.Aggregate()
-		aggregate.SetBatchSize(100)
-		cur, err = coll.Aggregate(ms.M.Ctx, ms.pipe, aggregate)
-		if err != nil {
-			log.Println("mgo aggregate err", err.Error())
 		}
+		find.SetSort(sort)
+	}
+	if ms.fields != nil {
+		find.SetProjection(ms.fields)
+	}
+	cur, err := coll.Find(ms.M.Ctx, ms.query, find)
+	if err != nil {
+		log.Println("mgo find err", err.Error())
 	}
 	if err == nil {
 		it.Cursor = cur
@@ -302,6 +302,33 @@ func (ms *MgoSess) Iter() *MgoIter {
 	return it
 }
 
+type pipe struct {
+	ms       *MgoSess
+	pipeline []map[string]interface{}
+}
+
+func (p *pipe) All(v *[]map[string]interface{}) {
+	cur, err := p.ms.M.C.Database(p.ms.db).Collection(p.ms.coll).Aggregate(p.ms.M.Ctx, p.pipeline)
+	if err == nil && cur.Err() == nil {
+		cur.All(p.ms.M.Ctx, v)
+	}
+}
+func (p *pipe) Iter() *MgoIter {
+	it := &MgoIter{}
+	coll := p.ms.M.C.Database(p.ms.db).Collection(p.ms.coll)
+	aggregate := options.Aggregate()
+	aggregate.SetBatchSize(100)
+	cur, err := coll.Aggregate(p.ms.M.Ctx, p.pipeline, aggregate)
+	if err != nil {
+		log.Println("mgo aggregate err", err.Error())
+	}
+	if err == nil {
+		it.Cursor = cur
+		it.Ctx = p.ms.M.Ctx
+	}
+	return it
+}
+
 type MongodbSim struct {
 	MongodbAddr string
 	Size        int