ws.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package model
  2. import (
  3. "aiChat/utility/fsw"
  4. . "app.yhyue.com/moapp/jybase/common"
  5. "app.yhyue.com/moapp/jybase/date"
  6. "app.yhyue.com/moapp/jybase/encrypt"
  7. "bufio"
  8. "context"
  9. "encoding/json"
  10. "fmt"
  11. "github.com/gogf/gf/v2/encoding/gjson"
  12. "github.com/gogf/gf/v2/frame/g"
  13. "github.com/gogf/gf/v2/net/ghttp"
  14. "github.com/gogf/gf/v2/util/gconv"
  15. "io"
  16. "strings"
  17. "time"
  18. )
  19. type WsChat struct {
  20. Ctx context.Context
  21. }
  22. func NewMessage(ctx context.Context) *WsChat {
  23. return &WsChat{
  24. Ctx: ctx,
  25. }
  26. }
  27. // Handle 处理消息
  28. func (m *WsChat) Handle(ws *ghttp.WebSocket, msg []byte) {
  29. defer Catch()
  30. jSession := SessionCtx.Get(m.Ctx).JSession
  31. if jSession.PositionId == 0 {
  32. _ = ws.WriteJSON(g.Map{
  33. "error_code": -1,
  34. "error_msg": "请登录",
  35. })
  36. return
  37. }
  38. req, from := &QuestionReq{}, 0
  39. if err := gjson.Unmarshal(msg, req); err != nil {
  40. g.Log().Errorf(m.Ctx, "%d 接收消息Unmarshal出错:%v", jSession.PositionId, err)
  41. return
  42. }
  43. questionId := ChatHistory.Save(m.Ctx, &ChatRecord{
  44. Content: req.Prompt,
  45. Type: 1,
  46. Refer: req.Href,
  47. PersonId: jSession.PositionId,
  48. CreateTime: time.Now().Format(date.Date_Full_Layout),
  49. })
  50. content, res, replyId, errMsg := func() (string, io.ReadCloser, int64, error) {
  51. var (
  52. err error
  53. res io.ReadCloser
  54. reply string
  55. )
  56. errReply := func() string {
  57. // 校验是否在黑名单,黑名单不返回内容
  58. if UserBlackList.CheckBlackList(m.Ctx, jSession.PositionId) {
  59. return g.Cfg().MustGet(m.Ctx, "limit.blackMsg").String()
  60. }
  61. // 校验问答频率
  62. if ChatLimit.GetBucket(m.Ctx, jSession.PositionId).TakeAvailable(1) == 0 {
  63. return g.Cfg().MustGet(m.Ctx, "limit.exceedMsg").String()
  64. }
  65. // 问题敏感词过滤
  66. if fsw.Match(req.Prompt) {
  67. return g.Cfg().MustGet(m.Ctx, "limit.fswMsg").String()
  68. }
  69. return ""
  70. }()
  71. if errReply != "" {
  72. reply, from = errReply, -1
  73. } else {
  74. reply, res, from, err = Question.DetailQuestion(m.Ctx, req)
  75. if err != nil {
  76. g.Log().Error(m.Ctx, "问答异常", err)
  77. reply = g.Cfg().MustGet(m.Ctx, "limit.errMsg").String()
  78. errReply = g.Cfg().MustGet(m.Ctx, "limit.errMsg").String()
  79. }
  80. }
  81. if from == Answer_ChatGPT {
  82. return reply, res, 0, nil
  83. }
  84. if reply == "" {
  85. reply = g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String()
  86. errReply = g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String()
  87. }
  88. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  89. Content: reply,
  90. Type: 2,
  91. Actions: gconv.Int(If(errReply == "", 1, 0)),
  92. QuestionId: questionId,
  93. PersonId: jSession.PositionId,
  94. Item: from,
  95. CreateTime: time.Now().Format(date.Date_Full_Layout),
  96. })
  97. if replyId <= 0 {
  98. g.Log().Error(m.Ctx, "问答存储存储异常")
  99. }
  100. if errReply != "" {
  101. return reply, nil, replyId, fmt.Errorf(errReply)
  102. }
  103. return reply, nil, replyId, nil
  104. }()
  105. if res != nil {
  106. defer res.Close()
  107. }
  108. if from != Answer_ChatGPT {
  109. if errMsg != nil {
  110. _ = ws.WriteJSON(g.Map{"error_code": -1, "error_msg": errMsg.Error(), "data": nil})
  111. } else {
  112. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": content, "isEnd": true}})
  113. }
  114. } else if res != nil {
  115. buf, lastData := bufio.NewReader(res), &BufRes{}
  116. isEmpty := true
  117. for {
  118. line, _, err := buf.ReadLine()
  119. if err == nil {
  120. break
  121. }
  122. if _, data := parseEventStream(line); data != nil && strings.TrimSpace(data.Response) != "" {
  123. data.Response = fsw.Repl(data.Response)
  124. lastData, isEmpty = data, false
  125. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"reply": lastData.Response, "isEnd": false}})
  126. }
  127. }
  128. ChatGptPool.Add() //放回链接池
  129. finalReply := If(isEmpty, g.Cfg().MustGet(m.Ctx, "limit.emptyMsg").String(), lastData.Response).(string)
  130. replyId := ChatHistory.Save(m.Ctx, &ChatRecord{
  131. Content: finalReply,
  132. Type: 2,
  133. Actions: gconv.Int(If(isEmpty, 0, 1)),
  134. QuestionId: questionId,
  135. PersonId: jSession.PositionId,
  136. Item: Answer_ChatGPT,
  137. CreateTime: time.Now().Format(date.Date_Full_Layout),
  138. })
  139. if !isEmpty {
  140. _ = ws.WriteJSON(g.Map{"error_code": 0, "error_msg": "", "data": g.Map{"id": encrypt.SE.Encode2Hex(fmt.Sprintf("%d", replyId)), "reply": finalReply, "isEnd": true}})
  141. } else {
  142. _ = ws.WriteJSON(g.Map{"error_code": -1, "error_msg": finalReply})
  143. }
  144. }
  145. }
  146. func parseEventStream(line []byte) (event string, date *BufRes) {
  147. // 如果行以 "event:" 开头,表示这是一个事件的标识符
  148. if len(line) > 6 && string(line[:6]) == "event:" {
  149. event = string(line[6 : len(line)-1])
  150. return event, nil
  151. }
  152. // 如果行以 "data:" 开头,表示这是事件的数据部分
  153. if len(line) > 5 && string(line[:5]) == "data:" {
  154. date = &BufRes{}
  155. if err := json.Unmarshal(line[5:len(line)], date); err == nil {
  156. return
  157. }
  158. return event, nil
  159. }
  160. return "", nil
  161. }