File tree 1 file changed +8
-22
lines changed
1 file changed +8
-22
lines changed Original file line number Diff line number Diff line change @@ -2056,40 +2056,26 @@ kernel void kernel_flash_attn_ext_f16(
2056
2056
continue ;
2057
2057
}
2058
2058
2059
- half4 s4 = 0 .0f ;
2059
+ device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
2060
+ device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
2060
2061
2061
- device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)) ;
2062
+ half4 s4 = 0 .0h ;
2062
2063
2063
2064
for (int64_t d = 0 ; d < D4; ++d) {
2064
2065
s4 += pk4[d] * pq4[d];
2065
2066
}
2066
2067
2067
- half s = s4.x + s4.y + s4.z + s4.w ;
2068
-
2069
- s = s*scale + mv;
2068
+ half s = (s4.x + s4.y + s4.z + s4.w )*scale + mv;
2070
2069
2071
2070
const half Mold = M;
2072
2071
2073
- half ms = 1 .0f ;
2074
- half vs = 1 .0f ;
2075
-
2076
- if (s > M) {
2077
- M = s;
2078
- ms = exp (Mold - M);
2079
-
2080
- // V = V*exp(Mold - M)
2081
- for (int64_t d = 0 ; d < D4; ++d) {
2082
- V16[d] *= ms;
2083
- }
2084
- } else {
2085
- vs = exp (s - M);
2086
- }
2072
+ M = max (M, s);
2087
2073
2088
- device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
2074
+ const half ms = exp (Mold - M);
2075
+ const half vs = exp (s - M);
2089
2076
2090
- // V += v*exp(s - M)
2091
2077
for (int64_t d = 0 ; d < D4; ++d) {
2092
- V16[d] += pv4 [d] * vs;
2078
+ V16[d] = V16 [d]*ms + pv4[d]* vs;
2093
2079
}
2094
2080
2095
2081
S = S*ms + vs;
You can’t perform that action at this time.
0 commit comments