Skip to content

Commit 74ffc16

Browse files
committed
Fix-1
Signed-off-by: amitraj <[email protected]>
1 parent 3f95df7 commit 74ffc16

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ def get_compilation_dims(qpc_path: str) -> Tuple[int, int, Optional[int]]:
174174
raise FileNotFoundError(f"expected specializations.json file at path, {qpc_base_path}")
175175

176176
compilation_batch_size = int(data["specializations"][0]["batch_size"])
177-
compilation_ctx_len = int(data["specializations"][0]["ctx_len"])
177+
if compilation_ctx_len := data["specializations"][0].get("ctx_len", None):
178+
compilation_ctx_len = int(data["specializations"][0]["ctx_len"])
178179
if compilation_fbs := data["specializations"][0].get("full_batch_size", None):
179180
compilation_fbs = int(compilation_fbs)
180181
return compilation_batch_size, compilation_ctx_len, compilation_fbs
@@ -349,25 +350,25 @@ def cloud_ai_100_exec_kv(
349350

350351
def cloud_ai_100_exec_embed(
351352
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
352-
prompt: List[str],
353353
qpc_path: str,
354-
device_id: List[int] = [0],
354+
prompt: List[str],
355+
device_id: List[int] = [0],
355356
):
356357
session = QAICInferenceSession(qpc_path, device_ids=device_id)
358+
batch_size = session.bindings[0].dims[0]
357359
seq_len = session.bindings[0].dims[1]
358360
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=seq_len)
359361

360-
prefill_inputs = dict(
362+
inputs = dict(
361363
input_ids=inputs["input_ids"].numpy(),
362364
attention_mask=inputs["attention_mask"].numpy(),
363365
)
364-
prefill_logits = {
365-
"output": np.random.randn(1, seq_len, session.bindings[2].dims[2]).astype(np.float32),
366+
output = {
367+
"output": np.random.randn(batch_size, seq_len, session.bindings[2].dims[2]).astype(np.float32),
366368
}
367-
session.set_buffers(prefill_logits)
368-
prefill_outputs = session.run(prefill_inputs)
369-
return prefill_outputs
370-
369+
session.set_buffers(output)
370+
outputs = session.run(inputs)
371+
return outputs
371372

372373
class QEffTextGenerationBase:
373374
def __init__(

0 commit comments

Comments
 (0)