Skip to content

[V1] Remove num_input_tokens from attn_metadata #17193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
if attn_metadata is not None:
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize
Expand Down Expand Up @@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
batchsize = num_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ class FlashAttentionMetadata:
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None

# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.

# for local attention
@dataclass
class LocalAttentionMetadata:
Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ class FlashInferMetadata:
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None

# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.

@property
def query_start_loc(self):
# The GPUModelRunner expects to be able to access this property.
Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]):
num_decode_tokens: int
num_prefills: int

# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.

# The dimension of the attention heads
head_dim: Optional[int] = None

Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,6 @@ def execute_model(
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens

# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
Expand Down Expand Up @@ -1088,7 +1087,9 @@ def execute_model(

# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
output = self.model(
input_ids=input_ids,
positions=positions,
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,10 @@ def execute_model(
xm.mark_step()
num_reqs = self.input_batch.num_reqs
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=scheduler_output.total_num_scheduled_tokens):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
Expand Down