Quellcode durchsuchen

feat:A24176调整

fuwencai vor 8 Monaten
Ursprung
Commit
a25c8dfcc1

+ 5 - 1
entity/db.go

@@ -1,11 +1,15 @@
 package entity
 
 import (
+	elastic "app.yhyue.com/moapp/jybase/es"
 	"app.yhyue.com/moapp/jybase/mysql"
 )
 
 var (
-	Mysql *mysql.Mysql
+	Mysql     *mysql.Mysql
+	ESV7      elastic.Es // 知识库使用新
+	ESV7Index string
+	ESV7Type  string
 )
 
 // MysqlMainStruct msyql

+ 8 - 1
rpc/knowledge/etc/knowledge.yaml

@@ -20,7 +20,14 @@ Es:
   index: smart_new
   type: smart
 Segment: http://192.168.3.149:9070/api/segment
-
+Esv7:
+  addr: http://192.168.3.149:9200
+  size: 5
+  index: jy_chat_new
+  type: jy_chat_new
+  username: ""
+  password: ""
+MinScore: 1.75
 TestConf:
   Etcd:
     Hosts:

+ 13 - 4
rpc/knowledge/init/init.go

@@ -1,7 +1,8 @@
 package init
 
 import (
-	elastic "app.yhyue.com/moapp/jybase/esv1"
+	elastic "app.yhyue.com/moapp/jybase/es"
+
 	"app.yhyue.com/moapp/jybase/mysql"
 	"bp.jydev.jianyu360.cn/SocialPlatform/knowledgeBase/entity"
 	"bp.jydev.jianyu360.cn/SocialPlatform/knowledgeBase/rpc/knowledge/internal/config"
@@ -17,7 +18,6 @@ var Logc entity.Logc
 
 var configF = flag.String("ff", "etc/knowledge.yaml", "the config file")
 
-//
 var logFile = flag.String("lf", "etc/logs.yaml", "the config file")
 
 func init() {
@@ -42,9 +42,18 @@ func init() {
 	if es.Addr != "" {
 		log.Println("--初始化 elasticsearch--")
 		log.Println(es.Addr, es.Size)
-		elastic.InitElasticSize(es.Addr, es.Size)
+		elastic.NewEs("v1", es.Addr, es.Size, "", "")
 	}
-
+	entity.ESV7 = &elastic.EsV7{
+		Address:  C.Esv7.Address,
+		UserName: C.Esv7.Username,
+		Password: C.Esv7.Password,
+		Size:     C.Esv7.Size,
+	}
+	entity.ESV7Index = C.Esv7.Index
+	entity.ESV7Type = C.Esv7.Type
+	entity.ESV7.Init()
+	log.Println("--初始化 elasticsearch v7--")
 	//初始化日志信息
 	conf.MustLoad(*logFile, &Logc)
 	if len(Logc.Level) > 0 {

+ 9 - 0
rpc/knowledge/internal/config/config.go

@@ -15,4 +15,13 @@ type Config struct {
 	UserCenterConf         zrpc.RpcClientConf
 	FindCount              int
 	RecommendQuestionCount int
+	Esv7                   struct {
+		Address  string `yaml:"addr"`
+		Size     int    `yaml:"size"`
+		Username string `yaml:"username"`
+		Password string `yaml:"password"`
+		Index    string `yaml:"index"`
+		Type     string `yaml:"type"`
+	}
+	MinScore float64
 }

+ 78 - 2
rpc/knowledge/internal/service/knowledgeService.go

@@ -5,6 +5,7 @@ import (
 	. "app.yhyue.com/moapp/jybase/encrypt"
 	elastic "app.yhyue.com/moapp/jybase/esv1"
 	. "bp.jydev.jianyu360.cn/SocialPlatform/knowledgeBase/entity"
+	"bp.jydev.jianyu360.cn/SocialPlatform/knowledgeBase/rpc/knowledge/init"
 	"bp.jydev.jianyu360.cn/SocialPlatform/knowledgeBase/rpc/knowledge/knowledgeclient"
 	"bp.jydev.jianyu360.cn/SocialPlatform/knowledgeBase/rpc/knowledge/util"
 	"database/sql"
@@ -77,6 +78,22 @@ func (k *KnowledgeService) KnowledgeAdd(param *knowledgeclient.AddRequest, segme
 				"entId":        param.EntId,
 			}
 			b := elastic.Save(Index, Type, knowledge)
+			if !b {
+				return false, "es 保存失败"
+			}
+			// 存入向量库
+			//插入es
+			knowledgeV := map[string]interface{}{
+				"mod_time":       time.Now().Unix(),
+				"answer":         param.Answer,
+				"question":       param.Question,
+				"id":             answerId,
+				"entId":          param.EntId,
+				"questionVector": util.EncodeVector(param.Question),
+			}
+			if !ESV7.Save(ESV7Index, ESV7Type, knowledgeV) {
+				logx.Error("知识库添加向量失败:", knowledgeV)
+			}
 			return b, ""
 		}
 		return fool, "插入mysql出错"
@@ -143,6 +160,42 @@ func (k *KnowledgeService) KnowledgeEdit(param *knowledgeclient.KnowledgeEditReq
 		}
 		ok1 := elastic.Del(Index, Type, query)
 		ok2 := elastic.Save(Index, Type, newKnowledge)
+		//  查询出来
+		queryByAid := fmt.Sprintf(`{
+    "query": {
+        "bool": {
+            "must": [
+                {
+                    "term": {
+                        "id": %v
+                    }
+                }
+            ]
+        }
+    },
+    "size": 1
+}`, param.AnswerId)
+		rs := ESV7.Get(ESV7Index, ESV7Type, queryByAid)
+		_id := ""
+		if rs != nil && len(*rs) > 0 {
+			_id = cm.InterfaceToStr((*rs)[0]["_id"])
+		}
+		// 存入向量库
+		//插入es
+		knowledgeV := map[string]interface{}{
+			"mod_time":       time.Now().Unix(),
+			"answer":         param.Answer,
+			"question":       param.Question,
+			"id":             param.AnswerId,
+			"entId":          param.EntId,
+			"questionVector": util.EncodeVector(param.Question),
+		}
+		if _id != "" {
+			knowledgeV["_id"] = _id
+		}
+		if !ESV7.Save(ESV7Index, ESV7Type, knowledgeV) {
+			logx.Error("知识库添加向量失败:", knowledgeV)
+		}
 		return ok1 && ok2
 	}
 	return ok
@@ -204,7 +257,29 @@ func (k *KnowledgeService) KnowledgeDel(answerId int64) (ok bool) {
 		//删除es数据
 		query := `{"query":{"bool":{"must":[{"term":{"answerId":"` + strconv.Itoa(int(answerId)) + `"}}],"must_not":[],"should":[]}},"from":0,"size":1,"sort":[],"facets":{}}`
 		ok = elastic.Del(Index, Type, query)
+		queryByAid := fmt.Sprintf(`{
+    "query": {
+        "bool": {
+            "must": [
+                {
+                    "term": {
+                        "id": %v
+                    }
+                }
+            ]
+        }
+    },
+    "size": 1
+}`, answerId)
+		rs := ESV7.Get(ESV7Index, ESV7Type, queryByAid)
+		if rs != nil && len(*rs) > 0 {
+			_id := cm.InterfaceToStr((*rs)[0]["_id"])
+			if !ESV7.DelById(ESV7Index, ESV7Type, _id) {
+				logx.Error("删除向量库失败:", _id, answerId)
+			}
+		}
 	}
+
 	return ok
 }
 
@@ -213,10 +288,11 @@ func (k *KnowledgeService) FindAnswer(param *knowledgeclient.FindAnswerReq, addr
 	var question knowledgeclient.Question
 	robotEntId := SE.Decode4Hex(param.RobotEntId)
 	//组装es query
-	query := util.DSL4SmartResponse(param.Question, robotEntId, int(param.Type), addr, index, segment)
+	//query := util.DSL4SmartResponse(param.Question, robotEntId, int(param.Type), addr, index, segment)
+	query := util.GetAnswerQueryStr(param.Question, robotEntId, 1, init.C.MinScore)
 	logx.Info("query:", query)
 	if query != "" {
-		res := elastic.Get(Index, Type, query)
+		res := ESV7.Get(ESV7Index, ESV7Type, query)
 		if res != nil && len(*res) > 0 {
 			data := (*res)[0]
 			question.Answer = cm.ObjToString(data["answer"])

+ 9 - 0
rpc/knowledge/util/elasticsearch_dsl.go

@@ -90,3 +90,12 @@ func GetQueryOT(tags, question, keywords, repositoryId string) (qstr string) {
 	qstr = fmt.Sprintf(query, queryMatch, queryTerms, queryId, queryQues)
 	return qstr
 }
+
+var queryStr = `{"_source": ["question","intention","answer"],"size": %d, "min_score":%v, 
+
+ "query": {"bool": {"must": [{"term":{"entId":%v}"},{"script_score": {"query": {"match_all": {}},"script": {"source": "cosineSimilarity(params.queryVector,'questionVector')+1", "params": {"queryVector": %v}}}}]}}}`
+
+func GetAnswerQueryStr(question string, entId string, size int, minScore float64) string {
+	qv, _ := EncodeVector(question)
+	return fmt.Sprintf(queryStr, size, minScore, entId, qv)
+}

+ 31 - 0
rpc/knowledge/util/util.go

@@ -1,8 +1,39 @@
 package util
 
+import (
+	"context"
+	"github.com/nlpodyssey/cybertron/pkg/models/bert"
+	"github.com/nlpodyssey/cybertron/pkg/tasks"
+	"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
+	"log"
+)
+
 func SafeConvert2String(obj interface{}) string {
 	if obj != nil {
 		return obj.(string)
 	}
 	return ""
 }
+
+var m textencoding.Interface
+
+func init() {
+	modelsDir := "./"
+	modelName := "BAAI/bge-base-zh-v1.5"
+	//modelName := "bge_base"
+	var err error
+
+	m, err = tasks.Load[textencoding.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName, DownloadPolicy: tasks.DownloadNever, ConversionPolicy: tasks.ConvertNever})
+	if err != nil {
+		log.Fatal(err)
+	}
+}
+
+// 转向量
+func EncodeVector(text string) ([]float64, error) {
+	result, err := m.Encode(context.Background(), text, int(bert.MeanPooling))
+	if err != nil {
+		return nil, err
+	}
+	return result.Vector.Data().F64(), nil
+}