diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index de9edc1..57926f0 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -5,7 +5,9 @@ package drpcmanager import ( "context" + "errors" "fmt" + "io" "time" "github.com/zeebo/errs" @@ -294,7 +296,11 @@ func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { select { case <-m.sigs.term.Signal(): - stream.Cancel(context.Canceled) + err := m.sigs.term.Err() + if errors.Is(err, io.EOF) { + err = context.Canceled + } + stream.Cancel(err) <-m.sterm return diff --git a/internal/integration/transport_test.go b/internal/integration/transport_test.go index d86675b..55f6a42 100644 --- a/internal/integration/transport_test.go +++ b/internal/integration/transport_test.go @@ -5,6 +5,8 @@ package integration import ( "context" + "errors" + "io" "testing" "github.com/zeebo/assert" @@ -80,13 +82,14 @@ func TestTransport_ErrorCausesCancel(t *testing.T) { // create a channel to signal when the rpc has started started := make(chan struct{}) - errs := make(chan error, 2) + serr := make(chan error, 1) + cerr := make(chan error, 1) // create a server that signals then waits for the context to die cli, close := createConnection(impl{ Method2Fn: func(stream DRPCService_Method2Stream) error { started <- struct{}{} - errs <- stream.MsgRecv(nil, Encoding) + serr <- stream.MsgRecv(nil, Encoding) return nil }, }) @@ -96,7 +99,7 @@ func TestTransport_ErrorCausesCancel(t *testing.T) { ctx.Run(func(ctx context.Context) { stream, _ := cli.Method2(ctx) started <- struct{}{} - errs <- stream.MsgRecv(nil, Encoding) + cerr <- stream.MsgRecv(nil, Encoding) }) // wait for it to be started. it is important to wait for @@ -111,6 +114,6 @@ func TestTransport_ErrorCausesCancel(t *testing.T) { assert.NoError(t, cli.DRPCConn().(*drpcconn.Conn).Transport().Close()) // ensure both of the errors we sent are canceled - assert.Equal(t, <-errs, context.Canceled) - assert.Equal(t, <-errs, context.Canceled) + assert.That(t, errors.Is(<-serr, context.Canceled)) + assert.That(t, errors.Is(<-cerr, io.ErrClosedPipe)) }