Skip to content

Commit 82478f1

Browse files
compiladearthw
authored andcommitted
gguf-py : fix some metadata name extraction edge cases (ggml-org#8591)
* gguf-py : fix some metadata name extraction edge cases * convert_lora : use the lora dir for the model card path * gguf-py : more metadata edge cases fixes Multiple finetune versions are now joined together, and the removal of the basename annotation on trailing versions is more robust. * gguf-py : add more name metadata extraction tests * convert_lora : fix default filename The default filename was previously hardcoded. * convert_hf : Model.fname_out can no longer be None * gguf-py : do not use title case for naming convention Some models use acronyms in lowercase, which can't be title-cased like other words, so it's best to simply use the same case as in the original model name. Note that the size label still has an uppercased suffix to make it distinguishable from the context size of a finetune.
1 parent 264c283 commit 82478f1

File tree

5 files changed

+112
-44
lines changed

5 files changed

+112
-44
lines changed

convert_hf_to_gguf.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Model:
4848

4949
dir_model: Path
5050
ftype: gguf.LlamaFileType
51-
fname_out: Path | None
51+
fname_out: Path
5252
is_big_endian: bool
5353
endianess: gguf.GGUFEndian
5454
use_temp_file: bool
@@ -62,11 +62,12 @@ class Model:
6262
gguf_writer: gguf.GGUFWriter
6363
model_name: str | None
6464
metadata_override: Path | None
65+
dir_model_card: Path
6566

6667
# subclasses should define this!
6768
model_arch: gguf.MODEL_ARCH
6869

69-
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False,
70+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
7071
use_temp_file: bool = False, eager: bool = False,
7172
metadata_override: Path | None = None, model_name: str | None = None,
7273
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
@@ -90,6 +91,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path |
9091
self.tensor_names = None
9192
self.metadata_override = metadata_override
9293
self.model_name = model_name
94+
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
9395

9496
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
9597
if self.ftype == gguf.LlamaFileType.GUESSED:
@@ -345,7 +347,7 @@ def prepare_metadata(self, vocab_only: bool):
345347

346348
total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
347349

348-
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params)
350+
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params)
349351

350352
# Fallback to model directory name if metadata name is still missing
351353
if self.metadata.name is None:
@@ -359,27 +361,22 @@ def prepare_metadata(self, vocab_only: bool):
359361
output_type: str = self.ftype.name.partition("_")[2]
360362

361363
# Filename Output
362-
# Note: `not is_dir()` is used because `.is_file()` will not detect
363-
# file template strings as it doesn't actually exist as a file
364-
if self.fname_out is not None and not self.fname_out.is_dir():
365-
# Output path is a custom defined templated filename
366-
367-
# Process templated file name with the output ftype, useful with the "auto" ftype
368-
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
369-
else:
364+
if self.fname_out.is_dir():
370365
# Generate default filename based on model specification and available metadata
371366
if not vocab_only:
372367
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None)
373368
else:
374369
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab")
375370

376-
# Check if preferred output directory path was provided
377-
if self.fname_out is not None and self.fname_out.is_dir():
378-
# output path is a directory
379-
self.fname_out = self.fname_out / f"{fname_default}.gguf"
380-
else:
381-
# output in the same directory as the model by default
382-
self.fname_out = self.dir_model / f"{fname_default}.gguf"
371+
# Use the default filename
372+
self.fname_out = self.fname_out / f"{fname_default}.gguf"
373+
else:
374+
# Output path is a custom defined templated filename
375+
# Note: `not is_dir()` is used because `.is_file()` will not detect
376+
# file template strings as it doesn't actually exist as a file
377+
378+
# Process templated file name with the output ftype, useful with the "auto" ftype
379+
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
383380

384381
self.set_type()
385382

@@ -3634,10 +3631,10 @@ def main() -> None:
36343631
logger.error("Error: Cannot use temp file when splitting")
36353632
sys.exit(1)
36363633

3637-
fname_out = None
3638-
36393634
if args.outfile is not None:
36403635
fname_out = args.outfile
3636+
else:
3637+
fname_out = dir_model
36413638

36423639
logger.info(f"Loading model: {dir_model.name}")
36433640

@@ -3668,7 +3665,6 @@ def main() -> None:
36683665
else:
36693666
logger.info("Exporting model...")
36703667
model_instance.write()
3671-
assert model_instance.fname_out is not None
36723668
out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
36733669
logger.info(f"Model successfully exported to {out_path}")
36743670

convert_lora_to_gguf.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def parse_args() -> argparse.Namespace:
290290
fname_out = args.outfile
291291
else:
292292
# output in the same directory as the model by default
293-
fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
293+
fname_out = dir_lora
294294

295295
if os.path.exists(input_model):
296296
# lazy import load_file only if lora is in safetensors format.
@@ -304,12 +304,6 @@ def parse_args() -> argparse.Namespace:
304304
# load base model
305305
logger.info(f"Loading base model: {dir_base_model.name}")
306306
hparams = Model.load_hparams(dir_base_model)
307-
308-
with open(lora_config, "r") as f:
309-
lparams: dict[str, Any] = json.load(f)
310-
311-
alpha: float = lparams["lora_alpha"]
312-
313307
with torch.inference_mode():
314308
try:
315309
model_class = Model.from_model_architecture(hparams["architectures"][0])
@@ -320,12 +314,21 @@ def parse_args() -> argparse.Namespace:
320314
class LoraModel(model_class):
321315
model_arch = model_class.model_arch
322316

317+
lora_alpha: float
318+
319+
def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
320+
321+
super().__init__(*args, **kwargs)
322+
323+
self.dir_model_card = dir_lora_model
324+
self.lora_alpha = float(lora_alpha)
325+
323326
def set_type(self):
324327
self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
325328
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
326329

327330
def set_gguf_parameters(self):
328-
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
331+
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
329332
super().set_gguf_parameters()
330333

331334
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
@@ -368,6 +371,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
368371
yield (dest_name + ".lora_a", lora_a)
369372
yield (dest_name + ".lora_b", lora_b)
370373

374+
with open(lora_config, "r") as f:
375+
lparams: dict[str, Any] = json.load(f)
376+
377+
alpha: float = lparams["lora_alpha"]
378+
371379
model_instance = LoraModel(
372380
dir_base_model,
373381
ftype,
@@ -376,6 +384,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
376384
use_temp_file=False,
377385
eager=args.no_lazy,
378386
dry_run=args.dry_run,
387+
dir_lora_model=dir_lora,
388+
lora_alpha=alpha,
379389
)
380390

381391
logger.info("Exporting model...")

gguf-py/gguf/metadata.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat
5454

5555
model_card = Metadata.load_model_card(model_path)
5656
hf_params = Metadata.load_hf_parameters(model_path)
57+
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
5758

5859
# heuristics
5960
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
@@ -177,6 +178,12 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int =
177178
org_component = None
178179

179180
name_parts: list[str] = model_full_name_component.split('-')
181+
182+
# Remove empty parts
183+
for i in reversed(range(len(name_parts))):
184+
if len(name_parts[i]) == 0:
185+
del name_parts[i]
186+
180187
name_types: list[
181188
set[Literal["basename", "size_label", "finetune", "version", "type"]]
182189
] = [set() for _ in name_parts]
@@ -223,9 +230,19 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int =
223230
name_parts[i] = part
224231
# Some easy to recognize finetune names
225232
elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
226-
name_types[i].add("finetune")
227-
if part.lower() == "lora":
228-
name_parts[i] = "LoRA"
233+
if total_params < 0 and part.lower() == "lora":
234+
# ignore redundant "lora" in the finetune part when the output is a lora adapter
235+
name_types[i].add("type")
236+
else:
237+
name_types[i].add("finetune")
238+
239+
# Ignore word-based size labels when there is at least a number-based one present
240+
# TODO: should word-based size labels always be removed instead?
241+
if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
242+
for n, t in zip(name_parts, name_types):
243+
if "size_label" in t:
244+
if all(c.isalpha() for c in n):
245+
t.remove("size_label")
229246

