1、消息提示

This commit is contained in:
2025-07-15 16:43:50 +08:00
parent ab50cf0dfe
commit 2e35d55838
21 changed files with 1436 additions and 6 deletions

84
app/websocket/client.go Normal file
View File

@ -0,0 +1,84 @@
package websocket
import (
"context"
"log"
"time"
"github.com/gorilla/websocket"
)
type Client struct {
ID string
Conn *websocket.Conn
Send chan []byte
Context context.Context
CancelFunc context.CancelFunc
}
func (c *Client) Read(hub *Hub) {
defer func() {
hub.Unregister <- c
c.Conn.Close()
log.Printf("Client %s disconnected\n", c.ID)
c.CancelFunc()
}()
c.Conn.SetReadLimit(512)
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
select {
case <-c.Context.Done():
return
default:
_, message, err := c.Conn.ReadMessage()
if err != nil {
log.Printf("Read error from client %s: %v", c.ID, err)
return
}
log.Printf("Receive [%s]: %s", c.ID, message)
// 这里你可以把消息发给 hub.Broadcast 或业务处理
}
}
}
func (c *Client) Write() {
ticker := time.NewTicker(54 * time.Second) // 小于读超时保证ping及时发
defer func() {
ticker.Stop()
c.Conn.Close()
c.CancelFunc()
}()
for {
select {
case <-c.Context.Done():
return
case msg, ok := <-c.Send:
if !ok {
// 通道关闭,结束写入
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err := c.Conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
log.Printf("Write error to client %s: %v", c.ID, err)
return
}
case <-ticker.C:
// 发送 ping
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
log.Printf("Ping error to client %s: %v", c.ID, err)
return
}
}
}
}

55
app/websocket/handler.go Normal file
View File

@ -0,0 +1,55 @@
package websocket
import (
"context"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-admin-team/go-admin-core/logger"
"github.com/go-admin-team/go-admin-core/sdk/pkg/jwtauth"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
func ServeWS(hub *Hub) gin.HandlerFunc {
return func(c *gin.Context) {
// 这里假设 JWT 中间件已验证,且用户 ID 在 Context
claims := jwtauth.ExtractClaims(c)
userID, ok := claims["identity"].(float64)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid user ID in token"})
return
}
// 设置 Subprotocols 支持客户端传来的协议token等
upgrader.Subprotocols = []string{c.GetHeader("Sec-WebSocket-Protocol")}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Errorf("WebSocket upgrade failed: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "upgrade failed"})
return
}
ctx, cancel := context.WithCancel(context.Background())
client := &Client{
ID: fmt.Sprintf("%v", userID),
Conn: conn,
Send: make(chan []byte, 1024),
Context: ctx,
CancelFunc: cancel,
}
hub.Register <- client
go client.Read(hub)
go client.Write()
}
}

63
app/websocket/hub.go Normal file
View File

@ -0,0 +1,63 @@
package websocket
import "sync"
type Hub struct {
Clients map[string]*Client
Register chan *Client
Unregister chan *Client
mu sync.RWMutex
}
func NewHub() *Hub {
return &Hub{
Clients: make(map[string]*Client),
Register: make(chan *Client),
Unregister: make(chan *Client),
}
}
func (h *Hub) Run() {
for {
select {
case client := <-h.Register:
h.mu.Lock()
h.Clients[client.ID] = client
h.mu.Unlock()
case client := <-h.Unregister:
h.mu.Lock()
if _, ok := h.Clients[client.ID]; ok {
delete(h.Clients, client.ID)
close(client.Send)
}
h.mu.Unlock()
}
}
}
func (h *Hub) SendToClient(id string, msg []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
if c, ok := h.Clients[id]; ok {
c.Send <- msg
}
}
func (h *Hub) SendToAll(ids []string, msg []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
for _, id := range ids {
if c, ok := h.Clients[id]; ok {
c.Send <- msg
}
}
}
func (h *Hub) Broadcast(msg []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
for _, client := range h.Clients {
client.Send <- msg
}
}