Skip to content

Commit

Permalink
Rewrite udp associate (#19)
Browse files Browse the repository at this point in the history
* relay packet to remote server

* remove deprecated go version

* package mapper

* nat table

* port for udp associate from request

* echo packet server for test

* time to live packet

* reset packet payload

* test command udp associate

* packet metrics and rules

* packet write timeout

* change closeListenerFn

* comments for options
  • Loading branch information
TuanKiri authored Aug 18, 2024
1 parent cb70d46 commit 7818724
Show file tree
Hide file tree
Showing 11 changed files with 527 additions and 221 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jobs:
test:
strategy:
matrix:
go-version: [1.18.x, 1.22.x]
go-version: [1.22.x]
platform: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.platform }}
steps:
Expand Down
33 changes: 1 addition & 32 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package socks5

import (
"bufio"
"bytes"
"io"
"net"
)
Expand Down Expand Up @@ -68,37 +67,7 @@ func (c *connection) keepAlive() {
return
}

c.closeFn()

close(c.done)
}

type packetConn struct {
net.PacketConn
net.Addr
reader *bytes.Buffer
}

func newPacketConn(conn net.PacketConn, addr net.Addr, data []byte) *packetConn {
return &packetConn{
PacketConn: conn,
Addr: addr,
reader: bytes.NewBuffer(data),
}
}

func (c *packetConn) readByte() (byte, error) {
return c.reader.ReadByte()
}

func (c *packetConn) read(p []byte) (int, error) {
return c.reader.Read(p)
}

func (c *packetConn) write(p []byte) (int, error) {
return c.PacketConn.WriteTo(p, c.Addr)
}

func (c *packetConn) bytes() []byte {
return c.reader.Bytes()
c.closeFn()
}
11 changes: 11 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package socks5

import (
"errors"
"net"
"time"
)
Expand All @@ -9,6 +10,7 @@ type Driver interface {
Listen(network, address string) (net.Listener, error)
ListenPacket(network, address string) (net.PacketConn, error)
Dial(network, address string) (net.Conn, error)
Resolve(network, address string) (net.Addr, error)
}

