Skip to content

Commit f188ecd

Browse files
slarenggerganov
authored andcommitted
ggml : mul_mat_id use the same tensor for all the experts (ggml-org#6387)
* ggml : update mul_mat_id to use the same tensor for all the experts * update cuda * minor * update metal * update test-backend-ops * fix cuda * Update ggml-metal.m Co-authored-by: Georgi Gerganov <[email protected]> * update convert.py * update convert-hf-to-gguf.py * update convert.py for mixtral hf models * Update convert-hf-to-gguf.py Co-authored-by: Georgi Gerganov <[email protected]> * cuda : support non-pow-2 number of experts * allow quantize to work for split and merged experts models in the same way * cleanup + disable mmap automatically with split tensors models * update imatrix * test-backend-ops : test qwen argsort * update grok model loading * llama : add merged experts tensors to the grok tensor map * minor * gguf : bump version * fix quantizing of merged experts * convert-hf-to-gguf.py : update grok (untested) * make linter happy * cuda/argsort : use shared memory instead of pool memory * convert : fix grok tensor names * metal : add support for non-pow-2 argsort * llama : more loader cleanup, better error checking * cuda : fix warning * llama : still use mmap for loading old models, but copy the data to a host buffer * add review note * llama : remove ffn tensor counting + add sanity check ggml-ci * convert : fix handling of n_experts == None ggml-ci * imatrix : fix ncall counters * llama : produce error if imatrix size does not match * quantize : terminate on errors + trace logs ggml-ci * metal : pad shared memory to 16 bytes --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent d9b0606 commit f188ecd

15 files changed

+744
-876
lines changed

convert-hf-to-gguf.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,8 @@ def write_tensors(self):
12161216
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
12171217
n_head = self.hparams.get("num_attention_heads")
12181218
n_kv_head = self.hparams.get("num_key_value_heads")
1219+
n_experts = self.hparams.get("num_local_experts")
1220+
experts = dict()
12191221
for name, data_torch in self.get_tensors():
12201222
# we don't need these
12211223
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
@@ -1236,6 +1238,49 @@ def write_tensors(self):
12361238

12371239
data = data.squeeze()
12381240

1241+
# process the experts separately
1242+
if name.find("block_sparse_moe.experts") != -1:
1243+
experts[name] = data
1244+
if len(experts) >= n_experts:
1245+
# merge the experts into a single 3d tensor
1246+
for bid in range(block_count):
1247+
for wid in range(1, 4):
1248+
full = True
1249+
for xid in range(n_experts):
1250+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
1251+
if ename not in experts:
1252+
full = False
1253+
break
1254+
if not full:
1255+
continue
1256+
1257+
datas = []
1258+
for xid in range(n_experts):
1259+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
1260+
datas.append(experts[ename])
1261+
del experts[ename]
1262+
1263+
data = np.stack(datas, axis=0)
1264+
data_dtype = data.dtype
1265+
1266+
if self.ftype == 0 and data_dtype == np.float16:
1267+
data = data.astype(np.float32)
1268+
1269+
if self.ftype == 1 and data_dtype == np.float32:
1270+
data = data.astype(np.float16)
1271+
1272+
merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight"
1273+
1274+
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
1275+
if new_name is None:
1276+
print(f"Can not map tensor {name!r}")
1277+
sys.exit()
1278+
1279+
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
1280+
1281+
self.gguf_writer.add_tensor(new_name, data)
1282+
continue
1283+
12391284
# map tensor names
12401285
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
12411286
if new_name is None:
@@ -1249,7 +1294,7 @@ def write_tensors(self):
12491294
if self.ftype == 0 and data_dtype == np.float16:
12501295
data = data.astype(np.float32)
12511296

1252-
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1297+
# 1d tensors need to be converted to float32
12531298
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
12541299
data = data.astype(np.float32)
12551300

@@ -1261,6 +1306,9 @@ def write_tensors(self):
12611306

12621307
self.gguf_writer.add_tensor(new_name, data)
12631308

1309+
if len(experts) > 0:
1310+
raise ValueError(f"Unprocessed experts: {experts.keys()}")
1311+
12641312

12651313
@Model.register("GrokForCausalLM")
12661314
class GrokModel(Model):
@@ -1276,6 +1324,92 @@ def set_gguf_parameters(self):
12761324
super().set_gguf_parameters()
12771325
self.gguf_writer.add_name("Grok")
12781326

1327+
def write_tensors(self):
1328+
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
1329+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1330+
n_experts = self.hparams.get("num_local_experts")
1331+
experts = dict()
1332+
for name, data_torch in self.get_tensors():
1333+
# we don't need these
1334+
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
1335+
continue
1336+
1337+
old_dtype = data_torch.dtype
1338+
1339+
# convert any unsupported data types to float32
1340+
if data_torch.dtype not in (torch.float16, torch.float32):
1341+
data_torch = data_torch.to(torch.float32)
1342+
1343+
data = data_torch.squeeze().numpy()
1344+
1345+
# process the experts separately
1346+
if name.find(".moe.") != -1:
1347+
experts[name] = data
1348+
if len(experts) >= n_experts:
1349+
# merge the experts into a single 3d tensor
1350+
for bid in range(block_count):
1351+
for wid in ["linear", "linear_1", "linear_v"]:
1352+
full = True
1353+
for xid in range(n_experts):
1354+
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
1355+
if ename not in experts:
1356+
full = False
1357+
break
1358+
if not full:
1359+
continue
1360+
1361+
datas = []
1362+
for xid in range(n_experts):
1363+
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
1364+
datas.append(experts[ename])
1365+
del experts[ename]
1366+
1367+
data = np.stack(datas, axis=0)
1368+
data_dtype = data.dtype
1369+
1370+
if self.ftype == 0 and data_dtype == np.float16:
1371+
data = data.astype(np.float32)
1372+
1373+
if self.ftype == 1 and data_dtype == np.float32:
1374+
data = data.astype(np.float16)
1375+
1376+
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
1377+
1378+
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
1379+
if new_name is None:
1380+
print(f"Can not map tensor {name!r}")
1381+
sys.exit()
1382+
1383+
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
1384+
1385+
self.gguf_writer.add_tensor(new_name, data)
1386+
continue
1387+
1388+
# map tensor names
1389+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1390+
if new_name is None:
1391+
print(f"Can not map tensor {name!r}")
1392+
sys.exit()
1393+
1394+
n_dims = len(data.shape)
1395+
data_dtype = data.dtype
1396+
1397+
# if f32 desired, convert any float16 to float32
1398+
if self.ftype == 0 and data_dtype == np.float16:
1399+
data = data.astype(np.float32)
1400+
1401+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1402+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
1403+
data = data.astype(np.float32)
1404+
1405+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1406+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
1407+
data = data.astype(np.float16)
1408+
1409+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1410+
1411+
self.gguf_writer.add_tensor(new_name, data)
1412+
12791413

12801414
@Model.register("MiniCPMForCausalLM")
12811415
class MiniCPMModel(Model):

convert.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,15 @@ def load() -> Tensor:
828828
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
829829

830830

831+
def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor:
832+
def load() -> Tensor:
833+
tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors]
834+
return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors]))
835+
s = lazy_tensors[0].shape.copy()
836+
s.insert(0, len(lazy_tensors))
837+
return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors))
838+
839+
831840
# Functionality that simulates `torch.load` but where individual tensors are
832841
# only loaded into memory on demand, not all at once.
833842
# PyTorch can't do this natively as of time of writing:
@@ -1246,6 +1255,22 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
12461255

