Skip to content

Commit 4efa813

Browse files
committed
Added support for embedding models
Signed-off-by: amitraj <[email protected]>
1 parent 881a766 commit 4efa813

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def cloud_ai_100_exec_embed(
367367
return prefill_outputs
368368

369369

370-
class TextGeneration:
370+
class QEffTextGenerationBase:
371371
def __init__(
372372
self,
373373
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],

QEfficient/transformers/models/modeling_auto.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,13 @@ def export(self, export_dir: Optional[str] = None) -> str:
213213
example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
214214
dynamic_axes["batch_index"] = {0: "batch_size"}
215215

216+
return self._export(
217+
example_inputs,
218+
output_names,
219+
dynamic_axes,
220+
export_dir=export_dir,
221+
)
222+
216223
def compile(
217224
self,
218225
onnx_path: Optional[str] = None,
@@ -381,7 +388,7 @@ def generate(
381388
device_id: List[int] = [0],
382389
runtime_ai100: bool = True,
383390
seq_len: int = constants.Constants.CTX_LEN,
384-
):
391+
) -> str:
385392
if runtime_ai100:
386393
if not isinstance(self.qpc_path, Path):
387394
raise TypeError("Please run compile API first!")

0 commit comments

Comments
 (0)