Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 56 additions & 42 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,45 +504,59 @@ def _set_state_dict_type(model: nn.Module):
def test_fused_as_fast_as_unfused(N: int,
D: int,
min_elems_traversed: int = 1000000):
W = torch.randn((N, D), device='cuda', requires_grad=True)
W.grad = torch.randn((N, D), device='cuda', requires_grad=False)

num_iters = int(np.ceil(min_elems_traversed / W.grad.numel()))
num_iters = min(100, num_iters) # don't take all day when overhead-bound

times = {}
kwargs = {'weight_decay': .01}
combos = [(True, False), (True, True), (False, False), ('NA', False)]
for fused, use_errors in combos:
if fused == 'NA':
opt = Lion8bit([W], quantize=False,
**kwargs) # type:ignore (reportGeneralTypeIssues)
else:
opt = Lion8bit([W],
_fused=fused,
error_correction=use_errors,
**kwargs) # type:ignore (reportGeneralTypeIssues)
for _ in range(3):
opt.step() # warmup iters
torch.cuda.synchronize()
t_start = time.time()
for _ in range(num_iters):
opt.step()
torch.cuda.synchronize()
t_end = time.time()
dur = (t_end - t_start) / num_iters
if use_errors:
times['ecc'] = dur
else:
times[fused] = dur

atol = 20e-6 # should always be faster, but avoids rare flakiness
assert times[True] < times[False] + atol
assert times[True] < times['NA'] + atol
assert times['ecc'] < times['NA'] + atol

print('')
print('time fused (ms): ', times[True] * 1e3)
print('time fused+ecc (ms): ', times['ecc'] * 1e3)
print('time unfused (ms): ', times[False] * 1e3)
print('time unquantized (ms): ', times['NA'] * 1e3)

def _time_kernels(N: int, D: int, min_elems_traversed: int):
W = torch.randn((N, D), device='cuda', requires_grad=True)
W.grad = torch.randn((N, D), device='cuda', requires_grad=False)

num_iters = int(np.ceil(min_elems_traversed / W.grad.numel()))
num_iters = min(100,
num_iters) # don't take all day when overhead-bound

times = {}
kwargs = {'weight_decay': .01}
combos = [(True, False), (True, True), (False, False), ('NA', False)]
for fused, use_errors in combos:
if fused == 'NA':
opt = Lion8bit(
[W], quantize=False,
**kwargs) # type:ignore (reportGeneralTypeIssues)
else:
opt = Lion8bit(
[W], _fused=fused, error_correction=use_errors,
**kwargs) # type:ignore (reportGeneralTypeIssues)
for _ in range(3):
opt.step() # warmup iters
torch.cuda.synchronize()
t_start = time.time()
for _ in range(num_iters):
opt.step()
torch.cuda.synchronize()
t_end = time.time()
dur = (t_end - t_start) / num_iters
if use_errors:
times['ecc'] = dur
else:
times[fused] = dur
return times

times = _time_kernels(N, D, min_elems_traversed)

atol = 2e-4 # should always be faster, but atol helps avoid flakiness
it = 0
while True:
try:
assert times[True] < times[False] + atol
assert times[True] < times['NA'] + atol
assert times['ecc'] < times['NA'] + atol
print('')
print('time fused (ms): ', times[True] * 1e3)
print('time fused+ecc (ms): ', times['ecc'] * 1e3)
print('time unfused (ms): ', times[False] * 1e3)
print('time unquantized (ms): ', times['NA'] * 1e3)
break
except AssertionError as e:
if it >= 2: # allow 3 retries to avoid flakiness
raise e
times = _time_kernels(N, D, min_elems_traversed)
it += 1