12471256
tmp = model
12481257

1258+
# merge experts into one tensor
1259+
if params.n_experts and params.n_experts > 0:
1260+
for i_l in range(params.n_layer):
1261+
for w in range(1, 4):
1262+
experts = []
1263+
for e in range(params.n_experts):
1264+
if f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" in model:
1265+
experts.append(model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"])
1266+
del tmp[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"]
1267+
elif f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" in model:
1268+
experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"])
1269+
del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]
1270+
else:
1271+
raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight")
1272+
tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts)
1273+
12491274
# HF models permut or pack some of the tensors, so we need to undo that
12501275
for i in itertools.count():
12511276
if f"model.layers.{i}.self_attn.q_proj.weight" in model:

examples/imatrix/imatrix.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,35 +98,38 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
9898

9999
const float * data = is_host ? (const float *) src1->data : m_src1_data.data();
100100

101+
// this has been adapted to the new format of storing merged experts in a single 3d tensor
102+
// ref: https://github.com/ggerganov/llama.cpp/pull/6387
101103
if (t->op == GGML_OP_MUL_MAT_ID) {
102104
const int idx = ((int32_t *) t->op_params)[0];
103-
const int n_as = ((int32_t *) t->op_params)[1];
105+
const ggml_tensor * ids = t->src[2];
106+
const int n_as = src0->ne[2];
104107

105-
// the top-k selected expert ids are stored in the src0 tensor
106-
// for simplicity, always copy src0 to host, because it is small
107-
// take into account that src0 is not contiguous!
108-
GGML_ASSERT(src0->ne[1] == src1->ne[1]);
109-
GGML_ASSERT(n_as*ggml_nrows(src0)*sizeof(int) == GGML_PAD(ggml_nbytes(src0), n_as*sizeof(int)));
110-
m_ids.resize(ggml_nbytes(src0)/sizeof(int));
111-
ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0));
108+
// the top-k selected expert ids are stored in the ids tensor
109+
// for simplicity, always copy ids to host, because it is small
110+
// take into account that ids is not contiguous!
111+
GGML_ASSERT(ids->ne[1] == src1->ne[1]);
112+
GGML_ASSERT(n_as*ggml_nrows(ids)*sizeof(int) == GGML_PAD(ggml_nbytes(ids), n_as*sizeof(int)));
113+
m_ids.resize(ggml_nbytes(ids)/sizeof(int));
114+
ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
115+
116+
auto & e = m_stats[wname];
117+
118+
++e.ncall;
119+
// NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
120+
// using the following line, we can correct for that if needed by replacing the line above with:
121+
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
112122

