Skip to content

Commit 090d2e5

Browse files
Resync with merged transformers model code
Signed-off-by: Alex-Brooks <[email protected]>
1 parent 3eb71a8 commit 090d2e5

File tree

1 file changed

+103
-96
lines changed

1 file changed

+103
-96
lines changed

vllm/model_executor/models/granite_speech.py

Lines changed: 103 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ def __init__(
209209
config.projector_config.hidden_size))
210210
self.query.data.normal_(mean=0.0, std=1.0)
211211

212+
# NOTE - this is implemented generically in transformers,
213+
# but for now we create the QFormer model directly since
214+
# all existing models use this for the projector.
212215
self.qformer = Blip2QFormerModel(
213216
config.projector_config,
214217
quant_config=quant_config,
@@ -242,58 +245,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
242245

243246

244247
### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
245-
class GraniteSpeechCTCEncoder(nn.Module):
246-
247-
def __init__(self, config: PretrainedConfig):
248-
super().__init__()
249-
self.config = config
250-
self.input_linear = nn.Linear(config.input_dim,
251-
config.hidden_dim,
252-
bias=True)
253-
self.layers = nn.ModuleList([
254-
GraniteSpeechConformerBlock(config)
255-
for _ in range(config.num_layers)
256-
])
257-
258-
self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True)
259-
self.out_mid = nn.Linear(config.output_dim,
260-
config.hidden_dim,
261-
bias=True)
262-
self.num_layers = config.num_layers
263-
264-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265-
hidden_states = self.input_linear(hidden_states)
266-
for idx, layer in enumerate(self.layers, start=1):
267-
hidden_states = layer(hidden_states)
268-
if idx == self.num_layers // 2:
269-
hidden_states_mid = hidden_states.clone()
270-
hidden_states_mid = self.out(hidden_states_mid)
271-
hidden_states += self.out_mid(
272-
nn.Softmax(dim=-1)(hidden_states_mid))
273-
return hidden_states
274-
275-
276-
class GraniteSpeechConformerBlock(nn.Module):
277-
"""Conformer block, consisting largely of linear layers,
278-
attention, and convolutional layers."""
279-
280-
def __init__(self, config: PretrainedConfig):
281-
super().__init__()
282-
self.ff1 = GraniteSpeechConformerFeedForward(config)
283-
self.attn = GraniteSpeechConformerAttention(config)
284-
self.conv = GraniteSpeechConformerConvModule(config)
285-
self.ff2 = GraniteSpeechConformerFeedForward(config)
286-
self.post_norm = nn.LayerNorm(config.hidden_dim)
287-
288-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
289-
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
290-
hidden_states = self.attn(hidden_states) + hidden_states
291-
hidden_states = self.conv(hidden_states) + hidden_states
292-
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
293-
hidden_states = self.post_norm(hidden_states)
294-
return hidden_states
295-
296-
297248
class GraniteSpeechConformerFeedForward(nn.Module):
298249
"""Feedforward module for conformer encoder blocks."""
299250

@@ -315,7 +266,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
315266

316267

317268
class GraniteSpeechConformerAttention(nn.Module):
318-
"""Attention for conformer blocks with shaw's relpos embeddings."""
269+
"""Attention for conformer blocks using Shaw's relative positional
270+
embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
271+
for more details.
272+
"""
319273

320274
def __init__(self, config: PretrainedConfig):
321275
super().__init__()
@@ -338,7 +292,8 @@ def __init__(self, config: PretrainedConfig):
338292
"Context size is either less than 0 or exceeds the max_pos_emb"
339293
)
340294

341-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
295+
def forward(self, hidden_states: torch.Tensor,
296+
attention_dists: torch.Tensor) -> torch.Tensor:
342297
hidden_states = self.pre_norm(hidden_states)
343298
bsz, num_features, _ = hidden_states.shape
344299

@@ -351,22 +306,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
351306

352307
query_states = self.to_q(hidden_states)
353308
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
354-
query_states, key_states, value_states = [
355-
t.reshape(
356-
bsz,
357-
num_blocks,
358-
self.context_size,
359-
self.num_heads,
360-
-1,
361-
).transpose(2, 3) for t in (query_states, key_states, value_states)
362-
]
309+
310+
query_states = query_states.reshape(bsz, num_blocks, self.context_size,
311+
self.num_heads,
312+
-1).transpose(2, 3)
313+
key_states = key_states.reshape(bsz, num_blocks, self.context_size,
314+
self.num_heads, -1).transpose(2, 3)
315+
value_states = value_states.reshape(bsz, num_blocks, self.context_size,
316+
self.num_heads,
317+
-1).transpose(2, 3)
363318

