Skip to content

Finite lorax support #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 19, 2024
39 changes: 39 additions & 0 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def cloud_ai_100_exec_kv(
stream: bool = True,
write_io_dir: Optional[str] = None,
automation=False,
prompt_to_lora_id_mapping: Optional[List[int]] = None,
):
"""
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
Expand Down Expand Up @@ -277,6 +278,7 @@ def cloud_ai_100_exec_kv(
stream=stream,
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
prompt_to_lora_id_mapping=prompt_to_lora_id_mapping,
)
if full_batch_size is None:
exec_info = [
Expand Down Expand Up @@ -313,6 +315,7 @@ def __init__(
qpc_path: str,
prompt: List[str],
full_batch_size: Optional[int] = None,
prompt_to_lora_id_mapping: Optional[List[int]] = None,
ctx_len: Optional[int] = None,
generation_len: Optional[int] = None,
device_id: Optional[List[int]] = None,
Expand Down Expand Up @@ -342,6 +345,16 @@ def __init__(
full_batch_size if full_batch_size else self._fetch_full_batch_size()
) # Check and fetch full batch size if CB is enabled

if prompt_to_lora_id_mapping:
self.prompt_to_lora_id_mapping_prefill = deque(prompt_to_lora_id_mapping)
if self.full_batch_size:
self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping
else:
self.prompt_to_lora_id_mapping_decode = deque(prompt_to_lora_id_mapping)
else:
self.prompt_to_lora_id_mapping_prefill = None
self.prompt_to_lora_id_mapping_decode = None

self.set_tokenizer_params() # set tokenizer params

# Initialize the storage variables.
Expand Down Expand Up @@ -461,6 +474,16 @@ def prepare_decode_inputs(self):
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index

if self.prompt_to_lora_id_mapping_decode:
if self.full_batch_size:
first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)]
decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape(
self.full_batch_size, 1
)
else:
batch_lora_ids = [self.prompt_to_lora_id_mapping_decode.popleft() for i in range(self.batch_size)]
decode_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

return decode_inputs

def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None):
Expand Down Expand Up @@ -549,6 +572,15 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
if decode_batch_id is not None:
inputs["batch_index"] = decode_batch_id

if self.prompt_to_lora_id_mapping_prefill:
if self.full_batch_size:
inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape(
1, 1
)
else:
batch_lora_ids = [self.prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][
Expand Down Expand Up @@ -625,6 +657,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):

self.session.set_buffers({"logits": logits_out_placeholder})
decode_pause_time += perf_counter() - start

if self.prompt_to_lora_id_mapping_decode:
decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[
batch_id_map[decode_batch_id]
]

else:
current_decode_ongoing[decode_batch_id] = False
else:
Expand All @@ -636,6 +674,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
)

generated_id_current_index[decode_batch_id] += 1

return decode_pause_time

def run_decode(self, decode_inputs, generation_len):
Expand Down
26 changes: 24 additions & 2 deletions QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np
import torch
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM, load_peft_weights
from peft import AutoPeftModelForCausalLM, PeftConfig, PeftModelForCausalLM, load_peft_weights
from torch import nn
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList
from transformers.generation.streamers import BaseStreamer
Expand All @@ -21,6 +21,7 @@
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform
from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform
from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform
Expand All @@ -38,6 +39,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel):

Args:
:model (nn.Module): PyTorch model
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.

.. code-block:: python

Expand Down Expand Up @@ -80,6 +82,9 @@ def __init__(self, model: nn.Module):
for adapter_name in model.peft_config
}

def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()

@property
def model_name(self) -> str:
mname = self.model.get_base_model().__class__.__name__ + "-lora"
Expand Down Expand Up @@ -145,14 +150,31 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
"""
Args:
:pretrained_name_or_path (str): Model card name from huggingface or local path to model directory.
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.
:adapter_name (str): Name used to identify loaded adapter.
:args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM.
"""
if kwargs.get("full_batch_size"):
raise NotImplementedError("Continuous batching currently not supported for PEFT models")
if kwargs.get("use_cache") is False:
warnings.warn("Overriding to use_cache=True")
kwargs["use_cache"] = True
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)

if kwargs.pop("finite_adapters", False): # initialize through finite_adapters class
obj = QEffAutoLoraModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=PeftConfig.from_pretrained(
pretrained_name_or_path
).base_model_name_or_path,
**kwargs,
)
if adapter_name := kwargs.pop("adapter_name", None):
obj.load_adapter(pretrained_name_or_path, adapter_name=adapter_name)
return obj
if len(args) == 0 or not isinstance(list(args)[0], str):
raise TypeError("Required adapter name argument in string format")
obj.load_adapter(pretrained_name_or_path, list(args)[0])
else:
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj

def export(self, export_dir: Optional[str] = None) -> str:
Expand Down
12 changes: 12 additions & 0 deletions QEfficient/peft/lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from QEfficient.peft.lora.auto import QEffAutoLoraModelForCausalLM

__all__ = [
"QEffAutoLoraModelForCausalLM",
]
Loading
Loading