diff --git a/Taskfile.dist.yaml b/Taskfile.dist.yaml index 98b7f34..1e3174f 100644 --- a/Taskfile.dist.yaml +++ b/Taskfile.dist.yaml @@ -52,6 +52,18 @@ tasks: - go run cmd/agent/main.go task --loglevel debug --provider $PROVIDER {{.ARGS}} silent: false + run:set:openai: + desc: Set the provider to openai + cmds: + - go run cmd/agent/main.go config set provider openai + - go run cmd/agent/main.go config set model gpt-4o-mini-2024-07-18 + + run:set:bedrock: + desc: Set the provider to bedrock + cmds: + - go run cmd/agent/main.go config set provider bedrock + - go run cmd/agent/main.go config set model anthropic.claude-3-haiku-20240307-v1:0 + # Environment related - mainly for testing env:build: desc: Build the test environment diff --git a/internal/connector/bedrock.go b/internal/connector/bedrock.go index 818f861..0e389be 100644 --- a/internal/connector/bedrock.go +++ b/internal/connector/bedrock.go @@ -194,9 +194,107 @@ func (bc *BedrockConnector) queryBedrock( return converseOutput, nil } +func (bc *BedrockConnector) queryBedrockStream( + ctx context.Context, qParams *QueryParams, toolConfig *types.ToolConfiguration, +) (string, error) { + systemPromptContent := []types.SystemContentBlock{ + &types.SystemContentBlockMemberText{Value: *qParams.SysPrompt}, + } + userPromptContent := []types.ContentBlock{ + &types.ContentBlockMemberText{Value: *qParams.UserPrompt}, + } + + converseInput := &bedrockruntime.ConverseStreamInput{ + ModelId: (*string)(&bc.modelID), + System: systemPromptContent, + Messages: []types.Message{ + {Role: "user", Content: userPromptContent}, + }, + } + + if toolConfig != nil { + converseInput.ToolConfig = toolConfig + } + + converseOutput, err := bc.client.ConverseStream(ctx, converseInput) + if err != nil { + var re *awshttp.ResponseError + if errors.As(err, &re) { + log.Printf("requestID: %s, error: %v", re.ServiceRequestID(), re.Unwrap()) + } + + if re.ResponseError.HTTPStatusCode() == 403 { + return "", ErrBedrockForbidden + } + + return "", fmt.Errorf("failed to send request: %v", err) + } + + var acc string + + stream := converseOutput.GetStream() + // var stream <-chan types.ConverseStreamOutput = converseOutput.GetStream() + var events <-chan types.ConverseStreamOutput = stream.Events() + + for _event := range events { + switch event := _event.(type) { + + // Message start contains info about "role" + case *types.ConverseStreamOutputMemberMessageStart: + v := event.Value + if v.Role != "assistant" { + bc.logger.Sugar().Debugw("Weird MessageStart", "role", v.Role) + } + + // Message stop contains info about "stopReason" and "additionalModelResponseFields" + case *types.ConverseStreamOutputMemberMessageStop: + v := event.Value + bc.logger.Sugar().Debugw("MessageStop", + "stopReason", v.StopReason, "additionalModelResponseFields", v.AdditionalModelResponseFields) + + case *types.ConverseStreamOutputMemberContentBlockStart: + start := event.Value.Start + bc.logger.Debug("ContentBlockStart", zap.Any("start", start)) + fmt.Println() + + case *types.ConverseStreamOutputMemberContentBlockStop: + stop := event.Value + bc.logger.Debug("ContentBlockStart", zap.Any("stop", stop)) + fmt.Println() + + case *types.ConverseStreamOutputMemberContentBlockDelta: + chunk, isText := event.Value.Delta.(*types.ContentBlockDeltaMemberText) + if !isText { + continue + } + fmt.Print(chunk.Value) + acc += chunk.Value + + case *types.ConverseStreamOutputMemberMetadata: + usage := event.Value.Usage + price := computePriceBedrock(bc.modelID, &BedrockUsage{ + InputTokens: *usage.InputTokens, + OutputTokens: *usage.OutputTokens, + TotalTokens: *usage.TotalTokens, + }) + + bc.logger.Sugar().Debugw("Usage", "usage", usage, "price", price) + + default: + bc.logger.Warn("union is nil or unknown type", zap.Any("event", event)) + } + } + + return acc, nil +} + func (bc *BedrockConnector) Query(ctx context.Context, qParams *QueryParams) (string, error) { bc.logger.Sugar().Debugw("Query", "model", bc.modelID) + if qParams.Stream { + return bc.queryBedrockStream(ctx, qParams, nil) + } + converseOutput, err := bc.queryBedrock(context.Background(), qParams.UserPrompt, qParams.SysPrompt, nil) if err != nil { return "", err diff --git a/internal/connector/openai.go b/internal/connector/openai.go index 09c4f5b..bd4a639 100644 --- a/internal/connector/openai.go +++ b/internal/connector/openai.go @@ -19,7 +19,7 @@ const ( var ( // https://openai.com/api/pricing/ - ModelPricesOpenai = map[string]map[string]float64{ + ModelPricesOpenai = map[openai.ChatModel]map[string]float64{ openai.ChatModelGPT4o: { "input": 0.00025, "output": 0.01000,