Skip to content

Commit

Permalink
Merge branch 'master' into llamacli-tools
Browse files Browse the repository at this point in the history
  • Loading branch information
bandoti committed Feb 25, 2025
2 parents 66eff76 + 401af80 commit 7b076ee
Show file tree
Hide file tree
Showing 102 changed files with 4,779 additions and 901 deletions.
16 changes: 12 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,15 @@ jobs:
name: llama-bin-macos-x64.zip

ubuntu-cpu-cmake:
runs-on: ubuntu-22.04
strategy:
matrix:
include:
- build: 'x64'
os: ubuntu-22.04
- build: 'arm64'
os: ubuntu-22.04-arm

runs-on: ${{ matrix.os }}

steps:
- name: Clone
Expand Down Expand Up @@ -239,14 +247,14 @@ jobs:
run: |
cp LICENSE ./build/bin/
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/*
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip
name: llama-bin-ubuntu-x64.zip
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip
name: llama-bin-ubuntu-${{ matrix.build }}.zip

ubuntu-latest-cmake-sanitizer:
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Pull requests (for contributors)

- llama.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier
- Test your changes:
- Execute [the full CI locally on your machine](ci/README.md) before publishing
- Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`)
- If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends)
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
- Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments

Expand Down
14 changes: 13 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,10 @@ ifdef GGML_CUDA_CCBIN
MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN)
endif # GGML_CUDA_CCBIN

ifdef GGML_CUDA_NO_FA
MK_NVCCFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA

ifdef GGML_CUDA_FA_ALL_QUANTS
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
endif # GGML_CUDA_FA_ALL_QUANTS
Expand Down Expand Up @@ -800,6 +804,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # GGML_CUDA_NO_PEER_COPY

ifdef GGML_CUDA_NO_FA
HIPFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA

OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o
OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
OBJ_GGML_EXT += $(OBJ_CUDA_TMPL)
Expand Down Expand Up @@ -847,7 +855,7 @@ ifdef GGML_MUSA
CXX := $(MUSA_PATH)/bin/clang++
MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc

MUSAFLAGS = -x musa -mtgpu
MUSAFLAGS = -fsigned-char -x musa -mtgpu
MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch))

ifdef GGML_CUDA_FORCE_MMQ
Expand Down Expand Up @@ -876,6 +884,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # GGML_CUDA_NO_PEER_COPY

ifdef GGML_CUDA_NO_FA
MUSAFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA

ifdef GGML_CUDA_FA_ALL_QUANTS
MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
endif # GGML_CUDA_FA_ALL_QUANTS
Expand Down
16 changes: 14 additions & 2 deletions docs/backend/SYCL.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ The following release is verified with good quality:

## News

- 2025.2
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|GPU|Base tokens/s|Increased tokens/s|Percent|
|-|-|-|-|
|PVC 1550|39|73|+87%|
|Flex 170|39|50|+28%|
|Arc770|42|55|+30%|
|MTL|13|16|+23%|
|ARL-H|14|17|+21%|

- 2024.11
- Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer.

Expand Down Expand Up @@ -97,8 +107,8 @@ SYCL backend supports Intel GPU Family:
| Intel Data Center Max Series | Support | Max 1550, 1100 |
| Intel Data Center Flex Series | Support | Flex 170 |
| Intel Arc Series | Support | Arc 770, 730M, Arc A750 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake |
| Intel iGPU | Support | iGPU in 13700k, i5-1250P, i7-1260P, i7-1165G7 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake |
| Intel iGPU | Support | iGPU in 13700k,iGPU in 13400, i5-1250P, i7-1260P, i7-1165G7 |

*Notes:*

Expand Down Expand Up @@ -660,8 +670,10 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| Name | Value | Function |
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |


## Known Issues

- `Split-mode:[row]` is not supported.
Expand Down
8 changes: 8 additions & 0 deletions docs/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,14 @@ This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GP
cmake --build build --config Release
```

For static build:

```bash
cmake -B build -DGGML_MUSA=ON \
-DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
cmake --build build --config Release
```

The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used.

The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted.
Expand Down
25 changes: 18 additions & 7 deletions examples/llama.swiftui/llama.swiftui/UI/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,26 @@ struct ContentView: View {
}
}
}.sheet(isPresented: $showingHelp) { // Sheet for help modal
VStack(alignment: .leading) {
NavigationView {
VStack(alignment: .leading) {
Text("1. Make sure the model is in GGUF Format")
.padding()
Text("2. Copy the download link of the quantized model")
.padding()
VStack(alignment: .leading) {
Text("1. Make sure the model is in GGUF Format")
.padding()
Text("2. Copy the download link of the quantized model")
.padding()
}
Spacer()
}
.navigationTitle("Help")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .navigationBarTrailing) {
Button("Done") {
showingHelp = false
}
}
}
Spacer()
}
}
}
}
}
Expand Down
183 changes: 183 additions & 0 deletions examples/llava/README-granitevision.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Granite Vision

Download the model and point your `GRANITE_MODEL` environment variable to the path.

```bash
$ git clone https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview
$ export GRANITE_MODEL=./granite-vision-3.1-2b-preview
```


### 1. Running llava surgery v2.
First, we need to run the llava surgery script as shown below:

`python llava_surgery_v2.py -C -m $GRANITE_MODEL`

You should see two new files (`llava.clip` and `llava.projector`) written into your model's directory, as shown below.

```bash
$ ls $GRANITE_MODEL | grep -i llava
llava.clip
llava.projector
```

We should see that the projector and visual encoder get split out into the llava files. Quick check to make sure they aren't empty:
```python
import os
import torch

MODEL_PATH = os.getenv("GRANITE_MODEL")
if not MODEL_PATH:
raise ValueError("env var GRANITE_MODEL is unset!")

encoder_tensors = torch.load(os.path.join(MODEL_PATH, "llava.clip"))
projector_tensors = torch.load(os.path.join(MODEL_PATH, "llava.projector"))

assert len(encoder_tensors) > 0
assert len(projector_tensors) > 0
```

If you actually inspect the `.keys()` of the loaded tensors, you should see a lot of `vision_model` tensors in the `encoder_tensors`, and 5 tensors (`'multi_modal_projector.linear_1.bias'`, `'multi_modal_projector.linear_1.weight'`, `'multi_modal_projector.linear_2.bias'`, `'multi_modal_projector.linear_2.weight'`, `'image_newline'`) in the multimodal `projector_tensors`.


### 2. Creating the Visual Component GGUF
To create the GGUF for the visual components, we need to write a config for the visual encoder; make sure the config contains the correct `image_grid_pinpoints`


Note: we refer to this file as `$VISION_CONFIG` later on.
```json
{
"_name_or_path": "siglip-model",
"architectures": [
"SiglipVisionModel"
],
"image_grid_pinpoints": [
[384,768],
[384,1152],
[384,1536],
[384,1920],
[384,2304],
[384,2688],
[384,3072],
[384,3456],
[384,3840],
[768,384],
[768,768],
[768,1152],
[768,1536],
[768,1920],
[1152,384],
[1152,768],
[1152,1152],
[1536,384],
[1536,768],
[1920,384],
[1920,768],
[2304,384],
[2688,384],
[3072,384],
[3456,384],
[3840,384]
],
"mm_patch_merge_type": "spatial_unpad",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"layer_norm_eps": 1e-6,
"hidden_act": "gelu_pytorch_tanh",
"projection_dim": 0,
"vision_feature_layer": [-24, -20, -12, -1]
}
```

Create a new directory to hold the visual components, and copy the llava.clip/projector files, as well as the vision config into it.

```bash
$ ENCODER_PATH=$PWD/visual_encoder
$ mkdir $ENCODER_PATH

$ cp $GRANITE_MODEL/llava.clip $ENCODER_PATH/pytorch_model.bin
$ cp $GRANITE_MODEL/llava.projector $ENCODER_PATH/
$ cp $VISION_CONFIG $ENCODER_PATH/config.json
```

At which point you should have something like this:
```bash
$ ls $ENCODER_PATH
config.json llava.projector pytorch_model.bin
```

Now convert the components to GGUF; Note that we also override the image mean/std dev to `[.5,.5,.5]` since we use the siglip visual encoder - in the transformers model, you can find these numbers in the [preprocessor_config.json](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview/blob/main/preprocessor_config.json).
```bash
$ python convert_image_encoder_to_gguf.py \
-m $ENCODER_PATH \
--llava-projector $ENCODER_PATH/llava.projector \
--output-dir $ENCODER_PATH \
--clip-model-is-vision \
--clip-model-is-siglip \
--image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5
```

this will create the first GGUF file at `$ENCODER_PATH/mmproj-model-f16.gguf`; we will refer to the abs path of this file as the `$VISUAL_GGUF_PATH.`


### 3. Creating the LLM GGUF.
The granite vision model contains a granite LLM as its language model. For now, the easiest way to get the GGUF for LLM is by loading the composite model in `transformers` and exporting the LLM so that it can be directly converted with the normal conversion path.

First, set the `LLM_EXPORT_PATH` to the path to export the `transformers` LLM to.
```
$ export LLM_EXPORT_PATH=$PWD/granite_vision_llm
```

```python
import os
import transformers

MODEL_PATH = os.getenv("GRANITE_MODEL")
if not MODEL_PATH:
raise ValueError("env var GRANITE_MODEL is unset!")

LLM_EXPORT_PATH = os.getenv("LLM_EXPORT_PATH")
if not MODEL_PATH:
raise ValueError("env var LLM_EXPORT_PATH is unset!")

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH)

# NOTE: granite vision support was added to transformers very recently (4.49);
# if you get size mismatches, your version is too old.
# If you are running with an older version, set `ignore_mismatched_sizes=True`
# as shown below; it won't be loaded correctly, but the LLM part of the model that
# we are exporting will be loaded correctly.
model = transformers.AutoModelForImageTextToText.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True)

tokenizer.save_pretrained(LLM_EXPORT_PATH)
model.language_model.save_pretrained(LLM_EXPORT_PATH)
```

Now you can convert the exported LLM to GGUF with the normal converter in the root of the llama cpp project.
```bash
$ LLM_GGUF_PATH=$LLM_EXPORT_PATH/granite_llm.gguf
...
$ python convert_hf_to_gguf.py --outfile $LLM_GGUF_PATH $LLM_EXPORT_PATH
```


### 4. Running the Model in Llama cpp
Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. Sample usage:

Note - the test image shown below can be found [here](https://github-production-user-asset-6210df.s3.amazonaws.com/10740300/415512792-d90d5562-8844-4f34-a0a5-77f62d5a58b5.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20250221%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250221T054145Z&X-Amz-Expires=300&X-Amz-Signature=86c60be490aa49ef7d53f25d6c973580a8273904fed11ed2453d0a38240ee40a&X-Amz-SignedHeaders=host).

```bash
$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \
--mmproj $VISUAL_GGUF_PATH \
--image cherry_blossom.jpg \
-c 16384 \
-p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\<image>\nWhat type of flowers are in this picture?\n<|assistant|>\n" \
--temp 0
```

Sample response: `The flowers in the picture are cherry blossoms, which are known for their delicate pink petals and are often associated with the beauty of spring.`
19 changes: 19 additions & 0 deletions examples/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,27 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow
```

**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)

**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)

**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.

```python
import os
import transformers

model_path = ...
llm_export_path = ...

tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)

tokenizer.save_pretrained(llm_export_path)
model.language_model.save_pretrained(llm_export_path)
```

Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.

## llava-cli templating and llava-1.6 prompting

llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
Expand Down
Loading

0 comments on commit 7b076ee

Please sign in to comment.