Skip to content

Commit

Permalink
proxy: add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andydunstall committed May 3, 2024
1 parent abdce25 commit 2177694
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 81 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ pico:

.PHONY: unit-test
unit-test:
go test ./...
go test ./... -v

.PHONY: integration-test
integration-test:
go test ./... -tags integration
go test ./... -tags integration -v

.PHONY: fmt
fmt:
Expand Down
2 changes: 1 addition & 1 deletion agent/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (e *endpoint) ProxyHTTP(r *http.Request) (*http.Response, error) {
return e.forwarder.Forward(r)
}

func (e *endpoint) connect(ctx context.Context) (*rpc.Stream, error) {
func (e *endpoint) connect(ctx context.Context) (rpc.Stream, error) {
backoff := time.Second
for {
conn, err := conn.DialWebsocket(ctx, e.serverURL())
Expand Down
47 changes: 29 additions & 18 deletions pkg/rpc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,18 @@ type message struct {
//
// Incoming RPC requests are handled in their own goroutine to avoid blocking
// the stream.
type Stream struct {
type Stream interface {
Addr() string
RPC(ctx context.Context, rpcType Type, req []byte) ([]byte, error)
Monitor(
ctx context.Context,
interval time.Duration,
timeout time.Duration,
) error
Close() error
}

type stream struct {
conn conn.Conn
handler *Handler

Expand All @@ -56,8 +67,8 @@ type Stream struct {

// NewStream creates an RPC stream on top of the given message-oriented
// connection.
func NewStream(conn conn.Conn, handler *Handler, logger log.Logger) *Stream {
stream := &Stream{
func NewStream(conn conn.Conn, handler *Handler, logger log.Logger) Stream {
stream := &stream{
conn: conn,
handler: handler,
nextMessageID: atomic.NewUint64(0),
Expand All @@ -73,15 +84,15 @@ func NewStream(conn conn.Conn, handler *Handler, logger log.Logger) *Stream {
return stream
}

func (s *Stream) Addr() string {
func (s *stream) Addr() string {
return s.conn.Addr()
}

// RPC sends the given request message to the peer and returns the response or
// an error.
//
// RPC is thread safe.
func (s *Stream) RPC(ctx context.Context, rpcType Type, req []byte) ([]byte, error) {
func (s *stream) RPC(ctx context.Context, rpcType Type, req []byte) ([]byte, error) {
header := &header{
RPCType: rpcType,
ID: s.nextMessageID.Inc(),
Expand Down Expand Up @@ -117,7 +128,7 @@ func (s *Stream) RPC(ctx context.Context, rpcType Type, req []byte) ([]byte, err
}

// Monitor monitors the stream is healthy using heartbeats.
func (s *Stream) Monitor(
func (s *stream) Monitor(
ctx context.Context,
interval time.Duration,
timeout time.Duration,
Expand All @@ -141,11 +152,11 @@ func (s *Stream) Monitor(
}
}

func (s *Stream) Close() error {
func (s *stream) Close() error {
return s.closeStream(ErrStreamClosed)
}

func (s *Stream) reader() {
func (s *stream) reader() {
defer s.recoverPanic("reader()")

for {
Expand Down Expand Up @@ -192,7 +203,7 @@ func (s *Stream) reader() {
}
}

func (s *Stream) writer() {
func (s *stream) writer() {
defer s.recoverPanic("writer()")

for {
Expand All @@ -216,7 +227,7 @@ func (s *Stream) writer() {
}
}

func (s *Stream) write(req *message) error {
func (s *stream) write(req *message) error {
w, err := s.conn.NextWriter()
if err != nil {
return err
Expand All @@ -232,7 +243,7 @@ func (s *Stream) write(req *message) error {
return w.Close()
}

func (s *Stream) closeStream(err error) error {
func (s *stream) closeStream(err error) error {
// Only shutdown once.
if !s.shutdown.CompareAndSwap(false, true) {
return ErrStreamClosed
Expand All @@ -254,7 +265,7 @@ func (s *Stream) closeStream(err error) error {
return nil
}

func (s *Stream) handleRequest(m *message) {
func (s *stream) handleRequest(m *message) {
handlerFunc, ok := s.handler.Find(m.Header.RPCType)
if !ok {
// If no handler is found, send a 'not supported' error to the client.
Expand Down Expand Up @@ -303,7 +314,7 @@ func (s *Stream) handleRequest(m *message) {
}
}

func (s *Stream) handleResponse(m *message) {
func (s *stream) handleResponse(m *message) {
// If no handler is found, it means RPC has already returned so discard
// the response.
ch, ok := s.findResponseHandler(m.Header.ID)
Expand All @@ -312,35 +323,35 @@ func (s *Stream) handleResponse(m *message) {
}
}

func (s *Stream) recoverPanic(prefix string) {
func (s *stream) recoverPanic(prefix string) {
if r := recover(); r != nil {
_ = s.closeStream(fmt.Errorf("panic: %s: %v", prefix, r))
}
}

func (s *Stream) registerResponseHandler(id uint64, ch chan<- *message) {
func (s *stream) registerResponseHandler(id uint64, ch chan<- *message) {
s.responseHandlersMu.Lock()
defer s.responseHandlersMu.Unlock()

s.responseHandlers[id] = ch
}

func (s *Stream) unregisterResponseHandler(id uint64) {
func (s *stream) unregisterResponseHandler(id uint64) {
s.responseHandlersMu.Lock()
defer s.responseHandlersMu.Unlock()

delete(s.responseHandlers, id)
}

func (s *Stream) findResponseHandler(id uint64) (chan<- *message, bool) {
func (s *stream) findResponseHandler(id uint64) (chan<- *message, bool) {
s.responseHandlersMu.Lock()
defer s.responseHandlersMu.Unlock()

ch, ok := s.responseHandlers[id]
return ch, ok
}

func (s *Stream) heartbeat(ctx context.Context, timeout time.Duration) error {
func (s *stream) heartbeat(ctx context.Context, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

Expand Down
2 changes: 1 addition & 1 deletion pkg/rpc/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ type Type uint16

const (
// TypeHeartbeat sends health checks between peers.
TypeHeartbeat = iota + 1
TypeHeartbeat Type = iota + 1
// TypeProxyHTTP sends a HTTP request and response between the Pico server
// and an upstream listener.
TypeProxyHTTP
Expand Down
24 changes: 13 additions & 11 deletions server/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ import (
)

type localEndpoint struct {
streams []*rpc.Stream
streams []rpc.Stream
nextIndex int
}

func (e *localEndpoint) AddUpstream(s *rpc.Stream) {
func (e *localEndpoint) AddUpstream(s rpc.Stream) {
e.streams = append(e.streams, s)
}

func (e *localEndpoint) RemoveUpstream(s *rpc.Stream) bool {
func (e *localEndpoint) RemoveUpstream(s rpc.Stream) bool {
for i := 0; i != len(e.streams); i++ {
if e.streams[i] == s {
e.streams = append(e.streams[:i], e.streams[i+1:]...)
Expand All @@ -45,7 +45,7 @@ func (e *localEndpoint) RemoveUpstream(s *rpc.Stream) bool {
return len(e.streams) == 0
}

func (e *localEndpoint) Next() *rpc.Stream {
func (e *localEndpoint) Next() rpc.Stream {
if len(e.streams) == 0 {
return nil
}
Expand Down Expand Up @@ -144,7 +144,7 @@ func (p *Proxy) Request(ctx context.Context, r *http.Request) (*http.Response, e
return resp, nil
}

func (p *Proxy) AddUpstream(endpointID string, stream *rpc.Stream) {
func (p *Proxy) AddUpstream(endpointID string, stream rpc.Stream) {
p.networkMap.AddLocalEndpoint(endpointID)

p.mu.Lock()
Expand All @@ -166,7 +166,7 @@ func (p *Proxy) AddUpstream(endpointID string, stream *rpc.Stream) {
p.metrics.Listeners.Inc()
}

func (p *Proxy) RemoveUpstream(endpointID string, stream *rpc.Stream) {
func (p *Proxy) RemoveUpstream(endpointID string, stream rpc.Stream) {
p.networkMap.RemoveLocalEndpoint(endpointID)

p.mu.Lock()
Expand Down Expand Up @@ -216,13 +216,13 @@ func (p *Proxy) request(

return nil, &status.ErrorInfo{
StatusCode: http.StatusServiceUnavailable,
Message: "no upstream found",
Message: "endpoint not found",
}
}

// lookupLocalListener looks up an RPC stream for an upstream listener for this
// endpoint.
func (p *Proxy) lookupLocalListener(endpointID string) (*rpc.Stream, bool) {
func (p *Proxy) lookupLocalListener(endpointID string) (rpc.Stream, bool) {
p.mu.Lock()
defer p.mu.Unlock()

Expand All @@ -236,7 +236,7 @@ func (p *Proxy) lookupLocalListener(endpointID string) (*rpc.Stream, bool) {

func (p *Proxy) requestLocal(
ctx context.Context,
stream *rpc.Stream,
stream rpc.Stream,
r *http.Request,
) (*http.Response, error) {
// Write the HTTP request to a buffer.
Expand All @@ -257,13 +257,13 @@ func (p *Proxy) requestLocal(
if errors.Is(err, context.DeadlineExceeded) {
return nil, &status.ErrorInfo{
StatusCode: http.StatusGatewayTimeout,
Message: "upstream timeout",
Message: "endpoint timeout",
}
}

return nil, &status.ErrorInfo{
StatusCode: http.StatusServiceUnavailable,
Message: "upstream unreachable",
Message: "endpoint unreachable",
}
}

Expand Down Expand Up @@ -346,6 +346,8 @@ func (p *Proxy) requestRemote(
return resp, nil
}

// parseEndpointID returns the endpoint ID from the HTTP request, or an empty
// string if no endpoint ID is specified.
func parseEndpointID(r *http.Request) string {
endpointID := r.Header.Get("x-pico-endpoint")
if endpointID != "" {
Expand Down
Loading

0 comments on commit 2177694

Please sign in to comment.