Skip to content

Commit a566952

Browse files
committed
fix-3
Signed-off-by: amitraj <[email protected]>
1 parent 9c00ef7 commit a566952

File tree

6 files changed

+29
-40
lines changed

6 files changed

+29
-40
lines changed

QEfficient/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from QEfficient.base import QEffAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader
99
from QEfficient.compile.compile_helper import compile
1010
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
11-
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_embedd, cloud_ai_100_exec_kv
11+
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_embed, cloud_ai_100_exec_kv
1212
from QEfficient.peft import QEffAutoPeftModelForCausalLM
1313
from QEfficient.transformers.transform import transform
1414

@@ -21,7 +21,7 @@
2121
"export",
2222
"compile",
2323
"cloud_ai_100_exec_kv",
24-
"cloud_ai_100_exec_embedd",
24+
"cloud_ai_100_exec_embed",
2525
"QEffAutoModel",
2626
"QEFFAutoModelForCausalLM",
2727
"QEffAutoPeftModelForCausalLM",

QEfficient/base/modeling_qeff.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,7 @@ class QEFFBaseModel(ABC):
4343

4444
@classmethod
4545
def _transform_names(cls) -> List[str]:
46-
transform_names = []
47-
if hasattr(cls, "_pytorch_transforms") and cls._pytorch_transforms:
48-
transform_names.extend(x.__name__ for x in cls._pytorch_transforms)
49-
if hasattr(cls, "_onnx_transforms") and cls._onnx_transforms:
50-
transform_names.extend(x.__name__ for x in cls._onnx_transforms)
51-
return transform_names
46+
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
5247

5348
def __init__(self, model: torch.nn.Module) -> None:
5449
super().__init__()
@@ -59,11 +54,9 @@ def __init__(self, model: torch.nn.Module) -> None:
5954

6055
# Apply the transformations
6156
any_transformed = False
62-
63-
if hasattr(self, "_pytorch_transforms") and self._pytorch_transforms:
64-
for transform in self._pytorch_transforms:
65-
self.model, transformed = transform.apply(self.model)
66-
any_transformed = any_transformed or transformed
57+
for transform in self._pytorch_transforms:
58+
self.model, transformed = transform.apply(self.model)
59+
any_transformed = any_transformed or transformed
6760

