Skip to content

Commit 2e5f2fb

Browse files
committed
use fla's get_err_ratio for err computation
1 parent 019683b commit 2e5f2fb

File tree

1 file changed

+5
-15
lines changed

1 file changed

+5
-15
lines changed

tests/unit_tests/test_moe.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,9 @@ def test_moe_ffn_equivalence(self, iteration: int = 0) -> tuple[float, float]:
174174
"moe_old vs ffn", out_ffn, out_moe_old, self.assert_close_ratio
175175
)
176176
assert_close("moe vs ffn", out_ffn, out_moe, self.assert_close_ratio)
177-
# Compute symmetric mean percentage errors
178-
moe_old_rel_err = (
179-
2
180-
* ((out_moe_old - out_ffn).abs() / (out_ffn.abs() + out_moe.abs()))
181-
.mean()
182-
.item()
183-
)
184-
moe_rel_err = (
185-
2
186-
* ((out_moe - out_ffn).abs() / (out_ffn.abs() + out_moe.abs()))
187-
.mean()
188-
.item()
189-
)
177+
178+
moe_old_rel_err = get_err_ratio(out_ffn, out_moe_old)
179+
moe_rel_err = get_err_ratio(out_ffn, out_moe)
190180
return moe_old_rel_err, moe_rel_err
191181

192182
def test_perf(
@@ -233,7 +223,7 @@ def test_perf(
233223
mean_moe_old_rel_err = torch.tensor(moe_old_rel_errs)
234224
mean_moe_rel_err = torch.tensor(moe_rel_errs)
235225

236-
print(f"\nACCURACY VS FFN: {accuracy_iters} iterations")
226+
print(f"\nACCURACY VS FFN: {accuracy_iters} iterations\n")
237227
print(f"{mean_moe_old_rel_err.mean()=}, {mean_moe_old_rel_err.std()=}")
238228
print(f"{mean_moe_rel_err.mean()=}, {mean_moe_rel_err.std()=}")
239229
print(f"{mean_moe_old_rel_err.mean()/mean_moe_rel_err.mean()=}")
@@ -242,7 +232,7 @@ def test_perf(
242232
perf_seqlen = 4096
243233
perf_bsz = 4
244234
print(
245-
f"\nTRITON BENCH: {perf_seqlen=} {perf_bsz=} warmups={t.perf_warmups} repeats={t.perf_reps}"
235+
f"\nTRITON BENCH: {perf_seqlen=} {perf_bsz=} warmups={t.perf_warmups} repeats={t.perf_reps}\n"
246236
)
247237
t.test_perf(bsz=perf_bsz, seqlen=perf_seqlen)
248238

0 commit comments

Comments
 (0)