@@ -1890,16 +1890,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1890
1890
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
1891
1891
1892
1892
if re .match (qkv_pattern , name ):
1893
- from einops import rearrange
1894
-
1895
1893
bid = re .findall (qkv_pattern , name )[0 ]
1896
1894
qkv = data_torch
1897
- qkv = rearrange (qkv .T , " o (g n i) ->o g n i" , g = num_groups , n = q_per_kv + 2 , i = head_dim )
1895
+ # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
1896
+ qkv = qkv .T .reshape ((- 1 , num_groups , q_per_kv + 2 , head_dim ))
1898
1897
q , k , v = qkv [..., : q_per_kv , :], qkv [..., q_per_kv : q_per_kv + 1 , :], qkv [..., q_per_kv + 1 : q_per_kv + 2 , :]
1899
1898
# The model weights of q and k equire additional reshape.
1900
- q = self ._hf_permute_qk (rearrange (q , " o g n i -> o (g n i)" ).T , num_heads , num_heads )
1901
- k = self ._hf_permute_qk (rearrange (k , " o g n i -> o (g n i)" ).T , num_heads , num_kv_heads )
1902
- v = rearrange (v , " o g n i -> o (g n i)" ).T
1899
+ # q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
1900
+ q = self ._hf_permute_qk (q .reshape ((q .shape [0 ], - 1 )).T , num_heads , num_heads )
1901
+ # k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
1902
+ k = self ._hf_permute_qk (k .reshape ((k .shape [0 ], - 1 )).T , num_heads , num_kv_heads )
1903
+ # v = rearrange(v, " o g n i -> o (g n i)").T
1904
+ v = v .reshape ((v .shape [0 ], - 1 )).T
1903
1905
return [
1904
1906
(self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_Q , bid ), q ),
1905
1907
(self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_K , bid ), k ),
@@ -2238,13 +2240,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
2238
2240
class LazyTorchTensor :
2239
2241
_meta : Tensor
2240
2242
_data : Tensor | None
2241
- _args : list [ Any ]
2242
- _func : Callable [[list [ Any ]] , Tensor ] | None = None
2243
+ _args : tuple
2244
+ _func : Callable [[tuple ] , Tensor ] | None
2243
2245
2244
- def __init__ (self , * , meta : Tensor , data : Tensor | None = None , args : list [ Any ] | None = None , func : Callable [[list [ Any ] ], Tensor ] | None = None ):
2246
+ def __init__ (self , * , meta : Tensor , data : Tensor | None = None , args : tuple = () , func : Callable [[tuple ], Tensor ] | None = None ):
2245
2247
self ._meta = meta
2246
2248
self ._data = data
2247
- self ._args = args if args is not None else []
2249
+ self ._args = args
2248
2250
self ._func = func
2249
2251
2250
2252
@staticmethod
@@ -2266,19 +2268,22 @@ def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], Lazy
2266
2268
def wrapped_fn (* args , ** kwargs ):
2267
2269
if kwargs is None :
2268
2270
kwargs = {}
2269
- args_list = ([ self ] if use_self else []) + list ( args )
2271
+ args = (( self ,) if use_self else ()) + args
2270
2272
2271
- meta_args = LazyTorchTensor ._recurse_apply (args_list , lambda t : t ._meta )
2273
+ meta_args = LazyTorchTensor ._recurse_apply (args , lambda t : t ._meta )
2272
2274
2273
- return LazyTorchTensor (meta = fn (* meta_args , ** kwargs ), args = args_list , func = lambda a : fn (* a , ** kwargs ))
2275
+ return LazyTorchTensor (meta = fn (* meta_args , ** kwargs ), args = args , func = lambda a : fn (* a , ** kwargs ))
2274
2276
return wrapped_fn
2275
2277
2276
2278
def __getattr__ (self , __name : str ) -> Any :
2277
2279
meta_attr = getattr (self ._meta , __name )
2278
- if not callable (meta_attr ):
2279
- return meta_attr
2280
- else :
2280
+ if callable (meta_attr ):
2281
2281
return self ._wrap_fn (getattr (torch .Tensor , __name ), use_self = True )
2282
+ elif isinstance (meta_attr , torch .Tensor ):
2283
+ # for things like self.T
2284
+ return self ._wrap_fn (lambda s : getattr (s , __name ))(self )
2285
+ else :
2286
+ return meta_attr
2282
2287
2283
2288
_dtype_map : dict [torch .dtype , type ] = {
2284
2289
torch .float16 : np .float16 ,
@@ -2295,7 +2300,7 @@ def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
2295
2300
2296
2301
@overload
2297
2302
@staticmethod
2298
- def to_eager (t : list [ Tensor | LazyTorchTensor ] ) -> list [ Tensor ] : ...
2303
+ def to_eager (t : tuple ) -> tuple : ...
2299
2304
2300
2305
@staticmethod
2301
2306
def to_eager (t : Any ) -> Any :
0 commit comments