Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: New Feature: .NET support calling Inference profiles with AWS Bedrock #10047

Open
minscandboo opened this issue Dec 31, 2024 · 2 comments
Labels
.NET Issue or Pull requests regarding .NET code

Comments

@minscandboo
Copy link


Name: Feature request .NET support calling Inference profiles with AWS Bedrock
About: For tracking and billing purposes using inference profiles with AWS Bedrock allows usage attribution to a specific model across multiple teams within an organization.


Problem

Currently Semantic Kernel expects modelId to match the format provider.modelName This is used to determine the client instantiated (e.g. meta, amazon, etc.) The modelId is also used to build the request.

When using an inference profile, the inference profile replaces the modelId in the request. By allowing the usage of inferenceProfile, multiple teams can independently track and attribute usage when working with the same model.

Current Implementation of service selection: BedrockServiceFactory.cs

 internal IBedrockChatCompletionService CreateChatCompletionService(string modelId)
    {
        (string modelProvider, string modelName) = this.GetModelProviderAndName(modelId);

        switch (modelProvider.ToUpperInvariant())
        {
            case "AI21":
                if (modelName.StartsWith("jamba", StringComparison.OrdinalIgnoreCase))
                {
                    return new AI21JambaService();
                }
                throw new NotSupportedException($"Unsupported AI21 model: {modelId}");
            case "AMAZON":
                if (modelName.StartsWith("titan-", StringComparison.OrdinalIgnoreCase))
                {
                    return new AmazonService();
                }
                throw new NotSupportedException($"Unsupported Amazon model: {modelId}");
            case "ANTHROPIC":
                if (modelName.StartsWith("claude-", StringComparison.OrdinalIgnoreCase))
                {
                    return new AnthropicService();
                }
                throw new NotSupportedException($"Unsupported Anthropic model: {modelId}");
            case "COHERE":
                if (modelName.StartsWith("command-r", StringComparison.OrdinalIgnoreCase))
                {
                    return new CohereCommandRService();
                }
                throw new NotSupportedException($"Unsupported Cohere model: {modelId}");
            case "META":
                if (modelName.StartsWith("llama3-", StringComparison.OrdinalIgnoreCase))
                {
                    return new MetaService();
                }
                throw new NotSupportedException($"Unsupported Meta model: {modelId}");
            case "MISTRAL":
                if (modelName.StartsWith("mistral-", StringComparison.OrdinalIgnoreCase)
                    || modelName.StartsWith("mixtral-", StringComparison.OrdinalIgnoreCase))
                {
                    return new MistralService();
                }
                throw new NotSupportedException($"Unsupported Mistral model: {modelId}");
            default:
                throw new NotSupportedException($"Unsupported model provider: {modelProvider}");
        }
    }

    internal (string modelProvider, string modelName) GetModelProviderAndName(string modelId)
    {
        string[] parts = modelId.Split('.'); //modelId looks like "amazon.titan-text-premier-v1:0"
        string modelName = parts.Length > 1 ? parts[1].ToUpperInvariant() : string.Empty;
        return (parts[0], modelName);
    }

Proposal

Create a new parameter to track inference profile e.g. inferenceProfile
When building the request to bedrock, if the inferenceProfile is present use it instead of the foundational modelId when generating the request Uri, otherwise use the foundational modelId.

Here are some example suggestions:

internal sealed class BedrockChatCompletionClient
{
    private readonly string _modelId;
    private readonly string _inferenceProfile;
    private readonly string _modelProvider;
    private readonly IAmazonBedrockRuntime _bedrockRuntime;
    private readonly IBedrockChatCompletionService _ioChatService;
    private Uri? _chatGenerationEndpoint;
    private readonly ILogger _logger;

