Skip to content

Commit 53bb288

Browse files
committed
Update benchmarks
stack-info: PR: #39, branch: drisspg/stack/2
1 parent a4c66bb commit 53bb288

File tree

1 file changed

+81
-37
lines changed

1 file changed

+81
-37
lines changed

benchmarks/fp8_matmul.py

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,25 @@
1414
preprocess_data,
1515
Float8MMConfig,
1616
)
17-
from transformer_nuggets.fp8.fp8_matmul import (
18-
matmul_persistent,
19-
matmul_tma_persistent,
20-
matmul_device_tma_persistent,
21-
)
17+
try:
18+
from transformer_nuggets.fp8.fp8_matmul import (
19+
matmul_persistent,
20+
matmul_tma_persistent,
21+
matmul_device_tma_persistent,
22+
)
23+
except ModuleNotFoundError:
24+
print("Triton version not new enough")
25+
pass
26+
2227
from datetime import datetime
2328
from enum import Enum
2429
import csv
30+
import logging
2531

26-
torch._dynamo.config.cache_size_limit = 1000
32+
torch._dynamo.config.cache_size_limit = 10000
33+
logging.getLogger("transformer_nuggets").setLevel(logging.INFO)
34+
torch._inductor.config.max_autotune_gemm_backends = "TRITON"
35+
CHECK = False
2736

2837

2938
class FP8Kernel(Enum):
@@ -80,13 +89,14 @@ class ExperimentConfig:
8089
scaling_strategy: ScalingStrategy
8190
fp8_kernel: FP8Kernel
8291
compile: bool
92+
bf16: bool
8393

8494

8595
@dataclass(frozen=True)
8696
class ExperimentResult:
87-
bf16_time: float
97+
bf16_time: Optional[float]
8898
fp8_time: float
89-
bf16_tflops: float
99+
bf16_tflops: Optional[float]
90100
fp8_tflops: float
91101

92102

@@ -113,29 +123,34 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
113123

114124
if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
115125
bf16_matmul = torch.compile(bf16_matmul)
116-
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune")
126+
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs", dynamic=False)
117127

118128
# Warmup phase
119129
warmup_iterations = 5
120130
for _ in range(warmup_iterations):
121-
_ = bf16_matmul(A, B)
131+
if config.bf16:
132+
_ = bf16_matmul(A, B)
122133
_ = fp8_matmul()
123134
torch.cuda.synchronize()
124135

125136
# Actual benchmarking
126-
bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B))
137+
138+
bf16_time = (
139+
benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) if config.bf16 else None
140+
)
127141
fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul)
128142

129143
# Calculate TFLOPS
130-
bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time)
144+
bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time) if bf16_time else None
131145
fp8_tflops = calculate_tflops(config.M, config.N, config.K, fp8_time)
132146

133147
# Baseline fp8_matmul correctness
134-
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
135-
out_base = scaled_mm_base()
136-
out = fp8_matmul()
137-
# Failing on one sample with large N
138-
torch.testing.assert_close(out, out_base)
148+
if CHECK:
149+
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
150+
out_base = scaled_mm_base()
151+
out = fp8_matmul()
152+
# Failing on one sample with large N
153+
torch.testing.assert_close(out, out_base)
139154