230247
at_start = True
231248
# Find the basename through the annotated name
@@ -240,18 +257,18 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int =
240257

241258
# Remove the basename annotation from trailing version
242259
for part, t in zip(reversed(name_parts), reversed(name_types)):
243-
if "basename" in t:
244-
if len(t) > 1:
245-
t.remove("basename")
260+
if "basename" in t and len(t) > 1:
261+
t.remove("basename")
246262
else:
247263
break
248264

249265
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
250-
size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
266+
# Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
267+
size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
251268
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
252269
# TODO: should the basename version always be excluded?
253-
# TODO: should multiple versions be joined together?
254-
version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
270+
# NOTE: multiple finetune versions are joined together
271+
version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
255272

256273
if size_label is None and finetune is None and version is None:
257274
# Too ambiguous, output nothing

gguf-py/gguf/utility.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
5050
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
5151

5252
if base_name is not None:
53-
name = base_name.strip().title().replace(' ', '-').replace('/', '-')
53+
name = base_name.strip().replace(' ', '-').replace('/', '-')
5454
elif model_name is not None:
55-
name = model_name.strip().title().replace(' ', '-').replace('/', '-')
55+
name = model_name.strip().replace(' ', '-').replace('/', '-')
5656
else:
5757
name = "ggml-model"
5858

5959
parameters = f"-{size_label}" if size_label is not None else ""
6060

61-
finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
61+
finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
6262

6363
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
6464

gguf-py/tests/test_metadata.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_get_model_id_components(self):
5454
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
5555
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
5656

57-
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
57+
# Non standard naming
5858
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
5959
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
6060

@@ -71,7 +71,7 @@ def test_get_model_id_components(self):
7171
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
7272
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
7373

74-
# None standard and not easy to disambiguate
74+
# Non standard and not easy to disambiguate
7575
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
7676
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
7777

@@ -123,6 +123,51 @@ def test_get_model_id_components(self):
123123
self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
124124
('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
125125

126+
# Ignore full-text size labels when there are number-based ones, and deduplicate size labels
127+
self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"),
128+
('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B'))
129+
130+
# Instruct in a name without a size label
131+
self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"),
132+
('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None))
133+
134+
# Non-obvious splitting relying on 'chat' keyword
135+
self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"),
136+
('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None))
137+
138+
# Multiple versions
139+
self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"),
140+
('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B'))
141+
142+
# TODO: DPO in the name
143+
self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"),
144+
('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B'))
145+
146+
# DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename
147+
self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"),
148+
('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B'))
149+
150+
# Too ambiguous
151+
# TODO: should "base" be a 'finetune' or 'size_label'?
152+
# (in this case it should be a size label, but other models use it to signal that they are not finetuned)
153+
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"),
154+
('Florence-2-base', 'microsoft', None, None, None, None))
155+
156+
## Invalid cases ##
157+
158+
# Start with a dash and has dashes in rows
159+
self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"),
160+
('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None))
161+
162+
## LoRA ##
163+
164+
self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"),
165+
('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B'))
166+
167+
# Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix
168+
self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234),
169+
('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B'))
170+
126171
def test_apply_metadata_heuristic_from_model_card(self):
127172
model_card = {
128173
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
@@ -134,7 +179,7 @@ def test_apply_metadata_heuristic_from_model_card(self):
134179
}
135180
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
136181
expect = gguf.Metadata()
137-
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
182+
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
138183
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
139184
expect.languages=['en']
140185
expect.datasets=['teknium/OpenHermes-2.5']

0 commit comments

Comments
 (0)