Skip to content

Commit bcd918f

Browse files
committed
[LayerNorm] Add option to write result to out and residual_out
1 parent bd82d6c commit bcd918f

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

flash_attn/ops/triton/layer_norm.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def _layer_norm_fwd(
267267
residual_dtype=None,
268268
is_rms_norm=False,
269269
return_dropout_mask=False,
270+
out=None,
271+
residual_out=None
270272
):
271273
if residual is not None:
272274
residual_dtype = residual.dtype
@@ -294,10 +296,13 @@ def _layer_norm_fwd(
294296
assert rowscale.is_contiguous()
295297
assert rowscale.shape == (M,)
296298
# allocate output
297-
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
298-
assert y.stride(-1) == 1
299+
if out is None:
300+
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
301+
else:
302+
assert out.shape == x.shape
303+
assert out.stride(-1) == 1
299304
if weight1 is not None:
300-
y1 = torch.empty_like(y)
305+
y1 = torch.empty_like(out)
301306
assert y1.stride(-1) == 1
302307
else:
303308
y1 = None
@@ -308,9 +313,12 @@ def _layer_norm_fwd(
308313
or rowscale is not None
309314
or x1 is not None
310315
):
311-
residual_out = torch.empty(
312-
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
313-
)
316+
if residual_out is None:
317+
residual_out = torch.empty(
318+
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
319+
)
320+
else:
321+
assert residual_out.shape == x.shape
314322
assert residual_out.stride(-1) == 1
315323
else:
316324
residual_out = None
@@ -334,7 +342,7 @@ def _layer_norm_fwd(
334342
with torch.cuda.device(x.device.index):
335343
_layer_norm_fwd_1pass_kernel[(M,)](
336344
x,
337-
y,
345+
out,
338346
weight,
339347
bias,
340348
residual,
@@ -349,7 +357,7 @@ def _layer_norm_fwd(
349357
mean,
350358
rstd,
351359
x.stride(0),
352-
y.stride(0),
360+
out.stride(0),
353361
residual.stride(0) if residual is not None else 0,
354362
residual_out.stride(0) if residual_out is not None else 0,
355363
x1.stride(0) if x1 is not None else 0,
@@ -373,7 +381,7 @@ def _layer_norm_fwd(
373381
else:
374382
dropout_mask1 = None
375383
return (
376-
y,
384+
out,
377385
y1,
378386
mean,
379387
rstd,
@@ -714,6 +722,8 @@ def forward(
714722
residual_in_fp32=False,
715723
is_rms_norm=False,
716724
return_dropout_mask=False,
725+
out=None,
726+
residual_out=None
717727
):
718728
x_shape_og = x.shape
719729
# reshape input data into 2D tensor
@@ -745,6 +755,10 @@ def forward(
745755
if residual is not None
746756
else (torch.float32 if residual_in_fp32 else None)
747757
)
758+
if out is not None:
759+
out = out.reshape(-1, out.shape[-1])
760+
if residual_out is not None:
761+
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
748762
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
749763
x,
750764
weight,
@@ -759,6 +773,8 @@ def forward(
759773
residual_dtype=residual_dtype,
760774
is_rms_norm=is_rms_norm,
761775
return_dropout_mask=return_dropout_mask,
776+
out=out,
777+
residual_out=residual_out
762778
)
763779
ctx.save_for_backward(
764780
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
@@ -853,6 +869,8 @@ def backward(ctx, dy, *args):
853869
None,
854870
None,
855871
None,
872+
None,
873+
None,
856874
)
857875

858876

@@ -871,6 +889,8 @@ def layer_norm_fn(
871889
residual_in_fp32=False,
872890
is_rms_norm=False,
873891
return_dropout_mask=False,
892+
out=None,
893+
residual_out=None
874894
):
875895
return LayerNormFn.apply(
876896
x,
@@ -887,6 +907,8 @@ def layer_norm_fn(
887907
residual_in_fp32,
888908
is_rms_norm,
889909
return_dropout_mask,
910+
out,
911+
residual_out
890912
)
891913

892914

@@ -904,6 +926,8 @@ def rms_norm_fn(
904926
prenorm=False,
905927
residual_in_fp32=False,
906928
return_dropout_mask=False,
929+
out=None,
930+
residual_out=None
907931
):
908932
return LayerNormFn.apply(
909933
x,
@@ -920,6 +944,8 @@ def rms_norm_fn(
920944
residual_in_fp32,
921945
True,
922946
return_dropout_mask,
947+
out,
948+
residual_out
923949
)
924950

925951

0 commit comments

Comments
 (0)