6861
if not any_transformed:
6962
warnings.warn(f"No transforms applied to model: {self.model_name}. It may be an unsupported model!")
@@ -137,7 +130,6 @@ def _export(
137130
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
138131
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
139132
"""
140-
141133
export_dir = Path(export_dir or (QEFF_HOME / self.model_name))
142134
export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash)
143135
onnx_path = export_dir / f"{self.model_name}.onnx"
@@ -224,7 +216,6 @@ def _compile(
224216
- aic_num_cores=16 -> -aic-num-cores=16
225217
- convert_to_fp16=True -> -convert-to-fp16
226218
"""
227-
228219
if onnx_path is None and self.onnx_path is None:
229220
self.export()
230221

QEfficient/exporter/export_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def fix_onnx_fp16(
217217
Return:
218218
:str: Updated base name of exported ONNX model.
219219
"""
220-
model = onnx.load("/local/mnt/workspace/amitraj/amit_efficient/efficient-transformers/model_base_name.onnx")
220+
model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
221221
# TODO: Remove this `fix_onnx_fp16` function and replace with this transform
222222
# as we're not utilizing the validations done in this function
223223
model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path)

QEfficient/generation/text_generation_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def cloud_ai_100_exec_kv(
308308
return exec_info
309309

310310

311-
def cloud_ai_100_exec_embedd(
311+
def cloud_ai_100_exec_embed(
312312
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
313313
prompt: List[str],
314314
qpc_path: str,

QEfficient/transformers/models/modeling_auto.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -379,17 +379,20 @@ def generate(
379379
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
380380
prompt: List[str],
381381
device_id: List[int] = [0],
382-
runtime: str = "AI_100",
383-
**kwargs,
382+
runtime_ai100: bool = True,
383+
seq_len: int = constants.Constants.CTX_LEN,
384384
):
385-
if runtime != "AI_100":
386-
raise ValueError("Only AI_100 runtime is supported right now via generate API")
387-
if not isinstance(self.qpc_path, Path):
388-
raise TypeError("Please run compile API first!")
385+
if runtime_ai100:
386+
if not isinstance(self.qpc_path, Path):
387+
raise TypeError("Please run compile API first!")
389388

390-
return QEfficient.cloud_ai_100_exec_embedd(
391-
tokenizer=tokenizer, prompt=prompt, qpc_path=self.qpc_path, device_id=device_id
392-
)
389+
return QEfficient.cloud_ai_100_exec_embed(
390+
tokenizer=tokenizer, prompt=prompt, qpc_path=self.qpc_path, device_id=device_id
391+
)
392+
else:
393+
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=seq_len)
394+
return self.model(**inputs)
395+
393396

394397
@classmethod
395398
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):

tests/transformers/models/test_causal_lm_models.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
180180
), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output."
181181

182182

183-
def check_embedd_pytorch_vs_ort_vs_ai100(
183+
def check_embed_pytorch_vs_ort_vs_ai100(
184184
model_name: str,
185185
seq_len: int = Constants.CTX_LEN,
186186
n_layer: int = 1,
@@ -197,14 +197,12 @@ def check_embedd_pytorch_vs_ort_vs_ai100(
197197
except TypeError:
198198
# If it fails, initialize without the parameter
199199
model = AutoModel.from_pretrained(model_name)
200-
qeff_model = QEffAutoModel.from_pretrained(pretrained_model_name_or_path=model_path, add_pooling_layer=False)
200+
qeff_model = QEffAutoModel.from_pretrained(pretrained_model_name_or_path=model_path)
201201
text = "My name is"
202202
tokenizer = AutoTokenizer.from_pretrained(model_name)
203203
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=seq_len)
204204

205-
# PyTorch output
206-
with torch.no_grad():
207-
pt_outputs = model(**inputs)
205+
pt_outputs=qeff_model.generate(tokenizer=tokenizer, prompt="My name is", runtime_ai100=False)
208206

209207
onnx_model = qeff_model.export()
210208
ort_session = ort.InferenceSession(str(onnx_model))
@@ -213,22 +211,19 @@ def check_embedd_pytorch_vs_ort_vs_ai100(
213211
# Run inference
214212
onnx_outputs = ort_session.run(None, onnx_inputs)
215213

216-
# Extract the embeddings from PyTorch and ONNX outputs
217-
pt_embeddings = pt_outputs[0].numpy()
214+
# Compare PyTorch and ONNX outputs
215+
pt_embeddings = pt_outputs[0].detach().numpy()
218216
onnx_embeddings = onnx_outputs[0]
219-
220-
# Calculate Mean Absolute Deviation (MAD)
221217
mad = np.mean(np.abs(pt_embeddings - onnx_embeddings))
222218
print("Mad for onnx and pytorch is ", mad)
223219
assert mad <= 10**-5, f"MAD is too high for onnx and Pytorch: {mad}"
224220

225-
# Compare with cloud AI100
226-
227221
qeff_model.compile(
228222
num_cores=14,
229223
)
230224
ai100_output = qeff_model.generate(tokenizer=tokenizer, prompt=["My name is"])
231225

226+
# Compare ONNX and AI 100 outputs
232227
mad = np.mean(np.abs(ai100_output["output"] - onnx_outputs[0]))
233228
print("Mad for onnx and AI 100 output is ", mad)
234229
assert mad <= 10**-2, f"MAD is too high for onnx and Pytorch: {mad}"
@@ -290,7 +285,7 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1():
290285

291286
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len)
292287

293-
294-
def test_embedd_model_pytorch_vs_onnx_vs_ai100():
288+
@pytest.mark.on_qaic
289+
def test_embed_model_pytorch_vs_onnx_vs_ai100():
295290
model_name = "BAAI/bge-small-en-v1.5"
296-
check_embedd_pytorch_vs_ort_vs_ai100(model_name=model_name, seq_len=32, n_layer=1)
291+
check_embed_pytorch_vs_ort_vs_ai100(model_name=model_name, seq_len=32, n_layer=1)

0 commit comments

Comments
 (0)