Skip to content

Commit

Permalink
Improve: Separate encoders & processors
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 20, 2024
1 parent 4f1568f commit cccfc62
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 399 deletions.
20 changes: 10 additions & 10 deletions javascript/encoders.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class ImageEncoder {
}
}

async forward(inputs) {
async forward(images) {
if (!this.session) {
throw new Error("Session is not initialized.");
}
Expand All @@ -206,33 +206,33 @@ class ImageEncoder {
return result;
};

let inputData;
let imagesData;
let dims;

if (Array.isArray(inputs)) {
// Assuming each input in the array is a Float32Array representing an image already processed to a fixed size.
const arrays = inputs.map(ensureFloat32Array);
inputData = concatFloat32Arrays(arrays);
if (Array.isArray(images)) {
// Assuming each images in the array is a Float32Array representing an image already processed to a fixed size.
const arrays = images.map(ensureFloat32Array);
imagesData = concatFloat32Arrays(arrays);
const numImages = arrays.length;
const numChannels = 3;
const height = this.imageSize;
const width = this.imageSize;
dims = [numImages, numChannels, height, width];
} else {
// Single image input, which is already a Float32Array.
inputData = ensureFloat32Array(inputs);
// Single image images, which is already a Float32Array.
imagesData = ensureFloat32Array(images);
const numChannels = 3;
const height = this.imageSize;
const width = this.imageSize;
dims = [1, numChannels, height, width];
}

// Create ONNX Tensor
const inputTensor = new Tensor('float32', inputData, dims);
const imagesTensor = new Tensor('float32', imagesData, dims);

// Run model inference
return this.session.run({
input: inputTensor,
images: imagesTensor,
});
}
}
Expand Down
89 changes: 52 additions & 37 deletions python/scripts/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from PIL import Image

import uform
from uform import Modality, get_model, get_model_onnx

# PyTorch is a very heavy dependency, so we may want to skip these tests if it's not installed
try:
Expand All @@ -27,12 +27,16 @@

torch_models = [
"unum-cloud/uform3-image-text-english-small",
"unum-cloud/uform-vl-english",
"unum-cloud/uform-vl-multilingual-v2",
"unum-cloud/uform3-image-text-english-base",
"unum-cloud/uform3-image-text-english-large",
"unum-cloud/uform3-image-text-multilingual-base",
]

onnx_models = [
"unum-cloud/uform3-image-text-english-small",
"unum-cloud/uform3-image-text-english-base",
"unum-cloud/uform3-image-text-english-large",
"unum-cloud/uform3-image-text-multilingual-base",
]

# Let's check if the HuggingFace Hub API token is set in the environment variable.
Expand Down Expand Up @@ -113,51 +117,52 @@ def cross_references_image_and_text_embeddings(text_to_embedding, image_to_embed
@pytest.mark.skipif(not torch_available, reason="PyTorch is not installed")
@pytest.mark.parametrize("model_name", torch_models)
def test_torch_one_embedding(model_name: str):
model, processor = uform.get_model(model_name, token=token)
processors, models = get_model(model_name, token=token)
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

text = "a small red panda in a zoo"
image_path = "assets/unum.png"

image = Image.open(image_path)
image_data = processor.preprocess_image(image)
text_data = processor.preprocess_text(text)
image_data = processor_image(image)
text_data = processor_text(text)

image_features, image_embedding = model.encode_image(image_data, return_features=True)
text_features, text_embedding = model.encode_text(text_data, return_features=True)
image_features, image_embedding = model_image.forward(image_data, return_features=True)
text_features, text_embedding = model_text.forward(text_data, return_features=True)

assert image_embedding.shape[0] == 1, "Image embedding batch size is not 1"
assert text_embedding.shape[0] == 1, "Text embedding batch size is not 1"

# Test reranking
score, joint_embedding = model.encode_multimodal(
image_features=image_features,
text_features=text_features,
attention_mask=text_data["attention_mask"],
return_scores=True,
)
assert score.shape[0] == 1, "Matching score batch size is not 1"
assert joint_embedding.shape[0] == 1, "Joint embedding batch size is not 1"

# Test if the model outputs actually make sense
cross_references_image_and_text_embeddings(
lambda text: model.encode_text(processor.preprocess_text(text)),
lambda image: model.encode_image(processor.preprocess_image(image)),
lambda text: model_text(processor_text(text)),
lambda image: model_image(processor_image(image)),
)


@pytest.mark.skipif(not torch_available, reason="PyTorch is not installed")
@pytest.mark.parametrize("model_name", torch_models)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_torch_many_embeddings(model_name: str, batch_size: int):
model, processor = uform.get_model(model_name, token=token)

processors, models = get_model(model_name, token=token)
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

texts = ["a small red panda in a zoo"] * batch_size
image_paths = ["assets/unum.png"] * batch_size

images = [Image.open(path) for path in image_paths]
image_data = processor.preprocess_image(images)
text_data = processor.preprocess_text(texts)
image_data = processor_image(images)
text_data = processor_text(texts)

image_embeddings = model.encode_image(image_data, return_features=False)
text_embeddings = model.encode_text(text_data, return_features=False)
image_embeddings = model_image.forward(image_data, return_features=False)
text_embeddings = model_text.forward(text_data, return_features=False)

assert image_embeddings.shape[0] == batch_size, "Image embedding is unexpected"
assert text_embeddings.shape[0] == batch_size, "Text embedding is unexpected"
Expand All @@ -172,24 +177,29 @@ def test_onnx_one_embedding(model_name: str, device: str):

try:

model, processor = uform.get_model_onnx(model_name, token=token, device=device)
processors, models = get_model_onnx(model_name, token=token, device=device)
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

text = "a small red panda in a zoo"
image_path = "assets/unum.png"

image = Image.open(image_path)
image_data = processor.preprocess_image(image)
text_data = processor.preprocess_text(text)
image_data = processor_image(image)
text_data = processor_text(text)

image_features, image_embedding = model.encode_image(image_data, return_features=True)
text_features, text_embedding = model.encode_text(text_data, return_features=True)
image_features, image_embedding = model_image(image_data)
text_features, text_embedding = model_text(text_data)

assert image_embedding.shape[0] == 1, "Image embedding batch size is not 1"
assert text_embedding.shape[0] == 1, "Text embedding batch size is not 1"

# Test if the model outputs actually make sense
cross_references_image_and_text_embeddings(
lambda text: model.encode_text(processor.preprocess_text(text)),
lambda image: model.encode_image(processor.preprocess_image(image)),
lambda text: model_text(processor_text(text)),
lambda image: model_image(processor_image(image)),
)

except ExecutionProviderError as e:
Expand All @@ -206,16 +216,21 @@ def test_onnx_many_embeddings(model_name: str, batch_size: int, device: str):

try:

model, processor = uform.get_model_onnx(model_name, token=token, device=device)
processors, models = get_model_onnx(model_name, token=token, device=device)
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

texts = ["a small red panda in a zoo"] * batch_size
image_paths = ["assets/unum.png"] * batch_size

images = [Image.open(path) for path in image_paths]
image_data = processor.preprocess_image(images)
text_data = processor.preprocess_text(texts)
image_data = processor_image(images)
text_data = processor_text(texts)

image_embeddings = model.encode_image(image_data, return_features=False)
text_embeddings = model.encode_text(text_data, return_features=False)
image_embeddings = model_image(image_data, return_features=False)
text_embeddings = model_text(text_data, return_features=False)

assert image_embeddings.shape[0] == batch_size, "Image embedding is unexpected"
assert text_embeddings.shape[0] == batch_size, "Text embedding is unexpected"
Expand Down
62 changes: 41 additions & 21 deletions python/uform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from json import load
from os.path import join, exists
from typing import Dict, Optional, Tuple, Literal
from typing import Dict, Optional, Tuple, Literal, Union, Callable
from enum import Enum

from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -88,20 +88,30 @@ def get_model(
model_name: str,
*,
token: Optional[str] = None,
modalities: Optional[Tuple[str]] = None,
):
from uform.torch_encoders import TextImageEncoder
from uform.torch_processors import TorchProcessor
modalities: Optional[Tuple[Union[str, Modality]]] = None,
) -> Tuple[Dict[Modality, Callable], Dict]:
from uform.torch_encoders import TextEncoder, ImageEncoder
from uform.torch_processors import TextProcessor, ImageProcessor

config_path, modality_paths, tokenizer_path = get_checkpoint(model_name, token, modalities, format=".pt")
modality_paths = (
{k.value: v for k, v in modality_paths.items()} if isinstance(modality_paths, dict) else modality_paths
)
modalities = normalize_modalities(modalities)
config_path, modality_paths, tokenizer_path = get_checkpoint(model_name, modalities, token=token, format=".pt")

result_processors = {}
result_models = {}

model = TextImageEncoder(config_path, modality_paths)
processor = TorchProcessor(config_path, tokenizer_path)
if Modality.TEXT_ENCODER in modalities:
processor = TextProcessor(config_path, tokenizer_path)
encoder = TextEncoder.from_pretrained(config_path, modality_paths.get(Modality.TEXT_ENCODER)).eval()
result_processors[Modality.TEXT_ENCODER] = processor
result_models[Modality.TEXT_ENCODER] = encoder

return model.eval(), processor
if Modality.IMAGE_ENCODER in modalities:
processor = ImageProcessor(config_path)
encoder = ImageEncoder.from_pretrained(config_path, modality_paths.get(Modality.IMAGE_ENCODER)).eval()
result_processors[Modality.IMAGE_ENCODER] = processor
result_models[Modality.IMAGE_ENCODER] = encoder

return result_processors, result_models


def get_model_onnx(
Expand All @@ -111,15 +121,25 @@ def get_model_onnx(
token: Optional[str] = None,
modalities: Optional[Tuple[str]] = None,
):
from uform.onnx_encoders import TextImageEncoder
from uform.numpy_processors import NumPyProcessor
from uform.onnx_encoders import TextEncoder, ImageEncoder
from uform.numpy_processors import TextProcessor, ImageProcessor

config_path, modality_paths, tokenizer_path = get_checkpoint(model_name, token, modalities, format=".onnx")
modality_paths = (
{k.value: v for k, v in modality_paths.items()} if isinstance(modality_paths, dict) else modality_paths
)
modalities = normalize_modalities(modalities)
config_path, modality_paths, tokenizer_path = get_checkpoint(model_name, modalities, token=token, format=".onnx")

result_processors = {}
result_models = {}

if Modality.TEXT_ENCODER in modalities:
processor = TextProcessor(config_path, tokenizer_path)
encoder = TextEncoder(modality_paths.get(Modality.TEXT_ENCODER), device=device)
result_processors[Modality.TEXT_ENCODER] = processor
result_models[Modality.TEXT_ENCODER] = encoder

model = TextImageEncoder(config_path, modality_paths, device=device)
processor = NumPyProcessor(config_path, tokenizer_path)
if Modality.IMAGE_ENCODER in modalities:
processor = ImageProcessor(config_path)
encoder = ImageEncoder(modality_paths.get(Modality.IMAGE_ENCODER), device=device)
result_processors[Modality.IMAGE_ENCODER] = processor
result_models[Modality.IMAGE_ENCODER] = encoder

return model, processor
return result_processors, result_models
18 changes: 12 additions & 6 deletions python/uform/numpy_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ def __init__(self, config_path: PathLike, tokenizer_path: PathLike):
"""

config = json.load(open(config_path, "r"))
self._max_seq_len = config["text_encoder"]["max_position_embeddings"]
if "text_encoder" in config:
config = config["text_encoder"]

self._max_seq_len = config["max_position_embeddings"]
self._tokenizer = Tokenizer.from_file(tokenizer_path)
self._tokenizer.no_padding()
self._pad_token_idx = config["text_encoder"]["padding_idx"]
self._pad_token_idx = config["padding_idx"]

def __call__(self, texts: Union[str, List[str]]) -> Dict[str, np.ndarray]:
"""Transforms one or more strings into dictionary with tokenized strings and attention masks.
Expand Down Expand Up @@ -50,17 +53,20 @@ def __call__(self, texts: Union[str, List[str]]) -> Dict[str, np.ndarray]:


class ImageProcessor:
def __init__(self, config_path: PathLike, tokenizer_path: PathLike):
def __init__(self, config_path: PathLike, tokenizer_path: PathLike = None):
"""
:param config: model config
:param tokenizer_path: path to tokenizer file
:param tensor_type: which tensors to return, either pt (PyTorch) or np (NumPy)
"""

config = json.load(open(config_path, "r"))
self._image_size = config["image_encoder"]["image_size"]
self._normalization_means = config["image_encoder"]["normalization_means"]
self._normalization_deviations = config["image_encoder"]["normalization_deviations"]
if "image_encoder" in config:
config = config["image_encoder"]

self._image_size = config["image_size"]
self._normalization_means = config["normalization_means"]
self._normalization_deviations = config["normalization_deviations"]

assert isinstance(self._image_size, int) and self._image_size > 0
assert isinstance(self._normalization_means, list) and isinstance(self._normalization_deviations, list)
Expand Down
Loading

0 comments on commit cccfc62

Please sign in to comment.