ws.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package model
  2. import (
  3. "aiChat/utility/fsw"
  4. . "app.yhyue.com/moapp/jybase/common"
  5. "app.yhyue.com/moapp/jybase/date"
  6. "app.yhyue.com/moapp/jybase/encrypt"
  7. "bufio"
  8. "context"
  9. "encoding/json"
  10. "fmt"
  11. "github.com/gogf/gf/v2/encoding/gjson"
  12. "github.com/gogf/gf/v2/frame/g"
  13. "github.com/gogf/gf/v2/net/ghttp"
  14. "github.com/gogf/gf/v2/os/glog"
  15. "github.com/gogf/gf/v2/util/gconv"
  16. "io"
  17. "time"
  18. )
  19. type WsChat struct {
  20. Ctx context.Context
  21. }
  22. func NewMessage(ctx context.Context) *WsChat {
  23. return &WsChat{
  24. Ctx: ctx,
  25. }
  26. }
  27. // Handle 处理消息
  28. func (m *WsChat) Handle(ws *ghttp.WebSocket, msg []byte) {
  29. defer Catch()
  30. jSession := SessionCtx.Get(m.Ctx).JSession
  31. if jSession.PositionId == 0 {
  32. _ = ws.WriteJSON(g.Map{
  33. "error_code": -1,
  34. "error_msg": "请登录",
  35. })
  36. return
  37. }
  38. req, from := &QuestionReq{}, 0
  39. if err := gjson.Unmarshal(msg, req); err != nil {
  40. glog.Errorf(m.Ctx, "%d 接收消息Unmarshal出错:%v", jSession.PositionId, err)
  41. return
  42. }
  43. questionId := ChatHistory.Save(m.Ctx, &ChatRecord{
  44. Content: req.Prompt,
  45. Type: 1,
  46. Refer: req.Href,
  47. PersonId: jSession.PositionId,
  48. CreateTime: time.Now().Format(date.Date_Full_Layout),
  49. })
  50. content, res, replyId, errMsg := func() (string, io.ReadCloser, int64, error) {
  51. var (
  52. err error
  53. res io.ReadCloser
  54. reply string
  55. )
  56. errReply := func() string {
  57. // 校验是否在黑名单,黑名单不返回内容
  58. if UserBlackList.CheckBlackList(m.Ctx, jSession.PositionId) {
  59. return g.Cfg().MustGet(m.Ctx, "limit.blackMsg").String()
  60. }
  61. // 校验问答频率
  62. if ChatLimit.GetBucket(m.Ctx, jSession.PositionId).TakeAvailable(1) == 0 {
  63. return g.Cfg().MustGet(m.Ctx, "limit.exceedMsg").String()
  64. }
  65. // 问题敏感词过滤
  66. if fsw.Match(req.Prompt) {
  67. return g.Cfg().MustGet(m.Ctx, "limit.fswMsg").String()
  68. }
  69. return ""
  70. }()
  71. if errReply != "" {
  72. reply, from = errReply, -1
  73. } else {
  74. reply, res, from, err = Question.DetailQuestion(m.Ctx, req)
  75. if err != nil {
  76. g.Log().Error(m.Ctx, "问答异常", err)
  77. reply, from = g.Cfg().MustGet(m.Ctx, "limit.errMsg").String(), -1
  78. }
  79. }
  80. if from == Answer_ChatGPT {
  81. return reply, res, 0, nil
  82. }
  83. if reply == "" {
  84. reply, from = g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String(), -1
  85. }
  86. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  87. Content: reply,
  88. Type: 2,
  89. Actions: gconv.Int(If(errReply == "", 1, 0)),
  90. QuestionId: questionId,
  91. PersonId: jSession.PositionId,
  92. Item: gconv.Int(If(errReply == "", from, -1)),
  93. CreateTime: time.Now().Format(date.Date_Full_Layout),
  94. })
  95. if replyId <= 0 {
  96. g.Log().Error(m.Ctx, "问答存储存储异常")
  97. }
  98. if errReply != "" {
  99. return reply, nil, replyId, fmt.Errorf(errReply)
  100. }
  101. return reply, nil, replyId, nil
  102. }()
  103. if res != nil {
  104. defer res.Close()
  105. }
  106. if from != Answer_ChatGPT {
  107. if errMsg != nil {
  108. _ = ws.WriteJSON(g.Map{"error_code": -1, "error_msg": errMsg.Error(), "data": nil})
  109. } else {
  110. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": content, "isEnd": true}})
  111. }
  112. } else if res != nil {
  113. buf, lastData := bufio.NewReader(res), &BufRes{}
  114. for {
  115. line, _, err := buf.ReadLine()
  116. if err == io.EOF {
  117. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  118. Content: lastData.Response,
  119. Type: 2,
  120. Actions: 1,
  121. QuestionId: questionId,
  122. PersonId: jSession.PositionId,
  123. Item: Answer_ChatGPT,
  124. CreateTime: time.Now().Format(date.Date_Full_Layout),
  125. })
  126. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": lastData.Response, "isEnd": lastData.Finished}})
  127. break
  128. }
  129. if _, data := parseEventStream(line); data != nil {
  130. data.Response = fsw.Repl(data.Response)
  131. lastData = data
  132. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"reply": lastData.Response, "isEnd": false}})
  133. }
  134. }
  135. }
  136. }
  137. func saveAndReturnMap() {
  138. }
  139. func parseEventStream(line []byte) (event string, date *BufRes) {
  140. // 如果行以 "event:" 开头,表示这是一个事件的标识符
  141. if len(line) > 6 && string(line[:6]) == "event:" {
  142. event = string(line[6 : len(line)-1])
  143. return event, nil
  144. }
  145. // 如果行以 "data:" 开头,表示这是事件的数据部分
  146. if len(line) > 5 && string(line[:5]) == "data:" {
  147. date = &BufRes{}
  148. if err := json.Unmarshal(line[5:len(line)], date); err == nil {
  149. return
  150. }
  151. return event, nil
  152. }
  153. return "", nil
  154. }