瀏覽代碼

Update main.go

hliang 2 年之前
父節點
當前提交
9eb384f9a9
共有 1 個文件被更改,包括 42 次插入13 次删除
  1. 42 13
      main.go

+ 42 - 13
main.go

@@ -7,6 +7,7 @@ import (
 	"github.com/gorilla/websocket"
 	"github.com/unrolled/secure"
 	"net/http"
+	"os"
 	"strings"
 	"sync"
 )
@@ -21,9 +22,9 @@ var (
 	upGrader = websocket.Upgrader{
 		CheckOrigin: func(r *http.Request) bool { return true },
 	}
-	hlSyncMap sync.Map
-	gm        = &sync.Mutex{}
-	gchan     = make(chan string)
+	hlSyncMap      sync.Map
+	gm             = &sync.Mutex{}
+	gchan, isPrint = make(chan string), false
 )
 
 type Clients struct {
@@ -38,6 +39,13 @@ type Message struct {
 	Param  string `json:"param"`
 }
 
+// is print?
+func logPrint(p ...interface{}) {
+	if isPrint {
+		fmt.Println(p)
+	}
+}
+
 // NewClient  initializes a new Clients instance
 func NewClient(group string, name string, ws *websocket.Conn) *Clients {
 	return &Clients{
@@ -73,7 +81,7 @@ func ws(c *gin.Context) {
 		if strIndex >= 1 {
 			action := msg[:strIndex]
 			client.Data[action] = msg[strIndex+5:]
-			fmt.Println("get_message:", client.Data[action])
+			logPrint("get_message:", client.Data[action])
 			gchan <- msg[strIndex+5:]
 			hlSyncMap.Store(group+"->"+name, client)
 		} else {
@@ -83,7 +91,7 @@ func ws(c *gin.Context) {
 	}
 	defer func(ws *websocket.Conn) {
 		_ = ws.Close()
-		fmt.Println(group+"->"+name, "下线了")
+		logPrint(group+"->"+name, "下线了")
 		hlSyncMap.Range(func(key, value interface{}) bool {
 			//client, _ := value.(*Clients)
 			if key == group+"->"+name {
@@ -94,6 +102,23 @@ func ws(c *gin.Context) {
 	}(ws)
 }
 
+func wsTest(c *gin.Context) {
+	testClient, _ := upGrader.Upgrade(c.Writer, c.Request, nil)
+	for {
+		//等待数据
+		_, message, err := testClient.ReadMessage()
+		if err != nil {
+			break
+		}
+		msg := string(message)
+		logPrint("接收到测试消息", msg)
+		_ = testClient.WriteMessage(1, []byte(msg))
+	}
+	defer func(ws *websocket.Conn) {
+		_ = ws.Close()
+	}(testClient)
+}
+
 func GQueryFunc(client *Clients, funcName string, param string, resChan chan<- string) {
 	WriteDate := Message{}
 	WriteDate.Action = funcName
@@ -111,7 +136,6 @@ func GQueryFunc(client *Clients, funcName string, param string, resChan chan<- s
 		fmt.Println(err, "写入数据失败")
 	}
 	res := <-gchan
-	fmt.Printf("res: %v\n", res)
 	resChan <- res
 }
 
@@ -127,12 +151,12 @@ func ResultSet(c *gin.Context) {
 	}
 
 	if getGroup == "" || getName == "" {
-		c.String(200, "input group and name")
+		c.JSON(400, gin.H{"status": 400, "data": "input group and name"})
 		return
 	}
 	clientName, ok := hlSyncMap.Load(getGroup + "->" + getName)
 	if ok == false {
-		c.String(200, "注入了ws?没有找到当前组和名字")
+		c.JSON(400, gin.H{"status": 400, "data": "注入了ws?没有找到当前组和名字"})
 		return
 	}
 	if Action == "" {
@@ -148,7 +172,7 @@ func ResultSet(c *gin.Context) {
 	c2 := make(chan string)
 	go GQueryFunc(client, Action, Param, c2)
 	//把管道传过去,获得值就返回了
-	c.JSON(200, gin.H{"status": "200", "group": client.clientGroup, "name": client.clientName, "data": <-c2})
+	c.JSON(200, gin.H{"status": 200, "group": client.clientGroup, "name": client.clientName, "data": <-c2})
 
 }
 
@@ -163,13 +187,13 @@ func Execjs(c *gin.Context) {
 		getGroup, getName, JsCode = c.PostForm("group"), c.PostForm("name"), c.PostForm("jscode")
 	}
 	if getGroup == "" || getName == "" {
-		c.String(200, "input group and name")
+		c.JSON(400, gin.H{"status": 400, "data": "input group and name"})
 		return
 	}
-	fmt.Println(getGroup, getName, JsCode)
+	logPrint(getGroup, getName, JsCode)
 	clientName, ok := hlSyncMap.Load(getGroup + "->" + getName)
 	if ok == false {
-		c.String(200, "注入了ws?没有找到组和名字")
+		c.JSON(400, gin.H{"status": 400, "data": "注入了ws?没有找到当前组和名字"})
 		return
 	}
 	//取一个ws客户端
@@ -213,12 +237,17 @@ func TlsHandler() gin.HandlerFunc {
 }
 
 func main() {
-	//gin.SetMode(gin.ReleaseMode)
+	for _, v := range os.Args {
+		if v == "log" {
+			isPrint = true
+		}
+	}
 	r := gin.Default()
 	r.GET("/", Index)
 	r.GET("/go", ResultSet)
 	r.POST("/go", ResultSet)
 	r.GET("/ws", ws)
+	r.GET("/wst", wsTest)
 	r.GET("/execjs", Execjs)
 	r.POST("/execjs", Execjs)
 	r.GET("/list", getList)