diff --git a/sdks/go/pkg/beam/runners/prism/internal/environments.go b/sdks/go/pkg/beam/runners/prism/internal/environments.go index 2f960a04f0cb..be4809f5e2f6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/environments.go +++ b/sdks/go/pkg/beam/runners/prism/internal/environments.go @@ -147,7 +147,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock ccr, err := cli.ContainerCreate(ctx, &container.Config{ Image: dp.GetContainerImage(), Cmd: []string{ - fmt.Sprintf("--id=%v-%v", wk.JobKey, wk.Env), + fmt.Sprintf("--id=%v", wk.ID), fmt.Sprintf("--control_endpoint=%v", wk.Endpoint()), fmt.Sprintf("--artifact_endpoint=%v", artifactEndpoint), fmt.Sprintf("--provision_endpoint=%v", wk.Endpoint()), diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index 8b56c30eb61b..2cc62769d2a9 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -53,13 +53,24 @@ func RunPipeline(j *jobservices.Job) { envs := j.Pipeline.GetComponents().GetEnvironments() wks := map[string]*worker.W{} for envID := range envs { - wk, err := makeWorker(envID, j) - if err != nil { - j.Failed(err) + wk := j.MakeWorker(envID) + wks[envID] = wk + if err := runEnvironment(j.RootCtx, j, envID, wk); err != nil { + j.Failed(fmt.Errorf("failed to start environment %v for job %v: %w", envID, j, err)) return } - wks[envID] = wk + // Check for connection succeeding after we've created the environment successfully. + timeout := 1 * time.Minute + time.AfterFunc(timeout, func() { + if wk.Connected() || wk.Stopped() { + return + } + err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, wk.Endpoint(), timeout) + j.Failed(err) + j.CancelFn(err) + }) } + // When this function exits, we cancel the context to clear // any related job resources. defer func() { @@ -86,33 +97,6 @@ func RunPipeline(j *jobservices.Job) { j.Done() } -// makeWorker creates a worker for that environment. -func makeWorker(env string, j *jobservices.Job) (*worker.W, error) { - wk := worker.New(j.String()+"_"+env, env) - - wk.EnvPb = j.Pipeline.GetComponents().GetEnvironments()[env] - wk.PipelineOptions = j.PipelineOptions() - wk.JobKey = j.JobKey() - wk.ArtifactEndpoint = j.ArtifactEndpoint() - - go wk.Serve() - - if err := runEnvironment(j.RootCtx, j, env, wk); err != nil { - return nil, fmt.Errorf("failed to start environment %v for job %v: %w", env, j, err) - } - // Check for connection succeeding after we've created the environment successfully. - timeout := 1 * time.Minute - time.AfterFunc(timeout, func() { - if wk.Connected() || wk.Stopped() { - return - } - err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, wk.Endpoint(), timeout) - j.Failed(err) - j.CancelFn(err) - }) - return wk, nil -} - type transformExecuter interface { ExecuteUrns() []string ExecuteTransform(stageID, tid string, t *pipepb.PTransform, comps *pipepb.Components, watermark mtime.Time, data [][]byte) *worker.B diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index 6158cd6d612c..4be64e5a9c80 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -38,6 +38,7 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "google.golang.org/protobuf/types/known/structpb" ) @@ -93,6 +94,7 @@ type Job struct { Logger *slog.Logger metrics metricsStore + mw *worker.MultiplexW } func (j *Job) ArtifactEndpoint() string { @@ -198,3 +200,14 @@ func (j *Job) Failed(err error) { j.sendState(jobpb.JobState_FAILED) j.CancelFn(fmt.Errorf("jobFailed %v: %w", j, err)) } + +// MakeWorker instantiates a worker.W populating environment and pipeline data from the Job. +func (j *Job) MakeWorker(env string) *worker.W { + wk := j.mw.MakeWorker(j.String()+"_"+env, env) + wk.EnvPb = j.Pipeline.GetComponents().GetEnvironments()[env] + wk.PipelineOptions = j.PipelineOptions() + wk.JobKey = j.JobKey() + wk.ArtifactEndpoint = j.ArtifactEndpoint() + + return wk +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index af559a92ab46..b9a28e4bc652 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -94,6 +94,7 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * }, Logger: s.logger, // TODO substitute with a configured logger. artifactEndpoint: s.Endpoint(), + mw: s.mw, } // Stop the idle timer when a new job appears. if idleTimer := s.idleTimer.Load(); idleTimer != nil { diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index bdfe2aff2dd4..fb55fc54bf93 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -28,6 +28,7 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "google.golang.org/grpc" ) @@ -60,6 +61,8 @@ type Server struct { // Artifact hack artifacts map[string][]byte + + mw *worker.MultiplexW } // NewServer acquires the indicated port. @@ -82,6 +85,9 @@ func NewServer(port int, execute func(*Job)) *Server { jobpb.RegisterJobServiceServer(s.server, s) jobpb.RegisterArtifactStagingServiceServer(s.server, s) jobpb.RegisterArtifactRetrievalServiceServer(s.server, s) + + s.mw = worker.NewMultiplexW(lis, s.server, s.logger) + return s } diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go index 161fb199ce96..08d30f67e445 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go @@ -25,7 +25,7 @@ import ( ) func TestBundle_ProcessOn(t *testing.T) { - wk := New("test", "testEnv") + wk := newWorker() b := &B{ InstID: "testInst", PBDID: "testPBDID", diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index 9d9058975b26..b4133b0332a6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -23,11 +23,9 @@ import ( "fmt" "io" "log/slog" - "math" "net" "sync" "sync/atomic" - "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -38,6 +36,7 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -55,16 +54,14 @@ type W struct { fnpb.UnimplementedBeamFnLoggingServer fnpb.UnimplementedProvisionServiceServer + parentPool *MultiplexW + ID, Env string JobKey, ArtifactEndpoint string EnvPb *pipepb.Environment PipelineOptions *structpb.Struct - // Server management - lis net.Listener - server *grpc.Server - // These are the ID sources inst uint64 connected, stopped atomic.Bool @@ -82,45 +79,8 @@ type controlResponder interface { Respond(*fnpb.InstructionResponse) } -// New starts the worker server components of FnAPI Execution. -func New(id, env string) *W { - lis, err := net.Listen("tcp", ":0") - if err != nil { - panic(fmt.Sprintf("failed to listen: %v", err)) - } - opts := []grpc.ServerOption{ - grpc.MaxRecvMsgSize(math.MaxInt32), - } - wk := &W{ - ID: id, - Env: env, - lis: lis, - server: grpc.NewServer(opts...), - - InstReqs: make(chan *fnpb.InstructionRequest, 10), - DataReqs: make(chan *fnpb.Elements, 10), - StoppedChan: make(chan struct{}), - - activeInstructions: make(map[string]controlResponder), - Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), - } - slog.Debug("Serving Worker components", slog.String("endpoint", wk.Endpoint())) - fnpb.RegisterBeamFnControlServer(wk.server, wk) - fnpb.RegisterBeamFnDataServer(wk.server, wk) - fnpb.RegisterBeamFnLoggingServer(wk.server, wk) - fnpb.RegisterBeamFnStateServer(wk.server, wk) - fnpb.RegisterProvisionServiceServer(wk.server, wk) - return wk -} - func (wk *W) Endpoint() string { - _, port, _ := net.SplitHostPort(wk.lis.Addr().String()) - return fmt.Sprintf("localhost:%v", port) -} - -// Serve serves on the started listener. Blocks. -func (wk *W) Serve() { - wk.server.Serve(wk.lis) + return wk.parentPool.endpoint } func (wk *W) String() string { @@ -154,16 +114,7 @@ func (wk *W) shutdown() { // Stop the GRPC server. func (wk *W) Stop() { wk.shutdown() - - // Give the SDK side 5 seconds to gracefully stop, before - // hard stopping all RPCs. - tim := time.AfterFunc(5*time.Second, func() { - wk.server.Stop() - }) - wk.server.GracefulStop() - tim.Stop() - - wk.lis.Close() + wk.parentPool.delete(wk) slog.Debug("stopped", "worker", wk) } @@ -710,3 +661,131 @@ func (wk *W) MonitoringMetadata(ctx context.Context, unknownIDs []string) *fnpb. }, }).GetMonitoringInfos() } + +// MultiplexW forwards FnAPI gRPC requests to W it manages in an in-memory pool. +type MultiplexW struct { + fnpb.UnimplementedBeamFnControlServer + fnpb.UnimplementedBeamFnDataServer + fnpb.UnimplementedBeamFnStateServer + fnpb.UnimplementedBeamFnLoggingServer + fnpb.UnimplementedProvisionServiceServer + + mu sync.Mutex + endpoint string + logger *slog.Logger + pool map[string]*W +} + +// NewMultiplexW instantiates a new FnAPI server for multiplexing FnAPI requests to a W. +func NewMultiplexW(lis net.Listener, g *grpc.Server, logger *slog.Logger) *MultiplexW { + _, p, _ := net.SplitHostPort(lis.Addr().String()) + mw := &MultiplexW{ + endpoint: "localhost:" + p, + logger: logger, + pool: make(map[string]*W), + } + + fnpb.RegisterBeamFnControlServer(g, mw) + fnpb.RegisterBeamFnDataServer(g, mw) + fnpb.RegisterBeamFnLoggingServer(g, mw) + fnpb.RegisterBeamFnStateServer(g, mw) + fnpb.RegisterProvisionServiceServer(g, mw) + + return mw +} + +// MakeWorker creates and registers a W, assigning id and env to W.ID and W.Env, respectively, associating W.ID +// to *W for later lookup. MultiplexW expects FnAPI gRPC requests to contain a matching 'worker_id' in its context +// metadata. A gRPC client should use the grpcx.WriteWorkerID helper method prior to sending the request. +func (mw *MultiplexW) MakeWorker(id, env string) *W { + mw.mu.Lock() + defer mw.mu.Unlock() + w := &W{ + ID: id, + Env: env, + + InstReqs: make(chan *fnpb.InstructionRequest, 10), + DataReqs: make(chan *fnpb.Elements, 10), + StoppedChan: make(chan struct{}), + + activeInstructions: make(map[string]controlResponder), + Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), + parentPool: mw, + } + mw.pool[id] = w + return w +} + +func (mw *MultiplexW) GetProvisionInfo(ctx context.Context, req *fnpb.GetProvisionInfoRequest) (*fnpb.GetProvisionInfoResponse, error) { + return handleUnary(mw, ctx, req, (*W).GetProvisionInfo) +} + +func (mw *MultiplexW) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { + return handleStream(mw, stream.Context(), stream, (*W).Logging) +} + +func (mw *MultiplexW) GetProcessBundleDescriptor(ctx context.Context, req *fnpb.GetProcessBundleDescriptorRequest) (*fnpb.ProcessBundleDescriptor, error) { + return handleUnary(mw, ctx, req, (*W).GetProcessBundleDescriptor) +} + +func (mw *MultiplexW) Control(ctrl fnpb.BeamFnControl_ControlServer) error { + return handleStream(mw, ctrl.Context(), ctrl, (*W).Control) +} + +func (mw *MultiplexW) Data(data fnpb.BeamFnData_DataServer) error { + return handleStream(mw, data.Context(), data, (*W).Data) +} + +func (mw *MultiplexW) State(state fnpb.BeamFnState_StateServer) error { + return handleStream(mw, state.Context(), state, (*W).State) +} + +func (mw *MultiplexW) MonitoringMetadata(ctx context.Context, unknownIDs []string) *fnpb.MonitoringInfosMetadataResponse { + mw.mu.Lock() + defer mw.mu.Unlock() + w, err := mw.workerFromMetadataCtx(ctx) + if err != nil { + mw.logger.Error(err.Error()) + return nil + } + return w.MonitoringMetadata(ctx, unknownIDs) +} + +func (mw *MultiplexW) workerFromMetadataCtx(ctx context.Context) (*W, error) { + mw.mu.Lock() + defer mw.mu.Unlock() + id, err := grpcx.ReadWorkerID(ctx) + if err != nil { + return nil, err + } + if id == "" { + return nil, fmt.Errorf("worker_id read from context metadata is an empty string") + } + w, ok := mw.pool[id] + if !ok { + return nil, fmt.Errorf("worker_id: '%s' read from context metadata but not registered in worker pool", id) + } + return w, nil +} + +func (mw *MultiplexW) delete(w *W) { + mw.mu.Lock() + defer mw.mu.Unlock() + delete(mw.pool, w.ID) +} + +func handleUnary[Request any, Response any, Method func(*W, context.Context, *Request) (*Response, error)](mw *MultiplexW, ctx context.Context, req *Request, m Method) (*Response, error) { + w, err := mw.workerFromMetadataCtx(ctx) + if err != nil { + return nil, err + } + return m(w, ctx, req) +} + +func handleStream[Stream any, Method func(*W, Stream) error](mw *MultiplexW, ctx context.Context, stream Stream, m Method) error { + w, err := mw.workerFromMetadataCtx(ctx) + if err != nil { + return err + } + return m(w, stream) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go index 469e0e2f3d83..a0cf577fbdba 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -18,34 +18,88 @@ package worker import ( "bytes" "context" + "log/slog" "net" "sort" "sync" "testing" "time" - "github.com/google/go-cmp/cmp" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" ) -func TestWorker_New(t *testing.T) { - w := New("test", "testEnv") +func TestMultiplexW_MakeWorker(t *testing.T) { + w := newWorker() + if w.parentPool == nil { + t.Errorf("MakeWorker instantiated W with a nil reference to MultiplexW") + } if got, want := w.ID, "test"; got != want { - t.Errorf("New(%q) = %v, want %v", want, got, want) + t.Errorf("MakeWorker(%q) = %v, want %v", want, got, want) + } + got, ok := w.parentPool.pool[w.ID] + if !ok || got == nil { + t.Errorf("MakeWorker(%q) not registered in worker pool %v", w.ID, w.parentPool.pool) + } +} + +func TestMultiplexW_workerFromMetadataCtx(t *testing.T) { + for _, tt := range []struct { + name string + ctx context.Context + want *W + wantErr string + }{ + { + name: "empty ctx metadata", + ctx: context.Background(), + wantErr: "failed to read metadata from context", + }, + { + name: "worker_id empty", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "")), + wantErr: "worker_id read from context metadata is an empty string", + }, + { + name: "mismatched worker_id", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "doesn't exist")), + wantErr: "worker_id: 'doesn't exist' read from context metadata but not registered in worker pool", + }, + { + name: "matched worker_id", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "test")), + want: &W{ID: "test"}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + w := newWorker() + got, err := w.parentPool.workerFromMetadataCtx(tt.ctx) + if err != nil && err.Error() != tt.wantErr { + t.Errorf("workerFromMetadataCtx() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr != "" { + return + } + if got.ID != tt.want.ID { + t.Errorf("workerFromMetadataCtx() id = %v, want %v", got.ID, tt.want.ID) + } + }) } } func TestWorker_NextInst(t *testing.T) { - w := New("test", "testEnv") + w := newWorker() instIDs := map[string]struct{}{} for i := 0; i < 100; i++ { @@ -57,7 +111,7 @@ func TestWorker_NextInst(t *testing.T) { } func TestWorker_GetProcessBundleDescriptor(t *testing.T) { - w := New("test", "testEnv") + w := newWorker() id := "available" w.Descriptors[id] = &fnpb.ProcessBundleDescriptor{ @@ -87,19 +141,21 @@ func serveTestWorker(t *testing.T) (context.Context, *W, *grpc.ClientConn) { ctx, cancelFn := context.WithCancel(context.Background()) t.Cleanup(cancelFn) - w := New("test", "testEnv") + g := grpc.NewServer() lis := bufconn.Listen(2048) - w.lis = lis - t.Cleanup(func() { w.Stop() }) - go w.Serve() - - clientConn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { - return lis.DialContext(ctx) - }), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + mw := NewMultiplexW(lis, g, slog.Default()) + t.Cleanup(func() { g.Stop() }) + go g.Serve(lis) + w := mw.MakeWorker("test", "testEnv") + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs("worker_id", w.ID)) + ctx = grpcx.WriteWorkerID(ctx, w.ID) + conn, err := grpc.DialContext(ctx, w.Endpoint(), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return lis.Dial() + }), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { t.Fatal("couldn't create bufconn grpc connection:", err) } - return ctx, w, clientConn + return ctx, w, conn } type closeSend func() @@ -465,3 +521,10 @@ func TestWorker_State_MultimapSideInput(t *testing.T) { }) } } + +func newWorker() *W { + mw := &MultiplexW{ + pool: map[string]*W{}, + } + return mw.MakeWorker("test", "testEnv") +}