56 lines
1.4 KiB
Go
56 lines
1.4 KiB
Go
|
|
package websocket
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"fmt"
|
|||
|
|
"net/http"
|
|||
|
|
|
|||
|
|
"github.com/gin-gonic/gin"
|
|||
|
|
"github.com/go-admin-team/go-admin-core/logger"
|
|||
|
|
"github.com/go-admin-team/go-admin-core/sdk/pkg/jwtauth"
|
|||
|
|
"github.com/gorilla/websocket"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
var upgrader = websocket.Upgrader{
|
|||
|
|
CheckOrigin: func(r *http.Request) bool {
|
|||
|
|
return true // 允许跨域
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func ServeWS(hub *Hub) gin.HandlerFunc {
|
|||
|
|
return func(c *gin.Context) {
|
|||
|
|
// 这里假设 JWT 中间件已验证,且用户 ID 在 Context
|
|||
|
|
claims := jwtauth.ExtractClaims(c)
|
|||
|
|
userID, ok := claims["identity"].(float64)
|
|||
|
|
if !ok {
|
|||
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid user ID in token"})
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 设置 Subprotocols 支持客户端传来的协议(token等)
|
|||
|
|
upgrader.Subprotocols = []string{c.GetHeader("Sec-WebSocket-Protocol")}
|
|||
|
|
|
|||
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|||
|
|
if err != nil {
|
|||
|
|
logger.Errorf("WebSocket upgrade failed: %v", err)
|
|||
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "upgrade failed"})
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|||
|
|
|
|||
|
|
client := &Client{
|
|||
|
|
ID: fmt.Sprintf("%v", userID),
|
|||
|
|
Conn: conn,
|
|||
|
|
Send: make(chan []byte, 1024),
|
|||
|
|
Context: ctx,
|
|||
|
|
CancelFunc: cancel,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
hub.Register <- client
|
|||
|
|
|
|||
|
|
go client.Read(hub)
|
|||
|
|
go client.Write()
|
|||
|
|
}
|
|||
|
|
}
|