Skip to content

Add GraniteMoeHybrid support for 4.0 #37658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

Ssukriti
Copy link
Contributor

What does this PR do?

The PR adds support for upcoming Granite4.0 models. It terms of model architecture, it is a hybrid class with shared MLP layer and Bamba layers.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1
Copy link
Member

cc @ArthurZucker for text models!

class GraniteMoeHybridSdpaAttention(GraniteMoeSharedSdpaAttention):
pass

GRANITEMOEHYBRID_ATTENTION_CLASSES = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a heads up, I think it would be nice to follow using the new attention interface (see #35235 for the original PR). Llama can also provide a good first pointer for this, e.g.

class LlamaAttention(nn.Module):

(Except I'm missing that this is a more special kind of attention here :D )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up @vasqu! We are still cleaning up this branch a bit, will take a look at this once the tests are in a better state 🙂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer @vasqu! Refactored this PR to the new attention interface 😄

@vasqu
Copy link
Contributor

vasqu commented Apr 22, 2025

ccing @molbap for mamba2/bamba (feels like I'm pinging you constantly 😆)

@alex-jw-brooks alex-jw-brooks force-pushed the granitemoe_hybrid_external_cleanup branch 5 times, most recently from ac9b018 to d751d26 Compare April 29, 2025 22:52
Ssukriti added 21 commits April 30, 2025 12:53
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
@alex-jw-brooks alex-jw-brooks force-pushed the granitemoe_hybrid_external_cleanup branch from a70d949 to 8274d2c Compare April 30, 2025 22:40
@Ssukriti Ssukriti marked this pull request as ready for review April 30, 2025 23:43
@alex-jw-brooks
Copy link
Contributor

Thanks @ArthurZucker! It's ready for another look when you get the chance!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice use of modular thanks a lot! 🤗


hidden_states = self.input_layernorm(hidden_states)
self_attn_weights = None
if self.layer_type == "mamba":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking let's remove the check on type, rely rather on the check of self.self_attn is not None?

Copy link
Contributor

@alex-jw-brooks alex-jw-brooks May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I also didn't like self.mamba being conditionally undefined. Updated this to define both in __init__ and just check do mamba if self.mamba is not None and attention otherwise 🙂

Comment on lines 138 to 139
else:
raise ValueError(f"Expected layer type in ['attention', 'mamba'], got {self.layer_type}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still todo 😉

hidden_states = self.post_attention_layernorm(hidden_states)
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)

if self.shared_mlp is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if you answered or not, is there two different checkpoint being released, one with / one without this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models that are about to come out do use it! I think there are likely experiments ongoing without it, but am not sure about concrete plans for when they'll be released since I'm not the one training the models 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case lets remove what's uncertain! 🤗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Removed the case with 0 experts, I'll open a follow-up PR if it ends up being used in a model to be released 😄

Comment on lines 381 to 403
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
layer_mask,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_router_logits,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use the new GradientCHeckpointingLayer wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, that is a lot cleaner! I updated the models in the chain for modular to all use the gradient checkpointing layer (GraniteMoe/GraniteMoeShared/GraniteMoeHybrid)

Comment on lines 428 to 429
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a @can_return_tuple for the forward

Suggested change
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Added

)


class GraniteMoeHybridModelTester:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we try to inherit tests from closes model so mambda in the same fashion as here

class Gemma2ModelTester(GemmaModelTester):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! The closest models are for the tests are bamba. Consolidated a bit to use Bamba tests, should be way easier to look at now 🤞

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect!

Copy link

@berserkr berserkr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- std initialized twice - std = self.config.initializer_range

align test init

delete more tests

Use common layer init with bamba tests

finish test consolidation
@alex-jw-brooks
Copy link
Contributor

Thanks @berserkr! There were two because of modular expanding the superclass implementation that also set it. Updated to just pass the config value directly so it's less weird looking 🙂

@alex-jw-brooks
Copy link
Contributor

Thank you very much for the fast review @ArthurZucker! I've made all the changes 🙂

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marvelous ! Merging once the build PR passes! (should be easy to fix!)

@alex-jw-brooks alex-jw-brooks force-pushed the granitemoe_hybrid_external_cleanup branch from 6b0ba0c to 1c0272a Compare May 2, 2025 14:04
@alex-jw-brooks
Copy link
Contributor

Thanks @ArthurZucker! Added the missing TOC entry and removed the currently unused shared condition for the MLP, should pass now! 🤞

@ArthurZucker ArthurZucker merged commit 471958b into huggingface:main May 6, 2025
18 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* initial config and MLA layer

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* first pass at decoder

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* completion of layers

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* modeling class

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* adding hybrid class to imports

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix imports granitemoehybrid

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix granitehybrid imports

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix granitehybrid import

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix generated modeling file

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* add some comments

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* minor fixes in layers

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* add sharedMLP layer

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* correct layer names

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fixes in mamba config

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix mamba config

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* change name of MLP layer

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix seq mizer layers

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* correct mamba config

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fixes in param names

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* enable hybrid model

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* update config

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix config granite hybrid

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix attention layer

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* cleanup to re-use mamba code

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* keep layer types

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* attention bias cleanup

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* update mamba layer name

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* first pass at tests

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* first pass at tests

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* use granite attention

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix: self attn weights

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* pass at making pos_emb optional

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* initialize self_attn only as needed

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* overwrite forward to create HybridMambaCache

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* Log invalid layer types

* Add attention outputs test

* Only emit attentions/logits if not None

* Fix config test hidden size divisibility

* mark granitmoehybrid as stateful

* Initialize mamba convolutional layers

* Formatting fixes

* config docstring, removed some unused attrs

* Fix missing arg in models test

* Fix create and check decoder model test

* support logits to keep in granitemoe

* regen to pass logits_to_keep

* Allow None or rope

* Fix gradient checkpointing

* Add granitemoehybrid as special cache for generate check

* Remove unused MLA refs

* Fix mamba layer mask

* Remove logits to keep from config

* Minor docstring nits

* Update licenses

* Enable cache by default

* map layer types to layer block type

* First pass at granite moe hybrid docs

* Ignore granite moe hybrid in valid checkpoint check

* Align attention interfaces

* regenerate modular granitemoeshared attention interface

* Align granite moe hybrid attn interface

* run formatting

* Handle mamba initialization

* avoid conditional attr defs

* Move hybrid layer validation to config

* Add placeholder integration tests

* Docs nits / Update model names

* Clean up forward conditions

* Use gradient checkpointing layer

* Remove some copied bamba tests + inherit

align test init

delete more tests

Use common layer init with bamba tests

finish test consolidation

* avoid redundant intermediate std var

* use @can_return_tuple

* Remove unused moe state

* make skipped test names consistent

* Fix docstring order

* Add missing toc

* Always create the shared mlp

* Fix name in docstring

* link preview model in docs

---------

Signed-off-by: Sukriti-Sharma4 <[email protected]>
Co-authored-by: Alex-Brooks <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants