Skip to content

Commit 8956543

Browse files
committed
convert_hf : simplify modify_tensors for InternLM2
* convert_lora : lazy conversion * llama : load and use alpha from LoRA adapters
1 parent 9d96328 commit 8956543

File tree

4 files changed

+124
-67
lines changed

4 files changed

+124
-67
lines changed

convert_hf_to_gguf.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2222,13 +2222,6 @@ def set_vocab(self):
22222222

22232223
special_vocab.add_to_gguf(self.gguf_writer)
22242224

2225-
def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
2226-
if n_head_kv is not None and n_head != n_head_kv:
2227-
n_head = n_head_kv
2228-
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
2229-
.swapaxes(1, 2)
2230-
.reshape(weights.shape))
2231-
22322225
def set_gguf_parameters(self):
22332226
self.gguf_writer.add_name("InternLM2")
22342227
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
@@ -2248,26 +2241,22 @@ def set_gguf_parameters(self):
22482241
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
22492242
num_heads = self.hparams["num_attention_heads"]
22502243
num_kv_heads = self.hparams["num_key_value_heads"]
2251-
hidden_size = self.hparams["hidden_size"]
2244+
n_embd = self.hparams["hidden_size"]
22522245
q_per_kv = num_heads // num_kv_heads
2253-
head_dim = hidden_size // num_heads
2246+
head_dim = n_embd // num_heads
22542247
num_groups = num_heads // q_per_kv
22552248

2256-
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
2257-
2258-
if re.match(qkv_pattern, name):
2259-
bid = re.findall(qkv_pattern, name)[0]
2249+
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
22602250
qkv = data_torch
2261-
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
2262-
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
2263-
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
2251+
2252+
qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd))
2253+
q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1]
2254+
22642255
# The model weights of q and k equire additional reshape.
2265-
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
2266-
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
2267-
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
2268-
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
2269-
# v = rearrange(v, " o g n i -> o (g n i)").T
2270-
v = v.reshape((v.shape[0], -1)).T
2256+
q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads)
2257+
k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads)
2258+
v = v.reshape((-1, v.shape[-1]))
2259+
22712260
return [
22722261
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
22732262
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),

convert_lora_to_gguf.py

Lines changed: 94 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import argparse
99
import os
1010
import sys
11+
import json
12+
from math import prod
1113
from pathlib import Path
12-
from types import EllipsisType
13-
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
14+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
1415

1516
import torch
1617

@@ -22,7 +23,7 @@
2223
import gguf
2324

2425
# reuse model definitions from convert_hf_to_gguf.py
25-
from convert_hf_to_gguf import Model
26+
from convert_hf_to_gguf import LazyTorchTensor, Model
2627

2728
logger = logging.getLogger("lora-to-gguf")
2829

@@ -35,45 +36,53 @@ class PartialLoraTensor:
3536

3637
# magic to support tensor shape modifications and splitting
3738
class LoraTorchTensor:
38-
_lora_A: Tensor
39-
_lora_B: Tensor
39+
_lora_A: Tensor # (n_rank, row_size)
40+
_lora_B: Tensor # (col_size, n_rank)
4041
_rank: int
4142

4243
def __init__(self, A: Tensor, B: Tensor):
4344
assert len(A.shape) == len(B.shape)
45+
assert A.shape[-2] == B.shape[-1]
4446
if A.dtype != B.dtype:
4547
A = A.to(torch.float32)
4648
B = B.to(torch.float32)
4749
self._lora_A = A
4850
self._lora_B = B
49-
assert self._lora_A.shape[-2] == self._lora_B.shape[-1]
50-
self._rank = self._lora_B.shape[-1]
51+
self._rank = B.shape[-1]
52+
53+
def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
54+
return (self._lora_A, self._lora_B)
5155

