1414 preprocess_data ,
1515 Float8MMConfig ,
1616)
17- from transformer_nuggets .fp8 .fp8_matmul import (
18- matmul_persistent ,
19- matmul_tma_persistent ,
20- matmul_device_tma_persistent ,
21- )
17+ try :
18+ from transformer_nuggets .fp8 .fp8_matmul import (
19+ matmul_persistent ,
20+ matmul_tma_persistent ,
21+ matmul_device_tma_persistent ,
22+ )
23+ except ModuleNotFoundError :
24+ print ("Triton version not new enough" )
25+ pass
26+
2227from datetime import datetime
2328from enum import Enum
2429import csv
30+ import logging
2531
26- torch ._dynamo .config .cache_size_limit = 1000
32+ torch ._dynamo .config .cache_size_limit = 10000
33+ logging .getLogger ("transformer_nuggets" ).setLevel (logging .INFO )
34+ torch ._inductor .config .max_autotune_gemm_backends = "TRITON"
35+ CHECK = False
2736
2837
2938class FP8Kernel (Enum ):
@@ -80,13 +89,14 @@ class ExperimentConfig:
8089 scaling_strategy : ScalingStrategy
8190 fp8_kernel : FP8Kernel
8291 compile : bool
92+ bf16 : bool
8393
8494
8595@dataclass (frozen = True )
8696class ExperimentResult :
87- bf16_time : float
97+ bf16_time : Optional [ float ]
8898 fp8_time : float
89- bf16_tflops : float
99+ bf16_tflops : Optional [ float ]
90100 fp8_tflops : float
91101
92102
@@ -113,29 +123,34 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
113123
114124 if config .compile and config .fp8_kernel == FP8Kernel .SCALED_MM :
115125 bf16_matmul = torch .compile (bf16_matmul )
116- fp8_matmul = torch .compile (fp8_matmul , mode = "max-autotune" )
126+ fp8_matmul = torch .compile (fp8_matmul , mode = "max-autotune-no-cudagraphs" , dynamic = False )
117127
118128 # Warmup phase
119129 warmup_iterations = 5
120130 for _ in range (warmup_iterations ):
121- _ = bf16_matmul (A , B )
131+ if config .bf16 :
132+ _ = bf16_matmul (A , B )
122133 _ = fp8_matmul ()
123134 torch .cuda .synchronize ()
124135
125136 # Actual benchmarking
126- bf16_time = benchmark_cuda_function_in_microseconds (lambda : bf16_matmul (A , B ))
137+
138+ bf16_time = (
139+ benchmark_cuda_function_in_microseconds (lambda : bf16_matmul (A , B )) if config .bf16 else None
140+ )
127141 fp8_time = benchmark_cuda_function_in_microseconds (fp8_matmul )
128142
129143 # Calculate TFLOPS
130- bf16_tflops = calculate_tflops (config .M , config .N , config .K , bf16_time )
144+ bf16_tflops = calculate_tflops (config .M , config .N , config .K , bf16_time ) if bf16_time else None
131145 fp8_tflops = calculate_tflops (config .M , config .N , config .K , fp8_time )
132146
133147 # Baseline fp8_matmul correctness
134- scaled_mm_base = get_fp8_matmul (A , B , config .scaling_strategy , FP8Kernel .SCALED_MM )
135- out_base = scaled_mm_base ()
136- out = fp8_matmul ()
137- # Failing on one sample with large N
138- torch .testing .assert_close (out , out_base )
148+ if CHECK :
149+ scaled_mm_base = get_fp8_matmul (A , B , config .scaling_strategy , FP8Kernel .SCALED_MM )
150+ out_base = scaled_mm_base ()
151+ out = fp8_matmul ()
152+ # Failing on one sample with large N
153+ torch .testing .assert_close (out , out_base )
139154
140155 return ExperimentResult (
141156 bf16_time = bf16_time , fp8_time = fp8_time , bf16_tflops = bf16_tflops , fp8_tflops = fp8_tflops
@@ -161,24 +176,38 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
161176 for experiment in experiments :
162177 config = experiment .config
163178 result = experiment .result
164- speedup = result .bf16_time / result .fp8_time
165- tflops_ratio = result .fp8_tflops / result .bf16_tflops
179+
180+ # Format values handling None cases
181+ bf16_time = f"{ result .bf16_time :.4f} " if result .bf16_time is not None else "N/A"
182+ fp8_time = f"{ result .fp8_time :.4f} "
183+ bf16_tflops = f"{ result .bf16_tflops :.2f} " if result .bf16_tflops is not None else "N/A"
184+ fp8_tflops = f"{ result .fp8_tflops :.2f} "
185+
186+ # Calculate ratios only if bf16 results exist
187+ if result .bf16_time is not None :
188+ speedup = f"{ (result .bf16_time / result .fp8_time ):.2f} x"
189+ tflops_ratio = f"{ (result .fp8_tflops / result .bf16_tflops ):.2f} x"
190+ else :
191+ speedup = "N/A"
192+ tflops_ratio = "N/A"
193+
166194 rows .append (
167195 [
168196 config .M ,
169197 config .K ,
170198 config .N ,
171- config .scaling_strategy ,
172- config .fp8_kernel ,
199+ config .scaling_strategy . value ,
200+ config .fp8_kernel . value ,
173201 config .compile ,
174- f" { result . bf16_time :.4f } " ,
175- f" { result . fp8_time :.4f } " ,
176- f" { speedup :.2f } x" ,
177- f" { result . bf16_tflops :.2f } " ,
178- f" { result . fp8_tflops :.2f } " ,
179- f" { tflops_ratio :.2f } x" ,
202+ bf16_time ,
203+ fp8_time ,
204+ speedup ,
205+ bf16_tflops ,
206+ fp8_tflops ,
207+ tflops_ratio ,
180208 ]
181209 )
210+
182211 print (tabulate (rows , headers = headers , floatfmt = ".4f" ))
183212
184213 if save_path is not None :
@@ -189,33 +218,41 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
189218 print (f"💾 Results saved to: { save_path } " )
190219
191220
192- def get_configs_varying_k (M : int = 8192 , N : int = 8192 ) -> List [ExperimentConfig ]:
193- shapes = [(M , K , N ) for K in range (512 , 8193 , 512 )]
221+ def get_configs_varying_k (
222+ M : int = 8192 , N : int = 8192 , bf16 : bool = False
223+ ) -> List [ExperimentConfig ]:
224+ shapes = [(M , K , N ) for K in range (1024 , 16385 , 1024 )]
194225 scaling_strategies = [ScalingStrategy .PER_ROW ]
195- compile_options = [False ]
226+ compile_options = [True , False ]
196227 configs = []
197228 fp8_kernels = [
198229 FP8Kernel .SCALED_MM ,
199230 # FP8Kernel.PERSISTENT,
200- FP8Kernel .PERSISTENT_TMA ,
201- FP8Kernel .DEVICE_TMA ,
231+ # FP8Kernel.PERSISTENT_TMA,
232+ # FP8Kernel.DEVICE_TMA,
202233 ]
203234
204235 for (M , K , N ), strategy , compile , kernel in itertools .product (
205236 shapes , scaling_strategies , compile_options , fp8_kernels
206237 ):
207238 configs .append (
208239 ExperimentConfig (
209- M = M , K = K , N = N , scaling_strategy = strategy , compile = compile , fp8_kernel = kernel
240+ M = M ,
241+ K = K ,
242+ N = N ,
243+ scaling_strategy = strategy ,
244+ compile = compile ,
245+ fp8_kernel = kernel ,
246+ bf16 = bf16 ,
210247 )
211248 )
212249 return configs
213250
214251
215252def load_and_process_data (file_path ):
216253 df = pd .read_csv (file_path )
217- df ["Speedup" ] = df ["Speedup" ].str .rstrip ("x" ).astype (float )
218- df ["TFLOPS Ratio" ] = df ["TFLOPS Ratio" ].str .rstrip ("x" ).astype (float )
254+ # df["Speedup"] = df["Speedup"].str.rstrip("x").astype(float)
255+ # df["TFLOPS Ratio"] = df["TFLOPS Ratio"].str.rstrip("x").astype(float)
219256 return df
220257
221258
@@ -250,17 +287,24 @@ def plot_tflops_comparison(df, save_path: Path):
250287 print (f"TFLOPS comparison plot saved as { graph_path } " )
251288
252289
253- def main (save_path : Optional [str ] = None , M : int = 8192 , N : int = 8192 , graph : bool = False ):
290+ def main (
291+ save_path : Optional [str ] = None ,
292+ M : int = 8192 ,
293+ N : int = 8192 ,
294+ graph : bool = False ,
295+ bf_16 : bool = False ,
296+ ):
254297 """Benchmark FP8 MatMul with different configurations and optionally graph results.
255298
256299 Args:
257300 save_path (Optional[str], optional): Path to save the results. Defaults to None.
258301 M (int, optional): Number of rows in the first matrix. Defaults to 8192.
259302 N (int, optional): Number of columns in the second matrix. Defaults to 8192.
260303 graph_results (bool, optional): Whether to create a graph of the results. Defaults to False.
304+ bf_16 (bool, optional): Whether to use BF16 for the baseline. Defaults to False.
261305 """
262306 torch .random .manual_seed (123 )
263- configs = get_configs_varying_k (M , N )
307+ configs = get_configs_varying_k (M , N , bf16 = bf_16 )
264308 results = []
265309 if save_path is not None :
266310 save_path = Path (save_path )
0 commit comments