Skip to content

Allow static cache to be larger than sequence length / batch size for encoder-decoder models #35444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
cptspacemanspiff opened this issue Dec 29, 2024 · 6 comments
Labels
Feature request Request for a new feature

Comments

@cptspacemanspiff
Copy link

cptspacemanspiff commented Dec 29, 2024

Feature request

In encoder decoder models using an encoder-decoder cache object when using a static cache:

  1. the cross-attention cache size must equal the encoder sequence length.
  2. batch size for both self-attention and cross-attention caches must be the same as the generating batch size.

Motivation

I have been working on executorch export for encoder-decoder models. as part of that I have been digging into the implementation of the encoder-decoder cache and static cache.

How I would expect static caches to work is that when you initialize the cache, then as long as your generation (batch size, encoder sequence length, decoder sequence length) is less than the associated cache values, it should work.

Currently however:

  1. The cross attention cache must be exactly the size as the encoder sequence length.
  2. The batch size that the cache is initialized with must be exactly the batch size that the cache is run with.

Your contribution

As I was digging through this, I updated the T5 attention and the static cache implementation in an attempt to handle both these cases.

#35445

That being said, I am just starting to learn transformers (both the hf library and in general), and have no real idea what I am doing.

Here is the code I have been using to generate the issue:

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
)
from transformers.cache_utils import (
    StaticCache,
    EncoderDecoderCache,
)

model_name = "google-t5/t5-small"

dtype = torch.float16

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
)


encoder_cache = StaticCache(
    model.config, max_cache_len=170, max_batch_size=4, dtype=dtype
)
decoder_cache = StaticCache(
    model.config, max_cache_len=200, max_batch_size=4, dtype=dtype
)
cache = EncoderDecoderCache(decoder_cache, encoder_cache)

strings_1 = [
    "When the night has come and the land is dark, and the moon is the only light we will see.",
    "Abba is the best",
    # "No lindy is the best",
    # "No Elton john is the absolute best.",
]
input_ids = tokenizer(strings_1, return_tensors="pt", padding=True)
tokens = model.generate(**input_ids, past_key_values=cache)
text_translated = [tokenizer.decode(t, skip_special_tokens=False) for t in tokens]
print(text_translated)
@cptspacemanspiff cptspacemanspiff added the Feature request Request for a new feature label Dec 29, 2024
@LysandreJik
Copy link
Member

cc cache masters @gante @zucchini-nlp @ArthurZucker

@zucchini-nlp
Copy link
Member

@cptspacemanspiff hey!

If you are init the cache object outside generate(), you would have to set the correct batch size (i.e. the batch size of your current input) and the correct max cache length yourself.

The cross attention cache must be exactly the size as the encoder sequence length.
The batch size that the cache is initialized with must be exactly the batch size that the cache is run with.

This is correct intuition and has to be handled by the user if not calling generate with cache_implementation. Otherwise generate() will handle eveyrhting internally. See below docs for more on Static Cache

https://huggingface.co/docs/transformers/en/llm_optims#static-kv-cache-and-torchcompile
https://huggingface.co/docs/transformers/en/kv_cache#static-cache

@cptspacemanspiff
Copy link
Author

Hi!

Thanks for the response (and adding the static cache to the T5 model :)

I guess I am trying to say that I don't think that the current state is the desired behavior.

Take the fixed encoder sequence length, currently even in generate, if you call generate multiple times, with different sequence lengths the cross attention static cache is not reused.

(It ends up building a new cache for each encoder sequence length, same as when you feed it multiple batch sizes.)

This is not really a problem for models like whisper, where the encoder length is fixed, but for t5/bart/others one of the use cases would be machine translation, where you feed it varying sequence lengths, on the encoder end.


I started dealing with this as part of doing export with executorch, where all memory must be preallocated.

So I would end up having to choose an encoder sequence length, and that is the only sequence length I can use on device, which is less than ideal. There is a similar story with the batch size.

If the encoder/decoder model logic is changed to not use the whole kv-cache available/batch size available, then this issue is resolved and I can choose my encoder sequence length/ batch size at runtime (as long as I stay smaller than the max)

@zucchini-nlp
Copy link
Member

Ah I see what you mean. I agree that making seq length flexible is a desired feature as we would love to have a re-usable cache object with a pre-defined max length. Indeed it works only on decoder-only models currently and we can add same for encoder cache to be aligned with the general idea

Regarding the batch size argument, it is something related to all models, not only encoder-decoder ones. That means we would need to change all decoder-only models in transformers. Personally, I believe it is a good feature to have as long as it doesn't break the main purpose of static cache (torch.compile compatibility). I don't think it will, because we rely on shapes only.

Let's see what @gante has to say on that, as a generation code master :)

@cptspacemanspiff
Copy link
Author

cptspacemanspiff commented Jan 6, 2025

Thanks,


As an aside, I think (maybe, possibly):

The batch size change can be implemented entirely from within the static cache inside cache_utils, this is because as long as values are pulled from the cache via .update where the model passes in the current cache position and associated kv values. When it does this the currently run batch size is the first dimension of the kv values passed in. The cache object can then just return only the first n rows of the cache.

As long as decoder models do not access the cache directly ala past_key_value.key_cache[self.layer_idx], it should work.
(a first pass search Jamba, llama, and pix2struct are the only ones in models that directly access the cache, aside from T5/whisper + derivatives that directly access it for the cross attention.)

Along this note, should direct access to the cache via past_key_value.key_cache[self.layer_idx] be a public api?


Also, with regards to the torch compile, I got this all working here (at least an initial pass) in #35445 (though there the batch size stuff is split between the model code and the static cache.) That being said while torch compile does succeed there, I did not check that the compiled output did not add a bunch of extra copy ops and such.... the slicing stuff is a pain...

@gante
Copy link
Member

gante commented Apr 9, 2025

Hey @cptspacemanspiff 👋

#37394 adds support for the caches to be used with smaller batch sizes than max_batch_size. The other part of the issue you raised (variable encoder length) is also relevant, I'll work on it some time in the future :) (a PR to enable it would also be welcome!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants