Skip to content

Commit f98eb31

Browse files
authored
convert-hf : save memory with lazy evaluation (#7075)
* convert-hf : begin refactoring write_tensor * convert : upgrade to sentencepiece v0.2.0 * convert-hf : remove unused n_dims in extra_*_tensors * convert-hf : simplify MoE weights stacking * convert-hf : flake8 linter doesn't like semicolons * convert-hf : allow unusual model part names For example, loading `model-00001-of-00001.safetensors` now works. * convert-hf : fix stacking MoE expert tensors `torch.stack` and `torch.cat` don't do the same thing. * convert-hf : fix Mamba conversion Tested to work even with a SentencePiece-based tokenizer. * convert : use a string for the SentencePiece tokenizer path * convert-hf : display tensor shape * convert-hf : convert norms to f32 by default * convert-hf : sort model part names `os.listdir` is said to list files in arbitrary order. Sorting the file names should let "model-00009-of-00042.safetensors" be loaded before "model-00010-of-00042.safetensors". * convert-hf : use an ABC for Model again It seems Protocol can't be used as a statically type-checked ABC, because its subclasses also can't be instantiated. (why did it seem to work?) At least there's still a way to throw an error when forgetting to define the `model_arch` property of any registered Model subclasses. * convert-hf : use a plain class for Model, and forbid direct instantiation There are no abstract methods used anyway, so using ABC isn't really necessary. * convert-hf : more consistent formatting of cmdline args * convert-hf : align the message logged for converted tensors * convert-hf : fix Refact conversion * convert-hf : save memory with lazy evaluation * convert-hf : flake8 doesn't like lowercase L as a variable name * convert-hf : remove einops requirement for InternLM2 * convert-hf : faster model parts loading Instead of pre-loading them all into a dict, iterate on the tensors in the model parts progressively as needed in Model.write_tensors Conversion for some architectures relies on checking for the presence of specific tensor names, so for multi-part models, the weight map is read from the relevant json file to quickly get these names up-front. * convert-hf : minor changes for consistency * gguf-py : add tqdm as a dependency It's small, and used for a progress bar in GGUFWriter.write_tensors_to_file
1 parent bc4bba3 commit f98eb31

14 files changed

+847
-1259
lines changed

convert-hf-to-gguf.py

Lines changed: 746 additions & 1223 deletions
Large diffs are not rendered by default.

convert.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
284284
n_experts = None
285285
n_experts_used = None
286286
f_rope_freq_base = None
287+
n_ff = None
287288

288289
# hack to determine LLaMA v1 vs v2 vs CodeLlama
289290
if config.get("moe"):
@@ -308,6 +309,8 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
308309
n_experts_used = config["moe"]["num_experts_per_tok"]
309310
f_rope_freq_base = 1e6
310311

312+
assert n_ff is not None
313+
311314
return Params(
312315
n_vocab = model["tok_embeddings.weight"].shape[0],
313316
n_embd = config["dim"],
@@ -462,7 +465,8 @@ def __init__(self, base_path: Path):
462465
# not found in alternate location either
463466
raise FileNotFoundError('Cannot find tokenizer.model')
464467

465-
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
468+
self.sentencepiece_tokenizer = SentencePieceProcessor()
469+
self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
466470
vocab_size = self.sentencepiece_tokenizer.vocab_size()
467471

468472
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
@@ -482,23 +486,23 @@ def __init__(self, base_path: Path):
482486
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
483487
tokenizer = self.sentencepiece_tokenizer
484488
for i in range(tokenizer.vocab_size()):
485-
piece = tokenizer.id_to_piece(i)
489+
piece = tokenizer.IdToPiece(i)
486490
text = piece.encode("utf-8")
487-
score: float = tokenizer.get_score(i)
491+
score: float = tokenizer.GetScore(i)
488492

489493
toktype = gguf.TokenType.NORMAL
490-
if tokenizer.is_unknown(i):
494+
if tokenizer.IsUnknown(i):
491495
toktype = gguf.TokenType.UNKNOWN
492-
if tokenizer.is_control(i):
496+
if tokenizer.IsControl(i):
493497
toktype = gguf.TokenType.CONTROL
494498

495499
# NOTE: I think added_tokens are user defined.
496500
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
497501
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
498502

499-
if tokenizer.is_unused(i):
503+
if tokenizer.IsUnused(i):
500504
toktype = gguf.TokenType.UNUSED
501-
if tokenizer.is_byte(i):
505+
if tokenizer.IsByte(i):
502506
toktype = gguf.TokenType.BYTE
503507

504508
yield text, score, toktype
@@ -906,7 +910,7 @@ def load() -> UnquantizedTensor:
906910
def rebuild_from_type_v2(func, new_type, args, state):
907911
return func(*args)
908912

909-
CLASSES = {
913+
CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = {
910914
# getattr used here as a workaround for mypy not being smart enough to determine
911915
# the staticmethods have a __func__ attribute.
912916
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),

examples/server/tests/features/steps/steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ async def oai_chat_completions(user_prompt,
939939
while event_received:
940940
event_received = False
941941
async for line_in_bytes in response.content:
942-
line = line_in_bytes.decode('utf8')
942+
line = line_in_bytes.decode('utf-8')
943943
line = line.rstrip('\n').rstrip('\r')
944944
if line == '':
945945
continue

gguf-py/gguf/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def get_type(val: Any) -> GGUFValueType:
860860
# Note: Does not support GGML_QKK_64
861861
QK_K = 256
862862
# Items here are (block size, type size)
863-
GGML_QUANT_SIZES = {
863+
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
864864
GGMLQuantizationType.F32: (1, 4),
865865
GGMLQuantizationType.F16: (1, 2),
866866
GGMLQuantizationType.Q4_0: (32, 2 + 16),

gguf-py/gguf/gguf_reader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class ReaderTensor(NamedTuple):
6565

6666
class GGUFReader:
6767
# I - same as host, S - swapped
68-
byte_order: Literal['I' | 'S'] = 'I'
68+
byte_order: Literal['I'] | Literal['S'] = 'I'
6969
alignment: int = GGUF_DEFAULT_ALIGNMENT
7070

7171
# Note: Internal helper, API may change.
@@ -83,7 +83,7 @@ class GGUFReader:
8383
GGUFValueType.BOOL: np.bool_,
8484
}
8585

86-
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
86+
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
8787
self.data = np.memmap(path, mode = mode)
8888
offs = 0
8989
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
@@ -128,7 +128,7 @@ def get_tensor(self, idx: int) -> ReaderTensor:
128128
return self.tensors[idx]
129129

130130
def _get(
131-
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
131+
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
132132
) -> npt.NDArray[Any]:
133133
count = int(count)
134134
itemsize = int(np.empty([], dtype = dtype).itemsize)
@@ -250,7 +250,7 @@ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
250250
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
251251
tensor_names.add(tensor_name)
252252
ggml_type = GGMLQuantizationType(raw_dtype[0])
253-
n_elems = np.prod(dims)
253+
n_elems = int(np.prod(dims))
254254
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
255255
n_bytes = n_elems * type_size // block_size
256256
data_offs = int(start_offs + offset_tensor[0])

gguf-py/gguf/gguf_writer.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tempfile
88
from enum import Enum, auto
99
from io import BufferedWriter
10-
from typing import IO, Any, Sequence, Mapping
10+
from typing import IO, Any, Callable, Sequence, Mapping
1111
from string import ascii_letters, digits
1212

1313
import numpy as np
@@ -28,6 +28,47 @@
2828
logger = logging.getLogger(__name__)
2929

3030

31+
class LazyTensor:
32+
data: Callable[[], np.ndarray[Any, Any]]
33+
# to avoid too deep recursion
34+
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
35+
dtype: np.dtype[Any]
36+
shape: tuple[int, ...]
37+
38+
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
39+
self.data = data
40+
self.functions = []
41+
self.dtype = np.dtype(dtype)
42+
self.shape = shape
43+
44+
def astype(self, dtype: type, **kwargs) -> LazyTensor:
45+
self.functions.append(lambda n: n.astype(dtype, **kwargs))
46+
self.dtype = np.dtype(dtype)
47+
return self
48+
49+
@property
50+
def nbytes(self) -> int:
51+
size = 1
52+
for n in self.shape:
53+
size *= n
54+
return size * self.dtype.itemsize
55+
56+
def tofile(self, *args, **kwargs) -> None:
57+
data = self.data()
58+
for f in self.functions:
59+
data = f(data)
60+
assert data.shape == self.shape
61+
assert data.dtype == self.dtype
62+
assert data.nbytes == self.nbytes
63+
self.functions = []
64+
self.data = lambda: data
65+
data.tofile(*args, **kwargs)
66+
67+
def byteswap(self, *args, **kwargs) -> LazyTensor:
68+
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
69+
return self
70+
71+
3172
class WriterState(Enum):
3273
EMPTY = auto()
3374
HEADER = auto()
@@ -38,7 +79,7 @@ class WriterState(Enum):
3879
class GGUFWriter:
3980
fout: BufferedWriter
4081
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
41-
tensors: list[np.ndarray[Any, Any]]
82+
tensors: list[np.ndarray[Any, Any] | LazyTensor]
4283
_simple_value_packing = {
4384
GGUFValueType.UINT8: "B",
4485
GGUFValueType.INT8: "b",
@@ -176,7 +217,7 @@ def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool
176217
if pack_fmt is not None:
177218
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
178219
elif vtype == GGUFValueType.STRING:
179-
encoded_val = val.encode("utf8") if isinstance(val, str) else val
220+
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
180221
self.kv_data += self._pack("Q", len(encoded_val))
181222
self.kv_data += encoded_val
182223
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
@@ -205,7 +246,7 @@ def add_tensor_info(
205246
raise ValueError(f'Duplicated tensor name {name}')
206247
self.ti_names.add(name)
207248

208-
encoded_name = name.encode("utf8")
249+
encoded_name = name.encode("utf-8")
209250
self.ti_data += self._pack("Q", len(encoded_name))
210251
self.ti_data += encoded_name
211252
n_dims = len(tensor_shape)
@@ -237,7 +278,7 @@ def add_tensor_info(
237278
self.ti_data_count += 1
238279

239280
def add_tensor(
240-
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
281+
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
241282
raw_dtype: GGMLQuantizationType | None = None,
242283
) -> None:
243284
if self.endianess == GGUFEndian.BIG:
@@ -262,7 +303,7 @@ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None
262303
if pad != 0:
263304
fp.write(bytes([0] * pad))
264305

265-
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
306+
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
266307
if self.state is not WriterState.TI_DATA:
267308
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
268309

@@ -272,15 +313,33 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
272313
tensor.tofile(self.fout)
273314
self.write_padding(self.fout, tensor.nbytes)
274315

275-
def write_tensors_to_file(self) -> None:
316+
def write_tensors_to_file(self, *, progress: bool = False) -> None:
276317
self.write_ti_data_to_file()
277318

278319
self.write_padding(self.fout, self.fout.tell())
279320

280321
if self.temp_file is None:
322+
self.tensors.reverse() # to pop from the "beginning" in constant time
323+
324+
if progress:
325+
from tqdm import tqdm
326+
327+
total_bytes = sum(t.nbytes for t in self.tensors)
328+
329+
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
330+
331+
while True:
332+
try:
333+
tensor = self.tensors.pop()
334+
except IndexError:
335+
break
336+
tensor.tofile(self.fout)
337+
bar.update(tensor.nbytes)
338+
self.write_padding(self.fout, tensor.nbytes)
339+
return
281340
while True:
282341
try:
283-
tensor = self.tensors.pop(0)
342+
tensor = self.tensors.pop()
284343
except IndexError:
285344
break
286345
tensor.tofile(self.fout)
@@ -479,7 +538,7 @@ def add_add_space_prefix(self, value: bool) -> None:
479538
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
480539

481540
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
482-
if isinstance(value, list):
541+
if not isinstance(value, str):
483542
template_default = None
484543
template_names = set()
485544

gguf-py/gguf/vocab.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import os
66
from pathlib import Path
7-
from typing import Any, Callable
7+
from typing import Any, Callable, Sequence, Mapping, Iterable
88

99
from .gguf_writer import GGUFWriter
1010

@@ -15,11 +15,11 @@ class SpecialVocab:
1515
merges: list[str]
1616
add_special_token: dict[str, bool]
1717
special_token_ids: dict[str, int]
18-
chat_template: str | None
18+
chat_template: str | Sequence[Mapping[str, str]] | None
1919

2020
def __init__(
2121
self, path: str | os.PathLike[str], load_merges: bool = False,
22-
special_token_types: tuple[str, ...] | None = None,
22+
special_token_types: Iterable[str] | None = None,
2323
n_vocab: int | None = None,
2424
):
2525
self.special_token_ids = {}

gguf-py/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ classifiers = [
2121
[tool.poetry.dependencies]
2222
python = ">=3.8"
2323
numpy = ">=1.17"
24+
tqdm = ">=4.27"
2425

2526
[tool.poetry.dev-dependencies]
2627
pytest = "^5.2"

gguf-py/scripts/gguf-dump.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
4747
if len(field.types) == 1:
4848
curr_type = field.types[0]
4949
if curr_type == GGUFValueType.STRING:
50-
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60]))
50+
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]))
5151
elif field.types[0] in reader.gguf_scalar_to_np:
5252
log_message += ' = {0}'.format(field.parts[-1][0])
5353
print(log_message) # noqa: NP100

gguf-py/scripts/gguf-new-metadata.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88

99
import numpy as np
10-
from typing import Any, Mapping, Sequence
10+
from typing import Any, Sequence
1111

1212
# Necessary to load the local gguf package
1313
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
@@ -34,19 +34,19 @@ def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
3434
return host_endian
3535

3636

37-
def decode_field(field: gguf.ReaderField) -> Any:
37+
def decode_field(field: gguf.ReaderField | None) -> Any:
3838
if field and field.types:
3939
main_type = field.types[0]
4040

4141
if main_type == gguf.GGUFValueType.ARRAY:
4242
sub_type = field.types[-1]
4343

4444
if sub_type == gguf.GGUFValueType.STRING:
45-
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
45+
return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
4646
else:
4747
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
4848
if main_type == gguf.GGUFValueType.STRING:
49-
return str(bytes(field.parts[-1]), encoding='utf8')
49+
return str(bytes(field.parts[-1]), encoding='utf-8')
5050
else:
5151
return field.parts[-1][0]
5252

@@ -59,7 +59,7 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
5959
return decode_field(field)
6060

6161

62-
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
62+
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
6363
for field in reader.fields.values():
6464
# Suppress virtual fields and fields written by GGUFWriter
6565
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
@@ -101,7 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
101101

102102
for tensor in reader.tensors:
103103
# Dimensions are written in reverse order, so flip them first
104-
shape = np.flipud(tensor.shape)
104+
shape = np.flipud(tensor.shape).tolist()
105105
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
106106

107107
writer.write_header_to_file()

pyrightconfig.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"extraPaths": ["gguf-py"],
3+
}
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
-r ./requirements-convert.txt
22
torch~=2.1.1
3-
einops~=0.7.0
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
-r ./requirements-convert.txt
22
torch~=2.1.1
3-
einops~=0.7.0

requirements/requirements-convert.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
numpy~=1.24.4
2-
sentencepiece~=0.1.98
2+
sentencepiece~=0.2.0
33
transformers>=4.40.1,<5.0.0
44
gguf>=0.1.0
55
protobuf>=4.21.0,<5.0.0

0 commit comments

Comments
 (0)