Skip to content

Commit 73804b1

Browse files
committed
Update benchmarks
stack-info: PR: #39, branch: drisspg/stack/2
1 parent 5f2a907 commit 73804b1

File tree

1 file changed

+68
-29
lines changed

1 file changed

+68
-29
lines changed

benchmarks/fp8_matmul.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
from datetime import datetime
2323
from enum import Enum
2424
import csv
25+
import logging
2526

26-
torch._dynamo.config.cache_size_limit = 1000
27+
torch._dynamo.config.cache_size_limit = 10000
28+
logging.getLogger("transformer_nuggets").setLevel(logging.INFO)
29+
torch._inductor.config.max_autotune_gemm_backends = "TRITON"
30+
CHECK = False
2731

2832

2933
class FP8Kernel(Enum):
@@ -80,13 +84,14 @@ class ExperimentConfig:
8084
scaling_strategy: ScalingStrategy
8185
fp8_kernel: FP8Kernel
8286
compile: bool
87+
bf16: bool
8388

8489

8590
@dataclass(frozen=True)
8691
class ExperimentResult:
87-
bf16_time: float
92+
bf16_time: Optional[float]
8893
fp8_time: float
89-
bf16_tflops: float
94+
bf16_tflops: Optional[float]
9095
fp8_tflops: float
9196

9297

@@ -113,29 +118,34 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
113118

114119
if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
115120
bf16_matmul = torch.compile(bf16_matmul)
116-
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune")
121+
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")
117122

118123
# Warmup phase
119124
warmup_iterations = 5
120125
for _ in range(warmup_iterations):
121-
_ = bf16_matmul(A, B)
126+
if config.bf16:
127+
_ = bf16_matmul(A, B)
122128
_ = fp8_matmul()
123129
torch.cuda.synchronize()
124130

125131
# Actual benchmarking
126-
bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B))
132+
133+
bf16_time = (
134+
benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) if config.bf16 else None
135+
)
127136
fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul)
128137

129138
# Calculate TFLOPS
130-
bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time)
139+
bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time) if bf16_time else None
131140
fp8_tflops = calculate_tflops(config.M, config.N, config.K, fp8_time)
132141

133142
# Baseline fp8_matmul correctness
134-
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
135-
out_base = scaled_mm_base()
136-
out = fp8_matmul()
137-
# Failing on one sample with large N
138-
torch.testing.assert_close(out, out_base)
143+
if CHECK:
144+
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
145+
out_base = scaled_mm_base()
146+
out = fp8_matmul()
147+
# Failing on one sample with large N
148+
torch.testing.assert_close(out, out_base)
139149

140150
return ExperimentResult(
141151
bf16_time=bf16_time, fp8_time=fp8_time, bf16_tflops=bf16_tflops, fp8_tflops=fp8_tflops
@@ -161,24 +171,38 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
161171
for experiment in experiments:
162172
config = experiment.config
163173
result = experiment.result
164-
speedup = result.bf16_time / result.fp8_time
165-
tflops_ratio = result.fp8_tflops / result.bf16_tflops
174+
175+
# Format values handling None cases
176+
bf16_time = f"{result.bf16_time:.4f}" if result.bf16_time is not None else "N/A"
177+
fp8_time = f"{result.fp8_time:.4f}"
178+
bf16_tflops = f"{result.bf16_tflops:.2f}" if result.bf16_tflops is not None else "N/A"
179+
fp8_tflops = f"{result.fp8_tflops:.2f}"
180+
181+
# Calculate ratios only if bf16 results exist
182+
if result.bf16_time is not None:
183+
speedup = f"{(result.bf16_time / result.fp8_time):.2f}x"
184+
tflops_ratio = f"{(result.fp8_tflops / result.bf16_tflops):.2f}x"
185+
else:
186+
speedup = "N/A"
187+
tflops_ratio = "N/A"
188+
166189
rows.append(
167190
[
168191
config.M,
169192
config.K,
170193
config.N,
171-
config.scaling_strategy,
172-
config.fp8_kernel,
194+
config.scaling_strategy.value,
195+
config.fp8_kernel.value,
173196
config.compile,
174-
f"{result.bf16_time:.4f}",
175-
f"{result.fp8_time:.4f}",
176-
f"{speedup:.2f}x",
177-
f"{result.bf16_tflops:.2f}",
178-
f"{result.fp8_tflops:.2f}",
179-
f"{tflops_ratio:.2f}x",
197+
bf16_time,
198+
fp8_time,
199+
speedup,
200+
bf16_tflops,
201+
fp8_tflops,
202+
tflops_ratio,
180203
]
181204
)
205+
182206
print(tabulate(rows, headers=headers, floatfmt=".4f"))
183207

184208
if save_path is not None:
@@ -189,24 +213,32 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
189213
print(f"💾 Results saved to: {save_path}")
190214

191215

192-
def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig]:
216+
def get_configs_varying_k(
217+
M: int = 8192, N: int = 8192, bf16: bool = False
218+
) -> List[ExperimentConfig]:
193219
shapes = [(M, K, N) for K in range(512, 8193, 512)]
194220
scaling_strategies = [ScalingStrategy.PER_ROW]
195-
compile_options = [False]
221+
compile_options = [True]
196222
configs = []
197223
fp8_kernels = [
198224
FP8Kernel.SCALED_MM,
199225
# FP8Kernel.PERSISTENT,
200-
FP8Kernel.PERSISTENT_TMA,
201-
FP8Kernel.DEVICE_TMA,
226+
# FP8Kernel.PERSISTENT_TMA,
227+
# FP8Kernel.DEVICE_TMA,
202228
]
203229

204230
for (M, K, N), strategy, compile, kernel in itertools.product(
205231
shapes, scaling_strategies, compile_options, fp8_kernels
206232
):
207233
configs.append(
208234
ExperimentConfig(
209-
M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel
235+
M=M,
236+
K=K,
237+
N=N,
238+
scaling_strategy=strategy,
239+
compile=compile,
240+
fp8_kernel=kernel,
241+
bf16=bf16,
210242
)
211243
)
212244
return configs
@@ -250,17 +282,24 @@ def plot_tflops_comparison(df, save_path: Path):
250282
print(f"TFLOPS comparison plot saved as {graph_path}")
251283

252284

253-
def main(save_path: Optional[str] = None, M: int = 8192, N: int = 8192, graph: bool = False):
285+
def main(
286+
save_path: Optional[str] = None,
287+
M: int = 8192,
288+
N: int = 8192,
289+
graph: bool = False,
290+
bf_16: bool = False,
291+
):
254292
"""Benchmark FP8 MatMul with different configurations and optionally graph results.
255293
256294
Args:
257295
save_path (Optional[str], optional): Path to save the results. Defaults to None.
258296
M (int, optional): Number of rows in the first matrix. Defaults to 8192.
259297
N (int, optional): Number of columns in the second matrix. Defaults to 8192.
260298
graph_results (bool, optional): Whether to create a graph of the results. Defaults to False.
299+
bf_16 (bool, optional): Whether to use BF16 for the baseline. Defaults to False.
261300
"""
262301
torch.random.manual_seed(123)
263-
configs = get_configs_varying_k(M, N)
302+
configs = get_configs_varying_k(M, N, bf16=bf_16)
264303
results = []
265304
if save_path is not None:
266305
save_path = Path(save_path)

0 commit comments

Comments
 (0)