|
@@ -4,7 +4,6 @@ import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|
|
- "google.golang.org/grpc"
|
|
|
"log"
|
|
|
. "rpc/chat"
|
|
|
"rpc/config"
|
|
@@ -12,25 +11,39 @@ import (
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
-type chatServer struct {
|
|
|
+var Chatserver *ChatServer
|
|
|
+
|
|
|
+func init() {
|
|
|
+ Chatserver = NewChatServer()
|
|
|
+}
|
|
|
+
|
|
|
+type ChatServer struct {
|
|
|
UnimplementedChatServiceServer
|
|
|
- clients map[string]chan *Message
|
|
|
- adminMsg chan *Message
|
|
|
- mu sync.RWMutex
|
|
|
+ clients map[string]chan *Message
|
|
|
+ adminMsg chan *Message
|
|
|
+ mu sync.RWMutex
|
|
|
+ shutdownChan chan struct{} // 关闭信号通道
|
|
|
}
|
|
|
|
|
|
-var ChatSrv = &chatServer{
|
|
|
- clients: make(map[string]chan *Message),
|
|
|
- adminMsg: make(chan *Message, 100),
|
|
|
+func NewChatServer() *ChatServer {
|
|
|
+ return &ChatServer{
|
|
|
+ clients: make(map[string]chan *Message),
|
|
|
+ adminMsg: make(chan *Message, 100),
|
|
|
+ shutdownChan: make(chan struct{}),
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// 建立连接
|
|
|
-func (s *chatServer) JoinChat(req *JoinRequest, stream ChatService_JoinChatServer) error {
|
|
|
+func (s *ChatServer) JoinChat(req *JoinRequest, stream ChatService_JoinChatServer) error {
|
|
|
// 创建新通道
|
|
|
msgChan := make(chan *Message, 100)
|
|
|
|
|
|
// 注册客户端
|
|
|
s.mu.Lock()
|
|
|
+ if _, exists := s.clients[req.UserId]; exists {
|
|
|
+ s.mu.Unlock()
|
|
|
+ return fmt.Errorf("用户 %s 已连接", req.UserId)
|
|
|
+ }
|
|
|
s.clients[req.UserId] = msgChan
|
|
|
s.mu.Unlock()
|
|
|
|
|
@@ -41,46 +54,39 @@ func (s *chatServer) JoinChat(req *JoinRequest, stream ChatService_JoinChatServe
|
|
|
Timestamp: time.Now().Unix(),
|
|
|
}
|
|
|
if err := stream.Send(welcomeMsg); err != nil {
|
|
|
- s.mu.Lock()
|
|
|
- delete(s.clients, req.UserId)
|
|
|
- s.mu.Unlock()
|
|
|
+ s.removeClient(req.UserId)
|
|
|
return err
|
|
|
}
|
|
|
-
|
|
|
// 清理处理
|
|
|
- defer func() {
|
|
|
- s.mu.Lock()
|
|
|
- if ch, exists := s.clients[req.UserId]; exists {
|
|
|
- delete(s.clients, req.UserId)
|
|
|
- close(ch)
|
|
|
- }
|
|
|
- s.mu.Unlock()
|
|
|
- }()
|
|
|
+ defer s.removeClient(req.UserId)
|
|
|
+ // 消息循环
|
|
|
for {
|
|
|
select {
|
|
|
case msg := <-msgChan:
|
|
|
- if err := stream.Send(msg); err != nil {
|
|
|
+ if err := s.sendWithTimeout(stream, msg, 5*time.Second); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
case adminMsg := <-s.adminMsg:
|
|
|
- if err := stream.Send(adminMsg); err != nil {
|
|
|
+ if err := s.sendWithTimeout(stream, adminMsg, 5*time.Second); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
case <-stream.Context().Done():
|
|
|
return nil
|
|
|
+ case <-s.shutdownChan:
|
|
|
+ return nil
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 接收消息处理
|
|
|
-func (s *chatServer) SendMessage(ctx context.Context, msg *Message) (*MessageAck, error) {
|
|
|
+func (s *ChatServer) SendMessage(ctx context.Context, msg *Message) (*MessageAck, error) {
|
|
|
msg.Timestamp = time.Now().Unix()
|
|
|
- log.Printf("收到来自 %s 的 %s 消息: %s", msg.UserId, msg.Action, msg.Text)
|
|
|
+ log.Printf("收到来自 %s 的 %s 消息: %s\n", msg.UserId, msg.Action, msg.Text)
|
|
|
|
|
|
// 先处理业务逻辑
|
|
|
switch msg.Action {
|
|
|
case "getContacts":
|
|
|
- log.Printf("接收%s通讯录信息", msg.UserId)
|
|
|
+ log.Printf("接收%s通讯录信息\n", msg.UserId)
|
|
|
//go SynchronousContacts(msg.UserId, msg.Text)
|
|
|
case "chatHistory":
|
|
|
go AddChatRecord(msg.UserId, msg.Text) // 异步处理
|
|
@@ -99,7 +105,7 @@ func (s *chatServer) SendMessage(ctx context.Context, msg *Message) (*MessageAck
|
|
|
select {
|
|
|
case ch <- msg:
|
|
|
default:
|
|
|
- log.Printf("客户端 %s 的消息通道已满", userId)
|
|
|
+ log.Printf("客户端 %s 的消息通道已满\n", userId)
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -107,7 +113,7 @@ func (s *chatServer) SendMessage(ctx context.Context, msg *Message) (*MessageAck
|
|
|
}
|
|
|
|
|
|
// SendAdminMessage 向指定用户发送系统消息
|
|
|
-func (s *chatServer) SendAdminMessage(userId string, text string, action string) error {
|
|
|
+func (s *ChatServer) SendAdminMessage(userId string, text string, action string) error {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
// 检查目标用户是否存在
|
|
@@ -125,41 +131,96 @@ func (s *chatServer) SendAdminMessage(userId string, text string, action string)
|
|
|
// 发送消息
|
|
|
select {
|
|
|
case msgChan <- msg:
|
|
|
- log.Printf("已向用户 %s 发送系统消息: %s", userId, text)
|
|
|
+ log.Printf("已向用户 %s 发送系统消息: %s\n", userId, text)
|
|
|
return nil
|
|
|
default:
|
|
|
return fmt.Errorf("用户 %s 的消息通道已满", userId)
|
|
|
}
|
|
|
}
|
|
|
+func (s *ChatServer) StartTimedMessages(ctx context.Context, interval time.Duration, action string) {
|
|
|
+ // 立即执行一次任务
|
|
|
+ s.executeTimedAction(ctx, action)
|
|
|
|
|
|
-// StartTimedMessages 启动定时消息发送
|
|
|
-func (s *chatServer) StartTimedMessages(interval time.Duration, action string) {
|
|
|
ticker := time.NewTicker(interval)
|
|
|
defer ticker.Stop()
|
|
|
- for range ticker.C {
|
|
|
- message := fmt.Sprintf("系统定时消息: 当前时间 %v", time.Now().Format("2006-01-02 15:04:05"))
|
|
|
- // 快速获取客户端列表
|
|
|
- s.mu.RLock()
|
|
|
- clients := make([]string, 0, len(s.clients))
|
|
|
- for userId := range s.clients {
|
|
|
- clients = append(clients, userId)
|
|
|
+
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ticker.C:
|
|
|
+ s.executeTimedAction(ctx, action)
|
|
|
+ case <-ctx.Done():
|
|
|
+ log.Printf("定时任务[%s]已停止", action)
|
|
|
+ return
|
|
|
+ case <-s.shutdownChan:
|
|
|
+ log.Printf("服务关闭,停止定时任务[%s]", action)
|
|
|
+ return
|
|
|
}
|
|
|
- s.mu.RUnlock()
|
|
|
-
|
|
|
- // 处理发送(无需在锁内)
|
|
|
- switch action {
|
|
|
- case "getContacts":
|
|
|
- s.BroadcastAdminMessage(message, "getContacts")
|
|
|
- case "sendTalk":
|
|
|
- Task()
|
|
|
- case "listenIn":
|
|
|
- s.BroadcastAdminMessage(message, "isnline")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (s *ChatServer) executeTimedAction(ctx context.Context, action string) {
|
|
|
+ defer func() {
|
|
|
+ if r := recover(); r != nil {
|
|
|
+ log.Printf("定时任务[%s]执行出错: %v\n", action, r)
|
|
|
}
|
|
|
+ }()
|
|
|
+
|
|
|
+ startTime := time.Now()
|
|
|
+ log.Printf("开始执行定时任务[%s]\n", action)
|
|
|
+ message := fmt.Sprintf("系统定时消息: 当前时间 %v", startTime.Format("2006-01-02 15:04:05"))
|
|
|
+ // 使用更安全的方式获取客户端列表
|
|
|
+ clients := s.getClientsSnapshot()
|
|
|
+ if len(clients) > 0 {
|
|
|
+ log.Printf("当前在线客户端数: %d\n", len(clients))
|
|
|
+ }
|
|
|
+
|
|
|
+ // 根据action执行不同操作
|
|
|
+ switch action {
|
|
|
+ case "getContacts":
|
|
|
+ s.BroadcastAdminMessage(message, "getContacts")
|
|
|
+ case "sendTalk":
|
|
|
+ s.executeTask(ctx)
|
|
|
+ case "heartbeat":
|
|
|
+ s.BroadcastAdminMessage(message, "heartbeat")
|
|
|
+ default:
|
|
|
+ log.Printf("未知的定时任务类型: %s\n", action)
|
|
|
+ }
|
|
|
+
|
|
|
+ log.Printf("完成定时任务[%s], 耗时: %v \n", action, time.Since(startTime))
|
|
|
+}
|
|
|
+
|
|
|
+func (s *ChatServer) getClientsSnapshot() []string {
|
|
|
+ s.mu.RLock()
|
|
|
+ defer s.mu.RUnlock()
|
|
|
+
|
|
|
+ clients := make([]string, 0, len(s.clients))
|
|
|
+ for userId := range s.clients {
|
|
|
+ clients = append(clients, userId)
|
|
|
+ }
|
|
|
+ return clients
|
|
|
+}
|
|
|
+
|
|
|
+func (s *ChatServer) executeTask(ctx context.Context) {
|
|
|
+ // 为Task操作添加超时控制
|
|
|
+ taskCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ done := make(chan struct{})
|
|
|
+ go func() {
|
|
|
+ defer close(done)
|
|
|
+ Task() // 假设Task是您定义的任务函数
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-done:
|
|
|
+ log.Println("Task执行完成")
|
|
|
+ case <-taskCtx.Done():
|
|
|
+ log.Println("Task执行超时")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// BroadcastAdminMessage 向所有客户端广播系统消息
|
|
|
-func (s *chatServer) BroadcastAdminMessage(text string, action string) {
|
|
|
+func (s *ChatServer) BroadcastAdminMessage(text string, action string) {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
msg := &Message{
|
|
@@ -171,15 +232,15 @@ func (s *chatServer) BroadcastAdminMessage(text string, action string) {
|
|
|
for userId, ch := range s.clients {
|
|
|
select {
|
|
|
case ch <- msg:
|
|
|
- log.Printf("已广播系统消息到用户 %s: %s", userId, text)
|
|
|
+ log.Printf("已广播系统消息到用户 %s: %s\n", userId, text)
|
|
|
default:
|
|
|
- log.Printf("用户 %s 的消息通道已满,无法广播", userId)
|
|
|
+ log.Printf("用户 %s 的消息通道已满,无法广播\n", userId)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// SpecifyAdminMessage 向制定客户端广播系统消息
|
|
|
-func (s *chatServer) SpecifyAdminMessage(taskId int64, userMap map[string]interface{}, contentData *[]map[string]interface{}, action, batchCode string) error {
|
|
|
+func (s *ChatServer) SpecifyAdminMessage(taskId int64, userMap map[string]interface{}, contentData *[]map[string]interface{}, action, batchCode string) error {
|
|
|
userId := gconv.String(userMap["userId"])
|
|
|
isRefuse := gconv.Int64(userMap["isRefuse"])
|
|
|
if isRefuse == 1 {
|
|
@@ -221,22 +282,22 @@ func (s *chatServer) SpecifyAdminMessage(taskId int64, userMap map[string]interf
|
|
|
}
|
|
|
select {
|
|
|
case ch <- msg:
|
|
|
- log.Printf("系统消息已发送到用户 %s: %s (Action: %s)", userId, text, action)
|
|
|
+ log.Printf("系统消息已发送到用户 %s: %s (Action: %s)\n", userId, text, action)
|
|
|
return nil
|
|
|
default:
|
|
|
- log.Printf("用户 %s 的消息通道已满,丢弃消息", userId)
|
|
|
+ log.Printf("用户 %s 的消息通道已满,丢弃消息\n", userId)
|
|
|
return fmt.Errorf("用户 %s 的消息通道阻塞", userId)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// SpecifysystemMessage 向制定客户端广播系统消息(拒绝也发,不保存发送记录)
|
|
|
-func (s *chatServer) SpecifysystemMessage(userId, wxId string, contentData map[string]interface{}, action string) error {
|
|
|
+func (s *ChatServer) SpecifysystemMessage(userId, wxId string, contentData map[string]interface{}, action string) error {
|
|
|
// 1. 加锁并获取用户channel
|
|
|
s.mu.Lock()
|
|
|
ch, exists := s.clients[userId]
|
|
|
if !exists {
|
|
|
s.mu.Unlock()
|
|
|
- log.Printf("用户 %s 不存在或已离线 (wxId: %s)", userId, wxId)
|
|
|
+ log.Printf("用户 %s 不存在或已离线 (wxId: %s)\n", userId, wxId)
|
|
|
return fmt.Errorf("user %s not found", userId)
|
|
|
}
|
|
|
|
|
@@ -268,15 +329,72 @@ func buildMessageText(contentData map[string]interface{}, wxId string) string {
|
|
|
func trySendMessage(ch chan<- *Message, msg *Message, userId, action string) error {
|
|
|
select {
|
|
|
case ch <- msg:
|
|
|
- log.Printf("系统消息发送成功 | 用户: %s | 动作: %s", userId, action)
|
|
|
+ log.Printf("系统消息发送成功 | 用户: %s | 动作: %s\n", userId, action)
|
|
|
return nil
|
|
|
default:
|
|
|
- log.Printf("消息通道已满 | 用户: %s | 动作: %s", userId, action)
|
|
|
+ log.Printf("消息通道已满 | 用户: %s | 动作: %s\n", userId, action)
|
|
|
return fmt.Errorf("message queue full for user %s", userId)
|
|
|
}
|
|
|
}
|
|
|
-func logPings(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
|
|
- // 只有客户端调用 RPC 方法时才会执行这里!
|
|
|
- log.Println("RPC called:", info.FullMethod)
|
|
|
- return handler(ctx, req)
|
|
|
+
|
|
|
+func (s *ChatServer) Ping(ctx context.Context, req *PingRequest) (*PingResponse, error) {
|
|
|
+ return &PingResponse{Status: "OK"}, nil // 确认返回值类型匹配
|
|
|
+}
|
|
|
+
|
|
|
+// Shutdown 优雅关闭服务
|
|
|
+func (s *ChatServer) Shutdown() {
|
|
|
+ close(s.shutdownChan)
|
|
|
+
|
|
|
+ // 关闭所有客户端连接
|
|
|
+ s.mu.Lock()
|
|
|
+ defer s.mu.Unlock()
|
|
|
+
|
|
|
+ for userId, ch := range s.clients {
|
|
|
+ close(ch)
|
|
|
+ delete(s.clients, userId)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// removeClient 安全地移除客户端连接
|
|
|
+func (s *ChatServer) removeClient(userId string) {
|
|
|
+ s.mu.Lock()
|
|
|
+ defer s.mu.Unlock()
|
|
|
+
|
|
|
+ if ch, exists := s.clients[userId]; exists {
|
|
|
+ // 关闭通道前先检查是否已关闭
|
|
|
+ select {
|
|
|
+ case _, ok := <-ch:
|
|
|
+ if ok {
|
|
|
+ close(ch) // 只有通道未关闭时才关闭它
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ close(ch)
|
|
|
+ }
|
|
|
+ delete(s.clients, userId)
|
|
|
+ log.Printf("客户端 %s 已断开连接", userId)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// sendWithTimeout 带超时的消息发送
|
|
|
+func (s *ChatServer) sendWithTimeout(stream ChatService_JoinChatServer, msg *Message, timeout time.Duration) error {
|
|
|
+ ctx, cancel := context.WithTimeout(stream.Context(), timeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ done := make(chan error, 1)
|
|
|
+ go func() {
|
|
|
+ done <- stream.Send(msg)
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case err := <-done:
|
|
|
+ return err
|
|
|
+ case <-ctx.Done():
|
|
|
+ // 超时后检查原始上下文是否已取消
|
|
|
+ select {
|
|
|
+ case <-stream.Context().Done():
|
|
|
+ return stream.Context().Err()
|
|
|
+ default:
|
|
|
+ return fmt.Errorf("消息发送超时")
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|