30
30
from vllm .v1 .attention .backends .pallas import (PallasAttentionBackend ,
31
31
PallasMetadata )
32
32
from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
33
- from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
34
- KVCacheSpec , SlidingWindowSpec )
33
+ from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
34
+ KVCacheConfig , KVCacheSpec ,
35
+ SlidingWindowSpec )
35
36
from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , LogprobsTensors ,
36
37
ModelRunnerOutput )
37
38
from vllm .v1 .sample .tpu .metadata import TPUSupportedSamplingMetadata
@@ -148,6 +149,7 @@ def __init__(
148
149
self .num_kv_heads = model_config .get_num_kv_heads (parallel_config )
149
150
self .head_size = model_config .get_head_size ()
150
151
self .hidden_size = model_config .get_hidden_size ()
152
+ self .vocab_size = model_config .get_vocab_size ()
151
153
152
154
# Multi-modal data support
153
155
self .mm_registry = MULTIMODAL_REGISTRY
@@ -178,7 +180,7 @@ def __init__(
178
180
max_num_blocks_per_req = self .max_num_blocks_per_req ,
179
181
device = self .device ,
180
182
pin_memory = self .pin_memory ,
181
- vocab_size = model_config . get_vocab_size () ,
183
+ vocab_size = self . vocab_size ,
182
184
)
183
185
184
186
# Cached torch/numpy tensor
@@ -221,6 +223,20 @@ def __init__(
221
223
self .num_reqs_paddings = _get_req_paddings (
222
224
min_req_size = MIN_NUM_SEQS , max_req_size = self .max_num_reqs )
223
225
226
+ # tensors for structured decoding
227
+ self .grammar_bitmask_cpu = torch .zeros (
228
+ (self .max_num_reqs , cdiv (self .vocab_size , 32 )),
229
+ dtype = torch .int32 ,
230
+ device = "cpu" ,
231
+ pin_memory = self .pin_memory )
232
+ self .require_structured_out_cpu = torch .zeros (
233
+ (self .max_num_reqs , 1 ),
234
+ dtype = torch .bool ,
235
+ device = "cpu" ,
236
+ pin_memory = self .pin_memory )
237
+ self .structured_decode_arange = torch .arange (
238
+ 0 , 32 , device = "cpu" , pin_memory = self .pin_memory )
239
+
224
240
# Get maximum number of mm items per modality (batch size).
225
241
self .max_num_mm_items_by_modality = dict ()
226
242
if (self .is_multimodal_model and self .max_num_encoder_input_tokens > 0
@@ -762,9 +778,16 @@ def execute_model(
762
778
)
763
779
hidden_states = self .select_hidden_states (hidden_states ,
764
780
logits_indices )
781
+ logits = self .compute_logits (hidden_states )
765
782
tpu_sampling_metadata = TPUSupportedSamplingMetadata .\
766
783
from_input_batch (self .input_batch , padded_num_reqs , self .device )
767
- selected_token_ids = self .sample_from_hidden (hidden_states ,
784
+ if scheduler_output .grammar_bitmask is not None :
785
+ require_struct_decoding , grammar_bitmask_padded , arange = \
786
+ self .prepare_structured_decoding_input (logits , scheduler_output )
787
+ logits = self .structured_decode (require_struct_decoding ,
788
+ grammar_bitmask_padded , logits ,
789
+ arange )
790
+ selected_token_ids = self .sample_from_logits (logits ,
768
791
tpu_sampling_metadata )
769
792
# Remove padding on cpu and keep dynamic op outside of xla graph.
770
793
selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
@@ -997,7 +1020,7 @@ def _precompile_backbone(self) -> None:
997
1020
self ._dummy_run (num_tokens )
998
1021
xm .wait_device_ops ()
999
1022
end = time .perf_counter ()
1000
- logger .info ("Compilation finished in in %.2f [secs]." , end - start )
1023
+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1001
1024
self ._update_num_xla_graphs ("model backbone" )
1002
1025
1003
1026
def _precompile_select_hidden_states (self ) -> None :
@@ -1026,19 +1049,59 @@ def _precompile_select_hidden_states(self) -> None:
1026
1049
break
1027
1050
xm .wait_device_ops ()
1028
1051
end = time .perf_counter ()
1029
- logger .info ("Compilation finished in in %.2f [secs]." , end - start )
1052
+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1030
1053
self ._update_num_xla_graphs ("select_hidden_states" )
1031
1054
1032
- def _precompile_sample_from_hidden (self ) -> None :
1033
- logger .info ("Compiling sampling with different num_reqs ." )
1055
+ def _precompile_compute_logits (self ) -> None :
1056
+ logger .info ("Compiling compute_logits with different input shapes ." )
1034
1057
start = time .perf_counter ()
1035
1058
hsize = self .model_config .get_hidden_size ()
1036
1059
for num_reqs in self .num_reqs_paddings :
1037
1060
dummy_hidden = torch .zeros ((num_reqs , hsize ),
1038
1061
device = self .device ,
1039
1062
dtype = self ._hidden_states_dtype )
1040
- # The first dimension of dummy_hidden cannot be mark_dynamic because
1041
- # some operations in the sampler require it to be static.
1063
+ torch ._dynamo .mark_dynamic (dummy_hidden , 0 )
1064
+ self .compute_logits (dummy_hidden )
1065
+ logger .info (" -- num_seqs: %d" , num_reqs )
1066
+ xm .wait_device_ops ()
1067
+ end = time .perf_counter ()
1068
+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1069
+ self ._update_num_xla_graphs ("compute_logits" )
1070
+
1071
+ def _precompile_structured_decoding (self ) -> None :
1072
+ logger .info (
1073
+ "Compiling structured_decoding with different input shapes." )
1074
+ start = time .perf_counter ()
1075
+ for num_reqs in self .num_reqs_paddings :
1076
+ dummy_logits = torch .zeros ((num_reqs , self .vocab_size ),
1077
+ device = self .device ,
1078
+ dtype = self ._hidden_states_dtype )
1079
+ dummy_require_struct_decoding = \
1080
+ self .require_structured_out_cpu [:num_reqs ].to (self .device )
1081
+ dummy_grammar_bitmask = \
1082
+ self .grammar_bitmask_cpu [:num_reqs ].to (self .device )
1083
+ # The first dimension of the above 3 dummy tensors cannot be
1084
+ # mark_dynamic because some operations in structured_decode require
1085
+ # them to be static.
1086
+ arange = self .structured_decode_arange .to (self .device )
1087
+ self .structured_decode (dummy_require_struct_decoding ,
1088
+ dummy_grammar_bitmask , dummy_logits , arange )
1089
+ logger .info (" -- num_seqs: %d" , num_reqs )
1090
+ xm .wait_device_ops ()
1091
+ end = time .perf_counter ()
1092
+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1093
+ self ._update_num_xla_graphs ("structured_decoding" )
1094
+
1095
+ def _precompile_sample_from_logits (self ) -> None :
1096
+ logger .info (
1097
+ "Compiling sample_from_logits with different input shapes." )
1098
+ start = time .perf_counter ()
1099
+ for num_reqs in self .num_reqs_paddings :
1100
+ dummy_logits = torch .zeros ((num_reqs , self .vocab_size ),
1101
+ device = self .device ,
1102
+ dtype = self ._hidden_states_dtype )
1103
+ # The first dimension of dummy_logits cannot be mark_dynamic
1104
+ # because some operations in the sampler require it to be static.
1042
1105
for all_greedy in [False , True ]:
1043
1106
generate_params_if_all_greedy = not all_greedy
1044
1107
sampling_metadata = (
@@ -1049,12 +1112,12 @@ def _precompile_sample_from_hidden(self) -> None:
1049
1112
generate_params_if_all_greedy ,
1050
1113
))
1051
1114
sampling_metadata .all_greedy = all_greedy
1052
- self .sample_from_hidden ( dummy_hidden , sampling_metadata )
1115
+ self .sample_from_logits ( dummy_logits , sampling_metadata )
1053
1116
logger .info (" -- num_seqs: %d" , num_reqs )
1054
1117
xm .wait_device_ops ()
1055
1118
end = time .perf_counter ()
1056
- logger .info ("Compilation finished in in %.2f [secs]." , end - start )
1057
- self ._update_num_xla_graphs ("sampling " )
1119
+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1120
+ self ._update_num_xla_graphs ("sample_from_logits " )
1058
1121
1059
1122
def capture_model (self ) -> None :
1060
1123
"""
@@ -1063,7 +1126,9 @@ def capture_model(self) -> None:
1063
1126
self ._precompile_mm_encoder ()
1064
1127
self ._precompile_backbone ()
1065
1128
self ._precompile_select_hidden_states ()
1066
- self ._precompile_sample_from_hidden ()
1129
+ self ._precompile_compute_logits ()
1130
+ self ._precompile_structured_decoding ()
1131
+ self ._precompile_sample_from_logits ()
1067
1132
1068
1133
def profile_run (
1069
1134
self ,
@@ -1144,7 +1209,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
1144
1209
tensor_config = kv_cache_config .tensors [layer_name ]
1145
1210
assert tensor_config .size % kv_cache_spec .page_size_bytes == 0
1146
1211
num_blocks = tensor_config .size // kv_cache_spec .page_size_bytes
1147
- if isinstance (kv_cache_spec , FullAttentionSpec ):
1212
+ if isinstance (kv_cache_spec , AttentionSpec ):
1148
1213
kv_cache_shape = PallasAttentionBackend .get_kv_cache_shape (
1149
1214
num_blocks , kv_cache_spec .block_size ,
1150
1215
kv_cache_spec .num_kv_heads , kv_cache_spec .head_size )
@@ -1179,29 +1244,86 @@ def select_hidden_states(self, hidden_states, indices_do_sample):
1179
1244
return hidden_states [indices_do_sample ]
1180
1245
1181
1246
@torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1182
- def sample_from_hidden (
1183
- self ,
1184
- sample_hidden_states : torch .Tensor ,
1185
- sampling_metadata : TPUSupportedSamplingMetadata ,
1186
- ) -> torch .Tensor :
1187
- """
1188
- Sample with xla-friendly function. This function is to be traced
1189
- separately from `forward` for lighter compilation overhead.
1190
- """
1191
- logits = self .model .compute_logits (sample_hidden_states , None )
1247
+ def compute_logits (self ,
1248
+ sample_hidden_states : torch .Tensor ) -> torch .Tensor :
1249
+ return self .model .compute_logits (sample_hidden_states , None )
1250
+
1251
+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1252
+ def sample_from_logits (
1253
+ self , logits : torch .Tensor ,
1254
+ sampling_metadata : TPUSupportedSamplingMetadata ) -> torch .Tensor :
1192
1255
if sampling_metadata .all_greedy :
1193
1256
out_tokens = torch .argmax (logits , dim = - 1 , keepdim = True )
1194
1257
else :
1195
1258
out_tokens = self .sampler (logits ,
1196
1259
sampling_metadata ).sampled_token_ids
1197
1260
return out_tokens
1198
1261
1262
+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1263
+ def structured_decode (self , require_struct_decoding : torch .Tensor ,
1264
+ grammar_bitmask : torch .Tensor , logits : torch .Tensor ,
1265
+ arange : torch .Tensor ) -> torch .Tensor :
1266
+ return torch .where (
1267
+ require_struct_decoding ,
1268
+ self .apply_grammar_bitmask (logits , grammar_bitmask , arange ),
1269
+ logits )
1270
+
1271
+ def apply_grammar_bitmask (self , logits : torch .Tensor ,
1272
+ grammar_bitmask : torch .Tensor ,
1273
+ arange : torch .Tensor ):
1274
+ assert (logits .shape [0 ] == grammar_bitmask .shape [0 ])
1275
+ logits_cloned = logits .clone ()
1276
+ for i in range (logits .shape [0 ]):
1277
+ unpacked_bitmask = (torch .bitwise_right_shift (
1278
+ grammar_bitmask [i ][:, None ], arange [None , :]) & 1 ) == 0
1279
+ unpacked_bitmask = unpacked_bitmask .reshape (- 1 )[:self .vocab_size ]
1280
+ logits_cloned [i ] = logits_cloned [i ].masked_fill (
1281
+ unpacked_bitmask , - float ("inf" ))
1282
+ return logits_cloned
1283
+
1199
1284
def get_multimodal_embeddings (self , * args , ** kwargs ):
1200
1285
return self .model .get_multimodal_embeddings (* args , ** kwargs )
1201
1286
1202
1287
def get_input_embeddings (self , * args , ** kwargs ):
1203
1288
return self .model .get_input_embeddings (* args , ** kwargs )
1204
1289
1290
+ def prepare_structured_decoding_input (
1291
+ self , logits : torch .Tensor , scheduler_output : "SchedulerOutput"
1292
+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1293
+ grammar_bitmask = scheduler_output .grammar_bitmask
1294
+ assert grammar_bitmask is not None
1295
+ num_reqs , _ = logits .shape
1296
+
1297
+ # Reset pre-allocated tensors
1298
+ self .grammar_bitmask_cpu .zero_ ()
1299
+ self .require_structured_out_cpu .zero_ ()
1300
+
1301
+ # We receive the structured output bitmask from the scheduler, but the
1302
+ # indices of the requests in the batch may not match the indices of
1303
+ # the bitmask since the scheduler doesn't know how the tpu runner is
1304
+ # ordering the requests in the batch. We need to match the order of
1305
+ # bitmask with the order of requests
1306
+ struct_out_indices : list [int ] = []
1307
+ mask_indices : list [int ] = []
1308
+ for req_id in self .input_batch .req_ids :
1309
+ mask_index = scheduler_output .structured_output_request_ids .get (
1310
+ req_id )
1311
+ if mask_index is None :
1312
+ continue
1313
+ batch_index = self .input_batch .req_id_to_index [req_id ]
1314
+ struct_out_indices .append (batch_index )
1315
+ mask_indices .append (mask_index )
1316
+ self .grammar_bitmask_cpu [struct_out_indices ] = torch .from_numpy (
1317
+ grammar_bitmask [mask_indices ])
1318
+ # It's not guaranteed that all requests in this batch require
1319
+ # structured output, so create a bool tensor to represent
1320
+ # the requests that need structured output.
1321
+ struct_out_indices = torch .tensor (struct_out_indices , dtype = torch .long )
1322
+ self .require_structured_out_cpu [struct_out_indices ] = True
1323
+ return self .require_structured_out_cpu [:num_reqs ].to (logits .device ), \
1324
+ self .grammar_bitmask_cpu [:num_reqs ].to (logits .device ), \
1325
+ self .structured_decode_arange .to (logits .device )
1326
+
1205
1327
def _get_mm_dummy_batch (self , modality : str ,
1206
1328
batch_size : int ) -> BatchedTensorInputs :
1207
1329
# Dummy data for pre-compiling multimodal models.
0 commit comments