113123
// loop over all possible experts, regardless if they are used or not in the batch
114-
// this is necessary to guarantee equal number of "ncall" for each tensor
115124
for (int ex = 0; ex < n_as; ++ex) {
116-
src0 = t->src[2 + ex];
117-
wname = filter_tensor_name(src0->name);
118-
auto& e = m_stats[wname];
125+
size_t e_start = ex*src1->ne[0];
119126
if (e.values.empty()) {
120-
e.values.resize(src1->ne[0], 0);
127+
e.values.resize(src1->ne[0]*n_as, 0);
121128
}
122-
else if (e.values.size() != (size_t)src1->ne[0]) {
123-
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
129+
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
130+
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
124131
exit(1); //GGML_ASSERT(false);
125132
}
126-
// NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
127-
// using the following line, we can correct for that if needed
128-
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
129-
++e.ncall;
130133
if (m_params.verbosity > 1) {
131134
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
132135
}
@@ -136,7 +139,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
136139
if (excur != ex) continue;
137140
const float * x = data + row * src1->ne[0];
138141
for (int j = 0; j < (int)src1->ne[0]; ++j) {
139-
e.values[j] += x[j]*x[j];
142+
e.values[e_start + j] += x[j]*x[j];
140143
}
141144
}
142145
if (e.ncall > m_last_call) {

examples/quantize/quantize.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,44 +116,48 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map<st
116116
std::ifstream in(imatrix_file.c_str(), std::ios::binary);
117117
if (!in) {
118118
printf("%s: failed to open %s\n",__func__, imatrix_file.c_str());
119-
return;
119+
exit(1);
120120
}
121121
int n_entries;
122122
in.read((char *)&n_entries, sizeof(n_entries));
123123
if (in.fail() || n_entries < 1) {
124124
printf("%s: no data in file %s\n", __func__, imatrix_file.c_str());
125-
return;
125+
exit(1);
126126
}
127127
for (int i = 0; i < n_entries; ++i) {
128128
int len; in.read((char *)&len, sizeof(len));
129129
std::vector<char> name_as_vec(len+1);
130130
in.read((char *)name_as_vec.data(), len);
131131
if (in.fail()) {
132132
printf("%s: failed reading name for entry %d from %s\n", __func__, i+1, imatrix_file.c_str());
133-
return;
133+
exit(1);
134134
}
135135
name_as_vec[len] = 0;
136136
std::string name{name_as_vec.data()};
137-
auto & e = imatrix_data[std::move(name)];
137+
auto & e = imatrix_data[name];
138138
int ncall;
139139
in.read((char *)&ncall, sizeof(ncall));
140140
int nval;
141141
in.read((char *)&nval, sizeof(nval));
142142
if (in.fail() || nval < 1) {
143143
printf("%s: failed reading number of values for entry %d\n", __func__, i);
144144
imatrix_data = {};
145-
return;
145+
exit(1);
146146
}
147147
e.resize(nval);
148148
in.read((char *)e.data(), nval*sizeof(float));
149149
if (in.fail()) {
150150
printf("%s: failed reading data for entry %d\n", __func__, i);
151151
imatrix_data = {};
152-
return;
152+
exit(1);
153153
}
154154
if (ncall > 0) {
155155
for (auto& v : e) v /= ncall;
156156
}
157+
158+
if (getenv("LLAMA_TRACE")) {
159+
printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str());
160+
}
157161
}
158162
printf("%s: loaded %d importance matrix entries from %s\n", __func__, int(imatrix_data.size()), imatrix_file.c_str());
159163
}

0 commit comments

Comments
 (0)