Skip to content

Commit

Permalink
add a DialAddr method to the client (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann authored Jul 4, 2024
1 parent ada223f commit f8c4793
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 15 deletions.
30 changes: 28 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ type Client struct {
rt *http3.SingleDestinationRoundTripper
}

// DialAddr dials a proxied connection to a target server.
// The target address is sent to the proxy, and the DNS resolution is left to the proxy.
// The target must be given as a host:port.
func (c *Client) DialAddr(ctx context.Context, target string) (net.PacketConn, error) {
if c.Template == nil {
return nil, errors.New("masque: no template")
}
host, port, err := net.SplitHostPort(target)
if err != nil {
return nil, fmt.Errorf("failed to parse target: %w", err)
}
str, err := c.Template.Expand(uritemplate.Values{
uriTemplateTargetHost: uritemplate.String(host),
uriTemplateTargetPort: uritemplate.String(port),
})
if err != nil {
return nil, fmt.Errorf("masque: failed to expand Template: %w", err)
}
return c.dial(ctx, str)
}

// Dial dials a proxied connection to a target server.
func (c *Client) Dial(ctx context.Context, raddr *net.UDPAddr) (net.PacketConn, error) {
if c.Template == nil {
return nil, errors.New("masque: no template")
Expand All @@ -48,7 +70,11 @@ func (c *Client) Dial(ctx context.Context, raddr *net.UDPAddr) (net.PacketConn,
if err != nil {
return nil, fmt.Errorf("masque: failed to expand Template: %w", err)
}
u, err := url.Parse(str)
return c.dial(ctx, str)
}

func (c *Client) dial(ctx context.Context, expandedTemplate string) (net.PacketConn, error) {
u, err := url.Parse(expandedTemplate)
if err != nil {
return nil, fmt.Errorf("masque: failed to parse URI: %w", err)
}
Expand Down Expand Up @@ -120,7 +146,7 @@ func (c *Client) Dial(ctx context.Context, raddr *net.UDPAddr) (net.PacketConn,
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
return nil, fmt.Errorf("masque: server responded with %d", rsp.StatusCode)
}
return newProxiedConn(rstr, conn.LocalAddr(), raddr), nil
return newProxiedConn(rstr, conn.LocalAddr()), nil
}

func (c *Client) Close() error {
Expand Down
18 changes: 7 additions & 11 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package masque
import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
Expand Down Expand Up @@ -41,12 +40,11 @@ type proxiedConn struct {

var _ net.PacketConn = &proxiedConn{}

func newProxiedConn(str http3.Stream, local, remote net.Addr) *proxiedConn {
func newProxiedConn(str http3.Stream, local net.Addr) *proxiedConn {
c := &proxiedConn{
str: str,
localAddr: local,
remoteAddr: remote,
readDone: make(chan struct{}),
str: str,
localAddr: local,
readDone: make(chan struct{}),
}
c.readCtx, c.readCtxCancel = context.WithCancel(context.Background())
go func() {
Expand Down Expand Up @@ -85,11 +83,9 @@ start:
return n, c.remoteAddr, nil
}

func (c *proxiedConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
// A CONNECT-UDP connection mirrors a connected UDP socket.
if addr != c.remoteAddr {
return 0, fmt.Errorf("unexpected remote address: %s, expected %s", addr, c.remoteAddr)
}
// WriteTo sends a UDP datagram to the target.
// The net.Addr parameter is ignored.
func (c *proxiedConn) WriteTo(p []byte, _ net.Addr) (n int, err error) {
return len(p), c.str.SendDatagram(p)
}

Expand Down
2 changes: 1 addition & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestReadDeadline(t *testing.T) {
<-done
return 0, errors.New("test done")
}).MaxTimes(1)
return str, newProxiedConn(str, nil, nil)
return str, newProxiedConn(str, nil)
}

t.Run("read after deadline", func(t *testing.T) {
Expand Down
70 changes: 69 additions & 1 deletion connect-udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func runEchoServer(t *testing.T) *net.UDPConn {
return conn
}

func TestProxying(t *testing.T) {
func TestProxyToIP(t *testing.T) {
remoteServerConn := runEchoServer(t)
defer remoteServerConn.Close()

Expand Down Expand Up @@ -93,6 +93,74 @@ func TestProxying(t *testing.T) {
require.Equal(t, []byte("foobar"), b[:n])
}

func TestProxyToHostname(t *testing.T) {
remoteServerConn := runEchoServer(t)
defer remoteServerConn.Close()

conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
t.Logf("server listening on %s", conn.LocalAddr())
template := uritemplate.MustNew(fmt.Sprintf("https://localhost:%d/masque?h={target_host}&p={target_port}", conn.LocalAddr().(*net.UDPAddr).Port))

mux := http.NewServeMux()
server := http3.Server{
TLSConfig: tlsConf,
QUICConfig: &quic.Config{EnableDatagrams: true},
EnableDatagrams: true,
Handler: mux,
}
defer server.Close()
proxy := masque.Proxy{}
defer proxy.Close()
mux.HandleFunc("/masque", func(w http.ResponseWriter, r *http.Request) {
req, err := masque.ParseRequest(r, template)
if err != nil {
t.Log("Upgrade failed:", err)
w.WriteHeader(http.StatusBadRequest)
return
}
if req.Target != "quic-go.net:1234" {
t.Log("unexpected request target:", req.Target)
w.WriteHeader(http.StatusServiceUnavailable)
}
// In this test, we don't actually want to connect to quic-go.net
// Replace the target with the UDP echoer we spun up earlier.
req.Target = remoteServerConn.LocalAddr().String()
proxy.Proxy(w, req)
})
go func() {
if err := server.Serve(conn); err != nil {
return
}
}()

cl := masque.Client{
Template: template,
TLSClientConfig: &tls.Config{ClientCAs: certPool, NextProtos: []string{http3.NextProtoH3}, InsecureSkipVerify: true},
}
defer cl.Close()
proxiedConn, err := cl.DialAddr(context.Background(), "quic-go.net:1234") // the proxy doesn't actually resolve this hostname
require.NoError(t, err)

_, err = proxiedConn.WriteTo([]byte("foobar"), nil)
require.NoError(t, err)
b := make([]byte, 1500)
n, _, err := proxiedConn.ReadFrom(b)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), b[:n])
}

func TestProxyToHostnameMissingPort(t *testing.T) {
cl := masque.Client{
Template: uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}"),
TLSClientConfig: &tls.Config{ClientCAs: certPool, NextProtos: []string{http3.NextProtoH3}, InsecureSkipVerify: true},
}
defer cl.Close()
_, err := cl.DialAddr(context.Background(), "quic-go.net") // missing port
require.ErrorContains(t, err, "address quic-go.net: missing port in address")
}

func TestProxyShutdown(t *testing.T) {
remoteServerConn := runEchoServer(t)
defer remoteServerConn.Close()
Expand Down

0 comments on commit f8c4793

Please sign in to comment.