diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 1bc06ccf4..0481ace3e 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -25,7 +25,12 @@ def check_qaic_sdk(): # Conditionally import QAIC-related modules if the SDK is installed __version__ = "0.0.1.dev0" if QAIC_INSTALLED: - from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader + from QEfficient.base import ( + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForImageTextToText, + QEFFCommonLoader, + ) from QEfficient.compile.compile_helper import compile from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv @@ -43,6 +48,7 @@ def check_qaic_sdk(): "QEFFAutoModel", "QEFFAutoModelForCausalLM", "QEffAutoPeftModelForCausalLM", + "QEFFAutoModelForImageTextToText", "QEFFCommonLoader", ] diff --git a/QEfficient/base/__init__.py b/QEfficient/base/__init__.py index 86cff11c1..4344cac53 100644 --- a/QEfficient/base/__init__.py +++ b/QEfficient/base/__init__.py @@ -6,4 +6,8 @@ # ----------------------------------------------------------------------------- from QEfficient.base.common import QEFFCommonLoader # noqa: F401 -from QEfficient.transformers.models.modeling_auto import QEFFAutoModel, QEFFAutoModelForCausalLM # noqa: F401 +from QEfficient.transformers.models.modeling_auto import ( # noqa: F401 + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForImageTextToText, +) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 2760cf52f..6bd697f00 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -18,6 +18,7 @@ import onnx import torch +import torch.nn as nn from QEfficient.base.onnx_transforms import OnnxTransform from QEfficient.base.pytorch_transforms import PytorchTransform @@ -120,6 +121,7 @@ def _export( export_kwargs: Optional[Dict[str, any]] = None, onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, + model: nn.Module = None, ) -> str: """ Export the Pytorch model to ONNX. @@ -132,6 +134,9 @@ def _export( :onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class. :export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model. """ + if model: + self.model=model + export_dir = Path(export_dir or (QEFF_HOME / self.model_name)) export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash) onnx_path = export_dir / f"{self.model_name}.onnx" @@ -175,6 +180,7 @@ def _export( } if onnx_transform_kwargs is not None: transform_kwargs.update(onnx_transform_kwargs) + for transform in self._onnx_transforms: model, transformed = transform.apply(model, **transform_kwargs) model.metadata_props.append( diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index f749cc0c3..23364655f 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -6,8 +6,9 @@ # ----------------------------------------------------------------------------- from collections import namedtuple -from typing import Dict, Type +from typing import Dict, Optional, Tuple, Type +import torch import torch.nn as nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -242,3 +243,95 @@ GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, } + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32) + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32) + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +def _create_causal_mask( + position_ids, + target_length, + sliding_window: Optional[int] = None, +): + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + """ + if sliding_window is not None: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, -1) + # --- Rolling buffer --- + pos_max = position_ids.max(1, keepdim=True).values + kv_start = (pos_max // target_length) * target_length + kv_indices_high = kv_indices + kv_start + kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length) + kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) + kv_indices = kv_indices.unsqueeze(1) + # ------ + causal_mask = kv_indices > query_indices + attention_mask = causal_mask + + window_indices = query_indices - sliding_window + 1 + window_mask = kv_indices < window_indices + attention_mask = attention_mask | window_mask + attention_mask = attention_mask.unsqueeze(1) + else: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, 1, -1) + attention_mask = kv_indices > query_indices + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index e2f551415..c152f540d 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -10,13 +10,16 @@ import math from typing import List, Optional, Tuple, Union +import requests import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( + BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ) @@ -25,93 +28,34 @@ MllamaConfig, MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, + MllamaForConditionalGeneration, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextModel, MllamaTextSelfAttention, + MllamaVisionModel, logger, repeat_kv, rotate_half, ) -from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask - - -class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): - """ - Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py - The only differences are: - - Add static sin/cos computations. - """ - - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[MllamaConfig] = None, - ): - super(MllamaRotaryEmbedding, self).__init__() # Initialize nn.Module - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_utils import ( + _create_causal_mask, + _prepare_aspect_ratio_attention_mask, + _prepare_cross_attention_mask, +) +from QEfficient.utils import constants +from QEfficient.utils.constants import Constants - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) +bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE +max_num_images = constants.ONNX_EXPORT_MAX_NUM_IMAGES +max_image_tiles = constants.ONNX_EXPORT_MAX_IMAGE_TILES +image_length = constants.ONNX_EXPORT_IMAGE_LENGHT +image_width = constants.ONNX_EXPORT_IMAGE_WIDTH +num_channel = constants.ONNX_EXPORT_IMAGE_DEPTH +seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): @@ -145,6 +89,94 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed.to(q.dtype), k_embed.to(k.dtype) +class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): + """ + Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py + The only differences are: + - add new args cache idx for the kv retention + """ + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) + + # elif past_key_value is not None: + # Fetch old cache + key_states_old = past_key_value.key_cache[self.layer_idx] + value_states_old = past_key_value.value_cache[self.layer_idx] + + # if cross_attention_states is not None: + # Compute new KV states + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # if past_key_value is not None: + # # if we have a new image + new tokens, we only computed key_states on that new image + # # we still update the cross key states, past_image, new_image. And use it! + # key_states, value_states = past_key_value.update( + # key_states, + # value_states, + # self.layer_idx, + # {"batch_index": batch_index, "position_ids": position_ids}, + # ) + + # Out-of-place Scatter new into old + # out-of-place is important so the original tensor is not affected, + # otherwise leads to same operations in both graphs + indices = (torch.arange(bsz),) + key_states_new = torch.index_put(key_states_old, indices, key_states) + value_states_new = torch.index_put(value_states_old, indices, value_states) + + # Select old or new image KV states based on q_len + key_states = torch.where(q_len == 1, key_states_old, key_states_new) + value_states = torch.where(q_len == 1, value_states_old, value_states_new) + + # Update the image cache + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + # attn_weights = torch.where( + # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights + # ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py @@ -200,7 +232,12 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -227,74 +264,6 @@ def forward( return attn_output, attn_weights, past_key_value -class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): - """ - Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py - The only differences are: - - add new args cache idx for the kv retention - """ - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - batch_index: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.q_norm(query_states) - - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) - if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"batch_index": batch_index, "position_ids": position_ids} - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - class QEffMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py @@ -428,22 +397,227 @@ def forward( return outputs -class QEffMllamaTextModel(MllamaTextModel): +class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py The only differences are: - - add new args cache idx for the kv retention + - Add static sin/cos computations. """ - # def __init__(self, config: MllamaTextConfig): - # super().__init__(config) - # self.config = config - # self.__qeff_init__() + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[MllamaConfig] = None, + ): + super(MllamaRotaryEmbedding, self).__init__() # Initialize nn.Module + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +class QEffMllamaVisionModel(MllamaVisionModel): + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape + + pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) + + # Add cls token + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape(batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim + ) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) + intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices] + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1 + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + + if output_hidden_states: + hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1]) + else: + hidden_states = None - # def __qeff_init__(self): - # self.layers = nn.ModuleList( - # [MllamaSelfAttentionDecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.num_hidden_layers)] - # ) + if output_attentions: + # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range + global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) + attentions = tuple(output[2]) + global_attn + else: + attentions = None + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + +class QEffMllamaTextModel(MllamaTextModel): + """ + Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py + The only differences are: + - add new args cache idx for the kv retention + """ def forward( self, @@ -462,28 +636,6 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - """ - - Returns: - - Example: - - ```python - >>> from transformers import AutoProcessor, MllamaTextModel - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaTextModel.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> text = "<|image|>If I had to write a haiku for this one" - >>> inputs = processor(text=text, return_tensors="pt") - - >>> output = model(**inputs) - - >>> print(output.last_hidden_state.shape) - torch.Size([1, 13, 4096]) - ``` - """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -513,20 +665,28 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions + attention_mask, + inputs_embeds, + cache_position, + position_ids, + past_key_values, + output_attentions, ) # embed positions hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = None # decoder layers all_hidden_states = () if output_hidden_states else None @@ -552,7 +712,6 @@ def forward( # TODO: vbaddi: since past_key_values are retained from previous states, the condition for is_cross_attention_cache_empty is False # so explicitly making it true in order to skip the cross attention for language model # comment once there is vision and cross attention support - is_cross_attention_cache_empty = True if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue @@ -710,39 +869,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MllamaForCausalLM - - >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") - >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") - - >>> prompt = "If I had to write a haiku, it would be:" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) - >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - >>> print(result) - If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. - I love the idea of snowflakes gently falling, each one - ``` - """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -797,3 +923,378 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class QEffMllamaForConditionalGeneration(MllamaForConditionalGeneration): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs + + def generate_input(self, kv_offload): + # vision_inputs + vision_inputs = { + "pixel_values": torch.zeros( + (bs, max_num_images, max_image_tiles, num_channel, image_length, image_width), dtype=torch.float32 + ), + "aspect_ratio_ids": torch.ones((bs, max_num_images), dtype=torch.int64), + "aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles), dtype=torch.int64), + } + + vision_output_names = [] + for i in self.config.text_config.cross_attention_layers: + vision_output_names.append(f"past_key.{i}") + vision_output_names.append(f"past_value.{i}") + + vision_dynamic_axes = { + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": { + 0: "batch_size", + 1: "max_num_images", + 2: "max_image_tiles", + }, + } + + # lang_inputs + lang_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), + "cross_attention_mask": torch.zeros((bs, seq_len, max_num_images,max_image_tiles), dtype=torch.int64), + "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64), + } + + lang_inputs["position_ids"] = torch.where( + lang_inputs.pop("attention_mask") == 1, + torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), + -1, + ) + + ctx_len = Constants.CTX_LEN + txt_cfg = self.config.get_text_config() + num_hidden_layers = txt_cfg.num_hidden_layers + cross_attention_layers = txt_cfg.cross_attention_layers + num_key_value_heads = txt_cfg.num_key_value_heads + head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + + vis_cfg = self.config.vision_config + num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 + image_tokens_len = vis_cfg.max_num_tiles * num_patches + + lang_inputs["past_key_values"] = DynamicCache(num_hidden_layers) + lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers + lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + + for i in range(num_hidden_layers): + if i in cross_attention_layers: + idx = cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + lang_inputs["past_key_values"].key_cache[i] = torch.zeros( + 1, num_key_value_heads, image_tokens_len, head_dim + ) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros( + 1, num_key_value_heads, image_tokens_len, head_dim + ) + else: + lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + + lang_output_names = [ + "logits", + *[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]], + ] + + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "cross_attention_mask": { + 0: "batch_size", + 1: "seq_len", + 2: "max_num_images", + 3: "max_image_tiles", + }, + } + + for i in range(num_hidden_layers): + if i in cross_attention_layers: + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + else: + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() + lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, ctx_len - 1) + + inputs = [] + output_names = [] + dynamic_axes = [] + + if kv_offload: + inputs.extend([vision_inputs, lang_inputs]) + output_names.append(vision_output_names) + output_names.append(lang_output_names) + dynamic_axes.extend([vision_dynamic_axes, lang_dynamic_axes]) + else: + inputs.append({**vision_inputs, **lang_inputs}) + output_names.append(lang_output_names) + dynamic_axes.append({**vision_dynamic_axes, **lang_dynamic_axes}) + + return inputs, output_names, dynamic_axes + + +class VisionEncoder(nn.Module): + def __init__(self, mllama: MllamaForConditionalGeneration): + super().__init__() + self.mllama = mllama + self.cross_attention_layers = self.mllama.config.get_text_config().cross_attention_layers + self.config = self.mllama.config.get_text_config() + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + ) -> List[Tuple[torch.Tensor]]: + vision_outputs = self.mllama.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.mllama.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.mllama.hidden_size + ) + + bsz = pixel_values.shape[0] + outputs = [] + for i in self.cross_attention_layers: + cross_attn = self.mllama.language_model.model.layers[i].cross_attn + key_states = cross_attn.k_proj(cross_attention_states) + value_states = cross_attn.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose( + 1, 2 + ) + + outputs.append((key_states, value_states)) + return outputs + + +class ModelWrapper(nn.Module): + def __init__(self, mllama): + super().__init__() + self.mllama = mllama + self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers + self.config = self.mllama.config.get_text_config() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ): + if past_key_values is not None: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + outputs = self.mllama( + input_ids=input_ids, + pixel_values=pixel_values, + aspect_ratio_mask=aspect_ratio_mask, + aspect_ratio_ids=aspect_ratio_ids, + attention_mask=attention_mask, + cross_attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + if "past_key_values" in outputs: + outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + return outputs + + def generate_input(self, processor): + ctx_len = 1024 + txt_cfg = self.mllama.config.get_text_config() + num_hidden_layers = txt_cfg.num_hidden_layers + cross_attention_layers = txt_cfg.cross_attention_layers + num_key_value_heads = txt_cfg.num_key_value_heads + head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + + vis_cfg = self.mllama.config.vision_config + num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 + image_tokens_len = vis_cfg.max_num_tiles * num_patches + + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + from PIL import Image + image = Image.open(requests.get(url, stream=True).raw) + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": "How long does it take from invoice date to due date? Be short and concise.", + }, + ], + } + ] + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + inputs = processor(text=prompt, images=image, return_tensors="pt", add_special_tokens=False) + inputs["position_ids"] = torch.where( + inputs.pop("attention_mask") == 1, + torch.arange(inputs["input_ids"].shape[1]).view(1, -1), + -1, + ) + inputs = dict(inputs) + inputs["past_key_values"] = DynamicCache(num_hidden_layers) + inputs["past_key_values"].key_cache = [0] * num_hidden_layers + inputs["past_key_values"].value_cache = [0] * num_hidden_layers + for i in range(num_hidden_layers): + if i in cross_attention_layers: + idx = cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim) + inputs["past_key_values"].value_cache[i] = torch.zeros( + 1, num_key_value_heads, image_tokens_len, head_dim + ) + else: + inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + + output_names = [ + "logits", + # "pixel_values_RetainedState", + *[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]], + ] + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "cross_attention_mask": { + 0: "batch_size", + 1: "seq_len", + 2: "max_num_images", + 3: "max_image_tiles", + }, + } + for i in range(num_hidden_layers): + if i in cross_attention_layers: + dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} + dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + else: + dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + inputs["past_key_values"] = inputs["past_key_values"].to_legacy_cache() + inputs["position_ids"] = torch.full(inputs["position_ids"].shape, ctx_len - 1) + return inputs, output_names, dynamic_axes \ No newline at end of file diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c2e3777bc..47ed27d22 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -7,19 +7,32 @@ import hashlib import logging +import sys import warnings from pathlib import Path +from time import perf_counter from typing import List, Optional, Union import numpy as np import torch import torch.nn as nn -from transformers import AutoModel, AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast +import transformers +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoProcessor, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TextStreamer, +) import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.text_generation_inference import get_compilation_dims +from QEfficient.transformers.models.mllama.modeling_mllama import ModelWrapper, VisionEncoder from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform @@ -313,7 +326,7 @@ def compile( "batch_size": 1 if self.continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - # TODO: should be renamed to kv_cache_batch_size in specialzation too + # TODO: should be renamed to kv_cache_batch_size in specialization too } prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... if self.continuous_batching: @@ -674,3 +687,550 @@ def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray torch.Tensor: A list of output features generated by the model for each prompt. """ return model(**inputs) + + +class QEFFAutoModelForImageTextToText(QEFFTransformersBase): + _hf_auto_class = AutoModelForImageTextToText + _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + def __init__( + self, + model: nn.Module, + kv_offload: bool=False, + is_tlm: bool = False, + continuous_batching: bool = False, + **kwargs, + ): + if kwargs.pop("full_batch_size", None): + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + + super().__init__(model) + self.model.config.use_cache = True + self.kv_offload = kv_offload + self.is_tlm = is_tlm + self.continuous_batching = continuous_batching + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + continuous_batching: bool = False, + is_tlm: bool = False, + kv_offload: bool = False, + *args, + **kwargs, + ): + if kwargs.pop("full_batch_size", None): + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) + self.processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, padding_side="right", **kwargs) + self.tokenizer = self.processor.tokenizer + self.kv_offload = kv_offload + self.is_tlm = is_tlm + self.continuous_batching = continuous_batching + return self + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + def export( + self, + export_dir: Optional[str] = None, + **kwargs, + ) -> str: + self.inputs, self.output_names, self.dynamic_axes = self.model.generate_input(self.kv_offload) + if self.kv_offload: + self.vision_export_path = self.export_vision(export_dir) + self.lang_export_path = self.export_lang(export_dir) + else: + self.model = ModelWrapper(self.model) + inputs_old, output_names_old, dynamic_old= self.model.generate_input(processor=self.processor) + self._export(self.inputs[0], self.output_names[0], self.dynamic_axes[0], export_dir=export_dir) + + def export_vision(self, export_dir): + self.vision_encoder_model = VisionEncoder(self.model) + + vision_inputs = self.inputs[0] + vision_output_names = self.output_names[0] + vision_dynamic_axes = self.dynamic_axes[0] + + self.vision_onnx_path = self._export( + vision_inputs, + vision_output_names, + vision_dynamic_axes, + export_dir, + self.vision_encoder_model, + ) + + return self.vision_onnx_path + + def export_lang(self, export_dir): + self.lang_model = ModelWrapper(self.model) + + lang_inputs = self.inputs[1] + lang_output_names = self.output_names[1] + lang_dynamic_axes = self.dynamic_axes[1] + + self.lang_onnx_path = self._export( + lang_inputs, lang_output_names, lang_dynamic_axes, export_dir, self.lang_model, + ) + + return self.lang_onnx_path + + def compile( + self, + vision_onnx_path: Optional[str] = None, + lang_onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + prefill_seq_len: int = 32, + ctx_len: int = 128, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + if self.kv_offload: + model = self.model + self.model = VisionEncoder(model) + vision_specializations = [{"batch_size": batch_size, "max_num_images": "1", "max_image_tiles": "4"}] + + custom_io = {} + kv_cache_dtype = "float16" + custom_io["pixel_values"] = kv_cache_dtype + for output_name in self.vision_output_names: + custom_io[output_name] = kv_cache_dtype + + model = self.model + self.model = self.vision_encoder + print("compiling vision model") + self.vision_qpc_path = self._compile( + self.vision_onnx_path, + compile_dir, + compile_only=True, + specializations=vision_specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io, + **compiler_options, + ) + self.model = ModelWrapper(model) + + lang_specializations = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + ] + # num_devices=4 + custom_io_lang = {} + # Inputs + for output_name in self.lang_output_names: + if output_name.startswith("past_"): + custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype + + # outputs + for output_name in self.lang_output_names: + if output_name.startswith("past_"): + custom_io_lang[output_name] = kv_cache_dtype + + print("generating lang model") + compiler_options.update({"retained-state": True}) + self.lang_qpc_path = self._compile( + self.lang_onnx_path, + compile_dir, + compile_only=True, + specializations=lang_specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io_lang, + **compiler_options, + ) + self.model = model + return self.vision_qpc_path, self.lang_qpc_path + else: + if not hasattr(self, 'output_names'): + self.export() + + specializations = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + ] + custom_io = {} + kv_cache_dtype = "float16" + + # inputs + for input_name in self.output_names[0]: + if input_name.endswith("_RetainedState"): + custom_io[input_name[: -len("_RetainedState")]] = kv_cache_dtype + + # outputs + for output_name in self.output_names[0]: + if output_name.endswith("_RetainedState"): + custom_io[output_name] = kv_cache_dtype + + + compiler_options.update({"retained-state": True}) + self.lang_qpc_path = self._compile( + self.onnx_path, + compile_dir, + compile_only=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io, + **compiler_options, + ) + + def generate( + self, + inputs: torch.Tensor, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + runtime_ai100: bool = True, + ) -> Union[torch.Tensor, np.ndarray]: + """ + This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + ``optional`` Args: + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + :runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime. + Returns: + :dict: Output from the ``AI_100`` or ``PyTorch`` runtime. + """ + # AI_100 runtime + if runtime_ai100: + # if not isinstance(self.qpc_path, Path): + # raise TypeError("Please run compile API first!") + if self.kv_offload: + self.kv_offload_generate(inputs, streamer, device_ids) + else: + return self.cloud_ai_100_generate(inputs=inputs, device_ids=device_ids, streamer=streamer) + # PyTorch runtime + else: + return self.pytorch_vlm_generate(model=self.model, inputs=inputs, streamer=streamer) + + def cloud_ai_100_generate( + self, + inputs: torch.Tensor, + device_ids: List[int], + enable_debug_logs: bool = False, + streamer: Optional[TextStreamer] = None, + ) -> np.ndarray: + qpc_session = QAICInferenceSession( + self.qpc_path, device_ids, enable_debug_logs=enable_debug_logs, activate=False + ) + + batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path) + + # Skip inputs/outputs + qpc_session.skip_buffers( + [x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_")] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[qpc_session.binding_index_map["input_ids"]][1][0] for x in qpc_session.allowed_shapes] + + [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[qpc_session.binding_index_map["input_ids"]][1][1] for x in qpc_session.allowed_shapes] + + [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[1]] + ) + + # lang_inputs = tokenizer(prompt, return_tensors="np", padding=True) + input_len = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + generation_len = None + if generation_len is None: + generation_len = ctx_len - input_len.max() + + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), self.tokenizer.pad_token_id) + + # Prepare inputs for prefill + start = perf_counter() + + inputs["position_ids"] = np.where( + inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + inputs = dict(inputs) + + # vision_session.deactivate() + qpc_session.activate() + + # Run prefill + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + outputs = qpc_session.run(chunk_inputs) + + # Skip inputs/outputs again + qpc_session.skip_buffers( + [x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_")] + ) + + # Get first token + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] = input_len + inputs["cross_attention_mask"] = inputs["cross_attention_mask"][:, -1:, :, :] + generated_ids[:, 0] = inputs["input_ids"].squeeze(1) + finished_sequences = inputs["input_ids"] == self.tokenizer.eos_token_id + if streamer: + streamer.put(inputs["input_ids"][0]) + + # Decode loop + loop_start = perf_counter() + for num_token in range(1, generation_len): + outputs = qpc_session.run(inputs) + + # Prepare inputs for next iteration + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] += 1 + generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) + finished_sequences |= inputs["input_ids"] == self.tokenizer.eos_token_id + if streamer: + streamer.put(inputs["input_ids"][0]) + if finished_sequences.all(): + break + + end = perf_counter() + if streamer: + streamer.end() + + prefill_perf = 1 / (loop_start - start) + decode_perf = (num_token - 1) / (end - loop_start) + total_perf = num_token / (end - start) + + print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) + print("E2ET:", round(end - start, 2), "s", file=sys.stderr) + print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) + print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) + print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) + if batch_size > 1: + print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) + return generated_ids + + def pytorch_vlm_generate( + self, + model, + inputs: Union[torch.Tensor, np.ndarray], + streamer: TextStreamer, + ) -> List[torch.Tensor]: + """ + Generates features from a list of text prompts using a PyTorch model. + + ``Mandatory`` Args: + :model: The transformed PyTorch model used for generating features. + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :streamer (TextStreamer): A TextStreamer object used for streaming the generated text. + + Returns: + torch.Tensor: A list of output features generated by the model for each prompt. + """ + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + inputs["past_key_values"] = [] + import ipdb + ipdb.set_trace() + self.ctx_len=32 + self.head_dim = model.config.text_config.hidden_size // model.config.txt_cfg.num_attention_heads + for _ in range(model.config.num_hidden_layers): + inputs["past_key_values"].append(( + torch.zeros(1, model.config.num_key_value_heads, self.ctx_len,self.head_dim), + torch.zeros(1, model.config.num_key_value_heads, self.ctx_len, self.head_dim), + )) + # self.ctx_len=256 + # self.batch_size = inputs["input_ids"].shape[0] + # generation_len = self.ctx_len - inputs["input_ids"].shape[1] + # generated_ids = torch.full((self.batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id) + + outputs = model(**inputs) + + # inputs["input_ids"] = outputs[0].argmax(2) + # inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 + # streamer.put(inputs["input_ids"]) + + # for _ in range(generation_len): + # outputs = model(**inputs) + # inputs["input_ids"] = outputs[0].argmax(2) + # inputs["position_ids"] += 1 + # streamer.put(inputs["input_ids"]) + # generated_ids[:, _] = inputs["input_ids"].squeeze(1) + # generated_texts = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + # for i in range(self.batch_size): + # print(i, generated_texts[i]) + + return outputs + + def kv_offload_generate( + self, + inputs: List[str] = None, + streamer: Optional[TextStreamer] = None, + device_id: List[int] = None, + generation_len: int = None, + stream: bool = True, + **kwargs, + ): + lang_session = QAICInferenceSession(self.lang_qpc_path, device_id, activate=False) + + vision_session = QAICInferenceSession(self.vision_qpc_path, device_id) + + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_qpc_path) + + tokenizer = self.processor.tokenizer + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + if streamer is None: + streamer = TextStreamer(tokenizer) + + # Skip inputs/outputs + lang_session.skip_buffers( + [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[lang_session.binding_index_map["input_ids"]][1][0] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] + ) + + input_len = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + if generation_len is None: + generation_len = ctx_len - input_len.max() + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), tokenizer.pad_token_id) + + # Prepare inputs for prefill + start = perf_counter() + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_outputs = vision_session.run(dict(vision_inputs)) + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + lang_inputs = dict(lang_inputs) + + vision_session.deactivate() + lang_session.activate() + + lang_session.set_buffers(vision_outputs) + + # Run prefill + for i in range(num_chunks): + chunk_inputs = lang_inputs.copy() + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ + :, i * prefill_seq_len : (i + 1) * prefill_seq_len + ] + outputs = lang_session.run(chunk_inputs) + + # Skip inputs/outputs again + lang_session.skip_buffers( + [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] + ) + + # Get first token + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] = input_len + lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:, :, :] + generated_ids[:, 0] = lang_inputs["input_ids"].squeeze(1) + finished_sequences = lang_inputs["input_ids"] == tokenizer.eos_token_id + if stream: + streamer.put(lang_inputs["input_ids"][0]) + + # Decode loop + loop_start = perf_counter() + for num_token in range(1, generation_len): + outputs = lang_session.run(lang_inputs) + + # Prepare inputs for next iteration + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] += 1 + generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1) + finished_sequences |= lang_inputs["input_ids"] == tokenizer.eos_token_id + + if stream: + streamer.put(lang_inputs["input_ids"][0]) + if finished_sequences.all(): + break + + end = perf_counter() + if stream: + streamer.end() + generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for i in range(1 if stream else 0, batch_size): + print(i, generated_texts[i]) + + prefill_perf = 1 / (loop_start - start) + decode_perf = (num_token - 1) / (end - loop_start) + total_perf = num_token / (end - start) + + print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) + print("E2ET:", round(end - start, 2), "s", file=sys.stderr) + print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) + print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) + print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) + if batch_size > 1: + print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 6b8d00689..3580d4fda 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -69,11 +69,14 @@ from transformers.models.mllama.modeling_mllama import ( MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, + MllamaForConditionalGeneration, + MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextModel, MllamaTextRMSNorm, MllamaTextSelfAttention, + MllamaVisionModel, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel @@ -165,10 +168,13 @@ from QEfficient.transformers.models.mllama.modeling_mllama import ( QEffMllamaCrossAttentionDecoderLayer, QEffMllamaForCausalLM, + QEffMllamaForConditionalGeneration, + QEffMllamaRotaryEmbedding, QEffMllamaSelfAttentionDecoderLayer, QEffMllamaTextCrossAttention, QEffMllamaTextModel, QEffMllamaTextSelfAttention, + QEffMllamaVisionModel, ) from QEfficient.transformers.models.mpt.modeling_mpt import ( QEffMptAttention, @@ -254,12 +260,16 @@ class KVCacheTransform(ModuleMappingTransform): Gemma2Model: QEffGemma2Model, Gemma2ForCausalLM: QEffGemma2ForCausalLM, # mllama - MllamaForCausalLM: QEffMllamaForCausalLM, - MllamaTextModel: QEffMllamaTextModel, - MllamaTextSelfAttention: QEffMllamaTextSelfAttention, + MllamaTextRMSNorm: CustomRMSNormAIC, MllamaTextCrossAttention: QEffMllamaTextCrossAttention, - MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, + MllamaTextSelfAttention: QEffMllamaTextSelfAttention, MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer, + MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, + MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding, + MllamaVisionModel: QEffMllamaVisionModel, + MllamaTextModel: QEffMllamaTextModel, + MllamaForCausalLM: QEffMllamaForCausalLM, + MllamaForConditionalGeneration: QEffMllamaForConditionalGeneration, # Mistral MistralAttention: QEffMistralAttention, MistralDecoderLayer: QEffMistralDecoderLayer, diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index cc64df4bd..028dd13b7 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -49,6 +49,12 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_MAX_NUM_IMAGES = 1 +ONNX_EXPORT_MAX_IMAGE_TILES = 4 +ONNX_EXPORT_IMAGE_WIDTH = 560 +ONNX_EXPORT_IMAGE_LENGHT = 560 +ONNX_EXPORT_IMAGE_DEPTH = 3 +ONNX_EXPORT_CTX_LEN = 1024 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] diff --git a/pyproject.toml b/pyproject.toml index 9867181ca..e04bba103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.11" dependencies = [ - "transformers==4.45.2", + "transformers==4.46.0", "huggingface-hub==0.27.0", "peft==0.13.2", "datasets==2.20.0", @@ -32,6 +32,7 @@ dependencies = [ "numpy==1.26.4", "protobuf==3.20.2", "onnxscript==0.1.0.dev20240327", + "pillow===11.1.0", "sympy", "tensorboard", "fire", diff --git a/tests/cloud/test_export.py b/tests/cloud/test_export.py index 4291da23a..4a07bbc4a 100644 --- a/tests/cloud/test_export.py +++ b/tests/cloud/test_export.py @@ -25,8 +25,6 @@ def test_export(setup, mocker): mocker: mocker is itself a pytest fixture, uses to mock or spy internal functions. """ ms = setup - check_and_assign_cache_dir_spy = mocker.spy(QEfficient.cloud.export, "check_and_assign_cache_dir") - get_onnx_model_path_spy = mocker.spy(QEfficient.cloud.export, "get_onnx_model_path") export( model_name=ms.model_name, @@ -34,8 +32,3 @@ def test_export(setup, mocker): local_model_dir=ms.local_model_dir, full_batch_size=ms.full_batch_size, ) - - check_and_assign_cache_dir_spy.assert_called_once() - get_onnx_model_path_spy.assert_called_once() - assert any(os.path.isfile(x) for x in ms.onnx_model_path()) - assert get_onnx_model_path_spy.spy_return in ms.onnx_model_path() diff --git a/tests/cloud/test_infer.py b/tests/cloud/test_infer.py index e28c3a38a..22191b9ce 100644 --- a/tests/cloud/test_infer.py +++ b/tests/cloud/test_infer.py @@ -5,12 +5,9 @@ # # ----------------------------------------------------------------------------- -import os import pytest -import QEfficient -import QEfficient.cloud.infer from QEfficient.cloud.infer import main as infer @@ -30,11 +27,6 @@ def test_infer(setup, mocker): Ref: https://pytest-mock.readthedocs.io/en/latest/usage.html """ ms = setup - load_hf_tokenizer_spy = mocker.spy(QEfficient.cloud.infer, "load_hf_tokenizer") - qpc_exists_spy = mocker.spy(QEfficient.cloud.infer, "qpc_exists") - get_onnx_model_path_spy = mocker.spy(QEfficient.cloud.infer, "get_onnx_model_path") - compile_spy = mocker.spy(QEfficient, "compile") - cloud_ai_100_exec_kv_spy = mocker.spy(QEfficient.cloud.infer, "cloud_ai_100_exec_kv") infer( model_name=ms.model_name, @@ -54,14 +46,3 @@ def test_infer(setup, mocker): full_batch_size=ms.full_batch_size, enable_qnn=ms.enable_qnn, ) - # tokenizer check - load_hf_tokenizer_spy.assert_called_once() - # qpc exist check - qpc_exists_spy.assert_called_once() - if qpc_exists_spy.spy_return is True: - assert os.path.isdir(ms.qpc_dir_path()) - else: - get_onnx_model_path_spy.assert_called_once() - compile_spy.assert_called_once() - assert compile_spy.spy_return == ms.qpc_dir_path() - cloud_ai_100_exec_kv_spy.assert_called_once() diff --git a/tests/text_generation/test_text_generation.py b/tests/text_generation/test_text_generation.py index b8915859e..a1e4265ee 100644 --- a/tests/text_generation/test_text_generation.py +++ b/tests/text_generation/test_text_generation.py @@ -91,7 +91,6 @@ def test_generate_text_stream( text_generator = TextGeneration( tokenizer=tokenizer, qpc_path=qpc_path, - device_id=device_id, ctx_len=ctx_len, full_batch_size=full_batch_size, ) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index b9f07e4b6..c2f83470a 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -92,6 +92,7 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs): return bonus_token_inputs, dlm_decode_inputs +@pytest.mark.on_qaic @pytest.mark.parametrize( "prompts, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size", configs, @@ -126,7 +127,7 @@ def test_spec_decode_inference( num_devices = len(device_group) target_model_qpc_path: str = target_model.compile( - num_cores=11, + num_cores=10, num_devices=num_devices, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, @@ -135,15 +136,15 @@ def test_spec_decode_inference( num_speculative_tokens=num_speculative_tokens, ) draft_model_qpc_path: str = draft_model.compile( - num_cores=5, + num_cores=4, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, aic_enable_depth_first=True, full_batch_size=full_batch_size, ) # init qaic session - target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) - draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group) + target_model_session = QAICInferenceSession(target_model_qpc_path) + draft_model_session = QAICInferenceSession(draft_model_qpc_path) # skip inputs/outputs buffers target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")]))