364319
# shaw's relative positional embedding
365-
seq = torch.arange(self.context_size, device=hidden_states.device)
366-
dist = seq.view(-1, 1) - seq.view(1, -1)
367-
dist = torch.clamp(dist, -self.context_size,
368-
self.context_size) + self.max_pos_emb
369-
rel_pos_emb = self.rel_pos_emb(dist).to(query_states)
320+
dist = attention_dists.to(hidden_states.device)
321+
rel_pos_emb = self.rel_pos_emb(dist)
370322
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] +
371323
list(rel_pos_emb.shape))
372324
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded,
@@ -390,18 +342,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
390342
attn_mask=pos_attn,
391343
scale=self.scale)
392344
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
393-
out = self.to_out(out[:, :num_features, :])
394-
return out
345+
return self.to_out(out[:, :num_features, :])
346+
347+
348+
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
349+
"""Wrapper for padded 1D pointwise convolution."""
350+
351+
def __init__(self, chan_in: int, chan_out: int, kernel_size: int):
352+
super().__init__()
353+
# Padding for the 1D conv is symmetric or close (i.e., offset by one).
354+
pad = kernel_size // 2
355+
pad_offset = (kernel_size + 1) % 2
356+
self.padding = (pad, pad - pad_offset)
357+
358+
self.conv = nn.Conv1d(chan_in,
359+
chan_out,
360+
kernel_size,
361+
groups=chan_in,
362+
bias=False)
363+
364+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
365+
hidden_states = F.pad(hidden_states, self.padding)
366+
return self.conv(hidden_states)
395367

396368

397369
class GraniteSpeechConformerConvModule(nn.Module):
398370
"""Conformer conv module consisting of several 1D/depthwise 1D
399-
convolutional layers."""
371+
convolutional layers.
372+
"""
400373

401374
def __init__(self, config: PretrainedConfig):
402375
super().__init__()
403376
inner_dim = config.hidden_dim * config.conv_expansion_factor
404-
padding = self.calc_same_padding(config.conv_kernel_size)
405377

406378
self.norm = nn.LayerNorm(config.hidden_dim)
407379
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
@@ -410,7 +382,7 @@ def __init__(self, config: PretrainedConfig):
410382
inner_dim,
411383
inner_dim,
412384
kernel_size=config.conv_kernel_size,
413-
padding=padding)
385+
)
414386
self.silu = nn.SiLU()
415387
self.batch_norm = nn.BatchNorm1d(inner_dim)
416388
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
@@ -424,34 +396,69 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
424396
hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
425397
return hidden_states
426398

427-
@staticmethod
428-
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
429-
"""Calculates symmetric padding for the depthwise 1D convolution."""
430-
pad = kernel_size // 2
431-
return (pad, pad - (kernel_size + 1) % 2)
432399

400+
class GraniteSpeechConformerBlock(nn.Module):
401+
"""Conformer block, consisting largely of linear layers,
402+
attention, and convolutional layers."""
433403

434-
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
435-
"""Wrapper for padded 1D pointwise convolution."""
404+
def __init__(self, config: PretrainedConfig):
405+
super().__init__()
406+
self.ff1 = GraniteSpeechConformerFeedForward(config)
407+
self.attn = GraniteSpeechConformerAttention(config)
408+
self.conv = GraniteSpeechConformerConvModule(config)
409+
self.ff2 = GraniteSpeechConformerFeedForward(config)
410+
self.post_norm = nn.LayerNorm(config.hidden_dim)
436411

437-
def __init__(
438-
self,
439-
chan_in: int,
440-
chan_out: int,
441-
kernel_size: int,
442-
padding: Tuple[int, int],
443-
):
412+
def forward(self, hidden_states: torch.Tensor,
413+
attention_dists: torch.Tensor) -> torch.Tensor:
414+
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
415+
hidden_states = self.attn(
416+
hidden_states, attention_dists=attention_dists) + hidden_states
417+
hidden_states = self.conv(hidden_states) + hidden_states
418+
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
419+
hidden_states = self.post_norm(hidden_states)
420+
return hidden_states
421+
422+
423+
class GraniteSpeechCTCEncoder(nn.Module):
424+
425+
def __init__(self, config: PretrainedConfig):
444426
super().__init__()
445-
self.padding = padding
446-
self.conv = nn.Conv1d(chan_in,
447-
chan_out,
448-
kernel_size,
449-
groups=chan_in,
450-
bias=False)
427+
self.config = config
451428

452-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
453-
hidden_states = F.pad(hidden_states, self.padding)
454-
return self.conv(hidden_states)
429+
# Precompute clamped relative positional encoding distances
430+
seq = torch.arange(config.context_size)
431+
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
432+
self.attention_dists = torch.clamp(
433+
relpos_dist, -config.context_size,
434+
config.context_size) + config.max_pos_emb
435+
436+
self.input_linear = nn.Linear(config.input_dim,
437+
config.hidden_dim,
438+
bias=True)
439+
self.layers = nn.ModuleList([
440+
GraniteSpeechConformerBlock(config)
441+
for _ in range(config.num_layers)
442+
])
443+
444+
self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True)
445+
self.out_mid = nn.Linear(config.output_dim,
446+
config.hidden_dim,
447+
bias=True)
448+
self.num_layers = config.num_layers
449+
450+
def forward(self, hidden_states: torch.Tensor):
451+
hidden_states = self.input_linear(hidden_states)
452+
for idx, layer in enumerate(self.layers, start=1):
453+
hidden_states = layer(hidden_states,
454+
attention_dists=self.attention_dists)
455+
456+
if idx == self.num_layers // 2:
457+
hidden_states_mid = hidden_states.clone()
458+
hidden_states_mid = self.out(hidden_states_mid)
459+
hidden_states += self.out_mid(
460+
nn.Softmax(dim=-1)(hidden_states_mid))
461+
return hidden_states
455462

456463

457464
@MULTIMODAL_REGISTRY.register_processor(

0 commit comments

Comments
 (0)