Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Llava support #352

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/tests/llava_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3

import torch
from scalellm import VLM, SamplingParameter, StoppingCriteria

def test_pixel_value_llava_generate():
vlm = VLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
image = torch.load("images/stop_sign_pixel_values.pt")

output = vlm.generate(images, prompt)
print(o.outputs[0].text)

def main():
test_pixel_value_llava_generate()

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions scalellm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pybind_extension(
csrc/sampling_params.cpp
csrc/output.cpp
csrc/llm_handler.cpp
csrc/vlm_handler.cpp
csrc/module.cpp
DEPS
:llm_handler
Expand Down
2 changes: 2 additions & 0 deletions scalellm/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from scalellm._C.llm_handler import LLMHandler, Message, Priority
from scalellm._C.output import (LogProb, LogProbData, RequestOutput,
SequenceOutput, Status, StatusCode, Usage)
from scalellm._C.sampling_params import SamplingParams
from scalellm._C.vlm_handler import VLMHandler

# Defined in scalellm/csrc/module.cpp
def get_metrics() -> str: ...
Expand All @@ -18,5 +19,6 @@ __all__ = [
"StatusCode",
"Usage",
"LLMHandler",
"VLMHandler",
"get_metrics",
]
47 changes: 47 additions & 0 deletions scalellm/_C/vlm_handler.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Callable, List, Optional

import torch

from scalellm._C.llm_handler import Future, Priority
from scalellm._C.output import RequestOutput
from scalellm._C.sampling_params import SamplingParams

class VLMHandler:
class Options:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
model_path: str
devices: Optional[str]
block_size: int
max_cache_size: int
max_memory_utilization: float
enable_prefix_cache: bool
enable_cuda_graph: bool
cuda_graph_max_seq_len: int
cuda_graph_batch_sizes: Optional[List[int]]
max_tokens_per_batch: int
max_seqs_per_batch: int
num_handling_threads: int
image_input_type: str
image_token_id: int
image_input_shape: str
image_feature_size: int

def __init__(self, options: Options) -> None: ...
def __repr__(self) -> str: ...
def schedule_async(
self,
image: torch.Tensor,
prompt: str,
sp: SamplingParams,
priority: Priority,
stream: bool,
callback: Callable[[RequestOutput], bool],
) -> Future: ...
def start(self) -> None: ...
def stop(self) -> None: ...
def run_until_complete(self) -> None: ...
def reset(self) -> None: ...
# helper functions
def encode(self, text: str) -> List[int]: ...
def decode(self, tokens: List[int], skip_special_tokens: bool) -> str: ...
3 changes: 2 additions & 1 deletion scalellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from scalellm._C import (LLMHandler, LogProb, LogProbData, Message, Priority,
RequestOutput, SamplingParams, SequenceOutput, Status,
StatusCode, Usage, get_metrics)
StatusCode, Usage, VLMHandler, get_metrics)
from scalellm.errors import ValidationError
from scalellm.llm import LLM
from scalellm.llm_engine import AsyncLLMEngine, OutputAsyncStream, OutputStream
Expand All @@ -34,5 +34,6 @@
"StatusCode",
"Usage",
"LLMHandler",
"VLMHandler",
"get_metrics",
]
4 changes: 3 additions & 1 deletion scalellm/csrc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace py = pybind11;
extern void init_sampling_params(py::module_& m);
extern void init_output(py::module_& m);
extern void init_llm_handler(py::module_& m);
extern void init_vlm_handler(py::module_& m);

// NOLINTNEXTLINE
static std::string get_metrics() { return Metrics::Instance().GetString(); }
Expand All @@ -26,6 +27,7 @@ PYBIND11_MODULE(PY_MODULE_NAME, m) {
init_sampling_params(m);
init_output(m);
init_llm_handler(m);
init_vlm_handler(m);
}

} // namespace llm::csrc
} // namespace llm::csrc
100 changes: 100 additions & 0 deletions scalellm/csrc/vlm_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include "handlers/vlm_handler.h"

#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

