Skip to content

Commit

Permalink
Merge pull request #953 from apernet/wip-udphop-listenudpfunc
Browse files Browse the repository at this point in the history
feat: allow set ListenUDP impl for udphop conn
  • Loading branch information
tobyxdd authored Mar 1, 2024
2 parents ea66299 + 1ac9d49 commit 982be54
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
2 changes: 1 addition & 1 deletion app/cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
if hyConfig.ServerAddr.Network() == "udphop" {
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
newFunc = func(addr net.Addr) (net.PacketConn, error) {
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval)
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, nil)
}
} else {
newFunc = func(addr net.Addr) (net.PacketConn, error) {
Expand Down
37 changes: 23 additions & 14 deletions extras/transport/udphop/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ const (
)

type udpHopPacketConn struct {
Addr net.Addr
Addrs []net.Addr
HopInterval time.Duration
Addr net.Addr
Addrs []net.Addr
HopInterval time.Duration
ListenUDPFunc ListenUDPFunc

connMutex sync.RWMutex
prevConn net.PacketConn
Expand All @@ -43,29 +44,37 @@ type udpPacket struct {
Err error
}

func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) {
type ListenUDPFunc func() (net.PacketConn, error)

func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) {
if hopInterval == 0 {
hopInterval = defaultHopInterval
} else if hopInterval < 5*time.Second {
return nil, errors.New("hop interval must be at least 5 seconds")
}
if listenUDPFunc == nil {
listenUDPFunc = func() (net.PacketConn, error) {
return net.ListenUDP("udp", nil)
}
}
addrs, err := addr.addrs()
if err != nil {
return nil, err
}
curConn, err := net.ListenUDP("udp", nil)
curConn, err := listenUDPFunc()
if err != nil {
return nil, err
}
hConn := &udpHopPacketConn{
Addr: addr,
Addrs: addrs,
HopInterval: hopInterval,
prevConn: nil,
currentConn: curConn,
addrIndex: rand.Intn(len(addrs)),
recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}),
Addr: addr,
Addrs: addrs,
HopInterval: hopInterval,
ListenUDPFunc: listenUDPFunc,
prevConn: nil,
currentConn: curConn,
addrIndex: rand.Intn(len(addrs)),
recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}),
bufPool: sync.Pool{
New: func() interface{} {
return make([]byte, udpBufferSize)
Expand Down Expand Up @@ -121,7 +130,7 @@ func (u *udpHopPacketConn) hop() {
if u.closed {
return
}
newConn, err := net.ListenUDP("udp", nil)
newConn, err := u.ListenUDPFunc()
if err != nil {
// Could be temporary, just skip this hop
return
Expand Down

0 comments on commit 982be54

Please sign in to comment.