Skip to content

Commit

Permalink
split Proxy.Upgrade into Proxy.ParseRequst and Proxy.Proxy (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann authored Jul 2, 2024
1 parent fa51376 commit 58c562b
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 94 deletions.
13 changes: 10 additions & 3 deletions cmd/proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"crypto/tls"
"errors"
"flag"
"log"
"log/slog"
Expand Down Expand Up @@ -53,11 +54,17 @@ func main() {
log.Fatalf("failed to parse URI template: %v", err)
}
http.HandleFunc(u.Path, func(w http.ResponseWriter, r *http.Request) {
if err := proxy.Upgrade(w, r); err != nil {
log.Printf("failed to upgrade request from %s: %v", r.RemoteAddr, err)
req, err := proxy.ParseRequest(r)
if err != nil {
var perr *masque.RequestParseError
if errors.As(err, &perr) {
w.WriteHeader(perr.HTTPStatus)
return
}
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
proxy.Proxy(w, req)
})
if err := server.ListenAndServe(); err != nil {
log.Fatalf("failed to run proxy: %v", err)
Expand Down
14 changes: 8 additions & 6 deletions connect-udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,16 @@ func TestProxying(t *testing.T) {
defer server.Close()
proxy := masque.Proxy{
Template: template,
Allow: func(context.Context, *net.UDPAddr) bool { return true },
}
defer proxy.Close()
mux.HandleFunc("/masque", func(w http.ResponseWriter, r *http.Request) {
if err := proxy.Upgrade(w, r); err != nil {
req, err := proxy.ParseRequest(r)
if err != nil {
t.Log("Upgrade failed:", err)
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
proxy.Proxy(w, req)
})
go func() {
if err := server.Serve(conn); err != nil {
Expand Down Expand Up @@ -114,14 +115,15 @@ func TestProxyShutdown(t *testing.T) {
defer server.Close()
proxy := masque.Proxy{
Template: template,
Allow: func(context.Context, *net.UDPAddr) bool { return true },
}
mux.HandleFunc("/masque", func(w http.ResponseWriter, r *http.Request) {
if err := proxy.Upgrade(w, r); err != nil {
req, err := proxy.ParseRequest(r)
if err != nil {
t.Log("Upgrade failed:", err)
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
proxy.Proxy(w, req)
})
go func() {
if err := server.Serve(conn); err != nil {
Expand Down
144 changes: 92 additions & 52 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package masque

import (
"context"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -46,98 +45,135 @@ type proxyEntry struct {
conn *net.UDPConn
}

// Request is returned from Proxy.ParseRequest.
// Target is the target server that the client requests to connect to.
// It can either be DNS name:port or an IP:port.
type Request struct {
Target string
}

// RequestParseError is returned from Proxy.ParseRequest if parsing the CONNECT-UDP request fails.
// It is recommended that the request is rejected with the corresponding HTTP status code.
type RequestParseError struct {
HTTPStatus int
Err error
}

func (e *RequestParseError) Error() string { return e.Err.Error() }
func (e *RequestParseError) Unwrap() error { return e.Err }

type Proxy struct {
// Template is the URI template that clients will use to configure this UDP proxy.
Template *uritemplate.Template

// Allow determines if a proxying request from a client is allowed to proceed.
// It is called after the requested target address has been resolved.
Allow func(context.Context, *net.UDPAddr) bool

// DialTarget is called when the proxy needs to open a new UDP socket to the target server.
// It must return a connected UDP socket.
// TODO(#3): support unconnected sockets.
DialTarget func(context.Context, *net.UDPAddr) (*net.UDPConn, error)

closed atomic.Bool

mx sync.Mutex
refCount sync.WaitGroup // counter for the Go routines spawned in Upgrade
conns map[proxyEntry]struct{}
}

func (s *Proxy) Upgrade(w http.ResponseWriter, r *http.Request) error {
if s.closed.Load() {
w.WriteHeader(http.StatusServiceUnavailable)
}

s.refCount.Add(1)
defer s.refCount.Done()

func (s *Proxy) ParseRequest(r *http.Request) (*Request, error) {
if r.Method != http.MethodConnect {
w.WriteHeader(http.StatusMethodNotAllowed)
return fmt.Errorf("expected CONNECT request, got %s", r.Method)
return nil, &RequestParseError{
HTTPStatus: http.StatusMethodNotAllowed,
Err: fmt.Errorf("expected CONNECT request, got %s", r.Method),
}
}
if r.Proto != requestProtocol {
w.WriteHeader(http.StatusNotImplemented)
return fmt.Errorf("unexpected protocol: %s", r.Proto)
return nil, &RequestParseError{
HTTPStatus: http.StatusNotImplemented,
Err: fmt.Errorf("unexpected protocol: %s", r.Proto),
}
}
// TODO: check :authority
capsuleHeaderValues, ok := r.Header[capsuleHeader]
if !ok {
w.WriteHeader(http.StatusBadRequest)
return fmt.Errorf("missing Capsule-Protocol header")
return nil, &RequestParseError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("missing Capsule-Protocol header"),
}
}
item, err := httpsfv.UnmarshalItem(capsuleHeaderValues)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return fmt.Errorf("invalid capsule header value: %s", r.Header[capsuleHeader])
return nil, &RequestParseError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("invalid capsule header value: %s", capsuleHeaderValues),
}
}
if v, ok := item.Value.(int64); !ok || v != 1 {
w.WriteHeader(http.StatusBadRequest)
return fmt.Errorf("incorrect capsule header value: %d", v)
return nil, &RequestParseError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("incorrect capsule header value: %d", v),
}
}

match := s.Template.Match(r.URL.String())
targetHostEncoded := match.Get(uriTemplateTargetHost).String()
targetPortStr := match.Get(uriTemplateTargetPort).String()
if targetHostEncoded == "" || targetPortStr == "" {
w.WriteHeader(http.StatusBadRequest)
return fmt.Errorf("expected target_host and target_port")
return nil, &RequestParseError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("expected target_host and target_port"),
}
}
targetHost, err := url.QueryUnescape(targetHostEncoded)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return fmt.Errorf("failed to decode target_host: %w", err)
return nil, &RequestParseError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("failed to decode target_host: %w", err),
}
}
targetPort, err := strconv.Atoi(targetPortStr)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return errors.New("failed to decode target_port")
return nil, &RequestParseError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("failed to decode target_port: %w", err),
}
}
w.Header().Set(capsuleHeader, capsuleProtocolHeaderValue)
return &Request{Target: fmt.Sprintf("%s:%d", targetHost, targetPort)}, nil
}

dst := fmt.Sprintf("%s:%d", targetHost, targetPort)
addr, err := net.ResolveUDPAddr("udp", dst)
// Proxy proxies a request on a newly created connected UDP socket.
// For more control over the UDP socket, use ProxyConnectedSocket.
// Applications may add custom header fields to the response header,
// but MUST NOT call WriteHeader on the http.ResponseWriter.
func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error {
if s.closed.Load() {
return net.ErrClosed
}

addr, err := net.ResolveUDPAddr("udp", r.Target)
if err != nil {
// TODO: set proxy-status header (might want to use structured headers)
w.WriteHeader(http.StatusGatewayTimeout)
return err
}
if s.Allow != nil && !s.Allow(r.Context(), addr) {
w.WriteHeader(http.StatusForbidden)
return errors.New("forbidden")
}

var conn *net.UDPConn
if s.DialTarget == nil {
conn, err = net.DialUDP("udp", nil, addr)
} else {
conn, err = s.DialTarget(r.Context(), addr)
}
conn, err := net.DialUDP("udp", nil, addr)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
// TODO: set proxy-status header (might want to use structured headers)
w.WriteHeader(http.StatusGatewayTimeout)
return err
}
defer conn.Close()

return s.ProxyConnectedSocket(w, r, conn)
}

// ProxyConnectedSocket proxies a request on a connected UDP socket.
// Applications may add custom header fields to the response header,
// but MUST NOT call WriteHeader on the http.ResponseWriter.
func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) error {
if s.closed.Load() {
return net.ErrClosed
}

s.refCount.Add(1)
defer s.refCount.Done()

w.Header().Set(capsuleHeader, capsuleProtocolHeaderValue)
w.WriteHeader(http.StatusOK)

str := w.(http3.HTTPStreamer).HTTPStream()

s.mx.Lock()
Expand All @@ -151,26 +187,29 @@ func (s *Proxy) Upgrade(w http.ResponseWriter, r *http.Request) error {
s.conns = make(map[proxyEntry]struct{})
}
s.conns[proxyEntry{str: str, conn: conn}] = struct{}{}
s.refCount.Add(3)
s.mx.Unlock()

var wg sync.WaitGroup
s.refCount.Add(3)
wg.Add(3)
go func() {
defer wg.Done()
defer s.refCount.Done()
if err := s.proxyConnSend(conn, str); err != nil {
log.Printf("proxying send side to %s failed: %v", conn.RemoteAddr(), err)
}
str.Close()
conn.Close()
}()
go func() {
defer wg.Done()
defer s.refCount.Done()
if err := s.proxyConnReceive(conn, str); err != nil && !s.closed.Load() {
log.Printf("proxying receive side to %s failed: %v", conn.RemoteAddr(), err)
}
str.Close()
conn.Close()
}()
go func() {
defer wg.Done()
defer s.refCount.Done()
// discard all capsules sent on the request stream
if err := skipCapsules(quicvarint.NewReader(str)); err == io.EOF {
Expand All @@ -179,6 +218,7 @@ func (s *Proxy) Upgrade(w http.ResponseWriter, r *http.Request) error {
str.Close()
conn.Close()
}()
wg.Wait()
return nil
}

Expand Down
Loading

0 comments on commit 58c562b

Please sign in to comment.