@@ -232,8 +232,7 @@ def __init__(
232
232
quant_config = quant_config ,
233
233
prefix = f"{ prefix } .o_proj" )
234
234
235
- self .position_embedding_type = config .position_embedding_type
236
- if self .position_embedding_type == "rope" :
235
+ if config .position_embedding_type == "rope" :
237
236
self .rotary_emb = get_rope (
238
237
self .head_dim ,
239
238
rotary_dim = self .head_dim ,
@@ -244,6 +243,8 @@ def __init__(
244
243
and config .rope_scaling is not None else None ,
245
244
is_neox_style = True ,
246
245
)
246
+ else :
247
+ self .rotary_emb = None
247
248
248
249
self .attn = Attention (self .num_heads ,
249
250
self .head_dim ,
@@ -263,7 +264,7 @@ def forward(
263
264
key = self .k_proj (hidden_states )[0 ]
264
265
value = self .v_proj (hidden_states )[0 ]
265
266
266
- if self .position_embedding_type == "rope" :
267
+ if self .rotary_emb is not None :
267
268
query , key = self .rotary_emb (positions , query , key )
268
269
269
270
hidden_states = self .attn (query , key , value )
@@ -349,11 +350,11 @@ def forward(
349
350
hidden_states = hidden_states * self .embedding_multiplier
350
351
residual = None
351
352
else :
352
- assert intermediate_tensors is not None
353
+ if intermediate_tensors is None :
354
+ raise RuntimeError ('Intermediate tensors may not be None!' )
353
355
hidden_states = intermediate_tensors ["hidden_states" ]
354
356
residual = intermediate_tensors ["residual" ]
355
357
356
- residual = None
357
358
num_attn = 0
358
359
for i in range (len (self .layers )):
359
360
layer = self .layers [i ]
@@ -463,18 +464,19 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
463
464
embedding_padding_modules = ["lm_head" ]
464
465
465
466
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
467
+ super ().__init__ ()
468
+
466
469
config = vllm_config .model_config .hf_config
467
470
self .vllm_config = vllm_config
468
471
self .model_config = vllm_config .model_config
469
472
cache_config = vllm_config .cache_config
470
473
lora_config = vllm_config .lora_config
471
474
scheduler_config = vllm_config .scheduler_config
472
- assert not cache_config .enable_prefix_caching , \
473
- "GraniteMoeHybrid currently does not support prefix caching"
475
+ if cache_config .enable_prefix_caching :
476
+ raise RuntimeError (
477
+ "GraniteMoeHybrid currently does not support prefix caching" )
474
478
475
479
self .quant_config = vllm_config .quant_config
476
-
477
- super ().__init__ ()
478
480
self .config = config
479
481
self .scheduler_config = scheduler_config
480
482
self .model = GraniteMoeHybridModel (vllm_config = vllm_config ,
0 commit comments