    /// <summary>
    /// Builds the client object and registers the model input-output service given the user's passed in model ID.
    /// </summary>
    /// <param name="modelId">The model ID for the client.</param>
    /// <param name="bedrockRuntime">The <see cref="IAmazonBedrockRuntime"/> instance to be used for Bedrock runtime actions.</param>
    /// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
    internal BedrockChatCompletionClient(string modelId, IAmazonBedrockRuntime bedrockRuntime, string? inferenceProfile = null, ILoggerFactory? loggerFactory = null)
    {
        var serviceFactory = new BedrockServiceFactory();
        this._modelId = modelId;
        this._bedrockRuntime = bedrockRuntime;
        this._ioChatService = serviceFactory.CreateChatCompletionService(modelId);
        this._modelProvider = serviceFactory.GetModelProviderAndName(modelId).modelProvider;
        this._inferenceProfile = inferenceProfile;
        this._logger = loggerFactory?.CreateLogger(this.GetType()) ?? NullLogger.Instance;
    }

    /// <summary>
    /// Generates a chat message based on the provided chat history and execution settings.
    /// </summary>
    /// <param name="chatHistory">The chat history to use for generating the chat message.</param>
    /// <param name="executionSettings">The execution settings for the chat completion.</param>
    /// <param name="kernel">The Semantic Kernel instance.</param>
    /// <param name="cancellationToken">The cancellation token.</param>
    /// <returns>The generated chat message.</returns>
    /// <exception cref="ArgumentNullException">Thrown when the chat history is null.</exception>
    /// <exception cref="ArgumentException">Thrown when the chat is empty.</exception>
    /// <exception cref="InvalidOperationException">Thrown when response content is not available.</exception>
    internal async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(
        ChatHistory chatHistory,
        PromptExecutionSettings? executionSettings = null,
        Kernel? kernel = null,
        CancellationToken cancellationToken = default)
    {
        Verify.NotNullOrEmpty(chatHistory);
        ConverseRequest converseRequest = this._ioChatService.GetConverseRequest(this._inferenceProfile ?? this._modelId, chatHistory, executionSettings);

Or if the InferenceProfile was allowed on the PromptExecutionSettings then:

  ConverseRequest converseRequest = this._ioChatService.GetConverseRequest(executionSettings.InferenceProfile ?? this._modelId, chatHistory, executionSettings);

Is mise, le meas, Matthew

@markwallace-microsoft markwallace-microsoft added .NET Issue or Pull requests regarding .NET code triage labels Dec 31, 2024
@github-actions github-actions bot changed the title New Feature: .NET support calling Inference profiles with AWS Bedrock .Net: New Feature: .NET support calling Inference profiles with AWS Bedrock Dec 31, 2024
@evchaki
Copy link
Contributor

evchaki commented Jan 2, 2025

@RogerBarreto can you take a look at this?

@evchaki evchaki removed the triage label Jan 2, 2025
@MarkWard0110
Copy link

Is this why I had to modify the GetModelProviderAndName to work with the MaaS (Model-as-a-Service) instances of the foundation models? Is this because it is really an inference profile, and because of this, the .NET Semantic Kernel is assuming all model IDs are fixed?

To use Meta Llama 3.1 8B instruct, I have to use the modelId us.meta.llama3-1-8b-instruct-v1:0. This does not work, so I modified the GetModelProviderAndName to support the additional segments in the model's ID.

    internal (string modelProvider, string modelName) GetModelProviderAndName(string modelId)
    {
        string[] parts = modelId.Split('.'); //modelId looks like "amazon.titan-text-premier-v1:0"
        string modelProvider = string.Empty;
        string modelName = string.Empty;
        if (parts.Length >= 3)
        {
            modelProvider = parts[1].ToUpperInvariant();
            modelName = parts[2].ToUpperInvariant();
        }
        else if (parts.Length == 2)
        {
            modelProvider = parts[0].ToUpperInvariant();
            modelName = parts[1].ToUpperInvariant();
        }
        else
        {
            throw new ArgumentException($"Invalid model ID: {modelId}");
        }
         
        return (modelProvider, modelName);
    }

With your proposal, we would use the foundation model's ID to create the kernel instance with the inference profile Id. Because AWS uses the inference profile ID as the model ID in the API request, the internal code will check and use the inference profile ID as the request's model ID.

I understand it would support the user-defined inference profile ID and still support the semantic kernel's Amazon Bendrock library to create the appropriate foundation model service.

Is this a correct understanding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
.NET Issue or Pull requests regarding .NET code
Projects
None yet
Development

No branches or pull requests

4 participants