Skip to content

Commit 334bee0

Browse files
committed
add test to compare old vs new rope
1 parent 0a47037 commit 334bee0

File tree

1 file changed

+328
-0
lines changed

1 file changed

+328
-0
lines changed

tests/test-rope.cpp

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,332 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
124124
ggml_graph_compute(graph, &plan);
125125
}
126126

127+
//
128+
// test comparing rope and rope_comp
129+
//
130+
131+
struct test_rope {
132+
const ggml_type type;
133+
const std::array<int64_t, 4> ne_a;
134+
int n_dims;
135+
int mode;
136+
int n_ctx; // used to generate positions
137+
float fs; // freq_scale
138+
float ef; // ext_factor
139+
float af; // attn_factor
140+
bool ff;
141+
int v; // view (1 : non-contiguous a)
142+
bool forward; // unused for now
143+
bool inplace;
144+
145+
bool use_comp = false;
146+
147+
std::string vars() {
148+
char buf[256];
149+
snprintf(buf, sizeof(buf),
150+
"type=%d ne=(%lld,%lld,%lld,%lld) n_dims=%d mode=%d fs=%f ef=%f af=%f ff=%d v=%d inplace=%d",
151+
type, ne_a[0], ne_a[1], ne_a[2], ne_a[3], n_dims, mode, fs, ef, af, ff ? 1 : 0, v, inplace ? 1 : 0);
152+
return std::string(buf);
153+
}
154+
155+
test_rope(ggml_type type = GGML_TYPE_F32,
156+
std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
157+
int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f,
158+
float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false)
159+
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {}
160+
161+
ggml_tensor * _ggml_rope_multi(
162+
struct ggml_context * ctx,
163+
struct ggml_tensor * a,
164+
struct ggml_tensor * b,
165+
struct ggml_tensor * c,
166+
int n_dims,
167+
int sections[GGML_MROPE_SECTIONS],
168+
int mode,
169+
int n_ctx_orig,
170+
float freq_base,
171+
float freq_scale,
172+
float ext_factor,
173+
float attn_factor,
174+
float beta_fast,
175+
float beta_slow) {
176+
if (use_comp) {
177+
return nullptr;
178+
} else {
179+
return ggml_rope_multi(
180+
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig,
181+
freq_base, freq_scale, ext_factor, attn_factor,
182+
beta_fast, beta_slow);
183+
}
184+
}
185+
186+
struct ggml_tensor * _ggml_rope_ext(
187+
struct ggml_context * ctx,
188+
struct ggml_tensor * a,
189+
struct ggml_tensor * b,
190+
struct ggml_tensor * c,
191+
int n_dims,
192+
int mode,
193+
int n_ctx_orig,
194+
float freq_base,
195+
float freq_scale,
196+
float ext_factor,
197+
float attn_factor,
198+
float beta_fast,
199+
float beta_slow) {
200+
if (use_comp) {
201+
b = ggml_cast(ctx, b, GGML_TYPE_F32); // pos must be F32
202+
return ggml_rope_comp(
203+
ctx, a, b, n_dims,
204+
freq_base, GGML_ROPE_ORDERING_NORMAL);
205+
} else {
206+
return ggml_rope_ext(
207+
ctx, a, b, c, n_dims, mode, n_ctx_orig,
208+
freq_base, freq_scale, ext_factor, attn_factor,
209+
beta_fast, beta_slow);
210+
}
211+
}
212+
213+
ggml_tensor * build_graph(ggml_context * ctx) {
214+
ggml_tensor * a;
215+
if (v & 1) {
216+
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
217+
a = ggml_new_tensor(ctx, type, 4, ne.data());
218+
if (forward) {
219+
ggml_set_param(a);
220+
}
221+
ggml_set_name(a, "a");
222+
223+
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
224+
ggml_set_name(a, "view_of_a");
225+
} else {
226+
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
227+
if (forward) {
228+
ggml_set_param(a);
229+
}
230+
ggml_set_name(a, "a");
231+
}
232+
233+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
234+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
235+
236+
ggml_tensor * pos;
237+
if (is_mrope || is_vision) {
238+
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
239+
} else {
240+
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
241+
}
242+
ggml_set_name(pos, "pos");
243+
244+
ggml_tensor * freq = nullptr;
245+
if (ff) {
246+
freq = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2);
247+
ggml_set_name(freq, "freq");
248+
}
249+
250+
ggml_tensor * out = nullptr;
251+
if (is_mrope) {
252+
if (is_vision) {
253+
GGML_ASSERT(n_dims/4 > 0);
254+
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
255+
if (forward) {
256+
if (inplace) {
257+
//out = _ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
258+
} else {
259+
out = _ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
260+
}
261+
} else {
262+
//out = _ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
263+
}
264+
} else {
265+
GGML_ASSERT(n_dims/3 > 0);
266+
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
267+
if (forward) {
268+
if (inplace) {
269+
//out = _ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
270+
} else {
271+
out = _ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
272+
}
273+
} else {
274+
//out = _ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
275+
}
276+
}
277+
} else {
278+
if (forward) {
279+
if (inplace) {
280+
//out = _ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
281+
} else {
282+
out = _ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
283+
}
284+
} else {
285+
//out = _ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
286+
}
287+
}
288+
289+
if (out) {
290+
ggml_set_name(out, "out");
291+
}
292+
293+
return out;
294+
}
295+
296+
void init_tensor_uniform(ggml_tensor * tensor, float fmin = -1.0f, float fmax = 1.0f) {
297+
const size_t n_elements = ggml_nelements(tensor);
298+
switch (tensor->type) {
299+
case GGML_TYPE_F32:
300+
{
301+
float * data = (float *)tensor->data;
302+
for (size_t i = 0; i < n_elements; i++) {
303+
data[i] = frand()*(fmax - fmin) + fmin;
304+
}
305+
} break;
306+
default:
307+
assert(false);
308+
}
309+
}
310+
311+
void initialize_tensors(ggml_context * ctx) {
312+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
313+
if (t->type == GGML_TYPE_I32) {
314+
// pos
315+
const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
316+
std::vector<int> data(num_pos_ids);
317+
for (int i = 0; i < num_pos_ids; i++) {
318+
data[i] = rand() % n_ctx;
319+
}
320+
memcpy(t->data, data.data(), num_pos_ids * sizeof(int));
321+
} else {
322+
if (t->ne[0] == n_dims/2) {
323+
// frequency factors in the range [0.9f, 1.1f]
324+
init_tensor_uniform(t, 0.9f, 1.1f);
325+
} else {
326+
init_tensor_uniform(t);
327+
}
328+
}
329+
}
330+
}
331+
};
332+
333+
static void test_rope_comp() {
334+
ggml_init_params params = {
335+
/* .mem_size = */ 128*1024*1024,
336+
/* .mem_buffer = */ NULL,
337+
/* .no_alloc = */ false,
338+
};
339+
340+
std::vector<test_rope *> test_cases;
341+
342+
bool all = true;
343+
bool fw = true;
344+
for (float fs : { 1.0f, 1.4245f }) {
345+
for (float ef : { 0.0f, 0.7465f }) {
346+
for (float af : { 1.0f, 1.4245f }) {
347+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
348+
for (bool ff : {false, true}) { // freq_factors
349+
for (float v : { 0, 1 }) {
350+
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B
351+
352+
if (all) {
353+
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
354+
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
355+
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
356+
}
357+
358+
if (all) {
359+
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
360+
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
361+
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
362+
363+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
364+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
365+
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
366+
367+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
368+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
369+
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
370+
}
371+
372+
if (all) {
373+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
374+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
375+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
376+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
377+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
378+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B)
379+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
380+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
381+
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
382+
test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
383+
}
384+
385+
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
386+
}
387+
}
388+
389+
all = false;
390+
}
391+
}
392+
}
393+
}
394+
395+
std::vector<test_rope *> comp_cases;
396+
for (auto & tc : test_cases) {
397+
auto tc_comp = new test_rope(*tc);
398+
tc_comp->use_comp = true;
399+
comp_cases.push_back(tc_comp);
400+
}
401+
402+
std::vector<uint8_t> work_buffer;
403+
404+
size_t n_passed = 0;
405+
406+
for (size_t i = 0; i < test_cases.size(); i++) {
407+
test_rope * tc_rope = test_cases[i];
408+
test_rope * tc_comp = comp_cases[i];
409+
410+
ggml_context * ctx0 = ggml_init(params);
411+
ggml_cgraph * gf = ggml_new_graph(ctx0);
412+
413+
ggml_tensor * out0 = tc_rope->build_graph(ctx0);
414+
ggml_tensor * out1 = tc_comp->build_graph(ctx0);
415+
416+
if (out0 == nullptr || out1 == nullptr) {
417+
GGML_PRINT("test_rope_comp \x1b[33mSKIPPED\x1b[0m: %s\n", tc_rope->vars().c_str());
418+
ggml_free(ctx0);
419+
delete tc_comp;
420+
delete tc_rope;
421+
continue;
422+
}
423+
424+
tc_rope->initialize_tensors(ctx0);
425+
tc_comp->initialize_tensors(ctx0);
426+
427+
// calculate nmse between out0 and out1
428+
ggml_tensor * diff = ggml_sub(ctx0, out0, out1);
429+
ggml_tensor * mse_a_b = ggml_sum(ctx0, ggml_sqr(ctx0, diff));
430+
ggml_tensor * mse_a_0 = ggml_sum(ctx0, ggml_sqr(ctx0, out0));
431+
ggml_tensor * out = ggml_div(ctx0, mse_a_b, mse_a_0);
432+
433+
ggml_build_forward_expand(gf, out);
434+
ggml_graph_compute_helper(work_buffer, gf, 4);
435+
436+
float nmse = ((float *)out->data)[0];
437+
const float nmse_threshold = 1e-6f;
438+
if (nmse > nmse_threshold) {
439+
GGML_PRINT("test_rope_comp \x1b[31mFAILED\x1b[0m: nmse=%f > %f for %s\n", nmse, nmse_threshold, tc_rope->vars().c_str());
440+
} else {
441+
GGML_PRINT("test_rope_comp OK : nmse=%f <= %f for %s\n", nmse, nmse_threshold, tc_rope->vars().c_str());
442+
n_passed++;
443+
}
444+
445+
ggml_free(ctx0);
446+
delete tc_comp;
447+
delete tc_rope;
448+
}
449+
450+
GGML_ASSERT(n_passed == test_cases.size());
451+
}
452+
127453
int main(int /*argc*/, const char ** /*argv*/) {
128454
struct ggml_init_params params = {
129455
/* .mem_size = */ 128*1024*1024,
@@ -259,5 +585,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
259585

260586
ggml_free(ctx0);
261587

588+
test_rope_comp();
589+
262590
return 0;
263591
}

0 commit comments

Comments
 (0)