diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go index 63f1a831e26d..0cf15fea32cf 100644 --- a/api/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -41,7 +41,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. var fn func() ([]float32, error) switch model := inferenceModel.(type) { - case *grpc.Client: + case grpc.Backend: fn = func() ([]float32, error) { predictOptions := gRPCPredictOpts(c, loader.ModelPath) if len(tokens) > 0 { diff --git a/api/backend/llm.go b/api/backend/llm.go index bd320b6155ab..9e202c53c53b 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -31,7 +31,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode grpcOpts := gRPCModelOpts(c) - var inferenceModel *grpc.Client + var inferenceModel grpc.Backend var err error opts := modelOpts(c, o, []model.Option{ diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go new file mode 100644 index 000000000000..ae8ffc5fe714 --- /dev/null +++ b/pkg/grpc/backend.go @@ -0,0 +1,46 @@ +package grpc + +import ( + "context" + "github.com/go-skynet/LocalAI/api/schema" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "google.golang.org/grpc" +) + +var embeds = map[string]*embedBackend{} + +func Provide(addr string, llm LLM) { + embeds[addr] = &embedBackend{s: &server{llm: llm}} +} + +func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend { + if bc, ok := embeds[address]; ok { + return bc + } + return NewGrpcClient(address, parallel, wd, enableWatchDog) +} + +func NewGrpcClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend { + if !enableWatchDog { + wd = nil + } + return &Client{ + address: address, + parallel: parallel, + wd: wd, + } +} + +type Backend interface { + IsBusy() bool + HealthCheck(ctx context.Context) (bool, error) + Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) + Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) + LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) + PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error + GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) + TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) + AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) + TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) + Status(ctx context.Context) (*pb.StatusResponse, error) +} diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 6f7f83bd434d..5e97ea73e068 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -27,17 +27,6 @@ type WatchDog interface { UnMark(address string) } -func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) *Client { - if !enableWatchDog { - wd = nil - } - return &Client{ - address: address, - parallel: parallel, - wd: wd, - } -} - func (c *Client) IsBusy() bool { c.Lock() defer c.Unlock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go new file mode 100644 index 000000000000..b9ab551f63c3 --- /dev/null +++ b/pkg/grpc/embed.go @@ -0,0 +1,121 @@ +package grpc + +import ( + "context" + "github.com/go-skynet/LocalAI/api/schema" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "time" +) + +var _ Backend = new(embedBackend) +var _ pb.Backend_PredictStreamServer = new(embedBackendServerStream) + +type embedBackend struct { + s *server +} + +func (e *embedBackend) IsBusy() bool { + return e.s.llm.Busy() +} + +func (e *embedBackend) HealthCheck(ctx context.Context) (bool, error) { + return true, nil +} + +func (e *embedBackend) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { + return e.s.Embedding(ctx, in) +} + +func (e *embedBackend) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { + return e.s.Predict(ctx, in) +} + +func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) { + return e.s.LoadModel(ctx, in) +} + +func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { + bs := &embedBackendServerStream{ + ctx: ctx, + fn: f, + } + return e.s.PredictStream(in, bs) +} + +func (e *embedBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { + return e.s.GenerateImage(ctx, in) +} + +func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { + return e.s.TTS(ctx, in) +} + +func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { + r, err := e.s.AudioTranscription(ctx, in) + if err != nil { + return nil, err + } + tr := &schema.Result{} + for _, s := range r.Segments { + var tks []int + for _, t := range s.Tokens { + tks = append(tks, int(t)) + } + tr.Segments = append(tr.Segments, + schema.Segment{ + Text: s.Text, + Id: int(s.Id), + Start: time.Duration(s.Start), + End: time.Duration(s.End), + Tokens: tks, + }) + } + tr.Text = r.Text + return tr, err +} + +func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { + return e.s.TokenizeString(ctx, in) +} + +func (e *embedBackend) Status(ctx context.Context) (*pb.StatusResponse, error) { + return e.s.Status(ctx, &pb.HealthMessage{}) +} + +type embedBackendServerStream struct { + ctx context.Context + fn func(s []byte) +} + +func (e *embedBackendServerStream) Send(reply *pb.Reply) error { + e.fn(reply.GetMessage()) + return nil +} + +func (e *embedBackendServerStream) SetHeader(md metadata.MD) error { + return nil +} + +func (e *embedBackendServerStream) SendHeader(md metadata.MD) error { + return nil +} + +func (e *embedBackendServerStream) SetTrailer(md metadata.MD) { +} + +func (e *embedBackendServerStream) Context() context.Context { + return e.ctx +} + +func (e *embedBackendServerStream) SendMsg(m any) error { + if x, ok := m.(*pb.Reply); ok { + return e.Send(x) + } + return nil +} + +func (e *embedBackendServerStream) RecvMsg(m any) error { + return nil +} diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 24dbe098eb16..07d055d99717 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -181,3 +181,23 @@ func StartServer(address string, model LLM) error { return nil } + +func RunServer(address string, model LLM) (func() error, error) { + lis, err := net.Listen("tcp", address) + if err != nil { + return nil, err + } + s := grpc.NewServer() + pb.RegisterBackendServer(s, &server{llm: model}) + log.Printf("gRPC Server listening at %v", lis.Addr()) + if err = s.Serve(lis); err != nil { + return func() error { + return lis.Close() + }, err + } + + return func() error { + s.GracefulStop() + return nil + }, nil +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index e17fc27fea23..e293669a7e14 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -166,7 +166,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string } } -func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) { +func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (grpc.Backend, error) { if parallel { return addr.GRPC(parallel, ml.wd), nil } @@ -177,7 +177,7 @@ func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.C return ml.grpcClients[string(addr)], nil } -func (ml *ModelLoader) BackendLoader(opts ...Option) (client *grpc.Client, err error) { +func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) { o := NewOptions(opts...) if o.model != "" { @@ -220,7 +220,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client *grpc.Client, err e return ml.resolveAddress(addr, o.parallelRequests) } -func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { +func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { o := NewOptions(opts...) ml.mu.Lock() diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 686b4298a01e..37c2a603a634 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -59,7 +59,7 @@ type ModelLoader struct { ModelPath string mu sync.Mutex // TODO: this needs generics - grpcClients map[string]*grpc.Client + grpcClients map[string]grpc.Backend models map[string]ModelAddress grpcProcesses map[string]*process.Process templates map[TemplateType]map[string]*template.Template @@ -68,7 +68,7 @@ type ModelLoader struct { type ModelAddress string -func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client { +func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) grpc.Backend { enableWD := false if wd != nil { enableWD = true @@ -79,7 +79,7 @@ func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client { func NewModelLoader(modelPath string) *ModelLoader { nml := &ModelLoader{ ModelPath: modelPath, - grpcClients: make(map[string]*grpc.Client), + grpcClients: make(map[string]grpc.Backend), models: make(map[string]ModelAddress), templates: make(map[TemplateType]map[string]*template.Template), grpcProcesses: make(map[string]*process.Process), @@ -163,7 +163,7 @@ func (ml *ModelLoader) StopModel(modelName string) error { } func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress { - var client *grpc.Client + var client grpc.Backend if m, ok := ml.models[s]; ok { log.Debug().Msgf("Model already loaded in memory: %s", s) if c, ok := ml.grpcClients[s]; ok {