From 5461a7ecd82aa1f2fc1eb724742cdf121364b725 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 3 Dec 2024 18:26:16 +0800 Subject: [PATCH] introduce a basic tracer --- connect-udp_test.go | 6 +++--- proxy.go | 29 +++++++++++++++++++++-------- proxy_test.go | 8 ++++---- tracer.go | 13 +++++++++++++ 4 files changed, 41 insertions(+), 15 deletions(-) create mode 100644 tracer.go diff --git a/connect-udp_test.go b/connect-udp_test.go index 3ce1efa..de06927 100644 --- a/connect-udp_test.go +++ b/connect-udp_test.go @@ -74,7 +74,7 @@ func testProxyToIP(t *testing.T, addr *net.UDPAddr) { w.WriteHeader(http.StatusBadRequest) return } - proxy.Proxy(w, req) + proxy.Proxy(w, req, nil) }) go func() { if err := server.Serve(conn); err != nil { @@ -135,7 +135,7 @@ func TestProxyToHostname(t *testing.T) { // In this test, we don't actually want to connect to quic-go.net // Replace the target with the UDP echoer we spun up earlier. req.Target = remoteServerConn.LocalAddr().String() - proxy.Proxy(w, req) + proxy.Proxy(w, req, nil) }) go func() { if err := server.Serve(conn); err != nil { @@ -231,7 +231,7 @@ func TestProxyShutdown(t *testing.T) { w.WriteHeader(http.StatusBadRequest) return } - proxy.Proxy(w, req) + proxy.Proxy(w, req, nil) }) go func() { if err := server.Serve(conn); err != nil { diff --git a/proxy.go b/proxy.go index e1b754d..75818ba 100644 --- a/proxy.go +++ b/proxy.go @@ -39,7 +39,7 @@ type Proxy struct { // For more control over the UDP socket, use ProxyConnectedSocket. // Applications may add custom header fields to the response header, // but MUST NOT call WriteHeader on the http.ResponseWriter. -func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error { +func (s *Proxy) Proxy(w http.ResponseWriter, r *Request, tracer *Tracer) error { if s.closed.Load() { w.WriteHeader(http.StatusServiceUnavailable) return net.ErrClosed @@ -59,14 +59,14 @@ func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error { } defer conn.Close() - return s.ProxyConnectedSocket(w, r, conn) + return s.ProxyConnectedSocket(w, conn, tracer) } // ProxyConnectedSocket proxies a request on a connected UDP socket. // Applications may add custom header fields to the response header, // but MUST NOT call WriteHeader on the http.ResponseWriter. // It closes the connection before returning. -func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) error { +func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, conn *net.UDPConn, tracer *Tracer) error { if s.closed.Load() { conn.Close() w.WriteHeader(http.StatusServiceUnavailable) @@ -91,17 +91,23 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne wg.Add(3) go func() { defer wg.Done() - if err := s.proxyConnSend(conn, str); err != nil { + if err := s.proxyConnSend(conn, str, tracer); err != nil { log.Printf("proxying send side to %s failed: %v", conn.RemoteAddr(), err) } str.Close() + if tracer != nil && tracer.SendDirectionClosed != nil { + tracer.SendDirectionClosed() + } }() go func() { defer wg.Done() - if err := s.proxyConnReceive(conn, str); err != nil && !s.closed.Load() { + if err := s.proxyConnReceive(conn, str, tracer); err != nil && !s.closed.Load() { log.Printf("proxying receive side to %s failed: %v", conn.RemoteAddr(), err) } str.Close() + if tracer != nil && tracer.ReceiveDirectionClosed != nil { + tracer.ReceiveDirectionClosed() + } }() go func() { defer wg.Done() @@ -116,7 +122,7 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne return nil } -func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) error { +func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream, tracer *Tracer) error { for { data, err := str.ReceiveDatagram(context.Background()) if err != nil { @@ -130,13 +136,17 @@ func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) error { // Drop this datagram. We currently only support proxying of UDP payloads. continue } - if _, err := conn.Write(data[n:]); err != nil { + b := data[n:] + if _, err := conn.Write(b); err != nil { return err } + if tracer != nil && tracer.SentData != nil { + tracer.SentData(len(b)) + } } } -func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) error { +func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream, tracer *Tracer) error { b := make([]byte, 1500) for { n, err := conn.Read(b) @@ -149,6 +159,9 @@ func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) error { if err := str.SendDatagram(data); err != nil { return err } + if tracer != nil && tracer.ReceivedData != nil { + tracer.ReceivedData(n) + } } } diff --git a/proxy_test.go b/proxy_test.go index 5438c3f..9f0de53 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -75,7 +75,7 @@ func TestProxyCloseProxiedConn(t *testing.T) { }) r, err := ParseRequest(req, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}")) require.NoError(t, err) - go p.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r) + go p.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r, nil) require.Equal(t, http.StatusOK, rec.Code) b := make([]byte, 100) @@ -103,7 +103,7 @@ func TestProxyDialFailure(t *testing.T) { require.NoError(t, err) rec := httptest.NewRecorder() - require.ErrorContains(t, p.Proxy(rec, req), "invalid port") + require.ErrorContains(t, p.Proxy(rec, req, nil), "invalid port") require.Equal(t, http.StatusGatewayTimeout, rec.Code) } @@ -117,7 +117,7 @@ func TestProxyingAfterClose(t *testing.T) { t.Run("proxying", func(t *testing.T) { rec := httptest.NewRecorder() - require.ErrorIs(t, p.Proxy(rec, req), net.ErrClosed) + require.ErrorIs(t, p.Proxy(rec, req, nil), net.ErrClosed) require.Equal(t, http.StatusServiceUnavailable, rec.Code) }) @@ -125,7 +125,7 @@ func TestProxyingAfterClose(t *testing.T) { rec := httptest.NewRecorder() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) - require.ErrorIs(t, p.ProxyConnectedSocket(rec, req, conn), net.ErrClosed) + require.ErrorIs(t, p.ProxyConnectedSocket(rec, conn, nil), net.ErrClosed) require.Equal(t, http.StatusServiceUnavailable, rec.Code) }) } diff --git a/tracer.go b/tracer.go new file mode 100644 index 0000000..7f1352e --- /dev/null +++ b/tracer.go @@ -0,0 +1,13 @@ +package masque + +// A Tracer can be used to monitor the progress of a proxied connection. +type Tracer struct { + // SentData is called when data is sent towards the target. + SentData func(n int) + // SentDirectionClosed is called when the send direction (towards the target) is closed. + SendDirectionClosed func() + // ReceivedData is called when data is received from the target. + ReceivedData func(n int) + // ReceiveDirectionClosed is called when the receive direction (from the target) is closed. + ReceiveDirectionClosed func() +}