Skip to content

Commit f772264

Browse files
Add granite speech model test
Signed-off-by: Alex-Brooks <[email protected]>
1 parent b309dc5 commit f772264

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from collections.abc import Sequence
4+
from typing import Optional
5+
6+
import pytest
7+
from transformers import AutoModelForSpeechSeq2Seq
8+
9+
from vllm.lora.request import LoRARequest
10+
from vllm.sequence import SampleLogprobs
11+
12+
from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets
13+
from ...registry import HF_EXAMPLE_MODELS
14+
from ...utils import check_logprobs_close
15+
16+
HF_AUDIO_PROMPT = "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>can you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
17+
18+
19+
def vllm_to_hf_output(
20+
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
21+
) -> tuple[list[int], str, Optional[SampleLogprobs]]:
22+
"""Sanitize hf output to be comparable with vllm output."""
23+
output_ids, output_str, out_logprobs = vllm_output
24+
25+
hf_output_str = output_str + "<|end_of_text|>"
26+
27+
return output_ids, hf_output_str, out_logprobs
28+
29+
30+
MODEL_NAME = "ibm-granite/granite-speech-3.3-8b"
31+
# Audio lora co-exists directly in the model directory, but
32+
# currently still needs to be passed directly to vLLM.
33+
audio_lora_path = MODEL_NAME
34+
models = [MODEL_NAME]
35+
36+
37+
def run_test(
38+
hf_runner: type[HfRunner],
39+
vllm_runner: type[VllmRunner],
40+
inputs: Sequence[tuple[list[str], PromptAudioInput]],
41+
model: str,
42+
*,
43+
max_model_len: int,
44+
dtype: str,
45+
max_tokens: int,
46+
num_logprobs: int,
47+
tensor_parallel_size: int,
48+
distributed_executor_backend: Optional[str] = None,
49+
):
50+
"""Inference result should be the same between hf and vllm.
51+
52+
All the audio fixtures for the test are from AUDIO_ASSETS.
53+
For huggingface runner, we provide the audio as input.
54+
For vllm runner, we provide MultiModalDataDict objects
55+
and corresponding MultiModalConfig as input.
56+
Note, the text input is also adjusted to abide by vllm contract.
57+
The text output is sanitized to be able to compare with hf.
58+
"""
59+
# NOTE: take care of the order. run vLLM first, and then run HF.
60+
# vLLM needs a fresh new process without cuda initialization.
61+
# if we run HF first, the cuda initialization will be done and it
62+
# will hurt multiprocessing backend with fork method (the default method).
63+
# max_model_len should be greater than image_feature_size
64+
with vllm_runner(
65+
model,
66+
task="generate",
67+
max_model_len=max_model_len,
68+
max_num_seqs=1,
69+
dtype=dtype,
70+
limit_mm_per_prompt={"audio": 1},
71+
tensor_parallel_size=tensor_parallel_size,
72+
distributed_executor_backend=distributed_executor_backend,
73+
enable_lora=True,
74+
max_lora_rank=64,
75+
enforce_eager=True,
76+
) as vllm_model:
77+
lora_request = LoRARequest("audio", 1, audio_lora_path)
78+
vllm_outputs_per_case = [
79+
vllm_model.generate_greedy_logprobs(prompts,
80+
max_tokens,
81+
num_logprobs=num_logprobs,
82+
audios=audios,
83+
lora_request=lora_request)
84+
for prompts, audios in inputs
85+
]
86+
87+
with hf_runner(model, dtype=dtype,
88+
auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
89+
90+
hf_processor = hf_model.processor
91+
eos_token_id = hf_processor.tokenizer.eos_token_id
92+
93+
hf_outputs_per_case = [
94+
hf_model.generate_greedy_logprobs_limit(prompts,
95+
max_tokens,
96+
num_logprobs=num_logprobs,
97+
audios=[audios],
98+
eos_token_id=eos_token_id)
99+
for prompts, audios in inputs
100+
]
101+
102+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
103+
vllm_outputs_per_case):
104+
check_logprobs_close(
105+
outputs_0_lst=hf_outputs,
106+
outputs_1_lst=[
107+
vllm_to_hf_output(output) for output in vllm_outputs
108+
],
109+
name_0="hf",
110+
name_1="vllm",
111+
)
112+
113+
114+
@pytest.mark.parametrize("model", models)
115+
@pytest.mark.parametrize("dtype", ["bfloat16"])
116+
@pytest.mark.parametrize("max_model_len", [2048])
117+
@pytest.mark.parametrize("max_tokens", [128])
118+
@pytest.mark.parametrize("num_logprobs", [10])
119+
def test_models(hf_runner, vllm_runner, model, audio_assets: _AudioAssets,
120+
dtype: str, max_model_len: int, max_tokens: int,
121+
num_logprobs: int) -> None:
122+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
123+
model_info.check_available_online(on_fail="skip")
124+
model_info.check_transformers_version(on_fail="skip")
125+
126+
audio, sr = audio_assets[0].audio_and_sample_rate
127+
# This model expects 16k sample rate, which our test audio
128+
# already is; if this changes, it may break this test,
129+
# so we check it directly
130+
assert sr == 16000
131+
run_test(
132+
hf_runner,
133+
vllm_runner,
134+
[[[HF_AUDIO_PROMPT], [audio]]],
135+
model,
136+
dtype=dtype,
137+
max_model_len=max_model_len,
138+
max_tokens=max_tokens,
139+
num_logprobs=num_logprobs,
140+
tensor_parallel_size=1,
141+
)

0 commit comments

Comments
 (0)