From 3683d85cb9c7bf64286fe7453e510624d46d79ba Mon Sep 17 00:00:00 2001 From: Tongxuan Liu Date: Fri, 28 Jun 2024 16:37:35 +0800 Subject: [PATCH 1/2] [model] support vision language model llava. (#178) (cherry picked from commit 437be3f35da4616881a96125ad31e6bdc4ea7a91) --- python/tests/llava_test.py | 28 ++ scalellm/_C/__init__.pyi | 2 + scalellm/_C/vlm_handler.pyi | 47 ++ scalellm/__init__.py | 3 +- scalellm/csrc/module.cpp | 4 +- scalellm/csrc/vlm_handler.cpp | 115 +++++ scalellm/vlm.py | 127 ++++++ src/engine/CMakeLists.txt | 4 + src/engine/batch.cpp | 3 + src/engine/batch.h | 6 + src/engine/engine.h | 8 + src/engine/parameters.h | 2 +- src/engine/vlm_engine.cpp | 303 +++++++++++++ src/engine/vlm_engine.h | 139 ++++++ src/engine/vlm_worker.cpp | 284 ++++++++++++ src/engine/vlm_worker.h | 114 +++++ src/handlers/CMakeLists.txt | 2 + src/handlers/vlm_handler.cpp | 475 ++++++++++++++++++++ src/handlers/vlm_handler.h | 154 +++++++ src/layers/activation.cpp | 4 + src/models/CMakeLists.txt | 2 + src/models/causal_vlm.cpp | 26 ++ src/models/causal_vlm.h | 76 ++++ src/models/huggingface/llava.h | 771 +++++++++++++++++++++++++++++++++ src/models/model_args.h | 108 +++++ src/models/model_registry.cpp | 16 + src/models/model_registry.h | 33 ++ src/models/parameters.h | 4 + src/request/request.cpp | 19 + src/request/request.h | 10 + src/request/sequence.cpp | 34 ++ src/request/sequence.h | 13 + 32 files changed, 2933 insertions(+), 3 deletions(-) create mode 100644 python/tests/llava_test.py create mode 100644 scalellm/_C/vlm_handler.pyi create mode 100644 scalellm/csrc/vlm_handler.cpp create mode 100644 scalellm/vlm.py create mode 100644 src/engine/vlm_engine.cpp create mode 100644 src/engine/vlm_engine.h create mode 100644 src/engine/vlm_worker.cpp create mode 100644 src/engine/vlm_worker.h create mode 100644 src/handlers/vlm_handler.cpp create mode 100644 src/handlers/vlm_handler.h create mode 100644 src/models/causal_vlm.cpp create mode 100644 src/models/causal_vlm.h create mode 100644 src/models/huggingface/llava.h diff --git a/python/tests/llava_test.py b/python/tests/llava_test.py new file mode 100644 index 00000000..0eb994b6 --- /dev/null +++ b/python/tests/llava_test.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +import torch +from scalellm import VLM, SamplingParameter, StoppingCriteria + +def test_pixel_value_llava_generate(): + vlm = VLM( + model="llava-hf/llava-1.5-7b-hf", + image_input_type="pixel_values", + image_token_id=32000, + image_input_shape="1,3,336,336", + image_feature_size=576, + ) + + prompt = "" * 576 + ( + "\nUSER: What is the content of this image?\nASSISTANT:") + + # This should be provided by another online or offline component. + image = torch.load("images/stop_sign_pixel_values.pt") + + output = vlm.generate(images, prompt) + print(o.outputs[0].text) + +def main(): + test_pixel_value_llava_generate() + +if __name__ == "__main__": + main() diff --git a/scalellm/_C/__init__.pyi b/scalellm/_C/__init__.pyi index bab85d93..0e5253e6 100644 --- a/scalellm/_C/__init__.pyi +++ b/scalellm/_C/__init__.pyi @@ -2,6 +2,7 @@ from scalellm._C.llm_handler import LLMHandler, Message, Priority from scalellm._C.output import (LogProb, LogProbData, RequestOutput, SequenceOutput, Status, StatusCode, Usage) from scalellm._C.sampling_params import SamplingParams +from scalellm._C.vlm_handler import VLMHandler # Defined in scalellm/csrc/module.cpp def get_metrics() -> str: ... @@ -18,5 +19,6 @@ __all__ = [ "StatusCode", "Usage", "LLMHandler", + "VLMHandler", "get_metrics", ] diff --git a/scalellm/_C/vlm_handler.pyi b/scalellm/_C/vlm_handler.pyi new file mode 100644 index 00000000..4d4d85a8 --- /dev/null +++ b/scalellm/_C/vlm_handler.pyi @@ -0,0 +1,47 @@ +from typing import Callable, List, Optional + +import torch + +from scalellm._C.llm_handler import Future, Priority +from scalellm._C.output import RequestOutput +from scalellm._C.sampling_params import SamplingParams + +class VLMHandler: + class Options: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + model_path: str + devices: Optional[str] + block_size: int + max_cache_size: int + max_memory_utilization: float + enable_prefix_cache: bool + enable_cuda_graph: bool + cuda_graph_max_seq_len: int + cuda_graph_batch_sizes: Optional[List[int]] + max_tokens_per_batch: int + max_seqs_per_batch: int + num_handling_threads: int + image_input_type: str + image_token_id: int + image_input_shape: str + image_feature_size: int + + def __init__(self, options: Options) -> None: ... + def __repr__(self) -> str: ... + def schedule_async( + self, + image: torch.Tensor, + prompt: str, + sp: SamplingParams, + priority: Priority, + stream: bool, + callback: Callable[[RequestOutput], bool], + ) -> Future: ... + def start(self) -> None: ... + def stop(self) -> None: ... + def run_until_complete(self) -> None: ... + def reset(self) -> None: ... + # helper functions + def encode(self, text: str) -> List[int]: ... + def decode(self, tokens: List[int], skip_special_tokens: bool) -> str: ... diff --git a/scalellm/__init__.py b/scalellm/__init__.py index 6dceb1b3..28f1f609 100644 --- a/scalellm/__init__.py +++ b/scalellm/__init__.py @@ -12,7 +12,7 @@ from scalellm._C import (LLMHandler, LogProb, LogProbData, Message, Priority, RequestOutput, SamplingParams, SequenceOutput, Status, - StatusCode, Usage, get_metrics) + StatusCode, Usage, VLMHandler, get_metrics) from scalellm.errors import ValidationError from scalellm.llm import LLM from scalellm.llm_engine import AsyncLLMEngine, OutputAsyncStream, OutputStream @@ -34,5 +34,6 @@ "StatusCode", "Usage", "LLMHandler", + "VLMHandler", "get_metrics", ] diff --git a/scalellm/csrc/module.cpp b/scalellm/csrc/module.cpp index 0d1e7802..bc65d2b9 100644 --- a/scalellm/csrc/module.cpp +++ b/scalellm/csrc/module.cpp @@ -11,6 +11,7 @@ namespace py = pybind11; extern void init_sampling_params(py::module_& m); extern void init_output(py::module_& m); extern void init_llm_handler(py::module_& m); +extern void init_vlm_handler(py::module_& m); // NOLINTNEXTLINE static std::string get_metrics() { return Metrics::Instance().GetString(); } @@ -26,6 +27,7 @@ PYBIND11_MODULE(PY_MODULE_NAME, m) { init_sampling_params(m); init_output(m); init_llm_handler(m); + init_vlm_handler(m); } -} // namespace llm::csrc \ No newline at end of file +} // namespace llm::csrc diff --git a/scalellm/csrc/vlm_handler.cpp b/scalellm/csrc/vlm_handler.cpp new file mode 100644 index 00000000..fef28ff3 --- /dev/null +++ b/scalellm/csrc/vlm_handler.cpp @@ -0,0 +1,115 @@ +#include "handlers/vlm_handler.h" + +#include +#include +#include +#include + +namespace llm::csrc { +namespace py = pybind11; +using namespace pybind11::literals; + +void init_vlm_handler(py::module_& m) { + py::enum_(m, "Priority") + .value("DEFAULT", Priority::NORMAL) + .value("LOW", Priority::LOW) + .value("NORMAL", Priority::NORMAL) + .value("HIGH", Priority::HIGH) + .export_values(); + + py::class_>(m, "Future") + .def("wait", + &std::future::wait, + py::call_guard()) + .def("get", + &std::future::get, + py::call_guard()); + + auto vlm_handler = + py::class_(m, "VLMHandler") + .def(py::init(), py::arg("options")) + .def("schedule_async", + &VLMHandler::schedule_async, + py::call_guard()) + .def("start", + &VLMHandler::start, + py::call_guard()) + .def("stop", + &VLMHandler::stop, + py::call_guard()) + .def("run_until_complete", + &VLMHandler::run_until_complete, + py::call_guard()) + .def("encode", + &VLMHandler::encode, + py::call_guard()) + .def("decode", + &VLMHandler::decode, + py::call_guard()) + .def("reset", + &VLMHandler::reset, + py::call_guard()) + .def("__repr__", [](const VLMHandler& self) { + return "VLMHandler({})"_s.format(self.options()); + }); + + // VLMHandler::Options + py::class_(vlm_handler, "Options") + .def(py::init()) + .def_readwrite("model_path", &VLMHandler::Options::model_path_) + .def_readwrite("devices", &VLMHandler::Options::devices_) + .def_readwrite("block_size", &VLMHandler::Options::block_size_) + .def_readwrite("max_cache_size", &VLMHandler::Options::max_cache_size_) + .def_readwrite("max_memory_utilization", + &VLMHandler::Options::max_memory_utilization_) + .def_readwrite("enable_prefix_cache", + &VLMHandler::Options::enable_prefix_cache_) + .def_readwrite("enable_cuda_graph", + &VLMHandler::Options::enable_cuda_graph_) + .def_readwrite("cuda_graph_max_seq_len", + &VLMHandler::Options::cuda_graph_max_seq_len_) + .def_readwrite("cuda_graph_batch_sizes", + &VLMHandler::Options::cuda_graph_batch_sizes_) + .def_readwrite("max_tokens_per_batch", + &VLMHandler::Options::max_tokens_per_batch_) + .def_readwrite("max_seqs_per_batch", + &VLMHandler::Options::max_seqs_per_batch_) + .def_readwrite("num_handling_threads", + &VLMHandler::Options::num_handling_threads_) + .def_readwrite("image_input_type", + &VLMHandler::Options::image_input_type_) + .def_readwrite("image_token_id", &VLMHandler::Options::image_token_id_) + .def_readwrite("image_input_shape", + &VLMHandler::Options::image_input_shape_) + .def_readwrite("image_feature_size", + &VLMHandler::Options::image_feature_size_) + .def("__repr__", [](const VLMHandler::Options& self) { + return "Options(model_path={}, devices={}, " + "block_size={}, max_cache_size={}, " + "max_memory_utilization={}, enable_prefix_cache={}, " + "enable_cuda_graph={}, cuda_graph_max_seq_len={}, " + "cuda_graph_batch_sizes={}, " + "max_tokens_per_batch={}, max_seqs_per_batch={}, " + "num_handling_threads={}, " + "image_input_type={}, image_token_id={}, + "image_input_shape={}, image_feature_size={})"_s.format( + self.model_path_, + self.devices_, + self.block_size_, + self.max_cache_size_, + self.max_memory_utilization_, + self.enable_prefix_cache_, + self.enable_cuda_graph_, + self.cuda_graph_max_seq_len_, + self.cuda_graph_batch_sizes_, + self.max_tokens_per_batch_, + self.max_seqs_per_batch_, + self.num_handling_threads_, + self.image_input_type_, + self.image_token_id_, + self.image_input_shape_, + self.image_feature_size_); + }); +} + +} // namespace llm::csrc diff --git a/scalellm/vlm.py b/scalellm/vlm.py new file mode 100644 index 00000000..e1e601c1 --- /dev/null +++ b/scalellm/vlm.py @@ -0,0 +1,127 @@ +import os +from typing import List, Optional + +import torch + +from scalellm._C import Priority, RequestOutput, SamplingParams, VLMHandler +from scalellm.downloader import download_hf_model +from scalellm.errors import ValidationError + + +class VLM: + def __init__( + self, + model: str, + revision: Optional[str] = None, + allow_patterns: Optional[str] = None, + cache_dir: Optional[str] = None, + convert_to_safetensors: bool = False, + devices: Optional[str] = None, + block_size: int = 16, + max_cache_size: int = 20 * 1024 * 1024 * 1024, + max_memory_utilization: float = 0.9, + enable_prefix_cache: bool = True, + enable_cuda_graph: bool = True, + cuda_graph_max_seq_len: int = 2048, + cuda_graph_batch_sizes: Optional[List[int]] = None, + max_tokens_per_batch: int = 409600, # a big number to disable chunked prefill + max_seqs_per_batch: int = 2048, # a big number for better throughput + num_handling_threads: int = 4, + # vision encoder configuration + image_input_type: Optional[str] = None, + image_token_id: Optional[int] = None, + image_input_shape: Optional[str] = None, + image_feature_size: Optional[int] = None, + ) -> None: + # download hf model if it does not exist + self._model = model + model_path = model + if not os.path.exists(model_path): + model_path = download_hf_model( + repo_id=model_path, + revision=revision, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + convert_to_safetensors=convert_to_safetensors, + ) + + options = VLMHandler.Options() + options.model_path = model_path + options.devices = devices + options.block_size = block_size + options.max_cache_size = max_cache_size + options.max_memory_utilization = max_memory_utilization + options.enable_prefix_cache = enable_prefix_cache + options.enable_cuda_graph = enable_cuda_graph + options.cuda_graph_max_seq_len = cuda_graph_max_seq_len + options.cuda_graph_batch_sizes = cuda_graph_batch_sizes + options.max_tokens_per_batch = max_tokens_per_batch + options.max_seqs_per_batch = max_seqs_per_batch + options.num_handling_threads = num_handling_threads + options.image_input_type = image_input_type + options.image_token_id = image_token_id + options.image_input_shape = image_input_shape + options.image_feature_size = image_feature_size + # create the LLM handler + self._handler = VLMHandler(options) + + def generate( + self, + image: torch.Tensor = None, + prompt: str = None, + sampling_params: Optional[SamplingParams] = None, + priority: Priority = Priority.NORMAL, + wait_for_schedule: bool = True, + ) -> RequestOutput: + # use default sampling parameters if not provided + if sampling_params is None: + sampling_params = SamplingParams() + + output = None + def callback(async_output: RequestOutput) -> bool: + #output = async_output + return True + + # schedule the batch requests + future = self._handler.schedule_async( + image, prompt, sampling_params, priority, False, callback + ) + + # wait for batch request to be scheduled + if wait_for_schedule: + future.wait() + + # run until all scheduled requsts complete + self._handler.run_until_complete() + + # throw an exception if there is any error + if output is None: + raise RuntimeError("Request failed, no output received") + if output.status is not None and not output.status.ok: + raise ValidationError(output.status.code, output.status.message) + # carry over the prompt to the output + output.prompt = prompt + return output + + def encode(self, text: str) -> List[int]: + return self._handler.encode(text) + + def decode( + self, tokens: List[int], skip_special_tokens: bool = True + ) -> Optional[str]: + return self._handler.decode(tokens, skip_special_tokens) + + def __del__(self): + self._handler.reset() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.__del__() + return False + + def __repr__(self) -> str: + if self._draft_model: + return f"VLM(model={self._model}, draft_model={self._draft_model})" + return f"VLM(model={self._model})" diff --git a/src/engine/CMakeLists.txt b/src/engine/CMakeLists.txt index 0ed44e43..8f1483e4 100644 --- a/src/engine/CMakeLists.txt +++ b/src/engine/CMakeLists.txt @@ -10,14 +10,18 @@ cc_library( batch.h model_runner.h worker.h + vlm_worker.h engine.h llm_engine.h + vlm_engine.h SRCS utils.cpp batch.cpp model_runner.cpp worker.cpp + vlm_worker.cpp llm_engine.cpp + vlm_engine.cpp DEPS torch :common diff --git a/src/engine/batch.cpp b/src/engine/batch.cpp index cedf7cc2..1b3b1b3d 100644 --- a/src/engine/batch.cpp +++ b/src/engine/batch.cpp @@ -49,6 +49,8 @@ void Batch::add(Sequence* sequence, uint32_t token_budget) { sequences_.push_back(sequence); token_budgets_.push_back(token_budget); budget_used_.push_back(0); + + input_embedding_ = sequence->get_input_embedding(); } void Batch::add(const std::vector& sequences) { @@ -258,6 +260,7 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, pad_2d_vector(block_tables_vec, /*pad_value=*/0); input_params.block_tables = create_2d_tensor(block_tables_vec, torch::kInt); + input_params.input_embedding = input_embedding_; CHECK_EQ(sampling_params.size(), selected_token_idxes.size()); if (!selected_token_idxes.empty()) { diff --git a/src/engine/batch.h b/src/engine/batch.h index b452b6ea..225e1e46 100644 --- a/src/engine/batch.h +++ b/src/engine/batch.h @@ -57,6 +57,9 @@ class Batch { // set the engine type for the batch void set_engine_type(EngineType engine_type); + // TODO: + torch::Tensor get_input_embedding() const { return input_embedding_; } + private: static Token build_token(int64_t index, torch::Tensor token_ids, @@ -72,6 +75,9 @@ class Batch { // number of used budget for each sequence std::vector budget_used_; + + // TODO: + torch::Tensor input_embedding_; }; } // namespace llm diff --git a/src/engine/engine.h b/src/engine/engine.h index 21e44498..2fb16237 100644 --- a/src/engine/engine.h +++ b/src/engine/engine.h @@ -27,4 +27,12 @@ class Engine { virtual const TokenizerArgs& tokenizer_args() const = 0; }; +class VisionEngine { + public: + virtual ~VisionEngine() = default; + + virtual torch::Tensor vision_encode(torch::Tensor image, + torch::Tensor tokens) = 0; +}; + } // namespace llm diff --git a/src/engine/parameters.h b/src/engine/parameters.h index cee92c9e..b91d9e81 100644 --- a/src/engine/parameters.h +++ b/src/engine/parameters.h @@ -39,4 +39,4 @@ struct ModelOutput { torch::Tensor logits; }; -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/engine/vlm_engine.cpp b/src/engine/vlm_engine.cpp new file mode 100644 index 00000000..9ea7116f --- /dev/null +++ b/src/engine/vlm_engine.cpp @@ -0,0 +1,303 @@ +#include "vlm_engine.h" + +#include +#include + +#include +#include +#include + +#include "common/metrics.h" +#include "common/pretty_print.h" +#include "model_loader/model_loader.h" +#include "model_parallel/parallel_args.h" +#include "models/model_args.h" +#include "vlm_worker.h" + +DEFINE_COUNTER(prepare_input_latency_seconds, + "Latency of preparing input in seconds"); + +namespace llm { +namespace { +// clang-format off +const std::vector kDefaultBatchSizesForCudaGraph = + {1, 2, 4, 8, 16, 24, 32, 48, 64}; +// clang-format on + +torch::ScalarType parse_dtype(const std::string& dtype_str, + const torch::Device& device) { + if (device.is_cpu()) { + // cpu only supports float32 for now + return torch::kFloat32; + } + + if (boost::iequals(dtype_str, "half") || + boost::iequals(dtype_str, "float16")) { + return torch::kFloat16; + } + if (boost::iequals(dtype_str, "bfloat16")) { + return torch::kBFloat16; + } + if ((boost::iequals(dtype_str, "float") || + boost::iequals(dtype_str, "float32"))) { + // cuda only supports float16 and bfloat16 for now + return torch::kFloat16; + } + + if (dtype_str.empty() || boost::iequals(dtype_str, "auto")) { + return torch::kFloat16; + } + CHECK(false) << "Unsupported dtype: " << dtype_str << " on device " << device; +} +} // namespace + +VLMEngine::VLMEngine(const Options& options) : options_(options) { + const auto& devices = options_.devices(); + CHECK_GT(devices.size(), 0) << "At least one device is required"; + + const auto device_type = devices[0].type(); + if (devices[0].is_cuda()) { + // check cuda compute capability + const auto* properties = at::cuda::getDeviceProperties(devices[0].index()); + const bool is_sm8x = properties->major == 8 && properties->minor >= 0; + const bool is_sm90 = properties->major == 9 && properties->minor == 0; + CHECK(is_sm90 || is_sm8x) << "Engine only supports Ampere GPUs or newer."; + // TODO: add Turing(sm75) support in the near future. + } + + // sort cuda graph batch sizes + if (options_.enable_cuda_graph()) { + batch_sizes_ = options_.cuda_graph_batch_sizes().value_or( + kDefaultBatchSizesForCudaGraph); + std::sort(batch_sizes_.begin(), batch_sizes_.end()); + } + + // create a worker for each device + ModelRunner::Options runner_options; + runner_options.block_size(options_.block_size()) + .num_decoding_tokens(options_.num_decoding_tokens()) + .cuda_graph_max_seq_len(options_.cuda_graph_max_seq_len()) + .cuda_graph_batch_sizes(batch_sizes_); + ParallelArgs parallel_args(0, 1, nullptr); + worker_ = + std::make_unique(parallel_args, devices[0], runner_options); +} + +bool VLMEngine::init(const std::string& model_weights_path) { + if (!init_model(model_weights_path)) { + LOG(ERROR) << "Failed to initialize model from: " << model_weights_path; + return false; + } + + // initialize kv cache + const int64_t cache_size_in_bytes = profile_memory_for_kv_cache(); + CHECK_GT(cache_size_in_bytes, 0); + LOG(INFO) << "Initializing kv cache with size: " + << readable_size(cache_size_in_bytes); + const int64_t n_blocks = calculate_kv_cache_blocks(cache_size_in_bytes); + if (!init_kv_cache(n_blocks)) { + LOG(ERROR) << "Failed to initialize kv cache"; + return false; + } + if (!capture_cuda_graphs()) { + LOG(ERROR) << "Failed to warmup model."; + return false; + } + return true; +} + +bool VLMEngine::init_model(const std::string& model_weights_path) { + auto model_loader = ModelLoader::create(model_weights_path); + LOG(INFO) << "Initializing model from: " << model_weights_path; + + tokenizer_ = model_loader->tokenizer(); + CHECK(tokenizer_ != nullptr); + + args_ = model_loader->model_args(); + quant_args_ = model_loader->quant_args(); + tokenizer_args_ = model_loader->tokenizer_args(); + + // compute the number of local kv heads and head dim + const int world_size = 1; + const int64_t n_heads = args_.n_heads(); + const int64_t n_kv_heads = args_.n_kv_heads().value_or(n_heads); + n_local_kv_heads_ = std::max(1, n_kv_heads / world_size); + head_dim_ = args_.head_dim(); + dtype_ = parse_dtype(args_.dtype(), options_.devices()[0]); + + // key + value for all layers + LOG(INFO) << "Block info, block_size: " << options_.block_size() + << ", n_local_kv_heads: " << n_local_kv_heads_ + << ", head_dim: " << head_dim_ << ", n_layers: " << args_.n_layers() + << ", dtype: " << dtype_; + + if (tokenizer_->vocab_size() != args_.vocab_size()) { + // use tokenizer vocab size if model vocab size is not set + if (args_.vocab_size() <= 0) { + LOG(WARNING) << "Model vocab size is not set, using tokenizer vocab " + "size: " + << tokenizer_->vocab_size(); + args_.vocab_size(tokenizer_->vocab_size()); + } else { + LOG(WARNING) << "Vocab size mismatch: tokenizer: " + << tokenizer_->vocab_size() + << ", model: " << args_.vocab_size(); + } + } + + LOG(INFO) << "Initializing model with " << args_; + LOG(INFO) << "Initializing model with quant args: " << quant_args_; + LOG(INFO) << "Initializing model with tokenizer args: " << tokenizer_args_; + + VLMWorker* worker = worker_.get(); + // only one worker, call init_model in current thread + if (!worker->init_model(dtype_, args_, quant_args_)) { + return false; + } + // load the weights from the checkpoint + for (const auto& state_dict : *model_loader) { + worker->load_state_dict(state_dict); + } + worker->verify_loaded_weights(); + return true; +} + +bool VLMEngine::capture_cuda_graphs() { + if (!options_.enable_cuda_graph()) { + return true; + } + + LOG(INFO) << "Capturing CUDA graphs: num_decoding_tokens: " + << options_.num_decoding_tokens() + << ", batch sizes: " << batch_sizes_; + + for (const auto batch_size : batch_sizes_) { + std::vector> futures; + futures.emplace_back(worker_->capture_cuda_graph_async(batch_size)); + // wait up to 4 seconds for all futures to complete + folly::collectAll(futures).within(std::chrono::seconds(4)).get(); + } + return true; +} + +int64_t VLMEngine::profile_memory_for_kv_cache() { + const int64_t max_cache_size = options_.max_cache_size(); + const double max_memory_utilization = options_.max_memory_utilization(); + + const auto& device = worker_->device(); + if (device.is_cpu()) { + // use max memory cache size for CPU + LOG(INFO) << "Initializing CPU cache with max cache size: " + << readable_size(max_cache_size); + // TODO: add CPU memory profiling + return max_cache_size; + } + CHECK(device.is_cuda()) << "Only support CPU and CUDA device for now."; + + // call worker to profile memory usage + std::vector>> futures; + futures.push_back(worker_->profile_device_memory_async()); + + // pick smallest available memory from all devices + int64_t smallest_available_memory = std::numeric_limits::max(); + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (size_t i = 0; i < results.size(); ++i) { + const auto device = worker_->device(); + if (!results[i].hasValue()) { + LOG(ERROR) << "Failed to profile memory usage for device: " << device; + continue; + } + auto [available_memory, total_memory] = results[i].value(); + LOG(INFO) << device + << ": available memory: " << readable_size(available_memory) + << ", total memory: " << readable_size(total_memory); + + LOG(INFO) << "Using max_memory_utilization: " << max_memory_utilization + << ", max_cache_size: " << readable_size(max_cache_size); + // apply memory cap from config if it is set + if (max_memory_utilization < 1.0) { + const int64_t buffer_memory = + total_memory * (1.0 - max_memory_utilization); + available_memory -= buffer_memory; + } + if (max_cache_size > 0) { + available_memory = std::min(available_memory, max_cache_size); + } + smallest_available_memory = + std::min(smallest_available_memory, available_memory); + } + return std::max(smallest_available_memory, int64_t(0)); +} + +bool VLMEngine::init_kv_cache(int64_t n_blocks) { + CHECK_GT(n_blocks, 0) << "no memory for kv cache"; + const int32_t block_size = options_.block_size(); + + // init kv cache for each worker + const std::vector kv_cache_shape = { + n_blocks, block_size, n_local_kv_heads_, head_dim_}; + LOG(INFO) << "Initializing kv cache with shape: [" << kv_cache_shape << "]"; + + // initialize block manager + BlockManager::Options options; + options.num_blocks(n_blocks) + .block_size(block_size) + .enable_prefix_cache(options_.enable_prefix_cache()); + block_manager_ = std::make_unique(options); + + // only one worker, call init_kv_cache in current thread + return worker_->init_kv_cache(kv_cache_shape); +} + +torch::Tensor VLMEngine::vision_encode(torch::Tensor image, + torch::Tensor tokens) { + return worker_->vision_encode(image, tokens); +} + +ModelOutput VLMEngine::execute_model(Batch& batch) { + // prepare inputs for worker + uint32_t adjusted_batch_size = 0; + if (options_.enable_cuda_graph()) { + // find the closest batch size in the captured graph + const auto it = std::lower_bound( + batch_sizes_.begin(), batch_sizes_.end(), batch.size()); + if (it != batch_sizes_.end()) { + adjusted_batch_size = *it; + } + } + + Timer timer; + auto model_inputs = batch.prepare_model_input(options_.num_decoding_tokens(), + adjusted_batch_size); + COUNTER_ADD(prepare_input_latency_seconds, timer.elapsed_seconds()); + + if (!model_inputs.token_ids.defined()) { + // empty input, just return + return {}; + } + + // only one worker, call blocking forward + auto model_output = worker_->execute_model(model_inputs); + DCHECK(model_output.has_value()) << "Failed to execute model"; + batch.process_sample_output(model_output.value().sample_output); + return model_output.value(); +} + +int64_t VLMEngine::kv_cache_slot_size_in_bytes() const { + const auto dtype_size = torch::scalarTypeToTypeMeta(dtype_).itemsize(); + // key + value for all layers + const int64_t slot_size_in_bytes = + 2 * n_local_kv_heads_ * head_dim_ * args_.n_layers() * dtype_size; + return slot_size_in_bytes; +} + +int64_t VLMEngine::calculate_kv_cache_blocks( + int64_t cache_size_in_bytes) const { + const int32_t block_size = options_.block_size(); + const int64_t block_size_in_bytes = + block_size * kv_cache_slot_size_in_bytes(); + return cache_size_in_bytes / block_size_in_bytes; +} + +} // namespace llm diff --git a/src/engine/vlm_engine.h b/src/engine/vlm_engine.h new file mode 100644 index 00000000..f2cfd267 --- /dev/null +++ b/src/engine/vlm_engine.h @@ -0,0 +1,139 @@ +#pragma once + +#include + +#include "batch.h" +#include "common/macros.h" +#include "engine.h" +#include "memory/block_manager.h" +#include "quantization/quant_args.h" +#include "tokenizer/tokenizer.h" +#include "tokenizer/tokenizer_args.h" +#include "vlm_worker.h" + +namespace llm { + +// The Large Language Model (LLM) engine is a model runner designed to execute +// inference procedures incrementally using batches of requests. It comprises +// three critical components: a model, a tokenizer, and a resource manager. +// The inference process is primarily divided into two stages: 'prefill' and +// 'decode'. +// * 'Prefill': This is the more costly phase, as it involves processing the +// prompt and generating kv caches. +// * 'decode': In this stage, subsequent tokens are generated using the +// previously generated kv caches. +// +// A single batch may contain requests from two stages of the inference +// process. The engine must be adept at handling these diverse requests, +// ensuring optimal resource management. + +class VLMEngine : public Engine, public VisionEngine { + public: + struct Options { + DEFINE_ARG(std::vector, devices); + + // the number of slots per block, default 16, value must be multiple of 16 + DEFINE_ARG(int32_t, block_size) = 16; + + // the maximum cache size in bytes, default 10GB + DEFINE_ARG(int64_t, max_cache_size) = 10737418240; + + // maximum memory utilization allowed, default 0.9 + DEFINE_ARG(double, max_memory_utilization) = 0.9; + + // enable prefix cache + DEFINE_ARG(bool, enable_prefix_cache) = true; + + // number of decoding tokens per sequence + // in speculative decoding, it is the number of speculative tokens + 1 + DEFINE_ARG(int64_t, num_decoding_tokens) = 1; + + // enable cuda graph + DEFINE_ARG(bool, enable_cuda_graph) = true; + + // max sequence length used to capture cuda graphs + DEFINE_ARG(int64_t, cuda_graph_max_seq_len) = 1024; + + // batch sizes to capture cuda graphs + DEFINE_ARG(std::optional>, cuda_graph_batch_sizes); + }; + + // create an engine with the given devices + VLMEngine(const Options& options); + + virtual ~VLMEngine() = default; + + // step the engine forward by one step with the batch + ModelOutput execute_model(Batch& batch) override; + + const Tokenizer* tokenizer() const override { return tokenizer_.get(); } + + BlockManager* block_manager() const override { return block_manager_.get(); } + + const ModelArgs& model_args() const override { return args_; } + + const TokenizerArgs& tokenizer_args() const override { + return tokenizer_args_; + } + + torch::Tensor vision_encode(torch::Tensor image, + torch::Tensor tokens) override; + + const QuantArgs& quant_args() const { return quant_args_; } + + const Options& options() const { return options_; } + + // initialize the engine with the given model weights + bool init(const std::string& model_weights_path); + + bool init_model(const std::string& model_weights_path); + + bool init_kv_cache(int64_t n_blocks); + + bool capture_cuda_graphs(); + + // returns the memory size for the kv cache + int64_t profile_memory_for_kv_cache(); + + // returns the memory size in bytes for each kv cache slot + int64_t kv_cache_slot_size_in_bytes() const; + + // returns the number of kv cache blocks from the given cache size in bytes + int64_t calculate_kv_cache_blocks(int64_t cache_size_in_bytes) const; + + private: + // options + Options options_; + + // dtype + torch::ScalarType dtype_; + + // model args + ModelArgs args_; + + // quantization args + QuantArgs quant_args_; + + // Tokenizer args + TokenizerArgs tokenizer_args_; + + // block manager + std::unique_ptr block_manager_; + + // a list of process groups, with each process group handling a single device + std::vector> process_groups_; + + // batch sizes to capture cuda graphs + std::vector batch_sizes_; + + // tokenizer + std::unique_ptr tokenizer_; + + std::unique_ptr worker_; + + // config for kv cache + int64_t n_local_kv_heads_ = 0; + int64_t head_dim_ = 0; +}; + +} // namespace llm diff --git a/src/engine/vlm_worker.cpp b/src/engine/vlm_worker.cpp new file mode 100644 index 00000000..42c1f737 --- /dev/null +++ b/src/engine/vlm_worker.cpp @@ -0,0 +1,284 @@ +#include "vlm_worker.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common/metrics.h" +#include "common/threadpool.h" +#include "common/timer.h" +#include "memory/kv_cache.h" +#include "memory/memory.h" +#include "model_loader/state_dict.h" +#include "model_parallel/model_parallel.h" +#include "models/parameters.h" +#include "sampling/logits_processor.h" +#include "sampling/sampler.h" + +// latency metrics +DEFINE_COUNTER_FAMILY(execution_latency_seconds, + "Execution latency in seconds"); +DEFINE_COUNTER_INSTANCE(model_execution_latency_seconds, + execution_latency_seconds, + {{"stage", "model"}}); +DEFINE_COUNTER_INSTANCE(logits_processing_latency_seconds, + execution_latency_seconds, + {{"stage", "logits_processing"}}); +DEFINE_COUNTER_INSTANCE(sampling_latency_seconds, + execution_latency_seconds, + {{"stage", "sampling"}}); + +namespace llm { + +VLMWorker::VLMWorker(const ParallelArgs& parallel_args, + const torch::Device& device, + const ModelRunner::Options& runner_options) + : parallel_args_(parallel_args), + device_(device), + runner_options_(runner_options) { + // first worker is the driver + driver_ = parallel_args.rank() == 0; +} + +bool VLMWorker::init_model(torch::ScalarType dtype, + const ModelArgs& args, + const QuantArgs& quant_args) { + CHECK(model_ == nullptr) << "Model is already initialized."; + + // initialize model + args_ = args; + dtype_ = dtype; + const auto options = torch::dtype(dtype_).device(device_); + model_ = CausalVLM::create(args, quant_args, parallel_args_, options); + CHECK(model_ != nullptr) << "Failed to create model."; + model_runner_ = + std::make_unique(model_.get(), device_, runner_options_); + return true; +} + +torch::Tensor VLMWorker::vision_encode(torch::Tensor image, + torch::Tensor tokens) { + return model_->vision_encode(image, tokens); +} + +bool VLMWorker::init_kv_cache(const std::vector& kv_cache_shape) { + CHECK(model_ != nullptr) << "Model is not initialized."; + CHECK(kv_caches_.empty()) << "KV caches are already initialized."; + + // create a KVCache for each layer + const int64_t num_layers = args_.n_layers(); + kv_caches_.reserve(num_layers); + for (int64_t i = 0; i < num_layers; ++i) { + auto key_cache = + torch::empty(kv_cache_shape, torch::dtype(dtype_).device(device_)); + auto value_cache = + torch::empty(kv_cache_shape, torch::dtype(dtype_).device(device_)); + kv_caches_.emplace_back(key_cache, value_cache); + } + return true; +} + +void VLMWorker::capture_cuda_graph(uint32_t batch_size) { + CHECK(model_ != nullptr) << "Model is not initialized."; + CHECK(!kv_caches_.empty()) << "KV caches are not initialized."; + return model_runner_->capture_cuda_graphs(batch_size, kv_caches_); +} + +void VLMWorker::load_state_dict(const StateDict& state_dict) { + CHECK(model_ != nullptr) << "Model is not initialized."; + model_->load_state_dict(state_dict); +} + +void VLMWorker::verify_loaded_weights() const { + CHECK(model_ != nullptr) << "Model is not initialized."; + model_->verify_loaded_weights(); +} + +std::tuple VLMWorker::profile_device_memory() { + CHECK(model_ != nullptr) << "Model is not initialized."; + CHECK(device_.is_cuda()) << "Memory profiling is only supported on GPU."; + + const auto available_memory = memory::available_memory(device_); + const auto total_memory = memory::total_memory(device_); + + return {available_memory, total_memory}; +} + +void VLMWorker::process_group_test() { + torch::DeviceGuard device_guard(device_); + torch::cuda::synchronize(); + + // create random tensors + const auto options = torch::dtype(torch::kHalf).device(device_); + torch::Tensor tensor = torch::randn({10, 10}, options); + // call allreduce + reduce_from_model_parallel_region(tensor, parallel_args_); + // call allgather + gather_from_model_parallel_region(tensor, parallel_args_); + torch::cuda::synchronize(); +} + +std::optional VLMWorker::execute_model(const ModelInput& inputs) { + torch::DeviceGuard device_guard(device_); + at::cuda::getCurrentCUDAStream().synchronize(); + + Timer timer; + + // all tensors should be on the same device as model + auto flatten_tokens = inputs.token_ids.to(device_); + auto flatten_positions = inputs.positions.to(device_); + auto params = inputs.input_params.to(device_); + auto sampling_params = inputs.sampling_params.to(device_, dtype_); + + // call model runner forward to get hidden states + auto hidden_states = model_runner_->forward( + flatten_tokens, flatten_positions, kv_caches_, params); + + torch::Tensor logits; + if (sampling_params.selected_token_idxes.defined()) { + logits = + model_->logits(hidden_states, sampling_params.selected_token_idxes); + } + + at::cuda::getCurrentCUDAStream().synchronize(); + COUNTER_ADD(model_execution_latency_seconds, timer.elapsed_seconds()); + + if (!driver_) { + return std::nullopt; + } + + // driver prepare model output + ModelOutput output; + if (sampling_params.selected_token_idxes.defined()) { + // create and call logits processors + timer.reset(); + auto logits_processor = LogitsProcessor::create(sampling_params); + // apply logits processors to logits (in place) + logits = logits_processor->forward(logits, + sampling_params.unique_token_ids, + sampling_params.unique_token_counts, + sampling_params.unique_token_ids_lens); + COUNTER_ADD(logits_processing_latency_seconds, timer.elapsed_seconds()); + + // set logits to output + output.logits = logits; + + timer.reset(); + auto sampler = std::make_unique(sampling_params.do_sample, + sampling_params.logprobs, + sampling_params.max_top_logprobs); + // select sample logits + auto sample_logits = + logits.index_select(/*dim=*/0, sampling_params.sample_idxes); + auto sample_output = sampler->forward(sample_logits); + COUNTER_ADD(sampling_latency_seconds, timer.elapsed_seconds()); + + // set sample output to output + output.sample_output = sample_output; + + // carry over the sampling params + output.do_sample = sampling_params.do_sample; + output.logprobs = sampling_params.logprobs; + output.max_top_logprobs = sampling_params.max_top_logprobs; + } + return output; +} + +folly::SemiFuture> +VLMWorker::profile_device_memory_async() { + folly::Promise> promise; + auto future = promise.getSemiFuture(); + threadpool_.schedule([this, promise = std::move(promise)]() mutable { + const auto output = this->profile_device_memory(); + promise.setValue(output); + }); + return future; +} + +folly::SemiFuture> VLMWorker::execute_model_async( + const ModelInput& inputs) { + folly::Promise> promise; + auto future = promise.getSemiFuture(); + threadpool_.schedule( + [this, inputs = inputs, promise = std::move(promise)]() mutable { + // run the model on the given input in working thread + const auto output = this->execute_model(inputs); + promise.setValue(output); + }); + return future; +} + +folly::SemiFuture VLMWorker::process_group_test_async() { + folly::Promise promise; + auto future = promise.getSemiFuture(); + threadpool_.schedule([this, promise = std::move(promise)]() mutable { + this->process_group_test(); + promise.setValue(); + }); + return future; +} + +// initialize model, cache manager. async call +folly::SemiFuture VLMWorker::init_model_async( + torch::ScalarType dtype, + const ModelArgs& args, + const QuantArgs& quant_args) { + folly::Promise promise; + auto future = promise.getSemiFuture(); + threadpool_.schedule([this, + dtype, + &args, + &quant_args, + promise = std::move(promise)]() mutable { + const bool success = this->init_model(dtype, args, quant_args); + promise.setValue(success); + }); + return future; +} + +folly::SemiFuture VLMWorker::init_kv_cache_async( + const std::vector& kv_cache_shape) { + folly::Promise promise; + auto future = promise.getSemiFuture(); + threadpool_.schedule( + [this, &kv_cache_shape, promise = std::move(promise)]() mutable { + const bool success = this->init_kv_cache(kv_cache_shape); + promise.setValue(success); + }); + return future; +} + +folly::SemiFuture VLMWorker::capture_cuda_graph_async( + uint32_t batch_size) { + folly::Promise promise; + auto future = promise.getSemiFuture(); + threadpool_.schedule( + [this, batch_size = batch_size, promise = std::move(promise)]() mutable { + this->capture_cuda_graph(batch_size); + promise.setValue(); + }); + return future; +} + +folly::SemiFuture VLMWorker::load_state_dict_async( + const StateDict& state_dict) { + folly::Promise promise; + auto future = promise.getSemiFuture(); + threadpool_.schedule( + [this, &state_dict, promise = std::move(promise)]() mutable { + // load the model weights from state_dict within the working thread + this->load_state_dict(state_dict); + promise.setValue(); + }); + return future; +} + +} // namespace llm diff --git a/src/engine/vlm_worker.h b/src/engine/vlm_worker.h new file mode 100644 index 00000000..28aa69b6 --- /dev/null +++ b/src/engine/vlm_worker.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include + +#include "common/threadpool.h" +#include "model_loader/state_dict.h" +#include "model_parallel/parallel_args.h" +#include "model_runner.h" +#include "models/causal_vlm.h" +#include "models/model_args.h" +#include "models/parameters.h" +#include "parameters.h" +#include "quantization/quant_args.h" + +namespace llm { + +class VLMWorker final { + public: + VLMWorker(const ParallelArgs& parallel_args, + const torch::Device& device, + const ModelRunner::Options& runner_options); + + ~VLMWorker() = default; + + // initialize model, cache manager. blocking call + bool init_model(torch::ScalarType dtype, + const ModelArgs& args, + const QuantArgs& quant_args); + + torch::Tensor vision_encode(torch::Tensor image, torch::Tensor tokens); + + // Load the model weights from state_dict. blocking call + // can be called multiple times to reload the model with different parameters + void load_state_dict(const StateDict& state_dict); + + // verify if the model is loaded correctly + void verify_loaded_weights() const; + + // returns available memory and total memory + std::tuple profile_device_memory(); + + // initialize kv cache. blocking call + bool init_kv_cache(const std::vector& kv_cache_shape); + + // Run the model on the given input. blocking call + std::optional execute_model(const ModelInput& inputs); + + // capture cuda graph for the model. blocking call + void capture_cuda_graph(uint32_t batch_size); + + // initialize model, cache manager. async call + folly::SemiFuture init_model_async(torch::ScalarType dtype, + const ModelArgs& args, + const QuantArgs& quant_args); + + // Load the model weights from state_dict. async call + // the future returns a successfull status with no meaningful value + folly::SemiFuture load_state_dict_async( + const StateDict& state_dict); + + folly::SemiFuture> profile_device_memory_async(); + + // initialize kv cache. async call + folly::SemiFuture init_kv_cache_async( + const std::vector& kv_cache_shape); + + // Run the model on the given input. async call + // the future returns a successfull status with no meaningful value + folly::SemiFuture> execute_model_async( + const ModelInput& inputs); + + folly::SemiFuture process_group_test_async(); + + // capture cuda graph for the model. async call + folly::SemiFuture capture_cuda_graph_async(uint32_t batch_size); + + const torch::Device& device() const { return device_; } + + private: + void process_group_test(); + + // whether the worker is a driver, who takes care of the sampling + bool driver_ = false; + + // working thread + ThreadPool threadpool_; + + // dtype of the model + torch::ScalarType dtype_; + + // device to run the model on + torch::Device device_; + + // parallel args + ParallelArgs parallel_args_; + + // model args + ModelArgs args_; + + // kv caches + std::vector kv_caches_; + + // causal VLM model + std::unique_ptr model_; + + // runner options + ModelRunner::Options runner_options_; + + // model runner that runs the model, with cuda graph if enabled + std::unique_ptr model_runner_; +}; + +} // namespace llm diff --git a/src/handlers/CMakeLists.txt b/src/handlers/CMakeLists.txt index 424afa56..e23e4fa8 100644 --- a/src/handlers/CMakeLists.txt +++ b/src/handlers/CMakeLists.txt @@ -6,8 +6,10 @@ cc_library( HDRS sampling_params.h llm_handler.h + vlm_handler.h SRCS llm_handler.cpp + vlm_handler.cpp DEPS :common :scheduler diff --git a/src/handlers/vlm_handler.cpp b/src/handlers/vlm_handler.cpp new file mode 100644 index 00000000..aca66940 --- /dev/null +++ b/src/handlers/vlm_handler.cpp @@ -0,0 +1,475 @@ +#include "vlm_handler.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "common/metrics.h" +#include "common/scope_guard.h" +#include "common/timer.h" +#include "engine/utils.h" +#include "engine/vlm_engine.h" +#include "models/model_args.h" +#include "models/model_registry.h" +#include "request/output.h" +#include "request/request.h" +#include "speculative/speculative_engine.h" + +DEFINE_COUNTER_FAMILY(request_status_total, "Total number of request status"); +DEFINE_COUNTER_INSTANCE(request_ok, request_status_total, {{"code", "OK"}}); +DEFINE_COUNTER_INSTANCE(request_cancelled, + request_status_total, + {{"code", "CANCELLED"}}); +DEFINE_COUNTER_INSTANCE(request_unknown, + request_status_total, + {{"code", "UNKNOWN"}}); +DEFINE_COUNTER_INSTANCE(request_invalid_argument, + request_status_total, + {{"code", "INVALID_ARGUMENT"}}); +DEFINE_COUNTER_INSTANCE(request_deadline_exceeded, + request_status_total, + {{"code", "DEADLINE_EXCEEDED"}}); +DEFINE_COUNTER_INSTANCE(request_resource_exhausted, + request_status_total, + {{"code", "RESOURCE_EXHAUSTED"}}); +DEFINE_COUNTER_INSTANCE(request_unauthenticated, + request_status_total, + {{"code", "UNAUTHENTICATED"}}); +DEFINE_COUNTER_INSTANCE(request_unavailable, + request_status_total, + {{"code", "UNAVAILABLE"}}); +DEFINE_COUNTER_INSTANCE(request_unimplemented, + request_status_total, + {{"code", "UNIMPLEMENTED"}}); + +DEFINE_COUNTER_FAMILY(request_handling_latency_seconds, + "Request handling latency in seconds"); +DEFINE_COUNTER_INSTANCE(chat_handling_latency_seconds, + request_handling_latency_seconds, + {{"type", "chat"}}); +DEFINE_COUNTER_INSTANCE(completion_handling_latency_seconds, + request_handling_latency_seconds, + {{"type", "completion"}}); + +DEFINE_COUNTER(tokenization_latency_seconds, + "Prompt tokenization latency in seconds"); +DEFINE_COUNTER(chat_template_latency_seconds, + "Chat template latency in seconds"); + +namespace llm { +namespace { + +#define CALLBACK_WITH_ERROR(CODE, MSG) callback(Status{CODE, MSG}); + +void log_request_status(StatusCode code) { + switch (code) { + case StatusCode::OK: + COUNTER_INC(request_ok); + break; + case StatusCode::CANCELLED: + COUNTER_INC(request_cancelled); + break; + case StatusCode::UNKNOWN: + COUNTER_INC(request_unknown); + break; + case StatusCode::INVALID_ARGUMENT: + COUNTER_INC(request_invalid_argument); + break; + case StatusCode::DEADLINE_EXCEEDED: + COUNTER_INC(request_deadline_exceeded); + break; + case StatusCode::RESOURCE_EXHAUSTED: + COUNTER_INC(request_resource_exhausted); + break; + case StatusCode::UNAUTHENTICATED: + COUNTER_INC(request_unauthenticated); + break; + case StatusCode::UNAVAILABLE: + COUNTER_INC(request_unavailable); + break; + case StatusCode::UNIMPLEMENTED: + COUNTER_INC(request_unimplemented); + break; + default: + COUNTER_INC(request_unknown); + break; + } +} + +bool verify_params(const SamplingParams& sp, OutputCallback callback) { + if (sp.n == 0) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "n should be greater than 0"); + return false; + } + if (sp.best_of.has_value()) { + if (sp.n > sp.best_of.value()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "n should be less than or equal to best_of"); + return false; + } + } + + // up to 4 stop sequences + if (sp.stop.has_value() && sp.stop.value().size() > 4) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "stop size is too large"); + return false; + } + + // temperature between [0.0, 2.0] + if (sp.temperature < 0.0 || sp.temperature > 2.0) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "temperature must be between 0.0 and 2.0"); + return false; + } + + // top_p between [0.0, 1.0] + if (sp.top_p < 0.0 || sp.top_p > 1.0) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "top_p must be between 0.0 and 1.0"); + return false; + } + + if (sp.logprobs) { + if (sp.echo) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "logprobs is not supported with echo"); + return false; + } + if (sp.top_logprobs < 0 || sp.top_logprobs > 20) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "logprobs must be between 0 and 20"); + return false; + } + } + + // presence_penalty between [-2.0, 2.0] + if (sp.presence_penalty < -2.0 || sp.presence_penalty > 2.0) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "presence_penalty must be between -2.0 and 2.0"); + return false; + } + + // frequency_penalty between [0.0, 2.0] + if (sp.frequency_penalty < 0.0 || sp.frequency_penalty > 2.0) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "frequency_penalty must be between 0.0 and 2.0"); + return false; + } + return true; +} + +} // namespace + +VLMHandler::VLMHandler(const Options& options) : options_(options) { + // construct engine + const auto devices = parse_devices(options.devices().value_or("auto")); + LOG(INFO) << "Creating engine with devices: " << to_string(devices); + + VLMEngine::Options eng_options; + eng_options.devices(devices) + .block_size(options.block_size()) + .max_cache_size(options.max_cache_size()) + .max_memory_utilization(options.max_memory_utilization()) + .enable_prefix_cache(options.enable_prefix_cache()) + .enable_cuda_graph(options.enable_cuda_graph()) + .cuda_graph_max_seq_len(options.cuda_graph_max_seq_len()) + .cuda_graph_batch_sizes(options.cuda_graph_batch_sizes()); + + auto engine = std::make_unique(eng_options); + CHECK(engine->init(options.model_path())); + engine_ = std::move(engine); + + model_args_ = engine_->model_args(); + + ContinuousScheduler::Options scheduler_options; + scheduler_options.max_tokens_per_batch(options.max_tokens_per_batch()) + .max_seqs_per_batch(options.max_seqs_per_batch()); + scheduler_ = + std::make_unique(engine_.get(), scheduler_options); + + // construct tokenizers and handling threads + const auto* tokenizer = engine_->tokenizer(); + for (size_t i = 0; i < options.num_handling_threads(); ++i) { + // create a tokenizer for each thread for now + tokenizers_.emplace_back(tokenizer->clone()); + handling_threads_.emplace_back([this, i] { handling_loop(i); }); + } +} + +VLMHandler::~VLMHandler() { reset(); } + +std::future VLMHandler::schedule_async(torch::Tensor image, + std::string prompt, + SamplingParams sp, + Priority priority, + bool stream, + OutputCallback callback) { + // add one pending request + scheduler_->inc_pending_requests(1); + return schedule( + std::move(image), + std::move(prompt), + std::move(sp), + priority, + stream, + [callback = std::move(callback)](const RequestOutput& output) { + if (output.status.has_value()) { + log_request_status(output.status.value().code()); + } + return callback(output); + }); +} + +std::future VLMHandler::schedule(torch::Tensor image, + std::string prompt, + SamplingParams sp, + Priority priority, + bool stream, + OutputCallback callback) { + std::promise promise; + auto future = promise.get_future(); + // add into the queue + queue_.push([this, + promise = std::move(promise), + image = std::move(image), + prompt = std::move(prompt), + sp = std::move(sp), + priority, + stream, + callback = std::move(callback)](size_t tid) mutable { + AUTO_COUNTER(completion_handling_latency_seconds); + + // remove the pending request after scheduling + SCOPE_GUARD([this] { scheduler_->dec_pending_requests(); }); + + Timer timer; + // verify the prompt + if (!verify_params(sp, callback)) { + promise.set_value(false); + return; + } + + auto request = create_request(tid, + std::move(image), + std::move(prompt), + sp, + priority, + stream, + callback); + if (!request) { + promise.set_value(false); + return; + } + + if (!scheduler_->schedule(request)) { + CALLBACK_WITH_ERROR(StatusCode::RESOURCE_EXHAUSTED, + "No available resources to schedule request"); + promise.set_value(false); + return; + } + promise.set_value(true); + }); + return future; +} + +void VLMHandler::handling_loop(size_t tid) { + while (true) { + Task task = queue_.pop(); + if (task == nullptr) { + // nullptr is a signal to exit + break; + } + task(tid); + } +} + +void VLMHandler::start() { + loop_thread_ = std::thread([this]() { + const bool running = running_.load(std::memory_order_relaxed); + CHECK(!running) << "Handler is already running"; + + running_.store(true, std::memory_order_relaxed); + const auto timeout = absl::Milliseconds(500); + while (!stoped_.load(std::memory_order_relaxed)) { + // move scheduler forward + scheduler_->step(timeout); + } + running_.store(false, std::memory_order_relaxed); + }); +} + +// stop the engine +void VLMHandler::stop() { + // set stop flag + stoped_.store(true, std::memory_order_relaxed); + // wait for the loop thread to finish + if (loop_thread_.joinable()) { + loop_thread_.join(); + } +} + +void VLMHandler::run_until_complete() { + const bool running = running_.load(std::memory_order_relaxed); + CHECK(!running) << "Handler is already running"; + + running_.store(true, std::memory_order_relaxed); + scheduler_->run_until_complete(); + running_.store(false, std::memory_order_relaxed); +} + +std::unique_ptr VLMHandler::create_request(size_t tid, + torch::Tensor image, + std::string prompt, + const SamplingParams& sp, + Priority priority, + bool stream, + OutputCallback callback) { + if (prompt.empty()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is empty"); + return nullptr; + } + + Timer timer; + std::vector prompt_tokens; + if (!tokenizers_[tid]->encode(prompt, &prompt_tokens)) { + LOG(ERROR) << "Failed to encode prompt: " << prompt; + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "Failed to encode prompt"); + return nullptr; + } + COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + + // encode the image, encode & projector + auto vision_engine = dynamic_cast(engine_.get()); + auto input_embedding = vision_engine->vision_encode( + image, torch::tensor(prompt_tokens, torch::kInt)); + + // TODO: prompt_token is not enough, need to add image token size + const int64_t max_context_len = model_args_.max_position_embeddings(); + if (prompt_tokens.size() >= max_context_len) { + LOG(ERROR) << "Prompt is too long: " << prompt_tokens.size(); + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is too long"); + return nullptr; + } + + uint32_t max_tokens = sp.max_tokens; + if (max_tokens == 0) { + const uint32_t kDefaultMaxTokens = 16; + max_tokens = kDefaultMaxTokens; + } + + // allocate enough capacity for prompt tokens, max tokens, and speculative + // tokens, TODO: add image token size as well. + const size_t capacity = prompt_tokens.size() + max_tokens + 1; + const size_t best_of = sp.best_of.value_or(sp.n); + auto request = std::make_unique(std::move(prompt), + std::move(prompt_tokens), + std::move(input_embedding), + capacity, + sp.n, + best_of, + sp.logprobs); + + // sampling parameters + auto& sampling_param = request->sampling_param; + sampling_param.frequency_penalty = sp.frequency_penalty; + sampling_param.presence_penalty = sp.presence_penalty; + sampling_param.repetition_penalty = sp.repetition_penalty; + sampling_param.temperature = sp.temperature; + sampling_param.top_p = sp.top_p; + sampling_param.top_k = sp.top_k; + sampling_param.logprobs = sp.logprobs; + sampling_param.top_logprobs = sp.top_logprobs; + if (best_of > sp.n) { + // enable logprobs for best_of to generate sequence logprob + sampling_param.logprobs = true; + } + // sampling_param.do_sample = sp.do_sample; + // sampling_param.seed = sp.seed; + + // stopping criteria + auto& stopping_criteria = request->stopping_criteria; + stopping_criteria.max_tokens = max_tokens; + stopping_criteria.max_context_len = max_context_len; + stopping_criteria.ignore_eos = sp.ignore_eos; + stopping_criteria.eos_token_id = model_args_.eos_token_id(); + + if (sp.stop_token_ids.has_value()) { + const auto& stop_token_ids = sp.stop_token_ids.value(); + stopping_criteria.stop_token_ids.insert(stop_token_ids.begin(), + stop_token_ids.end()); + } else { + // otherwise use default stop token id from model args + stopping_criteria.stop_token_ids = model_args_.stop_token_ids(); + } + + if (sp.stop.has_value()) { + for (const auto& s : sp.stop.value()) { + std::vector stop_tokens; + if (!tokenizers_[tid]->encode(s, &stop_tokens)) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "Failed to encode stop sequence"); + LOG(ERROR) << "Failed to encode stop sequence: " << s; + return nullptr; + } + stopping_criteria.stop_sequences.push_back(std::move(stop_tokens)); + } + } + + // results cannot be streamed when best_of != n + if (best_of != sp.n) { + stream = false; + } + request->stream = stream; + request->priority = priority; + request->echo = sp.echo; + + // set callback for outputs + request->on_output = callback; + + // add one sequence, rest will be added by scheduler + request->add_sequence(); + return request; +} + +std::vector VLMHandler::encode(const std::string& text) { + std::vector tokens; + engine_->tokenizer()->encode(text, &tokens); + return tokens; +} + +std::string VLMHandler::decode(const std::vector& tokens, + bool skip_special_tokens) { + return engine_->tokenizer()->decode(tokens, skip_special_tokens); +} + +void VLMHandler::reset() { + stop(); + + // stop all handling threads + // push nullptr to the queue to signal threads to exit + for (size_t i = 0; i < handling_threads_.size(); ++i) { + queue_.push(nullptr); + } + // wait for all threads to finish + for (auto& thread : handling_threads_) { + thread.join(); + } + handling_threads_.clear(); + + // release all underlying resources + scheduler_.reset(); + engine_.reset(); + tokenizers_.clear(); + + // torch::cuda::empty_cache(); + c10::cuda::CUDACachingAllocator::emptyCache(); +} + +} // namespace llm diff --git a/src/handlers/vlm_handler.h b/src/handlers/vlm_handler.h new file mode 100644 index 00000000..82c9b533 --- /dev/null +++ b/src/handlers/vlm_handler.h @@ -0,0 +1,154 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "chat_template/chat_template.h" +#include "common/concurrent_queue.h" +#include "engine/engine.h" +#include "request/output.h" +#include "sampling_params.h" +#include "scheduler/continuous_scheduler.h" + +namespace llm { + +// callback function for output, return true to continue, false to stop/cancel +using OutputCallback = std::function; + +// NOLINTNEXTLINE +class VLMHandler { + public: + struct Options { + DEFINE_ARG(std::string, model_path); + + DEFINE_ARG(std::optional, devices); + + // the number of slots per block, default 16, value must be multiple of 16 + DEFINE_ARG(int32_t, block_size) = 16; + + // the maximum cache size in bytes, default 10GB + DEFINE_ARG(int64_t, max_cache_size) = static_cast(10) * 1024 * + 1024 * 1024; + + // maximum memory utilization allowed, default 0.9 + DEFINE_ARG(double, max_memory_utilization) = 0.9; + + // enable prefix cache + DEFINE_ARG(bool, enable_prefix_cache) = true; + + // enable cuda graph + DEFINE_ARG(bool, enable_cuda_graph) = true; + + // max sequence length used to capture cuda graphs + DEFINE_ARG(int64_t, cuda_graph_max_seq_len) = 2048; + + // batch sizes to capture cuda graphs + DEFINE_ARG(std::optional>, cuda_graph_batch_sizes); + + // the maximum number of tokens per batch + DEFINE_ARG(int32_t, max_tokens_per_batch) = 256; + + // the maximum number of sequences per batch + DEFINE_ARG(int32_t, max_seqs_per_batch) = 64; + + // the number of threads to use for handling requests + DEFINE_ARG(size_t, num_handling_threads) = 4; + + DEFINE_ARG(std::string, image_input_type) = "pixel_values"; + + DEFINE_ARG(int64_t, image_token_id) = 32000; + + DEFINE_ARG(std::string, image_input_shape) = "1,3,336,336"; + + DEFINE_ARG(int32_t, image_feature_size) = 576; + }; + + VLMHandler(const Options& options); + + ~VLMHandler(); + + // schedule a request, the engine will execute the request asynchronously + // and call the callback with output when the request is done + // the callback will be called multiple times if the request is a streaming + // request + std::future schedule_async(torch::Tensor image, + std::string prompt, + SamplingParams sp, + Priority priority, + bool stream, + OutputCallback callback); + + // start the handling loop + void start(); + + // stop the engine + void stop(); + + // run until complete, blocking call + void run_until_complete(); + + std::vector encode(const std::string& text); + + std::string decode(const std::vector& tokens, + bool skip_special_tokens); + + // release underlying resources + void reset(); + + const Options& options() const { return options_; } + + private: + using Task = folly::Function; + std::unique_ptr create_request(size_t tid, + torch::Tensor image, + std::string prompt, + const SamplingParams& sp, + Priority priority, + bool stream, + OutputCallback callback); + + std::future schedule(torch::Tensor image, + std::string prompt, + SamplingParams sp, + Priority priority, + bool stream, + OutputCallback callback); + + void handling_loop(size_t tid); + + const Options options_; + + std::unique_ptr engine_; + + std::unique_ptr scheduler_; + + // model args + ModelArgs model_args_; + + // thread pool for handling requests + std::vector handling_threads_; + + // queue for tasks + ConcurrentQueue queue_; + + // we don't know if tokenizer is thread safe, so we create one for each thread + // for now + std::vector> tokenizers_; + + // thread for moving forward the scheduler + std::thread loop_thread_; + + // flag to stop the loop + std::atomic_bool stoped_{false}; + + // flag to indicate if the handler is running + std::atomic_bool running_{false}; +}; + +} // namespace llm diff --git a/src/layers/activation.cpp b/src/layers/activation.cpp index e314c586..2067b66c 100644 --- a/src/layers/activation.cpp +++ b/src/layers/activation.cpp @@ -84,6 +84,10 @@ ActFunc Activation::get_act_func(const std::string& name, if (boost::iequals(name, "gelu")) { return gelu; } + // TODO: need to support quick_gelu + if (boost::iequals(name, "quick_gelu")) { + return gelu; + } if (boost::iequals(name, "gelu_fast")) { return device.is_cuda() && !FLAGS_disable_custom_kernels ? kernel::gelu_fast : gelu_fast; diff --git a/src/models/CMakeLists.txt b/src/models/CMakeLists.txt index 582ff68a..7f95c02a 100644 --- a/src/models/CMakeLists.txt +++ b/src/models/CMakeLists.txt @@ -9,9 +9,11 @@ cc_library( parameters.h model_registry.h causal_lm.h + causal_vlm.h SRCS model_registry.cpp causal_lm.cpp + causal_vlm.cpp DEPS :common :layers diff --git a/src/models/causal_vlm.cpp b/src/models/causal_vlm.cpp new file mode 100644 index 00000000..e53cc985 --- /dev/null +++ b/src/models/causal_vlm.cpp @@ -0,0 +1,26 @@ +#include "causal_vlm.h" + +#include +#include + +#include "model_args.h" +#include "models/model_registry.h" + +namespace llm { + +std::unique_ptr CausalVLM::create( + const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // get the factory function for the model type from model registry + auto factory = ModelRegistry::get_causalvlm_factory(args.model_type()); + if (factory) { + return factory(args, quant_args, parallel_args, options); + } + + LOG(ERROR) << "Unsupported model type: " << args.model_type(); + return nullptr; +} + +} // namespace llm diff --git a/src/models/causal_vlm.h b/src/models/causal_vlm.h new file mode 100644 index 00000000..12fec8e6 --- /dev/null +++ b/src/models/causal_vlm.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include + +#include + +#include "causal_lm.h" +#include "memory/kv_cache.h" +#include "model_args.h" +#include "model_loader/state_dict.h" +#include "model_parallel/parallel_args.h" +#include "parameters.h" +#include "quantization/quant_args.h" + +namespace llm { + +// An interface for causal language models that can hold different models. +class CausalVLM : public CausalLM { + public: + ~CausalVLM() override = default; + + virtual torch::Tensor vision_encode(torch::Tensor image, + torch::Tensor tokens) = 0; + + static std::unique_ptr create(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); +}; + +// an template class to hold different models without using virtual functions. +template +class CausalVLMImpl : public CausalVLM { + public: + CausalVLMImpl(Model model, const torch::TensorOptions& options) + : model_(std::move(model)), options_(options) {} + + torch::Tensor vision_encode(torch::Tensor image, + torch::Tensor tokens) override { + return model_->vision_encode(image, tokens); + } + + torch::Tensor forward(const torch::Tensor& tokens, // [num_tokens] + const torch::Tensor& positions, // [num_tokens] + std::vector& kv_caches, + const InputParameters& parameters) override { + return model_->forward(tokens, positions, kv_caches, parameters); + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) override { + return model_->logits(hidden_states, seleted_idxes); + } + + void load_state_dict(const StateDict& state_dict) override { + model_->load_state_dict(state_dict); + } + + void verify_loaded_weights() const override { + return model_->verify_loaded_weights(); + } + + torch::Device device() const override { return options_.device(); } + + const torch::TensorOptions& options() const override { return options_; } + + private: + // underlying model + Model model_; + + // tensor options + torch::TensorOptions options_; +}; + +} // namespace llm diff --git a/src/models/huggingface/llava.h b/src/models/huggingface/llava.h new file mode 100644 index 00000000..bf5c1d09 --- /dev/null +++ b/src/models/huggingface/llava.h @@ -0,0 +1,771 @@ +#pragma once + +#include +#include + +#include "chat_template/common_chat_template.h" +#include "layers/activation.h" +#include "layers/attention/attention.h" +#include "layers/attention/handler.h" +#include "layers/embedding.h" +#include "layers/linear.h" +#include "layers/normalization.h" +#include "layers/qkv_linear.h" +#include "memory/kv_cache.h" +#include "models/huggingface/llama.h" +#include "models/model_args.h" +#include "models/model_registry.h" +#include "models/parameters.h" + +// llava model: llava-hf/llava-1.5-7b-hf + +namespace llm::hf { +class CLIPVisionEmbeddingImpl : public torch::nn::Module { + public: + CLIPVisionEmbeddingImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + embed_dim_ = args.mm_hidden_size(); + class_embedding_ = + register_parameter("class_embedding", torch::randn({embed_dim_})); + patch_embedding_ = register_module( + "patch_embedding", + torch::nn::Conv2d(torch::nn::Conv2dOptions(args.mm_num_channels(), + embed_dim_, + args.mm_patch_size()) + .stride(args.mm_patch_size()) + .bias(false))); + auto num_patches = (args.mm_image_size() / args.mm_patch_size()) * + (args.mm_image_size() / args.mm_patch_size()); + auto num_positions = num_patches + 1; + position_embedding_ = register_parameter( + "position_embedding", torch::randn({num_positions, embed_dim_})); + position_ids_ = register_buffer( + "position_ids", + torch::arange(0, num_positions, torch::kLong).unsqueeze(0)); + } + + torch::Tensor forward(const torch::Tensor& pixel_values) { + int64_t batch_size = pixel_values.size(0); + auto patch_embeds = + patch_embedding_ + ->forward(pixel_values.to(patch_embedding_->weight.dtype())) + .flatten(2) + .transpose(1, 2); + + auto class_embeds = class_embedding_.expand({batch_size, 1, embed_dim_}); + auto embeddings = torch::cat({class_embeds, patch_embeds}, 1); + embeddings += position_embedding_.index({position_ids_}); + return embeddings; + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + const auto cls = state_dict.get_tensor("class_embedding"); + DCHECK_EQ(cls.sizes(), class_embedding_.sizes()); + class_embedding_.copy_(cls); + + const auto pos = state_dict.get_tensor("position_embedding.weight"); + DCHECK_EQ(pos.sizes(), position_embedding_.sizes()); + position_embedding_.copy_(pos); + + const auto weight = state_dict.get_tensor("patch_embedding.weights"); + DCHECK_EQ(patch_embedding_->weight.sizes(), weight.sizes()); + patch_embedding_->weight.copy_(weight); + } + + void verify_loaded_weights(const std::string& prefix) const { + // No need to verify, already checked in load_state_dict + } + + private: + int64_t embed_dim_; + + torch::Tensor class_embedding_; + torch::Tensor position_ids_; + torch::nn::Conv2d patch_embedding_{nullptr}; + torch::Tensor position_embedding_{nullptr}; +}; +TORCH_MODULE(CLIPVisionEmbedding); + +class CLIPMLPImpl : public torch::nn::Module { + public: + CLIPMLPImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // TODO: default activation is quick_gelu, need to support quick_gelu + // https://github.com/huggingface/transformers/.../configuration_clip.py + act_ = Activation::get_act_func(args.mm_hidden_act(), options.device()); + CHECK(act_ != nullptr); + + fc1_ = register_module("fc1", + ColumnParallelLinear(args.mm_hidden_size(), + args.mm_intermediate_size(), + /*bias=*/true, + /*gather_output=*/false, + quant_args, + parallel_args, + options)); + fc2_ = register_module("fc2", + RowParallelLinear(args.mm_intermediate_size(), + args.mm_hidden_size(), + /*bias=*/true, + /*input_is_parallelized*/ true, + quant_args, + parallel_args, + options)); + } + + torch::Tensor forward(const torch::Tensor& hidden_states) { + return fc2_(act_(fc1_(hidden_states))); + } + + void load_state_dict(const StateDict& state_dict) { + fc1_->load_state_dict(state_dict.select("fc1.")); + fc2_->load_state_dict(state_dict.select("fc2.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + fc1_->verify_loaded_weights(prefix + "fc1."); + fc2_->verify_loaded_weights(prefix + "fc2."); + } + + private: + ActFunc act_{nullptr}; + ColumnParallelLinear fc1_{nullptr}; + RowParallelLinear fc2_{nullptr}; +}; +TORCH_MODULE(CLIPMLP); + +// TODO: Optimize CLIPAttention +class CLIPAttentionImpl : public torch::nn::Module { + public: + CLIPAttentionImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + CHECK(args.mm_hidden_size() % args.mm_num_attention_heads() == 0); + + head_dim_ = args.mm_head_dim(); + embed_dim_ = args.mm_hidden_size(); + const int32_t world_size = parallel_args.world_size(); + num_heads_ = args.mm_num_attention_heads(); + const int64_t n_local_heads = num_heads_ / world_size; + + qkv_sizes_ = {n_local_heads * args.mm_head_dim(), + n_local_heads * args.mm_head_dim(), + n_local_heads * args.mm_head_dim()}; + + scale_ = 1.0f / std::sqrt(static_cast(args.mm_head_dim())); + dropout_ = args.mm_dropout(); + + // register submodules + qkv_proj_ = register_module("qkv_proj", + QKVColumnParallelLinear(args.mm_hidden_size(), + num_heads_, + num_heads_, + head_dim_, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options)); + + o_proj_ = register_module("o_proj", + RowParallelLinear(args.mm_hidden_size(), + args.mm_hidden_size(), + /*bias=*/false, + /*input_is_parallelized=*/true, + quant_args, + parallel_args, + options)); + } + + torch::Tensor forward(const torch::Tensor& hidden_states) { + auto qkv = + qkv_proj_(hidden_states).split(/*split_size=*/qkv_sizes_, /*dim=*/-1); + DCHECK_EQ(qkv.size(), 3); + + auto query_states = qkv[0] * scale_; + auto bsz = hidden_states.size(0); + auto tgz_len = hidden_states.size(1); + auto key_states = shape(qkv[1], -1, bsz); + auto value_states = shape(qkv[2], -1, bsz); + + auto proj_shape = std::vector{bsz * num_heads_, -1, head_dim_}; + query_states = shape(query_states, tgz_len, bsz).view(proj_shape); + key_states = key_states.view(proj_shape); + value_states = value_states.view(proj_shape); + + auto src_len = key_states.size(1); + auto attn_weights = torch::bmm(query_states, key_states.transpose(1, 2)); + DCHECK_EQ(attn_weights.sizes(), + torch::IntArrayRef({bsz * num_heads_, tgz_len, src_len})); + + attn_weights = torch::softmax(attn_weights, -1); + auto attn_probs = torch::dropout(attn_weights, dropout_, false); + auto attn_output = torch::bmm(attn_probs, value_states); + + DCHECK_EQ(attn_output.sizes(), + torch::IntArrayRef({bsz * num_heads_, tgz_len, head_dim_})); + attn_output = + attn_output + .view(torch::IntArrayRef({bsz, num_heads_, tgz_len, head_dim_})) + .transpose(1, 2) + .contiguous(); + attn_output = + attn_output.view(torch::IntArrayRef({bsz, tgz_len, embed_dim_})); + + return o_proj_(attn_output); + } + + void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + qkv_proj_->load_state_dict( + state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); + o_proj_->load_state_dict(state_dict.select("out_proj.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); + o_proj_->verify_loaded_weights(prefix + "out_proj."); + } + + private: + torch::Tensor shape(torch::Tensor tensor, int64_t seq_len, int64_t bsz) { + return tensor.view({bsz, seq_len, num_heads_, head_dim_}) + .transpose(1, 2) + .contiguous(); + } + + private: + int64_t embed_dim_; + int64_t num_heads_; + int64_t head_dim_; + float scale_; + float dropout_; + std::vector qkv_sizes_; + QKVColumnParallelLinear qkv_proj_{nullptr}; + RowParallelLinear o_proj_{nullptr}; +}; +TORCH_MODULE(CLIPAttention); + +class CLIPEncoderLayerImpl : public torch::nn::Module { + public: + CLIPEncoderLayerImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + self_attn_ = register_module( + "self_attn", CLIPAttention(args, quant_args, parallel_args, options)); + layer_norm1_ = register_module("layer_norm1", + LayerNorm(args.mm_hidden_size(), + args.mm_layer_norm_eps(), + /*bias=*/true, + options)); + layer_norm2_ = register_module("layer_norm2", + LayerNorm(args.mm_hidden_size(), + args.mm_layer_norm_eps(), + /*bias=*/true, + options)); + } + + // TODO: self_attn, attention_mask, causal_attention_mask + torch::Tensor forward(const torch::Tensor& hidden_states) { + auto residual = hidden_states; + auto h = self_attn_(layer_norm1_(hidden_states)) + residual; + residual = h; + h = mlp_(layer_norm2_(h)) + residual; + return h; + } + + void load_state_dict(const StateDict& state_dict) { + self_attn_->load_state_dict(state_dict.select("self_attn.")); + layer_norm1_->load_state_dict(state_dict.select("layer_norm1.")); + mlp_->load_state_dict(state_dict.select("mlp.")); + layer_norm2_->load_state_dict(state_dict.select("layer_norm2.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + self_attn_->verify_loaded_weights(prefix + "self_attn."); + layer_norm1_->verify_loaded_weights(prefix + "layer_norm1."); + mlp_->verify_loaded_weights(prefix + "mlp."); + layer_norm2_->verify_loaded_weights(prefix + "layer_norm2."); + } + + private: + CLIPAttention self_attn_{nullptr}; + LayerNorm layer_norm1_{nullptr}; + CLIPMLP mlp_{nullptr}; + LayerNorm layer_norm2_{nullptr}; +}; +TORCH_MODULE(CLIPEncoderLayer); + +class CLIPEncoderImpl : public torch::nn::Module { + public: + CLIPEncoderImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + layers_.reserve(args.mm_num_hidden_layers()); + for (int32_t i = 0; i < args.mm_num_hidden_layers(); i++) { + auto block = CLIPEncoderLayer(args, quant_args, parallel_args, options); + layers_.push_back(block); + blocks_->push_back(block); + } + } + + // Output hidden states for all intermediate layers + std::vector forward(const torch::Tensor& embeddings) { + std::vector output_hidden_states; + auto hidden_states = embeddings; + for (size_t i = 0; i < layers_.size(); ++i) { + output_hidden_states.emplace_back(hidden_states); + auto& layer = layers_[i]; + hidden_states = layer(hidden_states); + } + output_hidden_states.emplace_back(hidden_states); + return output_hidden_states; + } + + void load_state_dict(const StateDict& state_dict) { + for (size_t i = 0; i < layers_.size(); ++i) { + layers_[i]->load_state_dict( + state_dict.select("layers." + std::to_string(i) + ".")); + } + } + + void verify_loaded_weights(const std::string& prefix) const { + for (size_t i = 0; i < layers_.size(); ++i) { + layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + + "."); + } + } + + private: + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; +}; +TORCH_MODULE(CLIPEncoder); + +class CLIPVisionTransformerImpl : public torch::nn::Module { + public: + CLIPVisionTransformerImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + embeddings_ = register_module( + "embeddings", + CLIPVisionEmbedding(args, quant_args, parallel_args, options)); + pre_layernorm_ = register_module( + "pre_layernorm", + LayerNorm(args.mm_hidden_size(), args.layer_norm_eps(), true, options)); + encoder_ = register_module( + "encoder", CLIPEncoder(args, quant_args, parallel_args, options)); + post_layernorm_ = register_module( + "post_layernorm", + LayerNorm(args.mm_hidden_size(), args.layer_norm_eps(), true, options)); + } + + std::vector forward(const torch::Tensor& pixel_values) { + auto hidden_states = embeddings_(pixel_values); + hidden_states = pre_layernorm_(hidden_states); + + auto encoder_output = encoder_(hidden_states); + // when return_dict = False, skip pooled output step. + // auto pooled_output = encoder_outputs.slice(1,0,1).squeeze(1); + // pooled_ouput = post_layernorm_(pooled_output);*/ + return encoder_output; + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + embeddings_->load_state_dict(state_dict.select("embeddings.")); + pre_layernorm_->load_state_dict(state_dict.select("pre_layrnorm.")); + encoder_->load_state_dict(state_dict.select("encoder.")); + post_layernorm_->load_state_dict(state_dict.select("post_layernorm.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + embeddings_->verify_loaded_weights(prefix + "embeddings."); + pre_layernorm_->verify_loaded_weights(prefix + "pre_layrnorm."); + encoder_->verify_loaded_weights(prefix + "encoder."); + post_layernorm_->verify_loaded_weights(prefix + "post_layernorm."); + } + + private: + CLIPVisionEmbedding embeddings_{nullptr}; + LayerNorm pre_layernorm_{nullptr}; + CLIPEncoder encoder_{nullptr}; + LayerNorm post_layernorm_{nullptr}; +}; +TORCH_MODULE(CLIPVisionTransformer); + +// Follow implementation: https://github.com/huggingface/transformers +class CLIPVisionModelImpl : public torch::nn::Module { + public: + CLIPVisionModelImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + transformer_ = register_module( + "transformer", + CLIPVisionTransformer(args, quant_args, parallel_args, options)); + } + + // return hidden_state (TODO support return: output_attention, return_dict) + std::vector forward(const torch::Tensor& images) { + return transformer_(images); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + transformer_->load_state_dict(state_dict.select("vision_model.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + transformer_->verify_loaded_weights(prefix + "vision_model."); + } + + private: + CLIPVisionTransformer transformer_{nullptr}; +}; +TORCH_MODULE(CLIPVisionModel); + +// Not used but need to support +class LlavaProjectorLinearImpl : public torch::nn::Module { + public: + LlavaProjectorLinearImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + linear_ = register_module("linear", + ColumnParallelLinear(args.mm_hidden_size(), + args.hidden_size(), + /*bias=*/true, + /*gather_output=*/false, + quant_args, + parallel_args, + options)); + } + + torch::Tensor forward(torch::Tensor image_features) { + return linear_(image_features); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + linear_->load_state_dict(state_dict.select("0.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + linear_->verify_loaded_weights(prefix + "0."); + } + + private: + ColumnParallelLinear linear_{nullptr}; +}; +TORCH_MODULE(LlavaProjectorLinear); + +class LlavaProjectorMLP2XImpl : public torch::nn::Module { + public: + LlavaProjectorMLP2XImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + linear_1_ = register_module("linear_0", + ColumnParallelLinear(args.mm_hidden_size(), + args.hidden_size(), + /*bias=*/true, + /*gather_output*/ false, + quant_args, + parallel_args, + options)); + linear_2_ = register_module("linear_2", + RowParallelLinear(args.hidden_size(), + args.hidden_size(), + /*bias=*/true, + /*gather_output*/ false, + quant_args, + parallel_args, + options)); + // projector's activation type is "gelu" + act_ = Activation::get_act_func(args.mm_projector_hidden_act(), + options.device()); + CHECK(act_ != nullptr); + } + + torch::Tensor forward(torch::Tensor image_features) { + return linear_2_(act_(linear_1_(image_features))); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + linear_1_->load_state_dict(state_dict.select("linear_1.")); + linear_2_->load_state_dict(state_dict.select("linear_2.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + linear_1_->verify_loaded_weights(prefix + "linear_1."); + linear_2_->verify_loaded_weights(prefix + "linear_2."); + } + + private: + ColumnParallelLinear linear_1_{nullptr}; + RowParallelLinear linear_2_{nullptr}; + ActFunc act_{nullptr}; +}; +TORCH_MODULE(LlavaProjectorMLP2X); + +class LlavaModelImpl : public torch::nn::Module { + public: + LlavaModelImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + vision_feature_layer_ = args.mm_vision_feature_layer(); + vision_model_ = register_module( + "vision_model", + CLIPVisionModel(args, quant_args, parallel_args, options)); + + projector_ = register_module( + "projector", + LlavaProjectorMLP2X(args, quant_args, parallel_args, options)); + + embed_tokens_ = register_module( + "embed_tokens", + ParallelEmbedding( + args.vocab_size(), args.hidden_size(), parallel_args, options)); + handler_ = AttentionHandler::create_handler_with_rope( + args, /*interleaved=*/false, options); + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(args.mm_num_hidden_layers()); + for (int32_t i = 0; i < args.mm_num_hidden_layers(); i++) { + auto block = LlamaDecoderLayer( + args, quant_args, parallel_args, options, handler_.get()); + layers_.push_back(block); + blocks_->push_back(block); + } + norm_ = register_module( + "norm", RMSNorm(args.hidden_size(), args.rms_norm_eps(), options)); + } + + torch::Tensor vision_encode(const torch::Tensor& image, + const torch::Tensor& tokens) { + auto text_embedding = embed_tokens_(tokens); + // TODO: filter last_hidden_states + const auto& image_hidden_states = vision_model_(image); + // Only use the last hidden states + const auto& last_hidden_states = image_hidden_states[vision_feature_layer_]; + const auto& vision_embedding = projector_(last_hidden_states); + merge_text_vision_embeddings(text_embedding, vision_embedding, tokens); + return text_embedding; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const InputParameters& input_params) { + auto input_embedding = input_params.input_embedding; + torch::Tensor hidden_states; + if (!input_embedding.defined()) { + hidden_states = embed_tokens_(tokens); + } else { + hidden_states = input_embedding; + } + // TODO: set working space for attention handler + for (size_t i = 0; i < layers_.size(); i++) { + auto& layer = layers_[i]; + hidden_states = + layer(hidden_states, positions, kv_caches[i], input_params); + } + return norm_(hidden_states); + } + + void load_state_dict(const StateDict& state_dict) { + embed_tokens_->load_state_dict( + state_dict.select("language_model.model.embed_tokens.")); + projector_->load_state_dict(state_dict.select("multi_modal_projector.")); + vision_model_->load_state_dict(state_dict.select("vision_tower.")); + // call each layer's load_state_dict function + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict(state_dict.select( + "language_model.model.layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.select("language_model.model.norm.")); + } + + void verify_loaded_weights() const { + embed_tokens_->verify_loaded_weights("language_model.model.embed_tokens."); + projector_->verify_loaded_weights("multi_modal_projector."); + vision_model_->verify_loaded_weights("vision_tower."); + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->verify_loaded_weights("language_model.model.layers." + + std::to_string(i) + "."); + } + norm_->verify_loaded_weights("language_model.model.norm."); + } + + private: + void merge_text_vision_embeddings(torch::Tensor& text_embedding, + const torch::Tensor& vision_embedding, + const torch::Tensor& token_ids) { + // TODO: configure image_token_ids + constexpr int32_t image_token_id = 512; + + auto mask = (token_ids == image_token_id); + text_embedding.index_put_( + {mask}, vision_embedding.view({-1, vision_embedding.size(-1)})); + } + + private: + CLIPVisionModel vision_model_{nullptr}; + int64_t vision_feature_layer_; + + LlavaProjectorMLP2X projector_{nullptr}; + + ParallelEmbedding embed_tokens_{nullptr}; + // attention handler + std::unique_ptr handler_{nullptr}; + + torch::nn::ModuleList blocks_{nullptr}; + // hold same data but different type as blocks_ to avoid type cast + std::vector layers_; + + RMSNorm norm_{nullptr}; +}; +TORCH_MODULE(LlavaModel); + +class LlavaForCausalVLMImpl : public torch::nn::Module { + public: + LlavaForCausalVLMImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + model_ = register_module( + "model", LlavaModel(args, quant_args, parallel_args, options)); + + lm_head_ = register_module("lm_head", + ColumnParallelLinear(args.hidden_size(), + args.vocab_size(), + /*bias=*/false, + /*gather_output*/ true, + parallel_args, + options)); + } + + torch::Tensor vision_encode(const torch::Tensor& image, + const torch::Tensor& tokens) { + return model_->vision_encode(image, tokens); + } + + // images is stored in input_params + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const InputParameters& input_params) { + return model_(tokens, positions, kv_caches, input_params); + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& selected_idxes) { + auto h = hidden_states; + if (selected_idxes.defined()) { + h = h.index_select(/*dim=*/0, selected_idxes); + } + return lm_head_(h); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + model_->load_state_dict(state_dict); + lm_head_->load_state_dict(state_dict.select("language_model.lm_head.")); + } + + void verify_loaded_weights() const { + model_->verify_loaded_weights(); + lm_head_->verify_loaded_weights("language_model.lm_head."); + } + + private: + LlavaModel model_{nullptr}; + ColumnParallelLinear lm_head_{nullptr}; +}; +TORCH_MODULE(LlavaForCausalVLM); + +REGISTER_CAUSAL_VLM_MODEL(llava, LlavaForCausalVLM); + +// REGISTER_DEFAULT_CHAT_TEMPLATE(llama, Llama2ChatTemplate); + +REGISTER_MODEL_ARGS(llava, [&] { + // vision config + LOAD_ARG_OR(model_type, "model_type", "llava"); + LOAD_ARG_OR(mm_dropout, "vision_config.dropout", 0.0f); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "quick_gelu"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1024); + LOAD_ARG_OR(mm_image_size, "vision_config.image_size", 336); + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 4096); + LOAD_ARG_OR(mm_num_channels, "vision_config.num_channels", 3); + LOAD_ARG_OR(mm_initializer_range, "vision_config.initializer_range", 0.02f); + LOAD_ARG_OR(mm_layer_norm_eps, "vision_config.layer_norm_eps", 1e-05); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_attention_heads", 12); + LOAD_ARG_OR(mm_num_beam_groups, "vision_config.num_beam_groups", 1); + LOAD_ARG_OR(mm_num_beams, "vision_config.num_beams", 1); + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.num_hidden_layers", 12); + LOAD_ARG_OR(mm_num_return_sequences, "vision_config.num_return_sequences", 1); + LOAD_ARG_OR(mm_output_attentions, "vision_config.output_attentions", false); + LOAD_ARG_OR( + mm_output_hidden_states, "vision_config.output_hidden_states", false); + LOAD_ARG_OR(mm_output_scores, "vision_config.output_scores", false); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 32); + LOAD_ARG_OR(mm_projection_dim, "vision_config.projection_dim", 512); + LOAD_ARG_OR( + mm_remove_invalid_values, "vision_config.remove_invalid_values", false); + LOAD_ARG_OR(mm_repetition_penalty, "vision_config.repetition_penalty", 1.0f); + LOAD_ARG_OR(mm_return_dict, "vision_config.return_dict", true); + LOAD_ARG_OR(mm_return_dict_in_generate, + "vision_config.return_dict_in_generate", + false); + LOAD_ARG_OR(mm_temperature, "vision_config.temperature", 1.0f); + LOAD_ARG_OR( + mm_tie_encoder_decoder, "vision_config.tie_encoder_decoder", false); + LOAD_ARG_OR( + mm_tie_word_embeddings, "vision_config.tie_word_embeddings", true); + LOAD_ARG_OR(mm_top_k, "vision_config.top_k", 50); + LOAD_ARG_OR(mm_top_p, "vision_config.top_p", 1.0f); + LOAD_ARG_OR(mm_torchscript, "vision_config.torchscript", false); + LOAD_ARG_OR(mm_use_bfloat16, "vision_config.user_bfloat16", false); + LOAD_ARG_OR_FUNC(mm_head_dim, "mm_head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + LOAD_ARG_OR(mm_vocab_size, "vision_config.vocab_size", 32000); + // projector config + LOAD_ARG_OR(mm_projector_type, "mm_projector_type", "mlp2x_gelu"); + LOAD_ARG_OR(mm_projector_hidden_act, "projector_hidden_act", "gelu"); + LOAD_ARG_OR(mm_projector_n_layers, "mm_projector_n_layers", 2); + LOAD_ARG_OR(mm_vision_feature_layer, "vision_feature_layer", -2); + LOAD_ARG_OR(mm_vision_feature_select_strategy, + "vision_feature_select_strategy", + "default"); + // text config + LOAD_ARG_OR(hidden_size, "hidden_size", 4096); + LOAD_ARG_OR(n_heads, "num_attention_heads", 32); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 32); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 11008); + LOAD_ARG_OR( + max_position_embeddings, "text_config.max_position_embeddings", 4096); + LOAD_ARG_OR(rms_norm_eps, "text_config.rms_norm_eps", 1e-05); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 1); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 2); + LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f); + LOAD_ARG_OR(rope_scaling, "rope_scaling", 1.0f); + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + LOAD_ARG_OR(vocab_size, "text_config.vocab_size", 32064); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 32); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); +}); +} // namespace llm::hf diff --git a/src/models/model_args.h b/src/models/model_args.h index ddbd45e5..f8ed7833 100644 --- a/src/models/model_args.h +++ b/src/models/model_args.h @@ -103,6 +103,114 @@ struct ModelArgs { // Stop token ids for decoding. DEFINE_ARG(std::unordered_set, stop_token_ids); + + // Vision model's dropout + DEFINE_ARG(float, mm_dropout) = 0.0f; + + // Vision model's hidden_act + DEFINE_ARG(std::string, mm_hidden_act); + + // Vision model's mm_hidden_size + DEFINE_ARG(int64_t, mm_hidden_size) = 0; + + // Vision model's mm_image_size + DEFINE_ARG(int64_t, mm_image_size) = 0; + + // Vision model's mm_intermediate_size + DEFINE_ARG(int64_t, mm_intermediate_size) = 0; + + // Vision model's mm_num_channels + DEFINE_ARG(int64_t, mm_num_channels) = 0; + + // Vision model's mm_initializer_range + DEFINE_ARG(float, mm_initializer_range) = 0.0f; + + // Vision model's mm_layer_norm_eps + DEFINE_ARG(float, mm_layer_norm_eps) = 0; + + // Vision model's mm_num_attention_heads + DEFINE_ARG(int64_t, mm_num_attention_heads) = 0; + + // Vision model's mm_num_beam_groups + DEFINE_ARG(int64_t, mm_num_beam_groups) = 0; + + // Vision model's mm_num_beams + DEFINE_ARG(int64_t, mm_num_beams) = 0; + + // Vision model's mm_num_hidden_layers + DEFINE_ARG(int64_t, mm_num_hidden_layers) = 0; + + // Vision model's mm_num_return_sequences + DEFINE_ARG(int64_t, mm_num_return_sequences) = 0; + + // Vision model's mm_output_attentions + DEFINE_ARG(bool, mm_output_attentions) = false; + + // Vision model's mm_output_hidden_states + DEFINE_ARG(bool, mm_output_hidden_states) = false; + + // Vision model's mm_output_scores + DEFINE_ARG(bool, mm_output_scores) = false; + + // Vision model's mm_patch_size + DEFINE_ARG(int64_t, mm_patch_size) = 0; + + // Vision model's mm_projection_dim + DEFINE_ARG(int64_t, mm_projection_dim) = 0; + + // Vision model's mm_remove_invalid_values + DEFINE_ARG(bool, mm_remove_invalid_values) = false; + + // Vision model's mm_repetition_penalty + DEFINE_ARG(float, mm_repetition_penalty) = 0.0f; + + // Vision model's mm_return_dict + DEFINE_ARG(bool, mm_return_dict) = false; + + // Vision model's mm_return_dict_in_generate + DEFINE_ARG(bool, mm_return_dict_in_generate) = false; + + // Vision model's mm_temperature + DEFINE_ARG(float, mm_temperature) = 0.0f; + + // Vision model's mm_tie_encoder_decoder + DEFINE_ARG(bool, mm_tie_encoder_decoder) = false; + + // Vision model's mm_tie_word_embeddings + DEFINE_ARG(bool, mm_tie_word_embeddings) = false; + + // Vision model's mm_top_k + DEFINE_ARG(int64_t, mm_top_k) = 0; + + // Vision model's mm_top_p + DEFINE_ARG(float, mm_top_p) = 0.0f; + + // Vision model's mm_torchscript + DEFINE_ARG(bool, mm_torchscript) = false; + + // Vision model's mm_use_bfloat16 + DEFINE_ARG(bool, mm_use_bfloat16) = false; + + // Vision model's mm_head_dim + DEFINE_ARG(int64_t, mm_head_dim) = 0; + + // Vision model's mm_vocab_size + DEFINE_ARG(int64_t, mm_vocab_size) = 0; + + // VLM model projector's mm_projector_type + DEFINE_ARG(std::string, mm_projector_type); + + // VLM model projector's mm_projector_hidden_act + DEFINE_ARG(std::string, mm_projector_hidden_act); + + // VLM model projector's mm_projector_n_layers + DEFINE_ARG(int64_t, mm_projector_n_layers) = 0; + + // VLM model projector's mm_vision_feature_layer + DEFINE_ARG(int64_t, mm_vision_feature_layer) = 0; + + // VLM model projector's mm_vision_feature_select_strategy + DEFINE_ARG(std::string, mm_vision_feature_select_strategy); }; inline std::ostream& operator<<(std::ostream& os, const ModelArgs& args) { diff --git a/src/models/model_registry.cpp b/src/models/model_registry.cpp index 5f0b10e0..5fe88d3f 100644 --- a/src/models/model_registry.cpp +++ b/src/models/model_registry.cpp @@ -13,6 +13,7 @@ #include "huggingface/gpt_neox.h" // IWYU pragma: keep #include "huggingface/internlm.h" // IWYU pragma: keep #include "huggingface/llama.h" // IWYU pragma: keep +#include "huggingface/llava.h" // IWYU pragma: keep #include "huggingface/mistral.h" // IWYU pragma: keep #include "huggingface/mpt.h" // IWYU pragma: keep #include "huggingface/phi.h" // IWYU pragma: keep @@ -36,6 +37,16 @@ void ModelRegistry::register_causallm_factory(const std::string& name, } } +void ModelRegistry::register_causalvlm_factory(const std::string& name, + CausalVLMFactory factory) { + ModelRegistry* instance = get_instance(); + if (instance->model_registry_[name].causal_vlm_factory != nullptr) { + LOG(WARNING) << "causal vlm factory for " << name << "already registered."; + } else { + instance->model_registry_[name].causal_vlm_factory = factory; + } +} + void ModelRegistry::register_model_args_loader(const std::string& name, ModelArgsLoader loader) { ModelRegistry* instance = get_instance(); @@ -84,6 +95,11 @@ CausalLMFactory ModelRegistry::get_causallm_factory(const std::string& name) { return instance->model_registry_[name].causal_lm_factory; } +CausalVLMFactory ModelRegistry::get_causalvlm_factory(const std::string& name) { + ModelRegistry* instance = get_instance(); + return instance->model_registry_[name].causal_vlm_factory; +} + ModelArgsLoader ModelRegistry::get_model_args_loader(const std::string& name) { ModelRegistry* instance = get_instance(); return instance->model_registry_[name].model_args_loader; diff --git a/src/models/model_registry.h b/src/models/model_registry.h index b472fd50..fb18eb3a 100644 --- a/src/models/model_registry.h +++ b/src/models/model_registry.h @@ -5,6 +5,7 @@ #include #include "causal_lm.h" +#include "causal_vlm.h" #include "chat_template/chat_template.h" #include "common/json_reader.h" #include "common/type_traits.h" // IWYU pragma: keep @@ -21,6 +22,13 @@ using CausalLMFactory = std::function( const ParallelArgs& parallel_args, const torch::TensorOptions& options)>; +using CausalVLMFactory = std::function( + const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options)>; + +using ChatTemplateFactory = std::function()>; using ChatTemplateFactory = std::function()>; using ModelArgsLoader = @@ -35,6 +43,7 @@ using TokenizerArgsLoader = // TODO: add default args loader. struct ModelMeta { CausalLMFactory causal_lm_factory; + CausalVLMFactory causal_vlm_factory; ChatTemplateFactory chat_template_factory; ModelArgsLoader model_args_loader; QuantArgsLoader quant_args_loader; @@ -50,6 +59,9 @@ class ModelRegistry { static void register_causallm_factory(const std::string& name, CausalLMFactory factory); + static void register_causalvlm_factory(const std::string& name, + CausalVLMFactory factory); + static void register_model_args_loader(const std::string& name, ModelArgsLoader loader); @@ -65,6 +77,8 @@ class ModelRegistry { static CausalLMFactory get_causallm_factory(const std::string& name); + static CausalVLMFactory get_causalvlm_factory(const std::string& name); + static ModelArgsLoader get_model_args_loader(const std::string& name); static QuantArgsLoader get_quant_args_loader(const std::string& name); @@ -98,6 +112,25 @@ class ModelRegistry { #define REGISTER_CAUSAL_MODEL(ModelType, ModelClass) \ REGISTER_CAUSAL_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass) +#define REGISTER_CAUSAL_VLM_MODEL_WITH_VARNAME(VarName, ModelType, ModelClass) \ + const bool VarName##_registered = []() { \ + ModelRegistry::register_causalvlm_factory( \ + #ModelType, \ + [](const ModelArgs& args, \ + const QuantArgs& quant_args, \ + const ParallelArgs& parallel_args, \ + const torch::TensorOptions& options) { \ + ModelClass model(args, quant_args, parallel_args, options); \ + model->eval(); \ + return std::make_unique>( \ + std::move(model), options); \ + }); \ + return true; \ + }() + +#define REGISTER_CAUSAL_VLM_MODEL(ModelType, ModelClass) \ + REGISTER_CAUSAL_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass) + #define REGISTER_DEFAULT_CHAT_TEMPLATE_WITH_VARNAME( \ VarName, ModelType, ChatTemplateClass) \ const bool VarName##_chat_template_registered = []() { \ diff --git a/src/models/parameters.h b/src/models/parameters.h index 561d60ef..35453f3e 100644 --- a/src/models/parameters.h +++ b/src/models/parameters.h @@ -23,6 +23,7 @@ struct InputParameters { params.new_cache_slots = safe_to(new_cache_slots, device); params.block_tables = safe_to(block_tables, device); + params.input_embedding = safe_to(input_embedding, device); return params; } @@ -54,6 +55,9 @@ struct InputParameters { // used in attention kernel to fetch cached key-value. // IntTensor: [n_seq, max_n_blocks] torch::Tensor block_tables; + + // input embedding = text_embedding + image_embedding + torch::Tensor input_embedding; }; } // namespace llm diff --git a/src/request/request.cpp b/src/request/request.cpp index 2dbe451e..4e2a3ac5 100644 --- a/src/request/request.cpp +++ b/src/request/request.cpp @@ -29,6 +29,24 @@ Request::Request(std::string prompt, CHECK_GE(best_of, n); } +Request::Request(std::string prompt, + std::vector prompt_tokens, + torch::Tensor input_embedding, + size_t seq_capacity, + size_t n, + size_t best_of, + bool logprobs) + : prompt(std::move(prompt)), + prompt_tokens(std::move(prompt_tokens)), + input_embedding(input_embedding), + seq_capacity(seq_capacity), + n(n), + best_of(best_of), + logprobs(logprobs), + created_time(absl::Now()) { + CHECK_GE(best_of, n); +} + void Request::add_sequence() { Sequence::Options options; options.sampling_param = this->sampling_param; @@ -40,6 +58,7 @@ void Request::add_sequence() { sequences.emplace_back(index, this->prompt, this->prompt_tokens, + this->input_embedding, this->created_time, this->seq_capacity, options); diff --git a/src/request/request.h b/src/request/request.h index d1bef10d..5cd11deb 100644 --- a/src/request/request.h +++ b/src/request/request.h @@ -33,6 +33,14 @@ struct Request final { size_t best_of, bool logprobs); + Request(std::string prompt, + std::vector prompt_tokens, + torch::Tensor input_embedding, + size_t seq_capacity, + size_t n, + size_t best_of, + bool logprobs); + void add_sequence(); bool is_finished() const; @@ -70,6 +78,8 @@ struct Request final { // NOLINTNEXTLINE const std::vector prompt_tokens; + torch::Tensor input_embedding; + // the number of sequences to generate completions for the prompt. // NOLINTNEXTLINE const size_t n; diff --git a/src/request/sequence.cpp b/src/request/sequence.cpp index dfa662b8..2b4e0162 100644 --- a/src/request/sequence.cpp +++ b/src/request/sequence.cpp @@ -55,6 +55,40 @@ Sequence::Sequence(size_t index, } } +Sequence::Sequence(size_t index, + const std::string_view& prompt, + const std::vector& prompt_token_ids, + torch::Tensor input_embedding, + const absl::Time& created_time, + size_t capacity, + const Options& option) + : index_(index), + last_token_time_(created_time), + options_(option), + incremental_decoder_(prompt, + prompt_token_ids.size(), + option.echo, + option.skip_special_tokens), + num_kv_cache_tokens_(static_cast(EngineType::COUNT), 0) { + CHECK(!prompt_token_ids.empty()) << "empty prompt token ids"; + CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small"; + + num_prompt_tokens_ = prompt_token_ids.size(); + // allocate space for token ids, logprobs, top tokens and top logprobs + token_ids_.resize(capacity); + logprobs_.resize(capacity); + top_tokens_.resize(capacity); + top_logprobs_.resize(capacity); + + // add the prompt tokens + for (const auto token_id : prompt_token_ids) { + token_ids_[num_tokens_++] = token_id; + token_to_count_map_[token_id]++; + } + + input_embedding_ = input_embedding; +} + Sequence::Sequence(const std::string_view& prompt, const std::vector& prompt_token_ids, size_t capacity, diff --git a/src/request/sequence.h b/src/request/sequence.h index d3d9abb8..5376be09 100644 --- a/src/request/sequence.h +++ b/src/request/sequence.h @@ -66,6 +66,14 @@ class Sequence final { size_t capacity, const Options& option); + Sequence(size_t index, + const std::string_view& prompt, + const std::vector& prompt_token_ids, + torch::Tensor input_embedding, + const absl::Time& created_time, + size_t capacity, + const Options& option); + // simple constructor for testing Sequence(const std::string_view& prompt, const std::vector& prompt_token_ids, @@ -82,6 +90,9 @@ class Sequence final { // get token ids Slice token_ids() const { return {token_ids_, num_tokens_}; } + // get input embedding + torch::Tensor get_input_embedding() const { return input_embedding_; } + // get token ids to count map const std::unordered_map& token_to_count_map() const { return token_to_count_map_; @@ -252,6 +263,8 @@ class Sequence final { // token ids generated for the sequence std::vector token_ids_; + torch::Tensor input_embedding_; + // log probabilities of the sequence std::vector> logprobs_; From 30f2b1a8ac48c837314d52bebcb2d81d509118d3 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 1 Jul 2024 00:52:30 -0700 Subject: [PATCH 2/2] build: fix multiple definition issue (#256) - [ ] @liutongxuan will help refactor code to get rid of duplications in following diffs (cherry picked from commit d711c5553f7746ef7efbd8dac429c1e711250ee8) --- scalellm/CMakeLists.txt | 1 + scalellm/csrc/vlm_handler.cpp | 17 +--- src/engine/vlm_engine.cpp | 6 +- src/engine/vlm_worker.cpp | 28 +++---- src/handlers/vlm_handler.cpp | 154 +++++++++++++++++----------------- 5 files changed, 96 insertions(+), 110 deletions(-) diff --git a/scalellm/CMakeLists.txt b/scalellm/CMakeLists.txt index 33dafd84..4de7cec6 100644 --- a/scalellm/CMakeLists.txt +++ b/scalellm/CMakeLists.txt @@ -9,6 +9,7 @@ pybind_extension( csrc/sampling_params.cpp csrc/output.cpp csrc/llm_handler.cpp + csrc/vlm_handler.cpp csrc/module.cpp DEPS :llm_handler diff --git a/scalellm/csrc/vlm_handler.cpp b/scalellm/csrc/vlm_handler.cpp index fef28ff3..4602ced1 100644 --- a/scalellm/csrc/vlm_handler.cpp +++ b/scalellm/csrc/vlm_handler.cpp @@ -10,21 +10,6 @@ namespace py = pybind11; using namespace pybind11::literals; void init_vlm_handler(py::module_& m) { - py::enum_(m, "Priority") - .value("DEFAULT", Priority::NORMAL) - .value("LOW", Priority::LOW) - .value("NORMAL", Priority::NORMAL) - .value("HIGH", Priority::HIGH) - .export_values(); - - py::class_>(m, "Future") - .def("wait", - &std::future::wait, - py::call_guard()) - .def("get", - &std::future::get, - py::call_guard()); - auto vlm_handler = py::class_(m, "VLMHandler") .def(py::init(), py::arg("options")) @@ -91,7 +76,7 @@ void init_vlm_handler(py::module_& m) { "cuda_graph_batch_sizes={}, " "max_tokens_per_batch={}, max_seqs_per_batch={}, " "num_handling_threads={}, " - "image_input_type={}, image_token_id={}, + "image_input_type={}, image_token_id={}, " "image_input_shape={}, image_feature_size={})"_s.format( self.model_path_, self.devices_, diff --git a/src/engine/vlm_engine.cpp b/src/engine/vlm_engine.cpp index 9ea7116f..17f2a2cd 100644 --- a/src/engine/vlm_engine.cpp +++ b/src/engine/vlm_engine.cpp @@ -14,8 +14,8 @@ #include "models/model_args.h" #include "vlm_worker.h" -DEFINE_COUNTER(prepare_input_latency_seconds, - "Latency of preparing input in seconds"); +// DEFINE_COUNTER(prepare_input_latency_seconds, +// "Latency of preparing input in seconds"); namespace llm { namespace { @@ -270,7 +270,7 @@ ModelOutput VLMEngine::execute_model(Batch& batch) { Timer timer; auto model_inputs = batch.prepare_model_input(options_.num_decoding_tokens(), adjusted_batch_size); - COUNTER_ADD(prepare_input_latency_seconds, timer.elapsed_seconds()); + // COUNTER_ADD(prepare_input_latency_seconds, timer.elapsed_seconds()); if (!model_inputs.token_ids.defined()) { // empty input, just return diff --git a/src/engine/vlm_worker.cpp b/src/engine/vlm_worker.cpp index 42c1f737..7a6bf02a 100644 --- a/src/engine/vlm_worker.cpp +++ b/src/engine/vlm_worker.cpp @@ -24,17 +24,17 @@ #include "sampling/sampler.h" // latency metrics -DEFINE_COUNTER_FAMILY(execution_latency_seconds, - "Execution latency in seconds"); -DEFINE_COUNTER_INSTANCE(model_execution_latency_seconds, - execution_latency_seconds, - {{"stage", "model"}}); -DEFINE_COUNTER_INSTANCE(logits_processing_latency_seconds, - execution_latency_seconds, - {{"stage", "logits_processing"}}); -DEFINE_COUNTER_INSTANCE(sampling_latency_seconds, - execution_latency_seconds, - {{"stage", "sampling"}}); +// DEFINE_COUNTER_FAMILY(execution_latency_seconds, +// "Execution latency in seconds"); +// DEFINE_COUNTER_INSTANCE(model_execution_latency_seconds, +// execution_latency_seconds, +// {{"stage", "model"}}); +// DEFINE_COUNTER_INSTANCE(logits_processing_latency_seconds, +// execution_latency_seconds, +// {{"stage", "logits_processing"}}); +// DEFINE_COUNTER_INSTANCE(sampling_latency_seconds, +// execution_latency_seconds, +// {{"stage", "sampling"}}); namespace llm { @@ -149,7 +149,7 @@ std::optional VLMWorker::execute_model(const ModelInput& inputs) { } at::cuda::getCurrentCUDAStream().synchronize(); - COUNTER_ADD(model_execution_latency_seconds, timer.elapsed_seconds()); + // COUNTER_ADD(model_execution_latency_seconds, timer.elapsed_seconds()); if (!driver_) { return std::nullopt; @@ -166,7 +166,7 @@ std::optional VLMWorker::execute_model(const ModelInput& inputs) { sampling_params.unique_token_ids, sampling_params.unique_token_counts, sampling_params.unique_token_ids_lens); - COUNTER_ADD(logits_processing_latency_seconds, timer.elapsed_seconds()); + // COUNTER_ADD(logits_processing_latency_seconds, timer.elapsed_seconds()); // set logits to output output.logits = logits; @@ -179,7 +179,7 @@ std::optional VLMWorker::execute_model(const ModelInput& inputs) { auto sample_logits = logits.index_select(/*dim=*/0, sampling_params.sample_idxes); auto sample_output = sampler->forward(sample_logits); - COUNTER_ADD(sampling_latency_seconds, timer.elapsed_seconds()); + // COUNTER_ADD(sampling_latency_seconds, timer.elapsed_seconds()); // set sample output to output output.sample_output = sample_output; diff --git a/src/handlers/vlm_handler.cpp b/src/handlers/vlm_handler.cpp index aca66940..80def6a7 100644 --- a/src/handlers/vlm_handler.cpp +++ b/src/handlers/vlm_handler.cpp @@ -20,86 +20,86 @@ #include "request/request.h" #include "speculative/speculative_engine.h" -DEFINE_COUNTER_FAMILY(request_status_total, "Total number of request status"); -DEFINE_COUNTER_INSTANCE(request_ok, request_status_total, {{"code", "OK"}}); -DEFINE_COUNTER_INSTANCE(request_cancelled, - request_status_total, - {{"code", "CANCELLED"}}); -DEFINE_COUNTER_INSTANCE(request_unknown, - request_status_total, - {{"code", "UNKNOWN"}}); -DEFINE_COUNTER_INSTANCE(request_invalid_argument, - request_status_total, - {{"code", "INVALID_ARGUMENT"}}); -DEFINE_COUNTER_INSTANCE(request_deadline_exceeded, - request_status_total, - {{"code", "DEADLINE_EXCEEDED"}}); -DEFINE_COUNTER_INSTANCE(request_resource_exhausted, - request_status_total, - {{"code", "RESOURCE_EXHAUSTED"}}); -DEFINE_COUNTER_INSTANCE(request_unauthenticated, - request_status_total, - {{"code", "UNAUTHENTICATED"}}); -DEFINE_COUNTER_INSTANCE(request_unavailable, - request_status_total, - {{"code", "UNAVAILABLE"}}); -DEFINE_COUNTER_INSTANCE(request_unimplemented, - request_status_total, - {{"code", "UNIMPLEMENTED"}}); - -DEFINE_COUNTER_FAMILY(request_handling_latency_seconds, - "Request handling latency in seconds"); -DEFINE_COUNTER_INSTANCE(chat_handling_latency_seconds, - request_handling_latency_seconds, - {{"type", "chat"}}); -DEFINE_COUNTER_INSTANCE(completion_handling_latency_seconds, - request_handling_latency_seconds, - {{"type", "completion"}}); - -DEFINE_COUNTER(tokenization_latency_seconds, - "Prompt tokenization latency in seconds"); -DEFINE_COUNTER(chat_template_latency_seconds, - "Chat template latency in seconds"); +// DEFINE_COUNTER_FAMILY(request_status_total, "Total number of request +// status"); DEFINE_COUNTER_INSTANCE(request_ok, request_status_total, {{"code", +// "OK"}}); DEFINE_COUNTER_INSTANCE(request_cancelled, +// request_status_total, +// {{"code", "CANCELLED"}}); +// DEFINE_COUNTER_INSTANCE(request_unknown, +// request_status_total, +// {{"code", "UNKNOWN"}}); +// DEFINE_COUNTER_INSTANCE(request_invalid_argument, +// request_status_total, +// {{"code", "INVALID_ARGUMENT"}}); +// DEFINE_COUNTER_INSTANCE(request_deadline_exceeded, +// request_status_total, +// {{"code", "DEADLINE_EXCEEDED"}}); +// DEFINE_COUNTER_INSTANCE(request_resource_exhausted, +// request_status_total, +// {{"code", "RESOURCE_EXHAUSTED"}}); +// DEFINE_COUNTER_INSTANCE(request_unauthenticated, +// request_status_total, +// {{"code", "UNAUTHENTICATED"}}); +// DEFINE_COUNTER_INSTANCE(request_unavailable, +// request_status_total, +// {{"code", "UNAVAILABLE"}}); +// DEFINE_COUNTER_INSTANCE(request_unimplemented, +// request_status_total, +// {{"code", "UNIMPLEMENTED"}}); + +// DEFINE_COUNTER_FAMILY(request_handling_latency_seconds, +// "Request handling latency in seconds"); +// DEFINE_COUNTER_INSTANCE(chat_handling_latency_seconds, +// request_handling_latency_seconds, +// {{"type", "chat"}}); +// DEFINE_COUNTER_INSTANCE(completion_handling_latency_seconds, +// request_handling_latency_seconds, +// {{"type", "completion"}}); + +// DEFINE_COUNTER(tokenization_latency_seconds, +// "Prompt tokenization latency in seconds"); +// DEFINE_COUNTER(chat_template_latency_seconds, +// "Chat template latency in seconds"); namespace llm { namespace { #define CALLBACK_WITH_ERROR(CODE, MSG) callback(Status{CODE, MSG}); -void log_request_status(StatusCode code) { - switch (code) { - case StatusCode::OK: - COUNTER_INC(request_ok); - break; - case StatusCode::CANCELLED: - COUNTER_INC(request_cancelled); - break; - case StatusCode::UNKNOWN: - COUNTER_INC(request_unknown); - break; - case StatusCode::INVALID_ARGUMENT: - COUNTER_INC(request_invalid_argument); - break; - case StatusCode::DEADLINE_EXCEEDED: - COUNTER_INC(request_deadline_exceeded); - break; - case StatusCode::RESOURCE_EXHAUSTED: - COUNTER_INC(request_resource_exhausted); - break; - case StatusCode::UNAUTHENTICATED: - COUNTER_INC(request_unauthenticated); - break; - case StatusCode::UNAVAILABLE: - COUNTER_INC(request_unavailable); - break; - case StatusCode::UNIMPLEMENTED: - COUNTER_INC(request_unimplemented); - break; - default: - COUNTER_INC(request_unknown); - break; - } -} +// void log_request_status(StatusCode code) { +// switch (code) { +// case StatusCode::OK: +// COUNTER_INC(request_ok); +// break; +// case StatusCode::CANCELLED: +// COUNTER_INC(request_cancelled); +// break; +// case StatusCode::UNKNOWN: +// COUNTER_INC(request_unknown); +// break; +// case StatusCode::INVALID_ARGUMENT: +// COUNTER_INC(request_invalid_argument); +// break; +// case StatusCode::DEADLINE_EXCEEDED: +// COUNTER_INC(request_deadline_exceeded); +// break; +// case StatusCode::RESOURCE_EXHAUSTED: +// COUNTER_INC(request_resource_exhausted); +// break; +// case StatusCode::UNAUTHENTICATED: +// COUNTER_INC(request_unauthenticated); +// break; +// case StatusCode::UNAVAILABLE: +// COUNTER_INC(request_unavailable); +// break; +// case StatusCode::UNIMPLEMENTED: +// COUNTER_INC(request_unimplemented); +// break; +// default: +// COUNTER_INC(request_unknown); +// break; +// } +// } bool verify_params(const SamplingParams& sp, OutputCallback callback) { if (sp.n == 0) { @@ -220,7 +220,7 @@ std::future VLMHandler::schedule_async(torch::Tensor image, stream, [callback = std::move(callback)](const RequestOutput& output) { if (output.status.has_value()) { - log_request_status(output.status.value().code()); + // log_request_status(output.status.value().code()); } return callback(output); }); @@ -243,7 +243,7 @@ std::future VLMHandler::schedule(torch::Tensor image, priority, stream, callback = std::move(callback)](size_t tid) mutable { - AUTO_COUNTER(completion_handling_latency_seconds); + // AUTO_COUNTER(completion_handling_latency_seconds); // remove the pending request after scheduling SCOPE_GUARD([this] { scheduler_->dec_pending_requests(); }); @@ -343,7 +343,7 @@ std::unique_ptr VLMHandler::create_request(size_t tid, "Failed to encode prompt"); return nullptr; } - COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + // COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); // encode the image, encode & projector auto vision_engine = dynamic_cast(engine_.get());