@@ -174,19 +174,9 @@ def test_moe_ffn_equivalence(self, iteration: int = 0) -> tuple[float, float]:
174174 "moe_old vs ffn" , out_ffn , out_moe_old , self .assert_close_ratio
175175 )
176176 assert_close ("moe vs ffn" , out_ffn , out_moe , self .assert_close_ratio )
177- # Compute symmetric mean percentage errors
178- moe_old_rel_err = (
179- 2
180- * ((out_moe_old - out_ffn ).abs () / (out_ffn .abs () + out_moe .abs ()))
181- .mean ()
182- .item ()
183- )
184- moe_rel_err = (
185- 2
186- * ((out_moe - out_ffn ).abs () / (out_ffn .abs () + out_moe .abs ()))
187- .mean ()
188- .item ()
189- )
177+
178+ moe_old_rel_err = get_err_ratio (out_ffn , out_moe_old )
179+ moe_rel_err = get_err_ratio (out_ffn , out_moe )
190180 return moe_old_rel_err , moe_rel_err
191181
192182 def test_perf (
@@ -233,7 +223,7 @@ def test_perf(
233223 mean_moe_old_rel_err = torch .tensor (moe_old_rel_errs )
234224 mean_moe_rel_err = torch .tensor (moe_rel_errs )
235225
236- print (f"\n ACCURACY VS FFN: { accuracy_iters } iterations" )
226+ print (f"\n ACCURACY VS FFN: { accuracy_iters } iterations\n " )
237227 print (f"{ mean_moe_old_rel_err .mean ()= } , { mean_moe_old_rel_err .std ()= } " )
238228 print (f"{ mean_moe_rel_err .mean ()= } , { mean_moe_rel_err .std ()= } " )
239229 print (f"{ mean_moe_old_rel_err .mean ()/ mean_moe_rel_err .mean ()= } " )
@@ -242,7 +232,7 @@ def test_perf(
242232 perf_seqlen = 4096
243233 perf_bsz = 4
244234 print (
245- f"\n TRITON BENCH: { perf_seqlen = } { perf_bsz = } warmups={ t .perf_warmups } repeats={ t .perf_reps } "
235+ f"\n TRITON BENCH: { perf_seqlen = } { perf_bsz = } warmups={ t .perf_warmups } repeats={ t .perf_reps } \n "
246236 )
247237 t .test_perf (bsz = perf_bsz , seqlen = perf_seqlen )
248238
0 commit comments