Skip to content

Commit 6c4b5a6

Browse files
committed
Addressed comments-1
Signed-off-by: amitraj <[email protected]>
1 parent 3926d0e commit 6c4b5a6

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,9 @@ def generate(
333333
self,
334334
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
335335
prompts: List[str],
336-
device_id: List[int] = None,
337-
runtime: str = "AI_100",
336+
device_id: List[int] = [0],
337+
runtime_ai100: bool = True,
338+
seq_len: int = constants.Constants.CTX_LEN,
338339
**kwargs,
339340
):
340341
"""
@@ -346,21 +347,25 @@ def generate(
346347
:prompts (List[str]): List of prompts to run the execution.
347348
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
348349
``optional`` Args:
349-
:runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
350+
:runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
351+
350352
"""
351-
if runtime != "AI_100":
352-
raise ValueError("Only AI_100 runtime is supported right now via generate API")
353-
if not isinstance(self.qpc_path, Path):
354-
raise TypeError("Please run compile API first!")
355-
generation_len = kwargs.pop("generation_len", None)
356-
return QEfficient.cloud_ai_100_exec_kv(
357-
tokenizer,
358-
self.qpc_path,
359-
prompt=prompts,
360-
device_id=device_id,
361-
generation_len=generation_len,
362-
is_tlm=self.is_tlm,
363-
)
353+
if runtime_ai100:
354+
if not isinstance(self.qpc_path, Path):
355+
raise TypeError("Please run compile API first!")
356+
generation_len = kwargs.pop("generation_len", None)
357+
return QEfficient.cloud_ai_100_exec_kv(
358+
tokenizer,
359+
self.qpc_path,
360+
prompt=prompts,
361+
device_id=device_id,
362+
generation_len=generation_len,
363+
is_tlm=self.is_tlm,
364+
)
365+
else:
366+
inputs = tokenizer(prompts, return_tensors="pt", padding="max_length", max_length=seq_len)
367+
return self.model(**inputs)
368+
364369

365370

366371
class QEffAutoModel(QEFFTransformersBase):
@@ -469,7 +474,9 @@ def compile(
469474
*,
470475
seq_len: int = 32,
471476
batch_size: int = 1,
477+
num_devices: int = 1,
472478
num_cores: int = 16, # FIXME: Make this mandatory arg
479+
mxfp6_matmul: bool = False,
473480
**compiler_options,
474481
) -> str:
475482
"""
@@ -497,14 +504,16 @@ def compile(
497504
compile_only=True,
498505
specializations=specializations,
499506
convert_to_fp16=True,
507+
mxfp6_matmul=mxfp6_matmul,
508+
mdp_ts_num_devices=num_devices,
500509
aic_num_cores=num_cores,
501510
**compiler_options,
502511
)
503512

504513
def generate(
505514
self,
506515
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
507-
prompt: List[str],
516+
prompts: List[str],
508517
device_id: List[int] = [0],
509518
runtime_ai100: bool = True,
510519
seq_len: int = constants.Constants.CTX_LEN,
@@ -518,7 +527,7 @@ def generate(
518527
:prompts (List[str]): List of prompts to run the execution.
519528
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
520529
``optional`` Args:
521-
:runtime_ai100 (bool), optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
530+
:runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
522531
523532
Returns:
524533
:dict: Output from the ``AI_100`` or ``PyTorch`` runtime.
@@ -530,9 +539,9 @@ def generate(
530539
raise TypeError("Please run compile API first!")
531540

532541
return QEfficient.cloud_ai_100_exec_embed(
533-
tokenizer=tokenizer, prompt=prompt, qpc_path=self.qpc_path, device_id=device_id
542+
tokenizer=tokenizer, prompt=prompts, qpc_path=self.qpc_path, device_id=device_id
534543
)
535544
# PyTorch runtime
536545
else:
537-
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=seq_len)
546+
inputs = tokenizer(prompts, return_tensors="pt", padding="max_length", max_length=seq_len)
538547
return self.model(**inputs)

0 commit comments

Comments
 (0)