-
Notifications
You must be signed in to change notification settings - Fork 29.1k
Add GraniteMoeHybrid support for 4.0 #37658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GraniteMoeHybrid support for 4.0 #37658
Conversation
cc @ArthurZucker for text models! |
class GraniteMoeHybridSdpaAttention(GraniteMoeSharedSdpaAttention): | ||
pass | ||
|
||
GRANITEMOEHYBRID_ATTENTION_CLASSES = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just as a heads up, I think it would be nice to follow using the new attention interface (see #35235 for the original PR). Llama can also provide a good first pointer for this, e.g.
class LlamaAttention(nn.Module): |
(Except I'm missing that this is a more special kind of attention here :D )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the heads up @vasqu! We are still cleaning up this branch a bit, will take a look at this once the tests are in a better state 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer @vasqu! Refactored this PR to the new attention interface 😄
ccing @molbap for mamba2/bamba (feels like I'm pinging you constantly 😆) |
ac9b018
to
d751d26
Compare
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
48b1c23
to
a70d949
Compare
a70d949
to
8274d2c
Compare
Thanks @ArthurZucker! It's ready for another look when you get the chance! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice use of modular thanks a lot! 🤗
|
||
hidden_states = self.input_layernorm(hidden_states) | ||
self_attn_weights = None | ||
if self.layer_type == "mamba": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking let's remove the check on type, rely rather on the check of self.self_attn is not None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I also didn't like self.mamba
being conditionally undefined. Updated this to define both in __init__
and just check do mamba if self.mamba is not None
and attention otherwise 🙂
else: | ||
raise ValueError(f"Expected layer type in ['attention', 'mamba'], got {self.layer_type}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still todo 😉
hidden_states = self.post_attention_layernorm(hidden_states) | ||
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) | ||
|
||
if self.shared_mlp is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if you answered or not, is there two different checkpoint being released, one with / one without this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The models that are about to come out do use it! I think there are likely experiments ongoing without it, but am not sure about concrete plans for when they'll be released since I'm not the one training the models 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case lets remove what's uncertain! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! Removed the case with 0 experts, I'll open a follow-up PR if it ends up being used in a model to be released 😄
if self.gradient_checkpointing and self.training: | ||
layer_outputs = self._gradient_checkpointing_func( | ||
decoder_layer.__call__, | ||
hidden_states, | ||
layer_mask, | ||
past_key_values, | ||
output_attentions, | ||
use_cache, | ||
cache_position, | ||
output_router_logits, | ||
position_embeddings, | ||
) | ||
else: | ||
layer_outputs = decoder_layer( | ||
hidden_states, | ||
attention_mask=layer_mask, | ||
past_key_value=past_key_values, | ||
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
cache_position=cache_position, | ||
output_router_logits=output_router_logits, | ||
position_embeddings=position_embeddings, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the new GradientCHeckpointingLayer wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely, that is a lot cleaner! I updated the models in the chain for modular to all use the gradient checkpointing layer (GraniteMoe/GraniteMoeShared/GraniteMoeHybrid)
if not return_dict: | ||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have a @can_return_tuple
for the forward
if not return_dict: | |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Added
) | ||
|
||
|
||
class GraniteMoeHybridModelTester: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we try to inherit tests from closes model so mambda in the same fashion as here
class Gemma2ModelTester(GemmaModelTester): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! The closest models are for the tests are bamba
. Consolidated a bit to use Bamba tests, should be way easier to look at now 🤞
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transformers/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
Line 1199 in 8274d2c
std = self.config.initializer_range |
std
initialized twice - std = self.config.initializer_range
align test init delete more tests Use common layer init with bamba tests finish test consolidation
Thanks @berserkr! There were two because of |
Thank you very much for the fast review @ArthurZucker! I've made all the changes 🙂 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marvelous ! Merging once the build PR passes! (should be easy to fix!)
6b0ba0c
to
1c0272a
Compare
Thanks @ArthurZucker! Added the missing TOC entry and removed the currently unused shared condition for the MLP, should pass now! 🤞 |
* initial config and MLA layer Signed-off-by: Sukriti-Sharma4 <[email protected]> * first pass at decoder Signed-off-by: Sukriti-Sharma4 <[email protected]> * completion of layers Signed-off-by: Sukriti-Sharma4 <[email protected]> * modeling class Signed-off-by: Sukriti-Sharma4 <[email protected]> * adding hybrid class to imports Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix imports granitemoehybrid Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix granitehybrid imports Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix granitehybrid import Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix generated modeling file Signed-off-by: Sukriti-Sharma4 <[email protected]> * add some comments Signed-off-by: Sukriti-Sharma4 <[email protected]> * minor fixes in layers Signed-off-by: Sukriti-Sharma4 <[email protected]> * add sharedMLP layer Signed-off-by: Sukriti-Sharma4 <[email protected]> * correct layer names Signed-off-by: Sukriti-Sharma4 <[email protected]> * fixes in mamba config Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix mamba config Signed-off-by: Sukriti-Sharma4 <[email protected]> * change name of MLP layer Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix seq mizer layers Signed-off-by: Sukriti-Sharma4 <[email protected]> * correct mamba config Signed-off-by: Sukriti-Sharma4 <[email protected]> * fixes in param names Signed-off-by: Sukriti-Sharma4 <[email protected]> * enable hybrid model Signed-off-by: Sukriti-Sharma4 <[email protected]> * update config Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix config granite hybrid Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix attention layer Signed-off-by: Sukriti-Sharma4 <[email protected]> * cleanup to re-use mamba code Signed-off-by: Sukriti-Sharma4 <[email protected]> * keep layer types Signed-off-by: Sukriti-Sharma4 <[email protected]> * attention bias cleanup Signed-off-by: Sukriti-Sharma4 <[email protected]> * update mamba layer name Signed-off-by: Sukriti-Sharma4 <[email protected]> * first pass at tests Signed-off-by: Sukriti-Sharma4 <[email protected]> * first pass at tests Signed-off-by: Sukriti-Sharma4 <[email protected]> * use granite attention Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix: self attn weights Signed-off-by: Sukriti-Sharma4 <[email protected]> * pass at making pos_emb optional Signed-off-by: Sukriti-Sharma4 <[email protected]> * initialize self_attn only as needed Signed-off-by: Sukriti-Sharma4 <[email protected]> * overwrite forward to create HybridMambaCache Signed-off-by: Sukriti-Sharma4 <[email protected]> * Log invalid layer types * Add attention outputs test * Only emit attentions/logits if not None * Fix config test hidden size divisibility * mark granitmoehybrid as stateful * Initialize mamba convolutional layers * Formatting fixes * config docstring, removed some unused attrs * Fix missing arg in models test * Fix create and check decoder model test * support logits to keep in granitemoe * regen to pass logits_to_keep * Allow None or rope * Fix gradient checkpointing * Add granitemoehybrid as special cache for generate check * Remove unused MLA refs * Fix mamba layer mask * Remove logits to keep from config * Minor docstring nits * Update licenses * Enable cache by default * map layer types to layer block type * First pass at granite moe hybrid docs * Ignore granite moe hybrid in valid checkpoint check * Align attention interfaces * regenerate modular granitemoeshared attention interface * Align granite moe hybrid attn interface * run formatting * Handle mamba initialization * avoid conditional attr defs * Move hybrid layer validation to config * Add placeholder integration tests * Docs nits / Update model names * Clean up forward conditions * Use gradient checkpointing layer * Remove some copied bamba tests + inherit align test init delete more tests Use common layer init with bamba tests finish test consolidation * avoid redundant intermediate std var * use @can_return_tuple * Remove unused moe state * make skipped test names consistent * Fix docstring order * Add missing toc * Always create the shared mlp * Fix name in docstring * link preview model in docs --------- Signed-off-by: Sukriti-Sharma4 <[email protected]> Co-authored-by: Alex-Brooks <[email protected]>
What does this PR do?
The PR adds support for upcoming Granite4.0 models. It terms of model architecture, it is a hybrid class with shared MLP layer and Bamba layers.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.