Skip to content

Commit

Permalink
[#32167][Prism] Use the worker-id gRPC metadata (#33438)
Browse files Browse the repository at this point in the history
* Implement MultiplexW and Pool

* Add missing license header

* Add multiplex worker to prism execute

* remove unused props

* Fix Prism python precommit

* Handle worker_id is empty string error

* Fix python worker id interceptor

* default empty _worker_id

* Revert defaulting worker id

* Fix worker_id in docker env

* Update per PR comments

* Add lock/unlock to MultiplexW

* Delegate W deletion via MW

* Remove unnecessary guard

* Small fixes after PR review

* Add code comment to MakeWorker

* clean up commented out code

* Revert portable/common changes
  • Loading branch information
damondouglas authored Jan 3, 2025
1 parent 7e6cf18 commit f27547d
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 103 deletions.
2 changes: 1 addition & 1 deletion sdks/go/pkg/beam/runners/prism/internal/environments.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock
ccr, err := cli.ContainerCreate(ctx, &container.Config{
Image: dp.GetContainerImage(),
Cmd: []string{
fmt.Sprintf("--id=%v-%v", wk.JobKey, wk.Env),
fmt.Sprintf("--id=%v", wk.ID),
fmt.Sprintf("--control_endpoint=%v", wk.Endpoint()),
fmt.Sprintf("--artifact_endpoint=%v", artifactEndpoint),
fmt.Sprintf("--provision_endpoint=%v", wk.Endpoint()),
Expand Down
46 changes: 15 additions & 31 deletions sdks/go/pkg/beam/runners/prism/internal/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,24 @@ func RunPipeline(j *jobservices.Job) {
envs := j.Pipeline.GetComponents().GetEnvironments()
wks := map[string]*worker.W{}
for envID := range envs {
wk, err := makeWorker(envID, j)
if err != nil {
j.Failed(err)
wk := j.MakeWorker(envID)
wks[envID] = wk
if err := runEnvironment(j.RootCtx, j, envID, wk); err != nil {
j.Failed(fmt.Errorf("failed to start environment %v for job %v: %w", envID, j, err))
return
}
wks[envID] = wk
// Check for connection succeeding after we've created the environment successfully.
timeout := 1 * time.Minute
time.AfterFunc(timeout, func() {
if wk.Connected() || wk.Stopped() {
return
}
err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, wk.Endpoint(), timeout)
j.Failed(err)
j.CancelFn(err)
})
}

// When this function exits, we cancel the context to clear
// any related job resources.
defer func() {
Expand All @@ -86,33 +97,6 @@ func RunPipeline(j *jobservices.Job) {
j.Done()
}

// makeWorker creates a worker for that environment.
func makeWorker(env string, j *jobservices.Job) (*worker.W, error) {
wk := worker.New(j.String()+"_"+env, env)

wk.EnvPb = j.Pipeline.GetComponents().GetEnvironments()[env]
wk.PipelineOptions = j.PipelineOptions()
wk.JobKey = j.JobKey()
wk.ArtifactEndpoint = j.ArtifactEndpoint()

go wk.Serve()

if err := runEnvironment(j.RootCtx, j, env, wk); err != nil {
return nil, fmt.Errorf("failed to start environment %v for job %v: %w", env, j, err)
}
// Check for connection succeeding after we've created the environment successfully.
timeout := 1 * time.Minute
time.AfterFunc(timeout, func() {
if wk.Connected() || wk.Stopped() {
return
}
err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, wk.Endpoint(), timeout)
j.Failed(err)
j.CancelFn(err)
})
return wk, nil
}

type transformExecuter interface {
ExecuteUrns() []string
ExecuteTransform(stageID, tid string, t *pipepb.PTransform, comps *pipepb.Components, watermark mtime.Time, data [][]byte) *worker.B
Expand Down
13 changes: 13 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker"
"google.golang.org/protobuf/types/known/structpb"
)

Expand Down Expand Up @@ -93,6 +94,7 @@ type Job struct {
Logger *slog.Logger

metrics metricsStore
mw *worker.MultiplexW
}

func (j *Job) ArtifactEndpoint() string {
Expand Down Expand Up @@ -198,3 +200,14 @@ func (j *Job) Failed(err error) {
j.sendState(jobpb.JobState_FAILED)
j.CancelFn(fmt.Errorf("jobFailed %v: %w", j, err))
}

// MakeWorker instantiates a worker.W populating environment and pipeline data from the Job.
func (j *Job) MakeWorker(env string) *worker.W {
wk := j.mw.MakeWorker(j.String()+"_"+env, env)
wk.EnvPb = j.Pipeline.GetComponents().GetEnvironments()[env]
wk.PipelineOptions = j.PipelineOptions()
wk.JobKey = j.JobKey()
wk.ArtifactEndpoint = j.ArtifactEndpoint()

return wk
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ *
},
Logger: s.logger, // TODO substitute with a configured logger.
artifactEndpoint: s.Endpoint(),
mw: s.mw,
}
// Stop the idle timer when a new job appears.
if idleTimer := s.idleTimer.Load(); idleTimer != nil {
Expand Down
6 changes: 6 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -60,6 +61,8 @@ type Server struct {

// Artifact hack
artifacts map[string][]byte

mw *worker.MultiplexW
}

// NewServer acquires the indicated port.
Expand All @@ -82,6 +85,9 @@ func NewServer(port int, execute func(*Job)) *Server {
jobpb.RegisterJobServiceServer(s.server, s)
jobpb.RegisterArtifactStagingServiceServer(s.server, s)
jobpb.RegisterArtifactRetrievalServiceServer(s.server, s)

s.mw = worker.NewMultiplexW(lis, s.server, s.logger)

return s
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
)

func TestBundle_ProcessOn(t *testing.T) {
wk := New("test", "testEnv")
wk := newWorker()
b := &B{
InstID: "testInst",
PBDID: "testPBDID",
Expand Down
187 changes: 133 additions & 54 deletions sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ import (
"fmt"
"io"
"log/slog"
"math"
"net"
"sync"
"sync/atomic"
"time"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
Expand All @@ -38,6 +36,7 @@ import (
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
"github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -55,16 +54,14 @@ type W struct {
fnpb.UnimplementedBeamFnLoggingServer
fnpb.UnimplementedProvisionServiceServer

parentPool *MultiplexW

ID, Env string

JobKey, ArtifactEndpoint string
EnvPb *pipepb.Environment
PipelineOptions *structpb.Struct

// Server management
lis net.Listener
server *grpc.Server

// These are the ID sources
inst uint64
connected, stopped atomic.Bool
Expand All @@ -82,45 +79,8 @@ type controlResponder interface {
Respond(*fnpb.InstructionResponse)
}

// New starts the worker server components of FnAPI Execution.
func New(id, env string) *W {
lis, err := net.Listen("tcp", ":0")
if err != nil {
panic(fmt.Sprintf("failed to listen: %v", err))
}
opts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(math.MaxInt32),
}
wk := &W{
ID: id,
Env: env,
lis: lis,
server: grpc.NewServer(opts...),

InstReqs: make(chan *fnpb.InstructionRequest, 10),
DataReqs: make(chan *fnpb.Elements, 10),
StoppedChan: make(chan struct{}),

activeInstructions: make(map[string]controlResponder),
Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor),
}
slog.Debug("Serving Worker components", slog.String("endpoint", wk.Endpoint()))
fnpb.RegisterBeamFnControlServer(wk.server, wk)
fnpb.RegisterBeamFnDataServer(wk.server, wk)
fnpb.RegisterBeamFnLoggingServer(wk.server, wk)
fnpb.RegisterBeamFnStateServer(wk.server, wk)
fnpb.RegisterProvisionServiceServer(wk.server, wk)
return wk
}

func (wk *W) Endpoint() string {
_, port, _ := net.SplitHostPort(wk.lis.Addr().String())
return fmt.Sprintf("localhost:%v", port)
}

// Serve serves on the started listener. Blocks.
func (wk *W) Serve() {
wk.server.Serve(wk.lis)
return wk.parentPool.endpoint
}

func (wk *W) String() string {
Expand Down Expand Up @@ -154,16 +114,7 @@ func (wk *W) shutdown() {
// Stop the GRPC server.
func (wk *W) Stop() {
wk.shutdown()

// Give the SDK side 5 seconds to gracefully stop, before
// hard stopping all RPCs.
tim := time.AfterFunc(5*time.Second, func() {
wk.server.Stop()
})
wk.server.GracefulStop()
tim.Stop()

wk.lis.Close()
wk.parentPool.delete(wk)
slog.Debug("stopped", "worker", wk)
}

Expand Down Expand Up @@ -710,3 +661,131 @@ func (wk *W) MonitoringMetadata(ctx context.Context, unknownIDs []string) *fnpb.
},
}).GetMonitoringInfos()
}

// MultiplexW forwards FnAPI gRPC requests to W it manages in an in-memory pool.
type MultiplexW struct {
fnpb.UnimplementedBeamFnControlServer
fnpb.UnimplementedBeamFnDataServer
fnpb.UnimplementedBeamFnStateServer
fnpb.UnimplementedBeamFnLoggingServer
fnpb.UnimplementedProvisionServiceServer

mu sync.Mutex
endpoint string
logger *slog.Logger
pool map[string]*W
}

// NewMultiplexW instantiates a new FnAPI server for multiplexing FnAPI requests to a W.
func NewMultiplexW(lis net.Listener, g *grpc.Server, logger *slog.Logger) *MultiplexW {
_, p, _ := net.SplitHostPort(lis.Addr().String())
mw := &MultiplexW{
endpoint: "localhost:" + p,
logger: logger,
pool: make(map[string]*W),
}

fnpb.RegisterBeamFnControlServer(g, mw)
fnpb.RegisterBeamFnDataServer(g, mw)
fnpb.RegisterBeamFnLoggingServer(g, mw)
fnpb.RegisterBeamFnStateServer(g, mw)
fnpb.RegisterProvisionServiceServer(g, mw)

return mw
}

// MakeWorker creates and registers a W, assigning id and env to W.ID and W.Env, respectively, associating W.ID
// to *W for later lookup. MultiplexW expects FnAPI gRPC requests to contain a matching 'worker_id' in its context
// metadata. A gRPC client should use the grpcx.WriteWorkerID helper method prior to sending the request.
func (mw *MultiplexW) MakeWorker(id, env string) *W {
mw.mu.Lock()
defer mw.mu.Unlock()
w := &W{
ID: id,
Env: env,

InstReqs: make(chan *fnpb.InstructionRequest, 10),
DataReqs: make(chan *fnpb.Elements, 10),
StoppedChan: make(chan struct{}),

activeInstructions: make(map[string]controlResponder),
Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor),
parentPool: mw,
}
mw.pool[id] = w
return w
}

func (mw *MultiplexW) GetProvisionInfo(ctx context.Context, req *fnpb.GetProvisionInfoRequest) (*fnpb.GetProvisionInfoResponse, error) {
return handleUnary(mw, ctx, req, (*W).GetProvisionInfo)
}

func (mw *MultiplexW) Logging(stream fnpb.BeamFnLogging_LoggingServer) error {
return handleStream(mw, stream.Context(), stream, (*W).Logging)
}

func (mw *MultiplexW) GetProcessBundleDescriptor(ctx context.Context, req *fnpb.GetProcessBundleDescriptorRequest) (*fnpb.ProcessBundleDescriptor, error) {
return handleUnary(mw, ctx, req, (*W).GetProcessBundleDescriptor)
}

func (mw *MultiplexW) Control(ctrl fnpb.BeamFnControl_ControlServer) error {
return handleStream(mw, ctrl.Context(), ctrl, (*W).Control)
}

func (mw *MultiplexW) Data(data fnpb.BeamFnData_DataServer) error {
return handleStream(mw, data.Context(), data, (*W).Data)
}

func (mw *MultiplexW) State(state fnpb.BeamFnState_StateServer) error {
return handleStream(mw, state.Context(), state, (*W).State)
}

func (mw *MultiplexW) MonitoringMetadata(ctx context.Context, unknownIDs []string) *fnpb.MonitoringInfosMetadataResponse {
mw.mu.Lock()
defer mw.mu.Unlock()
w, err := mw.workerFromMetadataCtx(ctx)
if err != nil {
mw.logger.Error(err.Error())
return nil
}
return w.MonitoringMetadata(ctx, unknownIDs)
}

func (mw *MultiplexW) workerFromMetadataCtx(ctx context.Context) (*W, error) {
mw.mu.Lock()
defer mw.mu.Unlock()
id, err := grpcx.ReadWorkerID(ctx)
if err != nil {
return nil, err
}
if id == "" {
return nil, fmt.Errorf("worker_id read from context metadata is an empty string")
}
w, ok := mw.pool[id]
if !ok {
return nil, fmt.Errorf("worker_id: '%s' read from context metadata but not registered in worker pool", id)
}
return w, nil
}

func (mw *MultiplexW) delete(w *W) {
mw.mu.Lock()
defer mw.mu.Unlock()
delete(mw.pool, w.ID)
}

func handleUnary[Request any, Response any, Method func(*W, context.Context, *Request) (*Response, error)](mw *MultiplexW, ctx context.Context, req *Request, m Method) (*Response, error) {
w, err := mw.workerFromMetadataCtx(ctx)
if err != nil {
return nil, err
}
return m(w, ctx, req)
}

func handleStream[Stream any, Method func(*W, Stream) error](mw *MultiplexW, ctx context.Context, stream Stream, m Method) error {
w, err := mw.workerFromMetadataCtx(ctx)
if err != nil {
return err
}
return m(w, stream)
}
Loading

0 comments on commit f27547d

Please sign in to comment.