56 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			56 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package websocket
 | ||
| 
 | ||
| import (
 | ||
| 	"context"
 | ||
| 	"fmt"
 | ||
| 	"net/http"
 | ||
| 
 | ||
| 	"github.com/gin-gonic/gin"
 | ||
| 	"github.com/go-admin-team/go-admin-core/logger"
 | ||
| 	"github.com/go-admin-team/go-admin-core/sdk/pkg/jwtauth"
 | ||
| 	"github.com/gorilla/websocket"
 | ||
| )
 | ||
| 
 | ||
| var upgrader = websocket.Upgrader{
 | ||
| 	CheckOrigin: func(r *http.Request) bool {
 | ||
| 		return true // 允许跨域
 | ||
| 	},
 | ||
| }
 | ||
| 
 | ||
| func ServeWS(hub *Hub) gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 这里假设 JWT 中间件已验证,且用户 ID 在 Context
 | ||
| 		claims := jwtauth.ExtractClaims(c)
 | ||
| 		userID, ok := claims["identity"].(float64)
 | ||
| 		if !ok {
 | ||
| 			c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid user ID in token"})
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 设置 Subprotocols 支持客户端传来的协议(token等)
 | ||
| 		upgrader.Subprotocols = []string{c.GetHeader("Sec-WebSocket-Protocol")}
 | ||
| 
 | ||
| 		conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
 | ||
| 		if err != nil {
 | ||
| 			logger.Errorf("WebSocket upgrade failed: %v", err)
 | ||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": "upgrade failed"})
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		ctx, cancel := context.WithCancel(context.Background())
 | ||
| 
 | ||
| 		client := &Client{
 | ||
| 			ID:         fmt.Sprintf("%v", userID),
 | ||
| 			Conn:       conn,
 | ||
| 			Send:       make(chan []byte, 1024),
 | ||
| 			Context:    ctx,
 | ||
| 			CancelFunc: cancel,
 | ||
| 		}
 | ||
| 
 | ||
| 		hub.Register <- client
 | ||
| 
 | ||
| 		go client.Read(hub)
 | ||
| 		go client.Write()
 | ||
| 	}
 | ||
| }
 |