Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collect some basic proxied connection stats #62

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions connect-udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

func runEchoServer(t *testing.T) *net.UDPConn {
// runEchoServer runs an echo server that echos back the data it receives n times.
func runEchoServer(t *testing.T, amplification int) *net.UDPConn {
t.Helper()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
Expand All @@ -34,16 +35,19 @@ func runEchoServer(t *testing.T) *net.UDPConn {
if err != nil {
return
}
if _, err := conn.WriteTo(b[:n], addr); err != nil {
return
for i := 0; i < amplification; i++ {
if _, err := conn.WriteTo(b[:n], addr); err != nil {
return
}
}
}
}()
return conn
}

func TestProxyToIP(t *testing.T) {
remoteServerConn := runEchoServer(t)
const amplification = 3
remoteServerConn := runEchoServer(t, 3)
defer remoteServerConn.Close()

conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expand All @@ -62,14 +66,16 @@ func TestProxyToIP(t *testing.T) {
defer server.Close()
proxy := masque.Proxy{}
defer proxy.Close()
statsChan := make(chan masque.Stats, 1)
mux.HandleFunc("/masque", func(w http.ResponseWriter, r *http.Request) {
req, err := masque.ParseRequest(r, template)
if err != nil {
t.Log("Upgrade failed:", err)
w.WriteHeader(http.StatusBadRequest)
return
}
proxy.Proxy(w, req)
stats, _ := proxy.Proxy(w, req)
statsChan <- stats
})
go func() {
if err := server.Serve(conn); err != nil {
Expand All @@ -81,20 +87,33 @@ func TestProxyToIP(t *testing.T) {
Template: template,
TLSClientConfig: &tls.Config{ClientCAs: certPool, NextProtos: []string{http3.NextProtoH3}, InsecureSkipVerify: true},
}
defer cl.Close()
proxiedConn, _, err := cl.Dial(context.Background(), remoteServerConn.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)

_, err = proxiedConn.WriteTo([]byte("foobar"), remoteServerConn.LocalAddr())
require.NoError(t, err)
b := make([]byte, 1500)
n, _, err := proxiedConn.ReadFrom(b)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), b[:n])
for i := 0; i < amplification; i++ {
b := make([]byte, 1500)
n, _, err := proxiedConn.ReadFrom(b)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), b[:n])
}
cl.Close()
select {
case stats := <-statsChan:
require.Equal(t, masque.Stats{
PacketsSent: 1,
DataSent: 6,
PacketsReceived: amplification,
DataReceived: 6 * amplification,
}, stats)
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for stats")
}
}

