Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce a basic tracer #80

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions connect-udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func testProxyToIP(t *testing.T, addr *net.UDPAddr) {
w.WriteHeader(http.StatusBadRequest)
return
}
proxy.Proxy(w, req)
proxy.Proxy(w, req, nil)
})
go func() {
if err := server.Serve(conn); err != nil {
Expand Down Expand Up @@ -135,7 +135,7 @@ func TestProxyToHostname(t *testing.T) {
// In this test, we don't actually want to connect to quic-go.net
// Replace the target with the UDP echoer we spun up earlier.
req.Target = remoteServerConn.LocalAddr().String()
proxy.Proxy(w, req)
proxy.Proxy(w, req, nil)
})
go func() {
if err := server.Serve(conn); err != nil {
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestProxyShutdown(t *testing.T) {
w.WriteHeader(http.StatusBadRequest)
return
}
proxy.Proxy(w, req)
proxy.Proxy(w, req, nil)
})
go func() {
if err := server.Serve(conn); err != nil {
Expand Down
29 changes: 21 additions & 8 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Proxy struct {
// For more control over the UDP socket, use ProxyConnectedSocket.
// Applications may add custom header fields to the response header,
// but MUST NOT call WriteHeader on the http.ResponseWriter.
func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error {
func (s *Proxy) Proxy(w http.ResponseWriter, r *Request, tracer *Tracer) error {
if s.closed.Load() {
w.WriteHeader(http.StatusServiceUnavailable)
return net.ErrClosed
Expand All @@ -59,14 +59,14 @@ func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error {
}
defer conn.Close()

return s.ProxyConnectedSocket(w, r, conn)
return s.ProxyConnectedSocket(w, conn, tracer)
}

// ProxyConnectedSocket proxies a request on a connected UDP socket.
// Applications may add custom header fields to the response header,
// but MUST NOT call WriteHeader on the http.ResponseWriter.
// It closes the connection before returning.
func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) error {
func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, conn *net.UDPConn, tracer *Tracer) error {
if s.closed.Load() {
conn.Close()
w.WriteHeader(http.StatusServiceUnavailable)
Expand All @@ -91,17 +91,23 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne
wg.Add(3)
go func() {
defer wg.Done()
if err := s.proxyConnSend(conn, str); err != nil {
if err := s.proxyConnSend(conn, str, tracer); err != nil {
log.Printf("proxying send side to %s failed: %v", conn.RemoteAddr(), err)
}
str.Close()
if tracer != nil && tracer.SendDirectionClosed != nil {
tracer.SendDirectionClosed()
}
}()
go func() {
defer wg.Done()
if err := s.proxyConnReceive(conn, str); err != nil && !s.closed.Load() {
if err := s.proxyConnReceive(conn, str, tracer); err != nil && !s.closed.Load() {
log.Printf("proxying receive side to %s failed: %v", conn.RemoteAddr(), err)
}
str.Close()
if tracer != nil && tracer.ReceiveDirectionClosed != nil {
tracer.ReceiveDirectionClosed()
}
}()
go func() {
defer wg.Done()
Expand All @@ -116,7 +122,7 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne
return nil
}

func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) error {
func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream, tracer *Tracer) error {
for {
data, err := str.ReceiveDatagram(context.Background())
if err != nil {
Expand All @@ -130,13 +136,17 @@ func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) error {
// Drop this datagram. We currently only support proxying of UDP payloads.
continue
}
if _, err := conn.Write(data[n:]); err != nil {
b := data[n:]
if _, err := conn.Write(b); err != nil {
return err
}
if tracer != nil && tracer.SentData != nil {
tracer.SentData(len(b))
}
}
}

func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) error {
func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream, tracer *Tracer) error {
b := make([]byte, 1500)
for {
n, err := conn.Read(b)
Expand All @@ -149,6 +159,9 @@ func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) error {
if err := str.SendDatagram(data); err != nil {
return err
}
if tracer != nil && tracer.ReceivedData != nil {
tracer.ReceivedData(n)
}
}
}

Expand Down
8 changes: 4 additions & 4 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestProxyCloseProxiedConn(t *testing.T) {
})
r, err := ParseRequest(req, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}"))
require.NoError(t, err)
go p.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r)
go p.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r, nil)
require.Equal(t, http.StatusOK, rec.Code)

b := make([]byte, 100)
Expand Down Expand Up @@ -103,7 +103,7 @@ func TestProxyDialFailure(t *testing.T) {
require.NoError(t, err)
rec := httptest.NewRecorder()

require.ErrorContains(t, p.Proxy(rec, req), "invalid port")
require.ErrorContains(t, p.Proxy(rec, req, nil), "invalid port")
require.Equal(t, http.StatusGatewayTimeout, rec.Code)
}

Expand All @@ -117,15 +117,15 @@ func TestProxyingAfterClose(t *testing.T) {

t.Run("proxying", func(t *testing.T) {
rec := httptest.NewRecorder()
require.ErrorIs(t, p.Proxy(rec, req), net.ErrClosed)
require.ErrorIs(t, p.Proxy(rec, req, nil), net.ErrClosed)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
})

t.Run("proxying connected socket", func(t *testing.T) {
rec := httptest.NewRecorder()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
require.ErrorIs(t, p.ProxyConnectedSocket(rec, req, conn), net.ErrClosed)
require.ErrorIs(t, p.ProxyConnectedSocket(rec, conn, nil), net.ErrClosed)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
})
}
13 changes: 13 additions & 0 deletions tracer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package masque

// A Tracer can be used to monitor the progress of a proxied connection.
type Tracer struct {
// SentData is called when data is sent towards the target.
SentData func(n int)
// SentDirectionClosed is called when the send direction (towards the target) is closed.
SendDirectionClosed func()
// ReceivedData is called when data is received from the target.
ReceivedData func(n int)
// ReceiveDirectionClosed is called when the receive direction (from the target) is closed.
ReceiveDirectionClosed func()
}
Loading