140155
return ExperimentResult(
141156
bf16_time=bf16_time, fp8_time=fp8_time, bf16_tflops=bf16_tflops, fp8_tflops=fp8_tflops
@@ -161,24 +176,38 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
161176
for experiment in experiments:
162177
config = experiment.config
163178
result = experiment.result
164-
speedup = result.bf16_time / result.fp8_time
165-
tflops_ratio = result.fp8_tflops / result.bf16_tflops
179+
180+
# Format values handling None cases
181+
bf16_time = f"{result.bf16_time:.4f}" if result.bf16_time is not None else "N/A"
182+
fp8_time = f"{result.fp8_time:.4f}"
183+
bf16_tflops = f"{result.bf16_tflops:.2f}" if result.bf16_tflops is not None else "N/A"
184+
fp8_tflops = f"{result.fp8_tflops:.2f}"
185+
186+
# Calculate ratios only if bf16 results exist
187+
if result.bf16_time is not None:
188+
speedup = f"{(result.bf16_time / result.fp8_time):.2f}x"
189+
tflops_ratio = f"{(result.fp8_tflops / result.bf16_tflops):.2f}x"
190+
else:
191+
speedup = "N/A"
192+
tflops_ratio = "N/A"
193+
166194
rows.append(
167195
[
168196
config.M,
169197
config.K,
170198
config.N,
171-
config.scaling_strategy,
172-
config.fp8_kernel,
199+
config.scaling_strategy.value,
200+
config.fp8_kernel.value,
173201
config.compile,
174-
f"{result.bf16_time:.4f}",
175-
f"{result.fp8_time:.4f}",
176-
f"{speedup:.2f}x",
177-
f"{result.bf16_tflops:.2f}",
178-
f"{result.fp8_tflops:.2f}",
179-
f"{tflops_ratio:.2f}x",
202+
bf16_time,
203+
fp8_time,
204+
speedup,
205+
bf16_tflops,
206+
fp8_tflops,
207+
tflops_ratio,
180208
]
181209
)
210+
182211
print(tabulate(rows, headers=headers, floatfmt=".4f"))
183212

184213
if save_path is not None:
@@ -189,33 +218,41 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
189218
print(f"💾 Results saved to: {save_path}")
190219

191220

192-
def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig]:
193-
shapes = [(M, K, N) for K in range(512, 8193, 512)]
221+
def get_configs_varying_k(
222+
M: int = 8192, N: int = 8192, bf16: bool = False
223+
) -> List[ExperimentConfig]:
224+
shapes = [(M, K, N) for K in range(1024, 16385, 1024)]
194225
scaling_strategies = [ScalingStrategy.PER_ROW]
195-
compile_options = [False]
226+
compile_options = [True, False]
196227
configs = []
197228
fp8_kernels = [
198229
FP8Kernel.SCALED_MM,
199230
# FP8Kernel.PERSISTENT,
200-
FP8Kernel.PERSISTENT_TMA,
201-
FP8Kernel.DEVICE_TMA,
231+
# FP8Kernel.PERSISTENT_TMA,
232+
# FP8Kernel.DEVICE_TMA,
202233
]
203234

204235
for (M, K, N), strategy, compile, kernel in itertools.product(
205236
shapes, scaling_strategies, compile_options, fp8_kernels
206237
):
207238
configs.append(
208239
ExperimentConfig(
209-
M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel
240+
M=M,
241+
K=K,
242+
N=N,
243+
scaling_strategy=strategy,
244+
compile=compile,
245+
fp8_kernel=kernel,
246+
bf16=bf16,
210247
)
211248
)
212249
return configs
213250

214251

215252
def load_and_process_data(file_path):
216253
df = pd.read_csv(file_path)
217-
df["Speedup"] = df["Speedup"].str.rstrip("x").astype(float)
218-
df["TFLOPS Ratio"] = df["TFLOPS Ratio"].str.rstrip("x").astype(float)
254+
# df["Speedup"] = df["Speedup"].str.rstrip("x").astype(float)
255+
# df["TFLOPS Ratio"] = df["TFLOPS Ratio"].str.rstrip("x").astype(float)
219256
return df
220257

221258

@@ -250,17 +287,24 @@ def plot_tflops_comparison(df, save_path: Path):
250287
print(f"TFLOPS comparison plot saved as {graph_path}")
251288

252289

253-
def main(save_path: Optional[str] = None, M: int = 8192, N: int = 8192, graph: bool = False):
290+
def main(
291+
save_path: Optional[str] = None,
292+
M: int = 8192,
293+
N: int = 8192,
294+
graph: bool = False,
295+
bf_16: bool = False,
296+
):
254297
"""Benchmark FP8 MatMul with different configurations and optionally graph results.
255298
256299
Args:
257300
save_path (Optional[str], optional): Path to save the results. Defaults to None.
258301
M (int, optional): Number of rows in the first matrix. Defaults to 8192.
259302
N (int, optional): Number of columns in the second matrix. Defaults to 8192.
260303
graph_results (bool, optional): Whether to create a graph of the results. Defaults to False.
304+
bf_16 (bool, optional): Whether to use BF16 for the baseline. Defaults to False.
261305
"""
262306
torch.random.manual_seed(123)
263-
configs = get_configs_varying_k(M, N)
307+
configs = get_configs_varying_k(M, N, bf16=bf_16)
264308
results = []
265309
if save_path is not None:
266310
save_path = Path(save_path)

0 commit comments

Comments
 (0)