-
Notifications
You must be signed in to change notification settings - Fork 29.1k
Allow static cache to be larger than sequence length / batch size for encoder-decoder models #35444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc cache masters @gante @zucchini-nlp @ArthurZucker |
@cptspacemanspiff hey! If you are init the cache object outside
This is correct intuition and has to be handled by the user if not calling https://huggingface.co/docs/transformers/en/llm_optims#static-kv-cache-and-torchcompile |
Hi! Thanks for the response (and adding the static cache to the T5 model :) I guess I am trying to say that I don't think that the current state is the desired behavior. Take the fixed encoder sequence length, currently even in generate, if you call generate multiple times, with different sequence lengths the cross attention static cache is not reused. (It ends up building a new cache for each encoder sequence length, same as when you feed it multiple batch sizes.) This is not really a problem for models like whisper, where the encoder length is fixed, but for t5/bart/others one of the use cases would be machine translation, where you feed it varying sequence lengths, on the encoder end. I started dealing with this as part of doing export with executorch, where all memory must be preallocated. So I would end up having to choose an encoder sequence length, and that is the only sequence length I can use on device, which is less than ideal. There is a similar story with the batch size. If the encoder/decoder model logic is changed to not use the whole kv-cache available/batch size available, then this issue is resolved and I can choose my encoder sequence length/ batch size at runtime (as long as I stay smaller than the max) |
Ah I see what you mean. I agree that making Regarding the batch size argument, it is something related to all models, not only encoder-decoder ones. That means we would need to change all decoder-only models in transformers. Personally, I believe it is a good feature to have as long as it doesn't break the main purpose of static cache ( Let's see what @gante has to say on that, as a generation code master :) |
Thanks, As an aside, I think (maybe, possibly): The batch size change can be implemented entirely from within the static cache inside cache_utils, this is because as long as values are pulled from the cache via .update where the model passes in the current cache position and associated kv values. When it does this the currently run batch size is the first dimension of the kv values passed in. The cache object can then just return only the first n rows of the cache. As long as decoder models do not access the cache directly ala past_key_value.key_cache[self.layer_idx], it should work. Along this note, should direct access to the cache via past_key_value.key_cache[self.layer_idx] be a public api? Also, with regards to the torch compile, I got this all working here (at least an initial pass) in #35445 (though there the batch size stuff is split between the model code and the static cache.) That being said while torch compile does succeed there, I did not check that the compiled output did not add a bunch of extra copy ops and such.... the slicing stuff is a pain... |
Hey @cptspacemanspiff 👋 #37394 adds support for the caches to be used with smaller batch sizes than |
Uh oh!
There was an error while loading. Please reload this page.
Feature request
In encoder decoder models using an encoder-decoder cache object when using a static cache:
Motivation
I have been working on executorch export for encoder-decoder models. as part of that I have been digging into the implementation of the encoder-decoder cache and static cache.
How I would expect static caches to work is that when you initialize the cache, then as long as your generation (batch size, encoder sequence length, decoder sequence length) is less than the associated cache values, it should work.
Currently however:
Your contribution
As I was digging through this, I updated the T5 attention and the static cache implementation in an attempt to handle both these cases.
#35445
That being said, I am just starting to learn transformers (both the hf library and in general), and have no real idea what I am doing.
Here is the code I have been using to generate the issue:
The text was updated successfully, but these errors were encountered: