Skip to content

Commit 0dbb1b9

Browse files
committed
Added missing files
Signed-off-by: Thomas Ortner <[email protected]>
1 parent 544a6aa commit 0dbb1b9

File tree

4 files changed

+13
-54
lines changed

4 files changed

+13
-54
lines changed

tests/models/decoder_only/language/test_granitemoehybrid.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

tests/models/language/generation/test_granitemoehybrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_model_equivalence_to_hf_greedy(
2323
dtype: str,
2424
max_tokens: int,
2525
num_logprobs: int,
26-
):
26+
):
2727
with vllm_runner(model, dtype=dtype) as vllm_model:
2828
vllm_outputs = vllm_model.generate_greedy_logprobs(
2929
example_prompts, max_tokens, num_logprobs)

tests/models/language/generation/test_hybrid.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"ai21labs/Jamba-tiny-dev",
2626
# NOTE: ibm-research/granite-4.0-tiny-test are skipped currently as
2727
# the HF model URLs not available yet
28-
"ibm-research/granite-4.0-tiny-test",
28+
# "ibm-research/granite-4.0-tiny-test",
2929
# NOTE: Running Plamo2 in transformers implementation requires to install
3030
# causal-conv1d package, which is not listed as a test dependency as it's
3131
# not compatible with pip-compile.
@@ -49,9 +49,6 @@ def test_models(
4949
max_tokens: int,
5050
num_logprobs: int,
5151
) -> None:
52-
if model == "ibm-research/granite-4.0-tiny-test":
53-
pytest.skip(reason="HF model URLs not available yet")
54-
5552
with hf_runner(model) as hf_model:
5653
hf_outputs = hf_model.generate_greedy_logprobs_limit(
5754
example_prompts, max_tokens, num_logprobs)

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ def __init__(
232232
quant_config=quant_config,
233233
prefix=f"{prefix}.o_proj")
234234

235-
self.position_embedding_type = config.position_embedding_type
236-
if self.position_embedding_type == "rope":
235+
if config.position_embedding_type == "rope":
237236
self.rotary_emb = get_rope(
238237
self.head_dim,
239238
rotary_dim=self.head_dim,
@@ -244,6 +243,8 @@ def __init__(
244243
and config.rope_scaling is not None else None,
245244
is_neox_style=True,
246245
)
246+
else:
247+
self.rotary_emb = None
247248

248249
self.attn = Attention(self.num_heads,
249250
self.head_dim,
@@ -263,7 +264,7 @@ def forward(
263264
key = self.k_proj(hidden_states)[0]
264265
value = self.v_proj(hidden_states)[0]
265266

266-
if self.position_embedding_type == "rope":
267+
if self.rotary_emb is not None:
267268
query, key = self.rotary_emb(positions, query, key)
268269

269270
hidden_states = self.attn(query, key, value)
@@ -349,11 +350,11 @@ def forward(
349350
hidden_states = hidden_states * self.embedding_multiplier
350351
residual = None
351352
else:
352-
assert intermediate_tensors is not None
353+
if intermediate_tensors is None:
354+
raise RuntimeError('Intermediate tensors may not be None!')
353355
hidden_states = intermediate_tensors["hidden_states"]
354356
residual = intermediate_tensors["residual"]
355357

356-
residual = None
357358
num_attn = 0
358359
for i in range(len(self.layers)):
359360
layer = self.layers[i]
@@ -463,18 +464,19 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
463464
embedding_padding_modules = ["lm_head"]
464465

465466
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
467+
super().__init__()
468+
466469
config = vllm_config.model_config.hf_config
467470
self.vllm_config = vllm_config
468471
self.model_config = vllm_config.model_config
469472
cache_config = vllm_config.cache_config
470473
lora_config = vllm_config.lora_config
471474
scheduler_config = vllm_config.scheduler_config
472-
assert not cache_config.enable_prefix_caching, \
473-
"GraniteMoeHybrid currently does not support prefix caching"
475+
if cache_config.enable_prefix_caching:
476+
raise RuntimeError(
477+
"GraniteMoeHybrid currently does not support prefix caching")
474478

475479
self.quant_config = vllm_config.quant_config
476-
477-
super().__init__()
478480
self.config = config
479481
self.scheduler_config = scheduler_config
480482
self.model = GraniteMoeHybridModel(vllm_config=vllm_config,

0 commit comments

Comments
 (0)