type netDriver struct {
Expand All @@ -26,3 +28,12 @@ func (d *netDriver) ListenPacket(network, address string) (net.PacketConn, error
func (d *netDriver) Dial(network, address string) (net.Conn, error) {
return net.DialTimeout(network, address, d.timeout)
}

func (d *netDriver) Resolve(network, address string) (net.Addr, error) {
switch network {
case "udp":
return net.ResolveUDPAddr(network, address)
default:
return nil, errors.New("bad network")
}
}
54 changes: 46 additions & 8 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"net"
"net/http"
"time"

"golang.org/x/net/proxy"

Expand Down Expand Up @@ -42,11 +43,12 @@ lPQTC4uW7AAywREZ2ekd8XUX8JwdTJHmPg==
var cert tls.Certificate

func init() {
c, err := tls.X509KeyPair([]byte(certPem), []byte(keyPem))
var err error

cert, err = tls.X509KeyPair([]byte(certPem), []byte(keyPem))
if err != nil {
panic(err)
log.Fatalf("error parsing public/private key pair: %v", err)
}
cert = c

mux := http.NewServeMux()

Expand All @@ -56,13 +58,19 @@ func init() {

go func() {
if err := http.ListenAndServe(":5444", mux); err != nil {
log.Fatalf("runRemoteServer: %v", err)
log.Fatalf("error running remote http server: %v", err)
}
}()

go func() {
if err := listenAndServeTLS(":6444", mux); err != nil {
log.Fatalf("runRemoteServer: %v", err)
log.Fatalf("error running remote https server: %v", err)
}
}()

go func() {
if err := echoPacketServer(":7444"); err != nil {
log.Fatalf("error running echo packet server: %v", err)
}
}()
}
Expand All @@ -89,6 +97,10 @@ func (d testTLSDriver) ListenPacket(network, address string) (net.PacketConn, er
return nil, nil
}

func (d testTLSDriver) Resolve(network, address string) (net.Addr, error) {
return nil, nil
}

func listenAndServeTLS(address string, handler http.Handler) error {
server := http.Server{
Addr: address,
Expand All @@ -101,15 +113,15 @@ func listenAndServeTLS(address string, handler http.Handler) error {
return server.ListenAndServeTLS("", "")
}

func runProxy(opts []socks5.Option) {
func runProxy(opts ...socks5.Option) {
srv := socks5.New(opts...)

if err := srv.ListenAndServe(); err != nil {
log.Fatalf("runProxy: %v", err)
log.Fatalf("error running socks5 proxy server: %v", err)
}
}

func setupClient(proxyAddress string, auth *proxy.Auth) (*http.Client, error) {
func newHttpClient(proxyAddress string, auth *proxy.Auth) (*http.Client, error) {
socksProxy, err := proxy.SOCKS5(
"tcp",
proxyAddress,
Expand Down Expand Up @@ -137,3 +149,29 @@ func setupClient(proxyAddress string, auth *proxy.Auth) (*http.Client, error) {
},
}, nil
}

func echoPacketServer(address string) error {
packetConn, err := net.ListenPacket("udp", address)
if err != nil {
return err
}
defer packetConn.Close()

buf := make([]byte, 1024)

for {
n, clientAddress, err := packetConn.ReadFrom(buf)
if err != nil {
log.Printf("error reading: %v", err)
continue
}

packetConn.SetWriteDeadline(time.Now().Add(200 * time.Millisecond))

_, err = packetConn.WriteTo(buf[:n], clientAddress)
if err != nil {
log.Printf("error writing: %v", err)
continue
}
}
}
81 changes: 81 additions & 0 deletions nat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package socks5

import (
"net"
"sync"
"time"
)

type natEntry struct {
src net.Addr
packet *packet
timestamp time.Time
}

type natTable struct {
mutex sync.RWMutex
table map[string]*natEntry
}

func newNatTable() *natTable {
return &natTable{table: make(map[string]*natEntry)}
}

func (n *natTable) set(src, dst net.Addr, packet *packet) {
n.mutex.Lock()
n.table[dst.String()] = &natEntry{
src: src,
packet: packet,
timestamp: time.Now(),
}
n.mutex.Unlock()
}

func (n *natTable) get(dst net.Addr) (net.Addr, *packet, bool) {
n.mutex.RLock()
defer n.mutex.RUnlock()

val, ok := n.table[dst.String()]
if !ok {
return nil, nil, ok
}

return val.src, val.packet, ok
}

func (n *natTable) delete(dst net.Addr) {
n.mutex.Lock()
delete(n.table, dst.String())
n.mutex.Unlock()
}

func (n *natTable) cleanup(period, ttl time.Duration) func() {
if period <= 0 || ttl <= 0 {
return func() {}
}

ticker := time.NewTicker(period)
done := make(chan struct{})

go func() {
for {
select {
case <-done:
ticker.Stop()
return
case <-ticker.C:
n.mutex.Lock()
for key, val := range n.table {
if time.Since(val.timestamp) >= ttl {
delete(n.table, key)
}
}
n.mutex.Unlock()
}
}
}()

return func() {
close(done)
}
}
48 changes: 47 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ type options struct {
allowCommands map[byte]struct{}
blockListHosts map[string]struct{}
allowIPs []net.IP
maxPacketSize int
packetWriteTimeout time.Duration
ttlPacket time.Duration
natCleanupPeriod time.Duration
logger Logger
store Store
driver Driver
Expand All @@ -56,10 +60,14 @@ func (o options) listenAddress() string {
}

func optsWithDefaults(opts *options) *options {
if opts.port == 0 {
if opts.port <= 0 {
opts.port = 1080
}

if opts.maxPacketSize <= 0 {
opts.maxPacketSize = 1500
}

if opts.publicIP == nil {
opts.publicIP = net.ParseIP("127.0.0.1")
}
Expand Down Expand Up @@ -119,18 +127,23 @@ func WithPort(val int) Option {
}
}

// WithPublicIP sets an IP address that is visible on the external Internet,
// accessible to users outside the local network and will be sent to clients in
// response to a connection request.
func WithPublicIP(val net.IP) Option {
return func(o *options) {
o.publicIP = val
}
}

// WithReadTimeout sets the read timeout for tcp connection.
func WithReadTimeout(val time.Duration) Option {
return func(o *options) {
o.readTimeout = val
}
}

// WithWriteTimeout sets the write timeout for tcp connection.
func WithWriteTimeout(val time.Duration) Option {
return func(o *options) {
o.writeTimeout = val
Expand Down Expand Up @@ -220,3 +233,36 @@ func WithBlockListHosts(hosts ...string) Option {
o.blockListHosts = blockListHosts
}
}

// WithPacketWriteTimeout sets the timeout for waiting to write a packet to the remote host.
func WithPacketWriteTimeout(val time.Duration) Option {
return func(o *options) {
o.packetWriteTimeout = val
}
}

// WithMaxPacketSize sets the maximum size in bytes for the datagram to be read from the socket.
func WithMaxPacketSize(val int) Option {
return func(o *options) {
o.maxPacketSize = val
}
}

// WithTTLPacket sets how long the packet will stay in the table
// that links the sender of the packet to the remote host it was meant for.
// Nat cleanup period must be greater than 0.
func WithTTLPacket(val time.Duration) Option {
return func(o *options) {
o.ttlPacket = val
}
}

// WithNatCleanupPeriod sets the period when the table that links the
// packets from the sender to the remote host will be cleaned.
// It makes sense if there's no time limit on the TCP connection.
// TTL of the packet must be greater than 0.
func WithNatCleanupPeriod(val time.Duration) Option {
return func(o *options) {
o.natCleanupPeriod = val
}
}
Loading

0 comments on commit 7818724

Please sign in to comment.