diff --git a/client.go b/client.go index d1cfe03..7fb0db1 100644 --- a/client.go +++ b/client.go @@ -10,13 +10,6 @@ import ( "time" ) -const ( - writeWait = 10 * time.Second // 写等待 - pongWait = 60 * time.Second // 心跳等待 - pingPeriod = (pongWait * 9) / 10 // 心跳频率 - maxMessageSize = 524288 // 512 kb -) - // 一个连接一个Client,负责处理连接的I/O type Client struct { id string @@ -27,17 +20,19 @@ type Client struct { req *http.Request mu sync.Mutex channels map[string]struct{} + conf *Config } -func NewClient(conn *websocket.Conn, handler Handler, hub *SubscriptionHub, req *http.Request) *Client { +func NewClient(conn *websocket.Conn, handler Handler, hub *SubscriptionHub, req *http.Request, c *Config) *Client { return &Client{ id: uuid.NewV4().String(), - writeChan: make(chan []byte, 256), + writeChan: make(chan []byte, c.WriteChanBuffer), conn: conn, handler: handler, hub: hub, channels: map[string]struct{}{}, req: req, + conf: c, } } @@ -55,10 +50,10 @@ func (c *Client) Run() { // 读取客户端发过来的内容 func (c *Client) reader() { - c.conn.SetReadLimit(maxMessageSize) - c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetReadLimit(c.conf.MaxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(c.conf.PongWait)) c.conn.SetPongHandler(func(string) error { - return c.conn.SetReadDeadline(time.Now().Add(pongWait)) + return c.conn.SetReadDeadline(time.Now().Add(c.conf.PongWait)) }) for { _, d, err := c.conn.ReadMessage() @@ -78,18 +73,18 @@ func (c *Client) reader() { // 向客户端写入内容 func (c *Client) writer() { - tik := time.NewTicker(pingPeriod) + tik := time.NewTicker(c.conf.PingPeriod) for { select { case buf := <-c.writeChan: - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + c.conn.SetWriteDeadline(time.Now().Add(c.conf.WriteWait)) err := c.conn.WriteMessage(websocket.TextMessage, buf) if err != nil { c.Close() return } case <-tik.C: - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + c.conn.SetWriteDeadline(time.Now().Add(c.conf.WriteWait)) err := c.conn.WriteMessage(websocket.PingMessage, nil) if err != nil { c.Close() @@ -133,6 +128,7 @@ func (c *Client) UnsubscribeAll() { for channel := range c.channels { c.hub.Unsubscribe(channel, c) } + c.channels = map[string]struct{}{} } // 获取已订阅的主题列表 diff --git a/config.go b/config.go new file mode 100644 index 0000000..56c7024 --- /dev/null +++ b/config.go @@ -0,0 +1,22 @@ +package webreal + +import "time" + +// 客户端配置 +type Config struct { + WriteWait time.Duration // 写等待时间 + WriteChanBuffer int // 写缓冲长度 + PongWait time.Duration // 心跳等待时间 + PingPeriod time.Duration // 心跳频率 + MaxMessageSize int64 // 最大消息字节数 +} + +func DefaultConfig() *Config { + return &Config{ + WriteWait: 10 * time.Second, + WriteChanBuffer: 256, + PongWait: 60 * time.Second, + PingPeriod: 54 * time.Second, + MaxMessageSize: 524288, // 512KB + } +} diff --git a/server.go b/server.go index 0c6edec..5e9219a 100644 --- a/server.go +++ b/server.go @@ -9,9 +9,10 @@ type Server struct { hub *SubscriptionHub handler Handler upgrader websocket.Upgrader + conf *Config } -func NewServer(handler Handler, hub *SubscriptionHub) *Server { +func NewServer(handler Handler, hub *SubscriptionHub, c *Config) *Server { return &Server{ hub: hub, handler: handler, @@ -20,6 +21,7 @@ func NewServer(handler Handler, hub *SubscriptionHub) *Server { return true }, }, + conf: c, } } @@ -29,7 +31,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { return } - NewClient(conn, s.handler, s.hub, r).Run() + NewClient(conn, s.handler, s.hub, r, s.conf).Run() } // 使用默认的http启动监听服务 diff --git a/test/main.go b/test/main.go index 5bbd89f..72e03e5 100644 --- a/test/main.go +++ b/test/main.go @@ -36,6 +36,6 @@ func main() { } } }() - server := webreal.NewServer(&PushingHandler{}, hub) + server := webreal.NewServer(&PushingHandler{}, hub, webreal.DefaultConfig()) server.Run("127.0.0.1:8080", "/ws") }