Skip to content

Commit 2fb41ad

Browse files
committed
Comments Addressed-1
Signed-off-by: amitraj <[email protected]>
1 parent 74ffc16 commit 2fb41ad

File tree

4 files changed

+105
-55
lines changed

4 files changed

+105
-55
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ def get_compilation_dims(qpc_path: str) -> Tuple[int, int, Optional[int]]:
174174
raise FileNotFoundError(f"expected specializations.json file at path, {qpc_base_path}")
175175

176176
compilation_batch_size = int(data["specializations"][0]["batch_size"])
177-
if compilation_ctx_len := data["specializations"][0].get("ctx_len", None):
178-
compilation_ctx_len = int(data["specializations"][0]["ctx_len"])
177+
compilation_ctx_len = int(data["specializations"][0]["ctx_len"])
179178
if compilation_fbs := data["specializations"][0].get("full_batch_size", None):
180179
compilation_fbs = int(compilation_fbs)
181180
return compilation_batch_size, compilation_ctx_len, compilation_fbs
@@ -352,8 +351,24 @@ def cloud_ai_100_exec_embed(
352351
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
353352
qpc_path: str,
354353
prompt: List[str],
355-
device_id: List[int] = [0],
356-
):
354+
device_id: List[int] = [0],
355+
) -> dict:
356+
"""
357+
This method generates output by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
358+
This is a sequential execution based on the ``batch_size`` of the compiled model and the number of prompts passed.
359+
If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped.
360+
361+
``Mandatory`` Args:
362+
:tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Model tokenizer.
363+
:qpc_path (str): Path to the saved generated binary file after compilation.
364+
:prompt (str): Sample prompt for the model text generation.
365+
``Optional`` Args:
366+
:device_id (List[int]): Device IDs to be used for execution. If ``len(device_id) > 1``, it enables multiple card setup. If ``None``, auto-device-picker will be used. ``Defaults to None``.
367+
368+
Returns:
369+
:dict: Output from the ``AI_100`` runtime.
370+
"""
371+
357372
session = QAICInferenceSession(qpc_path, device_ids=device_id)
358373
batch_size = session.bindings[0].dims[0]
359374
seq_len = session.bindings[0].dims[1]
@@ -368,8 +383,10 @@ def cloud_ai_100_exec_embed(
368383
}
369384
session.set_buffers(output)
370385
outputs = session.run(inputs)
386+
session.deactivate()
371387
return outputs
372388

