56 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			56 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| 
								 | 
							
								package websocket
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								import (
							 | 
						|||
| 
								 | 
							
									"context"
							 | 
						|||
| 
								 | 
							
									"fmt"
							 | 
						|||
| 
								 | 
							
									"net/http"
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
									"github.com/gin-gonic/gin"
							 | 
						|||
| 
								 | 
							
									"github.com/go-admin-team/go-admin-core/logger"
							 | 
						|||
| 
								 | 
							
									"github.com/go-admin-team/go-admin-core/sdk/pkg/jwtauth"
							 | 
						|||
| 
								 | 
							
									"github.com/gorilla/websocket"
							 | 
						|||
| 
								 | 
							
								)
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								var upgrader = websocket.Upgrader{
							 | 
						|||
| 
								 | 
							
									CheckOrigin: func(r *http.Request) bool {
							 | 
						|||
| 
								 | 
							
										return true // 允许跨域
							 | 
						|||
| 
								 | 
							
									},
							 | 
						|||
| 
								 | 
							
								}
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								func ServeWS(hub *Hub) gin.HandlerFunc {
							 | 
						|||
| 
								 | 
							
									return func(c *gin.Context) {
							 | 
						|||
| 
								 | 
							
										// 这里假设 JWT 中间件已验证,且用户 ID 在 Context
							 | 
						|||
| 
								 | 
							
										claims := jwtauth.ExtractClaims(c)
							 | 
						|||
| 
								 | 
							
										userID, ok := claims["identity"].(float64)
							 | 
						|||
| 
								 | 
							
										if !ok {
							 | 
						|||
| 
								 | 
							
											c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid user ID in token"})
							 | 
						|||
| 
								 | 
							
											return
							 | 
						|||
| 
								 | 
							
										}
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
										// 设置 Subprotocols 支持客户端传来的协议(token等)
							 | 
						|||
| 
								 | 
							
										upgrader.Subprotocols = []string{c.GetHeader("Sec-WebSocket-Protocol")}
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
										conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
							 | 
						|||
| 
								 | 
							
										if err != nil {
							 | 
						|||
| 
								 | 
							
											logger.Errorf("WebSocket upgrade failed: %v", err)
							 | 
						|||
| 
								 | 
							
											c.JSON(http.StatusBadRequest, gin.H{"error": "upgrade failed"})
							 | 
						|||
| 
								 | 
							
											return
							 | 
						|||
| 
								 | 
							
										}
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
										ctx, cancel := context.WithCancel(context.Background())
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
										client := &Client{
							 | 
						|||
| 
								 | 
							
											ID:         fmt.Sprintf("%v", userID),
							 | 
						|||
| 
								 | 
							
											Conn:       conn,
							 | 
						|||
| 
								 | 
							
											Send:       make(chan []byte, 1024),
							 | 
						|||
| 
								 | 
							
											Context:    ctx,
							 | 
						|||
| 
								 | 
							
											CancelFunc: cancel,
							 | 
						|||
| 
								 | 
							
										}
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
										hub.Register <- client
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
										go client.Read(hub)
							 | 
						|||
| 
								 | 
							
										go client.Write()
							 | 
						|||
| 
								 | 
							
									}
							 | 
						|||
| 
								 | 
							
								}
							 |