@@ -209,6 +209,9 @@ def __init__(
209
209
config .projector_config .hidden_size ))
210
210
self .query .data .normal_ (mean = 0.0 , std = 1.0 )
211
211
212
+ # NOTE - this is implemented generically in transformers,
213
+ # but for now we create the QFormer model directly since
214
+ # all existing models use this for the projector.
212
215
self .qformer = Blip2QFormerModel (
213
216
config .projector_config ,
214
217
quant_config = quant_config ,
@@ -242,58 +245,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
242
245
243
246
244
247
### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
245
- class GraniteSpeechCTCEncoder (nn .Module ):
246
-
247
- def __init__ (self , config : PretrainedConfig ):
248
- super ().__init__ ()
249
- self .config = config
250
- self .input_linear = nn .Linear (config .input_dim ,
251
- config .hidden_dim ,
252
- bias = True )
253
- self .layers = nn .ModuleList ([
254
- GraniteSpeechConformerBlock (config )
255
- for _ in range (config .num_layers )
256
- ])
257
-
258
- self .out = nn .Linear (config .hidden_dim , config .output_dim , bias = True )
259
- self .out_mid = nn .Linear (config .output_dim ,
260
- config .hidden_dim ,
261
- bias = True )
262
- self .num_layers = config .num_layers
263
-
264
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
265
- hidden_states = self .input_linear (hidden_states )
266
- for idx , layer in enumerate (self .layers , start = 1 ):
267
- hidden_states = layer (hidden_states )
268
- if idx == self .num_layers // 2 :
269
- hidden_states_mid = hidden_states .clone ()
270
- hidden_states_mid = self .out (hidden_states_mid )
271
- hidden_states += self .out_mid (
272
- nn .Softmax (dim = - 1 )(hidden_states_mid ))
273
- return hidden_states
274
-
275
-
276
- class GraniteSpeechConformerBlock (nn .Module ):
277
- """Conformer block, consisting largely of linear layers,
278
- attention, and convolutional layers."""
279
-
280
- def __init__ (self , config : PretrainedConfig ):
281
- super ().__init__ ()
282
- self .ff1 = GraniteSpeechConformerFeedForward (config )
283
- self .attn = GraniteSpeechConformerAttention (config )
284
- self .conv = GraniteSpeechConformerConvModule (config )
285
- self .ff2 = GraniteSpeechConformerFeedForward (config )
286
- self .post_norm = nn .LayerNorm (config .hidden_dim )
287
-
288
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
289
- hidden_states = 0.5 * self .ff1 (hidden_states ) + hidden_states
290
- hidden_states = self .attn (hidden_states ) + hidden_states
291
- hidden_states = self .conv (hidden_states ) + hidden_states
292
- hidden_states = 0.5 * self .ff2 (hidden_states ) + hidden_states
293
- hidden_states = self .post_norm (hidden_states )
294
- return hidden_states
295
-
296
-
297
248
class GraniteSpeechConformerFeedForward (nn .Module ):
298
249
"""Feedforward module for conformer encoder blocks."""
299
250
@@ -315,7 +266,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
315
266
316
267
317
268
class GraniteSpeechConformerAttention (nn .Module ):
318
- """Attention for conformer blocks with shaw's relpos embeddings."""
269
+ """Attention for conformer blocks using Shaw's relative positional
270
+ embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
271
+ for more details.
272
+ """
319
273
320
274
def __init__ (self , config : PretrainedConfig ):
321
275
super ().__init__ ()
@@ -338,7 +292,8 @@ def __init__(self, config: PretrainedConfig):
338
292
"Context size is either less than 0 or exceeds the max_pos_emb"
339
293
)
340
294
341
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
295
+ def forward (self , hidden_states : torch .Tensor ,
296
+ attention_dists : torch .Tensor ) -> torch .Tensor :
342
297
hidden_states = self .pre_norm (hidden_states )
343
298
bsz , num_features , _ = hidden_states .shape
344
299
@@ -351,22 +306,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
351
306
352
307
query_states = self .to_q (hidden_states )
353
308
key_states , value_states = self .to_kv (hidden_states ).chunk (2 , dim = - 1 )
354
- query_states , key_states , value_states = [
355
- t .reshape (
356
- bsz ,
357
- num_blocks ,
358
- self .context_size ,
359
- self .num_heads ,
360
- - 1 ,
361
- ). transpose ( 2 , 3 ) for t in ( query_states , key_states , value_states )
362
- ]
309
+
310
+ query_states = query_states .reshape (bsz , num_blocks , self . context_size ,
311
+ self . num_heads ,
312
+ - 1 ). transpose ( 2 , 3 )
313
+ key_states = key_states . reshape ( bsz , num_blocks , self .context_size ,
314
+ self .num_heads , - 1 ). transpose ( 2 , 3 )
315
+ value_states = value_states . reshape ( bsz , num_blocks , self . context_size ,
316
+ self . num_heads ,
317
+ - 1 ). transpose ( 2 , 3 )
363
318
364
319
# shaw's relative positional embedding
365
- seq = torch .arange (self .context_size , device = hidden_states .device )
366
- dist = seq .view (- 1 , 1 ) - seq .view (1 , - 1 )
367
- dist = torch .clamp (dist , - self .context_size ,
368
- self .context_size ) + self .max_pos_emb
369
- rel_pos_emb = self .rel_pos_emb (dist ).to (query_states )
320
+ dist = attention_dists .to (hidden_states .device )
321
+ rel_pos_emb = self .rel_pos_emb (dist )
370
322
rel_pos_emb_expanded = rel_pos_emb .view ([1 , 1 , 1 ] +
371
323
list (rel_pos_emb .shape ))
372
324
pos_attn = torch .sum (query_states .unsqueeze (- 2 ) * rel_pos_emb_expanded ,
@@ -390,18 +342,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
390
342
attn_mask = pos_attn ,
391
343
scale = self .scale )
392
344
out = out .transpose (2 , 3 ).reshape (bsz , hidden_states .shape [1 ], - 1 )
393
- out = self .to_out (out [:, :num_features , :])
394
- return out
345
+ return self .to_out (out [:, :num_features , :])
346
+
347
+
348
+ class GraniteSpeechConformerDepthWiseConv1d (nn .Module ):
349
+ """Wrapper for padded 1D pointwise convolution."""
350
+
351
+ def __init__ (self , chan_in : int , chan_out : int , kernel_size : int ):
352
+ super ().__init__ ()
353
+ # Padding for the 1D conv is symmetric or close (i.e., offset by one).
354
+ pad = kernel_size // 2
355
+ pad_offset = (kernel_size + 1 ) % 2
356
+ self .padding = (pad , pad - pad_offset )
357
+
358
+ self .conv = nn .Conv1d (chan_in ,
359
+ chan_out ,
360
+ kernel_size ,
361
+ groups = chan_in ,
362
+ bias = False )
363
+
364
+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
365
+ hidden_states = F .pad (hidden_states , self .padding )
366
+ return self .conv (hidden_states )
395
367
396
368
397
369
class GraniteSpeechConformerConvModule (nn .Module ):
398
370
"""Conformer conv module consisting of several 1D/depthwise 1D
399
- convolutional layers."""
371
+ convolutional layers.
372
+ """
400
373
401
374
def __init__ (self , config : PretrainedConfig ):
402
375
super ().__init__ ()
403
376
inner_dim = config .hidden_dim * config .conv_expansion_factor
404
- padding = self .calc_same_padding (config .conv_kernel_size )
405
377
406
378
self .norm = nn .LayerNorm (config .hidden_dim )
407
379
self .up_conv = nn .Conv1d (config .hidden_dim , inner_dim * 2 , 1 )
@@ -410,7 +382,7 @@ def __init__(self, config: PretrainedConfig):
410
382
inner_dim ,
411
383
inner_dim ,
412
384
kernel_size = config .conv_kernel_size ,
413
- padding = padding )
385
+ )
414
386
self .silu = nn .SiLU ()
415
387
self .batch_norm = nn .BatchNorm1d (inner_dim )
416
388
self .down_conv = nn .Conv1d (inner_dim , config .hidden_dim , 1 )
@@ -424,34 +396,69 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
424
396
hidden_states = self .down_conv (hidden_states ).permute (0 , 2 , 1 )
425
397
return hidden_states
426
398
427
- @staticmethod
428
- def calc_same_padding (kernel_size : int ) -> Tuple [int , int ]:
429
- """Calculates symmetric padding for the depthwise 1D convolution."""
430
- pad = kernel_size // 2
431
- return (pad , pad - (kernel_size + 1 ) % 2 )
432
399
400
+ class GraniteSpeechConformerBlock (nn .Module ):
401
+ """Conformer block, consisting largely of linear layers,
402
+ attention, and convolutional layers."""
433
403
434
- class GraniteSpeechConformerDepthWiseConv1d (nn .Module ):
435
- """Wrapper for padded 1D pointwise convolution."""
404
+ def __init__ (self , config : PretrainedConfig ):
405
+ super ().__init__ ()
406
+ self .ff1 = GraniteSpeechConformerFeedForward (config )
407
+ self .attn = GraniteSpeechConformerAttention (config )
408
+ self .conv = GraniteSpeechConformerConvModule (config )
409
+ self .ff2 = GraniteSpeechConformerFeedForward (config )
410
+ self .post_norm = nn .LayerNorm (config .hidden_dim )
436
411
437
- def __init__ (
438
- self ,
439
- chan_in : int ,
440
- chan_out : int ,
441
- kernel_size : int ,
442
- padding : Tuple [int , int ],
443
- ):
412
+ def forward (self , hidden_states : torch .Tensor ,
413
+ attention_dists : torch .Tensor ) -> torch .Tensor :
414
+ hidden_states = 0.5 * self .ff1 (hidden_states ) + hidden_states
415
+ hidden_states = self .attn (
416
+ hidden_states , attention_dists = attention_dists ) + hidden_states
417
+ hidden_states = self .conv (hidden_states ) + hidden_states
418
+ hidden_states = 0.5 * self .ff2 (hidden_states ) + hidden_states
419
+ hidden_states = self .post_norm (hidden_states )
420
+ return hidden_states
421
+
422
+
423
+ class GraniteSpeechCTCEncoder (nn .Module ):
424
+
425
+ def __init__ (self , config : PretrainedConfig ):
444
426
super ().__init__ ()
445
- self .padding = padding
446
- self .conv = nn .Conv1d (chan_in ,
447
- chan_out ,
448
- kernel_size ,
449
- groups = chan_in ,
450
- bias = False )
427
+ self .config = config
451
428
452
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
453
- hidden_states = F .pad (hidden_states , self .padding )
454
- return self .conv (hidden_states )
429
+ # Precompute clamped relative positional encoding distances
430
+ seq = torch .arange (config .context_size )
431
+ relpos_dist = seq .view (- 1 , 1 ) - seq .view (1 , - 1 )
432
+ self .attention_dists = torch .clamp (
433
+ relpos_dist , - config .context_size ,
434
+ config .context_size ) + config .max_pos_emb
435
+
436
+ self .input_linear = nn .Linear (config .input_dim ,
437
+ config .hidden_dim ,
438
+ bias = True )
439
+ self .layers = nn .ModuleList ([
440
+ GraniteSpeechConformerBlock (config )
441
+ for _ in range (config .num_layers )
442
+ ])
443
+
444
+ self .out = nn .Linear (config .hidden_dim , config .output_dim , bias = True )
445
+ self .out_mid = nn .Linear (config .output_dim ,
446
+ config .hidden_dim ,
447
+ bias = True )
448
+ self .num_layers = config .num_layers
449
+
450
+ def forward (self , hidden_states : torch .Tensor ):
451
+ hidden_states = self .input_linear (hidden_states )
452
+ for idx , layer in enumerate (self .layers , start = 1 ):
453
+ hidden_states = layer (hidden_states ,
454
+ attention_dists = self .attention_dists )
455
+
456
+ if idx == self .num_layers // 2 :
457
+ hidden_states_mid = hidden_states .clone ()
458
+ hidden_states_mid = self .out (hidden_states_mid )
459
+ hidden_states += self .out_mid (
460
+ nn .Softmax (dim = - 1 )(hidden_states_mid ))
461
+ return hidden_states
455
462
456
463
457
464
@MULTIMODAL_REGISTRY .register_processor (
0 commit comments