Skip to content

add causal-conv1d in Triton and integrate into vLLM with test code #18206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions tests/kernels/mamba/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn_triton, causal_conv1d_update_triton)
from vllm.platforms import current_platform


Expand Down Expand Up @@ -435,3 +438,225 @@
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding, dim, width,
seqlen, has_bias,
silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2

# set seed
current_platform.seed_everything(0)

padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size

channel_last = True
if not channel_last:
x = torch.randn(padded_batch_size, dim, seqlen, device=device, dtype=itype)
else:
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size, seqlen, dim, device=device, dtype=itype).transpose(1, 2)

x_ref = x.clone()

conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat([
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)

if not channel_last:
conv_state = torch.randn(total_entries,
dim,
width - 1,
device=device,
dtype=itype)
else:
# conv_state will be (cache_lines, dim, state_len) with contiguous along dim-axis
conv_state = torch.randn(total_entries,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)

conv_state_for_padding_test = conv_state.clone()
conv_state_origin = conv_state.clone()

Check failure on line 507 in tests/kernels/mamba/test_causal_conv1d.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/kernels/mamba/test_causal_conv1d.py:507:81: E501 Line too long (89 > 80)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"

out = causal_conv1d_update_triton(x,
conv_state,
weight,

Check failure on line 515 in tests/kernels/mamba/test_causal_conv1d.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/kernels/mamba/test_causal_conv1d.py:515:5: F841 Local variable `conv_state_origin` is assigned to but never used
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)

assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [8, 16, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
@pytest.mark.parametrize('with_padding', [True, False])
@pytest.mark.parametrize('batch', [4])
def test_causal_conv1d_varlen_vllm(batch, with_padding, dim, seqlen, width, has_bias,
silu_activation, itype):
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
current_platform.seed_everything(0)
seqlens = []
batch_size = batch
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1

eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values

seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])

total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
channel_last = True
if not channel_last:
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
else:
x = rearrange(torch.randn(1, seqlen, 4096 + dim + 64, device=device,
dtype=itype), "b s d -> b d s")[:, 4096:4096 + dim, :]

weight = torch.randn(dim, width, device=device, dtype=itype)

bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
if not channel_last:
final_states = torch.randn(total_entries,
dim,
width - 1,
device=x.device,
dtype=x.dtype)
else:
final_states = torch.randn(total_entries,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=x.device)[:batch_size]
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
out = causal_conv1d_fn_triton(
x.squeeze(0),
weight,
bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_states=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID)

out_ref = []
out_ref_b = []

splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[
padded_state_indices[i]].unsqueeze(0),
initial_states=final_states_ref[padded_state_indices[i]].
unsqueeze(0) if has_initial_states[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)

try:
assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol)
print("Passed conv_state")
except Exception as e:
print("FAILED conv_state")
raise e
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
try:
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
except Exception as e:
input("Passed conv_state, but failed output: Press Enter to continue...")

nz = out_ref_tensor.squeeze(0)-unpadded_out
non_zero_indices = torch.nonzero(nz)
print('nonzero indices :', non_zero_indices)
raise e
48 changes: 48 additions & 0 deletions vllm/model_executor/layers/mamba/mamba2_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
import numpy as np

from typing import Optional

@dataclass
class Mamba2Metadata:
Expand All @@ -22,6 +24,31 @@
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor

num_cache_lines : Optional[int] = None
stride_istate_seq: Optional[int] = None
stride_istate_dim: Optional[int] = None
stride_istate_token: Optional[int] = None
seqlens: Optional[np.ndarray] = None
padded_batch : Optional[int] = None
nums_dict: Optional[dict] = None
is_channel_last: bool = True
stride_w_dim: Optional[int] = None
stride_w_width: Optional[int] = None
width: Optional[int] = None
np2_statelen: Optional[int] = None
stride_x_seq: Optional[int] = 0
stride_x_dim : Optional[int] = None
stride_x_token: Optional[int] = None
dim: Optional[int] = None
cu_seqlen : Optional[int] = None
out: Optional[torch.Tensor] = None
stride_o_seq: Optional[int] = 0
stride_o_dim: Optional[int] = None
stride_o_token: Optional[int] = None
MAX_NUM_PROGRAMS: Optional[int] = 1024
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None


def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
Expand Down Expand Up @@ -62,6 +89,7 @@
def prepare_mamba2_metadata(
chunk_size: int,
attn_metadata: AttentionMetadata,
mamba2_metadata = None,
) -> Mamba2Metadata:

# compute number of prefill and decode requests
Expand All @@ -78,7 +106,13 @@

# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing
# TODO: maybe revert back to the original code (below) if above no longer holds

Check failure on line 110 in vllm/model_executor/layers/mamba/mamba2_metadata.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/mamba/mamba2_metadata.py:110:81: E501 Line too long (122 > 80)
# has_initial_states = attn_metadata.context_lens_tensor > 0

Check failure on line 111 in vllm/model_executor/layers/mamba/mamba2_metadata.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/mamba/mamba2_metadata.py:111:81: E501 Line too long (87 > 80)
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states]
# mamba_cache_params.ssm_state[zero_init_indices] = 0

Check failure on line 113 in vllm/model_executor/layers/mamba/mamba2_metadata.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/mamba/mamba2_metadata.py:113:81: E501 Line too long (90 > 80)
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor]
if (isinstance(attn_metadata,

Check failure on line 115 in vllm/model_executor/layers/mamba/mamba2_metadata.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/mamba/mamba2_metadata.py:115:81: E501 Line too long (96 > 80)
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
Expand All @@ -103,6 +137,20 @@
_query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens)

if mamba2_metadata is not None:
mamba2_metadata.has_initial_states=has_initial_states
mamba2_metadata.prep_initial_states=prep_initial_states
mamba2_metadata.chunk_size=chunk_size
mamba2_metadata.seq_idx=seq_idx
mamba2_metadata.chunk_indices=chunk_indices
mamba2_metadata.chunk_offsets=chunk_offsets
# We use 2 reset flags:
# * mamba2_metadata.width is None # update config at first run (never change whole session for a given model)
# (become available at first layer, e.g. conv_weights)

Check failure on line 149 in vllm/model_executor/layers/mamba/mamba2_metadata.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/mamba/mamba2_metadata.py:149:81: E501 Line too long (118 > 80)
# * mamba2_metadata.cu_seqlen is None # update config specific to (each input)

Check failure on line 150 in vllm/model_executor/layers/mamba/mamba2_metadata.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/mamba/mamba2_metadata.py:150:81: E501 Line too long (97 > 80)
# (become available at first layer, e.g. conv_weights)

Check failure on line 151 in vllm/model_executor/layers/mamba/mamba2_metadata.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/mamba/mamba2_metadata.py:151:81: E501 Line too long (87 > 80)
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input

Check failure on line 152 in vllm/model_executor/layers/mamba/mamba2_metadata.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/mamba/mamba2_metadata.py:152:81: E501 Line too long (97 > 80)
return mamba2_metadata
return Mamba2Metadata(has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
Expand Down
Loading