Skip to content

Commit

Permalink
[Feature][Spec Decode] Simplify the use of Eagle Spec Decode (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#12304)

Signed-off-by: Shangming Cai <[email protected]>
  • Loading branch information
ShangmingCai authored Feb 17, 2025
1 parent 2010f04 commit 46cdd59
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 18 deletions.
16 changes: 7 additions & 9 deletions docs/source/features/spec_decode.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
tensor_parallel_size=4,
speculative_model="path/to/modified/eagle/model",
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
speculative_draft_tensor_parallel_size=1,
)

Expand All @@ -190,14 +190,12 @@ for output in outputs:

A few important things to consider when using the EAGLE based draft models:

1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) cannot be
used directly with vLLM due to differences in the expected layer names and model definition.
To use these models with vLLM, use the [following script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d)
to convert them. Note that this script does not modify the model's weights.

In the above example, use the script to first convert
the [yuhuili/EAGLE-LLaMA3-Instruct-8B](https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B) model
and then use the converted checkpoint as the draft model in vLLM.
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
the latest version of vLLM, please leave a comment or raise an issue.

2. The EAGLE based draft models need to be run without tensor parallelism
(i.e. speculative_draft_tensor_parallel_size is set to 1), although
Expand Down
144 changes: 144 additions & 0 deletions tests/spec_decode/e2e/test_eagle_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,150 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "meta-llama/Llama-2-7b-chat-hf",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):

run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):

run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "Qwen/Qwen2-7B-Instruct",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):

run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)


if __name__ == "__main__":
import pytest
pytest.main([__file__])
40 changes: 39 additions & 1 deletion tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly)
from vllm.worker.worker import Worker

from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, create_sampler_output_list, mock_worker
from .utils import (create_batch, create_sampler_output_list, create_worker,
mock_worker)


@pytest.mark.parametrize('k', [1, 2, 6])
Expand Down Expand Up @@ -905,3 +908,38 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
worker.execute_model(execute_model_req=execute_model_req)
# but first draft still counted
assert draft_worker.get_spec_proposals.call_count == 1


def test_correctly_load_weight_for_eagle():
"""
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
"""
seed = 100
block_size = 32
num_gpu_blocks = 8096 // block_size
target_worker = create_worker(
Worker,
"JackFram/llama-68m",
block_size,
num_gpu_blocks,
seed,
)
draft_worker = create_worker(
MultiStepWorker,
"abhigoyal/vllm-eagle-llama-68m-random",
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)

spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False)
worker.proposer_worker.maybe_load_lm_head_weight(
target_worker.model_runner.model.lm_head.weight.data)
assert torch.allclose(
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
worker.scorer_worker.model_runner.model.lm_head.weight.data)
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,15 @@ def maybe_create_spec_config(

draft_hf_config = draft_model_config.hf_config

# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
if "eagle-" in draft_model_config.model.lower():
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(draft_model_config.hf_config, EAGLEConfig):
pass
else:
eagle_config = EAGLEConfig(draft_model_config.hf_config)
draft_model_config.hf_config = eagle_config

if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
Expand Down
24 changes: 18 additions & 6 deletions vllm/model_executor/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand All @@ -18,6 +19,8 @@

from .utils import maybe_prefix

logger = init_logger(__name__)


class DummyInputLayerNorm(nn.Module):

Expand Down Expand Up @@ -190,8 +193,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
default_weight_loader)
weight_loader(self.fc.bias, loaded_weight)
else:
raise ValueError("Found bias in the loaded weights "
"but the model config doesn't have bias")
logger.warning_once("Found bias in the loaded weights but "
"the model config doesn't have bias.")
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
Expand All @@ -200,12 +203,21 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
model_weights[f"model.{name}"] = loaded_weight

lm_head_weight = model_weights.pop("lm_head.weight")
if "lm_head.weight" in model_weights:
lm_head_weight = model_weights.pop("lm_head.weight")

if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:

if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
lm_head_weight = lm_head_weight[self.token_map]

lm_head_weight = lm_head_weight[self.token_map]
else:
# NOTE(Shangming): initialize the placeholder for lm_head weight.
lm_head_weight = torch.zeros(
self.lm_head.org_vocab_size,
self.lm_head.embedding_dim,
dtype=self.config.torch_dtype,
)

weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)
Expand Down
12 changes: 12 additions & 0 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
Expand Down Expand Up @@ -386,3 +387,14 @@ def _raise_if_unsupported(
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")

def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
weight_loader = getattr(
self.worker.model_runner.model_runner.model.lm_head.weight,
"weight_loader", default_weight_loader)
weight_loader(
self.worker.model_runner.model_runner.model.lm_head.weight,
lm_head_weight)
19 changes: 19 additions & 0 deletions vllm/spec_decode/smaller_tp_proposer_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
patch_tensor_parallel_group)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.multi_step_worker import MultiStepWorker
Expand Down Expand Up @@ -173,3 +174,21 @@ def get_cache_block_size_bytes(self) -> int:
@property
def vocab_size(self) -> int:
return self._worker.vocab_size

def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
if self._is_dummy:
return

with self._patch_tensor_parallel_group():
weight_loader = getattr(
self._worker.worker.model_runner.model_runner.model.\
lm_head.weight,
"weight_loader",
default_weight_loader)
weight_loader(
self._worker.worker.model_runner.model_runner.model.\
lm_head.weight,
lm_head_weight)
Loading

0 comments on commit 46cdd59

Please sign in to comment.