Skip to content

Commit 9351a59

Browse files
Chenyaaanglk-chen
authored andcommitted
[Core][V1][TPU] Enable structured decoding on TPU V1 (vllm-project#16499)
Signed-off-by: Chenyaaang <[email protected]>
1 parent fc61ddf commit 9351a59

File tree

5 files changed

+158
-31
lines changed

5 files changed

+158
-31
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ docker run --privileged --net host --shm-size=16G -it \
4444
&& echo TEST_9 \
4545
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \
4646
&& echo TEST_10 \
47-
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
47+
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
48+
&& echo TEST_11 \
49+
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
4850

4951

5052
# TODO: This test fails because it uses RANDOM_SEED sampling

benchmarks/benchmark_serving_structured_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
except ImportError:
5252
from argparse import ArgumentParser as FlexibleArgumentParser
5353

54-
from vllm.v1.structured_output.utils import (
54+
from vllm.v1.structured_output.backend_xgrammar import (
5555
has_xgrammar_unsupported_json_features)
5656

5757
MILLISECONDS_TO_SECONDS_CONVERSION = 1000

tests/v1/tpu/test_sampler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_sampler_different(model_name: str):
2323
different results.
2424
"""
2525
llm = LLM(model_name,
26-
enforce_eager=False,
26+
enforce_eager=True,
2727
max_num_seqs=1,
2828
max_model_len=512,
2929
max_num_batched_tokens=512)
@@ -57,4 +57,7 @@ def test_sampler_different(model_name: str):
5757
# Make sure first two reqs have the same K/P
5858
sampling_params[0] = sampling_params[1]
5959
output = llm.generate(p, sampling_params)
60-
assert output[0].outputs[0].text == output[1].outputs[0].text
60+
# There are natural numerical instabilities that make it difficult
61+
# to have deterministic results over many tokens, tests the first ~20
62+
# tokens match.
63+
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]

vllm/platforms/tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def validate_request(
168168
) -> None:
169169
"""Raises if this request is unsupported on this platform"""
170170
if isinstance(params, SamplingParams):
171-
if params.guided_decoding is not None:
171+
if params.guided_decoding is not None and not envs.VLLM_USE_V1:
172172
raise ValueError("Structured output is not supported on "
173-
f"{cls.device_name}.")
173+
f"{cls.device_name} V0.")
174174
if params.sampling_type == SamplingType.RANDOM_SEED:
175175
raise ValueError(
176176
"Torch XLA does not support per-request seed.")

vllm/v1/worker/tpu_model_runner.py

Lines changed: 147 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
3131
PallasMetadata)
3232
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
33-
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
34-
KVCacheSpec, SlidingWindowSpec)
33+
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
34+
KVCacheConfig, KVCacheSpec,
35+
SlidingWindowSpec)
3536
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
3637
ModelRunnerOutput)
3738
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
@@ -148,6 +149,7 @@ def __init__(
148149
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
149150
self.head_size = model_config.get_head_size()
150151
self.hidden_size = model_config.get_hidden_size()
152+
self.vocab_size = model_config.get_vocab_size()
151153

152154
# Multi-modal data support
153155
self.mm_registry = MULTIMODAL_REGISTRY
@@ -178,7 +180,7 @@ def __init__(
178180
max_num_blocks_per_req=self.max_num_blocks_per_req,
179181
device=self.device,
180182
pin_memory=self.pin_memory,
181-
vocab_size=model_config.get_vocab_size(),
183+
vocab_size=self.vocab_size,
182184
)
183185

184186
# Cached torch/numpy tensor
@@ -221,6 +223,20 @@ def __init__(
221223
self.num_reqs_paddings = _get_req_paddings(
222224
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
223225

226+
# tensors for structured decoding
227+
self.grammar_bitmask_cpu = torch.zeros(
228+
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
229+
dtype=torch.int32,
230+
device="cpu",
231+
pin_memory=self.pin_memory)
232+
self.require_structured_out_cpu = torch.zeros(
233+
(self.max_num_reqs, 1),
234+
dtype=torch.bool,
235+
device="cpu",
236+
pin_memory=self.pin_memory)
237+
self.structured_decode_arange = torch.arange(
238+
0, 32, device="cpu", pin_memory=self.pin_memory)
239+
224240
# Get maximum number of mm items per modality (batch size).
225241
self.max_num_mm_items_by_modality = dict()
226242
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
@@ -762,9 +778,16 @@ def execute_model(
762778
)
763779
hidden_states = self.select_hidden_states(hidden_states,
764780
logits_indices)
781+
logits = self.compute_logits(hidden_states)
765782
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
766783
from_input_batch(self.input_batch, padded_num_reqs, self.device)
767-
selected_token_ids = self.sample_from_hidden(hidden_states,
784+
if scheduler_output.grammar_bitmask is not None:
785+
require_struct_decoding, grammar_bitmask_padded, arange = \
786+
self.prepare_structured_decoding_input(logits, scheduler_output)
787+
logits = self.structured_decode(require_struct_decoding,
788+
grammar_bitmask_padded, logits,
789+
arange)
790+
selected_token_ids = self.sample_from_logits(logits,
768791
tpu_sampling_metadata)
769792
# Remove padding on cpu and keep dynamic op outside of xla graph.
770793
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
@@ -997,7 +1020,7 @@ def _precompile_backbone(self) -> None:
9971020
self._dummy_run(num_tokens)
9981021
xm.wait_device_ops()
9991022
end = time.perf_counter()
1000-
logger.info("Compilation finished in in %.2f [secs].", end - start)
1023+
logger.info("Compilation finished in %.2f [secs].", end - start)
10011024
self._update_num_xla_graphs("model backbone")
10021025

10031026
def _precompile_select_hidden_states(self) -> None:
@@ -1026,19 +1049,59 @@ def _precompile_select_hidden_states(self) -> None:
10261049
break
10271050
xm.wait_device_ops()
10281051
end = time.perf_counter()
1029-
logger.info("Compilation finished in in %.2f [secs].", end - start)
1052+
logger.info("Compilation finished in %.2f [secs].", end - start)
10301053
self._update_num_xla_graphs("select_hidden_states")
10311054

1032-
def _precompile_sample_from_hidden(self) -> None:
1033-
logger.info("Compiling sampling with different num_reqs.")
1055+
def _precompile_compute_logits(self) -> None:
1056+
logger.info("Compiling compute_logits with different input shapes.")
10341057
start = time.perf_counter()
10351058
hsize = self.model_config.get_hidden_size()
10361059
for num_reqs in self.num_reqs_paddings:
10371060
dummy_hidden = torch.zeros((num_reqs, hsize),
10381061
device=self.device,
10391062
dtype=self._hidden_states_dtype)
1040-
# The first dimension of dummy_hidden cannot be mark_dynamic because
1041-
# some operations in the sampler require it to be static.
1063+
torch._dynamo.mark_dynamic(dummy_hidden, 0)
1064+
self.compute_logits(dummy_hidden)
1065+
logger.info(" -- num_seqs: %d", num_reqs)
1066+
xm.wait_device_ops()
1067+
end = time.perf_counter()
1068+
logger.info("Compilation finished in %.2f [secs].", end - start)
1069+
self._update_num_xla_graphs("compute_logits")
1070+
1071+
def _precompile_structured_decoding(self) -> None:
1072+
logger.info(
1073+
"Compiling structured_decoding with different input shapes.")
1074+
start = time.perf_counter()
1075+
for num_reqs in self.num_reqs_paddings:
1076+
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
1077+
device=self.device,
1078+
dtype=self._hidden_states_dtype)
1079+
dummy_require_struct_decoding = \
1080+
self.require_structured_out_cpu[:num_reqs].to(self.device)
1081+
dummy_grammar_bitmask = \
1082+
self.grammar_bitmask_cpu[:num_reqs].to(self.device)
1083+
# The first dimension of the above 3 dummy tensors cannot be
1084+
# mark_dynamic because some operations in structured_decode require
1085+
# them to be static.
1086+
arange = self.structured_decode_arange.to(self.device)
1087+
self.structured_decode(dummy_require_struct_decoding,
1088+
dummy_grammar_bitmask, dummy_logits, arange)
1089+
logger.info(" -- num_seqs: %d", num_reqs)
1090+
xm.wait_device_ops()
1091+
end = time.perf_counter()
1092+
logger.info("Compilation finished in %.2f [secs].", end - start)
1093+
self._update_num_xla_graphs("structured_decoding")
1094+
1095+
def _precompile_sample_from_logits(self) -> None:
1096+
logger.info(
1097+
"Compiling sample_from_logits with different input shapes.")
1098+
start = time.perf_counter()
1099+
for num_reqs in self.num_reqs_paddings:
1100+
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
1101+
device=self.device,
1102+
dtype=self._hidden_states_dtype)
1103+
# The first dimension of dummy_logits cannot be mark_dynamic
1104+
# because some operations in the sampler require it to be static.
10421105
for all_greedy in [False, True]:
10431106
generate_params_if_all_greedy = not all_greedy
10441107
sampling_metadata = (
@@ -1049,12 +1112,12 @@ def _precompile_sample_from_hidden(self) -> None:
10491112
generate_params_if_all_greedy,
10501113
))
10511114
sampling_metadata.all_greedy = all_greedy
1052-
self.sample_from_hidden(dummy_hidden, sampling_metadata)
1115+
self.sample_from_logits(dummy_logits, sampling_metadata)
10531116
logger.info(" -- num_seqs: %d", num_reqs)
10541117
xm.wait_device_ops()
10551118
end = time.perf_counter()
1056-
logger.info("Compilation finished in in %.2f [secs].", end - start)
1057-
self._update_num_xla_graphs("sampling")
1119+
logger.info("Compilation finished in %.2f [secs].", end - start)
1120+
self._update_num_xla_graphs("sample_from_logits")
10581121

10591122
def capture_model(self) -> None:
10601123
"""
@@ -1063,7 +1126,9 @@ def capture_model(self) -> None:
10631126
self._precompile_mm_encoder()
10641127
self._precompile_backbone()
10651128
self._precompile_select_hidden_states()
1066-
self._precompile_sample_from_hidden()
1129+
self._precompile_compute_logits()
1130+
self._precompile_structured_decoding()
1131+
self._precompile_sample_from_logits()
10671132

10681133
def profile_run(
10691134
self,
@@ -1144,7 +1209,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
11441209
tensor_config = kv_cache_config.tensors[layer_name]
11451210
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
11461211
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
1147-
if isinstance(kv_cache_spec, FullAttentionSpec):
1212+
if isinstance(kv_cache_spec, AttentionSpec):
11481213
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
11491214
num_blocks, kv_cache_spec.block_size,
11501215
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
@@ -1179,29 +1244,86 @@ def select_hidden_states(self, hidden_states, indices_do_sample):
11791244
return hidden_states[indices_do_sample]
11801245

11811246
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1182-
def sample_from_hidden(
1183-
self,
1184-
sample_hidden_states: torch.Tensor,
1185-
sampling_metadata: TPUSupportedSamplingMetadata,
1186-
) -> torch.Tensor:
1187-
"""
1188-
Sample with xla-friendly function. This function is to be traced
1189-
separately from `forward` for lighter compilation overhead.
1190-
"""
1191-
logits = self.model.compute_logits(sample_hidden_states, None)
1247+
def compute_logits(self,
1248+
sample_hidden_states: torch.Tensor) -> torch.Tensor:
1249+
return self.model.compute_logits(sample_hidden_states, None)
1250+
1251+
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1252+
def sample_from_logits(
1253+
self, logits: torch.Tensor,
1254+
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
11921255
if sampling_metadata.all_greedy:
11931256
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
11941257
else:
11951258
out_tokens = self.sampler(logits,
11961259
sampling_metadata).sampled_token_ids
11971260
return out_tokens
11981261

1262+
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1263+
def structured_decode(self, require_struct_decoding: torch.Tensor,
1264+
grammar_bitmask: torch.Tensor, logits: torch.Tensor,
1265+
arange: torch.Tensor) -> torch.Tensor:
1266+
return torch.where(
1267+
require_struct_decoding,
1268+
self.apply_grammar_bitmask(logits, grammar_bitmask, arange),
1269+
logits)
1270+
1271+
def apply_grammar_bitmask(self, logits: torch.Tensor,
1272+
grammar_bitmask: torch.Tensor,
1273+
arange: torch.Tensor):
1274+
assert (logits.shape[0] == grammar_bitmask.shape[0])
1275+
logits_cloned = logits.clone()
1276+
for i in range(logits.shape[0]):
1277+
unpacked_bitmask = (torch.bitwise_right_shift(
1278+
grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0
1279+
unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size]
1280+
logits_cloned[i] = logits_cloned[i].masked_fill(
1281+
unpacked_bitmask, -float("inf"))
1282+
return logits_cloned
1283+
11991284
def get_multimodal_embeddings(self, *args, **kwargs):
12001285
return self.model.get_multimodal_embeddings(*args, **kwargs)
12011286

12021287
def get_input_embeddings(self, *args, **kwargs):
12031288
return self.model.get_input_embeddings(*args, **kwargs)
12041289

1290+
def prepare_structured_decoding_input(
1291+
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
1292+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1293+
grammar_bitmask = scheduler_output.grammar_bitmask
1294+
assert grammar_bitmask is not None
1295+
num_reqs, _ = logits.shape
1296+
1297+
# Reset pre-allocated tensors
1298+
self.grammar_bitmask_cpu.zero_()
1299+
self.require_structured_out_cpu.zero_()
1300+
1301+
# We receive the structured output bitmask from the scheduler, but the
1302+
# indices of the requests in the batch may not match the indices of
1303+
# the bitmask since the scheduler doesn't know how the tpu runner is
1304+
# ordering the requests in the batch. We need to match the order of
1305+
# bitmask with the order of requests
1306+
struct_out_indices: list[int] = []
1307+
mask_indices: list[int] = []
1308+
for req_id in self.input_batch.req_ids:
1309+
mask_index = scheduler_output.structured_output_request_ids.get(
1310+
req_id)
1311+
if mask_index is None:
1312+
continue
1313+
batch_index = self.input_batch.req_id_to_index[req_id]
1314+
struct_out_indices.append(batch_index)
1315+
mask_indices.append(mask_index)
1316+
self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy(
1317+
grammar_bitmask[mask_indices])
1318+
# It's not guaranteed that all requests in this batch require
1319+
# structured output, so create a bool tensor to represent
1320+
# the requests that need structured output.
1321+
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
1322+
self.require_structured_out_cpu[struct_out_indices] = True
1323+
return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
1324+
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
1325+
self.structured_decode_arange.to(logits.device)
1326+
12051327
def _get_mm_dummy_batch(self, modality: str,
12061328
batch_size: int) -> BatchedTensorInputs:
12071329
# Dummy data for pre-compiling multimodal models.

0 commit comments

Comments
 (0)