diff --git a/core/client/reconnect.go b/core/client/reconnect.go index 46f72bc072..05d60b3f4b 100644 --- a/core/client/reconnect.go +++ b/core/client/reconnect.go @@ -56,7 +56,11 @@ func (rc *reconnectableClientImpl) reconnect() error { } } -func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { +// clientDo calls f with the current client. +// If the client is nil, it will first reconnect. +// It will also detect if the client is closed, and if so, +// set it to nil for reconnect next time. +func (rc *reconnectableClientImpl) clientDo(f func(Client) (interface{}, error)) (interface{}, error) { rc.m.Lock() if rc.closed { rc.m.Unlock() @@ -72,46 +76,37 @@ func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { client := rc.client rc.m.Unlock() - conn, err := client.TCP(addr) + ret, err := f(client) if _, ok := err.(coreErrs.ClosedError); ok { // Connection closed, set client to nil for reconnect next time rc.m.Lock() - // In case the client has already been reconnected by another goroutine if rc.client == client { + // This check is in case the client is already changed by another goroutine rc.client = nil } rc.m.Unlock() } - return conn, err + return ret, err } -func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) { - rc.m.Lock() - if rc.closed { - rc.m.Unlock() - return nil, coreErrs.ClosedError{} - } - if rc.client == nil { - // No active connection, connect first - if err := rc.reconnect(); err != nil { - rc.m.Unlock() - return nil, err - } +func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { + if c, err := rc.clientDo(func(client Client) (interface{}, error) { + return client.TCP(addr) + }); err != nil { + return nil, err + } else { + return c.(net.Conn), nil } - client := rc.client - rc.m.Unlock() +} - conn, err := client.UDP() - if _, ok := err.(coreErrs.ClosedError); ok { - // Connection closed, set client to nil for reconnect next time - rc.m.Lock() - // In case the client has already been reconnected by another goroutine - if rc.client == client { - rc.client = nil - } - rc.m.Unlock() +func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) { + if c, err := rc.clientDo(func(client Client) (interface{}, error) { + return client.UDP() + }); err != nil { + return nil, err + } else { + return c.(HyUDPConn), nil } - return conn, err } func (rc *reconnectableClientImpl) Close() error {