Skip to content

Commit d4a3f91

Browse files
committed
Parametrize qkv bias
1 parent 42d3bfe commit d4a3f91

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

examples/models/llama/attention.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,16 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
175175
self.max_batch_size = args.max_batch_size
176176
self.max_context_len = args.max_context_len
177177
self.dim = args.dim
178-
# TODO: parametrize bias for attention and feedforward.
179-
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=True)
180-
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=True)
181-
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=True)
178+
self.attention_qkv_bias = args.attention_qkv_bias
179+
self.wq = nn.Linear(
180+
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
181+
)
182+
self.wk = nn.Linear(
183+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
184+
)
185+
self.wv = nn.Linear(
186+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
187+
)
182188
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
183189

184190
self.layer_id = layer_id

examples/models/llama/model_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class ModelArgs:
2121
num_experts: int = 8 # Number of experts
2222
num_activated_experts: int = 2 # Number of experts to activate
2323
attention_type: str = "mha" # Attention type, registered in attention.py
24+
attention_qkv_bias: bool = False
2425
use_kv_cache: bool = False # Use key/value cache
2526
use_sdpa_with_kv_cache_op: bool = (
2627
False # Use custom sdpa op that updates kv cache in-place

examples/models/qwen2_5/1_5b_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
"norm_eps": 1e-06,
99
"rope_theta": 1000000.0,
1010
"use_scaled_rope": false,
11-
"vocab_size": 151936
11+
"vocab_size": 151936,
12+
"attention_qkv_bias": true
1213
}

0 commit comments

Comments
 (0)