namespace llm::csrc {
namespace py = pybind11;
using namespace pybind11::literals;

void init_vlm_handler(py::module_& m) {
auto vlm_handler =
py::class_<VLMHandler>(m, "VLMHandler")
.def(py::init<const VLMHandler::Options&>(), py::arg("options"))
.def("schedule_async",
&VLMHandler::schedule_async,
py::call_guard<py::gil_scoped_release>())
.def("start",
&VLMHandler::start,
py::call_guard<py::gil_scoped_release>())
.def("stop",
&VLMHandler::stop,
py::call_guard<py::gil_scoped_release>())
.def("run_until_complete",
&VLMHandler::run_until_complete,
py::call_guard<py::gil_scoped_release>())
.def("encode",
&VLMHandler::encode,
py::call_guard<py::gil_scoped_release>())
.def("decode",
&VLMHandler::decode,
py::call_guard<py::gil_scoped_release>())
.def("reset",
&VLMHandler::reset,
py::call_guard<py::gil_scoped_release>())
.def("__repr__", [](const VLMHandler& self) {
return "VLMHandler({})"_s.format(self.options());
});

// VLMHandler::Options
py::class_<VLMHandler::Options>(vlm_handler, "Options")
.def(py::init())
.def_readwrite("model_path", &VLMHandler::Options::model_path_)
.def_readwrite("devices", &VLMHandler::Options::devices_)
.def_readwrite("block_size", &VLMHandler::Options::block_size_)
.def_readwrite("max_cache_size", &VLMHandler::Options::max_cache_size_)
.def_readwrite("max_memory_utilization",
&VLMHandler::Options::max_memory_utilization_)
.def_readwrite("enable_prefix_cache",
&VLMHandler::Options::enable_prefix_cache_)
.def_readwrite("enable_cuda_graph",
&VLMHandler::Options::enable_cuda_graph_)
.def_readwrite("cuda_graph_max_seq_len",
&VLMHandler::Options::cuda_graph_max_seq_len_)
.def_readwrite("cuda_graph_batch_sizes",
&VLMHandler::Options::cuda_graph_batch_sizes_)
.def_readwrite("max_tokens_per_batch",
&VLMHandler::Options::max_tokens_per_batch_)
.def_readwrite("max_seqs_per_batch",
&VLMHandler::Options::max_seqs_per_batch_)
.def_readwrite("num_handling_threads",
&VLMHandler::Options::num_handling_threads_)
.def_readwrite("image_input_type",
&VLMHandler::Options::image_input_type_)
.def_readwrite("image_token_id", &VLMHandler::Options::image_token_id_)
.def_readwrite("image_input_shape",
&VLMHandler::Options::image_input_shape_)
.def_readwrite("image_feature_size",
&VLMHandler::Options::image_feature_size_)
.def("__repr__", [](const VLMHandler::Options& self) {
return "Options(model_path={}, devices={}, "
"block_size={}, max_cache_size={}, "
"max_memory_utilization={}, enable_prefix_cache={}, "
"enable_cuda_graph={}, cuda_graph_max_seq_len={}, "
"cuda_graph_batch_sizes={}, "
"max_tokens_per_batch={}, max_seqs_per_batch={}, "
"num_handling_threads={}, "
"image_input_type={}, image_token_id={}, "
"image_input_shape={}, image_feature_size={})"_s.format(
self.model_path_,
self.devices_,
self.block_size_,
self.max_cache_size_,
self.max_memory_utilization_,
self.enable_prefix_cache_,
self.enable_cuda_graph_,
self.cuda_graph_max_seq_len_,
self.cuda_graph_batch_sizes_,
self.max_tokens_per_batch_,
self.max_seqs_per_batch_,
self.num_handling_threads_,
self.image_input_type_,
self.image_token_id_,
self.image_input_shape_,
self.image_feature_size_);
});
}

} // namespace llm::csrc
127 changes: 127 additions & 0 deletions scalellm/vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
from typing import List, Optional

import torch

from scalellm._C import Priority, RequestOutput, SamplingParams, VLMHandler
from scalellm.downloader import download_hf_model
from scalellm.errors import ValidationError


