123456789101112131415161718192021222324252627282930313233343536373839 |
- package util
- import (
- "context"
- "github.com/nlpodyssey/cybertron/pkg/models/bert"
- "github.com/nlpodyssey/cybertron/pkg/tasks"
- "github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
- "log"
- )
- func SafeConvert2String(obj interface{}) string {
- if obj != nil {
- return obj.(string)
- }
- return ""
- }
- var m textencoding.Interface
- func init() {
- modelsDir := "./"
- modelName := "BAAI/bge-base-zh-v1.5"
- //modelName := "bge_base"
- var err error
- m, err = tasks.Load[textencoding.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName, DownloadPolicy: tasks.DownloadNever, ConversionPolicy: tasks.ConvertNever})
- if err != nil {
- log.Fatal(err)
- }
- }
- // 转向量
- func EncodeVector(text string) ([]float64, error) {
- result, err := m.Encode(context.Background(), text, int(bert.MeanPooling))
- if err != nil {
- return nil, err
- }
- return result.Vector.Data().F64(), nil
- }
|