func TestProxyToHostname(t *testing.T) {
remoteServerConn := runEchoServer(t)
remoteServerConn := runEchoServer(t, 1)
defer remoteServerConn.Close()

conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expand Down Expand Up @@ -197,7 +216,7 @@ func TestProxyToHostnameMissingPort(t *testing.T) {
}

func TestProxyShutdown(t *testing.T) {
remoteServerConn := runEchoServer(t)
remoteServerConn := runEchoServer(t, 1)
defer remoteServerConn.Close()

conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expand Down
51 changes: 35 additions & 16 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
conn *net.UDPConn
}

type Stats struct {
PacketsSent, PacketsReceived uint64
DataSent, DataReceived uint64
}

type Proxy struct {
closed atomic.Bool

Expand All @@ -38,23 +43,23 @@
// For more control over the UDP socket, use ProxyConnectedSocket.
// Applications may add custom header fields to the response header,
// but MUST NOT call WriteHeader on the http.ResponseWriter.
func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error {
func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) (Stats, error) {
if s.closed.Load() {
w.WriteHeader(http.StatusServiceUnavailable)
return net.ErrClosed
return Stats{}, net.ErrClosed
}

addr, err := net.ResolveUDPAddr("udp", r.Target)
if err != nil {
// TODO: set proxy-status header (might want to use structured headers)
w.WriteHeader(http.StatusGatewayTimeout)
return err
return Stats{}, err
}
conn, err := net.DialUDP("udp", nil, addr)
if err != nil {
// TODO: set proxy-status header (might want to use structured headers)
w.WriteHeader(http.StatusGatewayTimeout)
return err
return Stats{}, err

Check warning on line 62 in proxy.go

View check run for this annotation

Codecov / codecov/patch

proxy.go#L62

Added line #L62 was not covered by tests
}
defer conn.Close()

Expand All @@ -64,11 +69,11 @@
// ProxyConnectedSocket proxies a request on a connected UDP socket.
// Applications may add custom header fields to the response header,
// but MUST NOT call WriteHeader on the http.ResponseWriter.
func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) error {
func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) (Stats, error) {
if s.closed.Load() {
conn.Close()
w.WriteHeader(http.StatusServiceUnavailable)
return net.ErrClosed
return Stats{}, net.ErrClosed
}

s.refCount.Add(1)
Expand All @@ -87,16 +92,21 @@

var wg sync.WaitGroup
wg.Add(3)
var packetsSent, packetsReceived, dataSent, dataReceived uint64
go func() {
defer wg.Done()
if err := s.proxyConnSend(conn, str); err != nil {
var err error
packetsSent, dataSent, err = s.proxyConnSend(conn, str)
if err != nil && !s.closed.Load() {
log.Printf("proxying send side to %s failed: %v", conn.RemoteAddr(), err)
}
str.Close()
}()
go func() {
defer wg.Done()
if err := s.proxyConnReceive(conn, str); err != nil && !s.closed.Load() {
var err error
packetsReceived, dataReceived, err = s.proxyConnReceive(conn, str)
if err != nil && !s.closed.Load() {
log.Printf("proxying receive side to %s failed: %v", conn.RemoteAddr(), err)
}
str.Close()
Expand All @@ -111,41 +121,50 @@
conn.Close()
}()
wg.Wait()
return nil
return Stats{
PacketsSent: packetsSent,
PacketsReceived: packetsReceived,
DataSent: dataSent,
DataReceived: dataReceived,
}, nil
}

func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) error {
func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) (packetsSent, dataSent uint64, _ error) {
for {
data, err := str.ReceiveDatagram(context.Background())
if err != nil {
return err
return packetsSent, dataSent, err
}
contextID, n, err := quicvarint.Parse(data)
if err != nil {
return err
return packetsSent, dataSent, err

Check warning on line 140 in proxy.go

View check run for this annotation

Codecov / codecov/patch

proxy.go#L140

Added line #L140 was not covered by tests
}
if contextID != 0 {
// Drop this datagram. We currently only support proxying of UDP payloads.
continue
}
packetsSent++
dataSent += uint64(len(data) - n)
if _, err := conn.Write(data[n:]); err != nil {
return err
return packetsSent, dataSent, err

Check warning on line 149 in proxy.go

View check run for this annotation

Codecov / codecov/patch

proxy.go#L149

Added line #L149 was not covered by tests
}
}
}

func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) error {
func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) (packetsReceived, dataReceived uint64, _ error) {
b := make([]byte, 1500)
for {
n, err := conn.Read(b)
if err != nil {
return err
return packetsReceived, dataReceived, err
}
packetsReceived++
dataReceived += uint64(n)
data := make([]byte, 0, len(contextIDZero)+n)
data = append(data, contextIDZero...)
data = append(data, b[:n]...)
if err := str.SendDatagram(data); err != nil {
return err
return packetsReceived, dataReceived, err

Check warning on line 167 in proxy.go

View check run for this annotation

Codecov / codecov/patch

proxy.go#L167

Added line #L167 was not covered by tests
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ func TestProxyDialFailure(t *testing.T) {
require.NoError(t, err)
rec := httptest.NewRecorder()

require.ErrorContains(t, p.Proxy(rec, req), "invalid port")
_, err = p.Proxy(rec, req)
require.ErrorContains(t, err, "invalid port")
require.Equal(t, http.StatusGatewayTimeout, rec.Code)
}

Expand All @@ -117,15 +118,17 @@ func TestProxyingAfterClose(t *testing.T) {

t.Run("proxying", func(t *testing.T) {
rec := httptest.NewRecorder()
require.ErrorIs(t, p.Proxy(rec, req), net.ErrClosed)
_, err := p.Proxy(rec, req)
require.ErrorIs(t, err, net.ErrClosed)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
})

t.Run("proxying connected socket", func(t *testing.T) {
rec := httptest.NewRecorder()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
require.ErrorIs(t, p.ProxyConnectedSocket(rec, req, conn), net.ErrClosed)
_, err = p.ProxyConnectedSocket(rec, req, conn)
require.ErrorIs(t, err, net.ErrClosed)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
})
}