class VLM:
def __init__(
self,
model: str,
revision: Optional[str] = None,
allow_patterns: Optional[str] = None,
cache_dir: Optional[str] = None,
convert_to_safetensors: bool = False,
devices: Optional[str] = None,
block_size: int = 16,
max_cache_size: int = 20 * 1024 * 1024 * 1024,
max_memory_utilization: float = 0.9,
enable_prefix_cache: bool = True,
enable_cuda_graph: bool = True,
cuda_graph_max_seq_len: int = 2048,
cuda_graph_batch_sizes: Optional[List[int]] = None,
max_tokens_per_batch: int = 409600, # a big number to disable chunked prefill
max_seqs_per_batch: int = 2048, # a big number for better throughput
num_handling_threads: int = 4,
# vision encoder configuration
image_input_type: Optional[str] = None,
image_token_id: Optional[int] = None,
image_input_shape: Optional[str] = None,
image_feature_size: Optional[int] = None,
) -> None:
# download hf model if it does not exist
self._model = model
model_path = model
if not os.path.exists(model_path):
model_path = download_hf_model(
repo_id=model_path,
revision=revision,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
convert_to_safetensors=convert_to_safetensors,
)

options = VLMHandler.Options()
options.model_path = model_path
options.devices = devices
options.block_size = block_size
options.max_cache_size = max_cache_size
options.max_memory_utilization = max_memory_utilization
options.enable_prefix_cache = enable_prefix_cache
options.enable_cuda_graph = enable_cuda_graph
options.cuda_graph_max_seq_len = cuda_graph_max_seq_len
options.cuda_graph_batch_sizes = cuda_graph_batch_sizes
options.max_tokens_per_batch = max_tokens_per_batch
options.max_seqs_per_batch = max_seqs_per_batch
options.num_handling_threads = num_handling_threads
options.image_input_type = image_input_type
options.image_token_id = image_token_id
options.image_input_shape = image_input_shape
options.image_feature_size = image_feature_size
# create the LLM handler
self._handler = VLMHandler(options)

def generate(
self,
image: torch.Tensor = None,
prompt: str = None,
sampling_params: Optional[SamplingParams] = None,
priority: Priority = Priority.NORMAL,
wait_for_schedule: bool = True,
) -> RequestOutput:
# use default sampling parameters if not provided
if sampling_params is None:
sampling_params = SamplingParams()

output = None
def callback(async_output: RequestOutput) -> bool:
#output = async_output
return True

# schedule the batch requests
future = self._handler.schedule_async(
image, prompt, sampling_params, priority, False, callback
)

# wait for batch request to be scheduled
if wait_for_schedule:
future.wait()

# run until all scheduled requsts complete
self._handler.run_until_complete()

# throw an exception if there is any error
if output is None:
raise RuntimeError("Request failed, no output received")
if output.status is not None and not output.status.ok:
raise ValidationError(output.status.code, output.status.message)
# carry over the prompt to the output
output.prompt = prompt
return output

def encode(self, text: str) -> List[int]:
return self._handler.encode(text)

def decode(
self, tokens: List[int], skip_special_tokens: bool = True
) -> Optional[str]:
return self._handler.decode(tokens, skip_special_tokens)

def __del__(self):
self._handler.reset()

def __enter__(self):
return self

def __exit__(self, *args):
self.__del__()
return False

def __repr__(self) -> str:
if self._draft_model:
return f"VLM(model={self._model}, draft_model={self._draft_model})"
return f"VLM(model={self._model})"
4 changes: 4 additions & 0 deletions src/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ cc_library(
batch.h
model_runner.h
worker.h
vlm_worker.h
engine.h
llm_engine.h
vlm_engine.h
SRCS
utils.cpp
batch.cpp
model_runner.cpp
worker.cpp
vlm_worker.cpp
llm_engine.cpp
vlm_engine.cpp
DEPS
torch
:common
Expand Down
3 changes: 3 additions & 0 deletions src/engine/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ void Batch::add(Sequence* sequence, uint32_t token_budget) {
sequences_.push_back(sequence);
token_budgets_.push_back(token_budget);
budget_used_.push_back(0);

input_embedding_ = sequence->get_input_embedding();
}

void Batch::add(const std::vector<Sequence*>& sequences) {
Expand Down Expand Up @@ -258,6 +260,7 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,

pad_2d_vector(block_tables_vec, /*pad_value=*/0);
input_params.block_tables = create_2d_tensor(block_tables_vec, torch::kInt);
input_params.input_embedding = input_embedding_;

CHECK_EQ(sampling_params.size(), selected_token_idxes.size());
if (!selected_token_idxes.empty()) {
Expand Down
Loading