Sfoglia il codice sorgente

整个逻辑调整

Heartfilia 1 anno fa
parent
commit
e7113e74bc
10 ha cambiato i file con 165 aggiunte e 109 eliminazioni
  1. 17 1
      config/config.go
  2. 45 90
      core/api.go
  3. 75 0
      core/engine.go
  4. 7 0
      core/routers.go
  5. 1 0
      go.mod
  6. 2 0
      go.sum
  7. 1 17
      main.go
  8. 8 0
      utils/code.go
  9. 9 0
      utils/hash.go
  10. 0 1
      utils/logger.go

+ 17 - 1
config/config.go

@@ -3,13 +3,29 @@ package config
 import (
 	"JsRpc/utils"
 	"errors"
+	"flag"
+	log "github.com/sirupsen/logrus"
 	"gopkg.in/yaml.v3"
 	"os"
 )
 
 var DefaultTimeout = 30
 
-func InitConf(path string) (ConfStruct, error) {
+func ReadConf() ConfStruct {
+	var ConfigPath string
+	// 定义命令行参数-c,后面跟着的是默认值以及参数说明
+	flag.StringVar(&ConfigPath, "c", "config.yaml", "指定配置文件的路径")
+	// 解析命令行参数
+	flag.Parse()
+
+	conf, err := initConf(ConfigPath)
+	if err != nil {
+		log.Errorln("读取配置文件错误,将使用默认配置运行。 ", err.Error())
+	}
+	return conf
+}
+
+func initConf(path string) (ConfStruct, error) {
 	defaultConf := ConfStruct{
 		BasicListen: `:12080`,
 		HttpsServices: HttpsConfig{

+ 45 - 90
core/api.go

@@ -3,22 +3,14 @@ package core
 import (
 	"JsRpc/config"
 	"JsRpc/utils"
-	"encoding/json"
-	"errors"
-	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gorilla/websocket"
 	log "github.com/sirupsen/logrus"
 	"github.com/unrolled/secure"
-	"math/rand"
 	"net/http"
-	"os"
-	"os/signal"
 	"strconv"
 	"strings"
 	"sync"
-	"syscall"
-	"time"
 )
 
 var (
@@ -75,8 +67,7 @@ func ws(c *gin.Context) {
 	}
 	//没有给客户端id的话 就用时间戳给他生成一个
 	if clientId == "" {
-		millisId := time.Now().UnixNano() / int64(time.Millisecond)
-		clientId = fmt.Sprintf("%d", millisId)
+		clientId = utils.GetUUID()
 	}
 	wsClient, err := upGrader.Upgrade(c.Writer, c.Request, nil)
 	if err != nil {
@@ -98,7 +89,12 @@ func ws(c *gin.Context) {
 		if strIndex >= 1 {
 			action := msg[:strIndex]
 			client.actionData[action] <- msg[strIndex+5:]
-			utils.LogPrint("get_message:", msg[strIndex+5:])
+			if len(msg) > 100 {
+				utils.LogPrint("get_message:", msg[strIndex+5:101]+"......")
+			} else {
+				utils.LogPrint("get_message:", msg[strIndex+5:])
+			}
+
 		} else {
 			log.Error(msg, "message error")
 		}
@@ -134,37 +130,44 @@ func wsTest(c *gin.Context) {
 	}(testClient)
 }
 
-// GQueryFunc 发送请求到客户端
-func (c *Clients) GQueryFunc(funcName string, param string, resChan chan<- string) {
-	WriteDate := Message{Param: param, Action: funcName}
-	data, _ := json.Marshal(WriteDate)
-	clientWs := c.clientWs
-	if c.actionData[funcName] == nil {
-		c.actionData[funcName] = make(chan string, 1) //此次action初始化1个消息
+func GetCookie(c *gin.Context) {
+	var RequestParam ApiParam
+	if err := c.ShouldBind(&RequestParam); err != nil {
+		GinJsonMsg(c, http.StatusBadRequest, err.Error())
+		return
 	}
-	gm.Lock()
-	err := clientWs.WriteMessage(1, data)
-	gm.Unlock()
-	if err != nil {
-		fmt.Println(err, "写入数据失败")
+	group := c.Query("group")
+	if group == "" {
+		GinJsonMsg(c, http.StatusBadRequest, "需要传入group")
+		return
 	}
-	resultFlag := false
-	for i := 0; i < config.DefaultTimeout*10; i++ {
-		if len(c.actionData[funcName]) > 0 {
-			res := <-c.actionData[funcName]
-			resChan <- res
-			resultFlag = true
-			break
-		}
-		time.Sleep(time.Millisecond * 100)
+
+	clientId := RequestParam.ClientId
+	client := getRandomClient(group, clientId)
+
+	c3 := make(chan string, 1)
+	go client.GQueryFunc("_execjs", utils.ConcatCode("document.cookie"), c3)
+	c.JSON(http.StatusOK, gin.H{"status": 200, "group": client.clientGroup, "clientId": client.clientId, "data": <-c3})
+}
+
+func GetHtml(c *gin.Context) {
+	var RequestParam ApiParam
+	if err := c.ShouldBind(&RequestParam); err != nil {
+		GinJsonMsg(c, http.StatusBadRequest, err.Error())
+		return
 	}
-	// 循环完了还是没有数据,那就超时退出
-	if true != resultFlag {
-		resChan <- "黑脸怪:timeout"
+	group := c.Query("group")
+	if group == "" {
+		GinJsonMsg(c, http.StatusBadRequest, "需要传入group")
+		return
 	}
-	defer func() {
-		close(resChan)
-	}()
+
+	clientId := RequestParam.ClientId
+	client := getRandomClient(group, clientId)
+
+	c3 := make(chan string, 1)
+	go client.GQueryFunc("_execjs", utils.ConcatCode("document.documentElement.outerHTML"), c3)
+	c.JSON(http.StatusOK, gin.H{"status": 200, "group": client.clientGroup, "clientId": client.clientId, "data": <-c3})
 }
 
 // GetResult 接收web请求参数,并发给客户端获取结果
@@ -198,40 +201,6 @@ func getResult(c *gin.Context) {
 
 }
 
-func getRandomClient(group string, clientId string) *Clients {
-	var client *Clients
-	// 不传递clientId时候,从group分组随便拿一个
-	if clientId != "" {
-		clientName, ok := hlSyncMap.Load(group + "->" + clientId)
-		if ok == false {
-			return nil
-		}
-		client, _ = clientName.(*Clients)
-		return client
-	}
-	groupClients := make([]*Clients, 0)
-	//循环读取syncMap 获取group名字的
-	hlSyncMap.Range(func(_, value interface{}) bool {
-		tmpClients, ok := value.(*Clients)
-		if !ok {
-			return true
-		}
-		if tmpClients.clientGroup == group {
-			groupClients = append(groupClients, tmpClients)
-		}
-		return true
-	})
-	if len(groupClients) == 0 {
-		return nil
-	}
-	// 使用随机数发生器
-	r := rand.New(rand.NewSource(time.Now().UnixNano()))
-	randomIndex := r.Intn(len(groupClients))
-	client = groupClients[randomIndex]
-	return client
-
-}
-
 func execjs(c *gin.Context) {
 	var RequestParam ApiParam
 	if err := c.ShouldBind(&RequestParam); err != nil {
@@ -322,17 +291,6 @@ func InitAPI(conf config.ConfStruct) {
 
 	setJsRpcRouters(router) // 核心路由
 
-	srv := &http.Server{
-		Addr:    conf.BasicListen,
-		Handler: router,
-	}
-
-	go func() {
-		if err := srv.ListenAndServe(); err != nil && !errors.Is(http.ErrServerClosed, err) {
-			log.Fatalf("listen: %s\n", err)
-		}
-	}()
-
 	var sb strings.Builder
 	sb.WriteString("当前监听地址:")
 	sb.WriteString(conf.BasicListen)
@@ -346,11 +304,8 @@ func InitAPI(conf config.ConfStruct) {
 	}
 	log.Infoln(sb.String())
 
-	//err := router.Run(conf.BasicListen)
-
-	// 设置优雅退出  按 ctrl+c 退出程序的时候 不会报错的那种
-	quit := make(chan os.Signal)
-	signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
-	<-quit
-
+	err := router.Run(conf.BasicListen)
+	if err != nil {
+		log.Errorln("服务启动失败..")
+	}
 }

+ 75 - 0
core/engine.go

@@ -1 +1,76 @@
 package core
+
+import (
+	"JsRpc/config"
+	"encoding/json"
+	"fmt"
+	"math/rand"
+	"time"
+)
+
+// GQueryFunc 发送请求到客户端
+func (c *Clients) GQueryFunc(funcName string, param string, resChan chan<- string) {
+	WriteData := Message{Param: param, Action: funcName}
+	data, _ := json.Marshal(WriteData)
+	clientWs := c.clientWs
+	if c.actionData[funcName] == nil {
+		c.actionData[funcName] = make(chan string, 1) //此次action初始化1个消息
+	}
+	gm.Lock()
+	err := clientWs.WriteMessage(1, data)
+	gm.Unlock()
+	if err != nil {
+		fmt.Println(err, "写入数据失败")
+	}
+	resultFlag := false
+	for i := 0; i < config.DefaultTimeout*10; i++ {
+		if len(c.actionData[funcName]) > 0 {
+			res := <-c.actionData[funcName]
+			resChan <- res
+			resultFlag = true
+			break
+		}
+		time.Sleep(time.Millisecond * 100)
+	}
+	// 循环完了还是没有数据,那就超时退出
+	if true != resultFlag {
+		resChan <- "黑脸怪:timeout"
+	}
+	defer func() {
+		close(resChan)
+	}()
+}
+
+func getRandomClient(group string, clientId string) *Clients {
+	var client *Clients
+	// 不传递clientId时候,从group分组随便拿一个
+	if clientId != "" {
+		clientName, ok := hlSyncMap.Load(group + "->" + clientId)
+		if ok == false {
+			return nil
+		}
+		client, _ = clientName.(*Clients)
+		return client
+	}
+	groupClients := make([]*Clients, 0)
+	//循环读取syncMap 获取group名字的
+	hlSyncMap.Range(func(_, value interface{}) bool {
+		tmpClients, ok := value.(*Clients)
+		if !ok {
+			return true
+		}
+		if tmpClients.clientGroup == group {
+			groupClients = append(groupClients, tmpClients)
+		}
+		return true
+	})
+	if len(groupClients) == 0 {
+		return nil
+	}
+	// 使用随机数发生器
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
+	randomIndex := r.Intn(len(groupClients))
+	client = groupClients[randomIndex]
+	return client
+
+}

+ 7 - 0
core/routers.go

@@ -8,6 +8,12 @@ func setJsRpcRouters(router *gin.Engine) {
 	// 核心部分的的路由组
 	router.GET("/", index)
 
+	page := router.Group("/page")
+	{
+		page.GET("/cookie", GetCookie)
+		page.GET("/html", GetHtml)
+	}
+
 	rpc := router.Group("/")
 	{
 		rpc.GET("go", getResult)
@@ -18,4 +24,5 @@ func setJsRpcRouters(router *gin.Engine) {
 		rpc.POST("execjs", execjs)
 		rpc.GET("list", getList)
 	}
+
 }

+ 1 - 0
go.mod

@@ -19,6 +19,7 @@ require (
 	github.com/go-playground/universal-translator v0.18.1 // indirect
 	github.com/go-playground/validator/v10 v10.14.0 // indirect
 	github.com/goccy/go-json v0.10.2 // indirect
+	github.com/google/uuid v1.6.0 // indirect
 	github.com/json-iterator/go v1.1.12 // indirect
 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
 	github.com/leodido/go-urn v1.2.4 // indirect

+ 2 - 0
go.sum

@@ -27,6 +27,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
 github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
 github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
 github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
 github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=

+ 1 - 17
main.go

@@ -4,28 +4,12 @@ import (
 	"JsRpc/config"
 	"JsRpc/core"
 	"JsRpc/utils"
-	"flag"
-	log "github.com/sirupsen/logrus"
 )
 
-func readConf() config.ConfStruct {
-	var ConfigPath string
-	// 定义命令行参数-c,后面跟着的是默认值以及参数说明
-	flag.StringVar(&ConfigPath, "c", "config.yaml", "指定配置文件的路径")
-	// 解析命令行参数
-	flag.Parse()
-
-	conf, err := config.InitConf(ConfigPath)
-	if err != nil {
-		log.Errorln("读取配置文件错误,将使用默认配置运行。 ", err.Error())
-	}
-	return conf
-}
-
 func main() {
 	utils.PrintJsRpc() // 开屏打印
-	baseConf := readConf()
 
+	baseConf := config.ReadConf()       // 读取日志信息
 	utils.InitLogger(baseConf.CloseLog) // 初始化日志
 	core.InitAPI(baseConf)              // 初始化api部分
 

+ 8 - 0
utils/code.go

@@ -0,0 +1,8 @@
+package utils
+
+import "fmt"
+
+func ConcatCode(code string) string {
+	// 拼接页面元素的js
+	return fmt.Sprintf("(function(){return %s;})()", code)
+}

+ 9 - 0
utils/hash.go

@@ -0,0 +1,9 @@
+package utils
+
+import "github.com/google/uuid"
+
+func GetUUID() string {
+	u := uuid.New()
+	key := u.String()
+	return key
+}

+ 0 - 1
utils/logger.go

@@ -24,7 +24,6 @@ func (w logWriter) Write(p []byte) (n int, err error) {
 	return len(p), nil
 }
 
-// is print?
 func LogPrint(p ...interface{}) {
 	if isPrint {
 		log.Infoln(p)