1
0
hliang 3 сар өмнө
parent
commit
ca03c6ec19
2 өөрчлөгдсөн 54 нэмэгдсэн , 29 устгасан
  1. 14 3
      core/api.go
  2. 40 26
      core/engine.go

+ 14 - 3
core/api.go

@@ -18,7 +18,7 @@ var (
 	upGrader = websocket.Upgrader{
 		CheckOrigin: func(r *http.Request) bool { return true },
 	}
-	gm        = &sync.Mutex{}
+	rwMu      sync.RWMutex
 	hlSyncMap sync.Map
 )
 
@@ -49,6 +49,17 @@ type Clients struct {
 	clientWs    *websocket.Conn
 }
 
+func (c *Clients) readFromMap(funcName string, MessageId string) chan string {
+	rwMu.RLock()
+	defer rwMu.RUnlock()
+	return c.actionData[funcName][MessageId]
+}
+func (c *Clients) writeToMap(funcName string, MessageId string, msg string) {
+	rwMu.Lock()
+	defer rwMu.Unlock()
+	c.actionData[funcName][MessageId] <- msg
+}
+
 // NewClient  initializes a new Clients instance
 func NewClient(group string, uid string, ws *websocket.Conn) *Clients {
 	return &Clients{
@@ -104,10 +115,10 @@ func ws(c *gin.Context) {
 		messageId := messageStruct.MessageId
 		msg := messageStruct.ResponseData
 		// 这里直接给管道塞数据,那么之前发送的时候要初始化好
-		if client.actionData[action][messageId] == nil {
+		if client.readFromMap(action, messageId) == nil {
 			log.Warning("当前消息id:", messageId, " 已被超时释放,回调的数据不做处理")
 		} else {
-			client.actionData[action][messageId] <- msg
+			client.writeToMap(action, messageId, msg)
 		}
 		if len(msg) > 100 {
 			utils.LogPrint("id", messageId, "get_message:", msg[:101]+"......")

+ 40 - 26
core/engine.go

@@ -3,57 +3,71 @@ package core
 import (
 	"JsRpc/config"
 	"JsRpc/utils"
+	"context"
 	"encoding/json"
 	log "github.com/sirupsen/logrus"
 	"math/rand"
-	"sync"
 	"time"
 )
 
-var wsMutex sync.Mutex
-
 // GQueryFunc 发送请求到客户端
 func (c *Clients) GQueryFunc(funcName string, param string, resChan chan<- string) {
 	if c.actionData[funcName] == nil {
+		rwMu.Lock()
 		c.actionData[funcName] = make(map[string]chan string)
+		rwMu.Unlock()
 	}
-	MessageId := ""
-	gm.Lock()
+	var MessageId string
 	for {
 		MessageId = utils.GetUUID()
-		// 先判断action是否需要初始化
-		if c.actionData[funcName][MessageId] == nil {
-			c.actionData[funcName][MessageId] = make(chan string, 1) //此次action初始化1个消息
-			//只有不存在的MessageId才会继续,
+		if c.readFromMap(funcName, MessageId) == nil {
+			rwMu.Lock()
+			c.actionData[funcName][MessageId] = make(chan string, 1)
+			rwMu.Unlock()
 			break
-		} else {
-			utils.LogPrint("存在的消息id,跳过")
 		}
+		utils.LogPrint("存在的消息id,跳过")
 	}
-	gm.Unlock()
+	// 确保资源释放
+	defer func() {
+		rwMu.Lock()
+		delete(c.actionData[funcName], MessageId)
+		rwMu.Unlock()
+		close(resChan)
+	}()
+
+	// 构造消息并发送
 	WriteData := Message{Param: param, MessageId: MessageId, Action: funcName}
-	data, _ := json.Marshal(WriteData)
-	clientWs := c.clientWs
-	wsMutex.Lock()
-	err := clientWs.WriteMessage(1, data)
+	data, err := json.Marshal(WriteData)
+	if err != nil {
+		log.Error(err, "JSON序列化失败")
+		resChan <- "JSON序列化失败"
+		return
+	}
+
+	rwMu.Lock()
+	err = c.clientWs.WriteMessage(1, data)
+	rwMu.Unlock()
 	if err != nil {
 		log.Error(err, "写入数据失败")
 		resChan <- "rpc发送数据失败"
+		return
+	}
+	// 使用 context 控制超时
+	ctx, cancel := context.WithTimeout(context.Background(), time.Duration(config.DefaultTimeout)*time.Second)
+	defer cancel()
+	resultChan := c.readFromMap(funcName, MessageId)
+	if resultChan == nil {
+		resChan <- "消息ID对应的管道不存在"
+		return
 	}
 	select {
-	case res := <-c.actionData[funcName][MessageId]:
+	case res := <-resultChan:
 		resChan <- res
-	case <-time.After(time.Duration(config.DefaultTimeout) * time.Second):
+	case <-ctx.Done():
 		utils.LogPrint(MessageId + "超时了")
-		resChan <- "黑脸怪:timeout"
+		resChan <- "获取结果超时 timeout"
 	}
-	wsMutex.Unlock()
-	// 清理资源
-	gm.Lock()
-	delete(c.actionData[funcName], MessageId)
-	gm.Unlock()
-
-	close(resChan)
 }
 
 func getRandomClient(group string, clientId string) *Clients {