@@ -267,6 +267,8 @@ def _layer_norm_fwd(
267
267
residual_dtype = None ,
268
268
is_rms_norm = False ,
269
269
return_dropout_mask = False ,
270
+ out = None ,
271
+ residual_out = None
270
272
):
271
273
if residual is not None :
272
274
residual_dtype = residual .dtype
@@ -294,10 +296,13 @@ def _layer_norm_fwd(
294
296
assert rowscale .is_contiguous ()
295
297
assert rowscale .shape == (M ,)
296
298
# allocate output
297
- y = torch .empty_like (x , dtype = x .dtype if out_dtype is None else out_dtype )
298
- assert y .stride (- 1 ) == 1
299
+ if out is None :
300
+ out = torch .empty_like (x , dtype = x .dtype if out_dtype is None else out_dtype )
301
+ else :
302
+ assert out .shape == x .shape
303
+ assert out .stride (- 1 ) == 1
299
304
if weight1 is not None :
300
- y1 = torch .empty_like (y )
305
+ y1 = torch .empty_like (out )
301
306
assert y1 .stride (- 1 ) == 1
302
307
else :
303
308
y1 = None
@@ -308,9 +313,12 @@ def _layer_norm_fwd(
308
313
or rowscale is not None
309
314
or x1 is not None
310
315
):
311
- residual_out = torch .empty (
312
- M , N , device = x .device , dtype = residual_dtype if residual_dtype is not None else x .dtype
313
- )
316
+ if residual_out is None :
317
+ residual_out = torch .empty (
318
+ M , N , device = x .device , dtype = residual_dtype if residual_dtype is not None else x .dtype
319
+ )
320
+ else :
321
+ assert residual_out .shape == x .shape
314
322
assert residual_out .stride (- 1 ) == 1
315
323
else :
316
324
residual_out = None
@@ -334,7 +342,7 @@ def _layer_norm_fwd(
334
342
with torch .cuda .device (x .device .index ):
335
343
_layer_norm_fwd_1pass_kernel [(M ,)](
336
344
x ,
337
- y ,
345
+ out ,
338
346
weight ,
339
347
bias ,
340
348
residual ,
@@ -349,7 +357,7 @@ def _layer_norm_fwd(
349
357
mean ,
350
358
rstd ,
351
359
x .stride (0 ),
352
- y .stride (0 ),
360
+ out .stride (0 ),
353
361
residual .stride (0 ) if residual is not None else 0 ,
354
362
residual_out .stride (0 ) if residual_out is not None else 0 ,
355
363
x1 .stride (0 ) if x1 is not None else 0 ,
@@ -373,7 +381,7 @@ def _layer_norm_fwd(
373
381
else :
374
382
dropout_mask1 = None
375
383
return (
376
- y ,
384
+ out ,
377
385
y1 ,
378
386
mean ,
379
387
rstd ,
@@ -714,6 +722,8 @@ def forward(
714
722
residual_in_fp32 = False ,
715
723
is_rms_norm = False ,
716
724
return_dropout_mask = False ,
725
+ out = None ,
726
+ residual_out = None
717
727
):
718
728
x_shape_og = x .shape
719
729
# reshape input data into 2D tensor
@@ -745,6 +755,10 @@ def forward(
745
755
if residual is not None
746
756
else (torch .float32 if residual_in_fp32 else None )
747
757
)
758
+ if out is not None :
759
+ out = out .reshape (- 1 , out .shape [- 1 ])
760
+ if residual_out is not None :
761
+ residual_out = residual_out .reshape (- 1 , residual_out .shape [- 1 ])
748
762
y , y1 , mean , rstd , residual_out , seeds , dropout_mask , dropout_mask1 = _layer_norm_fwd (
749
763
x ,
750
764
weight ,
@@ -759,6 +773,8 @@ def forward(
759
773
residual_dtype = residual_dtype ,
760
774
is_rms_norm = is_rms_norm ,
761
775
return_dropout_mask = return_dropout_mask ,
776
+ out = out ,
777
+ residual_out = residual_out
762
778
)
763
779
ctx .save_for_backward (
764
780
residual_out , weight , bias , weight1 , bias1 , rowscale , seeds , mean , rstd
@@ -853,6 +869,8 @@ def backward(ctx, dy, *args):
853
869
None ,
854
870
None ,
855
871
None ,
872
+ None ,
873
+ None ,
856
874
)
857
875
858
876
@@ -871,6 +889,8 @@ def layer_norm_fn(
871
889
residual_in_fp32 = False ,
872
890
is_rms_norm = False ,
873
891
return_dropout_mask = False ,
892
+ out = None ,
893
+ residual_out = None
874
894
):
875
895
return LayerNormFn .apply (
876
896
x ,
@@ -887,6 +907,8 @@ def layer_norm_fn(
887
907
residual_in_fp32 ,
888
908
is_rms_norm ,
889
909
return_dropout_mask ,
910
+ out ,
911
+ residual_out
890
912
)
891
913
892
914
@@ -904,6 +926,8 @@ def rms_norm_fn(
904
926
prenorm = False ,
905
927
residual_in_fp32 = False ,
906
928
return_dropout_mask = False ,
929
+ out = None ,
930
+ residual_out = None
907
931
):
908
932
return LayerNormFn .apply (
909
933
x ,
@@ -920,6 +944,8 @@ def rms_norm_fn(
920
944
residual_in_fp32 ,
921
945
True ,
922
946
return_dropout_mask ,
947
+ out ,
948
+ residual_out
923
949
)
924
950
925
951
0 commit comments