ws.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package model
  2. import (
  3. "aiChat/utility/fsw"
  4. . "app.yhyue.com/moapp/jybase/common"
  5. "app.yhyue.com/moapp/jybase/date"
  6. "app.yhyue.com/moapp/jybase/encrypt"
  7. "bufio"
  8. "context"
  9. "encoding/json"
  10. "fmt"
  11. "github.com/gogf/gf/v2/encoding/gjson"
  12. "github.com/gogf/gf/v2/frame/g"
  13. "github.com/gogf/gf/v2/net/ghttp"
  14. "github.com/gogf/gf/v2/util/gconv"
  15. "io"
  16. "time"
  17. )
  18. type WsChat struct {
  19. Ctx context.Context
  20. }
  21. func NewMessage(ctx context.Context) *WsChat {
  22. return &WsChat{
  23. Ctx: ctx,
  24. }
  25. }
  26. // Handle 处理消息
  27. func (m *WsChat) Handle(ws *ghttp.WebSocket, msg []byte) {
  28. defer Catch()
  29. jSession := SessionCtx.Get(m.Ctx).JSession
  30. if jSession.PositionId == 0 {
  31. _ = ws.WriteJSON(g.Map{
  32. "error_code": -1,
  33. "error_msg": "请登录",
  34. })
  35. return
  36. }
  37. req, from := &QuestionReq{}, 0
  38. if err := gjson.Unmarshal(msg, req); err != nil {
  39. g.Log().Errorf(m.Ctx, "%d 接收消息Unmarshal出错:%v", jSession.PositionId, err)
  40. return
  41. }
  42. questionId := ChatHistory.Save(m.Ctx, &ChatRecord{
  43. Content: req.Prompt,
  44. Type: 1,
  45. Refer: req.Href,
  46. PersonId: jSession.PositionId,
  47. CreateTime: time.Now().Format(date.Date_Full_Layout),
  48. })
  49. content, res, replyId, errMsg := func() (string, io.ReadCloser, int64, error) {
  50. var (
  51. err error
  52. res io.ReadCloser
  53. reply string
  54. )
  55. errReply := func() string {
  56. // 校验是否在黑名单,黑名单不返回内容
  57. if UserBlackList.CheckBlackList(m.Ctx, jSession.PositionId) {
  58. return g.Cfg().MustGet(m.Ctx, "limit.blackMsg").String()
  59. }
  60. // 校验问答频率
  61. if ChatLimit.GetBucket(m.Ctx, jSession.PositionId).TakeAvailable(1) == 0 {
  62. return g.Cfg().MustGet(m.Ctx, "limit.exceedMsg").String()
  63. }
  64. // 问题敏感词过滤
  65. if fsw.Match(req.Prompt) {
  66. return g.Cfg().MustGet(m.Ctx, "limit.fswMsg").String()
  67. }
  68. return ""
  69. }()
  70. if errReply != "" {
  71. reply, from = errReply, -1
  72. } else {
  73. reply, res, from, err = Question.DetailQuestion(m.Ctx, req)
  74. if err != nil {
  75. g.Log().Error(m.Ctx, "问答异常", err)
  76. reply = g.Cfg().MustGet(m.Ctx, "limit.errMsg").String()
  77. errReply = g.Cfg().MustGet(m.Ctx, "limit.errMsg").String()
  78. }
  79. }
  80. if from == Answer_ChatGPT {
  81. return reply, res, 0, nil
  82. }
  83. if reply == "" {
  84. reply = g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String()
  85. errReply = g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String()
  86. }
  87. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  88. Content: reply,
  89. Type: 2,
  90. Actions: gconv.Int(If(errReply == "", 1, 0)),
  91. QuestionId: questionId,
  92. PersonId: jSession.PositionId,
  93. Item: from,
  94. CreateTime: time.Now().Format(date.Date_Full_Layout),
  95. })
  96. if replyId <= 0 {
  97. g.Log().Error(m.Ctx, "问答存储存储异常")
  98. }
  99. if errReply != "" {
  100. return reply, nil, replyId, fmt.Errorf(errReply)
  101. }
  102. return reply, nil, replyId, nil
  103. }()
  104. if res != nil {
  105. defer res.Close()
  106. }
  107. if from != Answer_ChatGPT {
  108. if errMsg != nil {
  109. _ = ws.WriteJSON(g.Map{"error_code": -1, "error_msg": errMsg.Error(), "data": nil})
  110. } else {
  111. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": content, "isEnd": true}})
  112. }
  113. } else if res != nil {
  114. buf, lastData := bufio.NewReader(res), &BufRes{}
  115. isEmpty := true
  116. for {
  117. line, _, err := buf.ReadLine()
  118. if err == io.EOF {
  119. //放回链接池
  120. ChatGptPool.Add()
  121. finalReply := If(isEmpty, g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String(), lastData.Response).(string)
  122. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  123. Content: finalReply,
  124. Type: 2,
  125. Actions: gconv.Int(If(isEmpty, 0, 1)),
  126. QuestionId: questionId,
  127. PersonId: jSession.PositionId,
  128. Item: Answer_ChatGPT,
  129. CreateTime: time.Now().Format(date.Date_Full_Layout),
  130. })
  131. if !isEmpty {
  132. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": finalReply, "isEnd": true}})
  133. } else {
  134. _ = ws.WriteJSON(g.Map{"error_code": -1, "error_msg": finalReply})
  135. }
  136. break
  137. }
  138. if _, data := parseEventStream(line); data != nil {
  139. isEmpty = false
  140. data.Response = fsw.Repl(data.Response)
  141. lastData = data
  142. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"reply": lastData.Response, "isEnd": false}})
  143. }
  144. }
  145. }
  146. }
  147. func parseEventStream(line []byte) (event string, date *BufRes) {
  148. // 如果行以 "event:" 开头,表示这是一个事件的标识符
  149. if len(line) > 6 && string(line[:6]) == "event:" {
  150. event = string(line[6 : len(line)-1])
  151. return event, nil
  152. }
  153. // 如果行以 "data:" 开头,表示这是事件的数据部分
  154. if len(line) > 5 && string(line[:5]) == "data:" {
  155. date = &BufRes{}
  156. if err := json.Unmarshal(line[5:len(line)], date); err == nil {
  157. return
  158. }
  159. return event, nil
  160. }
  161. return "", nil
  162. }