From df66ac31171cf0bcaad736948335216c375efb74 Mon Sep 17 00:00:00 2001 From: Jou-An Chen Date: Thu, 10 Oct 2024 10:50:36 -0700 Subject: [PATCH 1/9] Initial commit for finite loras implementation Signed-off-by: Jou-An Chen --- QEfficient/__init__.py | 2 + .../exporter/export_hf_to_cloud_ai_100.py | 16 +- QEfficient/exporter/export_utils.py | 2 + .../generation/text_generation_inference.py | 26 ++ QEfficient/lora/__init__.py | 12 + QEfficient/lora/auto.py | 398 ++++++++++++++++++ QEfficient/lora/layers.py | 65 +++ QEfficient/lora/lora_model.py | 88 ++++ QEfficient/lora/pytorch_transforms.py | 53 +++ .../models/llama/modeling_llama.py | 16 +- .../models/mistral/modeling_mistral.py | 14 +- QEfficient/utils/generate_inputs.py | 23 +- docs/source/hl_api.md | 6 + examples/lora_models.py | 118 ++++++ tests/lora/test_lora_model.py | 216 ++++++++++ 15 files changed, 1043 insertions(+), 12 deletions(-) create mode 100644 QEfficient/lora/__init__.py create mode 100644 QEfficient/lora/auto.py create mode 100644 QEfficient/lora/layers.py create mode 100644 QEfficient/lora/lora_model.py create mode 100644 QEfficient/lora/pytorch_transforms.py create mode 100644 examples/lora_models.py create mode 100644 tests/lora/test_lora_model.py diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 0f7f40483..7adbbd6f7 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -9,6 +9,7 @@ from QEfficient.compile.compile_helper import compile from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv +from QEfficient.lora import QEffAutoLoraModelForCausalLM from QEfficient.peft import QEffAutoPeftModelForCausalLM from QEfficient.transformers.transform import transform @@ -24,5 +25,6 @@ "QEffAutoModel", "QEFFAutoModelForCausalLM", "QEffAutoPeftModelForCausalLM", + "QEffAutoLoraModelForCausalLM", "QEFFCommonLoader", ] diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index 55f2ac3be..5b2319edb 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -16,6 +16,7 @@ from QEfficient.base.common import AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP, QEFF_MODEL_TYPE, QEFFCommonLoader from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.exporter.export_utils import export_onnx, fix_onnx_fp16, generate_input_files, run_model_on_ort +from QEfficient.lora.auto import QEffAutoLoraModelForCausalLM from QEfficient.transformers.modeling_utils import get_lists_of_cb_qeff_models from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM from QEfficient.utils import load_hf_tokenizer @@ -148,6 +149,7 @@ def convert_to_cloud_kvstyle( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], onnx_dir_path: str, seq_len: int, + max_num_adapters: int, ) -> str: """ API to convert model with kv retention and export to ONNX. @@ -176,7 +178,7 @@ def convert_to_cloud_kvstyle( # Decide path for saving exported ONNX files. model_name = export_kvstyle_transformed_model_to_onnx( - model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len + model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len, max_num_adapters ) # type: ignore # return the model path for automation. @@ -190,6 +192,7 @@ def export_kvstyle_transformed_model_to_onnx( onnx_dir_path: str, seq_len: int, full_batch_size: Optional[int] = None, + max_num_adapters: Optional[int] = None, ) -> str: # Disabling requires_grad on all parameters for _, p in enumerate(transformed_model.parameters()): @@ -208,6 +211,7 @@ def export_kvstyle_transformed_model_to_onnx( prompt_len=Constants.PROMPT_LEN, ctx_len=seq_len, full_batch_size=full_batch_size, + max_num_adapters=max_num_adapters, ) inputs = input_handler.prepare_pytorch_inputs() @@ -315,6 +319,7 @@ def export_for_cloud( onnx_dir_path: str, seq_length: int = Constants.SEQ_LEN, full_batch_size: Optional[int] = None, + max_num_adapters: Optional[int] = None, ) -> str: # Check if model architecture is supported for continuous batching. if full_batch_size and qeff_model.model.config.architectures[0].lower() not in { @@ -325,7 +330,10 @@ def export_for_cloud( ) # FIXME: move all this to class instead of here, and just call qeff_model.export here. - if AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(qeff_model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM: # type: ignore + if ( + AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(qeff_model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM + or qeff_model.__class__ == QEffAutoLoraModelForCausalLM + ): # type: ignore return export_lm_model_for_cloud( model_name=model_name, qeff_model=qeff_model, # type: ignore @@ -333,6 +341,7 @@ def export_for_cloud( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, + max_num_adapters=max_num_adapters, ) else: raise NotImplementedError( @@ -347,6 +356,7 @@ def export_lm_model_for_cloud( onnx_dir_path: str, seq_length: int, full_batch_size: Optional[int] = None, + max_num_adapters: Optional[int] = None, ) -> str: if os.path.exists(onnx_dir_path): logger.warning(f"Overriding {onnx_dir_path}") @@ -375,6 +385,7 @@ def qualcomm_efficient_converter( kv: bool = True, form_factor: str = "cloud", full_batch_size: Optional[int] = None, + max_num_adapters: Optional[int] = None, ) -> Tuple[str, str]: """ This method is an alias for ``QEfficient.export``. @@ -450,6 +461,7 @@ def qualcomm_efficient_converter( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, + max_num_adapters=max_num_adapters, ) return onnx_dir_path, generated_onnx_model_path else: diff --git a/QEfficient/exporter/export_utils.py b/QEfficient/exporter/export_utils.py index d7da3ae04..46a1082e2 100644 --- a/QEfficient/exporter/export_utils.py +++ b/QEfficient/exporter/export_utils.py @@ -83,6 +83,8 @@ def export_onnx( dynamic_axes[iname] = {0: dynamic_axis_past_key, 2: "ctx_len"} elif iname == "batch_index": dynamic_axes[iname] = {0: "batch_size"} + elif iname == "lora_ids": + dynamic_axes[iname] = {0: "batch_size"} if "past_key.0" in input_names and "attention_mask" in input_names: dynamic_axes["attention_mask"] = {0: "batch_size", 1: "ctx_len"} diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 0ddb0acc9..a624ce24c 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -230,6 +230,8 @@ def cloud_ai_100_exec_kv( stream: bool = True, write_io_dir: Optional[str] = None, automation=False, + full_batch_size: Optional[int] = None, + prompt_to_lora_id_mapping: Optional[List[int]] = None, ): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. @@ -277,6 +279,7 @@ def cloud_ai_100_exec_kv( stream=stream, write_io_dir=write_io_dir, full_batch_size=full_batch_size, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, ) if full_batch_size is None: exec_info = [ @@ -313,6 +316,7 @@ def __init__( qpc_path: str, prompt: List[str], full_batch_size: Optional[int] = None, + prompt_to_lora_id_mapping: Optional[List[int]] = None, ctx_len: Optional[int] = None, generation_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -342,6 +346,13 @@ def __init__( full_batch_size if full_batch_size else self._fetch_full_batch_size() ) # Check and fetch full batch size if CB is enabled + if prompt_to_lora_id_mapping: + self.prompt_to_lora_id_mapping_prefill = deque(prompt_to_lora_id_mapping) + self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping + else: + self.prompt_to_lora_id_mapping_prefill = None + self.prompt_to_lora_id_mapping_decode = None + self.set_tokenizer_params() # set tokenizer params # Initialize the storage variables. @@ -461,6 +472,10 @@ def prepare_decode_inputs(self): if self.batch_index is not None: decode_inputs["batch_index"] = self.batch_index + if self.prompt_to_lora_id_mapping_decode and self.full_batch_size is not None: + first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)] + decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape(self.full_batch_size, 1) + return decode_inputs def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None): @@ -549,6 +564,11 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id + if self.prompt_to_lora_id_mapping_prefill: + inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape( + 1, 1 + ) + for i in range(num_chunks): chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ @@ -636,6 +656,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): ) generated_id_current_index[decode_batch_id] += 1 + + if self.prompt_to_lora_id_mapping_decode: + decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[ + batch_id_map[decode_batch_id] + ] + return decode_pause_time def run_decode(self, decode_inputs, generation_len): diff --git a/QEfficient/lora/__init__.py b/QEfficient/lora/__init__.py new file mode 100644 index 000000000..75966ff66 --- /dev/null +++ b/QEfficient/lora/__init__.py @@ -0,0 +1,12 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from QEfficient.lora.auto import QEffAutoLoraModelForCausalLM + +__all__ = [ + "QEffAutoLoraModelForCausalLM", +] diff --git a/QEfficient/lora/auto.py b/QEfficient/lora/auto.py new file mode 100644 index 000000000..06e2ca59c --- /dev/null +++ b/QEfficient/lora/auto.py @@ -0,0 +1,398 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import hashlib +import os +import sys +from pathlib import Path +from typing import List, Optional + +import torch +import torch.nn as nn +from peft import PeftConfig, load_peft_weights + +import QEfficient +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform +from QEfficient.transformers.pytorch_transforms import CBTransform +from QEfficient.utils import get_qpc_dir_path, qpc_exists +from QEfficient.utils.cache import to_hashable +from QEfficient.utils.constants import QEFF_MODELS_DIR +from QEfficient.utils.logging_utils import logger + +INTMAX = sys.maxsize + + +class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): + """ + QEff class for loading models with mutltiple LoRA adapters. + Once exported and compiled, the qpc can perform mixed batch inference with provided prompt_to_lora_id_mapping. + + Args: + :model (nn.Module): PyTorch model + :base_model_name (str): Model card name for base model + :adapter_weights (Dict): A dictionary contains lora_name to lora_weight mapping + :adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping + :active_adapters (Set): A set of lora_names that are currently active + :max_num_adapters (int): Total number of active adapters that to be exported and compiled + :active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping + + .. code-block:: python + + from QEfficient import QEffAutoLoraModelForCausalLM + + m = QEffAutoPeftModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + m.load_adapter("predibase/gsm8k", "gsm8k") + m.load_adapter("predibase/magicoder", "magicoder") + gsm8k_id = m.set_adapter("gsm8k") + magicoder_id = m.set_adapter("magicoder") + m.export(full_batch_size=3) + m.compile(num_cores=16, device_group=[0]) + + prompts=["code prompt", "math prompt", "generic"] + m.generate(prompts, device_group=[0], prompt_to_lora_id_mapping=[magicoder_id,gsm8k_id,INTMAX]) + + """ + + # inherit __init__() from QEFFAutoModelForCausalLM + def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwargs) -> None: + super().__init__(model, pretrained_model_name_or_path) + self.base_model_name = pretrained_model_name_or_path + self.adapter_weights = {} + self.adapter_configs = {} + self.active_adapters = set() + self.max_num_adapters = 0 + self.active_adapter_to_id = {} + + self.lora_rank = 0 + self.target_modules_for_all_adapters = [] + + @property + def model_hash(self) -> str: + mhash = hashlib.sha256() + + # should use model config here + mhash.update(to_hashable(self.model.model.config.to_diff_dict())) + + # create active adapter config dict + active_adapter_configs = {} + for adpt in self.active_adapters: + active_adapter_configs[adpt] = self.adapter_configs[adpt].to_dict() + mhash.update(to_hashable(active_adapter_configs)) + + # ensure model will be exported again if order of adapters changes + mhash.update(to_hashable(self.active_adapter_to_id)) + + mhash = mhash.hexdigest()[:16] + return mhash + + def load_adapter(self, adapter_model_id: str, adapter_name: str): + """Loads a new adapter from huggingface hub or local path into CPU cache + + Args: + :adapter_model_id (str): Adapter model ID from huggingface hub or local path + :adapter_name (str): Adapter name to be used to set this adapter as current + """ + if (adapter_name in self.adapter_weights.keys()) and (adapter_name in self.adapter_configs.keys()): + logger.warning(f"Overwrite weights and configs for adapter name {adapter_name}") + + self.adapter_weights[adapter_name] = { + k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items() + } + self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id) + + def unload_adapter(self, adapter_name: str): + if adapter_name not in self.adapter_weights.keys() and adapter_name not in self.adapter_configs.keys(): + print(f"Adapter name {adapter_name} is not loaded yet") + return False + + if adapter_name in self.active_adapters: + print(f"Adapter name {adapter_name} is stil in active list, do delete_adapter() before unloading") + return False + + self.adapter_weights.pop(adapter_name) + self.adapter_configs.pop(adapter_name) + logger.warning(f"Unloading {adapter_name} from CPU cache.") + return True + + def set_adapter(self, adapter_name: str): + "Sets active adapter from one of the loaded adapters" + + assert (adapter_name in self.adapter_weights.keys()) and ( + adapter_name in self.adapter_configs.keys() + ), f"Adapter name {adapter_name} has not been loaded yet" + + assert ( + list(self.adapter_configs.values())[0] + and self.adapter_configs[adapter_name].target_modules + == list(self.adapter_configs.values())[0].target_modules + ), "Not all adapters have the same target modules" + + assert ( + list(self.adapter_configs.values())[0] + and self.adapter_configs[adapter_name].r == list(self.adapter_configs.values())[0].r + ), "Not all adapters have the same ranks" + + # set active adapter id to current max + self.active_adapter_to_id[adapter_name] = self.max_num_adapters + + # add active adapter to set + self.active_adapters.add(adapter_name) + self.max_num_adapters = len(self.active_adapters) + + return self.active_adapter_to_id[adapter_name] + + def delete_adapter(self, adapter_name: str): + if adapter_name not in self.active_adapters: + print(f"Adapter name {adapter_name} is not set active yet") + return False + + self.active_adapters.discard(adapter_name) + self.max_num_adapters -= 1 + self.active_adapter_to_id.pop(adapter_name) + + # renumbering of active adapter id + for index, (key, value) in enumerate(self.active_adapter_to_id.items()): + self.active_adapter_to_id[key] = index + + logger.warning(f"Deleting {adapter_name} from active adapters.") + if self.onnx_path or self.qpc_path: + logger.warning("Please redo compile_and_export() to reflect the active adapters changes.") + + return True + + def get_adapter_id(self, adapter_name): + "get the adapter_id that maps to the adapter_name" + + return self.active_adapter_to_id[adapter_name] + + def load_adapter_weights_to_model(self): + "Loads adapter weights to the model's multilora layer in a stacked format" + + num_hidden_layers = len(self.model.model.layers) + for i in range(num_hidden_layers): + for target_module in self.target_modules_for_all_adapters: + # stack all adapters weights + a_tensor_list = list(range(self.max_num_adapters)) + b_tensor_list = list(range(self.max_num_adapters)) + c_tensor_list = list(range(self.max_num_adapters)) + + for lora_name, lora_id in self.active_adapter_to_id.items(): + if ( + target_module == "q_proj" + or target_module == "k_proj" + or target_module == "v_proj" + or target_module == "o_proj" + ): + a_tensor_list[lora_id] = torch.from_numpy( + self.adapter_weights[lora_name][ + f"base_model.model.model.layers.{i}.self_attn.{target_module}.lora_A.weight" + ] + ) + b_tensor_list[lora_id] = torch.from_numpy( + self.adapter_weights[lora_name][ + f"base_model.model.model.layers.{i}.self_attn.{target_module}.lora_B.weight" + ] + ) + else: + raise NotImplementedError("Target module not supported!!") + + c_tensor_list[lora_id] = torch.tensor( + self.adapter_configs[lora_name].lora_alpha / self.adapter_configs[lora_name].r, + dtype=torch.float16, + ) + + stacked_lora_A = ( + torch.stack(a_tensor_list, dim=0).unsqueeze(1).transpose(2, 3) + ) # + stacked_lora_B = ( + torch.stack(b_tensor_list, dim=0).unsqueeze(1).transpose(2, 3) + ) # + stacked_lora_C = ( + torch.stack(c_tensor_list, dim=0).unsqueeze(1).unsqueeze(2).unsqueeze(3) + ) # + + # stored weight to corresponding ops + if target_module == "q_proj": + module = self.model.model.layers[i].self_attn.q_proj + elif target_module == "k_proj": + module = self.model.model.layers[i].self_attn.k_proj + elif target_module == "v_proj": + module = self.model.model.layers[i].self_attn.v_proj + elif target_module == "o_proj": + module = self.model.model.layers[i].self_attn.o_proj + else: + raise NotImplementedError("Target module not supported!!") + + module.lora_weight_A.copy_(stacked_lora_A) + module.lora_weight_B.copy_(stacked_lora_B) + module.lora_weight_C.copy_(stacked_lora_C) + + def init_adapter_model(self): + "Initialize the fixed lora model with multiple adapter weigths standby" + + # assume all adapters have same target_modules and ranks + assert self.max_num_adapters == len(self.active_adapters), "Inconsistent max_num_adapters and active_adapters" + + assert list(self.adapter_configs.values())[0] and all( + list(self.adapter_configs.values())[i].target_modules + == list(self.adapter_configs.values())[0].target_modules + for i in range(self.max_num_adapters) + ), "Not all adapters have the same target modules" + + assert list(self.adapter_configs.values())[0] and all( + list(self.adapter_configs.values())[i].r == list(self.adapter_configs.values())[0].r + for i in range(self.max_num_adapters) + ), "Not all adapters have the same ranks" + self.lora_rank = list(self.adapter_configs.values())[0].r + + # do the module replacement + _, transformed = LoraModelInputsTransform.apply(self.model) + + self.target_modules_for_all_adapters = list(self.adapter_configs.values())[0].target_modules + _, transformed = TargetModulesTransform.apply( + self.model, self.target_modules_for_all_adapters, self.lora_rank, self.max_num_adapters + ) + + # load_weight to model + self.load_adapter_weights_to_model() + + def export(self, **kwargs) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + The model should already be transformed i.e. ``self.is_transformed`` should be ``True``. + Otherwise, this will raise an ``AssertionError``. + We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." + + ``Optional`` Args: + does not any arguments. + + Raises: + :AttributeError: If ``pretrained_model_name_or_path`` is a path, this function needs model card name of the model so that it can distinguish between directories while saving the ``ONNX`` files generated. So, user needs to pass ``model_card_name`` as a valid ``string`` in that case, Otherwise this will raise the error. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + self.full_batch_size = kwargs.get("full_batch_size", self.full_batch_size) + export_dir = kwargs.get("export_dir", None) + + # obtain all necessary information to initialize the model + self.init_adapter_model() + + assert self.is_transformed, "Please first run transform on the QEFFAutoModelForCausalLM object" + + # Caching export onnx + if export_dir is None: + model_card_dir = os.path.join(QEFF_MODELS_DIR, str(self.model_card_name)) + export_dir = Path(model_card_dir).with_name(str(self.model_card_name).split("/")[1] + "-" + self.model_hash) + else: + export_dir = Path(export_dir).with_name(export_dir.name + "-" + self.model_hash) + onnx_dir_path = os.path.join(export_dir, "onnx") + model_base_name = self.model_card_name.replace("/", "_") + "_kv" + onnx_path = os.path.join(onnx_dir_path, f"{model_base_name}.onnx") + + if Path(onnx_path).is_file(): + self.onnx_path = onnx_path + print(f"Using existing onnx path:-{self.onnx_path}") + return self.onnx_path + + # Export + os.makedirs(onnx_dir_path, exist_ok=True) + _, onnx_model_path = QEfficient.export( + model_name=self.model_card_name, + model_kv=self, + tokenizer=self.tokenizer, + full_batch_size=self.full_batch_size, + max_num_adapters=self.max_num_adapters, + onnx_dir_path=onnx_dir_path, + ) + self.onnx_path = onnx_model_path + + return self.onnx_path + + def export_and_compile( + self, + num_cores: int, + device_group: List[int], + batch_size: int = 1, + prompt_len: int = 32, + ctx_len: int = 128, + mxfp6: bool = True, + mxint8: bool = False, + mos: int = -1, + aic_enable_depth_first: bool = False, + qpc_dir_suffix: Optional[str] = None, + full_batch_size: Optional[int] = None, + ) -> str: + """ + This API is specific to Internal VLLM use-case and is not recommended to be used in your application unless your are using VLLM. + """ + _, transformed = CBTransform.apply(self.model) + if not transformed: + raise RuntimeError("Could not apply Continuous batch transform on the model") + if full_batch_size is not None: + self.full_batch_size = full_batch_size + + self.export() + + qpc_base_dir_name = get_qpc_dir_path( + model_card_name=self.model_card_name, + num_cores=num_cores, + mos=mos, + batch_size=batch_size, + prompt_len=prompt_len, + ctx_len=ctx_len, + mxfp6=mxfp6, + mxint8=mxint8, + device_group=device_group, + full_batch_size=self.full_batch_size, + ) + + # Caching compiled qpc + model_card_dir = os.path.join(QEFF_MODELS_DIR, str(self.model_card_name)) + export_dir = Path(model_card_dir).with_name(str(self.model_card_name).split("/")[1] + "-" + self.model_hash) + qpc_dir_path = qpc_base_dir_name.replace(model_card_dir, str(export_dir)) + qpc_path = os.path.join(qpc_dir_path, "qpcs") + + if not qpc_exists(qpc_path): + # Compile + self.qpc_path = QEfficient.compile( + onnx_path=self.onnx_path, + qpc_path=qpc_dir_path, + num_cores=num_cores, + device_group=device_group, + aic_enable_depth_first=aic_enable_depth_first, + mos=mos, + batch_size=batch_size, + prompt_len=prompt_len, + ctx_len=ctx_len, + mxfp6=mxfp6, + mxint8=mxint8, + full_batch_size=full_batch_size, + ) + print(f"Generated qpc:-{qpc_path}") + else: + self.qpc_path = qpc_path + print(f"Using existing qpc path:-{self.qpc_path}") + + return self.qpc_path + + def run_cloud_ai_100(self, prompts: List[str], device_id: List[int] = None, **kwargs): + assert isinstance(self.qpc_path, str), "Please run compile API first!" + generation_len = kwargs.pop("generation_len", None) + default_mapping = [INTMAX for _ in range(len(prompts))] + prompt_to_lora_id_mapping = kwargs.pop("prompt_to_lora_id_mapping", default_mapping) + return QEfficient.cloud_ai_100_exec_kv( + self.tokenizer, + self.qpc_path, + prompt=prompts, + device_id=device_id, + generation_len=generation_len, + full_batch_size=self.full_batch_size, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, + ) diff --git a/QEfficient/lora/layers.py b/QEfficient/lora/layers.py new file mode 100644 index 000000000..49694aacf --- /dev/null +++ b/QEfficient/lora/layers.py @@ -0,0 +1,65 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from QEfficient.customop import CtxGatherFuncCB + + +class LinearMultiLoRA(nn.Linear): + def multilora_init(self, lora_rank, max_num_adapters): + self.max_num_adapters = max_num_adapters + self.lora_rank = lora_rank + + self.lora_weight_A = nn.Parameter( + self.weight.new_zeros(self.max_num_adapters, 1, self.in_features, self.lora_rank) + ) + self.lora_weight_A.requires_grad = False + self.lora_weight_B = nn.Parameter( + self.weight.new_zeros(self.max_num_adapters, 1, self.lora_rank, self.out_features) + ) + self.lora_weight_B.requires_grad = False + self.lora_weight_C = torch.full((self.max_num_adapters, 1, 1, 1), 1.0, dtype=torch.float) + + nn.init.kaiming_uniform_(self.lora_weight_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_weight_B) + + def forward(self, x: torch.Tensor, **kwargs: Any): + lora_ids = kwargs.pop("lora_ids", torch.zeros((x.shape[0]), dtype=torch.int64).view(-1, 1)) + + with torch.no_grad(): + result = F.linear(x, self.weight, bias=self.bias) + + # multilora implementation: lora_ids + other_indices_A = torch.arange(self.lora_weight_A.shape[2]).view(1, 1, -1) + A_embedding = CtxGatherFuncCB.apply( + self.lora_weight_A, lora_ids, other_indices_A + ) # + other_indices_B = torch.arange(self.lora_weight_B.shape[2]).view(1, 1, -1) + B_embedding = CtxGatherFuncCB.apply( + self.lora_weight_B, lora_ids, other_indices_B + ) # + other_indices_C = torch.arange(self.lora_weight_C.shape[2]).view(1, 1, -1) + C_embedding = CtxGatherFuncCB.apply(self.lora_weight_C, lora_ids, other_indices_C) # + + A_embedding = A_embedding.squeeze(1) + B_embedding = B_embedding.squeeze(1) + C_embedding = C_embedding.squeeze(1) + + result = result + x @ A_embedding @ B_embedding * C_embedding + + return result + + +class LinearBase(nn.Linear): + def forward(self, x: torch.Tensor, **kwargs: Any): + return super().forward(x) diff --git a/QEfficient/lora/lora_model.py b/QEfficient/lora/lora_model.py new file mode 100644 index 000000000..456d3fdde --- /dev/null +++ b/QEfficient/lora/lora_model.py @@ -0,0 +1,88 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import List, Optional, Tuple, Union + +import torch +from transformers.modeling_outputs import ( + CausalLMOutputWithPast, +) + +from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM +from QEfficient.transformers.models.mistral.modeling_mistral import QEffMistralForCausalLM + + +class QEffLoraModelMistralForCausalLM(QEffMistralForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + lora_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + kwargs["lora_ids"] = lora_ids + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + +class QEffLoraModelLlamaForCausalLM(QEffLlamaForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + lora_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + kwargs["lora_ids"] = lora_ids + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) diff --git a/QEfficient/lora/pytorch_transforms.py b/QEfficient/lora/pytorch_transforms.py new file mode 100644 index 000000000..db70a984d --- /dev/null +++ b/QEfficient/lora/pytorch_transforms.py @@ -0,0 +1,53 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import Dict, Optional, Tuple + +from torch import nn + +from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.lora.layers import LinearBase, LinearMultiLoRA +from QEfficient.lora.lora_model import QEffLoraModelLlamaForCausalLM, QEffLoraModelMistralForCausalLM +from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM +from QEfficient.transformers.models.mistral.modeling_mistral import QEffMistralForCausalLM + + +class LoraModelInputsTransform(ModuleMappingTransform): + _module_mapping = { + QEffMistralForCausalLM: QEffLoraModelMistralForCausalLM, + QEffLlamaForCausalLM: QEffLoraModelLlamaForCausalLM, + } + + +class TargetModulesTransform(ModuleMappingTransform): + _module_mapping = {nn.Linear: LinearMultiLoRA} + + _module_mapping_nontarget = {nn.Linear: LinearBase} + + # whole set of supported target modules for now (make sure **kwargs are passed in on modeling file) + all_modules = {"q_proj", "k_proj", "v_proj", "o_proj"} + + # a class method that deals with target module names + @classmethod + def apply( + cls, model: nn.Module, target_modules: Optional[Dict], lora_rank: int, max_num_adapters: int + ) -> Tuple[nn.Module, bool]: + transformed = False + nontarget_modules = {key for key in cls.all_modules if key not in target_modules} + + for name, module in model.named_modules(): + if repl_module := cls._module_mapping.get(type(module)): + if name.split(".")[-1] in target_modules: + module.__class__ = repl_module + if hasattr(module, "multilora_init"): + module.multilora_init(lora_rank, max_num_adapters) + transformed = True + elif name.split(".")[-1] in nontarget_modules: + module.__class__ = cls._module_mapping_nontarget.get(type(module)) + transformed = True + + return model, transformed diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 4a1870380..679b4a2f9 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -168,6 +168,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -190,9 +191,9 @@ def forward( value_states = torch.cat(value_states, dim=-1) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states, **kwargs) + key_states = self.k_proj(hidden_states, **kwargs) + value_states = self.v_proj(hidden_states, **kwargs) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -244,7 +245,7 @@ def forward( o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output, **kwargs) if not output_attentions: attn_weights = None @@ -273,6 +274,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -318,6 +320,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Cast to INT32 to avoid issue while running in ONNXRT @@ -374,6 +377,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -403,6 +407,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -443,6 +448,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -515,6 +521,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) else: layer_outputs = decoder_layer( @@ -525,6 +532,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index ae913b42d..9fc71dc02 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -136,12 +136,13 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states, **kwargs) + key_states = self.k_proj(hidden_states, **kwargs) + value_states = self.v_proj(hidden_states, **kwargs) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -187,7 +188,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output, **kwargs) if not output_attentions: attn_weights = None @@ -215,6 +216,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -294,6 +296,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = layer_outputs[0] @@ -413,6 +416,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -439,7 +443,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -459,6 +462,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Cast to int32 to avoid ONNXRT issue diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index c45cfec41..98f45ac0a 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +from typing import Optional import numpy as np import torch @@ -12,7 +13,17 @@ class InputHandler: - def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size): + def __init__( + self, + batch_size, + tokenizer, + config, + prompt, + prompt_len, + ctx_len, + full_batch_size, + max_num_adapters: Optional[int] = None, + ): """ Initialization @@ -32,6 +43,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f self.prompt_len = prompt_len self.ctx_len = ctx_len self.full_batch_size = full_batch_size + self.max_num_adapters = max_num_adapters self.n_layer = get_num_layers_from_config(config) self.padding_shape = get_padding_shape_from_config( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len @@ -76,6 +88,11 @@ def prepare_pytorch_inputs(self): inputs["position_ids"] = torch.arange(input_len).view(1, input_len) inputs["batch_index"] = torch.arange(1).view(-1, 1) + # lora_ids for prefill + if self.max_num_adapters: + lora_ids = torch.zeros((1), dtype=torch.int64).view(-1, 1) + inputs["lora_ids"] = lora_ids + past_key_values = [] for i in range(self.n_layer): past_key = torch.zeros((self.padding_shape), dtype=torch.float32) @@ -119,6 +136,10 @@ def update_pytorch_inputs(self, inputs, pt_outputs): [(key.detach(), value.detach()) for key, value in pt_outputs["past_key_values"]] ) + if self.max_num_adapters: + lora_ids = torch.zeros((self.full_batch_size), dtype=torch.int64).view(-1, 1) + updated_inputs["lora_ids"] = lora_ids + return updated_inputs def prepare_ort_inputs(self): diff --git a/docs/source/hl_api.md b/docs/source/hl_api.md index 8ddc65ca7..798157be0 100644 --- a/docs/source/hl_api.md +++ b/docs/source/hl_api.md @@ -16,6 +16,12 @@ :members: ``` +## `QEffAutoLoraModelForCausalLM` +```{eval-rst} +.. autoclass:: QEfficient.lora.auto.QEffAutoLoraModelForCausalLM + :members: +``` + ## `export` ```{eval-rst} .. automodule:: QEfficient.exporter.export_hf_to_cloud_ai_100 diff --git a/examples/lora_models.py b/examples/lora_models.py new file mode 100644 index 000000000..2c83374e5 --- /dev/null +++ b/examples/lora_models.py @@ -0,0 +1,118 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## This example works on continuous batching with different lora adapters in the same batch ## + +import sys + +from QEfficient import QEffAutoLoraModelForCausalLM + +INTMAX = sys.maxsize + +base_model_name = "mistralai/Mistral-7B-v0.1" +seq_len = 128 +ctx_len = 256 +full_batch_size = 4 +device_group = [0] + +## STEP 1 -- init base model + +# **Option1**: Download model weights from hugging face & Init it with QEffAuto model to apply QEff transforms +# model_hf = AutoModelForCausalLM.from_pretrained(base_model_name) +# qeff_model = QEffAutoLoraModelForCausalLM(model_hf, pretrained_model_name_or_path=base_model_name) + +# **Option2**: Initialize the model using from_pretrained() method +qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name) + +## STEP 2 -- load adapter & set adapter +qeff_model.load_adapter("predibase/gsm8k", "gsm8k") +adapter_id_gsm8k = qeff_model.set_adapter("gsm8k") +print(f"Activating gsm8k as adapter_id {adapter_id_gsm8k}") + +qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen") +adapter_id_tldr = qeff_model.set_adapter("tldr_content_gen") +print(f"Activating tldr_content_gen as adapter_id {adapter_id_tldr}") + +# STEP 2 (optional) -- delete adapter & unload adapter +qeff_model.load_adapter("predibase/dbpedia", "dbpedia") +adapter_id_dbpedia = qeff_model.set_adapter("dbpedia") +print(f"Activating dbpedia as adapter_id {adapter_id_dbpedia}") + +delete_status = qeff_model.delete_adapter("dbpedia") +print(f"Deleting dbpedia success: {delete_status}") +unload_status = qeff_model.unload_adapter("dbpedia") +print(f"Unloading dbpedia success: {unload_status}") + +# get adapter id +# NOTE: should rely on get_adapter_id in case the id obtained at set_adpater() get updated +gsm8k_id = qeff_model.get_adapter_id("gsm8k") +tldr_id = qeff_model.get_adapter_id("tldr_content_gen") + +## STEP 3 -- export & compile qeff model +args = { + "num_cores": 16, + "device_group": device_group, + "batch_size": 1, + "prompt_len": seq_len, + "ctx_len": ctx_len, + "mxfp6": True, + "mxint8": True, + "mos": -1, + "aic_enable_depth_first": True, + "qpc_dir_suffix": None, + "full_batch_size": full_batch_size, +} +qpc_path = qeff_model.export_and_compile(**args) + +## STEP 4 -- run inference on the generate function +# prompt_to_lora_id_mapping is a list of lora_id of which the size matches num of prompts +# and is a one-on-one mapping for the prompt-to-loraid +# e.g., prompt_to_lora_id_mapping = [{adapter_id_0}, {adapter_id_1}, {adapter_id_0}, {adapter_id_1}, ...] +# setting INTMAX means using base model +prompts = [ + """Please answer the following question: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\n\nAnswer:""", + """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: Harvard shrank its insect-inspired microrobot to the size of a penny\n\nContent:""", + """Please answer the following question: Gene is sewing a quilt out of old souvenir t-shirts. He has one shirt from each vacation he has been on. Every shirt is its own quilt block. Each row is made of blocks from a different year of vacations. He goes on four vacations a year and has been vacationing since he was 23 years old. He is now 34. How many quilt blocks does he have in total?\n\nAnswer:""", + """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: New neurons for life? Old people can still make fresh brain cells, study finds\n\nContent:""", + """Please answer the following question: Harry slept 9 hours last night. His friend James slept only 2/3 of what Harry slept. How many more hours did Harry sleep than James?\n\nAnswer:""", + """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: Latest success from Google’s AI group: Controlling a fusion reactor\n\nContent:""", + """Please answer the following question: Gene is sewing a quilt out of old souvenir t-shirts. He has one shirt from each vacation he has been on. Every shirt is its own quilt block. Each row is made of blocks from a different year of vacations. He goes on four vacations a year and has been vacationing since he was 23 years old. He is now 34. How many quilt blocks does he have in total?\n\nAnswer:""", + """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: TikTok Picks Streaming Service Audius to Power New ‘Sounds’ Library\n\nContent:""", +] +qeff_model.generate( + prompts, + device_group, + prompt_to_lora_id_mapping=[gsm8k_id, tldr_id, gsm8k_id, INTMAX, gsm8k_id, tldr_id, gsm8k_id, tldr_id], +) + +""" +expected response: + +He runs 3*3=<<3*3=9>>9 sprints a week +So he runs 9*60=<<9*60=540>>540 meters a week +#### 540 + +Researchers at Harvard have created a microrobot that is smaller than a penny. The robot is made of a flexible polymer that can be folded and unfolded to move. It is powered by a laser and can be controlled by a computer. The robot is able to move on its own, but it can also be controlled remotely. It can be used to deliver drugs or to perform other tasks. A 1-minute video that shows the robot in action is available in the article. + +He has been on 34-23=<<34-23=11>>11 vacations +He has 11*4=<<11*4=44>>44 blocks +#### 44 + +A study has found that the human brain can continue to make new neurons throughout life. The study was conducted on 12 people aged 18 to 79. It found that the brains of older people had more new neurons were found in the hippocampus, a part of the brain that is important for memory. The study suggests that the brain may be able to compensate for age-related memory loss. + +James slept 2/3 * 9 = <<2/3*9=6>>6 hours. +Harry slept 9 - 6 = <<9-6=3>>3 hours more than James. +#### 3 + +He has been on 34-23=<<34-23=11>>11 vacations. +He has 11*4=<<11*4=44>>44 blocks. +#### 44 + +AI group has developed a system that can control a fusion reactor. The system uses a deep reinforcement learning + +TikTok has partnered with Audius to power its new Sounds library. The Sounds library will allow users to discover and share sounds from a wide range of creators. Audius is a music streaming platform that allows artists to upload their music and share it with fans. It has a community of over 1.5 million users. TikTok has been working on the Sounds library for over a year. The library will be available in the US, Canada, and Australia. +""" diff --git a/tests/lora/test_lora_model.py b/tests/lora/test_lora_model.py new file mode 100644 index 000000000..2951cb306 --- /dev/null +++ b/tests/lora/test_lora_model.py @@ -0,0 +1,216 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import sys +from pathlib import Path +from time import perf_counter + +import numpy as np +import pytest +from peft import LoraConfig +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient import QEffAutoLoraModelForCausalLM + +INTMAX = sys.maxsize + +configs = [ + pytest.param( + AutoConfig.for_model( + "llama", num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, hidden_size=128 + ), + LoraConfig(target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM", lora_alpha=8), + id="llama-2l-4h-2kvh-128d-qv", + ), + pytest.param( + AutoConfig.for_model( + "mistral", num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, hidden_size=128 + ), + LoraConfig(target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM", lora_alpha=6), + id="mistral-2l-4h-128d-qv", + ), +] + +model_samples = [ + pytest.param("mistralai/Mistral-7B-v0.1", "predibase/gsm8k", "predibase/dbpedia"), + pytest.param( + "meta-llama/Meta-Llama-3-8B", + "hallisky/lora-type-narrative-llama-3-8b", + "hallisky/lora-grade-elementary-llama-3-8b", + ), +] + + +def create_lora_base_model(base_config): + base_model = AutoModelForCausalLM.from_config(base_config, attn_implementation="eager") + lora_base_model = QEffAutoLoraModelForCausalLM( + base_model, pretrained_model_name_or_path=str(base_config.model_type) + ) + + return lora_base_model + + +def load_adapter_with_random_weights(lora_base_model, adapter_config, adapter_name): + lora_base_model.adapter_configs[adapter_name] = adapter_config + lora_base_model.adapter_weights[adapter_name] = {"weights": np.ones((3, 3))} + + return lora_base_model + + +# test model initialization using __init__ approach +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) +def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapter_id_1): + model_hf = AutoModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + qeff_model = QEffAutoLoraModelForCausalLM(model_hf, pretrained_model_name_or_path=base_model_name) + + assert qeff_model.base_model_name == base_model_name + assert len(qeff_model.adapter_weights) == 0 + assert len(qeff_model.adapter_configs) == 0 + assert len(qeff_model.active_adapters) == 0 + assert qeff_model.max_num_adapters == 0 + assert len(qeff_model.active_adapter_to_id) == 0 + + +# test model initialization using from_pretrained approach +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) +def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + + assert qeff_model.base_model_name == base_model_name + assert len(qeff_model.adapter_weights) == 0 + assert len(qeff_model.adapter_configs) == 0 + assert len(qeff_model.active_adapters) == 0 + assert qeff_model.max_num_adapters == 0 + assert len(qeff_model.active_adapter_to_id) == 0 + + +# test model hash +def test_auto_lora_model_for_causal_lm_hash(): + base_config_0, adapter_config_0 = configs[0].values + base_config_1, adapter_config_1 = configs[1].values + + qeff_model_0 = create_lora_base_model(base_config_0) + qeff_model_0 = load_adapter_with_random_weights(qeff_model_0, adapter_config_0, "adapter_0") + qeff_model_0 = load_adapter_with_random_weights(qeff_model_0, adapter_config_1, "adapter_1") + + qeff_model_1 = create_lora_base_model(base_config_1) + qeff_model_1 = load_adapter_with_random_weights(qeff_model_1, adapter_config_0, "adapter_0") + qeff_model_1 = load_adapter_with_random_weights(qeff_model_1, adapter_config_1, "adapter_1") + + qeff_model_0_1 = create_lora_base_model(base_config_0) + qeff_model_0_1 = load_adapter_with_random_weights(qeff_model_0_1, adapter_config_0, "adapter_0") + qeff_model_0_1 = load_adapter_with_random_weights(qeff_model_0_1, adapter_config_1, "adapter_1") + qeff_model_0_1.set_adapter("adapter_0") + model_hash_0_1_0 = qeff_model_0_1.model_hash + qeff_model_0_1.set_adapter("adapter_1") + model_hash_0_1_1 = qeff_model_0_1.model_hash + + # check num of adapter config matters + qeff_model_0.set_adapter("adapter_0") + model_hash_0_0 = qeff_model_0.model_hash + qeff_model_0.set_adapter("adapter_1") + model_hash_0_1 = qeff_model_0.model_hash + assert model_hash_0_0 != model_hash_0_1 + + # check if same model, same adapter result in same hash + assert model_hash_0_1_0 == model_hash_0_0 + assert model_hash_0_1_1 == model_hash_0_1 + + # check base model configs matters + qeff_model_1.set_adapter("adapter_0") + qeff_model_1.set_adapter("adapter_1") + model_hash_1_1 = qeff_model_1.model_hash + assert model_hash_1_1 != model_hash_0_1 + + # check adapter orders matters + qeff_model_1.delete_adapter("adapter_0") + qeff_model_1.delete_adapter("adapter_1") + qeff_model_1.set_adapter("adapter_1") + qeff_model_1.set_adapter("adapter_0") + model_hash_1_2 = qeff_model_1.model_hash + assert model_hash_1_2 != model_hash_1_1 + + # check if different adapter config, but same adapter name matters + qeff_model_0 = load_adapter_with_random_weights(qeff_model_0, adapter_config_0, "adapter_1") + model_hash_0_2 = qeff_model_0.model_hash + assert model_hash_0_2 != model_hash_0_1 + + +# test load_adapter() and set_adapter() and get_adapter_id() +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1]) +def test_auto_lora_model_for_causal_lm_load_set_adapter_id_check(base_model_name, adapter_id_0, adapter_id_1): + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + + qeff_model.load_adapter(adapter_id_0, "adapter_0") + qeff_model.load_adapter(adapter_id_1, "adapter_1") + set_id_0 = qeff_model.set_adapter("adapter_0") + set_id_1 = qeff_model.set_adapter("adapter_1") + + assert set_id_1 == set_id_0 + 1 + + qeff_model.load_adapter(adapter_id_1, "adapter_2") + qeff_model.set_adapter("adapter_2") + qeff_model.delete_adapter("adapter_1") + + update_id_0 = qeff_model.get_adapter_id("adapter_0") + update_id_2 = qeff_model.get_adapter_id("adapter_2") + assert set_id_0 == update_id_0 + assert set_id_1 == update_id_2 + + with pytest.raises(KeyError): + qeff_model.get_adapter_id("adapter_1") + + +# test unload_adapter() and delete_adapter() +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[1:]) +def test_auto_lora_model_for_causal_lm_unload_delete_adapter(base_model_name, adapter_id_0, adapter_id_1): + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + + qeff_model.load_adapter(adapter_id_0, "adapter_0") + qeff_model.load_adapter(adapter_id_1, "adapter_1") + qeff_model.set_adapter("adapter_0") + + assert not qeff_model.unload_adapter("adapter_0") # active adapter + assert qeff_model.unload_adapter("adapter_1") # valid unload + assert not qeff_model.unload_adapter("adapter_2") # not loaded adapter + + assert qeff_model.delete_adapter("adapter_0") # active adapter + assert not qeff_model.delete_adapter("adapter_1") # not active adapter + + +# test the export, export caching, compile, generate workflow +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1]) +def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path): + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + + qeff_model.load_adapter(adapter_id_0, "adapter_0") + qeff_model.load_adapter(adapter_id_1, "adapter_1") + qeff_model.set_adapter("adapter_0") + qeff_model.set_adapter("adapter_1") + + # export + start = perf_counter() + qeff_model.export(export_dir=tmp_path, full_batch_size=1) # NOTE: should export with full_batch_size enabled + end = perf_counter() + export_time_0 = end - start + model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) + assert model_path.is_dir() + assert Path(qeff_model.onnx_path).is_file() + + # test export caching + start = perf_counter() + qeff_model.export(export_dir=tmp_path, full_batch_size=1) + end = perf_counter() + export_time_1 = end - start + assert export_time_1 < 0.01 * export_time_0 + + # test compile + qeff_model.compile(num_cores=16, device_group=[0]) + assert Path(qeff_model.qpc_path).is_dir() + + # test generate + prompts = ["hello!", "hi", "hello, my name is", "hey"] + qeff_model.generate(prompts, [0], prompt_to_lora_id_mapping=[0, 1, 0, INTMAX]) From 0a406d5de9e7579df7aec05c749b467250a2946e Mon Sep 17 00:00:00 2001 From: Jou-An Chen Date: Thu, 17 Oct 2024 13:36:45 -0700 Subject: [PATCH 2/9] Remove set delete adapter, add init assertion, update LinearMultiLoRA Signed-off-by: Jou-An Chen --- QEfficient/lora/auto.py | 76 ++++++++++++------ QEfficient/lora/layers.py | 35 ++++---- examples/lora_models.py | 15 ++-- tests/lora/test_lora_model.py | 145 ++++++++++++++++++---------------- 4 files changed, 149 insertions(+), 122 deletions(-) diff --git a/QEfficient/lora/auto.py b/QEfficient/lora/auto.py index 06e2ca59c..947c0afb4 100644 --- a/QEfficient/lora/auto.py +++ b/QEfficient/lora/auto.py @@ -9,7 +9,7 @@ import os import sys from pathlib import Path -from typing import List, Optional +from typing import Any, List, Optional import torch import torch.nn as nn @@ -61,6 +61,10 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): # inherit __init__() from QEFFAutoModelForCausalLM def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwargs) -> None: super().__init__(model, pretrained_model_name_or_path) + assert ( + type(self.model).__name__ == "QEffMistralForCausalLM" or type(self.model).__name__ == "QEffLlamaForCausalLM" + ), f"Only QEffMistralForCausalLM and QEffLlamaForCausalLM model are supported but get {type(self.model).__name__}" + self.base_model_name = pretrained_model_name_or_path self.adapter_weights = {} self.adapter_configs = {} @@ -84,13 +88,19 @@ def model_hash(self) -> str: active_adapter_configs[adpt] = self.adapter_configs[adpt].to_dict() mhash.update(to_hashable(active_adapter_configs)) + # create active adapter weight dict + active_adapter_weights = {} + for adpt in self.active_adapters: + active_adapter_weights[adpt] = {key: value.tolist() for key, value in self.adapter_weights[adpt].items()} + mhash.update(to_hashable(active_adapter_weights)) + # ensure model will be exported again if order of adapters changes mhash.update(to_hashable(self.active_adapter_to_id)) mhash = mhash.hexdigest()[:16] return mhash - def load_adapter(self, adapter_model_id: str, adapter_name: str): + def download_adapter(self, adapter_model_id: str, adapter_name: str): """Loads a new adapter from huggingface hub or local path into CPU cache Args: @@ -105,27 +115,26 @@ def load_adapter(self, adapter_model_id: str, adapter_name: str): } self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id) - def unload_adapter(self, adapter_name: str): - if adapter_name not in self.adapter_weights.keys() and adapter_name not in self.adapter_configs.keys(): - print(f"Adapter name {adapter_name} is not loaded yet") - return False - - if adapter_name in self.active_adapters: - print(f"Adapter name {adapter_name} is stil in active list, do delete_adapter() before unloading") - return False + def load_adapter(self, adapter_model_id: str, adapter_name: str, **kwargs: Any): + "Load adapter into CPU cache and Sets active adapter from one of the loaded adapters" - self.adapter_weights.pop(adapter_name) - self.adapter_configs.pop(adapter_name) - logger.warning(f"Unloading {adapter_name} from CPU cache.") - return True + # check if adapter name already exist, if so, overwrite it + if (adapter_name in self.adapter_weights.keys()) and (adapter_name in self.adapter_configs.keys()): + logger.warning(f"Overwrite weights and configs for adapter name {adapter_name}") - def set_adapter(self, adapter_name: str): - "Sets active adapter from one of the loaded adapters" + adapter_weight = kwargs.pop("adapter_weight", None) + adapter_config = kwargs.pop("adapter_config", None) - assert (adapter_name in self.adapter_weights.keys()) and ( - adapter_name in self.adapter_configs.keys() - ), f"Adapter name {adapter_name} has not been loaded yet" + if adapter_weight and adapter_config: # if sufficiently get adapter weight and adpater config + self.adapter_weights[adapter_name] = adapter_weight + self.adapter_configs[adapter_name] = adapter_config + else: # load from hugging face + self.adapter_weights[adapter_name] = { + k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items() + } + self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id) + # check if adapters has same target module and rank assert ( list(self.adapter_configs.values())[0] and self.adapter_configs[adapter_name].target_modules @@ -137,16 +146,18 @@ def set_adapter(self, adapter_name: str): and self.adapter_configs[adapter_name].r == list(self.adapter_configs.values())[0].r ), "Not all adapters have the same ranks" - # set active adapter id to current max - self.active_adapter_to_id[adapter_name] = self.max_num_adapters + # set active adapter id to current max if adapter_name is new + if adapter_name not in self.active_adapter_to_id.keys(): + self.active_adapter_to_id[adapter_name] = self.max_num_adapters - # add active adapter to set - self.active_adapters.add(adapter_name) - self.max_num_adapters = len(self.active_adapters) + # add active adapter to set + self.active_adapters.add(adapter_name) + self.max_num_adapters = len(self.active_adapters) return self.active_adapter_to_id[adapter_name] - def delete_adapter(self, adapter_name: str): + def unload_adapter(self, adapter_name: str): + # remove from active list if adapter_name not in self.active_adapters: print(f"Adapter name {adapter_name} is not set active yet") return False @@ -162,6 +173,21 @@ def delete_adapter(self, adapter_name: str): logger.warning(f"Deleting {adapter_name} from active adapters.") if self.onnx_path or self.qpc_path: logger.warning("Please redo compile_and_export() to reflect the active adapters changes.") + self.onnx_path = None + self.qpc_path = None + + # delete from cache + if adapter_name not in self.adapter_weights.keys() and adapter_name not in self.adapter_configs.keys(): + print(f"Adapter name {adapter_name} is not loaded yet") + return False + + if adapter_name in self.active_adapters: + print(f"Adapter name {adapter_name} is stil in active list, do delete_adapter() before unloading") + return False + + self.adapter_weights.pop(adapter_name) + self.adapter_configs.pop(adapter_name) + logger.warning(f"Unloading {adapter_name} from CPU cache.") return True diff --git a/QEfficient/lora/layers.py b/QEfficient/lora/layers.py index 49694aacf..f726633c6 100644 --- a/QEfficient/lora/layers.py +++ b/QEfficient/lora/layers.py @@ -33,31 +33,24 @@ def multilora_init(self, lora_rank, max_num_adapters): nn.init.kaiming_uniform_(self.lora_weight_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_weight_B) - def forward(self, x: torch.Tensor, **kwargs: Any): - lora_ids = kwargs.pop("lora_ids", torch.zeros((x.shape[0]), dtype=torch.int64).view(-1, 1)) - - with torch.no_grad(): - result = F.linear(x, self.weight, bias=self.bias) + def forward(self, x: torch.Tensor, lora_ids: torch.Tensor): + result = F.linear(x, self.weight, bias=self.bias) - # multilora implementation: lora_ids - other_indices_A = torch.arange(self.lora_weight_A.shape[2]).view(1, 1, -1) - A_embedding = CtxGatherFuncCB.apply( - self.lora_weight_A, lora_ids, other_indices_A - ) # - other_indices_B = torch.arange(self.lora_weight_B.shape[2]).view(1, 1, -1) - B_embedding = CtxGatherFuncCB.apply( - self.lora_weight_B, lora_ids, other_indices_B - ) # - other_indices_C = torch.arange(self.lora_weight_C.shape[2]).view(1, 1, -1) - C_embedding = CtxGatherFuncCB.apply(self.lora_weight_C, lora_ids, other_indices_C) # + # multilora implementation: lora_ids + other_indices_A = torch.arange(self.lora_weight_A.shape[2]).view(1, 1, -1) + A_embedding = CtxGatherFuncCB.apply(self.lora_weight_A, lora_ids, other_indices_A) # + other_indices_B = torch.arange(self.lora_weight_B.shape[2]).view(1, 1, -1) + B_embedding = CtxGatherFuncCB.apply(self.lora_weight_B, lora_ids, other_indices_B) # + other_indices_C = torch.arange(self.lora_weight_C.shape[2]).view(1, 1, -1) + C_embedding = CtxGatherFuncCB.apply(self.lora_weight_C, lora_ids, other_indices_C) # - A_embedding = A_embedding.squeeze(1) - B_embedding = B_embedding.squeeze(1) - C_embedding = C_embedding.squeeze(1) + A_embedding = A_embedding.squeeze(1) + B_embedding = B_embedding.squeeze(1) + C_embedding = C_embedding.squeeze(1) - result = result + x @ A_embedding @ B_embedding * C_embedding + result = result + x @ A_embedding @ B_embedding * C_embedding - return result + return result class LinearBase(nn.Linear): diff --git a/examples/lora_models.py b/examples/lora_models.py index 2c83374e5..64132ab87 100644 --- a/examples/lora_models.py +++ b/examples/lora_models.py @@ -28,22 +28,17 @@ # **Option2**: Initialize the model using from_pretrained() method qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name) -## STEP 2 -- load adapter & set adapter -qeff_model.load_adapter("predibase/gsm8k", "gsm8k") -adapter_id_gsm8k = qeff_model.set_adapter("gsm8k") +## STEP 2 -- load adapter adapter +adapter_id_gsm8k = qeff_model.load_adapter("predibase/gsm8k", "gsm8k") print(f"Activating gsm8k as adapter_id {adapter_id_gsm8k}") -qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen") -adapter_id_tldr = qeff_model.set_adapter("tldr_content_gen") +adapter_id_tldr = qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen") print(f"Activating tldr_content_gen as adapter_id {adapter_id_tldr}") -# STEP 2 (optional) -- delete adapter & unload adapter -qeff_model.load_adapter("predibase/dbpedia", "dbpedia") -adapter_id_dbpedia = qeff_model.set_adapter("dbpedia") +adapter_id_dbpedia = qeff_model.load_adapter("predibase/dbpedia", "dbpedia") print(f"Activating dbpedia as adapter_id {adapter_id_dbpedia}") -delete_status = qeff_model.delete_adapter("dbpedia") -print(f"Deleting dbpedia success: {delete_status}") +# STEP 2 (optional) -- unload adapter unload_status = qeff_model.unload_adapter("dbpedia") print(f"Unloading dbpedia success: {unload_status}") diff --git a/tests/lora/test_lora_model.py b/tests/lora/test_lora_model.py index 2951cb306..0c381d847 100644 --- a/tests/lora/test_lora_model.py +++ b/tests/lora/test_lora_model.py @@ -53,13 +53,6 @@ def create_lora_base_model(base_config): return lora_base_model -def load_adapter_with_random_weights(lora_base_model, adapter_config, adapter_name): - lora_base_model.adapter_configs[adapter_name] = adapter_config - lora_base_model.adapter_weights[adapter_name] = {"weights": np.ones((3, 3))} - - return lora_base_model - - # test model initialization using __init__ approach @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapter_id_1): @@ -87,73 +80,98 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_ assert len(qeff_model.active_adapter_to_id) == 0 +# test the init assertion for models that are not supported +@pytest.mark.parametrize("base_model_name", ["distilbert/distilgpt2"]) +def test_auto_lora_model_for_causal_lm_init_from_unsupported_model(base_model_name): + model_hf = AutoModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + with pytest.raises(AssertionError): + QEffAutoLoraModelForCausalLM(model_hf, pretrained_model_name_or_path=base_model_name) + + with pytest.raises(AssertionError): + QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + + # test model hash def test_auto_lora_model_for_causal_lm_hash(): base_config_0, adapter_config_0 = configs[0].values base_config_1, adapter_config_1 = configs[1].values qeff_model_0 = create_lora_base_model(base_config_0) - qeff_model_0 = load_adapter_with_random_weights(qeff_model_0, adapter_config_0, "adapter_0") - qeff_model_0 = load_adapter_with_random_weights(qeff_model_0, adapter_config_1, "adapter_1") + qeff_model_0.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + qeff_model_0.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_0_0 = qeff_model_0.model_hash qeff_model_1 = create_lora_base_model(base_config_1) - qeff_model_1 = load_adapter_with_random_weights(qeff_model_1, adapter_config_0, "adapter_0") - qeff_model_1 = load_adapter_with_random_weights(qeff_model_1, adapter_config_1, "adapter_1") + qeff_model_1.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + qeff_model_1.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_1_0 = qeff_model_1.model_hash qeff_model_0_1 = create_lora_base_model(base_config_0) - qeff_model_0_1 = load_adapter_with_random_weights(qeff_model_0_1, adapter_config_0, "adapter_0") - qeff_model_0_1 = load_adapter_with_random_weights(qeff_model_0_1, adapter_config_1, "adapter_1") - qeff_model_0_1.set_adapter("adapter_0") + qeff_model_0_1.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + qeff_model_0_1.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))} + ) model_hash_0_1_0 = qeff_model_0_1.model_hash - qeff_model_0_1.set_adapter("adapter_1") - model_hash_0_1_1 = qeff_model_0_1.model_hash - - # check num of adapter config matters - qeff_model_0.set_adapter("adapter_0") - model_hash_0_0 = qeff_model_0.model_hash - qeff_model_0.set_adapter("adapter_1") - model_hash_0_1 = qeff_model_0.model_hash - assert model_hash_0_0 != model_hash_0_1 - # check if same model, same adapter result in same hash + # check if same model, same adapter config, same adapter weight, result in same hash assert model_hash_0_1_0 == model_hash_0_0 - assert model_hash_0_1_1 == model_hash_0_1 - # check base model configs matters - qeff_model_1.set_adapter("adapter_0") - qeff_model_1.set_adapter("adapter_1") - model_hash_1_1 = qeff_model_1.model_hash - assert model_hash_1_1 != model_hash_0_1 + # check if same model, same adapter config, but different weight, result in different hash + qeff_model_0_1.unload_adapter("adapter_1") + qeff_model_0_1.unload_adapter("adapter_0") + qeff_model_0_1.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.random.randn(3, 3)} + ) + qeff_model_0_1.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.random.randn(3, 3)} + ) + model_hash_0_1_1 = qeff_model_0_1.model_hash + assert model_hash_0_1_1 != model_hash_0_0 - # check adapter orders matters - qeff_model_1.delete_adapter("adapter_0") - qeff_model_1.delete_adapter("adapter_1") - qeff_model_1.set_adapter("adapter_1") - qeff_model_1.set_adapter("adapter_0") - model_hash_1_2 = qeff_model_1.model_hash - assert model_hash_1_2 != model_hash_1_1 + # check base model configs difference result in different hash + assert model_hash_0_0 != model_hash_1_0 - # check if different adapter config, but same adapter name matters - qeff_model_0 = load_adapter_with_random_weights(qeff_model_0, adapter_config_0, "adapter_1") - model_hash_0_2 = qeff_model_0.model_hash - assert model_hash_0_2 != model_hash_0_1 + # check different adapter orders, result in different hash + qeff_model_1.unload_adapter("adapter_0") + qeff_model_1.unload_adapter("adapter_1") + qeff_model_1.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))} + ) + qeff_model_1.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_1_1 = qeff_model_1.model_hash + assert model_hash_1_1 != model_hash_1_0 + # check if same adapter name, but different config, result in different hash + qeff_model_0.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_0_1 = qeff_model_0.model_hash + assert model_hash_0_1 != model_hash_0_0 -# test load_adapter() and set_adapter() and get_adapter_id() + +# test load_adapter() and get_adapter_id() @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1]) -def test_auto_lora_model_for_causal_lm_load_set_adapter_id_check(base_model_name, adapter_id_0, adapter_id_1): +def test_auto_lora_model_for_causal_lm_load_get_adapter_id_check(base_model_name, adapter_id_0, adapter_id_1): qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) - qeff_model.load_adapter(adapter_id_0, "adapter_0") - qeff_model.load_adapter(adapter_id_1, "adapter_1") - set_id_0 = qeff_model.set_adapter("adapter_0") - set_id_1 = qeff_model.set_adapter("adapter_1") - + set_id_0 = qeff_model.load_adapter(adapter_id_0, "adapter_0") + set_id_1 = qeff_model.load_adapter(adapter_id_1, "adapter_1") assert set_id_1 == set_id_0 + 1 qeff_model.load_adapter(adapter_id_1, "adapter_2") - qeff_model.set_adapter("adapter_2") - qeff_model.delete_adapter("adapter_1") + qeff_model.unload_adapter("adapter_1") update_id_0 = qeff_model.get_adapter_id("adapter_0") update_id_2 = qeff_model.get_adapter_id("adapter_2") @@ -164,21 +182,18 @@ def test_auto_lora_model_for_causal_lm_load_set_adapter_id_check(base_model_name qeff_model.get_adapter_id("adapter_1") -# test unload_adapter() and delete_adapter() +# test download_adapter(), load_adapter() and unload_adapter() @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[1:]) -def test_auto_lora_model_for_causal_lm_unload_delete_adapter(base_model_name, adapter_id_0, adapter_id_1): +def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adapter_id_0, adapter_id_1): qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) - qeff_model.load_adapter(adapter_id_0, "adapter_0") - qeff_model.load_adapter(adapter_id_1, "adapter_1") - qeff_model.set_adapter("adapter_0") + qeff_model.download_adapter(adapter_id_0, "adapter_0") + qeff_model.download_adapter(adapter_id_1, "adapter_1") - assert not qeff_model.unload_adapter("adapter_0") # active adapter - assert qeff_model.unload_adapter("adapter_1") # valid unload - assert not qeff_model.unload_adapter("adapter_2") # not loaded adapter + qeff_model.load_adapter(adapter_id_0, "adapter_0") - assert qeff_model.delete_adapter("adapter_0") # active adapter - assert not qeff_model.delete_adapter("adapter_1") # not active adapter + assert not qeff_model.unload_adapter("adapter_1") # not active adapter + assert qeff_model.unload_adapter("adapter_0") # valid unload # test the export, export caching, compile, generate workflow @@ -186,10 +201,8 @@ def test_auto_lora_model_for_causal_lm_unload_delete_adapter(base_model_name, ad def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path): qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) - qeff_model.load_adapter(adapter_id_0, "adapter_0") - qeff_model.load_adapter(adapter_id_1, "adapter_1") - qeff_model.set_adapter("adapter_0") - qeff_model.set_adapter("adapter_1") + id_0 = qeff_model.load_adapter(adapter_id_0, "adapter_0") + id_1 = qeff_model.load_adapter(adapter_id_1, "adapter_1") # export start = perf_counter() @@ -205,7 +218,7 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, qeff_model.export(export_dir=tmp_path, full_batch_size=1) end = perf_counter() export_time_1 = end - start - assert export_time_1 < 0.01 * export_time_0 + assert export_time_1 < export_time_0 # test compile qeff_model.compile(num_cores=16, device_group=[0]) @@ -213,4 +226,4 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, # test generate prompts = ["hello!", "hi", "hello, my name is", "hey"] - qeff_model.generate(prompts, [0], prompt_to_lora_id_mapping=[0, 1, 0, INTMAX]) + qeff_model.generate(prompts, [0], prompt_to_lora_id_mapping=[id_0, id_1, id_0, INTMAX]) From e38c62f2af19eb1b2b1ab03355978b5bbf164fa8 Mon Sep 17 00:00:00 2001 From: Jou-An Chen Date: Thu, 17 Oct 2024 16:32:29 -0700 Subject: [PATCH 3/9] Fix base model inference index INTMAX issue Signed-off-by: Jou-An Chen --- .../generation/text_generation_inference.py | 8 +++--- QEfficient/lora/auto.py | 28 +++++++++++-------- QEfficient/lora/layers.py | 6 ++-- examples/lora_models.py | 9 +++--- tests/lora/test_lora_model.py | 5 +--- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index a624ce24c..bb556722c 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -657,10 +657,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): generated_id_current_index[decode_batch_id] += 1 - if self.prompt_to_lora_id_mapping_decode: - decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[ - batch_id_map[decode_batch_id] - ] + if self.prompt_to_lora_id_mapping_decode: + decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[ + batch_id_map[decode_batch_id] + ] return decode_pause_time diff --git a/QEfficient/lora/auto.py b/QEfficient/lora/auto.py index 947c0afb4..263a5a7cb 100644 --- a/QEfficient/lora/auto.py +++ b/QEfficient/lora/auto.py @@ -7,7 +7,6 @@ import hashlib import os -import sys from pathlib import Path from typing import Any, List, Optional @@ -24,8 +23,6 @@ from QEfficient.utils.constants import QEFF_MODELS_DIR from QEfficient.utils.logging_utils import logger -INTMAX = sys.maxsize - class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): """ @@ -54,7 +51,7 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): m.compile(num_cores=16, device_group=[0]) prompts=["code prompt", "math prompt", "generic"] - m.generate(prompts, device_group=[0], prompt_to_lora_id_mapping=[magicoder_id,gsm8k_id,INTMAX]) + m.generate(prompts, device_group=[0], prompt_to_lora_id_mapping=[magicoder_id,gsm8k_id,0]) """ @@ -148,7 +145,7 @@ def load_adapter(self, adapter_model_id: str, adapter_name: str, **kwargs: Any): # set active adapter id to current max if adapter_name is new if adapter_name not in self.active_adapter_to_id.keys(): - self.active_adapter_to_id[adapter_name] = self.max_num_adapters + self.active_adapter_to_id[adapter_name] = self.max_num_adapters + 1 # reserve 0 for base # add active adapter to set self.active_adapters.add(adapter_name) @@ -168,7 +165,7 @@ def unload_adapter(self, adapter_name: str): # renumbering of active adapter id for index, (key, value) in enumerate(self.active_adapter_to_id.items()): - self.active_adapter_to_id[key] = index + self.active_adapter_to_id[key] = index + 1 logger.warning(f"Deleting {adapter_name} from active adapters.") if self.onnx_path or self.qpc_path: @@ -203,9 +200,9 @@ def load_adapter_weights_to_model(self): for i in range(num_hidden_layers): for target_module in self.target_modules_for_all_adapters: # stack all adapters weights - a_tensor_list = list(range(self.max_num_adapters)) - b_tensor_list = list(range(self.max_num_adapters)) - c_tensor_list = list(range(self.max_num_adapters)) + a_tensor_list = list(range(self.max_num_adapters + 1)) + b_tensor_list = list(range(self.max_num_adapters + 1)) + c_tensor_list = list(range(self.max_num_adapters + 1)) for lora_name, lora_id in self.active_adapter_to_id.items(): if ( @@ -232,12 +229,18 @@ def load_adapter_weights_to_model(self): dtype=torch.float16, ) + # dummy zero tensor for base model + a_tensor_list[0] = torch.zeros_like(a_tensor_list[1]) + b_tensor_list[0] = torch.zeros_like(b_tensor_list[1]) + c_tensor_list[0] = torch.zeros_like(c_tensor_list[1]) + + # stack weight tensors stacked_lora_A = ( torch.stack(a_tensor_list, dim=0).unsqueeze(1).transpose(2, 3) - ) # + ) # stacked_lora_B = ( torch.stack(b_tensor_list, dim=0).unsqueeze(1).transpose(2, 3) - ) # + ) # stacked_lora_C = ( torch.stack(c_tensor_list, dim=0).unsqueeze(1).unsqueeze(2).unsqueeze(3) ) # @@ -308,6 +311,7 @@ def export(self, **kwargs) -> str: export_dir = kwargs.get("export_dir", None) # obtain all necessary information to initialize the model + assert self.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage" self.init_adapter_model() assert self.is_transformed, "Please first run transform on the QEFFAutoModelForCausalLM object" @@ -411,7 +415,7 @@ def export_and_compile( def run_cloud_ai_100(self, prompts: List[str], device_id: List[int] = None, **kwargs): assert isinstance(self.qpc_path, str), "Please run compile API first!" generation_len = kwargs.pop("generation_len", None) - default_mapping = [INTMAX for _ in range(len(prompts))] + default_mapping = [0 for _ in range(len(prompts))] prompt_to_lora_id_mapping = kwargs.pop("prompt_to_lora_id_mapping", default_mapping) return QEfficient.cloud_ai_100_exec_kv( self.tokenizer, diff --git a/QEfficient/lora/layers.py b/QEfficient/lora/layers.py index f726633c6..4adeaceaa 100644 --- a/QEfficient/lora/layers.py +++ b/QEfficient/lora/layers.py @@ -21,14 +21,14 @@ def multilora_init(self, lora_rank, max_num_adapters): self.lora_rank = lora_rank self.lora_weight_A = nn.Parameter( - self.weight.new_zeros(self.max_num_adapters, 1, self.in_features, self.lora_rank) + self.weight.new_zeros(self.max_num_adapters + 1, 1, self.in_features, self.lora_rank) ) self.lora_weight_A.requires_grad = False self.lora_weight_B = nn.Parameter( - self.weight.new_zeros(self.max_num_adapters, 1, self.lora_rank, self.out_features) + self.weight.new_zeros(self.max_num_adapters + 1, 1, self.lora_rank, self.out_features) ) self.lora_weight_B.requires_grad = False - self.lora_weight_C = torch.full((self.max_num_adapters, 1, 1, 1), 1.0, dtype=torch.float) + self.lora_weight_C = torch.full((self.max_num_adapters + 1, 1, 1, 1), 1.0, dtype=torch.float) nn.init.kaiming_uniform_(self.lora_weight_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_weight_B) diff --git a/examples/lora_models.py b/examples/lora_models.py index 64132ab87..14e7ad402 100644 --- a/examples/lora_models.py +++ b/examples/lora_models.py @@ -7,12 +7,9 @@ ## This example works on continuous batching with different lora adapters in the same batch ## -import sys from QEfficient import QEffAutoLoraModelForCausalLM -INTMAX = sys.maxsize - base_model_name = "mistralai/Mistral-7B-v0.1" seq_len = 128 ctx_len = 256 @@ -67,7 +64,7 @@ # prompt_to_lora_id_mapping is a list of lora_id of which the size matches num of prompts # and is a one-on-one mapping for the prompt-to-loraid # e.g., prompt_to_lora_id_mapping = [{adapter_id_0}, {adapter_id_1}, {adapter_id_0}, {adapter_id_1}, ...] -# setting INTMAX means using base model +# setting 0 means using base model prompts = [ """Please answer the following question: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\n\nAnswer:""", """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: Harvard shrank its insect-inspired microrobot to the size of a penny\n\nContent:""", @@ -81,9 +78,11 @@ qeff_model.generate( prompts, device_group, - prompt_to_lora_id_mapping=[gsm8k_id, tldr_id, gsm8k_id, INTMAX, gsm8k_id, tldr_id, gsm8k_id, tldr_id], + prompt_to_lora_id_mapping=[0, 0, 0, 0, 0, 0, 0, 0], ) +# [gsm8k_id, tldr_id, gsm8k_id, 0, gsm8k_id, tldr_id, gsm8k_id, tldr_id] + """ expected response: diff --git a/tests/lora/test_lora_model.py b/tests/lora/test_lora_model.py index 0c381d847..3243fcba3 100644 --- a/tests/lora/test_lora_model.py +++ b/tests/lora/test_lora_model.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- -import sys from pathlib import Path from time import perf_counter @@ -15,8 +14,6 @@ from QEfficient import QEffAutoLoraModelForCausalLM -INTMAX = sys.maxsize - configs = [ pytest.param( AutoConfig.for_model( @@ -226,4 +223,4 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, # test generate prompts = ["hello!", "hi", "hello, my name is", "hey"] - qeff_model.generate(prompts, [0], prompt_to_lora_id_mapping=[id_0, id_1, id_0, INTMAX]) + qeff_model.generate(prompts, [0], prompt_to_lora_id_mapping=[id_0, id_1, id_0, 0]) From aed1e67c6adfb9c3602ed71bb82844d3e578c085 Mon Sep 17 00:00:00 2001 From: Jou-An Chen Date: Fri, 18 Oct 2024 10:57:00 -0700 Subject: [PATCH 4/9] Addressed review comments Signed-off-by: Jou-An Chen --- .../generation/text_generation_inference.py | 11 +- QEfficient/lora/auto.py | 183 ++++++++---------- QEfficient/lora/layers.py | 47 +++-- examples/lora_models.py | 25 ++- tests/lora/test_lora_model.py | 3 +- 5 files changed, 141 insertions(+), 128 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index bb556722c..ea9507ec6 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -645,6 +645,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): self.session.set_buffers({"logits": logits_out_placeholder}) decode_pause_time += perf_counter() - start + + if self.prompt_to_lora_id_mapping_decode: + decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[ + batch_id_map[decode_batch_id] + ] + else: current_decode_ongoing[decode_batch_id] = False else: @@ -657,11 +663,6 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): generated_id_current_index[decode_batch_id] += 1 - if self.prompt_to_lora_id_mapping_decode: - decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[ - batch_id_map[decode_batch_id] - ] - return decode_pause_time def run_decode(self, decode_inputs, generation_len): diff --git a/QEfficient/lora/auto.py b/QEfficient/lora/auto.py index 263a5a7cb..93b541b22 100644 --- a/QEfficient/lora/auto.py +++ b/QEfficient/lora/auto.py @@ -8,7 +8,7 @@ import hashlib import os from pathlib import Path -from typing import Any, List, Optional +from typing import List, Optional import torch import torch.nn as nn @@ -26,7 +26,7 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): """ - QEff class for loading models with mutltiple LoRA adapters. + QEff class for loading models with multiple LoRA adapters. Once exported and compiled, the qpc can perform mixed batch inference with provided prompt_to_lora_id_mapping. Args: @@ -34,7 +34,6 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): :base_model_name (str): Model card name for base model :adapter_weights (Dict): A dictionary contains lora_name to lora_weight mapping :adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping - :active_adapters (Set): A set of lora_names that are currently active :max_num_adapters (int): Total number of active adapters that to be exported and compiled :active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping @@ -65,7 +64,6 @@ def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwarg self.base_model_name = pretrained_model_name_or_path self.adapter_weights = {} self.adapter_configs = {} - self.active_adapters = set() self.max_num_adapters = 0 self.active_adapter_to_id = {} @@ -81,13 +79,13 @@ def model_hash(self) -> str: # create active adapter config dict active_adapter_configs = {} - for adpt in self.active_adapters: + for adpt in self.active_adapter_to_id.keys(): active_adapter_configs[adpt] = self.adapter_configs[adpt].to_dict() mhash.update(to_hashable(active_adapter_configs)) # create active adapter weight dict active_adapter_weights = {} - for adpt in self.active_adapters: + for adpt in self.active_adapter_to_id.keys(): active_adapter_weights[adpt] = {key: value.tolist() for key, value in self.adapter_weights[adpt].items()} mhash.update(to_hashable(active_adapter_weights)) @@ -97,69 +95,78 @@ def model_hash(self) -> str: mhash = mhash.hexdigest()[:16] return mhash - def download_adapter(self, adapter_model_id: str, adapter_name: str): + def download_adapter( + self, + adapter_model_id: str, + adapter_name: str, + adapter_weight: Optional[dict] = None, + adapter_config: Optional[PeftConfig] = None, + ): """Loads a new adapter from huggingface hub or local path into CPU cache Args: :adapter_model_id (str): Adapter model ID from huggingface hub or local path :adapter_name (str): Adapter name to be used to set this adapter as current """ - if (adapter_name in self.adapter_weights.keys()) and (adapter_name in self.adapter_configs.keys()): - logger.warning(f"Overwrite weights and configs for adapter name {adapter_name}") - self.adapter_weights[adapter_name] = { - k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items() - } - self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id) - - def load_adapter(self, adapter_model_id: str, adapter_name: str, **kwargs: Any): - "Load adapter into CPU cache and Sets active adapter from one of the loaded adapters" - - # check if adapter name already exist, if so, overwrite it + # check if adapter name already loaded if (adapter_name in self.adapter_weights.keys()) and (adapter_name in self.adapter_configs.keys()): - logger.warning(f"Overwrite weights and configs for adapter name {adapter_name}") - - adapter_weight = kwargs.pop("adapter_weight", None) - adapter_config = kwargs.pop("adapter_config", None) - - if adapter_weight and adapter_config: # if sufficiently get adapter weight and adpater config - self.adapter_weights[adapter_name] = adapter_weight - self.adapter_configs[adapter_name] = adapter_config - else: # load from hugging face - self.adapter_weights[adapter_name] = { - k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items() - } - self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id) - - # check if adapters has same target module and rank - assert ( - list(self.adapter_configs.values())[0] - and self.adapter_configs[adapter_name].target_modules - == list(self.adapter_configs.values())[0].target_modules - ), "Not all adapters have the same target modules" - - assert ( - list(self.adapter_configs.values())[0] - and self.adapter_configs[adapter_name].r == list(self.adapter_configs.values())[0].r - ), "Not all adapters have the same ranks" - - # set active adapter id to current max if adapter_name is new - if adapter_name not in self.active_adapter_to_id.keys(): - self.active_adapter_to_id[adapter_name] = self.max_num_adapters + 1 # reserve 0 for base + logger.warning(f"{adapter_name} has been loaded. Skip download.") + else: + if adapter_weight and adapter_config: # if sufficiently get adapter weight and adpater config + self.adapter_weights[adapter_name] = adapter_weight + self.adapter_configs[adapter_name] = adapter_config + else: # donwload with adapter_model_id + self.adapter_weights[adapter_name] = { + k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items() + } + self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id) + + def load_adapter( + self, + adapter_model_id: str, + adapter_name: str, + adapter_weight: Optional[dict] = None, + adapter_config: Optional[PeftConfig] = None, + ): + "Load adapter into CPU cache and Sets active adapter from one of the loaded adapters" - # add active adapter to set - self.active_adapters.add(adapter_name) - self.max_num_adapters = len(self.active_adapters) + # check if adapter name already exist and activated + if adapter_name in self.active_adapter_to_id.keys(): + logger.warning(f"{adapter_name} exists and activated. Please provide a different adapter_name.") + else: + self.download_adapter(adapter_model_id, adapter_name, adapter_weight, adapter_config) + + # starting from the second adapter_name, check if adapters has same target module and rank + if list(self.adapter_configs.values())[0] and ( + self.adapter_configs[adapter_name].target_modules + != list(self.adapter_configs.values())[0].target_modules + ): + raise ValueError( + f"{adapter_name} must have same target_modules as {list(self.adapter_configs.keys())[0]}" + ) + if list(self.adapter_configs.values())[0] and ( + self.adapter_configs[adapter_name].r != list(self.adapter_configs.values())[0].r + ): + raise ValueError(f"{adapter_name} must have same rank as {list(self.adapter_configs.keys())[0]}") + + # set active adapter id to current max if adapter_name is new + if adapter_name not in self.active_adapter_to_id.keys(): + self.active_adapter_to_id[adapter_name] = self.max_num_adapters + 1 # reserve 0 for base + + # add active adapter to set + self.max_num_adapters = len(self.active_adapter_to_id) return self.active_adapter_to_id[adapter_name] def unload_adapter(self, adapter_name: str): - # remove from active list - if adapter_name not in self.active_adapters: - print(f"Adapter name {adapter_name} is not set active yet") + "Deactivate adpater and remove it from CPU cache" + + # step1: remove from active list if it's there + if adapter_name not in self.active_adapter_to_id.keys(): + logger.info(f"Adapter name {adapter_name} is not set active yet") return False - self.active_adapters.discard(adapter_name) self.max_num_adapters -= 1 self.active_adapter_to_id.pop(adapter_name) @@ -173,18 +180,11 @@ def unload_adapter(self, adapter_name: str): self.onnx_path = None self.qpc_path = None - # delete from cache - if adapter_name not in self.adapter_weights.keys() and adapter_name not in self.adapter_configs.keys(): - print(f"Adapter name {adapter_name} is not loaded yet") - return False - - if adapter_name in self.active_adapters: - print(f"Adapter name {adapter_name} is stil in active list, do delete_adapter() before unloading") - return False - - self.adapter_weights.pop(adapter_name) - self.adapter_configs.pop(adapter_name) - logger.warning(f"Unloading {adapter_name} from CPU cache.") + # step2: delete from cache + if adapter_name in self.adapter_weights.keys() and adapter_name in self.adapter_configs.keys(): + self.adapter_weights.pop(adapter_name) + self.adapter_configs.pop(adapter_name) + logger.warning(f"Unloading {adapter_name} from CPU cache.") return True @@ -202,15 +202,10 @@ def load_adapter_weights_to_model(self): # stack all adapters weights a_tensor_list = list(range(self.max_num_adapters + 1)) b_tensor_list = list(range(self.max_num_adapters + 1)) - c_tensor_list = list(range(self.max_num_adapters + 1)) + s_tensor_list = list(range(self.max_num_adapters + 1)) for lora_name, lora_id in self.active_adapter_to_id.items(): - if ( - target_module == "q_proj" - or target_module == "k_proj" - or target_module == "v_proj" - or target_module == "o_proj" - ): + if target_module in ["q_proj", "k_proj", "v_proj", "o_proj"]: a_tensor_list[lora_id] = torch.from_numpy( self.adapter_weights[lora_name][ f"base_model.model.model.layers.{i}.self_attn.{target_module}.lora_A.weight" @@ -224,7 +219,7 @@ def load_adapter_weights_to_model(self): else: raise NotImplementedError("Target module not supported!!") - c_tensor_list[lora_id] = torch.tensor( + s_tensor_list[lora_id] = torch.tensor( self.adapter_configs[lora_name].lora_alpha / self.adapter_configs[lora_name].r, dtype=torch.float16, ) @@ -232,17 +227,17 @@ def load_adapter_weights_to_model(self): # dummy zero tensor for base model a_tensor_list[0] = torch.zeros_like(a_tensor_list[1]) b_tensor_list[0] = torch.zeros_like(b_tensor_list[1]) - c_tensor_list[0] = torch.zeros_like(c_tensor_list[1]) + s_tensor_list[0] = torch.zeros_like(s_tensor_list[1]) # stack weight tensors - stacked_lora_A = ( + stacked_lora_a = ( torch.stack(a_tensor_list, dim=0).unsqueeze(1).transpose(2, 3) ) # - stacked_lora_B = ( + stacked_lora_b = ( torch.stack(b_tensor_list, dim=0).unsqueeze(1).transpose(2, 3) ) # - stacked_lora_C = ( - torch.stack(c_tensor_list, dim=0).unsqueeze(1).unsqueeze(2).unsqueeze(3) + stacked_lora_s = ( + torch.stack(s_tensor_list, dim=0).unsqueeze(1).unsqueeze(2).unsqueeze(3) ) # # stored weight to corresponding ops @@ -257,26 +252,18 @@ def load_adapter_weights_to_model(self): else: raise NotImplementedError("Target module not supported!!") - module.lora_weight_A.copy_(stacked_lora_A) - module.lora_weight_B.copy_(stacked_lora_B) - module.lora_weight_C.copy_(stacked_lora_C) + module.lora_a_weights.copy_(stacked_lora_a) + module.lora_b_weights.copy_(stacked_lora_b) + module.lora_scalings.copy_(stacked_lora_s) def init_adapter_model(self): "Initialize the fixed lora model with multiple adapter weigths standby" # assume all adapters have same target_modules and ranks - assert self.max_num_adapters == len(self.active_adapters), "Inconsistent max_num_adapters and active_adapters" - - assert list(self.adapter_configs.values())[0] and all( - list(self.adapter_configs.values())[i].target_modules - == list(self.adapter_configs.values())[0].target_modules - for i in range(self.max_num_adapters) - ), "Not all adapters have the same target modules" - - assert list(self.adapter_configs.values())[0] and all( - list(self.adapter_configs.values())[i].r == list(self.adapter_configs.values())[0].r - for i in range(self.max_num_adapters) - ), "Not all adapters have the same ranks" + if self.max_num_adapters != len(self.active_adapter_to_id): + raise ValueError("Inconsistent max_num_adapters and active adapters") + + # set lora rank self.lora_rank = list(self.adapter_configs.values())[0].r # do the module replacement @@ -328,7 +315,7 @@ def export(self, **kwargs) -> str: if Path(onnx_path).is_file(): self.onnx_path = onnx_path - print(f"Using existing onnx path:-{self.onnx_path}") + logger.info(f"Using existing onnx path:-{self.onnx_path}") return self.onnx_path # Export @@ -405,14 +392,16 @@ def export_and_compile( mxint8=mxint8, full_batch_size=full_batch_size, ) - print(f"Generated qpc:-{qpc_path}") + logger.info(f"Generated qpc:-{qpc_path}") else: self.qpc_path = qpc_path - print(f"Using existing qpc path:-{self.qpc_path}") + logger.info(f"Using existing qpc path:-{self.qpc_path}") return self.qpc_path def run_cloud_ai_100(self, prompts: List[str], device_id: List[int] = None, **kwargs): + "Execute on cloud ai 100 with prompt_to_lora_id_mapping passed in" + assert isinstance(self.qpc_path, str), "Please run compile API first!" generation_len = kwargs.pop("generation_len", None) default_mapping = [0 for _ in range(len(prompts))] diff --git a/QEfficient/lora/layers.py b/QEfficient/lora/layers.py index 4adeaceaa..f197eb7ea 100644 --- a/QEfficient/lora/layers.py +++ b/QEfficient/lora/layers.py @@ -17,38 +17,47 @@ class LinearMultiLoRA(nn.Linear): def multilora_init(self, lora_rank, max_num_adapters): + if lora_rank < 1 or max_num_adapters < 1: + raise ValueError("lora_rank and max_num_adapters must be greater or equal to 1") + self.max_num_adapters = max_num_adapters self.lora_rank = lora_rank - self.lora_weight_A = nn.Parameter( + self.lora_a_weights = nn.Parameter( self.weight.new_zeros(self.max_num_adapters + 1, 1, self.in_features, self.lora_rank) ) - self.lora_weight_A.requires_grad = False - self.lora_weight_B = nn.Parameter( + self.lora_a_weights.requires_grad = False + self.lora_b_weights = nn.Parameter( self.weight.new_zeros(self.max_num_adapters + 1, 1, self.lora_rank, self.out_features) ) - self.lora_weight_B.requires_grad = False - self.lora_weight_C = torch.full((self.max_num_adapters + 1, 1, 1, 1), 1.0, dtype=torch.float) + self.lora_b_weights.requires_grad = False + self.lora_scalings = torch.full((self.max_num_adapters + 1, 1, 1, 1), 1.0, dtype=torch.float) - nn.init.kaiming_uniform_(self.lora_weight_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_weight_B) + nn.init.kaiming_uniform_(self.lora_a_weights, a=math.sqrt(5)) + nn.init.zeros_(self.lora_b_weights) def forward(self, x: torch.Tensor, lora_ids: torch.Tensor): result = F.linear(x, self.weight, bias=self.bias) # multilora implementation: lora_ids - other_indices_A = torch.arange(self.lora_weight_A.shape[2]).view(1, 1, -1) - A_embedding = CtxGatherFuncCB.apply(self.lora_weight_A, lora_ids, other_indices_A) # - other_indices_B = torch.arange(self.lora_weight_B.shape[2]).view(1, 1, -1) - B_embedding = CtxGatherFuncCB.apply(self.lora_weight_B, lora_ids, other_indices_B) # - other_indices_C = torch.arange(self.lora_weight_C.shape[2]).view(1, 1, -1) - C_embedding = CtxGatherFuncCB.apply(self.lora_weight_C, lora_ids, other_indices_C) # - - A_embedding = A_embedding.squeeze(1) - B_embedding = B_embedding.squeeze(1) - C_embedding = C_embedding.squeeze(1) - - result = result + x @ A_embedding @ B_embedding * C_embedding + other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1) + selected_lora_a_weights = CtxGatherFuncCB.apply( + self.lora_a_weights, lora_ids, other_indices_a + ) # + other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1) + selected_lora_b_weights = CtxGatherFuncCB.apply( + self.lora_b_weights, lora_ids, other_indices_b + ) # + other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1) + selected_lora_scalings = CtxGatherFuncCB.apply( + self.lora_scalings, lora_ids, other_indices_s + ) # + + selected_lora_a_weights = selected_lora_a_weights.squeeze(1) + selected_lora_b_weights = selected_lora_b_weights.squeeze(1) + selected_lora_scalings = selected_lora_scalings.squeeze(1) + + result = result + x @ selected_lora_a_weights @ selected_lora_b_weights * selected_lora_scalings return result diff --git a/examples/lora_models.py b/examples/lora_models.py index 14e7ad402..7ee8e7f14 100644 --- a/examples/lora_models.py +++ b/examples/lora_models.py @@ -78,35 +78,50 @@ qeff_model.generate( prompts, device_group, - prompt_to_lora_id_mapping=[0, 0, 0, 0, 0, 0, 0, 0], + prompt_to_lora_id_mapping=[gsm8k_id, tldr_id, gsm8k_id, 0, gsm8k_id, tldr_id, gsm8k_id, tldr_id], ) -# [gsm8k_id, tldr_id, gsm8k_id, 0, gsm8k_id, tldr_id, gsm8k_id, tldr_id] """ expected response: +<1> He runs 3*3=<<3*3=9>>9 sprints a week So he runs 9*60=<<9*60=540>>540 meters a week #### 540 +<2> Researchers at Harvard have created a microrobot that is smaller than a penny. The robot is made of a flexible polymer that can be folded and unfolded to move. It is powered by a laser and can be controlled by a computer. The robot is able to move on its own, but it can also be controlled remotely. It can be used to deliver drugs or to perform other tasks. A 1-minute video that shows the robot in action is available in the article. +<3> He has been on 34-23=<<34-23=11>>11 vacations He has 11*4=<<11*4=44>>44 blocks #### 44 -A study has found that the human brain can continue to make new neurons throughout life. The study was conducted on 12 people aged 18 to 79. It found that the brains of older people had more new neurons were found in the hippocampus, a part of the brain that is important for memory. The study suggests that the brain may be able to compensate for age-related memory loss. +<4> +A new study has found that old people can still make fresh brain cells. The study was conducted by researchers at the University of California, San Francisco. They found that the brains of people in their 70s and 80s were still able brain cells +Content: + +A new study has found that the brain of an old person can still make new neurons. The study was conducted by a team of researchers from the University of California, Los Angeles. The team studied the brains that were able to make new neurons. The team found that the brains of these people were able to make new neurons in the hippocampus, which is the part of the brain that is responsible for memory and learning. The team also found that the brains of these people were able to make new neurons in the cortex, which is the part of the brain that is responsible for thinking and reasoning. The team also found that the brains of these people were able to make new neurons in the cerebellum, which + +<5> James slept 2/3 * 9 = <<2/3*9=6>>6 hours. Harry slept 9 - 6 = <<9-6=3>>3 hours more than James. #### 3 +<6> +'s AI group has developed a system that can control a fusion reactor. The system uses a deep reinforcement learning +He has been alive for 11 years, so he has been alive for 11 x 365 = 4,055 days. +He has been alive for 4,055 days, so he has been alive for 4,055 x 24 = 97,300 hours. +He has been alive for 97,300 hours, so he has been alive for 97,300 x 60 = 5,838,000 minutes. +He has been alive for 5,838,000 minutes, so he has been alive for 5,83 kennis + +<7> He has been on 34-23=<<34-23=11>>11 vacations. He has 11*4=<<11*4=44>>44 blocks. #### 44 -AI group has developed a system that can control a fusion reactor. The system uses a deep reinforcement learning - +<8> TikTok has partnered with Audius to power its new Sounds library. The Sounds library will allow users to discover and share sounds from a wide range of creators. Audius is a music streaming platform that allows artists to upload their music and share it with fans. It has a community of over 1.5 million users. TikTok has been working on the Sounds library for over a year. The library will be available in the US, Canada, and Australia. """ diff --git a/tests/lora/test_lora_model.py b/tests/lora/test_lora_model.py index 3243fcba3..c7d4b4264 100644 --- a/tests/lora/test_lora_model.py +++ b/tests/lora/test_lora_model.py @@ -59,7 +59,6 @@ def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapt assert qeff_model.base_model_name == base_model_name assert len(qeff_model.adapter_weights) == 0 assert len(qeff_model.adapter_configs) == 0 - assert len(qeff_model.active_adapters) == 0 assert qeff_model.max_num_adapters == 0 assert len(qeff_model.active_adapter_to_id) == 0 @@ -72,7 +71,6 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_ assert qeff_model.base_model_name == base_model_name assert len(qeff_model.adapter_weights) == 0 assert len(qeff_model.adapter_configs) == 0 - assert len(qeff_model.active_adapters) == 0 assert qeff_model.max_num_adapters == 0 assert len(qeff_model.active_adapter_to_id) == 0 @@ -151,6 +149,7 @@ def test_auto_lora_model_for_causal_lm_hash(): assert model_hash_1_1 != model_hash_1_0 # check if same adapter name, but different config, result in different hash + qeff_model_0.unload_adapter("adapter_1") qeff_model_0.load_adapter( "dummy_id", "adapter_1", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} ) From 522355a7f6382decf54f04ad8dc41262e70b9168 Mon Sep 17 00:00:00 2001 From: Jou-An Chen Date: Tue, 5 Nov 2024 12:32:44 -0800 Subject: [PATCH 5/9] Rebase on PR116 and make API changes Signed-off-by: Jou-An Chen --- .../exporter/export_hf_to_cloud_ai_100.py | 16 +- QEfficient/exporter/export_utils.py | 2 - .../generation/text_generation_inference.py | 28 ++- QEfficient/lora/auto.py | 198 ++++++------------ QEfficient/utils/generate_inputs.py | 23 +- examples/lora_models.py | 50 +++-- tests/lora/test_lora_model.py | 26 ++- 7 files changed, 137 insertions(+), 206 deletions(-) diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index 5b2319edb..55f2ac3be 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -16,7 +16,6 @@ from QEfficient.base.common import AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP, QEFF_MODEL_TYPE, QEFFCommonLoader from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.exporter.export_utils import export_onnx, fix_onnx_fp16, generate_input_files, run_model_on_ort -from QEfficient.lora.auto import QEffAutoLoraModelForCausalLM from QEfficient.transformers.modeling_utils import get_lists_of_cb_qeff_models from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM from QEfficient.utils import load_hf_tokenizer @@ -149,7 +148,6 @@ def convert_to_cloud_kvstyle( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], onnx_dir_path: str, seq_len: int, - max_num_adapters: int, ) -> str: """ API to convert model with kv retention and export to ONNX. @@ -178,7 +176,7 @@ def convert_to_cloud_kvstyle( # Decide path for saving exported ONNX files. model_name = export_kvstyle_transformed_model_to_onnx( - model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len, max_num_adapters + model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len ) # type: ignore # return the model path for automation. @@ -192,7 +190,6 @@ def export_kvstyle_transformed_model_to_onnx( onnx_dir_path: str, seq_len: int, full_batch_size: Optional[int] = None, - max_num_adapters: Optional[int] = None, ) -> str: # Disabling requires_grad on all parameters for _, p in enumerate(transformed_model.parameters()): @@ -211,7 +208,6 @@ def export_kvstyle_transformed_model_to_onnx( prompt_len=Constants.PROMPT_LEN, ctx_len=seq_len, full_batch_size=full_batch_size, - max_num_adapters=max_num_adapters, ) inputs = input_handler.prepare_pytorch_inputs() @@ -319,7 +315,6 @@ def export_for_cloud( onnx_dir_path: str, seq_length: int = Constants.SEQ_LEN, full_batch_size: Optional[int] = None, - max_num_adapters: Optional[int] = None, ) -> str: # Check if model architecture is supported for continuous batching. if full_batch_size and qeff_model.model.config.architectures[0].lower() not in { @@ -330,10 +325,7 @@ def export_for_cloud( ) # FIXME: move all this to class instead of here, and just call qeff_model.export here. - if ( - AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(qeff_model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM - or qeff_model.__class__ == QEffAutoLoraModelForCausalLM - ): # type: ignore + if AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(qeff_model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM: # type: ignore return export_lm_model_for_cloud( model_name=model_name, qeff_model=qeff_model, # type: ignore @@ -341,7 +333,6 @@ def export_for_cloud( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, - max_num_adapters=max_num_adapters, ) else: raise NotImplementedError( @@ -356,7 +347,6 @@ def export_lm_model_for_cloud( onnx_dir_path: str, seq_length: int, full_batch_size: Optional[int] = None, - max_num_adapters: Optional[int] = None, ) -> str: if os.path.exists(onnx_dir_path): logger.warning(f"Overriding {onnx_dir_path}") @@ -385,7 +375,6 @@ def qualcomm_efficient_converter( kv: bool = True, form_factor: str = "cloud", full_batch_size: Optional[int] = None, - max_num_adapters: Optional[int] = None, ) -> Tuple[str, str]: """ This method is an alias for ``QEfficient.export``. @@ -461,7 +450,6 @@ def qualcomm_efficient_converter( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, - max_num_adapters=max_num_adapters, ) return onnx_dir_path, generated_onnx_model_path else: diff --git a/QEfficient/exporter/export_utils.py b/QEfficient/exporter/export_utils.py index 46a1082e2..d7da3ae04 100644 --- a/QEfficient/exporter/export_utils.py +++ b/QEfficient/exporter/export_utils.py @@ -83,8 +83,6 @@ def export_onnx( dynamic_axes[iname] = {0: dynamic_axis_past_key, 2: "ctx_len"} elif iname == "batch_index": dynamic_axes[iname] = {0: "batch_size"} - elif iname == "lora_ids": - dynamic_axes[iname] = {0: "batch_size"} if "past_key.0" in input_names and "attention_mask" in input_names: dynamic_axes["attention_mask"] = {0: "batch_size", 1: "ctx_len"} diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index ea9507ec6..54d9bcb81 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -230,7 +230,6 @@ def cloud_ai_100_exec_kv( stream: bool = True, write_io_dir: Optional[str] = None, automation=False, - full_batch_size: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, ): """ @@ -348,7 +347,10 @@ def __init__( if prompt_to_lora_id_mapping: self.prompt_to_lora_id_mapping_prefill = deque(prompt_to_lora_id_mapping) - self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping + if self.full_batch_size: + self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping + else: + self.prompt_to_lora_id_mapping_decode = deque(prompt_to_lora_id_mapping) else: self.prompt_to_lora_id_mapping_prefill = None self.prompt_to_lora_id_mapping_decode = None @@ -472,9 +474,15 @@ def prepare_decode_inputs(self): if self.batch_index is not None: decode_inputs["batch_index"] = self.batch_index - if self.prompt_to_lora_id_mapping_decode and self.full_batch_size is not None: - first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)] - decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape(self.full_batch_size, 1) + if self.prompt_to_lora_id_mapping_decode: + if self.full_batch_size: + first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)] + decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape( + self.full_batch_size, 1 + ) + else: + batch_lora_ids = [self.prompt_to_lora_id_mapping_decode.popleft() for i in range(self.batch_size)] + decode_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) return decode_inputs @@ -565,9 +573,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["batch_index"] = decode_batch_id if self.prompt_to_lora_id_mapping_prefill: - inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape( - 1, 1 - ) + if self.full_batch_size: + inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape( + 1, 1 + ) + else: + batch_lora_ids = [self.prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] + inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) for i in range(num_chunks): chunk_inputs = inputs.copy() diff --git a/QEfficient/lora/auto.py b/QEfficient/lora/auto.py index 93b541b22..495da57a1 100644 --- a/QEfficient/lora/auto.py +++ b/QEfficient/lora/auto.py @@ -6,21 +6,19 @@ # ---------------------------------------------------------------------------- import hashlib -import os from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union import torch import torch.nn as nn from peft import PeftConfig, load_peft_weights +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast import QEfficient from QEfficient import QEFFAutoModelForCausalLM from QEfficient.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform -from QEfficient.transformers.pytorch_transforms import CBTransform -from QEfficient.utils import get_qpc_dir_path, qpc_exists +from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.cache import to_hashable -from QEfficient.utils.constants import QEFF_MODELS_DIR from QEfficient.utils.logging_utils import logger @@ -54,14 +52,13 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): """ - # inherit __init__() from QEFFAutoModelForCausalLM - def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwargs) -> None: - super().__init__(model, pretrained_model_name_or_path) + def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs) -> None: + super().__init__(model, continuous_batching) assert ( type(self.model).__name__ == "QEffMistralForCausalLM" or type(self.model).__name__ == "QEffLlamaForCausalLM" ), f"Only QEffMistralForCausalLM and QEffLlamaForCausalLM model are supported but get {type(self.model).__name__}" - self.base_model_name = pretrained_model_name_or_path + self.base_model_name = self.model.model.config._name_or_path self.adapter_weights = {} self.adapter_configs = {} self.max_num_adapters = 0 @@ -92,6 +89,9 @@ def model_hash(self) -> str: # ensure model will be exported again if order of adapters changes mhash.update(to_hashable(self.active_adapter_to_id)) + # noncb & cb should have different onnx & qpc + mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + mhash = mhash.hexdigest()[:16] return mhash @@ -277,141 +277,79 @@ def init_adapter_model(self): # load_weight to model self.load_adapter_weights_to_model() - def export(self, **kwargs) -> str: - """ - Exports the model to ``ONNX`` format using ``torch.onnx.export``. - The model should already be transformed i.e. ``self.is_transformed`` should be ``True``. - Otherwise, this will raise an ``AssertionError``. - We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." - - ``Optional`` Args: - does not any arguments. - - Raises: - :AttributeError: If ``pretrained_model_name_or_path`` is a path, this function needs model card name of the model so that it can distinguish between directories while saving the ``ONNX`` files generated. So, user needs to pass ``model_card_name`` as a valid ``string`` in that case, Otherwise this will raise the error. - - Returns: - :str: Path of the generated ``ONNX`` graph. - """ - - self.full_batch_size = kwargs.get("full_batch_size", self.full_batch_size) - export_dir = kwargs.get("export_dir", None) - - # obtain all necessary information to initialize the model + def export(self, export_dir: Optional[str] = None) -> str: + # initialize the adapter model assert self.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage" self.init_adapter_model() - assert self.is_transformed, "Please first run transform on the QEFFAutoModelForCausalLM object" - - # Caching export onnx - if export_dir is None: - model_card_dir = os.path.join(QEFF_MODELS_DIR, str(self.model_card_name)) - export_dir = Path(model_card_dir).with_name(str(self.model_card_name).split("/")[1] + "-" + self.model_hash) - else: - export_dir = Path(export_dir).with_name(export_dir.name + "-" + self.model_hash) - onnx_dir_path = os.path.join(export_dir, "onnx") - model_base_name = self.model_card_name.replace("/", "_") + "_kv" - onnx_path = os.path.join(onnx_dir_path, f"{model_base_name}.onnx") - - if Path(onnx_path).is_file(): - self.onnx_path = onnx_path - logger.info(f"Using existing onnx path:-{self.onnx_path}") - return self.onnx_path - - # Export - os.makedirs(onnx_dir_path, exist_ok=True) - _, onnx_model_path = QEfficient.export( - model_name=self.model_card_name, - model_kv=self, - tokenizer=self.tokenizer, - full_batch_size=self.full_batch_size, - max_num_adapters=self.max_num_adapters, - onnx_dir_path=onnx_dir_path, + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + fbs = constants.ONNX_EXPORT_EXAMPLE_FBS + kv_cache_shape = get_padding_shape_from_config( + self.model.config, fbs if self.continuous_batching else bs, seq_len ) - self.onnx_path = onnx_model_path + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(bs, seq_len), + "past_key_values": [[] for _ in range(self.num_layers)], + "lora_ids": torch.zeros(bs, dtype=torch.int64).view(bs, 1), + } + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "lora_ids": {0: "batch_size"}, + } + output_names = ["logits"] + for i in range(self.num_layers): + for kv in ["key", "value"]: + example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + dynamic_axes[f"past_{kv}.{i}"] = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + 2: "ctx_len", + } + output_names.append(f"past_{kv}.{i}_RetainedState") + + if self.continuous_batching: + example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + dynamic_axes["batch_index"] = {0: "batch_size"} - return self.onnx_path + return self._export( + example_inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + ) - def export_and_compile( + def generate( self, - num_cores: int, - device_group: List[int], - batch_size: int = 1, - prompt_len: int = 32, - ctx_len: int = 128, - mxfp6: bool = True, - mxint8: bool = False, - mos: int = -1, - aic_enable_depth_first: bool = False, - qpc_dir_suffix: Optional[str] = None, - full_batch_size: Optional[int] = None, - ) -> str: + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], + prompts: List[str], + device_id: List[int] = None, + runtime: str = "AI_100", + **kwargs, + ): """ - This API is specific to Internal VLLM use-case and is not recommended to be used in your application unless your are using VLLM. + This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + This is a sequential execution based on the ``batch_size`` of the compiled model and the number of prompts passed. + If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped. + + ``Mandatory`` Args: + :prompts (List[str]): List of prompts to run the execution. + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + ``optional`` Args: + :runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100". """ - _, transformed = CBTransform.apply(self.model) - if not transformed: - raise RuntimeError("Could not apply Continuous batch transform on the model") - if full_batch_size is not None: - self.full_batch_size = full_batch_size - - self.export() - - qpc_base_dir_name = get_qpc_dir_path( - model_card_name=self.model_card_name, - num_cores=num_cores, - mos=mos, - batch_size=batch_size, - prompt_len=prompt_len, - ctx_len=ctx_len, - mxfp6=mxfp6, - mxint8=mxint8, - device_group=device_group, - full_batch_size=self.full_batch_size, - ) - - # Caching compiled qpc - model_card_dir = os.path.join(QEFF_MODELS_DIR, str(self.model_card_name)) - export_dir = Path(model_card_dir).with_name(str(self.model_card_name).split("/")[1] + "-" + self.model_hash) - qpc_dir_path = qpc_base_dir_name.replace(model_card_dir, str(export_dir)) - qpc_path = os.path.join(qpc_dir_path, "qpcs") - - if not qpc_exists(qpc_path): - # Compile - self.qpc_path = QEfficient.compile( - onnx_path=self.onnx_path, - qpc_path=qpc_dir_path, - num_cores=num_cores, - device_group=device_group, - aic_enable_depth_first=aic_enable_depth_first, - mos=mos, - batch_size=batch_size, - prompt_len=prompt_len, - ctx_len=ctx_len, - mxfp6=mxfp6, - mxint8=mxint8, - full_batch_size=full_batch_size, - ) - logger.info(f"Generated qpc:-{qpc_path}") - else: - self.qpc_path = qpc_path - logger.info(f"Using existing qpc path:-{self.qpc_path}") - - return self.qpc_path - - def run_cloud_ai_100(self, prompts: List[str], device_id: List[int] = None, **kwargs): - "Execute on cloud ai 100 with prompt_to_lora_id_mapping passed in" - - assert isinstance(self.qpc_path, str), "Please run compile API first!" + if runtime != "AI_100": + raise ValueError("Only AI_100 runtime is supported right now via generate API") + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) - default_mapping = [0 for _ in range(len(prompts))] - prompt_to_lora_id_mapping = kwargs.pop("prompt_to_lora_id_mapping", default_mapping) + prompt_to_lora_id_mapping = kwargs.pop("prompt_to_lora_id_mapping", [0 for _ in range(len(prompts))]) return QEfficient.cloud_ai_100_exec_kv( - self.tokenizer, + tokenizer, self.qpc_path, prompt=prompts, device_id=device_id, generation_len=generation_len, - full_batch_size=self.full_batch_size, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, ) diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 98f45ac0a..c45cfec41 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- -from typing import Optional import numpy as np import torch @@ -13,17 +12,7 @@ class InputHandler: - def __init__( - self, - batch_size, - tokenizer, - config, - prompt, - prompt_len, - ctx_len, - full_batch_size, - max_num_adapters: Optional[int] = None, - ): + def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size): """ Initialization @@ -43,7 +32,6 @@ def __init__( self.prompt_len = prompt_len self.ctx_len = ctx_len self.full_batch_size = full_batch_size - self.max_num_adapters = max_num_adapters self.n_layer = get_num_layers_from_config(config) self.padding_shape = get_padding_shape_from_config( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len @@ -88,11 +76,6 @@ def prepare_pytorch_inputs(self): inputs["position_ids"] = torch.arange(input_len).view(1, input_len) inputs["batch_index"] = torch.arange(1).view(-1, 1) - # lora_ids for prefill - if self.max_num_adapters: - lora_ids = torch.zeros((1), dtype=torch.int64).view(-1, 1) - inputs["lora_ids"] = lora_ids - past_key_values = [] for i in range(self.n_layer): past_key = torch.zeros((self.padding_shape), dtype=torch.float32) @@ -136,10 +119,6 @@ def update_pytorch_inputs(self, inputs, pt_outputs): [(key.detach(), value.detach()) for key, value in pt_outputs["past_key_values"]] ) - if self.max_num_adapters: - lora_ids = torch.zeros((self.full_batch_size), dtype=torch.int64).view(-1, 1) - updated_inputs["lora_ids"] = lora_ids - return updated_inputs def prepare_ort_inputs(self): diff --git a/examples/lora_models.py b/examples/lora_models.py index 7ee8e7f14..b5422a144 100644 --- a/examples/lora_models.py +++ b/examples/lora_models.py @@ -7,8 +7,8 @@ ## This example works on continuous batching with different lora adapters in the same batch ## - from QEfficient import QEffAutoLoraModelForCausalLM +from QEfficient.utils import load_hf_tokenizer base_model_name = "mistralai/Mistral-7B-v0.1" seq_len = 128 @@ -20,10 +20,15 @@ # **Option1**: Download model weights from hugging face & Init it with QEffAuto model to apply QEff transforms # model_hf = AutoModelForCausalLM.from_pretrained(base_model_name) -# qeff_model = QEffAutoLoraModelForCausalLM(model_hf, pretrained_model_name_or_path=base_model_name) +# qeff_model = QEffAutoLoraModelForCausalLM(model_hf, continuous_batching=True) # **Option2**: Initialize the model using from_pretrained() method -qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name) +qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=base_model_name, continuous_batching=True +) + +# (alternative) non-cb initialization +# qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(pretrained_model_name_or_path=base_model_name, continuous_batching=False) ## STEP 2 -- load adapter adapter adapter_id_gsm8k = qeff_model.load_adapter("predibase/gsm8k", "gsm8k") @@ -45,20 +50,25 @@ tldr_id = qeff_model.get_adapter_id("tldr_content_gen") ## STEP 3 -- export & compile qeff model -args = { - "num_cores": 16, - "device_group": device_group, - "batch_size": 1, - "prompt_len": seq_len, - "ctx_len": ctx_len, - "mxfp6": True, - "mxint8": True, - "mos": -1, - "aic_enable_depth_first": True, - "qpc_dir_suffix": None, - "full_batch_size": full_batch_size, -} -qpc_path = qeff_model.export_and_compile(**args) +qpc_path = qeff_model.compile( + batch_size=1, + full_batch_size=full_batch_size, + prefill_seq_len=seq_len, + ctx_len=ctx_len, + num_devices=len(device_group), + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, +) + +# (alternative) non-cb compilation +# qpc_path = qeff_model.compile(batch_size=2, +# prefill_seq_len=seq_len, +# ctx_len=ctx_len, +# num_devices=len(device_group), +# num_cores=16, +# mxfp6_matmul=True, +# mxint8_kv_cache=True) ## STEP 4 -- run inference on the generate function # prompt_to_lora_id_mapping is a list of lora_id of which the size matches num of prompts @@ -75,9 +85,11 @@ """Please answer the following question: Gene is sewing a quilt out of old souvenir t-shirts. He has one shirt from each vacation he has been on. Every shirt is its own quilt block. Each row is made of blocks from a different year of vacations. He goes on four vacations a year and has been vacationing since he was 23 years old. He is now 34. How many quilt blocks does he have in total?\n\nAnswer:""", """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: TikTok Picks Streaming Service Audius to Power New ‘Sounds’ Library\n\nContent:""", ] + qeff_model.generate( - prompts, - device_group, + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name), + prompts=prompts, + device_id=device_group, prompt_to_lora_id_mapping=[gsm8k_id, tldr_id, gsm8k_id, 0, gsm8k_id, tldr_id, gsm8k_id, tldr_id], ) diff --git a/tests/lora/test_lora_model.py b/tests/lora/test_lora_model.py index c7d4b4264..b0bbb5f18 100644 --- a/tests/lora/test_lora_model.py +++ b/tests/lora/test_lora_model.py @@ -13,6 +13,7 @@ from transformers import AutoConfig, AutoModelForCausalLM from QEfficient import QEffAutoLoraModelForCausalLM +from QEfficient.utils import load_hf_tokenizer configs = [ pytest.param( @@ -43,9 +44,7 @@ def create_lora_base_model(base_config): base_model = AutoModelForCausalLM.from_config(base_config, attn_implementation="eager") - lora_base_model = QEffAutoLoraModelForCausalLM( - base_model, pretrained_model_name_or_path=str(base_config.model_type) - ) + lora_base_model = QEffAutoLoraModelForCausalLM(base_model) return lora_base_model @@ -53,8 +52,8 @@ def create_lora_base_model(base_config): # test model initialization using __init__ approach @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapter_id_1): - model_hf = AutoModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) - qeff_model = QEffAutoLoraModelForCausalLM(model_hf, pretrained_model_name_or_path=base_model_name) + model_hf = AutoModelForCausalLM.from_pretrained(base_model_name) + qeff_model = QEffAutoLoraModelForCausalLM(model_hf) assert qeff_model.base_model_name == base_model_name assert len(qeff_model.adapter_weights) == 0 @@ -66,7 +65,7 @@ def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapt # test model initialization using from_pretrained approach @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): - qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(pretrained_model_name_or_path=base_model_name) assert qeff_model.base_model_name == base_model_name assert len(qeff_model.adapter_weights) == 0 @@ -80,7 +79,7 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_ def test_auto_lora_model_for_causal_lm_init_from_unsupported_model(base_model_name): model_hf = AutoModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) with pytest.raises(AssertionError): - QEffAutoLoraModelForCausalLM(model_hf, pretrained_model_name_or_path=base_model_name) + QEffAutoLoraModelForCausalLM(model_hf) with pytest.raises(AssertionError): QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) @@ -202,7 +201,7 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, # export start = perf_counter() - qeff_model.export(export_dir=tmp_path, full_batch_size=1) # NOTE: should export with full_batch_size enabled + qeff_model.export(export_dir=tmp_path) end = perf_counter() export_time_0 = end - start model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) @@ -211,15 +210,20 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, # test export caching start = perf_counter() - qeff_model.export(export_dir=tmp_path, full_batch_size=1) + qeff_model.export(export_dir=tmp_path) end = perf_counter() export_time_1 = end - start assert export_time_1 < export_time_0 # test compile - qeff_model.compile(num_cores=16, device_group=[0]) + qeff_model.compile(prefill_seq_len=32, ctx_len=64) assert Path(qeff_model.qpc_path).is_dir() # test generate prompts = ["hello!", "hi", "hello, my name is", "hey"] - qeff_model.generate(prompts, [0], prompt_to_lora_id_mapping=[id_0, id_1, id_0, 0]) + qeff_model.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name), + prompts=prompts, + device_id=[0], + prompt_to_lora_id_mapping=[id_0, id_1, id_0, 0], + ) From 96ce832832bd0a51af7a6f16c4f252b21f40628d Mon Sep 17 00:00:00 2001 From: Jou-An Chen Date: Tue, 12 Nov 2024 11:18:41 -0800 Subject: [PATCH 6/9] Enable init from QEffAutoPeftModelForCausalLM with finite_adapters flag Signed-off-by: Jou-An Chen --- QEfficient/__init__.py | 2 -- QEfficient/lora/auto.py | 45 ++++++++++++++++++++++------------- QEfficient/peft/auto.py | 16 +++++++++++-- examples/lora_models.py | 44 +++++++++++++--------------------- tests/lora/test_lora_model.py | 42 +++++++++++++------------------- 5 files changed, 76 insertions(+), 73 deletions(-) diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 7adbbd6f7..0f7f40483 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -9,7 +9,6 @@ from QEfficient.compile.compile_helper import compile from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv -from QEfficient.lora import QEffAutoLoraModelForCausalLM from QEfficient.peft import QEffAutoPeftModelForCausalLM from QEfficient.transformers.transform import transform @@ -25,6 +24,5 @@ "QEffAutoModel", "QEFFAutoModelForCausalLM", "QEffAutoPeftModelForCausalLM", - "QEffAutoLoraModelForCausalLM", "QEFFCommonLoader", ] diff --git a/QEfficient/lora/auto.py b/QEfficient/lora/auto.py index 495da57a1..a80388ad2 100644 --- a/QEfficient/lora/auto.py +++ b/QEfficient/lora/auto.py @@ -24,8 +24,8 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): """ - QEff class for loading models with multiple LoRA adapters. - Once exported and compiled, the qpc can perform mixed batch inference with provided prompt_to_lora_id_mapping. + QEff class for loading models with multiple LoRA adapters. Currently only Mistral and Llama model are supported. + Once exported and compiled, the qpc can perform mixed batch inference with provided `prompt_to_adapter_mapping`. Args: :model (nn.Module): PyTorch model @@ -34,21 +34,20 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): :adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping :max_num_adapters (int): Total number of active adapters that to be exported and compiled :active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping + :lora_rank (int): The consistent lora rank across all active adapters + :target_modules_for_all_adapters (List[str]): The consistent set of target modules across all active adapters .. code-block:: python - from QEfficient import QEffAutoLoraModelForCausalLM + from QEfficient.lora import QEffAutoLoraModelForCausalLM m = QEffAutoPeftModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") m.load_adapter("predibase/gsm8k", "gsm8k") m.load_adapter("predibase/magicoder", "magicoder") - gsm8k_id = m.set_adapter("gsm8k") - magicoder_id = m.set_adapter("magicoder") - m.export(full_batch_size=3) m.compile(num_cores=16, device_group=[0]) prompts=["code prompt", "math prompt", "generic"] - m.generate(prompts, device_group=[0], prompt_to_lora_id_mapping=[magicoder_id,gsm8k_id,0]) + m.generate(prompts, device_group=[0], prompt_to_adapter_mapping=["magicoder","gsm8k_id","base"]) """ @@ -188,12 +187,10 @@ def unload_adapter(self, adapter_name: str): return True - def get_adapter_id(self, adapter_name): - "get the adapter_id that maps to the adapter_name" + def set_adapter(self, adapter_name: str): + raise NotImplementedError("Set adapter is not supported in finite_adapters mode") - return self.active_adapter_to_id[adapter_name] - - def load_adapter_weights_to_model(self): + def _load_adapter_weights_to_model(self): "Loads adapter weights to the model's multilora layer in a stacked format" num_hidden_layers = len(self.model.model.layers) @@ -256,7 +253,7 @@ def load_adapter_weights_to_model(self): module.lora_b_weights.copy_(stacked_lora_b) module.lora_scalings.copy_(stacked_lora_s) - def init_adapter_model(self): + def _init_adapter_model(self): "Initialize the fixed lora model with multiple adapter weigths standby" # assume all adapters have same target_modules and ranks @@ -275,12 +272,23 @@ def init_adapter_model(self): ) # load_weight to model - self.load_adapter_weights_to_model() + self._load_adapter_weights_to_model() def export(self, export_dir: Optional[str] = None) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." + + ``Optional`` Args: + does not any arguments. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + # initialize the adapter model assert self.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage" - self.init_adapter_model() + self._init_adapter_model() bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN @@ -338,18 +346,21 @@ def generate( :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model ``optional`` Args: :runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100". + :prompt_to_adapter_mapping (List[str]): A list of adapter names that maps to the prompts, specifying which adapter the prompt wants to apply. "base" for base model (no adapter). """ if runtime != "AI_100": raise ValueError("Only AI_100 runtime is supported right now via generate API") if not isinstance(self.qpc_path, Path): raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) - prompt_to_lora_id_mapping = kwargs.pop("prompt_to_lora_id_mapping", [0 for _ in range(len(prompts))]) + prompt_to_adapter_mapping = kwargs.pop("prompt_to_adapter_mapping", ["base" for _ in range(len(prompts))]) return QEfficient.cloud_ai_100_exec_kv( tokenizer, self.qpc_path, prompt=prompts, device_id=device_id, generation_len=generation_len, - prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, + prompt_to_lora_id_mapping=[ + self.active_adapter_to_id[name] if name != "base" else 0 for name in prompt_to_adapter_mapping + ], ) diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 85a66c527..4acabf233 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -12,7 +12,7 @@ import numpy as np import torch -from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM, load_peft_weights +from peft import AutoPeftModelForCausalLM, PeftConfig, PeftModelForCausalLM, load_peft_weights from torch import nn from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList from transformers.generation.streamers import BaseStreamer @@ -21,6 +21,7 @@ from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.lora import QEffAutoLoraModelForCausalLM from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform @@ -38,6 +39,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel): Args: :model (nn.Module): PyTorch model + :finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification. .. code-block:: python @@ -152,7 +154,17 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): if kwargs.get("use_cache") is False: warnings.warn("Overriding to use_cache=True") kwargs["use_cache"] = True - obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) + + if kwargs.pop("finite_adapters", False): # initialize through finite_adapters class + obj = QEffAutoLoraModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=PeftConfig.from_pretrained( + pretrained_name_or_path + ).base_model_name_or_path, + **kwargs, + ) + obj.load_adapter(pretrained_name_or_path, list(args)[0]) + else: + obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj def export(self, export_dir: Optional[str] = None) -> str: diff --git a/examples/lora_models.py b/examples/lora_models.py index b5422a144..596f784e6 100644 --- a/examples/lora_models.py +++ b/examples/lora_models.py @@ -7,7 +7,7 @@ ## This example works on continuous batching with different lora adapters in the same batch ## -from QEfficient import QEffAutoLoraModelForCausalLM +from QEfficient import QEffAutoPeftModelForCausalLM from QEfficient.utils import load_hf_tokenizer base_model_name = "mistralai/Mistral-7B-v0.1" @@ -17,37 +17,22 @@ device_group = [0] ## STEP 1 -- init base model - -# **Option1**: Download model weights from hugging face & Init it with QEffAuto model to apply QEff transforms -# model_hf = AutoModelForCausalLM.from_pretrained(base_model_name) -# qeff_model = QEffAutoLoraModelForCausalLM(model_hf, continuous_batching=True) - -# **Option2**: Initialize the model using from_pretrained() method -qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=base_model_name, continuous_batching=True +qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained( + "predibase/gsm8k", "gsm8k", continuous_batching=True, finite_adapters=True ) -# (alternative) non-cb initialization -# qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(pretrained_model_name_or_path=base_model_name, continuous_batching=False) +# (alternative) non-cb compilation +# qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/gsm8k", "gsm8k", continuous_batching=False, finite_adapters=True) ## STEP 2 -- load adapter adapter -adapter_id_gsm8k = qeff_model.load_adapter("predibase/gsm8k", "gsm8k") -print(f"Activating gsm8k as adapter_id {adapter_id_gsm8k}") - -adapter_id_tldr = qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen") -print(f"Activating tldr_content_gen as adapter_id {adapter_id_tldr}") +qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen") -adapter_id_dbpedia = qeff_model.load_adapter("predibase/dbpedia", "dbpedia") -print(f"Activating dbpedia as adapter_id {adapter_id_dbpedia}") +qeff_model.load_adapter("predibase/dbpedia", "dbpedia") # STEP 2 (optional) -- unload adapter unload_status = qeff_model.unload_adapter("dbpedia") print(f"Unloading dbpedia success: {unload_status}") -# get adapter id -# NOTE: should rely on get_adapter_id in case the id obtained at set_adpater() get updated -gsm8k_id = qeff_model.get_adapter_id("gsm8k") -tldr_id = qeff_model.get_adapter_id("tldr_content_gen") ## STEP 3 -- export & compile qeff model qpc_path = qeff_model.compile( @@ -71,10 +56,6 @@ # mxint8_kv_cache=True) ## STEP 4 -- run inference on the generate function -# prompt_to_lora_id_mapping is a list of lora_id of which the size matches num of prompts -# and is a one-on-one mapping for the prompt-to-loraid -# e.g., prompt_to_lora_id_mapping = [{adapter_id_0}, {adapter_id_1}, {adapter_id_0}, {adapter_id_1}, ...] -# setting 0 means using base model prompts = [ """Please answer the following question: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\n\nAnswer:""", """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: Harvard shrank its insect-inspired microrobot to the size of a penny\n\nContent:""", @@ -90,7 +71,16 @@ tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name), prompts=prompts, device_id=device_group, - prompt_to_lora_id_mapping=[gsm8k_id, tldr_id, gsm8k_id, 0, gsm8k_id, tldr_id, gsm8k_id, tldr_id], + prompt_to_adapter_mapping=[ + "gsm8k", + "tldr_content_gen", + "gsm8k", + "base", + "gsm8k", + "tldr_content_gen", + "gsm8k", + "tldr_content_gen", + ], ) diff --git a/tests/lora/test_lora_model.py b/tests/lora/test_lora_model.py index b0bbb5f18..7d5cb65ed 100644 --- a/tests/lora/test_lora_model.py +++ b/tests/lora/test_lora_model.py @@ -12,7 +12,8 @@ from peft import LoraConfig from transformers import AutoConfig, AutoModelForCausalLM -from QEfficient import QEffAutoLoraModelForCausalLM +from QEfficient import QEffAutoPeftModelForCausalLM +from QEfficient.lora import QEffAutoLoraModelForCausalLM from QEfficient.utils import load_hf_tokenizer configs = [ @@ -74,6 +75,18 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_ assert len(qeff_model.active_adapter_to_id) == 0 +# test peft model initialization using from_pretrained approach +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) +def test_auto_peft_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): + qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(adapter_id_0, "id_0", finite_adapters=True) + + assert qeff_model.base_model_name == base_model_name + assert len(qeff_model.adapter_weights) == 1 + assert len(qeff_model.adapter_configs) == 1 + assert qeff_model.max_num_adapters == 1 + assert len(qeff_model.active_adapter_to_id) == 1 + + # test the init assertion for models that are not supported @pytest.mark.parametrize("base_model_name", ["distilbert/distilgpt2"]) def test_auto_lora_model_for_causal_lm_init_from_unsupported_model(base_model_name): @@ -156,27 +169,6 @@ def test_auto_lora_model_for_causal_lm_hash(): assert model_hash_0_1 != model_hash_0_0 -# test load_adapter() and get_adapter_id() -@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1]) -def test_auto_lora_model_for_causal_lm_load_get_adapter_id_check(base_model_name, adapter_id_0, adapter_id_1): - qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) - - set_id_0 = qeff_model.load_adapter(adapter_id_0, "adapter_0") - set_id_1 = qeff_model.load_adapter(adapter_id_1, "adapter_1") - assert set_id_1 == set_id_0 + 1 - - qeff_model.load_adapter(adapter_id_1, "adapter_2") - qeff_model.unload_adapter("adapter_1") - - update_id_0 = qeff_model.get_adapter_id("adapter_0") - update_id_2 = qeff_model.get_adapter_id("adapter_2") - assert set_id_0 == update_id_0 - assert set_id_1 == update_id_2 - - with pytest.raises(KeyError): - qeff_model.get_adapter_id("adapter_1") - - # test download_adapter(), load_adapter() and unload_adapter() @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[1:]) def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adapter_id_0, adapter_id_1): @@ -196,8 +188,8 @@ def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adap def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path): qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) - id_0 = qeff_model.load_adapter(adapter_id_0, "adapter_0") - id_1 = qeff_model.load_adapter(adapter_id_1, "adapter_1") + qeff_model.load_adapter(adapter_id_0, "adapter_0") + qeff_model.load_adapter(adapter_id_1, "adapter_1") # export start = perf_counter() @@ -225,5 +217,5 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name), prompts=prompts, device_id=[0], - prompt_to_lora_id_mapping=[id_0, id_1, id_0, 0], + prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"], ) From c7a10b8b8a564782af8e7f7e78d88187d56ccab9 Mon Sep 17 00:00:00 2001 From: Jou-An Chen Date: Thu, 14 Nov 2024 14:01:14 -0800 Subject: [PATCH 7/9] Address review comments Signed-off-by: Jou-An Chen --- QEfficient/peft/auto.py | 5 +- QEfficient/{ => peft}/lora/__init__.py | 2 +- QEfficient/{ => peft}/lora/auto.py | 88 +++++++++++-------- QEfficient/{ => peft}/lora/layers.py | 0 QEfficient/{ => peft}/lora/lora_model.py | 0 .../{ => peft}/lora/pytorch_transforms.py | 4 +- examples/lora_models.py | 20 +++-- tests/{ => peft}/lora/test_lora_model.py | 21 +++-- 8 files changed, 84 insertions(+), 56 deletions(-) rename QEfficient/{ => peft}/lora/__init__.py (83%) rename QEfficient/{ => peft}/lora/auto.py (82%) rename QEfficient/{ => peft}/lora/layers.py (100%) rename QEfficient/{ => peft}/lora/lora_model.py (100%) rename QEfficient/{ => peft}/lora/pytorch_transforms.py (92%) rename tests/{ => peft}/lora/test_lora_model.py (94%) diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 4acabf233..d2c62ef88 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -21,7 +21,7 @@ from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.lora import QEffAutoLoraModelForCausalLM +from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform @@ -147,6 +147,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): """ Args: :pretrained_name_or_path (str): Model card name from huggingface or local path to model directory. + :finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification. :args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM. """ if kwargs.get("full_batch_size"): @@ -162,6 +163,8 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): ).base_model_name_or_path, **kwargs, ) + if len(args) == 0 or not isinstance(list(args)[0], str): + raise TypeError("Required adapter name argument in string format") obj.load_adapter(pretrained_name_or_path, list(args)[0]) else: obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) diff --git a/QEfficient/lora/__init__.py b/QEfficient/peft/lora/__init__.py similarity index 83% rename from QEfficient/lora/__init__.py rename to QEfficient/peft/lora/__init__.py index 75966ff66..361972ba7 100644 --- a/QEfficient/lora/__init__.py +++ b/QEfficient/peft/lora/__init__.py @@ -5,7 +5,7 @@ # # ---------------------------------------------------------------------------- -from QEfficient.lora.auto import QEffAutoLoraModelForCausalLM +from QEfficient.peft.lora.auto import QEffAutoLoraModelForCausalLM __all__ = [ "QEffAutoLoraModelForCausalLM", diff --git a/QEfficient/lora/auto.py b/QEfficient/peft/lora/auto.py similarity index 82% rename from QEfficient/lora/auto.py rename to QEfficient/peft/lora/auto.py index a80388ad2..aa85aad7a 100644 --- a/QEfficient/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -16,7 +16,7 @@ import QEfficient from QEfficient import QEFFAutoModelForCausalLM -from QEfficient.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform +from QEfficient.peft.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.cache import to_hashable from QEfficient.utils.logging_utils import logger @@ -29,17 +29,11 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): Args: :model (nn.Module): PyTorch model - :base_model_name (str): Model card name for base model - :adapter_weights (Dict): A dictionary contains lora_name to lora_weight mapping - :adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping - :max_num_adapters (int): Total number of active adapters that to be exported and compiled - :active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping - :lora_rank (int): The consistent lora rank across all active adapters - :target_modules_for_all_adapters (List[str]): The consistent set of target modules across all active adapters + :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. .. code-block:: python - from QEfficient.lora import QEffAutoLoraModelForCausalLM + from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM m = QEffAutoPeftModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") m.load_adapter("predibase/gsm8k", "gsm8k") @@ -53,14 +47,13 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs) -> None: super().__init__(model, continuous_batching) - assert ( - type(self.model).__name__ == "QEffMistralForCausalLM" or type(self.model).__name__ == "QEffLlamaForCausalLM" - ), f"Only QEffMistralForCausalLM and QEffLlamaForCausalLM model are supported but get {type(self.model).__name__}" + if self.model.__class__.__name__ not in ["QEffMistralForCausalLM", "QEffLlamaForCausalLM"]: + raise NotImplementedError( + f"Only QEffMistralForCausalLM and QEffLlamaForCausalLM model are supported but get {self.model.__class__.__name__}" + ) - self.base_model_name = self.model.model.config._name_or_path self.adapter_weights = {} self.adapter_configs = {} - self.max_num_adapters = 0 self.active_adapter_to_id = {} self.lora_rank = 0 @@ -101,11 +94,15 @@ def download_adapter( adapter_weight: Optional[dict] = None, adapter_config: Optional[PeftConfig] = None, ): - """Loads a new adapter from huggingface hub or local path into CPU cache + """ + Loads a new adapter from huggingface hub or local path into CPU cache - Args: + ``Mandatory`` Args: :adapter_model_id (str): Adapter model ID from huggingface hub or local path - :adapter_name (str): Adapter name to be used to set this adapter as current + :adapter_name (str): Adapter name to be used to downloaded this adapter + ``Optional`` Args: + :adapter_weight (dict): Adapter weight tensors in dictionary format + :adapter_config (PeftConfig): Adapter config in the format of PeftConfig """ # check if adapter name already loaded @@ -128,7 +125,16 @@ def load_adapter( adapter_weight: Optional[dict] = None, adapter_config: Optional[PeftConfig] = None, ): - "Load adapter into CPU cache and Sets active adapter from one of the loaded adapters" + """ + Load adapter into CPU cache and set it as active + + ``Mandatory`` Args: + :adapter_model_id (str): Adapter model ID from huggingface hub or local path + :adapter_name (str): Adapter name to be used to load this adapter + ``Optional`` Args: + :adapter_weight (dict): Adapter weight tensors in dictionary format + :adapter_config (PeftConfig): Adapter config in the format of PeftConfig + """ # check if adapter name already exist and activated if adapter_name in self.active_adapter_to_id.keys(): @@ -151,22 +157,23 @@ def load_adapter( # set active adapter id to current max if adapter_name is new if adapter_name not in self.active_adapter_to_id.keys(): - self.active_adapter_to_id[adapter_name] = self.max_num_adapters + 1 # reserve 0 for base - - # add active adapter to set - self.max_num_adapters = len(self.active_adapter_to_id) + self.active_adapter_to_id[adapter_name] = len(self.active_adapter_to_id) + 1 # reserve 0 for base return self.active_adapter_to_id[adapter_name] def unload_adapter(self, adapter_name: str): - "Deactivate adpater and remove it from CPU cache" + """ + Deactivate adpater and remove it from CPU cache + + ``Mandatory`` Args: + :adapter_name (str): Adapter name to be unloaded + """ # step1: remove from active list if it's there if adapter_name not in self.active_adapter_to_id.keys(): logger.info(f"Adapter name {adapter_name} is not set active yet") return False - self.max_num_adapters -= 1 self.active_adapter_to_id.pop(adapter_name) # renumbering of active adapter id @@ -197,9 +204,9 @@ def _load_adapter_weights_to_model(self): for i in range(num_hidden_layers): for target_module in self.target_modules_for_all_adapters: # stack all adapters weights - a_tensor_list = list(range(self.max_num_adapters + 1)) - b_tensor_list = list(range(self.max_num_adapters + 1)) - s_tensor_list = list(range(self.max_num_adapters + 1)) + a_tensor_list = list(range(len(self.active_adapter_to_id) + 1)) + b_tensor_list = list(range(len(self.active_adapter_to_id) + 1)) + s_tensor_list = list(range(len(self.active_adapter_to_id) + 1)) for lora_name, lora_id in self.active_adapter_to_id.items(): if target_module in ["q_proj", "k_proj", "v_proj", "o_proj"]: @@ -256,10 +263,6 @@ def _load_adapter_weights_to_model(self): def _init_adapter_model(self): "Initialize the fixed lora model with multiple adapter weigths standby" - # assume all adapters have same target_modules and ranks - if self.max_num_adapters != len(self.active_adapter_to_id): - raise ValueError("Inconsistent max_num_adapters and active adapters") - # set lora rank self.lora_rank = list(self.adapter_configs.values())[0].r @@ -268,7 +271,7 @@ def _init_adapter_model(self): self.target_modules_for_all_adapters = list(self.adapter_configs.values())[0].target_modules _, transformed = TargetModulesTransform.apply( - self.model, self.target_modules_for_all_adapters, self.lora_rank, self.max_num_adapters + self.model, self.target_modules_for_all_adapters, self.lora_rank, len(self.active_adapter_to_id) ) # load_weight to model @@ -287,7 +290,11 @@ def export(self, export_dir: Optional[str] = None) -> str: """ # initialize the adapter model - assert self.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage" + if len(self.active_adapter_to_id) == 0: + raise ValueError( + "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage" + ) + self._init_adapter_model() bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE @@ -333,6 +340,7 @@ def generate( tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], prompts: List[str], device_id: List[int] = None, + prompt_to_adapter_mapping: List[str] = None, runtime: str = "AI_100", **kwargs, ): @@ -342,18 +350,28 @@ def generate( If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped. ``Mandatory`` Args: + :tokenizer (PreTrainedTokenizerFast or PreTrainedTokenizer): The tokenizer used in the inference :prompts (List[str]): List of prompts to run the execution. :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + :prompt_to_adapter_mapping (List[str]): The sequence of the adapter names will be matched with sequence of prompts and corresponding adapters will be used for the prompts."base" for base model (no adapter). ``optional`` Args: :runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100". - :prompt_to_adapter_mapping (List[str]): A list of adapter names that maps to the prompts, specifying which adapter the prompt wants to apply. "base" for base model (no adapter). + """ if runtime != "AI_100": raise ValueError("Only AI_100 runtime is supported right now via generate API") if not isinstance(self.qpc_path, Path): raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) - prompt_to_adapter_mapping = kwargs.pop("prompt_to_adapter_mapping", ["base" for _ in range(len(prompts))]) + + if not prompt_to_adapter_mapping: + prompt_to_adapter_mapping = ["base" for _ in range(len(prompts))] + + if len(prompt_to_adapter_mapping) != len(prompts): + raise RuntimeError( + f"Number of prompts should match number of prompt_to_adapter_mapping, got len(prompts) = {len(prompts)}, len(prompt_to_adapter_mapping) = {len(prompt_to_adapter_mapping)}" + ) + return QEfficient.cloud_ai_100_exec_kv( tokenizer, self.qpc_path, diff --git a/QEfficient/lora/layers.py b/QEfficient/peft/lora/layers.py similarity index 100% rename from QEfficient/lora/layers.py rename to QEfficient/peft/lora/layers.py diff --git a/QEfficient/lora/lora_model.py b/QEfficient/peft/lora/lora_model.py similarity index 100% rename from QEfficient/lora/lora_model.py rename to QEfficient/peft/lora/lora_model.py diff --git a/QEfficient/lora/pytorch_transforms.py b/QEfficient/peft/lora/pytorch_transforms.py similarity index 92% rename from QEfficient/lora/pytorch_transforms.py rename to QEfficient/peft/lora/pytorch_transforms.py index db70a984d..5e7463b97 100644 --- a/QEfficient/lora/pytorch_transforms.py +++ b/QEfficient/peft/lora/pytorch_transforms.py @@ -10,8 +10,8 @@ from torch import nn from QEfficient.base.pytorch_transforms import ModuleMappingTransform -from QEfficient.lora.layers import LinearBase, LinearMultiLoRA -from QEfficient.lora.lora_model import QEffLoraModelLlamaForCausalLM, QEffLoraModelMistralForCausalLM +from QEfficient.peft.lora.layers import LinearBase, LinearMultiLoRA +from QEfficient.peft.lora.lora_model import QEffLoraModelLlamaForCausalLM, QEffLoraModelMistralForCausalLM from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM from QEfficient.transformers.models.mistral.modeling_mistral import QEffMistralForCausalLM diff --git a/examples/lora_models.py b/examples/lora_models.py index 596f784e6..b4a1cd921 100644 --- a/examples/lora_models.py +++ b/examples/lora_models.py @@ -22,7 +22,9 @@ ) # (alternative) non-cb compilation -# qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/gsm8k", "gsm8k", continuous_batching=False, finite_adapters=True) +# qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained( +# "predibase/gsm8k", "gsm8k", continuous_batching=False, finite_adapters=True +# ) ## STEP 2 -- load adapter adapter qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen") @@ -47,13 +49,15 @@ ) # (alternative) non-cb compilation -# qpc_path = qeff_model.compile(batch_size=2, -# prefill_seq_len=seq_len, -# ctx_len=ctx_len, -# num_devices=len(device_group), -# num_cores=16, -# mxfp6_matmul=True, -# mxint8_kv_cache=True) +# qpc_path = qeff_model.compile( +# batch_size=2, +# prefill_seq_len=seq_len, +# ctx_len=ctx_len, +# num_devices=len(device_group), +# num_cores=16, +# mxfp6_matmul=True, +# mxint8_kv_cache=True, +# ) ## STEP 4 -- run inference on the generate function prompts = [ diff --git a/tests/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py similarity index 94% rename from tests/lora/test_lora_model.py rename to tests/peft/lora/test_lora_model.py index 7d5cb65ed..141a2d946 100644 --- a/tests/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -13,7 +13,7 @@ from transformers import AutoConfig, AutoModelForCausalLM from QEfficient import QEffAutoPeftModelForCausalLM -from QEfficient.lora import QEffAutoLoraModelForCausalLM +from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM from QEfficient.utils import load_hf_tokenizer configs = [ @@ -56,10 +56,8 @@ def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapt model_hf = AutoModelForCausalLM.from_pretrained(base_model_name) qeff_model = QEffAutoLoraModelForCausalLM(model_hf) - assert qeff_model.base_model_name == base_model_name assert len(qeff_model.adapter_weights) == 0 assert len(qeff_model.adapter_configs) == 0 - assert qeff_model.max_num_adapters == 0 assert len(qeff_model.active_adapter_to_id) == 0 @@ -68,10 +66,8 @@ def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapt def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(pretrained_model_name_or_path=base_model_name) - assert qeff_model.base_model_name == base_model_name assert len(qeff_model.adapter_weights) == 0 assert len(qeff_model.adapter_configs) == 0 - assert qeff_model.max_num_adapters == 0 assert len(qeff_model.active_adapter_to_id) == 0 @@ -80,21 +76,28 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_ def test_auto_peft_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(adapter_id_0, "id_0", finite_adapters=True) - assert qeff_model.base_model_name == base_model_name + assert isinstance(qeff_model, QEffAutoLoraModelForCausalLM) assert len(qeff_model.adapter_weights) == 1 assert len(qeff_model.adapter_configs) == 1 - assert qeff_model.max_num_adapters == 1 assert len(qeff_model.active_adapter_to_id) == 1 + # test pass without adapter name + with pytest.raises(TypeError): + QEffAutoLoraModelForCausalLM.from_pretrained(adapter_id_0, finite_adapters=True) + + # test pass with adapter name as integer + with pytest.raises(TypeError): + QEffAutoLoraModelForCausalLM.from_pretrained(adapter_id_0, 0, finite_adapters=True) + # test the init assertion for models that are not supported @pytest.mark.parametrize("base_model_name", ["distilbert/distilgpt2"]) def test_auto_lora_model_for_causal_lm_init_from_unsupported_model(base_model_name): model_hf = AutoModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) - with pytest.raises(AssertionError): + with pytest.raises(NotImplementedError): QEffAutoLoraModelForCausalLM(model_hf) - with pytest.raises(AssertionError): + with pytest.raises(NotImplementedError): QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) From e31568bd234de8feb204cccc14da860681de5e31 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 19 Nov 2024 13:18:05 +0530 Subject: [PATCH 8/9] allow adapter_name passed as keyword argument, updated all finite lora tests to use single layer models Signed-off-by: Onkar Chougule --- QEfficient/peft/auto.py | 7 +++++++ QEfficient/peft/lora/auto.py | 3 +++ tests/peft/lora/test_lora_model.py | 19 ++++++++++++++----- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index d2c62ef88..bcdd79bdf 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -82,6 +82,9 @@ def __init__(self, model: nn.Module): for adapter_name in model.peft_config } + def __repr__(self) -> str: + return self.__class__.__name__ + "\n" + self.model.__repr__() + @property def model_name(self) -> str: mname = self.model.get_base_model().__class__.__name__ + "-lora" @@ -148,6 +151,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): Args: :pretrained_name_or_path (str): Model card name from huggingface or local path to model directory. :finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification. + :adapter_name (str): Name used to identify loaded adapter. :args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM. """ if kwargs.get("full_batch_size"): @@ -163,6 +167,9 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): ).base_model_name_or_path, **kwargs, ) + if adapter_name := kwargs.pop("adapter_name", None): + obj.load_adapter(pretrained_name_or_path, adapter_name=adapter_name) + return obj if len(args) == 0 or not isinstance(list(args)[0], str): raise TypeError("Required adapter name argument in string format") obj.load_adapter(pretrained_name_or_path, list(args)[0]) diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py index aa85aad7a..2ccfac12a 100644 --- a/QEfficient/peft/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -59,6 +59,9 @@ def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs self.lora_rank = 0 self.target_modules_for_all_adapters = [] + def __repr__(self) -> str: + return self.__class__.__name__ + "\n" + self.model.__repr__() + @property def model_hash(self) -> str: mhash = hashlib.sha256() diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py index 141a2d946..0473cc7e4 100644 --- a/tests/peft/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -53,7 +53,7 @@ def create_lora_base_model(base_config): # test model initialization using __init__ approach @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapter_id_1): - model_hf = AutoModelForCausalLM.from_pretrained(base_model_name) + model_hf = AutoModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) qeff_model = QEffAutoLoraModelForCausalLM(model_hf) assert len(qeff_model.adapter_weights) == 0 @@ -64,7 +64,9 @@ def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapt # test model initialization using from_pretrained approach @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): - qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(pretrained_model_name_or_path=base_model_name) + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=base_model_name, num_hidden_layers=1 + ) assert len(qeff_model.adapter_weights) == 0 assert len(qeff_model.adapter_configs) == 0 @@ -74,8 +76,15 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_ # test peft model initialization using from_pretrained approach @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) def test_auto_peft_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): - qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(adapter_id_0, "id_0", finite_adapters=True) + qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained( + adapter_id_0, "id_0", finite_adapters=True, num_hidden_layers=1 + ) + qeff_model_tmp = QEffAutoPeftModelForCausalLM.from_pretrained( + adapter_id_0, adapter_name="id_0", finite_adapters=True, num_hidden_layers=1 + ) + assert qeff_model.active_adapter_to_id == qeff_model_tmp.active_adapter_to_id + del qeff_model_tmp assert isinstance(qeff_model, QEffAutoLoraModelForCausalLM) assert len(qeff_model.adapter_weights) == 1 assert len(qeff_model.adapter_configs) == 1 @@ -83,11 +92,11 @@ def test_auto_peft_model_for_causal_lm_from_pretrained(base_model_name, adapter_ # test pass without adapter name with pytest.raises(TypeError): - QEffAutoLoraModelForCausalLM.from_pretrained(adapter_id_0, finite_adapters=True) + QEffAutoLoraModelForCausalLM.from_pretrained(adapter_id_0, finite_adapters=True, num_hidden_layers=1) # test pass with adapter name as integer with pytest.raises(TypeError): - QEffAutoLoraModelForCausalLM.from_pretrained(adapter_id_0, 0, finite_adapters=True) + QEffAutoLoraModelForCausalLM.from_pretrained(adapter_id_0, 0, finite_adapters=True, num_hidden_layers=1) # test the init assertion for models that are not supported From 42a240d41e5a484106dba1273c0c1584e255b3fc Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 19 Nov 2024 13:43:30 +0530 Subject: [PATCH 9/9] added pytest on_qaic marker for lora test using AI_100 device Signed-off-by: Onkar Chougule --- tests/peft/lora/test_lora_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py index 0473cc7e4..a91555b3a 100644 --- a/tests/peft/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -196,6 +196,7 @@ def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adap # test the export, export caching, compile, generate workflow +@pytest.mark.on_qaic @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1]) def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path): qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1)