Skip to content

Commit b9cb7f2

Browse files
benchislettluyuzhe111
authored andcommitted
[V1][Spec Decode] EAGLE-3 Support (vllm-project#16937)
Signed-off-by: Bryan Lu <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]> Co-authored-by: Bryan Lu <[email protected]>
1 parent 9607ab7 commit b9cb7f2

File tree

12 files changed

+359
-35
lines changed

12 files changed

+359
-35
lines changed

examples/offline_inference/eagle.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def main():
5252

5353
args = parse_args()
5454

55-
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
56-
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
55+
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
56+
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
5757

5858
max_model_len = 2048
5959

@@ -81,7 +81,7 @@ def main():
8181
max_num_seqs=args.max_num_seqs,
8282
gpu_memory_utilization=0.8,
8383
speculative_config={
84-
"method": "eagle",
84+
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
8585
"model": eagle_dir,
8686
"num_speculative_tokens": args.num_spec_tokens,
8787
"draft_tensor_parallel_size": args.draft_tp,
@@ -95,6 +95,9 @@ def main():
9595
outputs = llm.generate(prompt_token_ids=prompt_ids,
9696
sampling_params=sampling_params)
9797

98+
if not hasattr(outputs, "metrics") or outputs.metrics is None:
99+
return
100+
98101
# calculate the average number of accepted tokens per forward pass, +1 is
99102
# to account for the token from the target model that's always going to be
100103
# accepted
@@ -109,6 +112,11 @@ def main():
109112
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
110113
print("-" * 50)
111114

115+
# print acceptance at each token position
116+
for i in range(len(acceptance_counts)):
117+
print(f"acceptance at token {i}:"
118+
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}")
119+
112120

113121
if __name__ == "__main__":
114122
main()

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ def check_available_online(
393393
trust_remote_code=True,
394394
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
395395
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
396+
"Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
397+
trust_remote_code=True,
398+
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
399+
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
396400
}
397401

398402
_TRANSFORMERS_MODELS = {

tests/v1/e2e/test_spec_decode.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,15 @@ def sampling_config():
5050

5151
@pytest.fixture
5252
def model_name():
53-
return "meta-llama/Meta-Llama-3-8B-Instruct"
53+
return "meta-llama/Llama-3.1-8B-Instruct"
5454

5555

56-
@pytest.fixture
5756
def eagle_model_name():
58-
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
57+
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
58+
59+
60+
def eagle3_model_name():
61+
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
5962

6063

6164
def test_ngram_correctness(
@@ -102,12 +105,13 @@ def test_ngram_correctness(
102105
del spec_llm
103106

104107

108+
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
105109
def test_eagle_correctness(
106110
monkeypatch: pytest.MonkeyPatch,
107111
test_prompts: list[list[dict[str, Any]]],
108112
sampling_config: SamplingParams,
109113
model_name: str,
110-
eagle_model_name: str,
114+
use_eagle3: bool,
111115
):
112116
'''
113117
Compare the outputs of a original LLM and a speculative LLM
@@ -116,18 +120,22 @@ def test_eagle_correctness(
116120
with monkeypatch.context() as m:
117121
m.setenv("VLLM_USE_V1", "1")
118122

119-
ref_llm = LLM(model=model_name, max_model_len=1024)
123+
ref_llm = LLM(model=model_name, max_model_len=2048)
120124
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
121125
del ref_llm
122126

127+
spec_model_name = eagle3_model_name(
128+
) if use_eagle3 else eagle_model_name()
123129
spec_llm = LLM(
124130
model=model_name,
131+
trust_remote_code=True,
125132
speculative_config={
126-
"method": "eagle",
127-
"model": eagle_model_name,
133+
"method": "eagle3" if use_eagle3 else "eagle",
134+
"model": spec_model_name,
128135
"num_speculative_tokens": 3,
136+
"max_model_len": 2048,
129137
},
130-
max_model_len=1024,
138+
max_model_len=2048,
131139
)
132140
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
133141
matches = 0

vllm/config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,9 +2339,10 @@ def __post_init__(self):
23392339
)
23402340

23412341
# Automatically detect the method
2342-
if self.method == 'eagle':
2342+
if self.method in ('eagle', 'eagle3'):
23432343
pass
2344-
elif "eagle-" in self.draft_model_config.model.lower():
2344+
elif "eagle-" in self.draft_model_config.model.lower() or \
2345+
"eagle3-" in self.draft_model_config.model.lower():
23452346
self.method = "eagle"
23462347
elif self.draft_model_config.hf_config.model_type == "medusa":
23472348
self.method = "medusa"
@@ -2352,7 +2353,7 @@ def __post_init__(self):
23522353
self.method = "draft_model"
23532354

23542355
# Replace hf_config for EAGLE draft_model
2355-
if self.method == "eagle":
2356+
if self.method in ("eagle", "eagle3"):
23562357
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
23572358
raise ValueError(
23582359
"Chunked prefill and EAGLE are not compatible "
@@ -2549,6 +2550,12 @@ def _verify_args(self) -> None:
25492550
"speculative decoding is > 1, but got "
25502551
f"{self.disable_by_batch_size=}")
25512552

2553+
if self.method == "eagle3" and self.target_model_config and \
2554+
"llama" not in self.target_model_config.hf_text_config.model_type:
2555+
raise ValueError(
2556+
"Eagle3 is only supported for Llama models. "
2557+
f"Got {self.target_model_config.hf_text_config.model_type=}")
2558+
25522559
@property
25532560
def num_lookahead_slots(self) -> int:
25542561
"""The number of additional slots the scheduler should allocate per

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1459,7 +1459,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14591459
if speculative_method:
14601460
if speculative_method in ("ngram", "[ngram]"):
14611461
is_ngram_enabled = True
1462-
elif speculative_method == "eagle":
1462+
elif speculative_method in ("eagle", "eagle3"):
14631463
is_eagle_enabled = True
14641464
else:
14651465
speculative_model = self.speculative_config.get("model")

vllm/model_executor/models/llama.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ def __init__(self,
330330
else:
331331
self.norm = PPMissingLayer()
332332

333+
self.aux_hidden_state_layers: tuple[int] = tuple()
334+
333335
self.make_empty_intermediate_tensors = (
334336
make_empty_intermediate_tensors_factory(
335337
["hidden_states", "residual"], config.hidden_size))
@@ -355,7 +357,11 @@ def forward(
355357
hidden_states = intermediate_tensors["hidden_states"]
356358
residual = intermediate_tensors["residual"]
357359

358-
for layer in self.layers[self.start_layer:self.end_layer]:
360+
aux_hidden_states = []
361+
for idx, layer in enumerate(
362+
self.layers[self.start_layer:self.end_layer]):
363+
if idx in self.aux_hidden_state_layers:
364+
aux_hidden_states.append(hidden_states + residual)
359365
hidden_states, residual = layer(positions, hidden_states, residual)
360366

361367
if not get_pp_group().is_last_rank:
@@ -365,6 +371,9 @@ def forward(
365371
})
366372

367373
hidden_states, _ = self.norm(hidden_states, residual)
374+
375+
if len(aux_hidden_states) > 0:
376+
return hidden_states, aux_hidden_states
368377
return hidden_states
369378

370379
def load_weights(self, weights: Iterable[Tuple[str,
@@ -517,6 +526,13 @@ def __init__(self,
517526
self.make_empty_intermediate_tensors = (
518527
self.model.make_empty_intermediate_tensors)
519528

529+
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
530+
self.model.aux_hidden_state_layers = layers
531+
532+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
533+
num_layers = len(self.model.layers)
534+
return (2, num_layers // 2, num_layers - 3)
535+
520536
def _init_model(self,
521537
vllm_config: VllmConfig,
522538
prefix: str = "",

vllm/model_executor/models/llama_eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def forward(
8282
hidden_states,
8383
residual,
8484
)
85-
return hidden_states + residual
85+
hidden_states = hidden_states + residual
86+
return hidden_states, hidden_states
8687

8788
def load_weights(self, weights: Iterable[Tuple[str,
8889
torch.Tensor]]) -> Set[str]:

0 commit comments

Comments
 (0)