Skip to content

Commit

Permalink
introduce a basic tracer
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Dec 3, 2024
1 parent b312bc5 commit 5461a7e
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 15 deletions.
6 changes: 3 additions & 3 deletions connect-udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func testProxyToIP(t *testing.T, addr *net.UDPAddr) {
w.WriteHeader(http.StatusBadRequest)
return
}
proxy.Proxy(w, req)
proxy.Proxy(w, req, nil)
})
go func() {
if err := server.Serve(conn); err != nil {
Expand Down Expand Up @@ -135,7 +135,7 @@ func TestProxyToHostname(t *testing.T) {
// In this test, we don't actually want to connect to quic-go.net
// Replace the target with the UDP echoer we spun up earlier.
req.Target = remoteServerConn.LocalAddr().String()
proxy.Proxy(w, req)
proxy.Proxy(w, req, nil)
})
go func() {
if err := server.Serve(conn); err != nil {
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestProxyShutdown(t *testing.T) {
w.WriteHeader(http.StatusBadRequest)
return
}
proxy.Proxy(w, req)
proxy.Proxy(w, req, nil)
})
go func() {
if err := server.Serve(conn); err != nil {
Expand Down
29 changes: 21 additions & 8 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Proxy struct {
// For more control over the UDP socket, use ProxyConnectedSocket.
// Applications may add custom header fields to the response header,
// but MUST NOT call WriteHeader on the http.ResponseWriter.
func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error {
func (s *Proxy) Proxy(w http.ResponseWriter, r *Request, tracer *Tracer) error {
if s.closed.Load() {
w.WriteHeader(http.StatusServiceUnavailable)
return net.ErrClosed
Expand All @@ -59,14 +59,14 @@ func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error {
}
defer conn.Close()

return s.ProxyConnectedSocket(w, r, conn)
return s.ProxyConnectedSocket(w, conn, tracer)
}

// ProxyConnectedSocket proxies a request on a connected UDP socket.
// Applications may add custom header fields to the response header,
// but MUST NOT call WriteHeader on the http.ResponseWriter.
// It closes the connection before returning.
func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) error {
func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, conn *net.UDPConn, tracer *Tracer) error {
if s.closed.Load() {
conn.Close()
w.WriteHeader(http.StatusServiceUnavailable)
Expand All @@ -91,17 +91,23 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne
wg.Add(3)
go func() {
defer wg.Done()
if err := s.proxyConnSend(conn, str); err != nil {
if err := s.proxyConnSend(conn, str, tracer); err != nil {
log.Printf("proxying send side to %s failed: %v", conn.RemoteAddr(), err)
}
str.Close()
if tracer != nil && tracer.SendDirectionClosed != nil {
tracer.SendDirectionClosed()
}
}()
go func() {
defer wg.Done()
if err := s.proxyConnReceive(conn, str); err != nil && !s.closed.Load() {
if err := s.proxyConnReceive(conn, str, tracer); err != nil && !s.closed.Load() {
log.Printf("proxying receive side to %s failed: %v", conn.RemoteAddr(), err)
}
str.Close()
if tracer != nil && tracer.ReceiveDirectionClosed != nil {
tracer.ReceiveDirectionClosed()
}
}()
go func() {
defer wg.Done()
Expand All @@ -116,7 +122,7 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne
return nil
}

func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) error {
func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream, tracer *Tracer) error {
for {
data, err := str.ReceiveDatagram(context.Background())
if err != nil {
Expand All @@ -130,13 +136,17 @@ func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) error {
// Drop this datagram. We currently only support proxying of UDP payloads.
continue
}
if _, err := conn.Write(data[n:]); err != nil {
b := data[n:]
if _, err := conn.Write(b); err != nil {
return err
}
if tracer != nil && tracer.SentData != nil {
tracer.SentData(len(b))
}
}
}

func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) error {
func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream, tracer *Tracer) error {
b := make([]byte, 1500)
for {
n, err := conn.Read(b)
Expand All @@ -149,6 +159,9 @@ func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) error {
if err := str.SendDatagram(data); err != nil {
return err
}
if tracer != nil && tracer.ReceivedData != nil {
tracer.ReceivedData(n)
}
}
}

Expand Down
8 changes: 4 additions & 4 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestProxyCloseProxiedConn(t *testing.T) {
})
r, err := ParseRequest(req, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}"))
require.NoError(t, err)
go p.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r)
go p.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r, nil)
require.Equal(t, http.StatusOK, rec.Code)

b := make([]byte, 100)
Expand Down Expand Up @@ -103,7 +103,7 @@ func TestProxyDialFailure(t *testing.T) {
require.NoError(t, err)
rec := httptest.NewRecorder()

require.ErrorContains(t, p.Proxy(rec, req), "invalid port")
require.ErrorContains(t, p.Proxy(rec, req, nil), "invalid port")
require.Equal(t, http.StatusGatewayTimeout, rec.Code)
}

Expand All @@ -117,15 +117,15 @@ func TestProxyingAfterClose(t *testing.T) {

t.Run("proxying", func(t *testing.T) {
rec := httptest.NewRecorder()
require.ErrorIs(t, p.Proxy(rec, req), net.ErrClosed)
require.ErrorIs(t, p.Proxy(rec, req, nil), net.ErrClosed)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
})

t.Run("proxying connected socket", func(t *testing.T) {
rec := httptest.NewRecorder()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
require.ErrorIs(t, p.ProxyConnectedSocket(rec, req, conn), net.ErrClosed)
require.ErrorIs(t, p.ProxyConnectedSocket(rec, conn, nil), net.ErrClosed)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
})
}
13 changes: 13 additions & 0 deletions tracer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package masque

// A Tracer can be used to monitor the progress of a proxied connection.
type Tracer struct {
// SentData is called when data is sent towards the target.
SentData func(n int)
// SentDirectionClosed is called when the send direction (towards the target) is closed.
SendDirectionClosed func()
// ReceivedData is called when data is received from the target.
ReceivedData func(n int)
// ReceiveDirectionClosed is called when the receive direction (from the target) is closed.
ReceiveDirectionClosed func()
}

0 comments on commit 5461a7e

Please sign in to comment.