util.go 888 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. package util
  2. import (
  3. "context"
  4. "github.com/nlpodyssey/cybertron/pkg/models/bert"
  5. "github.com/nlpodyssey/cybertron/pkg/tasks"
  6. "github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
  7. "log"
  8. )
  9. func SafeConvert2String(obj interface{}) string {
  10. if obj != nil {
  11. return obj.(string)
  12. }
  13. return ""
  14. }
  15. var m textencoding.Interface
  16. func init() {
  17. modelsDir := "./"
  18. modelName := "BAAI/bge-base-zh-v1.5"
  19. //modelName := "bge_base"
  20. var err error
  21. m, err = tasks.Load[textencoding.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName, DownloadPolicy: tasks.DownloadNever, ConversionPolicy: tasks.ConvertNever})
  22. if err != nil {
  23. log.Fatal(err)
  24. }
  25. }
  26. // 转向量
  27. func EncodeVector(text string) ([]float64, error) {
  28. result, err := m.Encode(context.Background(), text, int(bert.MeanPooling))
  29. if err != nil {
  30. return nil, err
  31. }
  32. return result.Vector.Data().F64(), nil
  33. }