File tree Expand file tree Collapse file tree 1 file changed +16
-4
lines changed Expand file tree Collapse file tree 1 file changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -152,19 +152,31 @@ def test_liger_layer_norm_weird_shapes(
152
152
torch_output = torch_ln (torch_x )
153
153
154
154
assert torch .allclose (
155
- liger_output , torch_output , atol = atol , rtol = rtol
155
+ liger_output ,
156
+ torch_output ,
157
+ atol = atol ,
158
+ rtol = rtol ,
156
159
), f"Forward pass mismatch for shape { shape } "
157
160
158
161
grad_output = torch .randn_like (x )
159
162
liger_output .backward (grad_output , retain_graph = True )
160
163
torch_output .backward (grad_output , retain_graph = True )
161
164
162
165
assert torch .allclose (
163
- liger_x .grad , torch_x .grad , atol = atol , rtol = rtol
166
+ liger_x .grad ,
167
+ torch_x .grad ,
168
+ atol = atol ,
169
+ rtol = rtol ,
164
170
), f"Input gradient mismatch for shape { shape } "
165
171
assert torch .allclose (
166
- liger_ln .weight .grad , torch_ln .weight .grad , atol = atol , rtol = rtol
172
+ liger_ln .weight .grad ,
173
+ torch_ln .weight .grad ,
174
+ atol = atol ,
175
+ rtol = rtol ,
167
176
), f"Weight gradient mismatch for shape { shape } "
168
177
assert torch .allclose (
169
- liger_ln .bias .grad , torch_ln .bias .grad , atol = atol , rtol = rtol
178
+ liger_ln .bias .grad ,
179
+ torch_ln .bias .grad ,
180
+ atol = atol ,
181
+ rtol = rtol ,
170
182
), f"Bias gradient mismatch for shape { shape } "
You can’t perform that action at this time.
0 commit comments