Skip to content

Commit

Permalink
feat: bedrock stream response support
Browse files Browse the repository at this point in the history
  • Loading branch information
laszukdawid committed Oct 16, 2024
1 parent c39d6c0 commit 5148d43
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 1 deletion.
12 changes: 12 additions & 0 deletions Taskfile.dist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ tasks:
- go run cmd/agent/main.go task --loglevel debug --provider $PROVIDER {{.ARGS}}
silent: false

run:set:openai:
desc: Set the provider to openai
cmds:
- go run cmd/agent/main.go config set provider openai
- go run cmd/agent/main.go config set model gpt-4o-mini-2024-07-18

run:set:bedrock:
desc: Set the provider to bedrock
cmds:
- go run cmd/agent/main.go config set provider bedrock
- go run cmd/agent/main.go config set model anthropic.claude-3-haiku-20240307-v1:0

# Environment related - mainly for testing
env:build:
desc: Build the test environment
Expand Down
98 changes: 98 additions & 0 deletions internal/connector/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,107 @@ func (bc *BedrockConnector) queryBedrock(
return converseOutput, nil
}

func (bc *BedrockConnector) queryBedrockStream(
ctx context.Context, qParams *QueryParams, toolConfig *types.ToolConfiguration,
) (string, error) {
systemPromptContent := []types.SystemContentBlock{
&types.SystemContentBlockMemberText{Value: *qParams.SysPrompt},
}
userPromptContent := []types.ContentBlock{
&types.ContentBlockMemberText{Value: *qParams.UserPrompt},
}

converseInput := &bedrockruntime.ConverseStreamInput{
ModelId: (*string)(&bc.modelID),
System: systemPromptContent,
Messages: []types.Message{
{Role: "user", Content: userPromptContent},
},
}

if toolConfig != nil {
converseInput.ToolConfig = toolConfig
}

converseOutput, err := bc.client.ConverseStream(ctx, converseInput)
if err != nil {
var re *awshttp.ResponseError
if errors.As(err, &re) {
log.Printf("requestID: %s, error: %v", re.ServiceRequestID(), re.Unwrap())
}

if re.ResponseError.HTTPStatusCode() == 403 {
return "", ErrBedrockForbidden
}

return "", fmt.Errorf("failed to send request: %v", err)
}

var acc string

stream := converseOutput.GetStream()
// var stream <-chan types.ConverseStreamOutput = converseOutput.GetStream()
var events <-chan types.ConverseStreamOutput = stream.Events()

for _event := range events {
switch event := _event.(type) {

// Message start contains info about "role"
case *types.ConverseStreamOutputMemberMessageStart:
v := event.Value
if v.Role != "assistant" {
bc.logger.Sugar().Debugw("Weird MessageStart", "role", v.Role)
}

// Message stop contains info about "stopReason" and "additionalModelResponseFields"
case *types.ConverseStreamOutputMemberMessageStop:
v := event.Value
bc.logger.Sugar().Debugw("MessageStop",
"stopReason", v.StopReason, "additionalModelResponseFields", v.AdditionalModelResponseFields)

case *types.ConverseStreamOutputMemberContentBlockStart:
start := event.Value.Start
bc.logger.Debug("ContentBlockStart", zap.Any("start", start))
fmt.Println()

case *types.ConverseStreamOutputMemberContentBlockStop:
stop := event.Value
bc.logger.Debug("ContentBlockStart", zap.Any("stop", stop))
fmt.Println()

case *types.ConverseStreamOutputMemberContentBlockDelta:
chunk, isText := event.Value.Delta.(*types.ContentBlockDeltaMemberText)
if !isText {
continue
}
fmt.Print(chunk.Value)
acc += chunk.Value

case *types.ConverseStreamOutputMemberMetadata:
usage := event.Value.Usage
price := computePriceBedrock(bc.modelID, &BedrockUsage{
InputTokens: *usage.InputTokens,
OutputTokens: *usage.OutputTokens,
TotalTokens: *usage.TotalTokens,
})

bc.logger.Sugar().Debugw("Usage", "usage", usage, "price", price)

default:
bc.logger.Warn("union is nil or unknown type", zap.Any("event", event))
}
}

return acc, nil
}

func (bc *BedrockConnector) Query(ctx context.Context, qParams *QueryParams) (string, error) {
bc.logger.Sugar().Debugw("Query", "model", bc.modelID)

if qParams.Stream {
return bc.queryBedrockStream(ctx, qParams, nil)
}

converseOutput, err := bc.queryBedrock(context.Background(), qParams.UserPrompt, qParams.SysPrompt, nil)
if err != nil {
return "", err
Expand Down
2 changes: 1 addition & 1 deletion internal/connector/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const (

var (
// https://openai.com/api/pricing/
ModelPricesOpenai = map[string]map[string]float64{
ModelPricesOpenai = map[openai.ChatModel]map[string]float64{
openai.ChatModelGPT4o: {
"input": 0.00025,
"output": 0.01000,
Expand Down

0 comments on commit 5148d43

Please sign in to comment.