ws.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package model
  2. import (
  3. "aiChat/utility/fsw"
  4. . "app.yhyue.com/moapp/jybase/common"
  5. "app.yhyue.com/moapp/jybase/date"
  6. "app.yhyue.com/moapp/jybase/encrypt"
  7. "bufio"
  8. "context"
  9. "encoding/json"
  10. "fmt"
  11. "github.com/gogf/gf/v2/encoding/gjson"
  12. "github.com/gogf/gf/v2/frame/g"
  13. "github.com/gogf/gf/v2/net/ghttp"
  14. "github.com/gogf/gf/v2/util/gconv"
  15. "io"
  16. "time"
  17. )
  18. type WsChat struct {
  19. Ctx context.Context
  20. }
  21. func NewMessage(ctx context.Context) *WsChat {
  22. return &WsChat{
  23. Ctx: ctx,
  24. }
  25. }
  26. // Handle 处理消息
  27. func (m *WsChat) Handle(ws *ghttp.WebSocket, msg []byte) {
  28. defer Catch()
  29. jSession := SessionCtx.Get(m.Ctx).JSession
  30. if jSession.PositionId == 0 {
  31. _ = ws.WriteJSON(g.Map{
  32. "error_code": -1,
  33. "error_msg": "请登录",
  34. })
  35. return
  36. }
  37. req, from := &QuestionReq{}, 0
  38. if err := gjson.Unmarshal(msg, req); err != nil {
  39. g.Log().Errorf(m.Ctx, "%d 接收消息Unmarshal出错:%v", jSession.PositionId, err)
  40. return
  41. }
  42. questionId := ChatHistory.Save(m.Ctx, &ChatRecord{
  43. Content: req.Prompt,
  44. Type: 1,
  45. Refer: req.Href,
  46. PersonId: jSession.PositionId,
  47. CreateTime: time.Now().Format(date.Date_Full_Layout),
  48. })
  49. content, res, replyId, errMsg := func() (string, io.ReadCloser, int64, error) {
  50. var (
  51. err error
  52. res io.ReadCloser
  53. reply string
  54. )
  55. errReply := func() string {
  56. // 校验是否在黑名单,黑名单不返回内容
  57. if UserBlackList.CheckBlackList(m.Ctx, jSession.PositionId) {
  58. return g.Cfg().MustGet(m.Ctx, "limit.blackMsg").String()
  59. }
  60. // 校验问答频率
  61. if ChatLimit.GetBucket(m.Ctx, jSession.PositionId).TakeAvailable(1) == 0 {
  62. return g.Cfg().MustGet(m.Ctx, "limit.exceedMsg").String()
  63. }
  64. // 问题敏感词过滤
  65. if fsw.Match(req.Prompt) {
  66. return g.Cfg().MustGet(m.Ctx, "limit.fswMsg").String()
  67. }
  68. return ""
  69. }()
  70. if errReply != "" {
  71. reply, from = errReply, -1
  72. } else {
  73. reply, res, from, err = Question.DetailQuestion(m.Ctx, req)
  74. if err != nil {
  75. g.Log().Error(m.Ctx, "问答异常", err)
  76. reply, from = g.Cfg().MustGet(m.Ctx, "limit.errMsg").String(), -1
  77. }
  78. }
  79. if from == Answer_ChatGPT {
  80. return reply, res, 0, nil
  81. }
  82. if reply == "" {
  83. reply = g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String()
  84. errReply = g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String()
  85. }
  86. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  87. Content: reply,
  88. Type: 2,
  89. Actions: gconv.Int(If(errReply == "", 1, 0)),
  90. QuestionId: questionId,
  91. PersonId: jSession.PositionId,
  92. Item: from,
  93. CreateTime: time.Now().Format(date.Date_Full_Layout),
  94. })
  95. if replyId <= 0 {
  96. g.Log().Error(m.Ctx, "问答存储存储异常")
  97. }
  98. if errReply != "" {
  99. return reply, nil, replyId, fmt.Errorf(errReply)
  100. }
  101. return reply, nil, replyId, nil
  102. }()
  103. if res != nil {
  104. defer res.Close()
  105. }
  106. if from != Answer_ChatGPT {
  107. if errMsg != nil {
  108. _ = ws.WriteJSON(g.Map{"error_code": -1, "error_msg": errMsg.Error(), "data": nil})
  109. } else {
  110. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": content, "isEnd": true}})
  111. }
  112. } else if res != nil {
  113. buf, lastData := bufio.NewReader(res), &BufRes{}
  114. isEmpty := true
  115. for {
  116. line, _, err := buf.ReadLine()
  117. if err == io.EOF {
  118. //放回链接池
  119. ChatGptPool.Add()
  120. finalReply := If(isEmpty, g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String(), lastData.Response).(string)
  121. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  122. Content: finalReply,
  123. Type: 2,
  124. Actions: gconv.Int(If(isEmpty, 0, 1)),
  125. QuestionId: questionId,
  126. PersonId: jSession.PositionId,
  127. Item: Answer_ChatGPT,
  128. CreateTime: time.Now().Format(date.Date_Full_Layout),
  129. })
  130. if !isEmpty {
  131. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": finalReply, "isEnd": true}})
  132. } else {
  133. _ = ws.WriteJSON(g.Map{"error_code": -1, "error_msg": finalReply})
  134. }
  135. break
  136. }
  137. if _, data := parseEventStream(line); data != nil {
  138. isEmpty = false
  139. data.Response = fsw.Repl(data.Response)
  140. lastData = data
  141. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"reply": lastData.Response, "isEnd": false}})
  142. }
  143. }
  144. }
  145. }
  146. func parseEventStream(line []byte) (event string, date *BufRes) {
  147. // 如果行以 "event:" 开头,表示这是一个事件的标识符
  148. if len(line) > 6 && string(line[:6]) == "event:" {
  149. event = string(line[6 : len(line)-1])
  150. return event, nil
  151. }
  152. // 如果行以 "data:" 开头,表示这是事件的数据部分
  153. if len(line) > 5 && string(line[:5]) == "data:" {
  154. date = &BufRes{}
  155. if err := json.Unmarshal(line[5:len(line)], date); err == nil {
  156. return
  157. }
  158. return event, nil
  159. }
  160. return "", nil
  161. }