5256
def __getitem__(
5357
self,
5458
indices: (
5559
SupportsIndex
5660
| slice
57-
| tuple[SupportsIndex | slice | EllipsisType | Tensor, ...]
61+
| tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
5862
),
5963
) -> LoraTorchTensor:
6064
shape = self.shape
61-
if isinstance(indices, (SupportsIndex, slice)):
65+
if isinstance(indices, SupportsIndex):
6266
if len(shape) > 2:
6367
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
6468
else:
65-
raise NotImplementedError
69+
raise NotImplementedError # can't return a vector
70+
elif isinstance(indices, slice):
71+
if len(shape) > 2:
72+
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
73+
else:
74+
return LoraTorchTensor(self._lora_A, self._lora_B[indices])
6675
elif isinstance(indices, tuple):
6776
assert len(indices) > 0
68-
if isinstance(indices[-1], EllipsisType):
77+
if indices[-1] is Ellipsis:
6978
return self[indices[:-1]]
7079
# expand ellipsis
7180
indices = tuple(
7281
u
7382
for v in (
7483
(
7584
(slice(None, None) for _ in range(len(indices) - 1))
76-
if isinstance(i, EllipsisType)
85+
if i is Ellipsis
7786
else (i,)
7887
)
7988
for i in indices
@@ -85,19 +94,22 @@ def __getitem__(
8594
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
8695

8796
# TODO: make sure this is correct
88-
# lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1])
8997
indices_A = (
9098
*(
91-
0 if isinstance(i, SupportsIndex) else slice(None, None)
92-
for i in indices[:-2]
99+
(
100+
j.__index__() % self._lora_A.shape[i]
101+
if isinstance(j, SupportsIndex)
102+
else slice(None, None)
103+
)
104+
for i, j in enumerate(indices[:-2])
93105
),
94106
slice(None, None),
95107
indices[-1],
96108
)
97109
indices_B = indices[:-1]
98110
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
99111
else:
100-
raise NotImplementedError
112+
raise NotImplementedError # unknown indice type
101113

102114
@property
103115
def dtype(self) -> torch.dtype:
@@ -106,23 +118,37 @@ def dtype(self) -> torch.dtype:
106118

107119
@property
108120
def shape(self) -> tuple[int, ...]:
121+
assert len(self._lora_A.shape) == len(self._lora_B.shape)
109122
return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
110123

111124
def size(self, dim=None):
112125
assert dim is None
113126
return self.shape
114127

115-
def reshape(self, *shape: int | tuple[int]) -> LoraTorchTensor:
128+
def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
116129
if isinstance(shape[0], tuple):
117-
new_shape: tuple[int] = shape[0]
130+
new_shape: tuple[int, ...] = shape[0]
118131
else:
119-
new_shape = cast(tuple[int], shape)
132+
new_shape = cast(tuple[int, ...], shape)
120133
orig_shape = self.shape
134+
if len(new_shape) < 2:
135+
raise NotImplementedError # can't become a vector
136+
137+
# expand -1 in the shape
138+
if any(dim == -1 for dim in new_shape):
139+
n_elems = prod(orig_shape)
140+
n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
141+
assert n_elems % n_new_elems == 0
142+
new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
143+
121144
if new_shape[-1] != orig_shape[-1]:
122-
raise NotImplementedError
145+
raise NotImplementedError # can't reshape the row size trivially
146+
147+
shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
148+
shape_B = (*new_shape[:-1], self._rank)
123149
return LoraTorchTensor(
124-
self._lora_A.reshape((*(1 for _ in new_shape[:-2]), *self._lora_A.shape[-2:])),
125-
self._lora_B.reshape((*new_shape[:-1], self._rank)),
150+
self._lora_A.reshape(shape_A),
151+
self._lora_B.reshape(shape_B),
126152
)
127153

128154
def reshape_as(self, other: Tensor) -> LoraTorchTensor:
@@ -134,12 +160,15 @@ def view(self, *size: int) -> LoraTorchTensor:
134160
def permute(self, *dims: int) -> LoraTorchTensor:
135161
shape = self.shape
136162
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
137-
if dims[-1] == -2 and dims[-2] == -1:
138-
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
139-
else:
140-
assert dims[-1] == -1
163+
if dims[-1] == -1:
164+
# TODO: support higher dimensional A shapes bigger than 1
141165
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
142166
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
167+
if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
168+
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
169+
else:
170+
# TODO: compose the above two
171+
raise NotImplementedError
143172

144173
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
145174
shape = self.shape
@@ -181,11 +210,13 @@ def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
181210
torch.cat([a._lora_A for a in args[0]], dim),
182211
torch.cat([b._lora_B for b in args[0]], dim),
183212
)
184-
else:
213+
elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
185214
return LoraTorchTensor(
186-
args[0][0]._lora_A, # TODO: is this correct? (can't cat over the rank)
215+
args[0][0]._lora_A,
187216
torch.cat([b._lora_B for b in args[0]], dim),
188217
)
218+
else:
219+
raise NotImplementedError
189220
else:
190221
raise NotImplementedError
191222

@@ -205,13 +236,17 @@ def parse_args() -> argparse.Namespace:
205236
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
206237
)
207238
parser.add_argument(
208-
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
209-
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0",
239+
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
240+
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
210241
)
211242
parser.add_argument(
212243
"--bigendian", action="store_true",
213244
help="model is executed on big endian machine",
214245
)
246+
parser.add_argument(
247+
"--no-lazy", action="store_true",
248+
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
249+
)
215250
parser.add_argument(
216251
"--verbose", action="store_true",
217252
help="increase output verbosity",
@@ -237,13 +272,16 @@ def parse_args() -> argparse.Namespace:
237272
"f16": gguf.LlamaFileType.MOSTLY_F16,
238273
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
239274
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
275+
"auto": gguf.LlamaFileType.GUESSED,
240276
}
277+
241278
ftype = ftype_map[args.outtype]
242279

243-
dir_base_model = args.base
244-
dir_lora = args.lora_path
245-
input_json = os.path.join(dir_lora, "adapter_config.json")
246-
input_model = os.path.join(dir_lora, "adapter_model.safetensors")
280+
dir_base_model: Path = args.base
281+
dir_lora: Path = args.lora_path
282+
lora_config = dir_lora / "adapter_config.json"
283+
input_model = dir_lora / "adapter_model.safetensors"
284+
247285
if args.outfile is not None:
248286
fname_out = args.outfile
249287
else:
@@ -276,6 +314,8 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
276314
tensor_map: dict[str, PartialLoraTensor] = {}
277315

278316
for name, tensor in lora_model.items():
317+
if self.lazy:
318+
tensor = LazyTorchTensor.from_eager(tensor)
279319
base_name = get_base_tensor_name(name)
280320
is_lora_a = ".lora_A.weight" in name
281321
is_lora_b = ".lora_B.weight" in name
@@ -305,16 +345,30 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
305345
dest = super().modify_tensors(data_torch, name, bid)
306346
for dest_name, dest_data in dest:
307347
assert isinstance(dest_data, LoraTorchTensor)
308-
# logger.info(f"{orig_name} --> {dest_name}")
309-
yield (dest_name + ".lora_a", dest_data._lora_A)
310-
yield (dest_name + ".lora_b", dest_data._lora_B)
311-
312-
model_instance = LoraModel(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
348+
lora_a, lora_b = dest_data.get_lora_A_B()
349+
350+
yield (dest_name + ".lora_a", lora_a)
351+
yield (dest_name + ".lora_b", lora_b)
352+
353+
model_instance = LoraModel(
354+
dir_base_model,
355+
ftype,
356+
fname_out,
357+
is_big_endian=args.bigendian,
358+
use_temp_file=False,
359+
eager=args.no_lazy,
360+
model_name=None,
361+
)
313362
logger.info("Set model parameters")
314363
model_instance.set_gguf_parameters()
315364

316-
# adapter_config = json.load(input_json)
365+
with open(lora_config, "r") as f:
366+
lparams: dict[str, Any] = json.load(f)
367+
368+
alpha = lparams["lora_alpha"]
369+
317370
model_instance.gguf_writer.add_string("training.type", "finetune_lora")
371+
model_instance.gguf_writer.add_float32("training.lora.alpha", float(alpha))
318372

319373
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
320374
logger.info("Exporting model...")

gguf-py/gguf/quants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
4343
osize *= dim
4444
out = np.empty(shape=osize, dtype=otype)
4545
# compute over groups of 16 rows (arbitrary, but seems good for performance)
46-
n_groups = rows.shape[0] // 16
46+
n_groups = (rows.shape[0] // 16) or 1
4747
np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
4848
return out.reshape(oshape)
4949

0 commit comments

Comments
 (0)