2222from datetime import datetime
2323from enum import Enum
2424import csv
25+ import logging
2526
26- torch ._dynamo .config .cache_size_limit = 1000
27+ torch ._dynamo .config .cache_size_limit = 10000
28+ logging .getLogger ("transformer_nuggets" ).setLevel (logging .INFO )
29+ torch ._inductor .config .max_autotune_gemm_backends = "TRITON"
30+ CHECK = False
2731
2832
2933class FP8Kernel (Enum ):
@@ -80,13 +84,14 @@ class ExperimentConfig:
8084 scaling_strategy : ScalingStrategy
8185 fp8_kernel : FP8Kernel
8286 compile : bool
87+ bf16 : bool
8388
8489
8590@dataclass (frozen = True )
8691class ExperimentResult :
87- bf16_time : float
92+ bf16_time : Optional [ float ]
8893 fp8_time : float
89- bf16_tflops : float
94+ bf16_tflops : Optional [ float ]
9095 fp8_tflops : float
9196
9297
@@ -113,29 +118,34 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
113118
114119 if config .compile and config .fp8_kernel == FP8Kernel .SCALED_MM :
115120 bf16_matmul = torch .compile (bf16_matmul )
116- fp8_matmul = torch .compile (fp8_matmul , mode = "max-autotune" )
121+ fp8_matmul = torch .compile (fp8_matmul , mode = "max-autotune-no-cudagraphs " )
117122
118123 # Warmup phase
119124 warmup_iterations = 5
120125 for _ in range (warmup_iterations ):
121- _ = bf16_matmul (A , B )
126+ if config .bf16 :
127+ _ = bf16_matmul (A , B )
122128 _ = fp8_matmul ()
123129 torch .cuda .synchronize ()
124130
125131 # Actual benchmarking
126- bf16_time = benchmark_cuda_function_in_microseconds (lambda : bf16_matmul (A , B ))
132+
133+ bf16_time = (
134+ benchmark_cuda_function_in_microseconds (lambda : bf16_matmul (A , B )) if config .bf16 else None
135+ )
127136 fp8_time = benchmark_cuda_function_in_microseconds (fp8_matmul )
128137
129138 # Calculate TFLOPS
130- bf16_tflops = calculate_tflops (config .M , config .N , config .K , bf16_time )
139+ bf16_tflops = calculate_tflops (config .M , config .N , config .K , bf16_time ) if bf16_time else None
131140 fp8_tflops = calculate_tflops (config .M , config .N , config .K , fp8_time )
132141
133142 # Baseline fp8_matmul correctness
134- scaled_mm_base = get_fp8_matmul (A , B , config .scaling_strategy , FP8Kernel .SCALED_MM )
135- out_base = scaled_mm_base ()
136- out = fp8_matmul ()
137- # Failing on one sample with large N
138- torch .testing .assert_close (out , out_base )
143+ if CHECK :
144+ scaled_mm_base = get_fp8_matmul (A , B , config .scaling_strategy , FP8Kernel .SCALED_MM )
145+ out_base = scaled_mm_base ()
146+ out = fp8_matmul ()
147+ # Failing on one sample with large N
148+ torch .testing .assert_close (out , out_base )
139149
140150 return ExperimentResult (
141151 bf16_time = bf16_time , fp8_time = fp8_time , bf16_tflops = bf16_tflops , fp8_tflops = fp8_tflops
@@ -161,24 +171,38 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
161171 for experiment in experiments :
162172 config = experiment .config
163173 result = experiment .result
164- speedup = result .bf16_time / result .fp8_time
165- tflops_ratio = result .fp8_tflops / result .bf16_tflops
174+
175+ # Format values handling None cases
176+ bf16_time = f"{ result .bf16_time :.4f} " if result .bf16_time is not None else "N/A"
177+ fp8_time = f"{ result .fp8_time :.4f} "
178+ bf16_tflops = f"{ result .bf16_tflops :.2f} " if result .bf16_tflops is not None else "N/A"
179+ fp8_tflops = f"{ result .fp8_tflops :.2f} "
180+
181+ # Calculate ratios only if bf16 results exist
182+ if result .bf16_time is not None :
183+ speedup = f"{ (result .bf16_time / result .fp8_time ):.2f} x"
184+ tflops_ratio = f"{ (result .fp8_tflops / result .bf16_tflops ):.2f} x"
185+ else :
186+ speedup = "N/A"
187+ tflops_ratio = "N/A"
188+
166189 rows .append (
167190 [
168191 config .M ,
169192 config .K ,
170193 config .N ,
171- config .scaling_strategy ,
172- config .fp8_kernel ,
194+ config .scaling_strategy . value ,
195+ config .fp8_kernel . value ,
173196 config .compile ,
174- f" { result . bf16_time :.4f } " ,
175- f" { result . fp8_time :.4f } " ,
176- f" { speedup :.2f } x" ,
177- f" { result . bf16_tflops :.2f } " ,
178- f" { result . fp8_tflops :.2f } " ,
179- f" { tflops_ratio :.2f } x" ,
197+ bf16_time ,
198+ fp8_time ,
199+ speedup ,
200+ bf16_tflops ,
201+ fp8_tflops ,
202+ tflops_ratio ,
180203 ]
181204 )
205+
182206 print (tabulate (rows , headers = headers , floatfmt = ".4f" ))
183207
184208 if save_path is not None :
@@ -189,24 +213,32 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
189213 print (f"💾 Results saved to: { save_path } " )
190214
191215
192- def get_configs_varying_k (M : int = 8192 , N : int = 8192 ) -> List [ExperimentConfig ]:
216+ def get_configs_varying_k (
217+ M : int = 8192 , N : int = 8192 , bf16 : bool = False
218+ ) -> List [ExperimentConfig ]:
193219 shapes = [(M , K , N ) for K in range (512 , 8193 , 512 )]
194220 scaling_strategies = [ScalingStrategy .PER_ROW ]
195- compile_options = [False ]
221+ compile_options = [True ]
196222 configs = []
197223 fp8_kernels = [
198224 FP8Kernel .SCALED_MM ,
199225 # FP8Kernel.PERSISTENT,
200- FP8Kernel .PERSISTENT_TMA ,
201- FP8Kernel .DEVICE_TMA ,
226+ # FP8Kernel.PERSISTENT_TMA,
227+ # FP8Kernel.DEVICE_TMA,
202228 ]
203229
204230 for (M , K , N ), strategy , compile , kernel in itertools .product (
205231 shapes , scaling_strategies , compile_options , fp8_kernels
206232 ):
207233 configs .append (
208234 ExperimentConfig (
209- M = M , K = K , N = N , scaling_strategy = strategy , compile = compile , fp8_kernel = kernel
235+ M = M ,
236+ K = K ,
237+ N = N ,
238+ scaling_strategy = strategy ,
239+ compile = compile ,
240+ fp8_kernel = kernel ,
241+ bf16 = bf16 ,
210242 )
211243 )
212244 return configs
@@ -250,17 +282,24 @@ def plot_tflops_comparison(df, save_path: Path):
250282 print (f"TFLOPS comparison plot saved as { graph_path } " )
251283
252284
253- def main (save_path : Optional [str ] = None , M : int = 8192 , N : int = 8192 , graph : bool = False ):
285+ def main (
286+ save_path : Optional [str ] = None ,
287+ M : int = 8192 ,
288+ N : int = 8192 ,
289+ graph : bool = False ,
290+ bf_16 : bool = False ,
291+ ):
254292 """Benchmark FP8 MatMul with different configurations and optionally graph results.
255293
256294 Args:
257295 save_path (Optional[str], optional): Path to save the results. Defaults to None.
258296 M (int, optional): Number of rows in the first matrix. Defaults to 8192.
259297 N (int, optional): Number of columns in the second matrix. Defaults to 8192.
260298 graph_results (bool, optional): Whether to create a graph of the results. Defaults to False.
299+ bf_16 (bool, optional): Whether to use BF16 for the baseline. Defaults to False.
261300 """
262301 torch .random .manual_seed (123 )
263- configs = get_configs_varying_k (M , N )
302+ configs = get_configs_varying_k (M , N , bf16 = bf_16 )
264303 results = []
265304 if save_path is not None :
266305 save_path = Path (save_path )
0 commit comments