@@ -174,7 +174,8 @@ def get_compilation_dims(qpc_path: str) -> Tuple[int, int, Optional[int]]:
174
174
raise FileNotFoundError (f"expected specializations.json file at path, { qpc_base_path } " )
175
175
176
176
compilation_batch_size = int (data ["specializations" ][0 ]["batch_size" ])
177
- compilation_ctx_len = int (data ["specializations" ][0 ]["ctx_len" ])
177
+ if compilation_ctx_len := data ["specializations" ][0 ].get ("ctx_len" , None ):
178
+ compilation_ctx_len = int (data ["specializations" ][0 ]["ctx_len" ])
178
179
if compilation_fbs := data ["specializations" ][0 ].get ("full_batch_size" , None ):
179
180
compilation_fbs = int (compilation_fbs )
180
181
return compilation_batch_size , compilation_ctx_len , compilation_fbs
@@ -349,25 +350,25 @@ def cloud_ai_100_exec_kv(
349
350
350
351
def cloud_ai_100_exec_embed (
351
352
tokenizer : Union [PreTrainedTokenizerFast , PreTrainedTokenizer ],
352
- prompt : List [str ],
353
353
qpc_path : str ,
354
- device_id : List [int ] = [0 ],
354
+ prompt : List [str ],
355
+ device_id : List [int ] = [0 ],
355
356
):
356
357
session = QAICInferenceSession (qpc_path , device_ids = device_id )
358
+ batch_size = session .bindings [0 ].dims [0 ]
357
359
seq_len = session .bindings [0 ].dims [1 ]
358
360
inputs = tokenizer (prompt , return_tensors = "pt" , padding = "max_length" , max_length = seq_len )
359
361
360
- prefill_inputs = dict (
362
+ inputs = dict (
361
363
input_ids = inputs ["input_ids" ].numpy (),
362
364
attention_mask = inputs ["attention_mask" ].numpy (),
363
365
)
364
- prefill_logits = {
365
- "output" : np .random .randn (1 , seq_len , session .bindings [2 ].dims [2 ]).astype (np .float32 ),
366
+ output = {
367
+ "output" : np .random .randn (batch_size , seq_len , session .bindings [2 ].dims [2 ]).astype (np .float32 ),
366
368
}
367
- session .set_buffers (prefill_logits )
368
- prefill_outputs = session .run (prefill_inputs )
369
- return prefill_outputs
370
-
369
+ session .set_buffers (output )
370
+ outputs = session .run (inputs )
371
+ return outputs
371
372
372
373
class QEffTextGenerationBase :
373
374
def __init__ (
0 commit comments