A case for performant portable ops #10886
Replies: 4 comments 5 replies
-
Thanks for compiling the numbers on this. I'm excited about this - in addition to the benefits for HF, it will also benefit sequence to sequence tasks, ASR, and potentially simplify enablement for any emerging architectures. |
Beta Was this translation helpful? Give feedback.
-
Here is a "good first issue" to reinplace slice_copy with slice: #10917 |
Beta Was this translation helpful? Give feedback.
-
@kimishpatel I think we can probably register a new cache impl with custom ops to perform in-place cache update in Transformers here: https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py |
Beta Was this translation helpful? Give feedback.
-
There is need for optimum-executorch in enabling transformer based model on ExecuTorch (with eventual target towards non LLMs as well), with decent performance, while etLLM targets more targeted transforms that enable best performance across backends.
In this effort we enabled custom_sdpa, both via graph transform (PR) and via hf’s attention customization API (PR), that improved out-of-box performance significantly. However, there is still a significant gap compared to etLLM. The optimizations that were left out were the ones that are hard to apply in optimum-executorch, namely custom kv cache. This module uses custom op, update_cache, to mutate cache in place without incurring slicing and indexing costs. We wanted to understand the impact of these and maybe other portable ops to prioritize work on improving portable operator performance.
To do this we ran four models, using optimum-executorch, on a ubuntu CI machine. Job details can be found here. Similar profiling on android device is under way. The four models were:
Gemma3 1B
Qwen 3 0.6B
SmolLM2 135M
Llama 3.2 1B
Following is the operator level breakdown, where DELEGATE is xnnpack delegate lowering 4bit quantized linear layers.
Qwen3 0.6B:
SmolLM2 135M
Llama3.2 1B
Note how copy and index_put take up significant portion of the runtime especially in Llama3.2 1B, SmilLM2 and Qwen3 0.6B. Particularly smaller the model worse it is. This is because of a) functionalization that results in full copy of the data and b) lack of mutation means we have to copy entire mutable buffer state back to its original copy. Plus index_put is notoriously hard op to implement in aten compliant manner so it is really really slow.
What can we do?
Three things.
Reverse functionalization: No more copy_.
Implement index_put_.
Implement fßast path for index_put where index updates across a specific dimension should just result in a bunch of memcpy.
On Gemma3 1B:
Note the start difference in Gemma3 where indexing itself is a significant chunk of the issue. Why? Because gemma3 has local attention that uses sliding window and sliding window implementation literally slices out the last N entries from cache and moves it up. See here. There isnt a good way to handle this quite. We have a ring buffer implementation that does a sliding window in a more efficient way, however this requires module swap at the moment. I think the best thing to do would be to upstream our implementation to HF and everyone gets to benefit.
Beta Was this translation helpful? Give feedback.
All reactions