Skip to content

Commit f8c90cd

Browse files
ds5t5ggerganov
andauthored
llm : add Refact model (#3329)
* add refact model * resolve comments * rebase to the latest * solve alibi cpu error --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent f93af02 commit f8c90cd

File tree

4 files changed

+723
-10
lines changed

4 files changed

+723
-10
lines changed

convert-refact-hf-to-gguf.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
#!/usr/bin/env python3
2+
# HF refact--> gguf conversion
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import os
9+
import sys
10+
from pathlib import Path
11+
12+
import numpy as np
13+
import torch
14+
from transformers import AutoTokenizer # type: ignore[import]
15+
16+
if "NO_LOCAL_GGUF" not in os.environ:
17+
sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf"))
18+
import gguf
19+
20+
21+
def bytes_to_unicode():
22+
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
23+
"""
24+
Returns list of utf-8 byte and a corresponding list of unicode strings.
25+
The reversible bpe codes work on unicode strings.
26+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
27+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
28+
This is a significant percentage of your normal, say, 32K bpe vocab.
29+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
30+
And avoids mapping to whitespace/control characters the bpe code barfs on.
31+
"""
32+
bs = (
33+
list(range(ord("!"), ord("~") + 1))
34+
+ list(range(ord("¡"), ord("¬") + 1))
35+
+ list(range(ord("®"), ord("ÿ") + 1))
36+
)
37+
cs = bs[:]
38+
n = 0
39+
for b in range(2**8):
40+
if b not in bs:
41+
bs.append(b)
42+
cs.append(2**8 + n)
43+
n += 1
44+
return dict(zip(bs, (chr(n) for n in cs)))
45+
46+
47+
def count_model_parts(dir_model: Path) -> int:
48+
num_parts = 0
49+
for filename in os.listdir(dir_model):
50+
if filename.startswith("pytorch_model-"):
51+
num_parts += 1
52+
53+
if num_parts > 0:
54+
print("gguf: found " + str(num_parts) + " model parts")
55+
return num_parts
56+
57+
58+
def parse_args() -> argparse.Namespace:
59+
parser = argparse.ArgumentParser(
60+
description="Convert a Refact model to a GGML compatible file"
61+
)
62+
parser.add_argument(
63+
"--vocab-only",
64+
action="store_true",
65+
help="extract only the vocab",
66+
)
67+
parser.add_argument(
68+
"--outfile",
69+
type=Path,
70+
help="path to write to; default: based on input",
71+
)
72+
parser.add_argument(
73+
"model",
74+
type=Path,
75+
help="directory containing model file, or model file itself (*.bin)",
76+
)
77+
parser.add_argument(
78+
"ftype",
79+
type=int,
80+
choices=[0, 1],
81+
default=1,
82+
nargs="?",
83+
help="output format - use 0 for float32, 1 for float16",
84+
)
85+
return parser.parse_args()
86+
87+
88+
args = parse_args()
89+
90+
dir_model = args.model
91+
ftype = args.ftype
92+
if not dir_model.is_dir():
93+
print(f"Error: {args.model} is not a directory", file=sys.stderr)
94+
sys.exit(1)
95+
96+
# possible tensor data types
97+
# ftype == 0 -> float32
98+
# ftype == 1 -> float16
99+
100+
# map from ftype to string
101+
ftype_str = ["f32", "f16"]
102+
103+
if args.outfile is not None:
104+
fname_out = args.outfile
105+
else:
106+
# output in the same directory as the model by default
107+
fname_out = dir_model / f"ggml-model-{ftype_str[ftype]}.gguf"
108+
109+
print("gguf: loading model " + dir_model.name)
110+
111+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
112+
hparams = json.load(f)
113+
114+
if hparams["architectures"][0] != "GPTRefactForCausalLM":
115+
print("Model architecture not supported: " + hparams["architectures"][0])
116+
117+
sys.exit(1)
118+
119+
# get number of model parts
120+
num_parts = count_model_parts(dir_model)
121+
122+
ARCH = gguf.MODEL_ARCH.REFACT
123+
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
124+
125+
print("gguf: get model metadata")
126+
127+
# Get refact feed forward dimension
128+
hidden_dim = hparams["n_embd"]
129+
inner_dim = 4 * hidden_dim
130+
hidden_dim = int(2 * inner_dim / 3)
131+
multiple_of = 256
132+
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
133+
134+
block_count = hparams["n_layer"]
135+
136+
gguf_writer.add_name("Refact")
137+
# refact uses Alibi. So this is from config.json which might be used by training.
138+
gguf_writer.add_context_length(hparams["n_positions"])
139+
gguf_writer.add_embedding_length(hparams["n_embd"])
140+
141+
gguf_writer.add_feed_forward_length(ff_dim)
142+
gguf_writer.add_block_count(block_count)
143+
gguf_writer.add_head_count(hparams["n_head"])
144+
gguf_writer.add_head_count_kv(1)
145+
gguf_writer.add_layer_norm_rms_eps(hparams["layer_norm_epsilon"])
146+
gguf_writer.add_file_type(ftype)
147+
148+
# TOKENIZATION
149+
150+
print("gguf: get tokenizer metadata")
151+
152+
tokens: list[bytearray] = []
153+
scores: list[float] = []
154+
toktypes: list[int] = []
155+
156+
tokenizer_json_file = dir_model / "tokenizer.json"
157+
if not tokenizer_json_file.is_file():
158+
print(f"Error: Missing {tokenizer_json_file}", file=sys.stderr)
159+
sys.exit(1)
160+
161+
# gpt2 tokenizer
162+
gguf_writer.add_tokenizer_model("gpt2")
163+
164+
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
165+
tokenizer_json = json.load(f)
166+
167+
print("gguf: get gpt2 tokenizer vocab")
168+
169+
# The number of tokens in tokenizer.json can differ from the expected vocab size.
170+
# This causes downstream issues with mismatched tensor sizes when running the inference
171+
vocab_size = (
172+
hparams["vocab_size"]
173+
if "vocab_size" in hparams
174+
else len(tokenizer_json["model"]["vocab"])
175+
)
176+
177+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
178+
179+
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
180+
byte_encoder = bytes_to_unicode()
181+
byte_decoder = {v: k for k, v in byte_encoder.items()}
182+
183+
for i in range(vocab_size):
184+
if i in reverse_vocab:
185+
text = reverse_vocab[i]
186+
try:
187+
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
188+
except KeyError:
189+
text = bytearray()
190+
for c in reverse_vocab[i]:
191+
if ord(c) < 256: # single byte character
192+
text.append(byte_decoder[ord(c)])
193+
else: # multibyte special token character
194+
text.extend(c.encode("utf-8"))
195+
else:
196+
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
197+
pad_token = f"[PAD{i}]".encode("utf8")
198+
text = bytearray(pad_token)
199+
200+
tokens.append(text)
201+
scores.append(0.0) # dymmy
202+
toktypes.append(gguf.TokenType.NORMAL) # dummy
203+
204+
gguf_writer.add_token_list(tokens)
205+
gguf_writer.add_token_scores(scores)
206+
gguf_writer.add_token_types(toktypes)
207+
208+
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
209+
special_vocab.add_to_gguf(gguf_writer)
210+
211+
# TENSORS
212+
213+
tensor_map = gguf.get_tensor_name_map(ARCH, block_count)
214+
215+
# params for qkv transform
216+
n_head = hparams["n_head"]
217+
n_head_kv = 1
218+
219+
head_dim = hparams["n_embd"] // n_head
220+
221+
# tensor info
222+
print("gguf: get tensor metadata")
223+
224+
if num_parts == 0:
225+
part_names = iter(("pytorch_model.bin",))
226+
else:
227+
part_names = (
228+
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
229+
)
230+
for part_name in part_names:
231+
if args.vocab_only:
232+
break
233+
print("gguf: loading model part '" + part_name + "'")
234+
model_part = torch.load(dir_model / part_name, map_location="cpu")
235+
236+
for i in range(block_count):
237+
if f"transformer.h.{i}.attn.kv.weight" in model_part:
238+
data = model_part[f"transformer.h.{i}.attn.kv.weight"]
239+
model_part[f"model.layers.{i}.self_attn.k_proj.weight"] = data[
240+
: n_head_kv * head_dim
241+
]
242+
model_part[f"model.layers.{i}.self_attn.v_proj.weight"] = data[
243+
n_head_kv * head_dim :
244+
]
245+
del model_part[f"transformer.h.{i}.attn.kv.weight"]
246+
if f"transformer.h.{i}.attn.q.weight" in model_part:
247+
model_part[f"model.layers.{i}.self_attn.q_proj.weight"] = model_part[
248+
f"transformer.h.{i}.attn.q.weight"
249+
]
250+
del model_part[f"transformer.h.{i}.attn.q.weight"]
251+
if f"transformer.h.{i}.mlp.gate_up_proj.weight" in model_part:
252+
data = model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
253+
model_part[f"model.layers.{i}.mlp.gate_proj.weight"] = data[:ff_dim]
254+
model_part[f"model.layers.{i}.mlp.up_proj.weight"] = data[ff_dim:]
255+
del model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
256+
257+
for name in model_part.keys():
258+
data = model_part[name]
259+
260+
old_dtype = data.dtype
261+
262+
# convert any unsupported data types to float32
263+
if data.dtype != torch.float16 and data.dtype != torch.float32:
264+
data = data.to(torch.float32)
265+
266+
data = data.squeeze().numpy()
267+
268+
# map tensor names
269+
new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
270+
if new_name is None:
271+
print("Can not map tensor '" + name + "'")
272+
sys.exit()
273+
274+
n_dims = len(data.shape)
275+
data_dtype = data.dtype
276+
277+
# if f32 desired, convert any float16 to float32
278+
if ftype == 0 and data_dtype == np.float16:
279+
data = data.astype(np.float32)
280+
281+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
282+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
283+
data = data.astype(np.float32)
284+
285+
# if f16 desired, convert any float32 2-dim weight tensors to float16
286+
if (
287+
ftype == 1
288+
and data_dtype == np.float32
289+
and name.endswith(".weight")
290+
and n_dims == 2
291+
):
292+
data = data.astype(np.float16)
293+
294+
print(
295+
new_name
296+
+ ", n_dims = "
297+
+ str(n_dims)
298+
+ ", "
299+
+ str(old_dtype)
300+
+ " --> "
301+
+ str(data.dtype)
302+
)
303+
304+
gguf_writer.add_tensor(new_name, data)
305+
306+
307+
print("gguf: write header")
308+
gguf_writer.write_header_to_file()
309+
print("gguf: write metadata")
310+
gguf_writer.write_kv_data_to_file()
311+
if not args.vocab_only:
312+
print("gguf: write tensors")
313+
gguf_writer.write_tensors_to_file()
314+
315+
gguf_writer.close()
316+
317+
print(f"gguf: model successfully exported to '{fname_out}'")
318+
print("")

ggml.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13082,7 +13082,6 @@ static void ggml_compute_forward_alibi_f32(
1308213082
return;
1308313083
}
1308413084

13085-
const int n_past = ((int32_t *) dst->op_params)[0];
1308613085
const int n_head = ((int32_t *) dst->op_params)[1];
1308713086
float max_bias;
1308813087
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -13103,7 +13102,6 @@ static void ggml_compute_forward_alibi_f32(
1310313102
//const int nb3 = src0->nb[3];
1310413103

1310513104
GGML_ASSERT(nb0 == sizeof(float));
13106-
GGML_ASSERT(ne1 + n_past == ne0);
1310713105
GGML_ASSERT(n_head == ne2);
1310813106

1310913107
// add alibi to src0 (KQ_scaled)

0 commit comments

Comments
 (0)