diff --git a/.github/workflows/build_75.yaml b/.github/workflows/build_75.yaml index 59d592c4..6d941d03 100644 --- a/.github/workflows/build_75.yaml +++ b/.github/workflows/build_75.yaml @@ -24,22 +24,30 @@ steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Tailscale + uses: huggingface/tailscale-action@main + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v2.0.0 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] + - name: Configure sccache uses: actions/github-script@v6 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 - - name: Tailscale - uses: huggingface/tailscale-action@v1 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 @@ -47,12 +55,14 @@ registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry uses: docker/login-action@v2.1.0 with: username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker id: meta-75 uses: docker/metadata-action@v4.3.0 @@ -67,6 +77,7 @@ type=semver,pattern=turing-{{major}}.{{minor}} type=raw,value=turing-latest type=raw,value=turing-sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image id: build-and-push-75 uses: docker/build-push-action@v4 @@ -87,6 +98,7 @@ labels: ${{ steps.meta-75.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max + - name: Extract metadata (tags, labels) for Docker id: meta-75-grpc uses: docker/metadata-action@v4.3.0 @@ -101,6 +113,7 @@ type=semver,pattern=turing-{{major}}.{{minor}}-grpc type=raw,value=turing-latest-grpc type=raw,value=turing-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + - name: Build and push Docker image id: build-and-push-75-grpc uses: docker/build-push-action@v4 diff --git a/.github/workflows/build_80.yaml b/.github/workflows/build_80.yaml index c032cb00..7a41c117 100644 --- a/.github/workflows/build_80.yaml +++ b/.github/workflows/build_80.yaml @@ -36,6 +36,12 @@ steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Tailscale + uses: huggingface/tailscale-action@main + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v2.0.0 with: @@ -43,18 +49,17 @@ config-inline: | [registry."docker.io"] mirrors = ["registry.github-runners.huggingface.tech"] + - name: Configure sccache uses: actions/github-script@v6 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 - - name: Tailscale - uses: huggingface/tailscale-action@v1 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 @@ -62,12 +67,14 @@ registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry uses: docker/login-action@v2.1.0 with: username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker id: meta-80 uses: docker/metadata-action@v4.3.0 @@ -82,6 +89,7 @@ type=semver,pattern={{major}}.{{minor}} type=raw,value=latest type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image id: build-and-push-80 uses: docker/build-push-action@v4 @@ -101,6 +109,7 @@ labels: ${{ steps.meta-80.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-80,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-80,mode=max + - name: Extract metadata (tags, labels) for Docker id: meta-80-grpc uses: docker/metadata-action@v4.3.0 @@ -115,6 +124,7 @@ type=semver,pattern={{major}}.{{minor}}-grpc type=raw,value=latest-grpc type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-grpc + - name: Build and push Docker image id: build-and-push-80-grpc uses: docker/build-push-action@v4 diff --git a/.github/workflows/build_86.yaml b/.github/workflows/build_86.yaml index 4c9cb02f..16c182ec 100644 --- a/.github/workflows/build_86.yaml +++ b/.github/workflows/build_86.yaml @@ -24,22 +24,30 @@ steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Tailscale + uses: huggingface/tailscale-action@main + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v2.0.0 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] + - name: Configure sccache uses: actions/github-script@v6 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 - - name: Tailscale - uses: huggingface/tailscale-action@v1 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 @@ -47,12 +55,14 @@ registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry uses: docker/login-action@v2.1.0 with: username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker id: meta-86 uses: docker/metadata-action@v4.3.0 @@ -67,6 +77,7 @@ type=semver,pattern=86-{{major}}.{{minor}} type=raw,value=86-latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=86-sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image id: build-and-push-86 uses: docker/build-push-action@v4 @@ -86,6 +97,7 @@ labels: ${{ steps.meta-86.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-86,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-86,mode=max + - name: Extract metadata (tags, labels) for Docker id: meta-86-grpc uses: docker/metadata-action@v4.3.0 @@ -100,6 +112,7 @@ type=semver,pattern=86-{{major}}.{{minor}}-grpc type=raw,value=86-latest-grpc type=raw,value=86-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + - name: Build and push Docker image id: build-and-push-86-grpc uses: docker/build-push-action@v4 diff --git a/.github/workflows/build_89.yaml b/.github/workflows/build_89.yaml index cc7d171e..689838f6 100644 --- a/.github/workflows/build_89.yaml +++ b/.github/workflows/build_89.yaml @@ -24,22 +24,30 @@ steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Tailscale + uses: huggingface/tailscale-action@main + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v2.0.0 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] + - name: Configure sccache uses: actions/github-script@v6 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 - - name: Tailscale - uses: huggingface/tailscale-action@v1 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 @@ -47,12 +55,14 @@ registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry uses: docker/login-action@v2.1.0 with: username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker id: meta-89 uses: docker/metadata-action@v4.3.0 @@ -67,6 +77,7 @@ type=semver,pattern=89-{{major}}.{{minor}} type=raw,value=89-latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=89-sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image id: build-and-push-89 uses: docker/build-push-action@v4 @@ -86,6 +97,7 @@ labels: ${{ steps.meta-89.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-89,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-89,mode=max + - name: Extract metadata (tags, labels) for Docker id: meta-89-grpc uses: docker/metadata-action@v4.3.0 @@ -100,6 +112,7 @@ type=semver,pattern=89-{{major}}.{{minor}}-grpc type=raw,value=89-latest-grpc type=raw,value=89-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + - name: Build and push Docker image id: build-and-push-89-grpc uses: docker/build-push-action@v4 diff --git a/.github/workflows/build_90.yaml b/.github/workflows/build_90.yaml index 1c468bfa..40f05235 100644 --- a/.github/workflows/build_90.yaml +++ b/.github/workflows/build_90.yaml @@ -24,22 +24,30 @@ steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Tailscale + uses: huggingface/tailscale-action@main + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v2.0.0 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] + - name: Configure sccache uses: actions/github-script@v6 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 - - name: Tailscale - uses: huggingface/tailscale-action@v1 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 @@ -47,12 +55,14 @@ registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry uses: docker/login-action@v2.1.0 with: username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker id: meta-90 uses: docker/metadata-action@v4.3.0 @@ -67,6 +77,7 @@ type=semver,pattern=hopper-{{major}}.{{minor}} type=raw,value=hopper-latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=hopper-sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image id: build-and-push-90 uses: docker/build-push-action@v4 @@ -86,6 +97,7 @@ labels: ${{ steps.meta-90.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-90,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-90,mode=max + - name: Extract metadata (tags, labels) for Docker id: meta-90-grpc uses: docker/metadata-action@v4.3.0 @@ -100,6 +112,7 @@ type=semver,pattern=hopper-{{major}}.{{minor}}-grpc type=raw,value=hopper-latest-grpc type=raw,value=hopper-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + - name: Build and push Docker image id: build-and-push-90-grpc uses: docker/build-push-action@v4 diff --git a/.github/workflows/build_all.yaml b/.github/workflows/build_all.yaml index 36042cd0..58aa7b09 100644 --- a/.github/workflows/build_all.yaml +++ b/.github/workflows/build_all.yaml @@ -24,16 +24,23 @@ steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Tailscale + uses: huggingface/tailscale-action@v1 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v2.0.0 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] + - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 - - name: Tailscale - uses: huggingface/tailscale-action@v1 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 @@ -41,12 +48,14 @@ registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry uses: docker/login-action@v2.1.0 with: username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker id: meta uses: docker/metadata-action@v4.3.0 @@ -61,6 +70,7 @@ type=semver,pattern=cuda-{{major}}.{{minor}} type=raw,value=cuda-latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=cuda-sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image id: build-and-push uses: docker/build-push-action@v4 @@ -76,6 +86,7 @@ labels: ${{ steps.meta.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-all,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-all,mode=max + - name: Extract metadata (tags, labels) for Docker id: meta-sagemaker uses: docker/metadata-action@v4.3.0 @@ -89,6 +100,7 @@ type=semver,pattern=cuda-{{major}}.{{minor}} type=raw,value=cuda-latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=cuda-sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image id: build-and-push-sagemaker uses: docker/build-push-action@v4 diff --git a/.github/workflows/build_cpu.yaml b/.github/workflows/build_cpu.yaml index ce61ef7f..3cf87f5a 100644 --- a/.github/workflows/build_cpu.yaml +++ b/.github/workflows/build_cpu.yaml @@ -36,22 +36,30 @@ steps: - name: Checkout repository uses: actions/checkout@v3 - - name: Initialize Docker Buildx - uses: docker/setup-buildx-action@v2.0.0 - with: - install: true + - name: Configure sccache uses: actions/github-script@v6 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 + - name: Tailscale - uses: huggingface/tailscale-action@v1 + uses: huggingface/tailscale-action@main with: authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] + - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 @@ -59,12 +67,14 @@ registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry uses: docker/login-action@v2.1.0 with: username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker id: meta-cpu uses: docker/metadata-action@v4.3.0 @@ -79,6 +89,7 @@ type=semver,pattern=cpu-{{major}}.{{minor}} type=raw,value=cpu-latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=cpu-sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image id: build-and-push-cpu uses: docker/build-push-action@v4 @@ -97,6 +108,7 @@ labels: ${{ steps.meta-cpu.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-cpu,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-cpu,mode=max + - name: Extract metadata (tags, labels) for Docker id: meta-cpu-grpc uses: docker/metadata-action@v4.3.0 @@ -111,6 +123,7 @@ type=semver,pattern=cpu-{{major}}.{{minor}}-grpc type=raw,value=cpu-latest-grpc type=raw,value=cpu-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + - name: Build and push Docker image id: build-and-push-cpu-grpc uses: docker/build-push-action@v4 diff --git a/Cargo.lock b/Cargo.lock index e0424a3b..2b494010 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3555,6 +3555,7 @@ dependencies = [ "async-stream", "axum 0.7.5", "axum-tracing-opentelemetry", + "base64 0.21.7", "clap", "futures", "hf-hub", diff --git a/Cargo.toml b/Cargo.toml index 7d7be46e..f8198679 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,12 +24,17 @@ candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "33b7ecf9ed82bb7c20f1a94555218fabfbaa2fe3", package = "candle-flash-attn" } hf-hub = { git = "https://github.com/huggingface/hf-hub", rev = "b167f69692be5f49eb8003788f7f8a499a98b096" } - [profile.release] debug = 0 -incremental = true lto = "fat" opt-level = 3 codegen-units = 1 strip = "symbols" panic = "abort" + +[profile.release-debug] +inherits = "release" +debug = 1 +lto = "thin" +codegen-units = 16 +strip = "none" diff --git a/Dockerfile b/Dockerfile index c84baa1b..6d7c35d4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ WORKDIR /usr/src ENV SCCACHE=0.5.4 ENV RUSTC_WRAPPER=/usr/local/bin/sccache -# Donwload & configure sccache +# Download, configure sccache RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ chmod +x /usr/local/bin/sccache diff --git a/Dockerfile-cuda b/Dockerfile-cuda index 990e4261..d67eba93 100644 --- a/Dockerfile-cuda +++ b/Dockerfile-cuda @@ -35,6 +35,11 @@ ARG CUDA_COMPUTE_CAP=80 ARG GIT_SHA ARG DOCKER_LABEL +# Limit parallelism +ARG RAYON_NUM_THREADS +ARG CARGO_BUILD_JOBS +ARG CARGO_BUILD_INCREMENTAL + # sccache specific variables ARG ACTIONS_CACHE_URL ARG ACTIONS_RUNTIME_TOKEN diff --git a/Dockerfile-cuda-all b/Dockerfile-cuda-all index c7ed2e5b..71321705 100644 --- a/Dockerfile-cuda-all +++ b/Dockerfile-cuda-all @@ -40,8 +40,10 @@ ARG ACTIONS_CACHE_URL ARG ACTIONS_RUNTIME_TOKEN ARG SCCACHE_GHA_ENABLED -# limit the number of kernels built at the same time +# Limit parallelism ARG RAYON_NUM_THREADS=4 +ARG CARGO_BUILD_JOBS +ARG CARGO_BUILD_INCREMENTAL WORKDIR /usr/src diff --git a/README.md b/README.md index 30960206..e9306fda 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ Examples of supported models: | N/A | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) | | N/A | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) | | N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) | +| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) | You can explore the list of best performing text embeddings models [here](https://huggingface.co/spaces/mteb/leaderboard). @@ -89,9 +90,9 @@ Example of supported sequence classification models: | Task | Model Type | Model ID | |--------------------|-------------|---------------------------------------------------------------------------------------------| -| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | +| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | | Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | -| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | +| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | ### Docker diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 241504d0..27c5d843 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -11,13 +11,14 @@ use crate::compute_cap::{ compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap, }; use crate::models::{ - BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, Model, NomicBertModel, - NomicConfig, PositionEmbeddingType, + BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaConfig, JinaBertModel, JinaCodeConfig, JinaCodeBertModel, + Model, NomicBertModel, NomicConfig, }; #[cfg(feature = "cuda")] use crate::models::{ - FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashNomicBertModel, + FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, FlashNomicBertModel, }; +use anyhow::Context; use candle::{DType, Device}; use candle_nn::VarBuilder; use models::BertConfig; @@ -36,6 +37,10 @@ enum Config { XlmRoberta(BertConfig), Camembert(BertConfig), Roberta(BertConfig), + #[serde(rename(deserialize = "jina_bert"))] + JinaBert(JinaConfig), + #[serde(rename(deserialize = "jina_code_bert"))] + JinaCodeBert(JinaCodeConfig), #[serde(rename(deserialize = "distilbert"))] DistilBert(DistilBertConfig), #[serde(rename(deserialize = "nomic_bert"))] @@ -55,9 +60,11 @@ impl CandleBackend { ) -> Result { // Load config let config: String = std::fs::read_to_string(model_path.join("config.json")) - .map_err(|err| BackendError::Start(err.to_string()))?; + .context("Unable to read config file") + .map_err(|err| BackendError::Start(format!("{err:?}")))?; let config: Config = serde_json::from_str(&config) - .map_err(|err| BackendError::Start(format!("Model is not supported: {}", err)))?; + .context("Model is not supported") + .map_err(|err| BackendError::Start(format!("{err:?}")))?; // Get candle device let device = if candle::utils::cuda_is_available() { @@ -72,7 +79,7 @@ impl CandleBackend { ))) } Err(err) => { - tracing::warn!("Could not find a compatible CUDA device on host: {err}"); + tracing::warn!("Could not find a compatible CUDA device on host: {err:?}"); tracing::warn!("Using CPU instead"); Ok(Device::Cpu) } @@ -117,13 +124,16 @@ impl CandleBackend { "`cuda` feature is not enabled".to_string(), )), (Config::Bert(config), Device::Cpu | Device::Metal(_)) => { - if config.position_embedding_type == PositionEmbeddingType::Alibi { - tracing::info!("Starting JinaBertModel model on {:?}", device); - Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?)) - } else { - tracing::info!("Starting Bert model on {:?}", device); - Ok(Box::new(BertModel::load(vb, &config, model_type).s()?)) - } + tracing::info!("Starting Bert model on {:?}", device); + Ok(Box::new(BertModel::load(vb, &config, model_type).s()?)) + } + (Config::JinaBert(config), Device::Cpu | Device::Metal(_)) => { + tracing::info!("Starting JinaBertModel model on {:?}", device); + Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?)) + } + (Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => { + tracing::info!("Starting JinaCodeBertModel model on {:?}", device); + Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?)) } ( Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config), @@ -154,23 +164,43 @@ impl CandleBackend { && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" { if config.position_embedding_type == PositionEmbeddingType::Alibi { - tracing::info!("Starting FlashJinaBertModel model on {:?}", device); - Ok(Box::new( - FlashJinaBertModel::load(vb, &config, model_type).s()?, - )) - } else { tracing::info!("Starting FlashBert model on {:?}", device); Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?)) - } - } else { - if config.position_embedding_type == PositionEmbeddingType::Alibi { - tracing::info!("Starting JinaBertModel model on {:?}", device); - Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?)) } else { tracing::info!("Starting Bert model on {:?}", device); Ok(Box::new(BertModel::load(vb, &config, model_type).s()?)) } } + #[cfg(feature = "cuda")] + (Config::JinaBert(config), Device::Cuda(_)) => { + if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) + && dtype == DType::F16 + && ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi)) + // Allow disabling because of flash attention v1 precision problems + // See: https://github.com/huggingface/text-embeddings-inference/issues/37 + && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" + { + tracing::info!("Starting FlashJinaBertModel model on {:?}", device); + Ok(Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?,)) + } else { + tracing::info!("Starting JinaBertModel model on {:?}", device); + Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?)) + } + #[cfg(feature = "cuda")] + (Config::JinaCodeBert(config), Device::Cuda(_)) => { + if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) + && dtype == DType::F16 + && ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi)) + // Allow disabling because of flash attention v1 precision problems + // See: https://github.com/huggingface/text-embeddings-inference/issues/37 + && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" + { + tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device); + Ok(Box::new(FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,)) + } else { + tracing::info!("Starting JinaCodeBertModel model on {:?}", device); + Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?)) + } } #[cfg(feature = "cuda")] ( diff --git a/backends/candle/src/models.rs b/backends/candle/src/models.rs index ed89482e..7f098cfb 100644 --- a/backends/candle/src/models.rs +++ b/backends/candle/src/models.rs @@ -15,6 +15,9 @@ mod flash_bert; #[cfg(feature = "cuda")] mod flash_jina; +#[cfg(feature = "cuda")] +mod flash_jina_code; + #[cfg(feature = "cuda")] mod flash_nomic; @@ -24,7 +27,8 @@ mod flash_distilbert; pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; pub use distilbert::{DistilBertConfig, DistilBertModel}; -pub use jina::JinaBertModel; +pub use jina::{JinaConfig, JinaBertModel}; +pub use jina_code::{JinaCodeConfig, JinaCodeBertModel}; pub use nomic::{NomicBertModel, NomicConfig}; use text_embeddings_backend_core::Batch; @@ -34,6 +38,10 @@ pub use flash_bert::FlashBertModel; #[cfg(feature = "cuda")] pub use flash_jina::FlashJinaBertModel; +#[cfg(feature = "cuda")] +pub use flash_jina_code::FlashJinaCodeBertModel; + + #[cfg(feature = "cuda")] pub use flash_nomic::FlashNomicBertModel; diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index ebf2c292..e128252a 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -1,7 +1,8 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; -use crate::models::bert::{BertConfig, PositionEmbeddingType}; +use crate::models::bert::PositionEmbeddingType; +use crate::models::jina::{JinaConfig, BertEmbeddings}; use crate::models::jina::BertEmbeddings; use crate::models::Model; use candle::{DType, Device, IndexOp, Result, Tensor}; diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs new file mode 100644 index 00000000..97ca5fc0 --- /dev/null +++ b/backends/candle/src/models/flash_jina_code.rs @@ -0,0 +1,458 @@ +use crate::alibi::alibi_head_slopes; +use crate::flash_attn::flash_attn_varlen; +use crate::layers::{HiddenAct, LayerNorm, Linear}; +use crate::models::bert::PositionEmbeddingType; +use crate::models::jina::{JinaCodeConfig, BertEmbeddings}; +use crate::models::Model; +use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle_nn::VarBuilder; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; + +struct AlibiBertAttention { + query_linear: Linear, + key_linear: Linear, + value_linear: Linear, + + dense: Linear, + layer_norm_q: LayerNorm, + layer_norm_k: LayerNorm, + layer_norm_out: LayerNorm, + + alibi_slopes: Option, + + num_attention_heads: usize, + attention_head_size: usize, + softmax_scale: f32, + + span: tracing::Span, +} + +impl AlibiBertAttention { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig, alibi_slopes: Option) -> Result { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let hidden_size = config.hidden_size; + + let query_weight = vb + .pp("self.query") + .get((all_head_size, hidden_size), "weight")?; + let query_bias = vb.pp("self.query").get(all_head_size, "bias")?; + let key_weight = vb + .pp("self.key") + .get((all_head_size, hidden_size), "weight")?; + let key_bias = vb.pp("self.key").get(all_head_size, "bias")?; + let value_weight = vb + .pp("self.value") + .get((all_head_size, hidden_size), "weight")?; + let value_bias = vb.pp("self.value").get(all_head_size, "bias")?; + + let layer_norm_q = LayerNorm::load( + vp.pp("self").pp("layer_norm_q"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + let layer_norm_k = LayerNorm::load( + vp.pp("self").pp("layer_norm_k"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + + let query_linear = Linear::new(query_weight, Some(query_bias), None); + let key_linear = Linear::new(key_weight, Some(key_bias), None); + let value_linear = Linear::new(value_weight, Some(value_bias), None); + + let dense_weight = vb + .pp("output") + .pp("dense") + .get((hidden_size, hidden_size), "weight")?; + let dense_bias = vb.pp("output").pp("dense").get(hidden_size, "bias")?; + + let dense = Linear::new(dense_weight, Some(dense_bias), None); + + let layer_norm_out = LayerNorm::load( + vb.pp("output").pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + + let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32; + + Ok(Self { + query_linear, + key_linear, + value_linear, + dense, + layer_norm_q, + layer_norm_k, + layer_norm_out, + alibi_slopes, + num_attention_heads: config.num_attention_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + cu_seqlens: &Tensor, + max_s: usize, + ) -> Result { + let _enter = self.span.enter(); + + let residual = hidden_states.clone(); + + let query_layer = self.query_linear.forward(hidden_states)?; + let query_layer = self.layer_norm_q.forward(&query_layer, None)?; + + let key_layer = self.key_linear.forward(hidden_states)?; + let key_layer = self.layer_norm_k.forward(&key_layer, None)?; + + let value_layer = self.value_linear.forward(hidden_states)?; + + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads); + new_qkv_shape.push(self.attention_head_size); + + let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + + let attention = flash_attn_varlen( + query_layer, + key_layer, + value_layer, + self.alibi_slopes.as_ref(), + cu_seqlens, + cu_seqlens, + max_s, + max_s, + self.softmax_scale, + false, + )?; + let attention = attention.flatten_from(candle::D::Minus2)?; + + let hidden_states = self.dense.forward(&attention)?; + let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?; + + Ok(hidden_states) + } +} + +struct JinaBertLayer { + attention: AlibiBertAttention, + up_gated_layer: Linear, + down_layer: Linear, + layer_norm_1: LayerNorm, + layer_norm_2: LayerNorm, + act: HiddenAct, + + intermediate_size: usize, + + span: tracing::Span, +} + +impl JinaBertLayer { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig, alibi: Option) -> Result { + let attention = AlibiBertAttention::load(vb.pp("attention"), config, alibi)?; + + let up_gated_weight = vb + .pp("mlp") + .pp("up_gated_layer") + .get((config.intermediate_size * 2, config.hidden_size), "weight")?; + let up_gated_layer = Linear::new(up_gated_weight, None, None); + + let down_weight = vb + .pp("mlp") + .pp("down_layer") + .get((config.hidden_size, config.intermediate_size), "weight")?; + let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?; + let down_layer = Linear::new(down_weight, Some(down_bias), None); + + let layer_norm_1 = LayerNorm::load( + vb.pp("layer_norm_1"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + let layer_norm_2 = LayerNorm::load( + vb.pp("layer_norm_2"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + + Ok(Self { + attention, + up_gated_layer, + down_layer, + layer_norm_1, + layer_norm_2, + act: config.hidden_act.clone(), + intermediate_size: config.intermediate_size, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + cu_seqlens: &Tensor, + max_s: usize, + ) -> Result { + let _enter = self.span.enter(); + + let residual = hidden_states.clone(); + let hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s)?; + + // Pre-MLP LayerNorm + let hidden_states = self.layer_norm_1.forward(&hidden_states, Some(&residual))?; + + // MLP block + let residual = hidden_states.clone(); + let hidden_states = self.up_gated_layer.forward(&hidden_states)?; + let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?; + let gated = hidden_states.i((.., .., self.intermediate_size..))?; + let gated = match self.act { + HiddenAct::Gelu => gated.gelu(), + HiddenAct::Relu => gated.relu(), + HiddenAct::Swiglu => gated.silu(), + }?; + let hidden_states = (non_gated * gated)?; + let hidden_states = self.down_layer.forward(&hidden_states)?; + + // Post-MLP LayerNorm + let hidden_states = self.layer_norm_2.forward(&hidden_states, Some(&residual))?; + + Ok(hidden_states) + } +} + +struct BertEncoder { + layers: Vec, + span: tracing::Span, +} + +impl BertEncoder { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig, alibi: Option) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| { + JinaBertLayer::load(vb.pp(format!("layer.{index}")), config, alibi.clone()) + }) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + + Ok(BertEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor, cu_seqlens: &Tensor, max_s: usize) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, cu_seqlens, max_s)? + } + + Ok(hidden_states) + } +} + +pub struct FlashJinaCodeBertModel { + embeddings: BertEmbeddings, + encoder: BertEncoder, + pool: Pool, + pub device: Device, + + span: tracing::Span, +} + +impl FlashJinaCodeBertModel { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig, model_type: ModelType) -> Result { + let alibi = match config.position_embedding_type { + PositionEmbeddingType::Alibi => { + let alibi_slopes = alibi_head_slopes(config.num_attention_heads); + Some( + Tensor::from_vec(alibi_slopes, config.num_attention_heads, vb.device())? + .to_dtype(DType::F32)?, + ) + } + PositionEmbeddingType::Absolute => None, + }; + + match vb.device() { + Device::Cuda(_) => {} + _ => candle::bail!("FlashJinaCodeBertModel requires Cuda"), + } + + if vb.dtype() != DType::F16 { + candle::bail!("FlashJinaCodeBertModel requires DType::F16") + } + + let pool = match model_type { + ModelType::Classifier => { + candle::bail!("`classifier` model type is not supported for Jina Code") + } + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Jina Code") + } + pool + } + }; + + let (embeddings, encoder) = match ( + BertEmbeddings::load(vb.pp("embeddings"), config), + BertEncoder::load(vb.pp("encoder"), config, alibi.clone()), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(embeddings), Ok(encoder)) = ( + BertEmbeddings::load(vb.pp("bert.embeddings"), config), + BertEncoder::load(vb.pp("bert.encoder"), config, alibi.clone()), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } + }; + + Ok(Self { + embeddings, + encoder, + pool, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.len(); + let shape = batch.input_ids.len(); + + // Create Cuda tensors + let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; + let type_ids = Tensor::from_vec(batch.token_type_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; + let cu_seqlens = Tensor::from_vec( + batch.cumulative_seq_lengths.clone(), + batch_size + 1, + &self.device, + )?; + + let embedding_output = self + .embeddings + .forward(&input_ids, &type_ids, &position_ids)?; + + let outputs = + self.encoder + .forward(&embedding_output, &cu_seqlens, batch.max_length as usize)?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + match self.pool { + // CLS pooling + Pool::Cls => { + if batch_size > 1 { + // Get the indices of the cls tokens from cu_seqlens + let mut cls_indices = cu_seqlens.narrow(0, 0, batch_size)?; + + // If raw_indices is empty, we don't need to do anything with + // the pooled_indices + if has_raw_requests { + // We need the pooled indices to select the correct cls indices + let pooled_indices = Tensor::from_vec( + batch.pooled_indices.clone(), + batch.pooled_indices.len(), + &self.device, + )?; + + // Only select indices that requires pooling + cls_indices = cls_indices.index_select(&pooled_indices, 0)? + } + + // Select cls tokens + Some(outputs.index_select(&cls_indices, 0)?) + } else { + Some(outputs.i(0)?) + } + } + // Mean pooling + Pool::Mean => { + if batch_size > 1 { + // for each request that requires pooling + let results: Result> = batch + .pooled_indices + .into_iter() + .map(|i| { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let len = batch.cumulative_seq_lengths[i + 1] - start; + + // Mean + let embeddings = outputs.narrow(0, start as usize, len as usize)?; + embeddings.sum_keepdim(0)? / (len as f64) + }) + .collect(); + + // Concatenate all results + Some(Tensor::cat(&results?, 0)?) + } else { + Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) + } + } + Pool::Splade => { + unreachable!(); + } + } + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + if batch_size > 1 && has_pooling_requests { + // Create indexing vector for the embeddings + let mut final_indices: Vec = Vec::with_capacity(shape); + for i in batch.raw_indices.into_iter() { + let i = i as usize; + // Get start/end token index of this specific member of the batch + let start = batch.cumulative_seq_lengths[i]; + let end = batch.cumulative_seq_lengths[i + 1]; + + for j in start..end { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for FlashJinaCodeBertModel { + fn is_padded(&self) -> bool { + false + } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } +} \ No newline at end of file diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index 97bc7f96..3f5d5916 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -1,11 +1,36 @@ use crate::alibi::build_alibi_tensor; use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; use crate::models::Model; -use crate::models::{BertConfig, PositionEmbeddingType}; +use crate::models::PositionEmbeddingType; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; +use serde::Deserialize; +use std::collections::HashMap; use text_embeddings_backend_core::{Batch, ModelType, Pool}; +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct JinaConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub hidden_act: HiddenAct, + pub hidden_dropout_prob: f64, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub initializer_range: f64, + pub layer_norm_eps: f64, + pub pad_token_id: usize, + #[serde(default)] + pub position_embedding_type: PositionEmbeddingType, + #[serde(default)] + pub use_cache: bool, + pub classifier_dropout: Option, + pub id2label: Option>, +} + + #[derive(Debug)] pub struct BertEmbeddings { word_embeddings: Embedding, @@ -16,7 +41,7 @@ pub struct BertEmbeddings { } impl BertEmbeddings { - pub fn load(vb: VarBuilder, config: &BertConfig) -> Result { + pub fn load(vb: VarBuilder, config: &JinaConfig) -> Result { let position_embeddings = if config.position_embedding_type == PositionEmbeddingType::Absolute { Some(Embedding::new( @@ -88,7 +113,7 @@ struct BertAttention { } impl BertAttention { - pub fn load(vb: VarBuilder, config: &BertConfig) -> Result { + pub fn load(vb: VarBuilder, config: &JinaConfig) -> Result { let attention_head_size = config.hidden_size / config.num_attention_heads; let all_head_size = config.num_attention_heads * attention_head_size; let hidden_size = config.hidden_size; @@ -249,7 +274,7 @@ struct JinaBertLayer { } impl JinaBertLayer { - pub fn load(vb: VarBuilder, config: &BertConfig) -> Result { + pub fn load(vb: VarBuilder, config: &JinaConfig) -> Result { let attention = BertAttention::load(vb.pp("attention"), config)?; let gated_weight = vb @@ -316,7 +341,7 @@ struct BertEncoder { } impl BertEncoder { - pub fn load(vb: VarBuilder, config: &BertConfig) -> Result { + pub fn load(vb: VarBuilder, config: &JinaConfig) -> Result { let layers = (0..config.num_hidden_layers) .map(|index| JinaBertLayer::load(vb.pp(format!("layer.{index}")), config)) .collect::>>()?; @@ -354,7 +379,7 @@ pub struct JinaBertModel { } impl JinaBertModel { - pub fn load(vb: VarBuilder, config: &BertConfig, model_type: ModelType) -> Result { + pub fn load(vb: VarBuilder, config: &JinaConfig, model_type: ModelType) -> Result { let alibi = match config.position_embedding_type { PositionEmbeddingType::Alibi => Some(build_alibi_tensor( config.max_position_embeddings, diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs new file mode 100644 index 00000000..cb4084d7 --- /dev/null +++ b/backends/candle/src/models/jina_code.rs @@ -0,0 +1,728 @@ +use crate::alibi::build_alibi_tensor; +use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; +use crate::models::Model; +use crate::models::PositionEmbeddingType; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{Embedding, VarBuilder}; +use serde::Deserialize; +use std::collections::HashMap; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct JinaCodeConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub hidden_act: HiddenAct, + pub hidden_dropout_prob: f64, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub initializer_range: f64, + pub layer_norm_eps: f64, + pub pad_token_id: usize, + #[serde(default)] + pub position_embedding_type: PositionEmbeddingType, + #[serde(default)] + pub use_cache: bool, + pub classifier_dropout: Option, + pub id2label: Option>, +} + + +#[derive(Debug)] +pub struct BertEmbeddings { + word_embeddings: Embedding, + token_type_embeddings: Embedding, + position_embeddings: Option, + layer_norm: LayerNorm, + span: tracing::Span, +} + +impl BertEmbeddings { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig) -> Result { + let position_embeddings = + if config.position_embedding_type == PositionEmbeddingType::Absolute { + Some(Embedding::new( + vb.pp("position_embeddings").get( + (config.max_position_embeddings, config.hidden_size), + "weight", + )?, + config.hidden_size, + )) + } else { + None + }; + + Ok(Self { + word_embeddings: Embedding::new( + vb.pp("word_embeddings") + .get((config.vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + ), + token_type_embeddings: Embedding::new( + vb.pp("token_type_embeddings") + .get((config.type_vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + ), + position_embeddings, + layer_norm: LayerNorm::load( + vb.pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?, + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: &Tensor, + position_ids: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; + + if let Some(position_embeddings) = &self.position_embeddings { + let position_embeddings = position_embeddings.forward(position_ids)?; + let embeddings = input_embeddings.add(&token_type_embeddings)?; + self.layer_norm + .forward(&embeddings, Some(&position_embeddings)) + } else { + self.layer_norm + .forward(&input_embeddings, Some(&token_type_embeddings)) + } + } +} + +struct BertAttention { + query_linear: Linear, + key_linear: Linear, + value_linear: Linear, + + dense: Linear, + layer_norm_q: LayerNorm, + layer_norm_k: LayerNorm, + layer_norm_out: LayerNorm, + + num_attention_heads: usize, + attention_head_size: usize, + softmax_scale: f64, + + span: tracing::Span, +} + +impl BertAttention { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig) -> Result { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let hidden_size = config.hidden_size; + + let query_weight = vb + .pp("self.query") + .get((all_head_size, hidden_size), "weight")?; + let query_bias = vb.pp("self.query").get(all_head_size, "bias")?; + + let key_weight = vb + .pp("self.key") + .get((all_head_size, hidden_size), "weight")?; + let key_bias = vb.pp("self.key").get(all_head_size, "bias")?; + + let value_weight = vb + .pp("self.value") + .get((all_head_size, hidden_size), "weight")?; + let value_bias = vb.pp("self.value").get(all_head_size, "bias")?; + + let layer_norm_q = LayerNorm::load( + vb.pp("self").pp("layer_norm_q"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + let layer_norm_k = LayerNorm::load( + vb.pp("self").pp("layer_norm_k"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + + let query_linear = Linear::new(query_weight, Some(query_bias), None); + let key_linear = Linear::new(key_weight, Some(key_bias), None); + let value_linear = Linear::new(value_weight, Some(value_bias), None); + + let dense_weight = vb + .pp("output") + .pp("dense") + .get((hidden_size, hidden_size), "weight")?; + let dense_bias = vb.pp("output").pp("dense").get(hidden_size, "bias")?; + + let dense = Linear::new(dense_weight, Some(dense_bias), None); + + let layer_norm_out = LayerNorm::load( + vb.pp("output").pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + + let softmax_scale = 1. / (attention_head_size as f64).sqrt(); + + Ok(Self { + query_linear, + key_linear, + value_linear, + dense, + layer_norm_q, + layer_norm_k, + layer_norm_out, + num_attention_heads: config.num_attention_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let device = hidden_states.device(); + let residual = hidden_states.clone(); + + let query_layer = self.query_linear.forward(hidden_states)?; + let query_layer = self.layer_norm_q.forward(&query_layer, None)?; + + let key_layer = self.key_linear.forward(hidden_states)?; + let key_layer = self.layer_norm_k.forward(&key_layer, None)?; + + let value_layer = self.value_linear.forward(hidden_states)?; + + let mut new_qkv_shape = query_layer.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads); + new_qkv_shape.push(self.attention_head_size); + + let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + + #[allow(unused_variables)] + let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = + (device, get_cublas_lt_wrapper()) + { + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?; + let key_layer = key_layer.flatten(0, 1)?; + let query_layer = query_layer.flatten(0, 1)?; + let value_layer = value_layer.flatten(0, 1)?; + let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?; + + // If attention_bias is set, we fuse the add by giving it as the output matrix + // and setting beta to 1.0 + let beta = match attention_bias.is_some() { + true => Some(1.0), + false => None, + }; + + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let attention_scores = cublaslt.batch_matmul( + &key_layer, + &query_layer, + attention_bias.as_ref(), + Some(self.softmax_scale as f32), + beta, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &value_layer.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&query_layer), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape(( + batch_size, + self.num_attention_heads, + seq_len, + self.attention_head_size, + )) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let mut attention_scores = (attention_scores * self.softmax_scale)?; + + if let Some(attention_bias) = attention_bias { + attention_scores = attention_scores.add(attention_bias)?; + } + + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs.matmul(&value_layer.contiguous()?) + }?; + + let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; + + let hidden_states = self.dense.forward(&context_layer)?; + let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?; + + Ok(hidden_states) + } +} + +struct JinaCodeBertLayer { + attention: BertAttention, + up_gated_layer: Linear, + down_layer: Linear, + layer_norm_1: LayerNorm, + layer_norm_2: LayerNorm, + act: HiddenAct, + + intermediate_size: usize, + + span: tracing::Span, +} + +impl JinaCodeBertLayer { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig) -> Result { + let attention = BertAttention::load(vb.pp("attention"), config)?; + + let up_gated_weight = vb + .pp("mlp") + .pp("up_gated_layer") + .get((config.intermediate_size * 2, config.hidden_size), "weight")?; + let up_gated_layer = Linear::new(up_gated_weight, None, None); + + let down_weight = vb + .pp("mlp") + .pp("down_layer") + .get((config.hidden_size, config.intermediate_size), "weight")?; + let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?; + let down_layer = Linear::new(down_weight, Some(down_bias), None); + + let layer_norm_1 = LayerNorm::load( + vb.pp("layer_norm_1"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + let layer_norm_2 = LayerNorm::load( + vb.pp("layer_norm_2"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; + + Ok(Self { + attention, + up_gated_layer, + down_layer, + layer_norm_1, + layer_norm_2, + act: config.hidden_act.clone(), + intermediate_size: config.intermediate_size, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + attention_bias: Option<&Tensor>, + ) -> Result { + let _enter = self.span.enter(); + + // Pre-Norm + let residual = hidden_states.clone(); + + // Self-Attention block + let hidden_states = self.attention.forward(hidden_states, attention_bias)?; + + // Pre-MLP LayerNorm + let hidden_states = self.layer_norm_1.forward(&hidden_states, Some(&residual))?; + + // MLP block + let residual = hidden_states.clone(); + let hidden_states = self.up_gated_layer.forward(&hidden_states)?; + let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?; + let gated = hidden_states.i((.., .., self.intermediate_size..))?; + let gated = match self.act { + HiddenAct::Gelu => gated.gelu(), + HiddenAct::Relu => gated.relu(), + HiddenAct::Swiglu => gated.silu(), + }?; + let hidden_states = (non_gated * gated)?; + let hidden_states = self.down_layer.forward(&hidden_states)?; + + // Post-MLP LayerNorm + let hidden_states = self.layer_norm_2.forward(&hidden_states, Some(&residual))?; + + Ok(hidden_states) + } +} + +struct BertEncoder { + layers: Vec, + span: tracing::Span, +} + +impl BertEncoder { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| JinaCodeBertLayer::load(vb.pp(format!("layer.{index}")), config)) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + + Ok(BertEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, attention_bias)?; + } + + Ok(hidden_states) + } +} + +pub struct JinaCodeBertModel { + embeddings: BertEmbeddings, + encoder: BertEncoder, + pool: Pool, + alibi: Option, + + num_attention_heads: usize, + + device: Device, + dtype: DType, + + span: tracing::Span, +} + +impl JinaCodeBertModel { + pub fn load(vb: VarBuilder, config: &JinaCodeConfig, model_type: ModelType) -> Result { + let alibi = match config.position_embedding_type { + PositionEmbeddingType::Alibi => Some(build_alibi_tensor( + config.max_position_embeddings, + config.num_attention_heads, + vb.device(), + vb.dtype(), + )?), + PositionEmbeddingType::Absolute => None, + }; + + let pool = match model_type { + ModelType::Classifier => { + candle::bail!("`classifier` model type is not supported for Jina") + } + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Jina") + } + pool + } + }; + + let (embeddings, encoder) = match ( + BertEmbeddings::load(vb.pp("embeddings"), config), + BertEncoder::load(vb.pp("encoder"), config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(embeddings), Ok(encoder)) = ( + BertEmbeddings::load(vb.pp("bert.embeddings"), config), + BertEncoder::load(vb.pp("bert.encoder"), config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } + }; + + Ok(Self { + embeddings, + encoder, + pool, + alibi, + num_attention_heads: config.num_attention_heads, + device: vb.device().clone(), + dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + let shape = (batch_size, max_length); + + let (input_ids, type_ids, position_ids, input_lengths, attention_bias, attention_mask) = + if batch_size > 1 { + // Prepare padded batch + let elems = batch_size * max_length; + + let mut input_ids = Vec::with_capacity(elems); + let mut type_ids = Vec::with_capacity(elems); + let mut position_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + let mut attention_bias = Vec::with_capacity(elems); + let mut input_lengths = Vec::with_capacity(batch_size); + // Bool to know if we need to use the attention mask + let mut masking = false; + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j]); + type_ids.push(batch.token_type_ids[j]); + position_ids.push(batch.position_ids[j]); + attention_mask.push(1.0_f32); + attention_bias.push(0.0); + } + + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + // Set bool to use attention mask + masking = true; + for _ in 0..padding { + input_ids.push(0); + type_ids.push(0); + position_ids.push(0); + attention_mask.push(0.0_f32); + attention_bias.push(f32::NEG_INFINITY); + } + } + } + + let (attention_bias, attention_mask) = match masking { + true => { + // We only need the mask if we use mean pooling + // For CLS pooling, the bias is enough + let attention_mask = if self.pool == Pool::Mean { + let attention_mask = Tensor::from_vec( + attention_mask, + (batch_size, max_length, 1), + &self.device, + )? + .to_dtype(self.dtype)?; + + Some(attention_mask) + } else { + None + }; + + let attention_bias = Tensor::from_vec( + attention_bias, + (batch_size, 1, 1, max_length), + &self.device, + )? + .to_dtype(self.dtype)?; + + // Broadcast once instead of at every layer + let mut attention_bias = attention_bias.broadcast_as(( + batch_size, + self.num_attention_heads, + max_length, + max_length, + ))?; + + // Add alibi tensor + if let Some(alibi) = &self.alibi { + let alibi = alibi + .i((.., .., 0..max_length, 0..max_length))? + .broadcast_as(( + batch_size, + self.num_attention_heads, + max_length, + max_length, + ))?; + + attention_bias = attention_bias.add(&alibi)?; + } + + (Some(attention_bias.contiguous()?), attention_mask) + } + false => { + if let Some(alibi) = &self.alibi { + ( + Some( + alibi + .i((.., .., 0..max_length, 0..max_length))? + .broadcast_as(( + batch_size, + self.num_attention_heads, + max_length, + max_length, + ))? + .contiguous()?, + ), + None, + ) + } else { + (None, None) + } + } + }; + + ( + input_ids, + type_ids, + position_ids, + input_lengths, + attention_bias, + attention_mask, + ) + } else { + let attention_bias = if let Some(alibi) = &self.alibi { + Some( + alibi + .i((.., .., 0..max_length, 0..max_length))? + .contiguous()?, + ) + } else { + None + }; + + ( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, + vec![batch.max_length as f32], + attention_bias, + None, + ) + }; + + // Create CPU tensors + let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; + let type_ids = Tensor::from_vec(type_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?; + let input_lengths = + Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; + + let embedding_output = self + .embeddings + .forward(&input_ids, &type_ids, &position_ids)?; + + let outputs = self + .encoder + .forward(&embedding_output, attention_bias.as_ref())?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + let pooled_indices_length = batch.pooled_indices.len(); + let mut outputs = outputs.clone(); + + // Only use pooled_indices if at least one member of the batch ask for raw embeddings + let pooled_indices = if has_raw_requests { + let pooled_indices = + Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?; + + // Select values in the batch + outputs = outputs.index_select(&pooled_indices, 0)?; + Some(pooled_indices) + } else { + None + }; + + let pooled_embeddings = match self.pool { + // CLS pooling + Pool::Cls => outputs.i((.., 0))?, + // Mean pooling + Pool::Mean => { + if let Some(ref attention_mask) = attention_mask { + let mut attention_mask = attention_mask.clone(); + + if let Some(pooled_indices) = pooled_indices { + // Select values in the batch + attention_mask = attention_mask.index_select(&pooled_indices, 0)?; + }; + + // Mask padded values + outputs = outputs.broadcast_mul(&attention_mask)?; + } + + (outputs.sum(1)?.broadcast_div(&input_lengths))? + } + Pool::Splade => unreachable!(), + }; + Some(pooled_embeddings) + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + // Reshape outputs + let (b, l, h) = outputs.shape().dims3()?; + let outputs = outputs.reshape((b * l, h))?; + + // We need to remove the padding tokens only if batch_size > 1 and there are some + // member of the batch that require pooling + // or if batch_size > 1 and the members of the batch have different lengths + if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 { + let mut final_indices: Vec = Vec::with_capacity(batch_size * max_length); + + for i in batch.raw_indices.into_iter() { + let start = i * batch.max_length; + let i = i as usize; + let length = + batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]; + + for j in start..start + length { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for JinaCodeBertModel { + fn is_padded(&self) -> bool { + true + } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } +} diff --git a/backends/candle/tests/snapshots/test_flash_jina_code__jina_code_batch.snap b/backends/candle/tests/snapshots/test_flash_jina_code__jina_code_batch.snap new file mode 100644 index 00000000..3441cb71 --- /dev/null +++ b/backends/candle/tests/snapshots/test_flash_jina_code__jina_code_batch.snap @@ -0,0 +1,772 @@ +--- +source: backends/candle/tests/test_flash_jina_code.rs +expression: embeddings_batch +--- +- - -0.011624558 + - -0.0016926473 + - -0.006434922 + - -0.009909191 + - 0.027506275 + - 0.022786874 + - -0.059169117 + - -0.047887173 + - 0.038912594 + - 0.025214802 + - -0.014350341 + - 0.039117776 + - 0.03485464 + - 0.05296896 + - 0.034907352 + - 0.0013971328 + - 0.0019907136 + - 0.052471623 + - -0.014562107 + - 0.00030206257 + - -0.02770458 + - 0.02685756 + - -0.015385578 + - -0.012668428 + - -0.025259107 + - 0.00836893 + - 0.028925523 + - -0.035507426 + - -0.04220648 + - 0.047328673 + - -0.083232224 + - -0.02608385 + - -0.012809777 + - -0.0140402755 + - -0.013649549 + - -0.013641793 + - 0.02880054 + - 0.04465023 + - -0.0274121 + - -0.03170939 + - -0.047180377 + - -0.0349574 + - -0.037762504 + - 0.024104225 + - -0.039361924 + - -0.024559166 + - -0.0138650155 + - -0.049862508 + - -0.0507675 + - -0.07775355 + - -0.055519626 + - -0.025151063 + - 0.040781654 + - 0.0039354665 + - 0.015940087 + - -0.022214677 + - -0.013484792 + - -0.00070730236 + - -0.009409981 + - 0.0016682384 + - 0.016079267 + - -0.032368172 + - 0.024572799 + - 0.026780155 + - -0.025954 + - 0.021282032 + - -0.047395118 + - 0.044386413 + - -0.023020437 + - 0.0056320634 + - -0.017416032 + - -0.024118245 + - 0.05878816 + - 0.00059366866 + - -0.018553924 + - 0.003762008 + - 0.0035026476 + - -0.0077498616 + - -0.004517831 + - 0.008116448 + - -0.010922055 + - 0.037391223 + - -0.03318112 + - -0.07876148 + - -0.0018068684 + - -0.0484656 + - -0.038104814 + - -0.0070756334 + - -0.015427567 + - -0.04499487 + - 0.008910639 + - 0.029193114 + - 0.032674707 + - 0.0305758 + - -0.017541302 + - 0.028164856 + - -0.0344121 + - 0.016969312 + - 0.04108575 + - -0.054463603 + - -0.005427823 + - -0.008252881 + - -0.024992533 + - 0.048412405 + - -0.05769221 + - -0.08227673 + - -0.0523458 + - 0.025244992 + - -0.016289622 + - -0.049253095 + - -0.03999235 + - 0.05642755 + - -0.010015731 + - 0.04892553 + - -0.02504831 + - -0.014305144 + - 0.04907702 + - -0.0096177375 + - -0.0854665 + - 0.017617436 + - -0.014481601 + - 0.046187755 + - -0.030009247 + - -0.049553443 + - -0.006797771 + - -0.034234725 + - 0.06525787 + - 0.017298417 + - -0.018988462 + - -0.08520813 + - -0.011840521 + - -0.042942975 + - -0.056341827 + - 0.08071612 + - 0.028146833 + - -0.0141261155 + - 0.04001012 + - 0.018236378 + - 0.01583886 + - 0.0058086845 + - 0.0005290361 + - 0.0045259795 + - 0.024873685 + - 0.018050008 + - 0.02440984 + - -0.057676647 + - -0.022263395 + - -0.033016473 + - 0.014353932 + - -0.010189813 + - 0.018424626 + - -0.018238619 + - 0.07698089 + - -0.015595315 + - 0.007074499 + - 0.027340613 + - -0.02035102 + - -0.073149584 + - -0.0061021335 + - 0.0209061 + - -0.024985746 + - -0.020401739 + - -0.0048702923 + - 0.0444933 + - -0.02158648 + - 0.010312202 + - 0.009883107 + - 0.022790415 + - -0.01893943 + - 0.017615119 + - 0.023000762 + - -0.051276624 + - 0.019835759 + - 0.016675396 + - 0.021532865 + - 0.023097536 + - 0.01091247 + - 0.007642041 + - 0.0385409 + - 0.044193454 + - 0.003319116 + - 0.019355802 + - -0.055695307 + - -0.07680676 + - -0.035279598 + - -0.0070618573 + - -0.01408385 + - 0.03107026 + - 0.011547187 + - 0.04217532 + - 0.048034772 + - -0.014330861 + - -0.0007901662 + - 0.009103947 + - 0.011091706 + - -0.008960474 + - -0.009301379 + - 0.032424763 + - 0.09180531 + - 0.015491586 + - 0.031114861 + - 0.013289549 + - 0.03616163 + - 0.031084707 + - -0.021437835 + - 0.011905716 + - -0.104623 + - 0.033417992 + - -0.04683295 + - -0.05382028 + - 0.025979578 + - -0.003795323 + - 0.094473585 + - 0.09126942 + - 0.067700386 + - -0.002935823 + - -0.031604428 + - 0.0045677535 + - 0.014042491 + - -0.02969786 + - -0.012263599 + - -0.05807616 + - -0.0107510965 + - 0.022099612 + - 0.02685798 + - -0.048912797 + - 0.020464186 + - 0.011517319 + - -0.0016606319 + - -0.033255223 + - 0.02488959 + - -0.043240122 + - 0.00013934783 + - 0.037608404 + - 0.0748658 + - 0.061115693 + - -0.01670558 + - 0.037827995 + - -0.004162286 + - -0.00011341072 + - -0.031436242 + - -0.040532686 + - -0.0096016815 + - -0.065639205 + - -0.030070387 + - 0.05715196 + - 0.024338327 + - -0.008972259 + - -0.013220871 + - -0.07115904 + - -0.03446362 + - -0.010893238 + - 0.011950567 + - 0.0143028535 + - 0.0022963681 + - -0.060428943 + - 0.047038406 + - -0.017493663 + - 0.0208505 + - -0.010212147 + - -0.03337894 + - 0.026969628 + - -0.019734934 + - 0.0922163 + - -0.056534916 + - 0.017255465 + - 0.021692138 + - 0.023583038 + - 0.017224979 + - -0.0689936 + - 0.23665377 + - 0.04984247 + - -0.0039337534 + - -0.010740561 + - 0.00313393 + - -0.02665381 + - 0.013037893 + - 0.020565815 + - -0.002266256 + - 0.03748113 + - 0.018353125 + - -0.048318923 + - -0.06522075 + - -0.021460611 + - -0.029665926 + - -0.025507431 + - 0.0033435076 + - -0.0071087605 + - 0.001970083 + - -0.0025331723 + - 0.061776515 + - 0.017747495 + - 0.04396135 + - 0.0011617352 + - 0.02509327 + - 0.039078914 + - -0.03762337 + - 0.021575885 + - 0.015548619 + - 0.043990802 + - -0.024633111 + - -0.0013324996 + - 0.063795656 + - -0.02035371 + - -0.0440217 + - -0.033049908 + - 0.031034056 + - -0.004495562 + - -0.0044647786 + - -0.033148468 + - 0.0025072312 + - 0.009637453 + - 0.04357035 + - 0.044546504 + - -0.08865154 + - -0.08487347 + - 0.00205395 + - 0.0060572727 + - 0.023767816 + - 0.051007573 + - 0.0057035745 + - 0.040539596 + - -0.035988905 + - -0.01824621 + - 0.007887274 + - -0.052848075 + - 0.08228733 + - 0.0015825987 + - 0.0004136183 + - 0.018108545 + - 0.040081892 + - -0.035405345 + - 0.050933696 + - 0.036154125 + - -0.03947257 + - -0.0018412384 + - -0.005589829 + - 0.03620321 + - 0.012144826 + - -0.008619581 + - -0.0043279063 + - 0.016455552 + - -0.015757388 + - 0.047043085 + - 0.08675011 + - -0.009986743 + - 0.022379123 + - 0.009470605 + - 0.034120724 + - 0.009922824 + - -0.014435422 + - 0.017998574 + - -0.03849387 + - 0.042357396 + - -0.010053916 + - 0.057034835 + - -0.027412737 + - 0.017308975 + - 0.02185228 + - -0.048017155 + - -0.025138885 + - 0.016482655 + - -0.01756698 + - 0.031146016 + - 0.021930695 + - -0.00075341173 + - -0.015085438 + - 0.0146785155 + - 0.028547939 + - -0.036677707 + - -0.022699077 + - -0.045135103 + - -0.02802744 + - 0.071454674 + - -0.009201392 + - -0.051717956 + - -0.019421624 + - 0.0034821872 + - 0.06667364 + - 0.09145379 + - -0.068826884 + - 0.053568542 + - 0.01160062 + - -0.025829546 + - 0.04487214 + - 0.030954553 + - -0.022543794 + - -0.009475118 + - -0.014623143 + - -0.02070793 + - -0.02656788 + - 0.009701591 + - -0.025120718 + - -0.018472325 + - -0.027967019 + - -0.021122226 + - -0.009891716 + - 0.018696679 + - -0.068876855 + - -0.023419108 + - -0.025495855 + - 0.016256742 + - 0.00064859784 + - 0.037749656 + - -0.046321914 + - 0.065936595 + - -0.008921658 + - 0.07497468 + - -0.05094385 + - -0.052860104 + - 0.022196138 + - -0.0140462285 + - -0.03562305 + - 0.024858234 + - 0.021310989 + - 0.014657512 + - 0.07767391 + - -0.027777392 + - 0.057577316 + - 0.055513144 + - 0.06322926 + - -0.026312957 + - 0.010970987 + - -0.0057475767 + - 0.0028267235 + - 0.051367335 + - -0.022320578 + - 0.009050165 + - 0.016952222 + - -0.032026373 + - -0.074292615 + - -0.02315535 + - 0.025375988 + - -0.041241057 + - -0.032157563 + - -0.015576387 + - -0.015834223 + - 0.02224181 + - 0.017586967 + - -0.029070066 + - 0.0030065721 + - -0.051857695 + - 0.04008828 + - -0.007960872 + - -0.061745025 + - -0.00086617953 + - 0.026723113 + - 0.008719714 + - -0.049826868 + - -0.0046574236 + - 0.018954279 + - -0.007935451 + - 0.053987946 + - 0.022795292 + - 0.029722994 + - -0.010146585 + - -0.019956842 + - -0.014686722 + - 0.01708331 + - 0.020001508 + - 0.016564105 + - 0.011379248 + - -0.011843253 + - 0.04056168 + - 0.004384286 + - -0.049596023 + - 0.02674251 + - 0.0076475106 + - -0.0064563937 + - -0.014233138 + - 0.014224383 + - 0.0052741244 + - 0.049964864 + - -0.016286546 + - -0.001200327 + - 0.041904222 + - -0.0010395087 + - -0.05105399 + - 0.022099879 + - -0.019455278 + - 0.009444127 + - -0.081325725 + - 0.015994828 + - 0.0010952728 + - 0.0008373874 + - -0.0016424303 + - -0.05096469 + - 0.045976803 + - 0.024695056 + - -0.031656373 + - -0.0013138534 + - -0.0060524447 + - 0.060276203 + - -0.016745795 + - -0.029930653 + - -0.0222771 + - 0.012314711 + - 0.038991332 + - -0.006665343 + - 0.041694533 + - 0.0076992502 + - -0.014353178 + - 0.0025135442 + - -0.0498445 + - 0.020322764 + - 0.0575802 + - -0.006096128 + - -0.010841882 + - -0.024337102 + - 0.021975596 + - 0.0064031687 + - -0.026746146 + - 0.03455729 + - -0.011909055 + - 0.016994143 + - 0.026053395 + - -0.020393625 + - 0.06403403 + - 0.042590734 + - 0.009193913 + - 0.04016698 + - 0.028304791 + - 0.022147119 + - -0.030121539 + - -0.0037334429 + - 0.03235819 + - -0.020825844 + - 0.0009766509 + - -0.012216568 + - 0.07944978 + - 0.04177374 + - -0.008281654 + - -0.008908983 + - 0.04799388 + - -0.012743454 + - -0.00076762337 + - 0.012673029 + - -0.018283572 + - 0.022068778 + - -0.009605337 + - 0.0030652087 + - -0.036517244 + - -0.006263211 + - 0.0194632 + - 0.009333852 + - 0.02350168 + - -0.0008530139 + - 0.018934859 + - -0.013986168 + - -0.010833636 + - -0.011189203 + - 0.011567913 + - -0.07253544 + - -0.005748846 + - 0.11930293 + - 0.060044624 + - -0.023167728 + - -0.015781552 + - -0.010494401 + - 0.01930528 + - -0.039608266 + - -0.073587865 + - 0.0019034932 + - -0.0013838339 + - 0.026257295 + - 0.004433007 + - -0.051545423 + - -0.033456888 + - -0.06401291 + - -0.08664347 + - 0.010781564 + - 0.0408775 + - 0.021475399 + - 0.017633006 + - -0.02186024 + - 0.047795497 + - -0.006370007 + - 0.01792626 + - 0.005195737 + - 0.0026206016 + - -0.008816542 + - 0.009266863 + - -0.018453414 + - 0.0040575014 + - 0.0047053 + - -0.0197809 + - -0.01580334 + - 0.007821501 + - -0.06296649 + - -0.06274416 + - -0.0031381177 + - -0.024228694 + - -0.002459634 + - 0.00034323192 + - -0.02430543 + - -0.029262288 + - -0.04606642 + - -0.013138838 + - 0.040473 + - 0.012308485 + - 0.0020701357 + - -0.007718021 + - -0.0064122216 + - -0.024890581 + - 0.014665469 + - -0.0028788927 + - -0.047072053 + - -0.014959743 + - 0.004587824 + - -0.07462158 + - -0.008558996 + - 0.019324543 + - -0.02247574 + - -0.07721102 + - 0.007920586 + - -0.020274863 + - 0.028696692 + - 0.024707401 + - 0.016905285 + - 0.012742534 + - 0.018577736 + - -0.05768951 + - 0.03203929 + - 0.018105863 + - -0.03917534 + - -0.026208939 + - -0.038492158 + - -0.043517314 + - -0.008031121 + - -0.016162876 + - -0.0010640965 + - -0.004164019 + - 0.005193703 + - -0.0058410293 + - 0.029311381 + - -0.016533207 + - 0.05005747 + - 0.012600715 + - 0.016292874 + - -0.008300225 + - -0.08953819 + - 0.06544125 + - -0.010512851 + - -0.021443438 + - 0.030277776 + - -0.06502247 + - 0.004850903 + - -0.013137611 + - 0.0506941 + - 0.0072725127 + - -0.025755724 + - 0.008224718 + - 0.013813313 + - 0.037197027 + - 0.0023671025 + - 0.00763629 + - 0.011905766 + - 0.033143394 + - 0.007750765 + - -0.07725993 + - -0.03193554 + - -0.016900484 + - 0.06256093 + - 0.015902048 + - 0.062251173 + - 0.07478062 + - -0.009171957 + - -0.025452917 + - -0.015754124 + - -0.004426243 + - -0.0873611 + - 0.0077999695 + - 0.061026644 + - 0.016489599 + - 0.0066420045 + - -0.0062355455 + - 0.00345123 + - -0.022935547 + - 0.03939866 + - -0.0037231673 + - 0.0033949488 + - 0.029471302 + - 0.023097953 + - -0.0237214 + - 0.046621986 + - 0.029790087 + - -0.030113066 + - -0.012432801 + - 0.036813233 + - 0.01785254 + - -0.08032645 + - 0.051262226 + - 0.04222712 + - -0.023358794 + - -0.03602671 + - 0.0092950305 + - -0.015663076 + - -0.023873692 + - 0.009383877 + - -0.056770466 + - 0.032832243 + - 0.019920528 + - -0.04734062 + - -0.03368295 + - 0.0041841906 + - 0.03257228 + - -0.07387151 + - 0.011083565 + - 0.008633363 + - -0.04694844 + - 0.0027943242 + - -0.009078945 + - -0.011981829 + - -0.026407028 + - 0.033040557 + - -0.025803888 + - -0.021555608 + - 0.0042838412 + - -0.04263439 + - 0.0465685 + - 0.070476055 + - -0.02560814 + - 0.060619954 + - 0.0071254023 + - -0.062549844 + - -0.021544673 + - -0.03598606 + - -0.04904548 + - 0.04631042 + - -0.029345924 + - -0.015404836 + - 0.0063473387 + - -0.023926385 + - 0.022935716 + - -0.1004558 + - -0.007012574 + - 0.049480513 + - -0.02592937 + - 0.057209775 + - -0.010056263 + - -0.02236333 + - 0.008163001 + - 0.0036735693 + - -0.008406754 + - -0.0029980235 + - -0.0052409745 + - -0.014090007 + - -0.025248934 + - 0.022062942 + - -0.019766279 + - 0.07160526 + - -0.00075892545 + - -0.0096911425 + - -0.019483106 + - 0.042904716 + - -0.0013572425 + - 0.04353212 + - -0.07759716 + - -0.011731261 + - -0.10602095 + - 0.0010180247 + - -0.07403971 + - 0.043784548 + - -0.010357722 + - -0.009020027 + - 0.026289912 + - 0.053744033 + - 0.009665143 diff --git a/backends/candle/tests/snapshots/test_flash_jina_code__jina_code_single.snap b/backends/candle/tests/snapshots/test_flash_jina_code__jina_code_single.snap new file mode 100644 index 00000000..5db1bc22 --- /dev/null +++ b/backends/candle/tests/snapshots/test_flash_jina_code__jina_code_single.snap @@ -0,0 +1,772 @@ +--- +source: backends/candle/tests/test_flash_jina_code.rs +expression: embeddings_single +--- +- - -0.011624558 + - -0.0016926473 + - -0.006434922 + - -0.009909191 + - 0.027506275 + - 0.022786874 + - -0.059169117 + - -0.047887173 + - 0.038912594 + - 0.025214802 + - -0.014350341 + - 0.039117776 + - 0.03485464 + - 0.05296896 + - 0.034907352 + - 0.0013971328 + - 0.0019907136 + - 0.052471623 + - -0.014562107 + - 0.00030206257 + - -0.02770458 + - 0.02685756 + - -0.015385578 + - -0.012668428 + - -0.025259107 + - 0.00836893 + - 0.028925523 + - -0.035507426 + - -0.04220648 + - 0.047328673 + - -0.083232224 + - -0.02608385 + - -0.012809777 + - -0.0140402755 + - -0.013649549 + - -0.013641793 + - 0.02880054 + - 0.04465023 + - -0.0274121 + - -0.03170939 + - -0.047180377 + - -0.0349574 + - -0.037762504 + - 0.024104225 + - -0.039361924 + - -0.024559166 + - -0.0138650155 + - -0.049862508 + - -0.0507675 + - -0.07775355 + - -0.055519626 + - -0.025151063 + - 0.040781654 + - 0.0039354665 + - 0.015940087 + - -0.022214677 + - -0.013484792 + - -0.00070730236 + - -0.009409981 + - 0.0016682384 + - 0.016079267 + - -0.032368172 + - 0.024572799 + - 0.026780155 + - -0.025954 + - 0.021282032 + - -0.047395118 + - 0.044386413 + - -0.023020437 + - 0.0056320634 + - -0.017416032 + - -0.024118245 + - 0.05878816 + - 0.00059366866 + - -0.018553924 + - 0.003762008 + - 0.0035026476 + - -0.0077498616 + - -0.004517831 + - 0.008116448 + - -0.010922055 + - 0.037391223 + - -0.03318112 + - -0.07876148 + - -0.0018068684 + - -0.0484656 + - -0.038104814 + - -0.0070756334 + - -0.015427567 + - -0.04499487 + - 0.008910639 + - 0.029193114 + - 0.032674707 + - 0.0305758 + - -0.017541302 + - 0.028164856 + - -0.0344121 + - 0.016969312 + - 0.04108575 + - -0.054463603 + - -0.005427823 + - -0.008252881 + - -0.024992533 + - 0.048412405 + - -0.05769221 + - -0.08227673 + - -0.0523458 + - 0.025244992 + - -0.016289622 + - -0.049253095 + - -0.03999235 + - 0.05642755 + - -0.010015731 + - 0.04892553 + - -0.02504831 + - -0.014305144 + - 0.04907702 + - -0.0096177375 + - -0.0854665 + - 0.017617436 + - -0.014481601 + - 0.046187755 + - -0.030009247 + - -0.049553443 + - -0.006797771 + - -0.034234725 + - 0.06525787 + - 0.017298417 + - -0.018988462 + - -0.08520813 + - -0.011840521 + - -0.042942975 + - -0.056341827 + - 0.08071612 + - 0.028146833 + - -0.0141261155 + - 0.04001012 + - 0.018236378 + - 0.01583886 + - 0.0058086845 + - 0.0005290361 + - 0.0045259795 + - 0.024873685 + - 0.018050008 + - 0.02440984 + - -0.057676647 + - -0.022263395 + - -0.033016473 + - 0.014353932 + - -0.010189813 + - 0.018424626 + - -0.018238619 + - 0.07698089 + - -0.015595315 + - 0.007074499 + - 0.027340613 + - -0.02035102 + - -0.073149584 + - -0.0061021335 + - 0.0209061 + - -0.024985746 + - -0.020401739 + - -0.0048702923 + - 0.0444933 + - -0.02158648 + - 0.010312202 + - 0.009883107 + - 0.022790415 + - -0.01893943 + - 0.017615119 + - 0.023000762 + - -0.051276624 + - 0.019835759 + - 0.016675396 + - 0.021532865 + - 0.023097536 + - 0.01091247 + - 0.007642041 + - 0.0385409 + - 0.044193454 + - 0.003319116 + - 0.019355802 + - -0.055695307 + - -0.07680676 + - -0.035279598 + - -0.0070618573 + - -0.01408385 + - 0.03107026 + - 0.011547187 + - 0.04217532 + - 0.048034772 + - -0.014330861 + - -0.0007901662 + - 0.009103947 + - 0.011091706 + - -0.008960474 + - -0.009301379 + - 0.032424763 + - 0.09180531 + - 0.015491586 + - 0.031114861 + - 0.013289549 + - 0.03616163 + - 0.031084707 + - -0.021437835 + - 0.011905716 + - -0.104623 + - 0.033417992 + - -0.04683295 + - -0.05382028 + - 0.025979578 + - -0.003795323 + - 0.094473585 + - 0.09126942 + - 0.067700386 + - -0.002935823 + - -0.031604428 + - 0.0045677535 + - 0.014042491 + - -0.02969786 + - -0.012263599 + - -0.05807616 + - -0.0107510965 + - 0.022099612 + - 0.02685798 + - -0.048912797 + - 0.020464186 + - 0.011517319 + - -0.0016606319 + - -0.033255223 + - 0.02488959 + - -0.043240122 + - 0.00013934783 + - 0.037608404 + - 0.0748658 + - 0.061115693 + - -0.01670558 + - 0.037827995 + - -0.004162286 + - -0.00011341072 + - -0.031436242 + - -0.040532686 + - -0.0096016815 + - -0.065639205 + - -0.030070387 + - 0.05715196 + - 0.024338327 + - -0.008972259 + - -0.013220871 + - -0.07115904 + - -0.03446362 + - -0.010893238 + - 0.011950567 + - 0.0143028535 + - 0.0022963681 + - -0.060428943 + - 0.047038406 + - -0.017493663 + - 0.0208505 + - -0.010212147 + - -0.03337894 + - 0.026969628 + - -0.019734934 + - 0.0922163 + - -0.056534916 + - 0.017255465 + - 0.021692138 + - 0.023583038 + - 0.017224979 + - -0.0689936 + - 0.23665377 + - 0.04984247 + - -0.0039337534 + - -0.010740561 + - 0.00313393 + - -0.02665381 + - 0.013037893 + - 0.020565815 + - -0.002266256 + - 0.03748113 + - 0.018353125 + - -0.048318923 + - -0.06522075 + - -0.021460611 + - -0.029665926 + - -0.025507431 + - 0.0033435076 + - -0.0071087605 + - 0.001970083 + - -0.0025331723 + - 0.061776515 + - 0.017747495 + - 0.04396135 + - 0.0011617352 + - 0.02509327 + - 0.039078914 + - -0.03762337 + - 0.021575885 + - 0.015548619 + - 0.043990802 + - -0.024633111 + - -0.0013324996 + - 0.063795656 + - -0.02035371 + - -0.0440217 + - -0.033049908 + - 0.031034056 + - -0.004495562 + - -0.0044647786 + - -0.033148468 + - 0.0025072312 + - 0.009637453 + - 0.04357035 + - 0.044546504 + - -0.08865154 + - -0.08487347 + - 0.00205395 + - 0.0060572727 + - 0.023767816 + - 0.051007573 + - 0.0057035745 + - 0.040539596 + - -0.035988905 + - -0.01824621 + - 0.007887274 + - -0.052848075 + - 0.08228733 + - 0.0015825987 + - 0.0004136183 + - 0.018108545 + - 0.040081892 + - -0.035405345 + - 0.050933696 + - 0.036154125 + - -0.03947257 + - -0.0018412384 + - -0.005589829 + - 0.03620321 + - 0.012144826 + - -0.008619581 + - -0.0043279063 + - 0.016455552 + - -0.015757388 + - 0.047043085 + - 0.08675011 + - -0.009986743 + - 0.022379123 + - 0.009470605 + - 0.034120724 + - 0.009922824 + - -0.014435422 + - 0.017998574 + - -0.03849387 + - 0.042357396 + - -0.010053916 + - 0.057034835 + - -0.027412737 + - 0.017308975 + - 0.02185228 + - -0.048017155 + - -0.025138885 + - 0.016482655 + - -0.01756698 + - 0.031146016 + - 0.021930695 + - -0.00075341173 + - -0.015085438 + - 0.0146785155 + - 0.028547939 + - -0.036677707 + - -0.022699077 + - -0.045135103 + - -0.02802744 + - 0.071454674 + - -0.009201392 + - -0.051717956 + - -0.019421624 + - 0.0034821872 + - 0.06667364 + - 0.09145379 + - -0.068826884 + - 0.053568542 + - 0.01160062 + - -0.025829546 + - 0.04487214 + - 0.030954553 + - -0.022543794 + - -0.009475118 + - -0.014623143 + - -0.02070793 + - -0.02656788 + - 0.009701591 + - -0.025120718 + - -0.018472325 + - -0.027967019 + - -0.021122226 + - -0.009891716 + - 0.018696679 + - -0.068876855 + - -0.023419108 + - -0.025495855 + - 0.016256742 + - 0.00064859784 + - 0.037749656 + - -0.046321914 + - 0.065936595 + - -0.008921658 + - 0.07497468 + - -0.05094385 + - -0.052860104 + - 0.022196138 + - -0.0140462285 + - -0.03562305 + - 0.024858234 + - 0.021310989 + - 0.014657512 + - 0.07767391 + - -0.027777392 + - 0.057577316 + - 0.055513144 + - 0.06322926 + - -0.026312957 + - 0.010970987 + - -0.0057475767 + - 0.0028267235 + - 0.051367335 + - -0.022320578 + - 0.009050165 + - 0.016952222 + - -0.032026373 + - -0.074292615 + - -0.02315535 + - 0.025375988 + - -0.041241057 + - -0.032157563 + - -0.015576387 + - -0.015834223 + - 0.02224181 + - 0.017586967 + - -0.029070066 + - 0.0030065721 + - -0.051857695 + - 0.04008828 + - -0.007960872 + - -0.061745025 + - -0.00086617953 + - 0.026723113 + - 0.008719714 + - -0.049826868 + - -0.0046574236 + - 0.018954279 + - -0.007935451 + - 0.053987946 + - 0.022795292 + - 0.029722994 + - -0.010146585 + - -0.019956842 + - -0.014686722 + - 0.01708331 + - 0.020001508 + - 0.016564105 + - 0.011379248 + - -0.011843253 + - 0.04056168 + - 0.004384286 + - -0.049596023 + - 0.02674251 + - 0.0076475106 + - -0.0064563937 + - -0.014233138 + - 0.014224383 + - 0.0052741244 + - 0.049964864 + - -0.016286546 + - -0.001200327 + - 0.041904222 + - -0.0010395087 + - -0.05105399 + - 0.022099879 + - -0.019455278 + - 0.009444127 + - -0.081325725 + - 0.015994828 + - 0.0010952728 + - 0.0008373874 + - -0.0016424303 + - -0.05096469 + - 0.045976803 + - 0.024695056 + - -0.031656373 + - -0.0013138534 + - -0.0060524447 + - 0.060276203 + - -0.016745795 + - -0.029930653 + - -0.0222771 + - 0.012314711 + - 0.038991332 + - -0.006665343 + - 0.041694533 + - 0.0076992502 + - -0.014353178 + - 0.0025135442 + - -0.0498445 + - 0.020322764 + - 0.0575802 + - -0.006096128 + - -0.010841882 + - -0.024337102 + - 0.021975596 + - 0.0064031687 + - -0.026746146 + - 0.03455729 + - -0.011909055 + - 0.016994143 + - 0.026053395 + - -0.020393625 + - 0.06403403 + - 0.042590734 + - 0.009193913 + - 0.04016698 + - 0.028304791 + - 0.022147119 + - -0.030121539 + - -0.0037334429 + - 0.03235819 + - -0.020825844 + - 0.0009766509 + - -0.012216568 + - 0.07944978 + - 0.04177374 + - -0.008281654 + - -0.008908983 + - 0.04799388 + - -0.012743454 + - -0.00076762337 + - 0.012673029 + - -0.018283572 + - 0.022068778 + - -0.009605337 + - 0.0030652087 + - -0.036517244 + - -0.006263211 + - 0.0194632 + - 0.009333852 + - 0.02350168 + - -0.0008530139 + - 0.018934859 + - -0.013986168 + - -0.010833636 + - -0.011189203 + - 0.011567913 + - -0.07253544 + - -0.005748846 + - 0.11930293 + - 0.060044624 + - -0.023167728 + - -0.015781552 + - -0.010494401 + - 0.01930528 + - -0.039608266 + - -0.073587865 + - 0.0019034932 + - -0.0013838339 + - 0.026257295 + - 0.004433007 + - -0.051545423 + - -0.033456888 + - -0.06401291 + - -0.08664347 + - 0.010781564 + - 0.0408775 + - 0.021475399 + - 0.017633006 + - -0.02186024 + - 0.047795497 + - -0.006370007 + - 0.01792626 + - 0.005195737 + - 0.0026206016 + - -0.008816542 + - 0.009266863 + - -0.018453414 + - 0.0040575014 + - 0.0047053 + - -0.0197809 + - -0.01580334 + - 0.007821501 + - -0.06296649 + - -0.06274416 + - -0.0031381177 + - -0.024228694 + - -0.002459634 + - 0.00034323192 + - -0.02430543 + - -0.029262288 + - -0.04606642 + - -0.013138838 + - 0.040473 + - 0.012308485 + - 0.0020701357 + - -0.007718021 + - -0.0064122216 + - -0.024890581 + - 0.014665469 + - -0.0028788927 + - -0.047072053 + - -0.014959743 + - 0.004587824 + - -0.07462158 + - -0.008558996 + - 0.019324543 + - -0.02247574 + - -0.07721102 + - 0.007920586 + - -0.020274863 + - 0.028696692 + - 0.024707401 + - 0.016905285 + - 0.012742534 + - 0.018577736 + - -0.05768951 + - 0.03203929 + - 0.018105863 + - -0.03917534 + - -0.026208939 + - -0.038492158 + - -0.043517314 + - -0.008031121 + - -0.016162876 + - -0.0010640965 + - -0.004164019 + - 0.005193703 + - -0.0058410293 + - 0.029311381 + - -0.016533207 + - 0.05005747 + - 0.012600715 + - 0.016292874 + - -0.008300225 + - -0.08953819 + - 0.06544125 + - -0.010512851 + - -0.021443438 + - 0.030277776 + - -0.06502247 + - 0.004850903 + - -0.013137611 + - 0.0506941 + - 0.0072725127 + - -0.025755724 + - 0.008224718 + - 0.013813313 + - 0.037197027 + - 0.0023671025 + - 0.00763629 + - 0.011905766 + - 0.033143394 + - 0.007750765 + - -0.07725993 + - -0.03193554 + - -0.016900484 + - 0.06256093 + - 0.015902048 + - 0.062251173 + - 0.07478062 + - -0.009171957 + - -0.025452917 + - -0.015754124 + - -0.004426243 + - -0.0873611 + - 0.0077999695 + - 0.061026644 + - 0.016489599 + - 0.0066420045 + - -0.0062355455 + - 0.00345123 + - -0.022935547 + - 0.03939866 + - -0.0037231673 + - 0.0033949488 + - 0.029471302 + - 0.023097953 + - -0.0237214 + - 0.046621986 + - 0.029790087 + - -0.030113066 + - -0.012432801 + - 0.036813233 + - 0.01785254 + - -0.08032645 + - 0.051262226 + - 0.04222712 + - -0.023358794 + - -0.03602671 + - 0.0092950305 + - -0.015663076 + - -0.023873692 + - 0.009383877 + - -0.056770466 + - 0.032832243 + - 0.019920528 + - -0.04734062 + - -0.03368295 + - 0.0041841906 + - 0.03257228 + - -0.07387151 + - 0.011083565 + - 0.008633363 + - -0.04694844 + - 0.0027943242 + - -0.009078945 + - -0.011981829 + - -0.026407028 + - 0.033040557 + - -0.025803888 + - -0.021555608 + - 0.0042838412 + - -0.04263439 + - 0.0465685 + - 0.070476055 + - -0.02560814 + - 0.060619954 + - 0.0071254023 + - -0.062549844 + - -0.021544673 + - -0.03598606 + - -0.04904548 + - 0.04631042 + - -0.029345924 + - -0.015404836 + - 0.0063473387 + - -0.023926385 + - 0.022935716 + - -0.1004558 + - -0.007012574 + - 0.049480513 + - -0.02592937 + - 0.057209775 + - -0.010056263 + - -0.02236333 + - 0.008163001 + - 0.0036735693 + - -0.008406754 + - -0.0029980235 + - -0.0052409745 + - -0.014090007 + - -0.025248934 + - 0.022062942 + - -0.019766279 + - 0.07160526 + - -0.00075892545 + - -0.0096911425 + - -0.019483106 + - 0.042904716 + - -0.0013572425 + - 0.04353212 + - -0.07759716 + - -0.011731261 + - -0.10602095 + - 0.0010180247 + - -0.07403971 + - 0.043784548 + - -0.010357722 + - -0.009020027 + - 0.026289912 + - 0.053744033 + - 0.009665143 diff --git a/backends/candle/tests/snapshots/test_jina_code__jina_code_batch.snap b/backends/candle/tests/snapshots/test_jina_code__jina_code_batch.snap new file mode 100644 index 00000000..87f5bbd6 --- /dev/null +++ b/backends/candle/tests/snapshots/test_jina_code__jina_code_batch.snap @@ -0,0 +1,772 @@ +--- +source: backends/candle/tests/test_jina_code.rs +expression: embeddings_batch +--- +- - -0.011624558 + - -0.0016926473 + - -0.006434922 + - -0.009909191 + - 0.027506275 + - 0.022786874 + - -0.059169117 + - -0.047887173 + - 0.038912594 + - 0.025214802 + - -0.014350341 + - 0.039117776 + - 0.03485464 + - 0.05296896 + - 0.034907352 + - 0.0013971328 + - 0.0019907136 + - 0.052471623 + - -0.014562107 + - 0.00030206257 + - -0.02770458 + - 0.02685756 + - -0.015385578 + - -0.012668428 + - -0.025259107 + - 0.00836893 + - 0.028925523 + - -0.035507426 + - -0.04220648 + - 0.047328673 + - -0.083232224 + - -0.02608385 + - -0.012809777 + - -0.0140402755 + - -0.013649549 + - -0.013641793 + - 0.02880054 + - 0.04465023 + - -0.0274121 + - -0.03170939 + - -0.047180377 + - -0.0349574 + - -0.037762504 + - 0.024104225 + - -0.039361924 + - -0.024559166 + - -0.0138650155 + - -0.049862508 + - -0.0507675 + - -0.07775355 + - -0.055519626 + - -0.025151063 + - 0.040781654 + - 0.0039354665 + - 0.015940087 + - -0.022214677 + - -0.013484792 + - -0.00070730236 + - -0.009409981 + - 0.0016682384 + - 0.016079267 + - -0.032368172 + - 0.024572799 + - 0.026780155 + - -0.025954 + - 0.021282032 + - -0.047395118 + - 0.044386413 + - -0.023020437 + - 0.0056320634 + - -0.017416032 + - -0.024118245 + - 0.05878816 + - 0.00059366866 + - -0.018553924 + - 0.003762008 + - 0.0035026476 + - -0.0077498616 + - -0.004517831 + - 0.008116448 + - -0.010922055 + - 0.037391223 + - -0.03318112 + - -0.07876148 + - -0.0018068684 + - -0.0484656 + - -0.038104814 + - -0.0070756334 + - -0.015427567 + - -0.04499487 + - 0.008910639 + - 0.029193114 + - 0.032674707 + - 0.0305758 + - -0.017541302 + - 0.028164856 + - -0.0344121 + - 0.016969312 + - 0.04108575 + - -0.054463603 + - -0.005427823 + - -0.008252881 + - -0.024992533 + - 0.048412405 + - -0.05769221 + - -0.08227673 + - -0.0523458 + - 0.025244992 + - -0.016289622 + - -0.049253095 + - -0.03999235 + - 0.05642755 + - -0.010015731 + - 0.04892553 + - -0.02504831 + - -0.014305144 + - 0.04907702 + - -0.0096177375 + - -0.0854665 + - 0.017617436 + - -0.014481601 + - 0.046187755 + - -0.030009247 + - -0.049553443 + - -0.006797771 + - -0.034234725 + - 0.06525787 + - 0.017298417 + - -0.018988462 + - -0.08520813 + - -0.011840521 + - -0.042942975 + - -0.056341827 + - 0.08071612 + - 0.028146833 + - -0.0141261155 + - 0.04001012 + - 0.018236378 + - 0.01583886 + - 0.0058086845 + - 0.0005290361 + - 0.0045259795 + - 0.024873685 + - 0.018050008 + - 0.02440984 + - -0.057676647 + - -0.022263395 + - -0.033016473 + - 0.014353932 + - -0.010189813 + - 0.018424626 + - -0.018238619 + - 0.07698089 + - -0.015595315 + - 0.007074499 + - 0.027340613 + - -0.02035102 + - -0.073149584 + - -0.0061021335 + - 0.0209061 + - -0.024985746 + - -0.020401739 + - -0.0048702923 + - 0.0444933 + - -0.02158648 + - 0.010312202 + - 0.009883107 + - 0.022790415 + - -0.01893943 + - 0.017615119 + - 0.023000762 + - -0.051276624 + - 0.019835759 + - 0.016675396 + - 0.021532865 + - 0.023097536 + - 0.01091247 + - 0.007642041 + - 0.0385409 + - 0.044193454 + - 0.003319116 + - 0.019355802 + - -0.055695307 + - -0.07680676 + - -0.035279598 + - -0.0070618573 + - -0.01408385 + - 0.03107026 + - 0.011547187 + - 0.04217532 + - 0.048034772 + - -0.014330861 + - -0.0007901662 + - 0.009103947 + - 0.011091706 + - -0.008960474 + - -0.009301379 + - 0.032424763 + - 0.09180531 + - 0.015491586 + - 0.031114861 + - 0.013289549 + - 0.03616163 + - 0.031084707 + - -0.021437835 + - 0.011905716 + - -0.104623 + - 0.033417992 + - -0.04683295 + - -0.05382028 + - 0.025979578 + - -0.003795323 + - 0.094473585 + - 0.09126942 + - 0.067700386 + - -0.002935823 + - -0.031604428 + - 0.0045677535 + - 0.014042491 + - -0.02969786 + - -0.012263599 + - -0.05807616 + - -0.0107510965 + - 0.022099612 + - 0.02685798 + - -0.048912797 + - 0.020464186 + - 0.011517319 + - -0.0016606319 + - -0.033255223 + - 0.02488959 + - -0.043240122 + - 0.00013934783 + - 0.037608404 + - 0.0748658 + - 0.061115693 + - -0.01670558 + - 0.037827995 + - -0.004162286 + - -0.00011341072 + - -0.031436242 + - -0.040532686 + - -0.0096016815 + - -0.065639205 + - -0.030070387 + - 0.05715196 + - 0.024338327 + - -0.008972259 + - -0.013220871 + - -0.07115904 + - -0.03446362 + - -0.010893238 + - 0.011950567 + - 0.0143028535 + - 0.0022963681 + - -0.060428943 + - 0.047038406 + - -0.017493663 + - 0.0208505 + - -0.010212147 + - -0.03337894 + - 0.026969628 + - -0.019734934 + - 0.0922163 + - -0.056534916 + - 0.017255465 + - 0.021692138 + - 0.023583038 + - 0.017224979 + - -0.0689936 + - 0.23665377 + - 0.04984247 + - -0.0039337534 + - -0.010740561 + - 0.00313393 + - -0.02665381 + - 0.013037893 + - 0.020565815 + - -0.002266256 + - 0.03748113 + - 0.018353125 + - -0.048318923 + - -0.06522075 + - -0.021460611 + - -0.029665926 + - -0.025507431 + - 0.0033435076 + - -0.0071087605 + - 0.001970083 + - -0.0025331723 + - 0.061776515 + - 0.017747495 + - 0.04396135 + - 0.0011617352 + - 0.02509327 + - 0.039078914 + - -0.03762337 + - 0.021575885 + - 0.015548619 + - 0.043990802 + - -0.024633111 + - -0.0013324996 + - 0.063795656 + - -0.02035371 + - -0.0440217 + - -0.033049908 + - 0.031034056 + - -0.004495562 + - -0.0044647786 + - -0.033148468 + - 0.0025072312 + - 0.009637453 + - 0.04357035 + - 0.044546504 + - -0.08865154 + - -0.08487347 + - 0.00205395 + - 0.0060572727 + - 0.023767816 + - 0.051007573 + - 0.0057035745 + - 0.040539596 + - -0.035988905 + - -0.01824621 + - 0.007887274 + - -0.052848075 + - 0.08228733 + - 0.0015825987 + - 0.0004136183 + - 0.018108545 + - 0.040081892 + - -0.035405345 + - 0.050933696 + - 0.036154125 + - -0.03947257 + - -0.0018412384 + - -0.005589829 + - 0.03620321 + - 0.012144826 + - -0.008619581 + - -0.0043279063 + - 0.016455552 + - -0.015757388 + - 0.047043085 + - 0.08675011 + - -0.009986743 + - 0.022379123 + - 0.009470605 + - 0.034120724 + - 0.009922824 + - -0.014435422 + - 0.017998574 + - -0.03849387 + - 0.042357396 + - -0.010053916 + - 0.057034835 + - -0.027412737 + - 0.017308975 + - 0.02185228 + - -0.048017155 + - -0.025138885 + - 0.016482655 + - -0.01756698 + - 0.031146016 + - 0.021930695 + - -0.00075341173 + - -0.015085438 + - 0.0146785155 + - 0.028547939 + - -0.036677707 + - -0.022699077 + - -0.045135103 + - -0.02802744 + - 0.071454674 + - -0.009201392 + - -0.051717956 + - -0.019421624 + - 0.0034821872 + - 0.06667364 + - 0.09145379 + - -0.068826884 + - 0.053568542 + - 0.01160062 + - -0.025829546 + - 0.04487214 + - 0.030954553 + - -0.022543794 + - -0.009475118 + - -0.014623143 + - -0.02070793 + - -0.02656788 + - 0.009701591 + - -0.025120718 + - -0.018472325 + - -0.027967019 + - -0.021122226 + - -0.009891716 + - 0.018696679 + - -0.068876855 + - -0.023419108 + - -0.025495855 + - 0.016256742 + - 0.00064859784 + - 0.037749656 + - -0.046321914 + - 0.065936595 + - -0.008921658 + - 0.07497468 + - -0.05094385 + - -0.052860104 + - 0.022196138 + - -0.0140462285 + - -0.03562305 + - 0.024858234 + - 0.021310989 + - 0.014657512 + - 0.07767391 + - -0.027777392 + - 0.057577316 + - 0.055513144 + - 0.06322926 + - -0.026312957 + - 0.010970987 + - -0.0057475767 + - 0.0028267235 + - 0.051367335 + - -0.022320578 + - 0.009050165 + - 0.016952222 + - -0.032026373 + - -0.074292615 + - -0.02315535 + - 0.025375988 + - -0.041241057 + - -0.032157563 + - -0.015576387 + - -0.015834223 + - 0.02224181 + - 0.017586967 + - -0.029070066 + - 0.0030065721 + - -0.051857695 + - 0.04008828 + - -0.007960872 + - -0.061745025 + - -0.00086617953 + - 0.026723113 + - 0.008719714 + - -0.049826868 + - -0.0046574236 + - 0.018954279 + - -0.007935451 + - 0.053987946 + - 0.022795292 + - 0.029722994 + - -0.010146585 + - -0.019956842 + - -0.014686722 + - 0.01708331 + - 0.020001508 + - 0.016564105 + - 0.011379248 + - -0.011843253 + - 0.04056168 + - 0.004384286 + - -0.049596023 + - 0.02674251 + - 0.0076475106 + - -0.0064563937 + - -0.014233138 + - 0.014224383 + - 0.0052741244 + - 0.049964864 + - -0.016286546 + - -0.001200327 + - 0.041904222 + - -0.0010395087 + - -0.05105399 + - 0.022099879 + - -0.019455278 + - 0.009444127 + - -0.081325725 + - 0.015994828 + - 0.0010952728 + - 0.0008373874 + - -0.0016424303 + - -0.05096469 + - 0.045976803 + - 0.024695056 + - -0.031656373 + - -0.0013138534 + - -0.0060524447 + - 0.060276203 + - -0.016745795 + - -0.029930653 + - -0.0222771 + - 0.012314711 + - 0.038991332 + - -0.006665343 + - 0.041694533 + - 0.0076992502 + - -0.014353178 + - 0.0025135442 + - -0.0498445 + - 0.020322764 + - 0.0575802 + - -0.006096128 + - -0.010841882 + - -0.024337102 + - 0.021975596 + - 0.0064031687 + - -0.026746146 + - 0.03455729 + - -0.011909055 + - 0.016994143 + - 0.026053395 + - -0.020393625 + - 0.06403403 + - 0.042590734 + - 0.009193913 + - 0.04016698 + - 0.028304791 + - 0.022147119 + - -0.030121539 + - -0.0037334429 + - 0.03235819 + - -0.020825844 + - 0.0009766509 + - -0.012216568 + - 0.07944978 + - 0.04177374 + - -0.008281654 + - -0.008908983 + - 0.04799388 + - -0.012743454 + - -0.00076762337 + - 0.012673029 + - -0.018283572 + - 0.022068778 + - -0.009605337 + - 0.0030652087 + - -0.036517244 + - -0.006263211 + - 0.0194632 + - 0.009333852 + - 0.02350168 + - -0.0008530139 + - 0.018934859 + - -0.013986168 + - -0.010833636 + - -0.011189203 + - 0.011567913 + - -0.07253544 + - -0.005748846 + - 0.11930293 + - 0.060044624 + - -0.023167728 + - -0.015781552 + - -0.010494401 + - 0.01930528 + - -0.039608266 + - -0.073587865 + - 0.0019034932 + - -0.0013838339 + - 0.026257295 + - 0.004433007 + - -0.051545423 + - -0.033456888 + - -0.06401291 + - -0.08664347 + - 0.010781564 + - 0.0408775 + - 0.021475399 + - 0.017633006 + - -0.02186024 + - 0.047795497 + - -0.006370007 + - 0.01792626 + - 0.005195737 + - 0.0026206016 + - -0.008816542 + - 0.009266863 + - -0.018453414 + - 0.0040575014 + - 0.0047053 + - -0.0197809 + - -0.01580334 + - 0.007821501 + - -0.06296649 + - -0.06274416 + - -0.0031381177 + - -0.024228694 + - -0.002459634 + - 0.00034323192 + - -0.02430543 + - -0.029262288 + - -0.04606642 + - -0.013138838 + - 0.040473 + - 0.012308485 + - 0.0020701357 + - -0.007718021 + - -0.0064122216 + - -0.024890581 + - 0.014665469 + - -0.0028788927 + - -0.047072053 + - -0.014959743 + - 0.004587824 + - -0.07462158 + - -0.008558996 + - 0.019324543 + - -0.02247574 + - -0.07721102 + - 0.007920586 + - -0.020274863 + - 0.028696692 + - 0.024707401 + - 0.016905285 + - 0.012742534 + - 0.018577736 + - -0.05768951 + - 0.03203929 + - 0.018105863 + - -0.03917534 + - -0.026208939 + - -0.038492158 + - -0.043517314 + - -0.008031121 + - -0.016162876 + - -0.0010640965 + - -0.004164019 + - 0.005193703 + - -0.0058410293 + - 0.029311381 + - -0.016533207 + - 0.05005747 + - 0.012600715 + - 0.016292874 + - -0.008300225 + - -0.08953819 + - 0.06544125 + - -0.010512851 + - -0.021443438 + - 0.030277776 + - -0.06502247 + - 0.004850903 + - -0.013137611 + - 0.0506941 + - 0.0072725127 + - -0.025755724 + - 0.008224718 + - 0.013813313 + - 0.037197027 + - 0.0023671025 + - 0.00763629 + - 0.011905766 + - 0.033143394 + - 0.007750765 + - -0.07725993 + - -0.03193554 + - -0.016900484 + - 0.06256093 + - 0.015902048 + - 0.062251173 + - 0.07478062 + - -0.009171957 + - -0.025452917 + - -0.015754124 + - -0.004426243 + - -0.0873611 + - 0.0077999695 + - 0.061026644 + - 0.016489599 + - 0.0066420045 + - -0.0062355455 + - 0.00345123 + - -0.022935547 + - 0.03939866 + - -0.0037231673 + - 0.0033949488 + - 0.029471302 + - 0.023097953 + - -0.0237214 + - 0.046621986 + - 0.029790087 + - -0.030113066 + - -0.012432801 + - 0.036813233 + - 0.01785254 + - -0.08032645 + - 0.051262226 + - 0.04222712 + - -0.023358794 + - -0.03602671 + - 0.0092950305 + - -0.015663076 + - -0.023873692 + - 0.009383877 + - -0.056770466 + - 0.032832243 + - 0.019920528 + - -0.04734062 + - -0.03368295 + - 0.0041841906 + - 0.03257228 + - -0.07387151 + - 0.011083565 + - 0.008633363 + - -0.04694844 + - 0.0027943242 + - -0.009078945 + - -0.011981829 + - -0.026407028 + - 0.033040557 + - -0.025803888 + - -0.021555608 + - 0.0042838412 + - -0.04263439 + - 0.0465685 + - 0.070476055 + - -0.02560814 + - 0.060619954 + - 0.0071254023 + - -0.062549844 + - -0.021544673 + - -0.03598606 + - -0.04904548 + - 0.04631042 + - -0.029345924 + - -0.015404836 + - 0.0063473387 + - -0.023926385 + - 0.022935716 + - -0.1004558 + - -0.007012574 + - 0.049480513 + - -0.02592937 + - 0.057209775 + - -0.010056263 + - -0.02236333 + - 0.008163001 + - 0.0036735693 + - -0.008406754 + - -0.0029980235 + - -0.0052409745 + - -0.014090007 + - -0.025248934 + - 0.022062942 + - -0.019766279 + - 0.07160526 + - -0.00075892545 + - -0.0096911425 + - -0.019483106 + - 0.042904716 + - -0.0013572425 + - 0.04353212 + - -0.07759716 + - -0.011731261 + - -0.10602095 + - 0.0010180247 + - -0.07403971 + - 0.043784548 + - -0.010357722 + - -0.009020027 + - 0.026289912 + - 0.053744033 + - 0.009665143 diff --git a/backends/candle/tests/snapshots/test_jina_code__jina_code_single.snap b/backends/candle/tests/snapshots/test_jina_code__jina_code_single.snap new file mode 100644 index 00000000..ee1cdfa1 --- /dev/null +++ b/backends/candle/tests/snapshots/test_jina_code__jina_code_single.snap @@ -0,0 +1,772 @@ +--- +source: backends/candle/tests/test_jina_code.rs +expression: embeddings_single +--- +- - -0.011624558 + - -0.0016926473 + - -0.006434922 + - -0.009909191 + - 0.027506275 + - 0.022786874 + - -0.059169117 + - -0.047887173 + - 0.038912594 + - 0.025214802 + - -0.014350341 + - 0.039117776 + - 0.03485464 + - 0.05296896 + - 0.034907352 + - 0.0013971328 + - 0.0019907136 + - 0.052471623 + - -0.014562107 + - 0.00030206257 + - -0.02770458 + - 0.02685756 + - -0.015385578 + - -0.012668428 + - -0.025259107 + - 0.00836893 + - 0.028925523 + - -0.035507426 + - -0.04220648 + - 0.047328673 + - -0.083232224 + - -0.02608385 + - -0.012809777 + - -0.0140402755 + - -0.013649549 + - -0.013641793 + - 0.02880054 + - 0.04465023 + - -0.0274121 + - -0.03170939 + - -0.047180377 + - -0.0349574 + - -0.037762504 + - 0.024104225 + - -0.039361924 + - -0.024559166 + - -0.0138650155 + - -0.049862508 + - -0.0507675 + - -0.07775355 + - -0.055519626 + - -0.025151063 + - 0.040781654 + - 0.0039354665 + - 0.015940087 + - -0.022214677 + - -0.013484792 + - -0.00070730236 + - -0.009409981 + - 0.0016682384 + - 0.016079267 + - -0.032368172 + - 0.024572799 + - 0.026780155 + - -0.025954 + - 0.021282032 + - -0.047395118 + - 0.044386413 + - -0.023020437 + - 0.0056320634 + - -0.017416032 + - -0.024118245 + - 0.05878816 + - 0.00059366866 + - -0.018553924 + - 0.003762008 + - 0.0035026476 + - -0.0077498616 + - -0.004517831 + - 0.008116448 + - -0.010922055 + - 0.037391223 + - -0.03318112 + - -0.07876148 + - -0.0018068684 + - -0.0484656 + - -0.038104814 + - -0.0070756334 + - -0.015427567 + - -0.04499487 + - 0.008910639 + - 0.029193114 + - 0.032674707 + - 0.0305758 + - -0.017541302 + - 0.028164856 + - -0.0344121 + - 0.016969312 + - 0.04108575 + - -0.054463603 + - -0.005427823 + - -0.008252881 + - -0.024992533 + - 0.048412405 + - -0.05769221 + - -0.08227673 + - -0.0523458 + - 0.025244992 + - -0.016289622 + - -0.049253095 + - -0.03999235 + - 0.05642755 + - -0.010015731 + - 0.04892553 + - -0.02504831 + - -0.014305144 + - 0.04907702 + - -0.0096177375 + - -0.0854665 + - 0.017617436 + - -0.014481601 + - 0.046187755 + - -0.030009247 + - -0.049553443 + - -0.006797771 + - -0.034234725 + - 0.06525787 + - 0.017298417 + - -0.018988462 + - -0.08520813 + - -0.011840521 + - -0.042942975 + - -0.056341827 + - 0.08071612 + - 0.028146833 + - -0.0141261155 + - 0.04001012 + - 0.018236378 + - 0.01583886 + - 0.0058086845 + - 0.0005290361 + - 0.0045259795 + - 0.024873685 + - 0.018050008 + - 0.02440984 + - -0.057676647 + - -0.022263395 + - -0.033016473 + - 0.014353932 + - -0.010189813 + - 0.018424626 + - -0.018238619 + - 0.07698089 + - -0.015595315 + - 0.007074499 + - 0.027340613 + - -0.02035102 + - -0.073149584 + - -0.0061021335 + - 0.0209061 + - -0.024985746 + - -0.020401739 + - -0.0048702923 + - 0.0444933 + - -0.02158648 + - 0.010312202 + - 0.009883107 + - 0.022790415 + - -0.01893943 + - 0.017615119 + - 0.023000762 + - -0.051276624 + - 0.019835759 + - 0.016675396 + - 0.021532865 + - 0.023097536 + - 0.01091247 + - 0.007642041 + - 0.0385409 + - 0.044193454 + - 0.003319116 + - 0.019355802 + - -0.055695307 + - -0.07680676 + - -0.035279598 + - -0.0070618573 + - -0.01408385 + - 0.03107026 + - 0.011547187 + - 0.04217532 + - 0.048034772 + - -0.014330861 + - -0.0007901662 + - 0.009103947 + - 0.011091706 + - -0.008960474 + - -0.009301379 + - 0.032424763 + - 0.09180531 + - 0.015491586 + - 0.031114861 + - 0.013289549 + - 0.03616163 + - 0.031084707 + - -0.021437835 + - 0.011905716 + - -0.104623 + - 0.033417992 + - -0.04683295 + - -0.05382028 + - 0.025979578 + - -0.003795323 + - 0.094473585 + - 0.09126942 + - 0.067700386 + - -0.002935823 + - -0.031604428 + - 0.0045677535 + - 0.014042491 + - -0.02969786 + - -0.012263599 + - -0.05807616 + - -0.0107510965 + - 0.022099612 + - 0.02685798 + - -0.048912797 + - 0.020464186 + - 0.011517319 + - -0.0016606319 + - -0.033255223 + - 0.02488959 + - -0.043240122 + - 0.00013934783 + - 0.037608404 + - 0.0748658 + - 0.061115693 + - -0.01670558 + - 0.037827995 + - -0.004162286 + - -0.00011341072 + - -0.031436242 + - -0.040532686 + - -0.0096016815 + - -0.065639205 + - -0.030070387 + - 0.05715196 + - 0.024338327 + - -0.008972259 + - -0.013220871 + - -0.07115904 + - -0.03446362 + - -0.010893238 + - 0.011950567 + - 0.0143028535 + - 0.0022963681 + - -0.060428943 + - 0.047038406 + - -0.017493663 + - 0.0208505 + - -0.010212147 + - -0.03337894 + - 0.026969628 + - -0.019734934 + - 0.0922163 + - -0.056534916 + - 0.017255465 + - 0.021692138 + - 0.023583038 + - 0.017224979 + - -0.0689936 + - 0.23665377 + - 0.04984247 + - -0.0039337534 + - -0.010740561 + - 0.00313393 + - -0.02665381 + - 0.013037893 + - 0.020565815 + - -0.002266256 + - 0.03748113 + - 0.018353125 + - -0.048318923 + - -0.06522075 + - -0.021460611 + - -0.029665926 + - -0.025507431 + - 0.0033435076 + - -0.0071087605 + - 0.001970083 + - -0.0025331723 + - 0.061776515 + - 0.017747495 + - 0.04396135 + - 0.0011617352 + - 0.02509327 + - 0.039078914 + - -0.03762337 + - 0.021575885 + - 0.015548619 + - 0.043990802 + - -0.024633111 + - -0.0013324996 + - 0.063795656 + - -0.02035371 + - -0.0440217 + - -0.033049908 + - 0.031034056 + - -0.004495562 + - -0.0044647786 + - -0.033148468 + - 0.0025072312 + - 0.009637453 + - 0.04357035 + - 0.044546504 + - -0.08865154 + - -0.08487347 + - 0.00205395 + - 0.0060572727 + - 0.023767816 + - 0.051007573 + - 0.0057035745 + - 0.040539596 + - -0.035988905 + - -0.01824621 + - 0.007887274 + - -0.052848075 + - 0.08228733 + - 0.0015825987 + - 0.0004136183 + - 0.018108545 + - 0.040081892 + - -0.035405345 + - 0.050933696 + - 0.036154125 + - -0.03947257 + - -0.0018412384 + - -0.005589829 + - 0.03620321 + - 0.012144826 + - -0.008619581 + - -0.0043279063 + - 0.016455552 + - -0.015757388 + - 0.047043085 + - 0.08675011 + - -0.009986743 + - 0.022379123 + - 0.009470605 + - 0.034120724 + - 0.009922824 + - -0.014435422 + - 0.017998574 + - -0.03849387 + - 0.042357396 + - -0.010053916 + - 0.057034835 + - -0.027412737 + - 0.017308975 + - 0.02185228 + - -0.048017155 + - -0.025138885 + - 0.016482655 + - -0.01756698 + - 0.031146016 + - 0.021930695 + - -0.00075341173 + - -0.015085438 + - 0.0146785155 + - 0.028547939 + - -0.036677707 + - -0.022699077 + - -0.045135103 + - -0.02802744 + - 0.071454674 + - -0.009201392 + - -0.051717956 + - -0.019421624 + - 0.0034821872 + - 0.06667364 + - 0.09145379 + - -0.068826884 + - 0.053568542 + - 0.01160062 + - -0.025829546 + - 0.04487214 + - 0.030954553 + - -0.022543794 + - -0.009475118 + - -0.014623143 + - -0.02070793 + - -0.02656788 + - 0.009701591 + - -0.025120718 + - -0.018472325 + - -0.027967019 + - -0.021122226 + - -0.009891716 + - 0.018696679 + - -0.068876855 + - -0.023419108 + - -0.025495855 + - 0.016256742 + - 0.00064859784 + - 0.037749656 + - -0.046321914 + - 0.065936595 + - -0.008921658 + - 0.07497468 + - -0.05094385 + - -0.052860104 + - 0.022196138 + - -0.0140462285 + - -0.03562305 + - 0.024858234 + - 0.021310989 + - 0.014657512 + - 0.07767391 + - -0.027777392 + - 0.057577316 + - 0.055513144 + - 0.06322926 + - -0.026312957 + - 0.010970987 + - -0.0057475767 + - 0.0028267235 + - 0.051367335 + - -0.022320578 + - 0.009050165 + - 0.016952222 + - -0.032026373 + - -0.074292615 + - -0.02315535 + - 0.025375988 + - -0.041241057 + - -0.032157563 + - -0.015576387 + - -0.015834223 + - 0.02224181 + - 0.017586967 + - -0.029070066 + - 0.0030065721 + - -0.051857695 + - 0.04008828 + - -0.007960872 + - -0.061745025 + - -0.00086617953 + - 0.026723113 + - 0.008719714 + - -0.049826868 + - -0.0046574236 + - 0.018954279 + - -0.007935451 + - 0.053987946 + - 0.022795292 + - 0.029722994 + - -0.010146585 + - -0.019956842 + - -0.014686722 + - 0.01708331 + - 0.020001508 + - 0.016564105 + - 0.011379248 + - -0.011843253 + - 0.04056168 + - 0.004384286 + - -0.049596023 + - 0.02674251 + - 0.0076475106 + - -0.0064563937 + - -0.014233138 + - 0.014224383 + - 0.0052741244 + - 0.049964864 + - -0.016286546 + - -0.001200327 + - 0.041904222 + - -0.0010395087 + - -0.05105399 + - 0.022099879 + - -0.019455278 + - 0.009444127 + - -0.081325725 + - 0.015994828 + - 0.0010952728 + - 0.0008373874 + - -0.0016424303 + - -0.05096469 + - 0.045976803 + - 0.024695056 + - -0.031656373 + - -0.0013138534 + - -0.0060524447 + - 0.060276203 + - -0.016745795 + - -0.029930653 + - -0.0222771 + - 0.012314711 + - 0.038991332 + - -0.006665343 + - 0.041694533 + - 0.0076992502 + - -0.014353178 + - 0.0025135442 + - -0.0498445 + - 0.020322764 + - 0.0575802 + - -0.006096128 + - -0.010841882 + - -0.024337102 + - 0.021975596 + - 0.0064031687 + - -0.026746146 + - 0.03455729 + - -0.011909055 + - 0.016994143 + - 0.026053395 + - -0.020393625 + - 0.06403403 + - 0.042590734 + - 0.009193913 + - 0.04016698 + - 0.028304791 + - 0.022147119 + - -0.030121539 + - -0.0037334429 + - 0.03235819 + - -0.020825844 + - 0.0009766509 + - -0.012216568 + - 0.07944978 + - 0.04177374 + - -0.008281654 + - -0.008908983 + - 0.04799388 + - -0.012743454 + - -0.00076762337 + - 0.012673029 + - -0.018283572 + - 0.022068778 + - -0.009605337 + - 0.0030652087 + - -0.036517244 + - -0.006263211 + - 0.0194632 + - 0.009333852 + - 0.02350168 + - -0.0008530139 + - 0.018934859 + - -0.013986168 + - -0.010833636 + - -0.011189203 + - 0.011567913 + - -0.07253544 + - -0.005748846 + - 0.11930293 + - 0.060044624 + - -0.023167728 + - -0.015781552 + - -0.010494401 + - 0.01930528 + - -0.039608266 + - -0.073587865 + - 0.0019034932 + - -0.0013838339 + - 0.026257295 + - 0.004433007 + - -0.051545423 + - -0.033456888 + - -0.06401291 + - -0.08664347 + - 0.010781564 + - 0.0408775 + - 0.021475399 + - 0.017633006 + - -0.02186024 + - 0.047795497 + - -0.006370007 + - 0.01792626 + - 0.005195737 + - 0.0026206016 + - -0.008816542 + - 0.009266863 + - -0.018453414 + - 0.0040575014 + - 0.0047053 + - -0.0197809 + - -0.01580334 + - 0.007821501 + - -0.06296649 + - -0.06274416 + - -0.0031381177 + - -0.024228694 + - -0.002459634 + - 0.00034323192 + - -0.02430543 + - -0.029262288 + - -0.04606642 + - -0.013138838 + - 0.040473 + - 0.012308485 + - 0.0020701357 + - -0.007718021 + - -0.0064122216 + - -0.024890581 + - 0.014665469 + - -0.0028788927 + - -0.047072053 + - -0.014959743 + - 0.004587824 + - -0.07462158 + - -0.008558996 + - 0.019324543 + - -0.02247574 + - -0.07721102 + - 0.007920586 + - -0.020274863 + - 0.028696692 + - 0.024707401 + - 0.016905285 + - 0.012742534 + - 0.018577736 + - -0.05768951 + - 0.03203929 + - 0.018105863 + - -0.03917534 + - -0.026208939 + - -0.038492158 + - -0.043517314 + - -0.008031121 + - -0.016162876 + - -0.0010640965 + - -0.004164019 + - 0.005193703 + - -0.0058410293 + - 0.029311381 + - -0.016533207 + - 0.05005747 + - 0.012600715 + - 0.016292874 + - -0.008300225 + - -0.08953819 + - 0.06544125 + - -0.010512851 + - -0.021443438 + - 0.030277776 + - -0.06502247 + - 0.004850903 + - -0.013137611 + - 0.0506941 + - 0.0072725127 + - -0.025755724 + - 0.008224718 + - 0.013813313 + - 0.037197027 + - 0.0023671025 + - 0.00763629 + - 0.011905766 + - 0.033143394 + - 0.007750765 + - -0.07725993 + - -0.03193554 + - -0.016900484 + - 0.06256093 + - 0.015902048 + - 0.062251173 + - 0.07478062 + - -0.009171957 + - -0.025452917 + - -0.015754124 + - -0.004426243 + - -0.0873611 + - 0.0077999695 + - 0.061026644 + - 0.016489599 + - 0.0066420045 + - -0.0062355455 + - 0.00345123 + - -0.022935547 + - 0.03939866 + - -0.0037231673 + - 0.0033949488 + - 0.029471302 + - 0.023097953 + - -0.0237214 + - 0.046621986 + - 0.029790087 + - -0.030113066 + - -0.012432801 + - 0.036813233 + - 0.01785254 + - -0.08032645 + - 0.051262226 + - 0.04222712 + - -0.023358794 + - -0.03602671 + - 0.0092950305 + - -0.015663076 + - -0.023873692 + - 0.009383877 + - -0.056770466 + - 0.032832243 + - 0.019920528 + - -0.04734062 + - -0.03368295 + - 0.0041841906 + - 0.03257228 + - -0.07387151 + - 0.011083565 + - 0.008633363 + - -0.04694844 + - 0.0027943242 + - -0.009078945 + - -0.011981829 + - -0.026407028 + - 0.033040557 + - -0.025803888 + - -0.021555608 + - 0.0042838412 + - -0.04263439 + - 0.0465685 + - 0.070476055 + - -0.02560814 + - 0.060619954 + - 0.0071254023 + - -0.062549844 + - -0.021544673 + - -0.03598606 + - -0.04904548 + - 0.04631042 + - -0.029345924 + - -0.015404836 + - 0.0063473387 + - -0.023926385 + - 0.022935716 + - -0.1004558 + - -0.007012574 + - 0.049480513 + - -0.02592937 + - 0.057209775 + - -0.010056263 + - -0.02236333 + - 0.008163001 + - 0.0036735693 + - -0.008406754 + - -0.0029980235 + - -0.0052409745 + - -0.014090007 + - -0.025248934 + - 0.022062942 + - -0.019766279 + - 0.07160526 + - -0.00075892545 + - -0.0096911425 + - -0.019483106 + - 0.042904716 + - -0.0013572425 + - 0.04353212 + - -0.07759716 + - -0.011731261 + - -0.10602095 + - 0.0010180247 + - -0.07403971 + - 0.043784548 + - -0.010357722 + - -0.009020027 + - 0.026289912 + - 0.053744033 + - 0.009665143 diff --git a/backends/candle/tests/test_flash_jina.rs b/backends/candle/tests/test_flash_jina.rs index 4a5f8276..4e42428a 100644 --- a/backends/candle/tests/test_flash_jina.rs +++ b/backends/candle/tests/test_flash_jina.rs @@ -34,7 +34,7 @@ fn test_flash_jina_small() -> Result<()> { let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); let embeddings_batch = SnapshotScores::from(pooled_embeddings); - insta::assert_yaml_snapshot!("jina_batch", embeddings_batch, &matcher); + insta::assert_yaml_snapshot!("flash_jina_batch", embeddings_batch, &matcher); let input_single = batch( vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], @@ -45,7 +45,7 @@ fn test_flash_jina_small() -> Result<()> { let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); let embeddings_single = SnapshotScores::from(pooled_embeddings); - insta::assert_yaml_snapshot!("jina_single", embeddings_single, &matcher); + insta::assert_yaml_snapshot!("flash_jina_single", embeddings_single, &matcher); assert_eq!(embeddings_batch[0], embeddings_single[0]); assert_eq!(embeddings_batch[2], embeddings_single[0]); diff --git a/backends/candle/tests/test_flash_jina_code.rs b/backends/candle/tests/test_flash_jina_code.rs new file mode 100644 index 00000000..74035117 --- /dev/null +++ b/backends/candle/tests/test_flash_jina_code.rs @@ -0,0 +1,53 @@ +#![allow(dead_code, unused_imports)] +mod common; + +use crate::common::{sort_embeddings, SnapshotScores}; +use anyhow::Result; +use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_core::{Backend, ModelType, Pool}; + +#[test] +#[serial_test::serial] +#[cfg(all(feature = "cuda", feature = "flash-attn"))] +fn test_flash_jina_code_base() -> Result<()> { + let model_root = download_artifacts("jinaai/jina-embeddings-v2-base-code", None)?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new( + model_root, + "float16".to_string(), + ModelType::Embedding(Pool::Mean), + )?; + + let input_batch = batch( + vec![ + tokenizer.encode("What is Deep Learning?", true).unwrap(), + tokenizer.encode("Deep Learning is...", true).unwrap(), + tokenizer.encode("What is Deep Learning?", true).unwrap(), + ], + [0, 1, 2].to_vec(), + vec![], + ); + + let matcher = relative_matcher(); + + let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); + let embeddings_batch = SnapshotScores::from(pooled_embeddings); + insta::assert_yaml_snapshot!("flash_jina_code_batch", embeddings_batch, &matcher); + + let input_single = batch( + vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], + [0].to_vec(), + vec![], + ); + + let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); + let embeddings_single = SnapshotScores::from(pooled_embeddings); + + insta::assert_yaml_snapshot!("flash_jina_code_single", embeddings_single, &matcher); + assert_eq!(embeddings_batch[0], embeddings_single[0]); + assert_eq!(embeddings_batch[2], embeddings_single[0]); + + Ok(()) +} diff --git a/backends/candle/tests/test_jina_code.rs b/backends/candle/tests/test_jina_code.rs new file mode 100644 index 00000000..70248e1a --- /dev/null +++ b/backends/candle/tests/test_jina_code.rs @@ -0,0 +1,50 @@ +mod common; + +use crate::common::{sort_embeddings, SnapshotScores}; +use anyhow::Result; +use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_core::{Backend, ModelType, Pool}; + +#[test] +fn test_jina_code_base() -> Result<()> { + let model_root = download_artifacts("jinaai/jina-embeddings-v2-base-code", None)?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new( + model_root, + "float32".to_string(), + ModelType::Embedding(Pool::Mean), + )?; + + let input_batch = batch( + vec![ + tokenizer.encode("What is Deep Learning?", true).unwrap(), + tokenizer.encode("Deep Learning is...", true).unwrap(), + tokenizer.encode("What is Deep Learning?", true).unwrap(), + ], + [0, 1, 2].to_vec(), + vec![], + ); + + let matcher = relative_matcher(); + + let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); + let embeddings_batch = SnapshotScores::from(pooled_embeddings); + insta::assert_yaml_snapshot!("jina_code_batch", embeddings_batch, &matcher); + + let input_single = batch( + vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], + [0].to_vec(), + vec![], + ); + + let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); + let embeddings_single = SnapshotScores::from(pooled_embeddings); + + insta::assert_yaml_snapshot!("jina_code_single", embeddings_single, &matcher); + assert_eq!(embeddings_batch[0], embeddings_single[0]); + assert_eq!(embeddings_batch[2], embeddings_single[0]); + + Ok(()) +} diff --git a/core/src/infer.rs b/core/src/infer.rs index 54f755d9..0f95ff8b 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -4,6 +4,7 @@ use crate::TextEmbeddingsError; use std::sync::Arc; use std::time::{Duration, Instant}; use text_embeddings_backend::{Backend, BackendError, Embedding, ModelType}; +use tokenizers::TruncationDirection; use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore}; use tracing::instrument; @@ -117,6 +118,7 @@ impl Infer { &self, inputs: I, truncate: bool, + truncation_direction: TruncationDirection, permit: OwnedSemaphorePermit, ) -> Result { let start_time = Instant::now(); @@ -131,7 +133,14 @@ impl Infer { } let results = self - .embed(inputs, truncate, false, &start_time, permit) + .embed( + inputs, + truncate, + truncation_direction, + false, + &start_time, + permit, + ) .await?; let InferResult::AllEmbedding(response) = results else { @@ -165,6 +174,7 @@ impl Infer { &self, inputs: I, truncate: bool, + truncation_direction: TruncationDirection, permit: OwnedSemaphorePermit, ) -> Result { let start_time = Instant::now(); @@ -179,7 +189,14 @@ impl Infer { } let results = self - .embed(inputs, truncate, true, &start_time, permit) + .embed( + inputs, + truncate, + truncation_direction, + true, + &start_time, + permit, + ) .await?; let InferResult::PooledEmbedding(response) = results else { @@ -213,6 +230,7 @@ impl Infer { &self, inputs: I, truncate: bool, + truncation_direction: TruncationDirection, normalize: bool, permit: OwnedSemaphorePermit, ) -> Result { @@ -228,7 +246,14 @@ impl Infer { } let results = self - .embed(inputs, truncate, true, &start_time, permit) + .embed( + inputs, + truncate, + truncation_direction, + true, + &start_time, + permit, + ) .await?; let InferResult::PooledEmbedding(mut response) = results else { @@ -278,6 +303,7 @@ impl Infer { &self, inputs: I, truncate: bool, + truncation_direction: TruncationDirection, pooling: bool, start_time: &Instant, _permit: OwnedSemaphorePermit, @@ -296,7 +322,7 @@ impl Infer { // Tokenization let encoding = self .tokenization - .encode(inputs.into(), truncate) + .encode(inputs.into(), truncate, truncation_direction) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); @@ -340,6 +366,7 @@ impl Infer { &self, inputs: I, truncate: bool, + truncation_direction: TruncationDirection, raw_scores: bool, _permit: OwnedSemaphorePermit, ) -> Result { @@ -357,7 +384,7 @@ impl Infer { // Tokenization let encoding = self .tokenization - .encode(inputs.into(), truncate) + .encode(inputs.into(), truncate, truncation_direction) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 46f1411e..07226823 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -64,6 +64,7 @@ impl Tokenization { &self, inputs: EncodingInput, truncate: bool, + truncation_direction: TruncationDirection, ) -> Result { // Check if inputs is empty if inputs.is_empty() { @@ -80,6 +81,7 @@ impl Tokenization { .send(TokenizerRequest::Encode( inputs, truncate, + truncation_direction, response_sender, Span::current(), )) @@ -163,7 +165,13 @@ fn tokenizer_worker( // Loop over requests while let Some(request) = receiver.blocking_recv() { match request { - TokenizerRequest::Encode(inputs, truncate, response_tx, parent_span) => { + TokenizerRequest::Encode( + inputs, + truncate, + truncation_direction, + response_tx, + parent_span, + ) => { parent_span.in_scope(|| { if !response_tx.is_closed() { // It's possible that the user dropped its request resulting in a send error. @@ -171,6 +179,7 @@ fn tokenizer_worker( let _ = response_tx.send(encode_input( inputs, truncate, + truncation_direction, max_input_length, position_offset, &mut tokenizer, @@ -247,13 +256,14 @@ fn tokenize_input( fn encode_input( inputs: EncodingInput, truncate: bool, + truncation_direction: TruncationDirection, max_input_length: usize, position_offset: usize, tokenizer: &mut Tokenizer, ) -> Result { // Default truncation params let truncate_params = truncate.then_some(TruncationParams { - direction: TruncationDirection::Right, + direction: truncation_direction, max_length: max_input_length, strategy: TruncationStrategy::LongestFirst, stride: 0, @@ -316,6 +326,7 @@ enum TokenizerRequest { Encode( EncodingInput, bool, + TruncationDirection, oneshot::Sender>, Span, ), diff --git a/docs/source/en/supported_models.md b/docs/source/en/supported_models.md index 64ecbc1f..876e5d06 100644 --- a/docs/source/en/supported_models.md +++ b/docs/source/en/supported_models.md @@ -25,13 +25,14 @@ model with Alibi positions. Below are some examples of the currently supported models: -| MTEB Rank | Model Type | Model ID | -|-----------|-------------|--------------------------------------------------------------------------------------------------| -| 6 | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) | -| 10 | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) | -| N/A | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) | -| N/A | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) | -| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) | +| MTEB Rank | Model Type | Model ID | +|-----------|--------------|--------------------------------------------------------------------------------------------------| +| 6 | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) | +| 10 | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) | +| N/A | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) | +| N/A | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) | +| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) | +| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jniaai/jina-embeddings-v2-base-code) | To explore the list of best performing text embeddings models, visit the diff --git a/proto/tei.proto b/proto/tei.proto index 6538e34a..394c0262 100644 --- a/proto/tei.proto +++ b/proto/tei.proto @@ -69,10 +69,16 @@ message Metadata { uint64 inference_time_ns = 6; } +enum TruncationDirection { + TRUNCATION_DIRECTION_RIGHT = 0; + TRUNCATION_DIRECTION_LEFT = 1; +} + message EmbedRequest { string inputs = 1; bool truncate = 2; bool normalize = 3; + TruncationDirection truncation_direction = 4; } message EmbedResponse { @@ -83,6 +89,7 @@ message EmbedResponse { message EmbedSparseRequest { string inputs = 1; bool truncate = 2; + TruncationDirection truncation_direction = 3; } message SparseValue { @@ -98,6 +105,7 @@ message EmbedSparseResponse { message EmbedAllRequest { string inputs = 1; bool truncate = 2; + TruncationDirection truncation_direction = 3; } message TokenEmbedding { @@ -113,12 +121,14 @@ message PredictRequest { string inputs = 1; bool truncate = 2; bool raw_scores = 3; + TruncationDirection truncation_direction = 4; } message PredictPairRequest { repeated string inputs = 1; bool truncate = 2; bool raw_scores = 3; + TruncationDirection truncation_direction = 4; } message Prediction { @@ -137,6 +147,7 @@ message RerankRequest { bool truncate = 3; bool raw_scores = 4; bool return_text = 5; + TruncationDirection truncation_direction = 6; } message RerankStreamRequest{ @@ -147,6 +158,7 @@ message RerankStreamRequest{ bool raw_scores = 4; // The server will only consider the first value bool return_text = 5; + TruncationDirection truncation_direction = 6; } message Rank { diff --git a/router/Cargo.toml b/router/Cargo.toml index 84d5a2ca..2f2b0815 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -43,6 +43,7 @@ mimalloc = { version = "*", default-features = false } # HTTP dependencies axum = { version = "0.7.4", features = ["json"], optional = true } axum-tracing-opentelemetry = { version = "0.17.0", optional = true } +base64 = { version = "0.21.4", optional = true } tower-http = { version = "0.5.1", features = ["cors"], optional = true } utoipa = { version = "4.2", features = ["axum_extras"], optional = true } utoipa-swagger-ui = { version = "6.0", features = ["axum"], optional = true } @@ -66,7 +67,7 @@ tonic-build = { version = "0.10.2", optional = true } [features] default = ["candle", "http"] -http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] +http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:base64", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build", "dep:async-stream", "dep:tokio-stream"] metal = ["text-embeddings-backend/metal"] mkl = ["text-embeddings-backend/mkl"] diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index 913455b3..98ee5601 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -1,7 +1,7 @@ use crate::grpc::pb::tei::v1::{ EmbedAllRequest, EmbedAllResponse, EmbedSparseRequest, EmbedSparseResponse, EncodeRequest, EncodeResponse, PredictPairRequest, RerankStreamRequest, SimpleToken, SparseValue, - TokenEmbedding, + TokenEmbedding, TruncationDirection, }; use crate::grpc::{ DecodeRequest, DecodeResponse, EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, @@ -80,9 +80,16 @@ impl TextEmbeddingsService { let start_time = Instant::now(); let compute_chars = request.inputs.chars().count(); + let truncation_direction = convert_truncation_direction(request.truncation_direction); let response = self .infer - .embed_pooled(request.inputs, request.truncate, request.normalize, permit) + .embed_pooled( + request.inputs, + request.truncate, + truncation_direction, + request.normalize, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -128,9 +135,15 @@ impl TextEmbeddingsService { let start_time = Instant::now(); let compute_chars = request.inputs.chars().count(); + let truncation_direction = convert_truncation_direction(request.truncation_direction); let response = self .infer - .embed_sparse(request.inputs, request.truncate, permit) + .embed_sparse( + request.inputs, + request.truncate, + truncation_direction, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -187,9 +200,15 @@ impl TextEmbeddingsService { let start_time = Instant::now(); let compute_chars = request.inputs.chars().count(); + let truncation_direction = convert_truncation_direction(request.truncation_direction); let response = self .infer - .embed_all(request.inputs, request.truncate, permit) + .embed_all( + request.inputs, + request.truncate, + truncation_direction, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -236,6 +255,7 @@ impl TextEmbeddingsService { &self, inputs: I, truncate: bool, + truncation_direction: tokenizers::TruncationDirection, raw_scores: bool, permit: OwnedSemaphorePermit, ) -> Result<(PredictResponse, ResponseMetadata), Status> { @@ -251,7 +271,7 @@ impl TextEmbeddingsService { let response = self .infer - .predict(inputs, truncate, raw_scores, permit) + .predict(inputs, truncate, truncation_direction, raw_scores, permit) .await .map_err(ErrorResponse::from)?; @@ -701,8 +721,15 @@ impl grpc::predict_server::Predict for TextEmbeddingsService { .map_err(ErrorResponse::from)?; let request = request.into_inner(); + let truncation_direction = convert_truncation_direction(request.truncation_direction); let (response, metadata) = self - .predict_inner(request.inputs, request.truncate, request.raw_scores, permit) + .predict_inner( + request.inputs, + request.truncate, + truncation_direction, + request.raw_scores, + permit, + ) .await?; let headers = HeaderMap::from(metadata); @@ -743,8 +770,15 @@ impl grpc::predict_server::Predict for TextEmbeddingsService { .try_acquire_permit() .map_err(ErrorResponse::from)?; + let truncation_direction = convert_truncation_direction(request.truncation_direction); let (response, metadata) = self - .predict_inner(inputs, request.truncate, request.raw_scores, permit) + .predict_inner( + inputs, + request.truncate, + truncation_direction, + request.raw_scores, + permit, + ) .await?; let headers = HeaderMap::from(metadata); @@ -767,8 +801,15 @@ impl grpc::predict_server::Predict for TextEmbeddingsService { // Clone for move below let clone = self.clone(); let function = |req: PredictRequest, permit: OwnedSemaphorePermit| async move { + let truncation_direction = convert_truncation_direction(req.truncation_direction); clone - .predict_inner(req.inputs, req.truncate, req.raw_scores, permit) + .predict_inner( + req.inputs, + req.truncate, + truncation_direction, + req.raw_scores, + permit, + ) .await }; @@ -800,8 +841,15 @@ impl grpc::predict_server::Predict for TextEmbeddingsService { } }; + let truncation_direction = convert_truncation_direction(req.truncation_direction); clone - .predict_inner(inputs, req.truncate, req.raw_scores, permit) + .predict_inner( + inputs, + req.truncate, + truncation_direction, + req.raw_scores, + permit, + ) .await }; @@ -862,12 +910,19 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { let rerank_inner = move |query: String, text: String, truncate: bool, + truncation_direction: tokenizers::TruncationDirection, raw_scores: bool, infer: Infer| async move { let permit = infer.acquire_permit().await; let response = infer - .predict((query, text), truncate, raw_scores, permit) + .predict( + (query, text), + truncate, + truncation_direction, + raw_scores, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -902,6 +957,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { let mut futures = Vec::with_capacity(batch_size); let query_chars = request.query.chars().count(); let mut total_compute_chars = query_chars * batch_size; + let truncation_direction = convert_truncation_direction(request.truncation_direction); for text in &request.texts { total_compute_chars += text.chars().count(); @@ -910,6 +966,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { request.query.clone(), text.clone(), request.truncate, + truncation_direction, request.raw_scores, local_infer, )) @@ -1027,11 +1084,18 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { query: String, text: String, truncate: bool, + truncation_direction: tokenizers::TruncationDirection, raw_scores: bool, infer: Infer, permit: OwnedSemaphorePermit| async move { let response = infer - .predict((query, text.clone()), truncate, raw_scores, permit) + .predict( + (query, text.clone()), + truncate, + truncation_direction, + raw_scores, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -1055,7 +1119,14 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { // Create bounded channel to have an upper bound of spawned tasks // We will have at most `max_parallel_stream_requests` messages from this stream in the queue let (rerank_sender, mut rerank_receiver) = mpsc::channel::<( - (usize, String, String, bool, bool), + ( + usize, + String, + String, + bool, + tokenizers::TruncationDirection, + bool, + ), oneshot::Sender< Result<(usize, usize, Duration, Duration, Duration, f32, String), ErrorResponse>, >, @@ -1066,8 +1137,10 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { // Background task that uses the bounded channel tokio::spawn(async move { - while let Some(((index, query, text, truncate, raw_scores), mut sender)) = - rerank_receiver.recv().await + while let Some(( + (index, query, text, truncate, truncation_direction, raw_scores), + mut sender, + )) = rerank_receiver.recv().await { // Wait on permit before spawning the task to avoid creating more tasks than needed let permit = local_infer.acquire_permit().await; @@ -1079,7 +1152,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { tokio::spawn(async move { // Select on closed to cancel work if the stream was closed tokio::select! { - result = rerank_inner(index, query, text, truncate, raw_scores, task_infer, permit) => { + result = rerank_inner(index, query, text, truncate, truncation_direction, raw_scores, task_infer, permit) => { let _ = sender.send(result); } _ = sender.closed() => {} @@ -1118,6 +1191,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { total_compute_chars += request.query.chars().count(); total_compute_chars += request.text.chars().count(); + let truncation_direction = convert_truncation_direction(request.truncation_direction); rerank_sender .send(( ( @@ -1125,6 +1199,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { request.query, request.text, request.truncate, + truncation_direction, raw_scores.unwrap(), ), result_sender, @@ -1434,3 +1509,10 @@ impl From for Status { Status::new(code, value.error) } } + +fn convert_truncation_direction(value: i32) -> tokenizers::TruncationDirection { + match TruncationDirection::try_from(value).expect("Unexpected enum value") { + TruncationDirection::Right => tokenizers::TruncationDirection::Right, + TruncationDirection::Left => tokenizers::TruncationDirection::Left, + } +} diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 1832abf1..11b3a521 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1,11 +1,11 @@ /// HTTP Server logic use crate::http::types::{ DecodeRequest, DecodeResponse, EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse, - EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, InputType, OpenAICompatEmbedding, - OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, - PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, - Sequence, SimpleToken, SparseValue, TokenizeInput, TokenizeRequest, TokenizeResponse, - VertexPrediction, VertexRequest, VertexResponse, + EmbedSparseRequest, EmbedSparseResponse, Embedding, EncodingFormat, Input, InputIds, InputType, + OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, + OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank, + RerankRequest, RerankResponse, Sequence, SimpleToken, SparseValue, TokenizeInput, + TokenizeRequest, TokenizeResponse, VertexPrediction, VertexRequest, VertexResponse, }; use crate::{ shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, @@ -19,6 +19,8 @@ use axum::http::{Method, StatusCode}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; use futures::future::join_all; use futures::FutureExt; use http::header::AUTHORIZATION; @@ -30,6 +32,7 @@ use text_embeddings_core::infer::{ AllEmbeddingsInferResponse, Infer, InferMetadata, PooledEmbeddingsInferResponse, }; use text_embeddings_core::TextEmbeddingsError; +use tokenizers::TruncationDirection; use tokio::sync::OwnedSemaphorePermit; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::instrument; @@ -103,7 +106,6 @@ async fn predict( // Closure for predict let predict_inner = move |inputs: Sequence, truncate: bool, - raw_scores: bool, infer: Infer, info: Info, permit: Option| async move { @@ -113,7 +115,13 @@ async fn predict( }; let response = infer - .predict(inputs, truncate, raw_scores, permit) + .predict( + inputs, + truncate, + req.truncation_direction, + req.raw_scores, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -159,15 +167,8 @@ async fn predict( let compute_chars = inputs.count_chars(); let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner( - inputs, - truncate, - req.raw_scores, - infer.0, - info.0, - Some(permit), - ) - .await?; + let (prompt_tokens, tokenization, queue, inference, predictions) = + predict_inner(inputs, truncate, infer.0, info.0, Some(permit)).await?; metrics::increment_counter!("te_request_success", "method" => "single"); @@ -211,7 +212,6 @@ async fn predict( futures.push(predict_inner( input, truncate, - req.raw_scores, local_infer.0, local_info.0, None, @@ -321,15 +321,17 @@ async fn rerank( })?; // Closure for rerank - let rerank_inner = move |query: String, - text: String, - truncate: bool, - raw_scores: bool, - infer: Infer| async move { + let rerank_inner = move |query: String, text: String, truncate: bool, infer: Infer| async move { let permit = infer.acquire_permit().await; let response = infer - .predict((query, text), truncate, raw_scores, permit) + .predict( + (query, text), + truncate, + req.truncation_direction, + req.raw_scores, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -375,7 +377,6 @@ async fn rerank( req.query.clone(), text.clone(), truncate, - req.raw_scores, local_infer.0, )) } @@ -484,7 +485,13 @@ async fn embed( let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; let response = infer - .embed_pooled(input, truncate, req.normalize, permit) + .embed_pooled( + input, + truncate, + req.truncation_direction, + req.normalize, + permit, + ) .await .map_err(ErrorResponse::from)?; @@ -541,7 +548,13 @@ async fn embed( futures.push(async move { let permit = local_infer.acquire_permit().await; local_infer - .embed_pooled(input, truncate, req.normalize, permit) + .embed_pooled( + input, + truncate, + req.truncation_direction, + req.normalize, + permit, + ) .await }) } @@ -641,7 +654,7 @@ async fn embed_sparse( let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; let response = infer - .embed_sparse(input, truncate, permit) + .embed_sparse(input, truncate, req.truncation_direction, permit) .await .map_err(ErrorResponse::from)?; @@ -697,7 +710,9 @@ async fn embed_sparse( let local_infer = infer.clone(); futures.push(async move { let permit = local_infer.acquire_permit().await; - let response = local_infer.embed_sparse(input, truncate, permit).await?; + let response = local_infer + .embed_sparse(input, truncate, req.truncation_direction, permit) + .await?; Ok((sparsify(response.results), response.metadata)) }) } @@ -789,7 +804,7 @@ async fn embed_all( let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; let response = infer - .embed_all(input, truncate, permit) + .embed_all(input, truncate, req.truncation_direction, permit) .await .map_err(ErrorResponse::from)?; @@ -845,7 +860,9 @@ async fn embed_all( let local_infer = infer.clone(); futures.push(async move { let permit = local_infer.acquire_permit().await; - local_infer.embed_all(input, truncate, permit).await + local_infer + .embed_all(input, truncate, req.truncation_direction, permit) + .await }) } let results = join_all(futures) @@ -923,6 +940,21 @@ async fn openai_embed( Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + let encode_embedding = |array: Vec| { + match req.encoding_format { + EncodingFormat::Float => Embedding::Float(array), + EncodingFormat::Base64 => { + // Unsafe is fine here since we do not violate memory ownership: bytes + // is only used in this scope and we return an owned string + let bytes = unsafe { + std::slice::from_raw_parts(array.as_ptr() as *const u8, array.len() * 4) + }; + + Embedding::Base64(BASE64_STANDARD.encode(bytes)) + } + } + }; + let span = tracing::Span::current(); let start_time = Instant::now(); @@ -936,16 +968,17 @@ async fn openai_embed( let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; let response = infer - .embed_pooled(input, truncate, true, permit) + .embed_pooled(input, truncate, TruncationDirection::Right, true, permit) .await .map_err(ErrorResponse::from)?; metrics::increment_counter!("te_request_success", "method" => "single"); + let embedding = encode_embedding(response.results); ( vec![OpenAICompatEmbedding { object: "embedding", - embedding: response.results, + embedding, index: 0, }], ResponseMetadata::new( @@ -997,7 +1030,7 @@ async fn openai_embed( futures.push(async move { let permit = local_infer.acquire_permit().await; local_infer - .embed_pooled(input, truncate, true, permit) + .embed_pooled(input, truncate, TruncationDirection::Right, true, permit) .await }) } @@ -1018,9 +1051,10 @@ async fn openai_embed( total_queue_time += r.metadata.queue.as_nanos() as u64; total_inference_time += r.metadata.inference.as_nanos() as u64; total_compute_tokens += r.metadata.prompt_tokens; + let embedding = encode_embedding(r.results); embeddings.push(OpenAICompatEmbedding { object: "embedding", - embedding: r.results, + embedding, index: i, }); } @@ -1451,7 +1485,6 @@ pub async fn run( // Create router let mut app = Router::new() - .layer(DefaultBodyLimit::max(payload_limit)) .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc)) // Base routes .route("/info", get(get_model_info)) @@ -1474,7 +1507,9 @@ pub async fn run( // AWS Sagemaker health route .route("/ping", get(health)) // Prometheus metrics route - .route("/metrics", get(metrics)); + .route("/metrics", get(metrics)) + // Update payload limit + .layer(DefaultBodyLimit::max(payload_limit)); #[cfg(feature = "google")] { diff --git a/router/src/http/types.rs b/router/src/http/types.rs index be514d40..a2a773e8 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -4,6 +4,7 @@ use serde::{de, Deserialize, Deserializer, Serialize}; use serde_json::json; use std::fmt::Formatter; use text_embeddings_core::tokenization::EncodingInput; +use tokenizers::TruncationDirection; use utoipa::openapi::{RefOr, Schema}; use utoipa::ToSchema; @@ -199,6 +200,9 @@ pub(crate) struct PredictRequest { #[schema(default = "false", example = "false", nullable = true)] pub truncate: Option, #[serde(default)] + #[schema(default = "right", example = "right")] + pub truncation_direction: TruncationDirection, + #[serde(default)] #[schema(default = "false", example = "false")] pub raw_scores: bool, } @@ -228,6 +232,9 @@ pub(crate) struct RerankRequest { #[schema(default = "false", example = "false", nullable = true)] pub truncate: Option, #[serde(default)] + #[schema(default = "right", example = "right")] + pub truncation_direction: TruncationDirection, + #[serde(default)] #[schema(default = "false", example = "false")] pub raw_scores: bool, #[serde(default)] @@ -278,6 +285,14 @@ pub(crate) enum Input { Batch(Vec), } +#[derive(Deserialize, ToSchema, Default)] +#[serde(rename_all = "snake_case")] +pub(crate) enum EncodingFormat { + #[default] + Float, + Base64, +} + #[derive(Deserialize, ToSchema)] pub(crate) struct OpenAICompatRequest { pub input: Input, @@ -287,6 +302,16 @@ pub(crate) struct OpenAICompatRequest { #[allow(dead_code)] #[schema(nullable = true, example = "null")] pub user: Option, + #[schema(default = "float", example = "float")] + #[serde(default)] + pub encoding_format: EncodingFormat, +} + +#[derive(Serialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum Embedding { + Float(Vec), + Base64(String), } #[derive(Serialize, ToSchema)] @@ -294,7 +319,7 @@ pub(crate) struct OpenAICompatEmbedding { #[schema(example = "embedding")] pub object: &'static str, #[schema(example = json!([0.0, 1.0, 2.0]))] - pub embedding: Vec, + pub embedding: Embedding, #[schema(example = "0")] pub index: usize, } @@ -323,6 +348,9 @@ pub(crate) struct EmbedRequest { #[serde(default)] #[schema(default = "false", example = "false", nullable = true)] pub truncate: Option, + #[serde(default)] + #[schema(default = "right", example = "right")] + pub truncation_direction: TruncationDirection, #[serde(default = "default_normalize")] #[schema(default = "true", example = "true")] pub normalize: bool, @@ -342,6 +370,9 @@ pub(crate) struct EmbedSparseRequest { #[serde(default)] #[schema(default = "false", example = "false", nullable = true)] pub truncate: Option, + #[serde(default)] + #[schema(default = "right", example = "right")] + pub truncation_direction: TruncationDirection, } #[derive(Serialize, ToSchema)] @@ -359,6 +390,9 @@ pub(crate) struct EmbedAllRequest { #[serde(default)] #[schema(default = "false", example = "false", nullable = true)] pub truncate: Option, + #[serde(default)] + #[schema(default = "right", example = "right")] + pub truncation_direction: TruncationDirection, } #[derive(Serialize, ToSchema)]