Skip to content

Commit 48b2f9c

Browse files
jukofyorkslaren
andauthored
Fixed save_imatrix to match old behaviour for MoE (#7099)
* Fixed save_imatrix to match old behaviour for MoE This fix is simple and clear, but unnecessarily doubles the memory overhead.. * Fixed missing idx variable * Unconditionally increment ncall Co-authored-by: slaren <[email protected]> * Fixed 2 bugs in save_imatrix() - Fixed segfault bug because the counts vector needed to be created. - Fixed pre-existing bug didn't actually add to the counts for "--combine" option. * ncall needs summing too * Trailing whitespace --------- Co-authored-by: slaren <[email protected]>
1 parent af0a5b6 commit 48b2f9c

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

examples/imatrix/imatrix.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
struct Stats {
2121
std::vector<float> values;
22+
std::vector<int> counts;
2223
int ncall = 0;
2324
};
2425

@@ -121,12 +122,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
121122
auto & e = m_stats[wname];
122123

123124
++e.ncall;
124-
// NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
125-
// using the following line, we can correct for that if needed by replacing the line above with:
126-
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
127125

128126
if (e.values.empty()) {
129127
e.values.resize(src1->ne[0]*n_as, 0);
128+
e.counts.resize(src1->ne[0]*n_as, 0);
130129
}
131130
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
132131
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
@@ -153,6 +152,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
153152

154153
for (int j = 0; j < (int)src1->ne[0]; ++j) {
155154
e.values[e_start + j] += x[j]*x[j];
155+
e.counts[e_start + j]++;
156156
}
157157
}
158158
}
@@ -170,6 +170,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
170170
auto& e = m_stats[wname];
171171
if (e.values.empty()) {
172172
e.values.resize(src1->ne[0], 0);
173+
e.counts.resize(src1->ne[0], 0);
173174
}
174175
else if (e.values.size() != (size_t)src1->ne[0]) {
175176
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
@@ -183,6 +184,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
183184
const float * x = data + row * src1->ne[0];
184185
for (int j = 0; j < (int)src1->ne[0]; ++j) {
185186
e.values[j] += x[j]*x[j];
187+
e.counts[j]++;
186188
}
187189
}
188190
if (e.ncall > m_last_call) {
@@ -222,7 +224,13 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co
222224
out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
223225
int nval = p.second.values.size();
224226
out.write((const char *) &nval, sizeof(nval));
225-
if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float));
227+
if (nval > 0) {
228+
std::vector<float> tmp(nval);
229+
for (int i = 0; i < nval; i++) {
230+
tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
231+
}
232+
out.write((const char*)tmp.data(), nval*sizeof(float));
233+
}
226234
}
227235

228236
// Write the number of call the matrix was computed with
@@ -270,14 +278,28 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma
270278
imatrix_data = {};
271279
return false;
272280
}
273-
e.values.resize(nval);
274-
in.read((char*)e.values.data(), nval*sizeof(float));
281+
282+
// When re-called from load_imatrix() with add set, this will already be created.
283+
if (e.values.empty()) {
284+
e.values.resize(nval, 0);
285+
e.counts.resize(nval, 0);
286+
}
287+
288+
std::vector<float> tmp(nval);
289+
in.read((char*)tmp.data(), nval*sizeof(float));
275290
if (in.fail()) {
276291
printf("%s: failed reading data for entry %d\n",__func__,i);
277292
imatrix_data = {};
278293
return false;
279294
}
280-
e.ncall = ncall;
295+
296+
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
297+
for (int i = 0; i < nval; i++) {
298+
e.values[i] += tmp[i];
299+
e.counts[i] += ncall;
300+
}
301+
e.ncall += ncall;
302+
281303
}
282304
return true;
283305
}

0 commit comments

Comments
 (0)