389+
373390
class QEffTextGenerationBase:
374391
def __init__(
375392
self,

QEfficient/transformers/models/modeling_auto.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,6 @@ def model_name(self) -> str:
6666
mname = mname[4:]
6767
return mname
6868

69-
@property
70-
def model_hash(self) -> str:
71-
# NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path.
72-
# Using same card name will result in same hash. But, using a relative path for one run and
73-
# absolute path for another run will result in different hash.
74-
# The added complexity to resolve different paths to same location is not worth pursuing.
75-
# Instead, advise the user to always provide same relative paths or absolute paths for local models.
76-
77-
# Compute the hash with: model_config, transforms
78-
mhash = hashlib.sha256()
79-
mhash.update(to_hashable(self.model.config.to_diff_dict()))
80-
mhash.update(to_hashable(self._transform_names()))
81-
mhash.update(to_hashable({"is_tlm": self.is_tlm}))
82-
mhash = mhash.hexdigest()[:16]
83-
return mhash
84-
8569

8670
class QEFFAutoModelForCausalLM(QEFFTransformersBase):
8771
"""
@@ -349,8 +333,9 @@ def generate(
349333
self,
350334
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
351335
prompts: List[str],
352-
device_id: List[int] = None,
353-
runtime: str = "AI_100",
336+
device_id: List[int] = [0],
337+
runtime_ai100: bool = True,
338+
seq_len: int = constants.Constants.CTX_LEN,
354339
**kwargs,
355340
):
356341
"""
@@ -362,21 +347,24 @@ def generate(
362347
:prompts (List[str]): List of prompts to run the execution.
363348
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
364349
``optional`` Args:
365-
:runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
350+
:runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
351+
366352
"""
367-
if runtime != "AI_100":
368-
raise ValueError("Only AI_100 runtime is supported right now via generate API")
369-
if not isinstance(self.qpc_path, Path):
370-
raise TypeError("Please run compile API first!")
371-
generation_len = kwargs.pop("generation_len", None)
372-
return QEfficient.cloud_ai_100_exec_kv(
373-
tokenizer,
374-
self.qpc_path,
375-
prompt=prompts,
376-
device_id=device_id,
377-
generation_len=generation_len,
378-
is_tlm=self.is_tlm,
379-
)
353+
if runtime_ai100:
354+
if not isinstance(self.qpc_path, Path):
355+
raise TypeError("Please run compile API first!")
356+
generation_len = kwargs.pop("generation_len", None)
357+
return QEfficient.cloud_ai_100_exec_kv(
358+
tokenizer,
359+
self.qpc_path,
360+
prompt=prompts,
361+
device_id=device_id,
362+
generation_len=generation_len,
363+
is_tlm=self.is_tlm,
364+
)
365+
else:
366+
inputs = tokenizer(prompts, return_tensors="pt", padding="max_length", max_length=seq_len)
367+
return self.model(**inputs)
380368

381369

382370
class QEffAutoModel(QEFFTransformersBase):
@@ -405,7 +393,7 @@ def __init__(self, model: nn.Module, **kwargs):
405393
super().__init__(model)
406394
self.model.config.use_cache = True
407395
self.num_layers = model.config.num_hidden_layers
408-
396+
409397
@classmethod
410398
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
411399
"""
@@ -429,11 +417,26 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
429417
# You can now execute the model
430418
model.generate(prompts=["Hi there!!"])
431419
"""
432-
420+
433421
self = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
434422

435423
return self
436424

425+
@property
426+
def model_hash(self) -> str:
427+
# NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path.
428+
# Using same card name will result in same hash. But, using a relative path for one run and
429+
# absolute path for another run will result in different hash.
430+
# The added complexity to resolve different paths to same location is not worth pursuing.
431+
# Instead, advise the user to always provide same relative paths or absolute paths for local models.
432+
433+
# Compute the hash with: model_config, transforms
434+
mhash = hashlib.sha256()
435+
mhash.update(to_hashable(self.model.config.to_diff_dict()))
436+
mhash.update(to_hashable(self._transform_names()))
437+
mhash = mhash.hexdigest()[:16]
438+
return mhash
439+
437440
def export(self, export_dir: Optional[str] = None) -> str:
438441
"""
439442
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
@@ -470,7 +473,9 @@ def compile(
470473
*,
471474
seq_len: int = 32,
472475
batch_size: int = 1,
476+
num_devices: int = 1,
473477
num_cores: int = 16, # FIXME: Make this mandatory arg
478+
mxfp6_matmul: bool = False,
474479
**compiler_options,
475480
) -> str:
476481
"""
@@ -498,18 +503,20 @@ def compile(
498503
compile_only=True,
499504
specializations=specializations,
500505
convert_to_fp16=True,
506+
mxfp6_matmul=mxfp6_matmul,
507+
mdp_ts_num_devices=num_devices,
501508
aic_num_cores=num_cores,
502509
**compiler_options,
503510
)
504511

505512
def generate(
506513
self,
507514
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
508-
prompt: List[str],
515+
prompts: List[str],
509516
device_id: List[int] = [0],
510517
runtime_ai100: bool = True,
511518
seq_len: int = constants.Constants.CTX_LEN,
512-
) -> str:
519+
) -> dict:
513520
"""
514521
This method generates output by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
515522
This is a sequential execution based on the ``batch_size`` of the compiled model and the number of prompts passed.
@@ -519,10 +526,10 @@ def generate(
519526
:prompts (List[str]): List of prompts to run the execution.
520527
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
521528
``optional`` Args:
522-
:runtime_ai100 (bool), optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
529+
:runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
523530
524531
Returns:
525-
:str: Output from the ``AI_100`` or ``PyTorch`` runtime.
532+
:dict: Output from the ``AI_100`` or ``PyTorch`` runtime.
526533
"""
527534

528535
# AI_100 runtime
@@ -531,10 +538,9 @@ def generate(
531538
raise TypeError("Please run compile API first!")
532539

533540
return QEfficient.cloud_ai_100_exec_embed(
534-
tokenizer=tokenizer, prompt=prompt, qpc_path=self.qpc_path, device_id=device_id
541+
tokenizer=tokenizer, prompt=prompts, qpc_path=self.qpc_path, device_id=device_id
535542
)
536543
# PyTorch runtime
537544
else:
538-
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=seq_len)
545+
inputs = tokenizer(prompts, return_tensors="pt", padding="max_length", max_length=seq_len)
539546
return self.model(**inputs)
540-

QEfficient/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_models_dir():
4747
ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32
4848
ONNX_EXPORT_EXAMPLE_FBS = 4
4949
ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep
50-
ONNX_EXPORT_OPSET = 13
50+
ONNX_EXPORT_OPSET = 14
5151

5252
COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"]
5353

tests/transformers/models/test_causal_lm_models.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
1616
from QEfficient.transformers.models.modeling_auto import QEffAutoModel, QEFFAutoModelForCausalLM
1717
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers
18-
from QEfficient.utils import hf_download
18+
from QEfficient.utils import hf_download, padding_check_and_fix
1919
from QEfficient.utils._utils import load_hf_tokenizer
2020
from QEfficient.utils.constants import Constants
2121
from QEfficient.utils.device_utils import get_available_device_id
@@ -192,13 +192,26 @@ def check_embed_pytorch_vs_ort_vs_ai100(
192192

193193
# Try to initialize with add_pooling_layer parameter
194194
try:
195-
qeff_model = QEffAutoModel.from_pretrained(pretrained_model_name_or_path=model_path, add_pooling_layer=False)
195+
qeff_model = QEffAutoModel.from_pretrained(
196+
pretrained_model_name_or_path=model_path,
197+
add_pooling_layer=False,
198+
num_hidden_layers=n_layer,
199+
attn_implementation="eager",
200+
trust_remote_code=True,
201+
)
196202
except TypeError:
197203
# If it fails, initialize without the parameter
198-
qeff_model = QEffAutoModel.from_pretrained(pretrained_model_name_or_path=model_path)
199-
text = "My name is"
204+
qeff_model = QEffAutoModel.from_pretrained(
205+
pretrained_model_name_or_path=model_path,
206+
num_hidden_layers=n_layer,
207+
attn_implementation="eager",
208+
trust_remote_code=True,
209+
)
210+
211+
prompt = "My name is"
200212
tokenizer = AutoTokenizer.from_pretrained(model_name)
201-
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=seq_len)
213+
padding_check_and_fix(tokenizer)
214+
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=seq_len)
202215

203216
pt_outputs = qeff_model.generate(tokenizer=tokenizer, prompt="My name is", runtime_ai100=False)
204217

@@ -214,7 +227,7 @@ def check_embed_pytorch_vs_ort_vs_ai100(
214227
onnx_embeddings = onnx_outputs[0]
215228
mad = np.mean(np.abs(pt_embeddings - onnx_embeddings))
216229
print("Mad for onnx and pytorch is ", mad)
217-
assert mad <= 10**-5, f"MAD is too high for onnx and Pytorch: {mad}"
230+
assert mad <= 10**-3, f"MAD is too high for onnx and Pytorch: {mad}"
218231

219232
qeff_model.compile(
220233
num_cores=14,
@@ -224,7 +237,7 @@ def check_embed_pytorch_vs_ort_vs_ai100(
224237
# Compare ONNX and AI 100 outputs
225238
mad = np.mean(np.abs(ai100_output["output"] - onnx_outputs[0]))
226239
print("Mad for onnx and AI 100 output is ", mad)
227-
assert mad <= 10**-2, f"MAD is too high for onnx and Pytorch: {mad}"
240+
assert mad <= 10**-3, f"MAD is too high for onnx and Pytorch: {mad}"
228241

229242

230243
# FIXME: there should be a CB test here
@@ -302,7 +315,21 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1():
302315
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len)
303316

304317

318+
embed_test_models = [
319+
# model_name, architecture
320+
"nomic-ai/nomic-embed-text-v1.5", # NomicBertModel
321+
"sentence-transformers/multi-qa-mpnet-base-cos-v1", # MPNetForMaskedLM
322+
"BAAI/bge-reranker-v2-m3", # XLMRobertaForSequenceClassification
323+
"BAAI/bge-small-en-v1.5", # BertModel
324+
# "intfloat/e5-mistral-7b-instruct", # MistralModel
325+
# "dunzhang/stella_en_1.5B_v5", # Qwen2ForCausalLM
326+
]
327+
328+
305329
@pytest.mark.on_qaic
306-
def test_embed_model_pytorch_vs_onnx_vs_ai100():
307-
model_name = "BAAI/bge-small-en-v1.5"
330+
@pytest.mark.parametrize("model_name", embed_test_models)
331+
def test_embed_model_pytorch_vs_onnx_vs_ai100(model_name):
332+
"""
333+
Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output.
334+
"""
308335
check_embed_pytorch_vs_ort_vs_ai100(model_name=model_name, seq_len=32, n_layer=1)

0 commit comments

Comments
 (0)