Skip to content

Commit 98db434

Browse files
committed
convert-hf : remove einops requirement for InternLM2
1 parent 0c38332 commit 98db434

File tree

4 files changed

+22
-20
lines changed

4 files changed

+22
-20
lines changed

.devops/nix/package.nix

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ let
8686
# TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
8787
llama-python-extra = python3.withPackages (
8888
ps: [
89-
ps.einops
9089
ps.numpy
9190
ps.sentencepiece
9291
ps.tiktoken

convert-hf-to-gguf.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,16 +1890,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
18901890
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
18911891

18921892
if re.match(qkv_pattern, name):
1893-
from einops import rearrange
1894-
18951893
bid = re.findall(qkv_pattern, name)[0]
18961894
qkv = data_torch
1897-
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
1895+
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
1896+
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
18981897
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
18991898
# The model weights of q and k equire additional reshape.
1900-
q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
1901-
k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
1902-
v = rearrange(v, " o g n i -> o (g n i)").T
1899+
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
1900+
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
1901+
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
1902+
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
1903+
# v = rearrange(v, " o g n i -> o (g n i)").T
1904+
v = v.reshape((v.shape[0], -1)).T
19031905
return [
19041906
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
19051907
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
@@ -2238,13 +2240,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
22382240
class LazyTorchTensor:
22392241
_meta: Tensor
22402242
_data: Tensor | None
2241-
_args: list[Any]
2242-
_func: Callable[[list[Any]], Tensor] | None = None
2243+
_args: tuple
2244+
_func: Callable[[tuple], Tensor] | None
22432245

2244-
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: list[Any] | None = None, func: Callable[[list[Any]], Tensor] | None = None):
2246+
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
22452247
self._meta = meta
22462248
self._data = data
2247-
self._args = args if args is not None else []
2249+
self._args = args
22482250
self._func = func
22492251

22502252
@staticmethod
@@ -2266,19 +2268,22 @@ def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], Lazy
22662268
def wrapped_fn(*args, **kwargs):
22672269
if kwargs is None:
22682270
kwargs = {}
2269-
args_list = ([self] if use_self else []) + list(args)
2271+
args = ((self,) if use_self else ()) + args
22702272

2271-
meta_args = LazyTorchTensor._recurse_apply(args_list, lambda t: t._meta)
2273+
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
22722274

2273-
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args_list, func=lambda a: fn(*a, **kwargs))
2275+
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
22742276
return wrapped_fn
22752277

22762278
def __getattr__(self, __name: str) -> Any:
22772279
meta_attr = getattr(self._meta, __name)
2278-
if not callable(meta_attr):
2279-
return meta_attr
2280-
else:
2280+
if callable(meta_attr):
22812281
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
2282+
elif isinstance(meta_attr, torch.Tensor):
2283+
# for things like self.T
2284+
return self._wrap_fn(lambda s: getattr(s, __name))(self)
2285+
else:
2286+
return meta_attr
22822287

22832288
_dtype_map: dict[torch.dtype, type] = {
22842289
torch.float16: np.float16,
@@ -2295,7 +2300,7 @@ def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
22952300

22962301
@overload
22972302
@staticmethod
2298-
def to_eager(t: list[Tensor | LazyTorchTensor]) -> list[Tensor]: ...
2303+
def to_eager(t: tuple) -> tuple: ...
22992304

23002305
@staticmethod
23012306
def to_eager(t: Any) -> Any:
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
-r ./requirements-convert.txt
22
torch~=2.1.1
3-
einops~=0.7.0
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
-r ./requirements-convert.txt
22
torch~=2.1.1
3-
einops~=0.7.0

0 commit comments

Comments
 (0)