diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index b8bc68dcd1b..1aa95bc6ee1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -17,6 +17,7 @@ package internal import ( "context" + "errors" "fmt" "io" "sort" @@ -70,6 +71,13 @@ func RunPipeline(j *jobservices.Job) { j.Failed(err) return } + + if errors.Is(context.Cause(j.RootCtx), jobservices.ErrCancel) { + j.SendMsg("pipeline canceled " + j.String()) + j.Canceled() + return + } + j.SendMsg("pipeline completed " + j.String()) j.SendMsg("terminating " + j.String()) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index bb5eb88c919..6cde48ded9a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -177,6 +177,16 @@ func (j *Job) Done() { j.sendState(jobpb.JobState_DONE) } +// Canceling indicates that the job is canceling. +func (j *Job) Canceling() { + j.sendState(jobpb.JobState_CANCELLING) +} + +// Canceled indicates that the job is canceled. +func (j *Job) Canceled() { + j.sendState(jobpb.JobState_CANCELLED) +} + // Failed indicates that the job completed unsuccessfully. func (j *Job) Failed(err error) { slog.Error("job failed", slog.Any("job", j), slog.Any("error", err)) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index 323d8c46efb..0da37ef0bd7 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -17,6 +17,7 @@ package jobservices import ( "context" + "errors" "fmt" "sync" "sync/atomic" @@ -30,6 +31,10 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +var ( + ErrCancel = errors.New("pipeline canceled") +) + func (s *Server) nextId() string { v := atomic.AddUint32(&s.index, 1) return fmt.Sprintf("job-%03d", v) @@ -215,6 +220,31 @@ func (s *Server) Run(ctx context.Context, req *jobpb.RunJobRequest) (*jobpb.RunJ }, nil } +// Cancel a Job requested by the CancelJobRequest for jobs not in an already terminal state. +// Otherwise, returns nil if Job does not exist or the Job's existing state as part of the CancelJobResponse. +func (s *Server) Cancel(_ context.Context, req *jobpb.CancelJobRequest) (*jobpb.CancelJobResponse, error) { + s.mu.Lock() + job, ok := s.jobs[req.GetJobId()] + s.mu.Unlock() + if !ok { + return nil, nil + } + state := job.state.Load().(jobpb.JobState_Enum) + switch state { + case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, jobpb.JobState_DRAINED, jobpb.JobState_UPDATED, jobpb.JobState_FAILED: + // Already at terminal state. + return &jobpb.CancelJobResponse{ + State: state, + }, nil + } + job.SendMsg("canceling " + job.String()) + job.Canceling() + job.CancelFn(ErrCancel) + return &jobpb.CancelJobResponse{ + State: jobpb.JobState_CANCELLING, + }, nil +} + // GetMessageStream subscribes to a stream of state changes and messages from the job. If throughput // is high, this may cause losses of messages. func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream jobpb.JobService_GetMessageStreamServer) error { diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go index 5813e6ef73e..176abb8543a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go @@ -169,6 +169,40 @@ func TestServer(t *testing.T) { } }, }, + { + name: "Canceling", + noJobsCheck: func(ctx context.Context, t *testing.T, undertest *Server) { + resp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{JobId: "job-001"}) + if resp != nil { + t.Errorf("Canceling(\"job-001\") = %s, want nil", resp) + } + if err != nil { + t.Errorf("Canceling(\"job-001\") = %v, want nil", err) + } + }, + postPrepCheck: func(ctx context.Context, t *testing.T, undertest *Server) { + resp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{JobId: "job-001"}) + if err != nil { + t.Errorf("Canceling(\"job-001\") = %v, want nil", err) + } + if diff := cmp.Diff(&jobpb.CancelJobResponse{ + State: jobpb.JobState_CANCELLING, + }, resp, cmpOpts...); diff != "" { + t.Errorf("Canceling(\"job-001\") (-want, +got):\n%v", diff) + } + }, + postRunCheck: func(ctx context.Context, t *testing.T, undertest *Server, jobID string) { + resp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{JobId: jobID}) + if err != nil { + t.Errorf("Canceling(\"%s\") = %v, want nil", jobID, err) + } + if diff := cmp.Diff(&jobpb.CancelJobResponse{ + State: jobpb.JobState_DONE, + }, resp, cmpOpts...); diff != "" { + t.Errorf("Canceling(\"%s\") (-want, +got):\n%v", jobID, diff) + } + }, + }, } for _, test := range tests { var called sync.WaitGroup diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go index 2223f030ce1..473c84f958e 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go @@ -17,6 +17,7 @@ package jobservices import ( "context" + "errors" "sync" "testing" @@ -77,3 +78,63 @@ func TestServer_JobLifecycle(t *testing.T) { t.Log("success!") // Nothing to cleanup because we didn't start the server. } + +// Validates that invoking Cancel cancels a running job. +func TestServer_RunThenCancel(t *testing.T) { + var called sync.WaitGroup + called.Add(1) + undertest := NewServer(0, func(j *Job) { + if errors.Is(context.Cause(j.RootCtx), ErrCancel) { + j.state.Store(jobpb.JobState_CANCELLED) + called.Done() + } + }) + ctx := context.Background() + + wantPipeline := &pipepb.Pipeline{ + Requirements: []string{urns.RequirementSplittableDoFn}, + } + wantName := "testJob" + + resp, err := undertest.Prepare(ctx, &jobpb.PrepareJobRequest{ + Pipeline: wantPipeline, + JobName: wantName, + }) + if err != nil { + t.Fatalf("server.Prepare() = %v, want nil", err) + } + + if got := resp.GetPreparationId(); got == "" { + t.Fatalf("server.Prepare() = returned empty preparation ID, want non-empty: %v", prototext.Format(resp)) + } + + runResp, err := undertest.Run(ctx, &jobpb.RunJobRequest{ + PreparationId: resp.GetPreparationId(), + }) + if err != nil { + t.Fatalf("server.Run() = %v, want nil", err) + } + if got := runResp.GetJobId(); got == "" { + t.Fatalf("server.Run() = returned empty preparation ID, want non-empty") + } + + cancelResp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{ + JobId: runResp.GetJobId(), + }) + if err != nil { + t.Fatalf("server.Canceling() = %v, want nil", err) + } + if cancelResp.State != jobpb.JobState_CANCELLING { + t.Fatalf("server.Canceling() = %v, want %v", cancelResp.State, jobpb.JobState_CANCELLING) + } + + called.Wait() + + stateResp, err := undertest.GetState(ctx, &jobpb.GetJobStateRequest{JobId: runResp.GetJobId()}) + if err != nil { + t.Fatalf("server.GetState() = %v, want nil", err) + } + if stateResp.State != jobpb.JobState_CANCELLED { + t.Fatalf("server.GetState() = %v, want %v", stateResp.State, jobpb.JobState_CANCELLED) + } +}