Skip to content

Commit

Permalink
.Net: Fix hugging face embedding (microsoft#6673)
Browse files Browse the repository at this point in the history
### Motivation and Context
Fix the Bug microsoft#6635 HuggingFace Embedding: Unable to Deserialization for
certain models

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description
As mentioned in the issue, the HuggingFace Embedding API interface
returns responses typically in the form of `List<ReadOnlyMemory<float>>`
and occasionally as `List<List<List<ReadOnlyMemory<float>>>>`.
Currently, only the latter format is handled correctly, leading to
deserialization issues.

To address this, I propose the following solution:
```
try {
    // Attempt to parse data as List<ReadOnlyMemory<float>> and return the parsed data
}
catch (KernelException ex1) {
    try {
        // If the first attempt fails, attempt to parse data as List<List<List<ReadOnlyMemory<float>>>>` and return the parsed data
    }
    catch (KernelException ex2) {
        // If both attempts fail, handle the exception (e.g., the model doesn't exist ,the model has still been loading, or an HTTP exception occurred) and rethrow the error
    }
}

```
### Contribution Checklist 

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄

---------

Co-authored-by: Dmytro Struk <[email protected]>
  • Loading branch information
N-E-W-T-O-N and dmytrostruk authored Jul 8, 2024
1 parent 13e3a22 commit 32d3f5d
Show file tree
Hide file tree
Showing 5 changed files with 1,105 additions and 2,319 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json;
using Microsoft.SemanticKernel.Connectors.HuggingFace;
using Microsoft.SemanticKernel.Connectors.Sqlite;
using Microsoft.SemanticKernel.Memory;

#pragma warning disable CS8602 // Dereference of a possibly null reference.

namespace Memory;

/// <summary>
/// This example shows how to use custom <see cref="HttpClientHandler"/> to override Hugging Face HTTP response.
/// Generally, an embedding model will return results as a 1 * n matrix for input type [string]. However, the model can have different matrix dimensionality.
/// For example, the <a href="https://huggingface.co/cointegrated/LaBSE-en-ru">cointegrated/LaBSE-en-ru</a> model returns results as a 1 * 1 * 4 * 768 matrix, which is different from Hugging Face embedding generation service implementation.
/// To address this, a custom <see cref="HttpClientHandler"/> can be used to modify the response before sending it back.
/// </summary>
public class HuggingFace_TextEmbeddingCustomHttpHandler(ITestOutputHelper output) : BaseTest(output)
{
public async Task RunInferenceApiEmbeddingCustomHttpHandlerAsync()
{
Console.WriteLine("\n======= Hugging Face Inference API - Embedding Example ========\n");

var hf = new HuggingFaceTextEmbeddingGenerationService(
"cointegrated/LaBSE-en-ru",
apiKey: TestConfiguration.HuggingFace.ApiKey,
httpClient: new HttpClient(new CustomHttpClientHandler()
{
CheckCertificateRevocationList = true
})
);

var sqliteMemory = await SqliteMemoryStore.ConnectAsync("./../../../Sqlite.sqlite");

var skMemory = new MemoryBuilder()
.WithTextEmbeddingGeneration(hf)
.WithMemoryStore(sqliteMemory)
.Build();

await skMemory.SaveInformationAsync("Test", "THIS IS A SAMPLE", "sample", "TEXT");
}

private sealed class CustomHttpClientHandler : HttpClientHandler
{
private readonly JsonSerializerOptions _jsonOptions = new();
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
// Log the request URI
//Console.WriteLine($"Request: {request.Method} {request.RequestUri}");

// Send the request and get the response
HttpResponseMessage response = await base.SendAsync(request, cancellationToken);

// Log the response status code
//Console.WriteLine($"Response: {(int)response.StatusCode} {response.ReasonPhrase}");

// You can manipulate the response here
// For example, add a custom header
// response.Headers.Add("X-Custom-Header", "CustomValue");

// For example, modify the response content
Stream originalContent = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
List<List<List<ReadOnlyMemory<float>>>> modifiedContent = (await JsonSerializer.DeserializeAsync<List<List<List<ReadOnlyMemory<float>>>>>(originalContent, _jsonOptions, cancellationToken).ConfigureAwait(false))!;

Stream modifiedStream = new MemoryStream();
await JsonSerializer.SerializeAsync(modifiedStream, modifiedContent[0][0].ToList(), _jsonOptions, cancellationToken).ConfigureAwait(false);
response.Content = new StreamContent(modifiedStream);

// Return the modified response
return response;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ public async Task ShouldHandleServiceResponseAsync()
//Assert

Assert.NotNull(embeddings);
Assert.Equal(3, embeddings.Count);
Assert.Equal(768, embeddings.First().Length);
Assert.Single(embeddings);
Assert.Equal(1024, embeddings.First().Length);
}

public void Dispose()
Expand Down
Loading

0 comments on commit 32d3f5d

Please sign in to comment.