diff --git a/README.md b/README.md index af388c3..f4c4b58 100644 --- a/README.md +++ b/README.md @@ -349,8 +349,26 @@ and not a pre-trained model (any model with a `-pt` suffix). **How do I convert my fine-tune to a `.sbs` compressed model file?** -See compression/convert_weights.py to convert a pytorch checkpint. (The code may -need updates to work with Gemma-2 models.) +For PaliGemma (1 and 2) checkpoints, you can use +python/convert_from_safetensors.py to convert from safetensors format (tested +with building via bazel). For an adapter model, you will likely need to call +merge_and_unload() to convert the adapter model to a single-file format before +converting it. + +Here is how to use it using a bazel build of the compression library assuming +locally installed (venv) torch, numpy, safetensors, absl-py, etc.: + +```sh +bazel build //compression/python:compression +BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression" +python3 -c "import site; print(site.getsitepackages())" +# Use your sites-packages file here: +ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression +python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json +``` + +See also compression/convert_weights.py for a slightly older option to convert a +pytorch checkpoint. (The code may need updates to work with Gemma-2 models.) **What are some easy ways to make the model run faster?** diff --git a/python/BUILD.bazel b/python/BUILD.bazel index 5b7dabd..29de6bc 100644 --- a/python/BUILD.bazel +++ b/python/BUILD.bazel @@ -36,7 +36,7 @@ py_binary( deps = [ ":gemma", "@python_deps//absl_py", - # placeholder forabsl/flags + # placeholder for absl/flags "@compression_deps//numpy", ], ) diff --git a/python/convert_from_safetensors.py b/python/convert_from_safetensors.py new file mode 100644 index 0000000..8d28e7c --- /dev/null +++ b/python/convert_from_safetensors.py @@ -0,0 +1,436 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert a PaliGemma[1/2] model from SafeTensors to gemma.cpp format.""" +# Tested with: +# - PG1: huggingface.co/google/paligemma-3b-pt-224 +# - PG1: huggingface.co/merve/paligemma_vqav2 +# - PG2: huggingface.co/google/paligemma2-3b-pt-448 +# - PG2: huggingface.co/merve/paligemma2-3b-vqav2 +# The last one above is a Lora model, so the merged weights were saved using: +# model_name = "google/paligemma2-3b-pt-448" +# lora_weights_path = "merve/paligemma2-3b-vqav2" +# model = PaliGemmaForConditionalGeneration.from_pretrained(model_name) +# model = PeftModel.from_pretrained(model, lora_weights_path) +# model = model.merge_and_unload() +# model.save_pretrained("/tmp/lora-model") + +from collections.abc import Sequence +import csv +import json +import os +import sys +from typing import Any, Dict + +from absl import app +from absl import flags +from absl import logging +import numpy as np +import safetensors +import torch + +from compression.python import compression + + +def flatten_f32(x: np.ndarray) -> np.ndarray: + """Flattens an array. + + Args: + x: input array + + Returns: + Flattened array. + """ + return x.ravel().astype(np.float32, copy=False) + + +def compute_scale(x: np.ndarray) -> float: + """Rescales weight tensor to fit max magnitude within 1.875. + + Args: + x: input array + + Returns: + Scale value (1.0 means no rescaling). + """ + magnitude = np.max(np.abs(x)) + return max(1.0, magnitude / 1.875) + + +def _is_float_param(param_name: str) -> bool: + for prefix in ["img_pos_emb", "attn_out_b", "linear_0_b", "linear_1_b", + "qkv_ein_b", "img_emb_bias", "img_head_bias"]: + if param_name.startswith(prefix): + return True + return False + + +def _is_bf16_param(param_name: str) -> bool: + for prefix in ["pre_", "post_", "c_", "img_head_kernel"]: + if param_name.startswith(prefix): + return True + return False + + +# Layernorm names are slightly confusing in HF transformers between versions. +# Gemma layernorms: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py +# input_layernorm attn residual post_attention_layernorm mlp residual +# Gemma2 layernorms: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py +# input_layernorm attn post_attention_layernorm residual +# pre_feedforward_layernorm mlp post_feedforward_layernorm residual +# Note that post_attention_layernorm denotes something different. +# For comparison, the Big Vision Gemma2 keeps the same name for the same norm: +# pre_attention_norm attn [post_attention_norm] residual +# pre_ffw_norm mlp [post_ffw_norm] residual + + +# Tuples correspond to (transformers-name, shape, sbs-name). +# The qkv-einsum weights are part of llm-layers but are handled separately and +# thus not included in the list. +def _get_layer_config(dims: Dict[str, Any]): + """Returns a dictionary of layer configurations. + + Args: + dims: A dictionary of (mostly) dimension values. + Returns: + A dictionary of layer configurations. + """ + model_dim = dims["model_dim"] + hidden_dim = dims["hidden_dim"] + vit_seq_len = dims["vit_seq_len"] + config = { + "llm-non-layers": [ + ("language_model.model.embed_tokens.weight", (257152, model_dim), "c_embedding"), + ("language_model.model.norm.weight", (model_dim,), "c_final_norm"), + ], + "llm-layers": [ + ("language_model.model.layers.%d.mlp.down_proj.weight", (model_dim, hidden_dim), "linear_w"), + ], + "img-non-layers": [ + ("vision_tower.vision_model.post_layernorm.bias", (1152,), "enc_norm_bias"), + ("vision_tower.vision_model.post_layernorm.weight", (1152,), "enc_norm_scale"), + ("vision_tower.vision_model.embeddings.patch_embedding.bias", (1152,), "img_emb_bias"), + ("vision_tower.vision_model.embeddings.patch_embedding.weight", (1152, 14, 14, 3), "img_emb_kernel"), + ("multi_modal_projector.linear.bias", (model_dim,), "img_head_bias"), + ("multi_modal_projector.linear.weight", (model_dim, 1152), "img_head_kernel"), + ("vision_tower.vision_model.embeddings.position_embedding.weight", (vit_seq_len, 1152), "img_pos_emb"), + ], + "img-layers": [ + ("vision_tower.vision_model.encoder.layers.%d.layer_norm1.bias", (1152,), "ln_0_bias"), + ("vision_tower.vision_model.encoder.layers.%d.layer_norm1.weight", (1152,), "ln_0_scale"), + ("vision_tower.vision_model.encoder.layers.%d.layer_norm2.bias", (1152,), "ln_1_bias"), + ("vision_tower.vision_model.encoder.layers.%d.layer_norm2.weight", (1152,), "ln_1_scale"), + ("vision_tower.vision_model.encoder.layers.%d.mlp.fc1.bias", (4304,), "linear_0_b"), + ("vision_tower.vision_model.encoder.layers.%d.mlp.fc1.weight", (4304, 1152), "linear_0_w"), + ("vision_tower.vision_model.encoder.layers.%d.mlp.fc2.bias", (1152,), "linear_1_b"), + ("vision_tower.vision_model.encoder.layers.%d.mlp.fc2.weight", (1152, 4304), "linear_1_w"), + ("vision_tower.vision_model.encoder.layers.%d.self_attn.out_proj.bias", (1152,), "attn_out_b"), + ("vision_tower.vision_model.encoder.layers.%d.self_attn.out_proj.weight", (1152, 16 * 72), "attn_out_w"), + ], + } + if dims["has_post_norm"]: # See longer comment above. + config["llm-layers"] += [ + ("language_model.model.layers.%d.input_layernorm.weight", (model_dim,), "pre_att_ns"), + ("language_model.model.layers.%d.pre_feedforward_layernorm.weight", (model_dim,), "pre_ff_ns"), + ("language_model.model.layers.%d.post_attention_layernorm.weight", (model_dim,), "post_att_ns"), + ("language_model.model.layers.%d.post_feedforward_layernorm.weight", (model_dim,), "post_ff_ns"), + ] + else: + config["llm-layers"] += [ + ("language_model.model.layers.%d.input_layernorm.weight", (model_dim,), "pre_att_ns"), + ("language_model.model.layers.%d.post_attention_layernorm.weight", (model_dim,), "pre_ff_ns"), + ] + return config + + +def _get_dimensions(params): + """Returns a dictionary of dimension values. + + Args: + params: A dictionary with parameters. + Returns: + A dictionary of dimension values. + """ + dims = {} + # For PG1 and PG2-{3B,10B} head_dim is 256, would need update for PG2-28B. + # Unfortunately not easily available in any of the input sizes. + dims["head_dim"] = 256 + dims["model_dim"] = params["multi_modal_projector.linear.bias"].shape[0] + dims["hidden_dim"] = params[ + "language_model.model.layers.0.mlp.gate_proj.weight" + ].shape[0] + dims["num_heads"] = ( + params["language_model.model.layers.0.self_attn.q_proj.weight"].shape[0] + // dims["head_dim"] + ) + dims["vit_seq_len"] = params[ + "vision_tower.vision_model.embeddings.position_embedding.weight" + ].shape[0] + dims["num_llm_layers"] = len( + set([k for k in params.keys() if "input_layernorm.weight" in k]) + ) + dims["has_post_norm"] = ( + "language_model.model.layers.0.post_feedforward_layernorm.weight" + in params + ) + return dims + + +def export_paligemma_sbs( + load_path: str, + csv_file: str, + sbs_file: str, +) -> None: + """Exports sbs file from paligemma safetensors file(s).""" + + # If this is a multi-part checkpoint, get the list of files from the json. + if load_path.endswith(".json"): + with open(load_path, "r") as f: + j_obj = json.load(f) + files = list(set(j_obj["weight_map"].values())) + files = [os.path.join(os.path.dirname(load_path), f) for f in files] + else: + files = [load_path] + + # Read the parameters from the files. + params = {} + for file in files: + with safetensors.safe_open(file, framework="pt") as f: + for k in f.keys(): + params[k] = f.get_tensor(k) + print(k, params[k].shape, params[k].view(-1)[0].item()) + + # See https://tinyurl.com/paligemmavocab - HF transformers extends the + # embedding matrix by 64. Undo that here. + params["language_model.model.embed_tokens.weight"] = params[ + "language_model.model.embed_tokens.weight" + ][:-64] + + # Initialize a few things. + writer = compression.SbsWriter(compression.CompressorMode.NO_TOC) + metadata = [] + scales = {} + dims = _get_dimensions(params) + layer_config = _get_layer_config(dims) + + # Adds a parameter with expected shape to the writer. + def add_data(param_name, data, expected_shape, sbs_name, layer_index=None): + # Check shape. + if not isinstance(expected_shape, tuple): + expected_shape = (expected_shape,) + print(f"Writing {param_name} with shape {data.shape} e:{expected_shape}") + assert data.shape == expected_shape, param_name + + # Here we assume that the read data is a torch tensor and then convert it to + # a numpy array. + assert isinstance(data, torch.Tensor) + data = data.to(torch.float32).numpy() + data = np.array(data) + + # Add the layer index to the param name and sbs name if needed. + if layer_index is not None: + param_name = param_name % layer_index + sbs_name = sbs_name + f"_{layer_index}" + + # Flatten the data and get scale. + value = flatten_f32(data) + scale = compute_scale(value) + both_names = param_name + "::" + sbs_name + print(f"Param {both_names} has scale {scale}") + metadata.append((both_names, data.dtype, data.shape, scale)) + + # Determine the type as which to insert. + if _is_float_param(sbs_name): + insert = writer.insert_float # Insert as float. + print(f"Inserting {both_names} as float (f32) (no scaling)") + elif _is_bf16_param(sbs_name) or param_name.startswith("vision_tower"): + insert = writer.insert_bf16 # Insert as BF16. + print(f"Inserting {both_names} as BF16 (no scaling)") + else: + insert = writer.insert_sfp # Insert as SFP. + # Assumes that all scales are 1.0 for SFP. Consider adding scales. + # They would still need to be written, but would be collected here. + assert scale == 1.0, f"Scale for {both_names} is not 1.0" + if scale != 1.0: + value = value / scale + scales[sbs_name] = scale # Unused at the moment. + print(f"Inserting {both_names} as SFP with scale {scale}") + sys.stdout.flush() + + # Add the data to the writer. + insert(sbs_name, value) + + def add_qkv_einsum(i): # Handle qkv for layer i. + name = "language_model.model.layers.%d.self_attn.q_proj.weight" # (N*H, D) + q_i = params.pop(name % i) + (nh, d) = q_i.shape + h = dims["head_dim"] + n = dims["num_heads"] + assert nh == n * h + assert dims["model_dim"] == d + q_i = q_i.reshape(n, h, d) + name = "language_model.model.layers.%d.self_attn.k_proj.weight" # (K*H, D) + k_i = params.pop(name % i) + kh = k_i.shape[0] + k = kh // h + assert k_i.shape[1] == d + k_i = k_i.reshape(k, h, d) + name = "language_model.model.layers.%d.self_attn.v_proj.weight" # (K*H, D) + v_i = params.pop(name % i) + assert v_i.shape[0] == kh + assert v_i.shape[1] == d + v_i = v_i.reshape(k, h, d) + # Stack and reshape KV to interleave (k,v), (k,v), ... + stacked = torch.stack((k_i, v_i), dim=0) # (2, K, H, D) + transposed = stacked.transpose(0, 1) # (K, 2, H, D) + reshaped = transposed.reshape(2 * k, h, d) # (2K, H, D) + # Concatenate Q and KV to get the full qkv. + qkv_i = torch.cat([q_i, reshaped], dim=0) + name = "language_model.model.layers.%d.self_attn.qkv_proj.weight" + expected_shape = (n + 2 * k, h, d) # (N+2K, H, D) + sbs_name = "qkv_ein" + add_data(name, qkv_i, expected_shape, sbs_name, i) + + def add_att_einsum(i): # Handle att_ein for layer i. + name = "language_model.model.layers.%d.self_attn.o_proj.weight" # (D, N*H) + o_i = params.pop(name % i) + (d, nh) = o_i.shape + h = dims["head_dim"] + n = dims["num_heads"] + assert nh == n * h + o_i = o_i.reshape(d, n, h).permute(1, 0, 2) # (D, N, H) -> (N, D, H) + expected_shape = (n, d, h) + sbs_name = "att_ein" + add_data(name, o_i, expected_shape, sbs_name, i) + + # Join gate and up projection weights to gating_einsum for layer i. + def add_gating_einsum(i): + name = "language_model.model.layers.%d.mlp.gate_proj.weight" + gate_i = params.pop(name % i) + f, d = gate_i.shape + assert dims["hidden_dim"] == f + assert dims["model_dim"] == d + name = "language_model.model.layers.%d.mlp.up_proj.weight" + up_i = params.pop(name % i) + assert up_i.shape == gate_i.shape + gating_einsum_i = torch.stack([gate_i, up_i], dim=0) + name = "language_model.model.layers.%d.mlp.gating_einsum.weight" + expected_shape = (2, f, d) + sbs_name = "gating_ein" + add_data(name, gating_einsum_i, expected_shape, sbs_name, i) + + # Handle the q and kv einsum parts for layer i in the ViT - merge into qkv. + def add_vit_qkv_einsum(i): + # Weights first. + prefix = "vision_tower.vision_model.encoder.layers.%d.self_attn" + name = prefix + ".q_proj.weight" # (16 * 72, 1152) + q_i = params.pop(name % i) + q_i = q_i.reshape(16, 72, 1152) + name = prefix + ".k_proj.weight" # (16 * 72, 1152) + k_i = params.pop(name % i) + k_i = k_i.reshape(16, 72, 1152) + name = prefix + ".v_proj.weight" # (16 * 72, 1152) + v_i = params.pop(name % i) + v_i = v_i.reshape(16, 72, 1152) + qkv_i, shape = torch.stack([q_i, k_i, v_i], dim=1), (16, 3, 72, 1152) + name = prefix + ".qkv_proj.weight" + sbs_name = "qkv_ein_w" + add_data(name, qkv_i, shape, sbs_name, i) + # Now the biases. + name = prefix + ".q_proj.bias" # (16 * 72) + q_i = params.pop(name % i) + q_i = q_i.reshape(16, 72) + name = prefix + ".k_proj.bias" # (16 * 72) + k_i = params.pop(name % i) + k_i = k_i.reshape(16, 72) + name = prefix + ".v_proj.bias" # (16 * 72) + v_i = params.pop(name % i) + v_i = v_i.reshape(16, 72) + qkv_i, shape = torch.stack([q_i, k_i, v_i], dim=1), (16, 3, 72) + name = prefix + ".qkv_proj.bias" + sbs_name = "qkv_ein_b" + add_data(name, qkv_i, shape, sbs_name, i) + + # Handle the image embedding kernel transpose. + name = "vision_tower.vision_model.embeddings.patch_embedding.weight" + assert params[name].shape == (1152, 3, 14, 14,) + params[name] = params[name].permute(0, 2, 3, 1) + + # Add the non-layer params. + for name, shape, sbs_name in ( + layer_config["llm-non-layers"] + layer_config["img-non-layers"] + ): + add_data(name, params.pop(name), shape, sbs_name) + + # Go through the LLM layers and add the weights. + for i in range(dims["num_llm_layers"]): + add_att_einsum(i) + add_gating_einsum(i) + for name, shape, sbs_name in layer_config["llm-layers"]: + add_data(name, params.pop(name % i), shape, sbs_name, i) + add_qkv_einsum(i) + + # Go through the Vit layers and add the weights. + for i in range(27): + for name, shape, sbs_name in layer_config["img-layers"]: + add_data(name, params.pop(name % i), shape, sbs_name, i) + add_vit_qkv_einsum(i) + + assert not params, "Some params were not used: %s" % params.keys() + + # Write everything to the sbs file. + writer.write(sbs_file) + + # Write the metadata for manual inspection. + with open(csv_file, "w") as csv_handle: + csv.writer(csv_handle).writerows(metadata) + + +_LOAD_PATH = flags.DEFINE_string( + "load_path", + "", + "Path to the safetensors index.json file to read", +) +_METADATA_FILE = flags.DEFINE_string( + "metadata_file", + "/tmp/gemmacpp.csv", + "Path to the metadata file to write", +) +_SBS_FILE = flags.DEFINE_string( + "sbs_file", + "/tmp/gemmacpp.sbs", + "Path to the sbs file to write", +) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + logging.use_python_logging() + logging.set_verbosity(logging.INFO) + load_path = _LOAD_PATH.value + metadata_file = _METADATA_FILE.value + sbs_file = _SBS_FILE.value + + logging.info( + "\n====\nReading from %s and writing to %s\n====", load_path, sbs_file + ) + export_paligemma_sbs(load_path, metadata_file, sbs_file) + + +if __name__ == "__main__": + app.run(main)