Skip to content

Commit e3b6d18

Browse files
committed
update
1 parent 7473d37 commit e3b6d18

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

test/transformers/test_layer_norm.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,19 +152,31 @@ def test_liger_layer_norm_weird_shapes(
152152
torch_output = torch_ln(torch_x)
153153

154154
assert torch.allclose(
155-
liger_output, torch_output, atol=atol, rtol=rtol
155+
liger_output,
156+
torch_output,
157+
atol=atol,
158+
rtol=rtol,
156159
), f"Forward pass mismatch for shape {shape}"
157160

158161
grad_output = torch.randn_like(x)
159162
liger_output.backward(grad_output, retain_graph=True)
160163
torch_output.backward(grad_output, retain_graph=True)
161164

162165
assert torch.allclose(
163-
liger_x.grad, torch_x.grad, atol=atol, rtol=rtol
166+
liger_x.grad,
167+
torch_x.grad,
168+
atol=atol,
169+
rtol=rtol,
164170
), f"Input gradient mismatch for shape {shape}"
165171
assert torch.allclose(
166-
liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol
172+
liger_ln.weight.grad,
173+
torch_ln.weight.grad,
174+
atol=atol,
175+
rtol=rtol,
167176
), f"Weight gradient mismatch for shape {shape}"
168177
assert torch.allclose(
169-
liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol
178+
liger_ln.bias.grad,
179+
torch_ln.bias.grad,
180+
atol=atol,
181+
rtol=rtol,
170182
), f"Bias gradient mismatch for shape {shape}"

0 commit comments

Comments
 (0)