Skip to content

Commit

Permalink
[Prism] Implement jobservices.Server Cancel (#30178)
Browse files Browse the repository at this point in the history
* Implement jobservices.Server Cancel

* Small code cleanup

* Fix test err; canceled state after complete
  • Loading branch information
damondouglas authored Feb 5, 2024
1 parent c1b3a27 commit a47b1fa
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 0 deletions.
8 changes: 8 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package internal

import (
"context"
"errors"
"fmt"
"io"
"sort"
Expand Down Expand Up @@ -70,6 +71,13 @@ func RunPipeline(j *jobservices.Job) {
j.Failed(err)
return
}

if errors.Is(context.Cause(j.RootCtx), jobservices.ErrCancel) {
j.SendMsg("pipeline canceled " + j.String())
j.Canceled()
return
}

j.SendMsg("pipeline completed " + j.String())

j.SendMsg("terminating " + j.String())
Expand Down
10 changes: 10 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ func (j *Job) Done() {
j.sendState(jobpb.JobState_DONE)
}

// Canceling indicates that the job is canceling.
func (j *Job) Canceling() {
j.sendState(jobpb.JobState_CANCELLING)
}

// Canceled indicates that the job is canceled.
func (j *Job) Canceled() {
j.sendState(jobpb.JobState_CANCELLED)
}

// Failed indicates that the job completed unsuccessfully.
func (j *Job) Failed(err error) {
slog.Error("job failed", slog.Any("job", j), slog.Any("error", err))
Expand Down
30 changes: 30 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package jobservices

import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
Expand All @@ -30,6 +31,10 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)

var (
ErrCancel = errors.New("pipeline canceled")
)

func (s *Server) nextId() string {
v := atomic.AddUint32(&s.index, 1)
return fmt.Sprintf("job-%03d", v)
Expand Down Expand Up @@ -215,6 +220,31 @@ func (s *Server) Run(ctx context.Context, req *jobpb.RunJobRequest) (*jobpb.RunJ
}, nil
}

// Cancel a Job requested by the CancelJobRequest for jobs not in an already terminal state.
// Otherwise, returns nil if Job does not exist or the Job's existing state as part of the CancelJobResponse.
func (s *Server) Cancel(_ context.Context, req *jobpb.CancelJobRequest) (*jobpb.CancelJobResponse, error) {
s.mu.Lock()
job, ok := s.jobs[req.GetJobId()]
s.mu.Unlock()
if !ok {
return nil, nil
}
state := job.state.Load().(jobpb.JobState_Enum)
switch state {
case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, jobpb.JobState_DRAINED, jobpb.JobState_UPDATED, jobpb.JobState_FAILED:
// Already at terminal state.
return &jobpb.CancelJobResponse{
State: state,
}, nil
}
job.SendMsg("canceling " + job.String())
job.Canceling()
job.CancelFn(ErrCancel)
return &jobpb.CancelJobResponse{
State: jobpb.JobState_CANCELLING,
}, nil
}

// GetMessageStream subscribes to a stream of state changes and messages from the job. If throughput
// is high, this may cause losses of messages.
func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream jobpb.JobService_GetMessageStreamServer) error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,40 @@ func TestServer(t *testing.T) {
}
},
},
{
name: "Canceling",
noJobsCheck: func(ctx context.Context, t *testing.T, undertest *Server) {
resp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{JobId: "job-001"})
if resp != nil {
t.Errorf("Canceling(\"job-001\") = %s, want nil", resp)
}
if err != nil {
t.Errorf("Canceling(\"job-001\") = %v, want nil", err)
}
},
postPrepCheck: func(ctx context.Context, t *testing.T, undertest *Server) {
resp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{JobId: "job-001"})
if err != nil {
t.Errorf("Canceling(\"job-001\") = %v, want nil", err)
}
if diff := cmp.Diff(&jobpb.CancelJobResponse{
State: jobpb.JobState_CANCELLING,
}, resp, cmpOpts...); diff != "" {
t.Errorf("Canceling(\"job-001\") (-want, +got):\n%v", diff)
}
},
postRunCheck: func(ctx context.Context, t *testing.T, undertest *Server, jobID string) {
resp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{JobId: jobID})
if err != nil {
t.Errorf("Canceling(\"%s\") = %v, want nil", jobID, err)
}
if diff := cmp.Diff(&jobpb.CancelJobResponse{
State: jobpb.JobState_DONE,
}, resp, cmpOpts...); diff != "" {
t.Errorf("Canceling(\"%s\") (-want, +got):\n%v", jobID, diff)
}
},
},
}
for _, test := range tests {
var called sync.WaitGroup
Expand Down
61 changes: 61 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package jobservices

import (
"context"
"errors"
"sync"
"testing"

Expand Down Expand Up @@ -77,3 +78,63 @@ func TestServer_JobLifecycle(t *testing.T) {
t.Log("success!")
// Nothing to cleanup because we didn't start the server.
}

// Validates that invoking Cancel cancels a running job.
func TestServer_RunThenCancel(t *testing.T) {
var called sync.WaitGroup
called.Add(1)
undertest := NewServer(0, func(j *Job) {
if errors.Is(context.Cause(j.RootCtx), ErrCancel) {
j.state.Store(jobpb.JobState_CANCELLED)
called.Done()
}
})
ctx := context.Background()

wantPipeline := &pipepb.Pipeline{
Requirements: []string{urns.RequirementSplittableDoFn},
}
wantName := "testJob"

resp, err := undertest.Prepare(ctx, &jobpb.PrepareJobRequest{
Pipeline: wantPipeline,
JobName: wantName,
})
if err != nil {
t.Fatalf("server.Prepare() = %v, want nil", err)
}

if got := resp.GetPreparationId(); got == "" {
t.Fatalf("server.Prepare() = returned empty preparation ID, want non-empty: %v", prototext.Format(resp))
}

runResp, err := undertest.Run(ctx, &jobpb.RunJobRequest{
PreparationId: resp.GetPreparationId(),
})
if err != nil {
t.Fatalf("server.Run() = %v, want nil", err)
}
if got := runResp.GetJobId(); got == "" {
t.Fatalf("server.Run() = returned empty preparation ID, want non-empty")
}

cancelResp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{
JobId: runResp.GetJobId(),
})
if err != nil {
t.Fatalf("server.Canceling() = %v, want nil", err)
}
if cancelResp.State != jobpb.JobState_CANCELLING {
t.Fatalf("server.Canceling() = %v, want %v", cancelResp.State, jobpb.JobState_CANCELLING)
}

called.Wait()

stateResp, err := undertest.GetState(ctx, &jobpb.GetJobStateRequest{JobId: runResp.GetJobId()})
if err != nil {
t.Fatalf("server.GetState() = %v, want nil", err)
}
if stateResp.State != jobpb.JobState_CANCELLED {
t.Fatalf("server.GetState() = %v, want %v", stateResp.State, jobpb.JobState_CANCELLED)
}
}

0 comments on commit a47b1fa

Please sign in to comment.