ws.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package model
  2. import (
  3. "aiChat/utility/fsw"
  4. . "app.yhyue.com/moapp/jybase/common"
  5. "app.yhyue.com/moapp/jybase/date"
  6. "app.yhyue.com/moapp/jybase/encrypt"
  7. "bufio"
  8. "context"
  9. "encoding/json"
  10. "fmt"
  11. "github.com/gogf/gf/v2/encoding/gjson"
  12. "github.com/gogf/gf/v2/frame/g"
  13. "github.com/gogf/gf/v2/net/ghttp"
  14. "github.com/gogf/gf/v2/os/glog"
  15. "github.com/gogf/gf/v2/util/gconv"
  16. "io"
  17. "time"
  18. )
  19. type WsChat struct {
  20. Ctx context.Context
  21. }
  22. func NewMessage(ctx context.Context) *WsChat {
  23. return &WsChat{
  24. Ctx: ctx,
  25. }
  26. }
  27. // Handle 处理消息
  28. func (m *WsChat) Handle(ws *ghttp.WebSocket, msg []byte) {
  29. defer Catch()
  30. jSession := SessionCtx.Get(m.Ctx).JSession
  31. if jSession.PositionId == 0 {
  32. _ = ws.WriteJSON(g.Map{
  33. "error_code": -1,
  34. "error_msg": "请登录",
  35. })
  36. return
  37. }
  38. req, from := &QuestionReq{}, 0
  39. if err := gjson.Unmarshal(msg, req); err != nil {
  40. glog.Errorf(m.Ctx, "%d 接收消息Unmarshal出错:%v", jSession.PositionId, err)
  41. return
  42. }
  43. questionId := ChatHistory.Save(m.Ctx, &ChatRecord{
  44. Content: req.Prompt,
  45. Type: 1,
  46. Refer: req.Href,
  47. PersonId: jSession.PositionId,
  48. CreateTime: time.Now().Format(date.Date_Full_Layout),
  49. })
  50. content, buf, replyId, errMsg := func() (string, *bufio.Reader, int64, error) {
  51. var (
  52. err error
  53. buf *bufio.Reader
  54. reply string
  55. )
  56. errReply := func() string {
  57. // 校验是否在黑名单,黑名单不返回内容
  58. if UserBlackList.CheckBlackList(m.Ctx, jSession.PositionId) {
  59. return g.Cfg().MustGet(m.Ctx, "limit.blackMsg").String()
  60. }
  61. // 校验问答频率
  62. if ChatLimit.GetBucket(m.Ctx, jSession.PositionId).TakeAvailable(1) == 0 {
  63. return g.Cfg().MustGet(m.Ctx, "limit.exceedMsg").String()
  64. }
  65. // 问题敏感词过滤
  66. if fsw.Match(req.Prompt) {
  67. return g.Cfg().MustGet(m.Ctx, "limit.fswMsg").String()
  68. }
  69. return ""
  70. }()
  71. if errReply != "" {
  72. reply, from = errReply, -1
  73. } else {
  74. reply, buf, from, err = Question.DetailQuestion(m.Ctx, req)
  75. if err != nil {
  76. g.Log().Error(m.Ctx, "问答异常", err)
  77. reply, from = g.Cfg().MustGet(m.Ctx, "limit.errMsg").String(), -1
  78. }
  79. }
  80. if from != Answer_ChatGPT {
  81. if reply == "" {
  82. reply, from = g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String(), -1
  83. }
  84. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  85. Content: reply,
  86. Type: 2,
  87. Actions: gconv.Int(If(errReply == "", 1, 0)),
  88. QuestionId: questionId,
  89. PersonId: jSession.PositionId,
  90. Item: gconv.Int(If(errReply == "", from, -1)),
  91. CreateTime: time.Now().Format(date.Date_Full_Layout),
  92. })
  93. if replyId <= 0 {
  94. g.Log().Error(m.Ctx, "问答存储存储异常")
  95. }
  96. return reply, nil, replyId, nil
  97. }
  98. return reply, buf, 0, nil
  99. }()
  100. if from != Answer_ChatGPT {
  101. if errMsg != nil {
  102. _ = ws.WriteJSON(g.Map{"error_code": -1, "error_msg": errMsg.Error(), "data": nil})
  103. } else {
  104. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": content, "isEnd": true}})
  105. }
  106. } else if buf != nil {
  107. for {
  108. line, _, err := buf.ReadLine()
  109. if err == io.EOF {
  110. break
  111. }
  112. if _, data := parseEventStream(line); data != nil {
  113. data.Response = fsw.Repl(data.Response)
  114. if data.Finished {
  115. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  116. Content: data.Response,
  117. Type: 2,
  118. Actions: 1,
  119. QuestionId: questionId,
  120. PersonId: jSession.PositionId,
  121. Item: Answer_ChatGPT,
  122. CreateTime: time.Now().Format(date.Date_Full_Layout),
  123. })
  124. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": data.Response, "isEnd": data.Finished}})
  125. } else {
  126. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"reply": data.Response, "isEnd": data.Finished}})
  127. }
  128. }
  129. }
  130. }
  131. }
  132. func parseEventStream(line []byte) (event string, date *BufRes) {
  133. // 如果行以 "event:" 开头,表示这是一个事件的标识符
  134. if len(line) > 6 && string(line[:6]) == "event:" {
  135. event = string(line[6 : len(line)-1])
  136. return event, nil
  137. }
  138. // 如果行以 "data:" 开头,表示这是事件的数据部分
  139. if len(line) > 5 && string(line[:5]) == "data:" {
  140. date = &BufRes{}
  141. if err := json.Unmarshal(line[5:len(line)], date); err == nil {
  142. return
  143. }
  144. return event, nil
  145. }
  146. return "", nil
  147. }