diff --git a/executor.go b/executor.go index 789fc1d..fcbb925 100644 --- a/executor.go +++ b/executor.go @@ -3,6 +3,7 @@ package async import ( "context" "errors" + "fmt" "sync" "sync/atomic" ) @@ -13,12 +14,12 @@ type ExecutorStatus uint32 const ( ExecutorStatusRunning ExecutorStatus = iota ExecutorStatusTerminating - ExecutorStatusShutdown + ExecutorStatusShutDown ) var ( ErrExecutorQueueFull = errors.New("async: executor queue is full") - ErrExecutorShutdown = errors.New("async: executor is shut down") + ErrExecutorShutDown = errors.New("async: executor is shut down") ) // ExecutorService is an interface that defines a task executor. @@ -44,6 +45,7 @@ type ExecutorConfig struct { } // NewExecutorConfig returns a new [ExecutorConfig]. +// workerPoolSize must be positive and queueSize non-negative. func NewExecutorConfig(workerPoolSize, queueSize int) *ExecutorConfig { return &ExecutorConfig{ WorkerPoolSize: workerPoolSize, @@ -53,6 +55,7 @@ func NewExecutorConfig(workerPoolSize, queueSize int) *ExecutorConfig { // Executor implements the [ExecutorService] interface. type Executor[T any] struct { + mtx sync.RWMutex cancel context.CancelFunc queue chan executorJob[T] status atomic.Uint32 @@ -65,6 +68,16 @@ type executorJob[T any] struct { task func(context.Context) (T, error) } +// run executes the task, handling possible panics. +func (job *executorJob[T]) run(ctx context.Context) (result T, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("recovered: %v", r) + } + }() + return job.task(ctx) +} + // NewExecutor returns a new [Executor]. func NewExecutor[T any](ctx context.Context, config *ExecutorConfig) *Executor[T] { ctx, cancel := context.WithCancel(ctx) @@ -97,10 +110,11 @@ func (e *Executor[T]) startWorkers(ctx context.Context, poolSize int) { go func() { defer wg.Done() loop: + // check the status to break the loop even if the queue is not empty for ExecutorStatus(e.status.Load()) == ExecutorStatusRunning { select { case job := <-e.queue: - result, err := job.task(ctx) + result, err := job.run(ctx) if err != nil { job.promise.Failure(err) } else { @@ -115,30 +129,39 @@ func (e *Executor[T]) startWorkers(ctx context.Context, poolSize int) { // wait for all workers to exit wg.Wait() + // mark the executor as terminating + e.status.Store(uint32(ExecutorStatusTerminating)) + + // avoid submissions while draining the queue + e.mtx.Lock() + defer e.mtx.Unlock() + // close the queue and cancel all pending tasks close(e.queue) for job := range e.queue { - job.promise.Failure(ErrExecutorShutdown) + job.promise.Failure(ErrExecutorShutDown) } // mark the executor as shut down - e.status.Store(uint32(ExecutorStatusShutdown)) + e.status.Store(uint32(ExecutorStatusShutDown)) } // Submit submits a function to the executor. // The function will be executed asynchronously and the result will be // available via the returned future. func (e *Executor[T]) Submit(f func(context.Context) (T, error)) (Future[T], error) { - promise := NewPromise[T]() + e.mtx.RLock() + defer e.mtx.RUnlock() + if ExecutorStatus(e.status.Load()) == ExecutorStatusRunning { + promise := NewPromise[T]() select { case e.queue <- executorJob[T]{promise, f}: + return promise.Future(), nil default: return nil, ErrExecutorQueueFull } - } else { - return nil, ErrExecutorShutdown } - return promise.Future(), nil + return nil, ErrExecutorShutDown } // Shutdown shuts down the executor. diff --git a/executor_test.go b/executor_test.go index e27fed5..ce554a2 100644 --- a/executor_test.go +++ b/executor_test.go @@ -50,17 +50,17 @@ func TestExecutor(t *testing.T) { // verify that submit fails after the executor was shut down _, err = executor.Submit(job) - assert.ErrorIs(t, err, ErrExecutorShutdown) + assert.ErrorIs(t, err, ErrExecutorShutDown) // validate the executor status assert.Equal(t, executor.Status(), ExecutorStatusTerminating) time.Sleep(10 * time.Millisecond) - assert.Equal(t, executor.Status(), ExecutorStatusShutdown) + assert.Equal(t, executor.Status(), ExecutorStatusShutDown) assert.Equal(t, routines, runtime.NumGoroutine()+4) assertFutureResult(t, 1, future1, future2, future3, future4) - assertFutureError(t, ErrExecutorShutdown, future5, future6) + assertFutureError(t, ErrExecutorShutDown, future5, future6) } func TestExecutor_context(t *testing.T) { @@ -80,7 +80,30 @@ func TestExecutor_context(t *testing.T) { cancel() time.Sleep(5 * time.Millisecond) - assert.Equal(t, executor.Status(), ExecutorStatusShutdown) + + _, err = executor.Submit(job) + assert.ErrorIs(t, err, ErrExecutorShutDown) + + assert.Equal(t, executor.Status(), ExecutorStatusShutDown) +} + +func TestExecutor_jobPanic(t *testing.T) { + ctx := context.Background() + executor := NewExecutor[int](ctx, NewExecutorConfig(2, 2)) + + job := func(_ context.Context) (int, error) { + var i int + return 1 / i, nil + } + + future, err := executor.Submit(job) + assert.IsNil(t, err) + + result, err := future.Join() + assert.Equal(t, result, 0) + assert.ErrorContains(t, err, "integer divide by zero") + + _ = executor.Shutdown() } func submitJob[T any](t *testing.T, executor ExecutorService[T], @@ -88,7 +111,7 @@ func submitJob[T any](t *testing.T, executor ExecutorService[T], future, err := executor.Submit(f) assert.IsNil(t, err) - time.Sleep(time.Millisecond) // switch context + runtime.Gosched() return future }