Skip to content

Commit bb70b92

Browse files
committed
fix comments
1 parent 7ad71a0 commit bb70b92

File tree

13 files changed

+65
-52
lines changed

13 files changed

+65
-52
lines changed

cmake/external/flashattn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

paddle/phi/backends/dynload/flashattn.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2-
Licensed under the Apache License, Version 2.0 (the "License");
3-
you may not use this file except in compliance with the License.
4-
You may obtain a copy of the License at
5-
http://www.apache.org/licenses/LICENSE-2.0
6-
Unless required by applicable law or agreed to in writing, software
7-
distributed under the License is distributed on an "AS IS" BASIS,
8-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9-
See the License for the specific language governing permissions and
10-
limitations under the License. */
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
1114

1215
#include "paddle/phi/backends/dynload/flashattn.h"
1316

paddle/phi/backends/dynload/flashattn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.

paddle/phi/infermeta/backward.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,23 @@ void CropGradInferMeta(const MetaTensor& out_grad,
198198
}
199199
}
200200

201+
void FlashAttnGradInferMeta(const MetaTensor& q,
202+
const MetaTensor& k,
203+
const MetaTensor& v,
204+
MetaTensor* dq,
205+
MetaTensor* dk,
206+
MetaTensor* dv) {
207+
if (dq) {
208+
dq->share_meta(q);
209+
}
210+
if (dk && k) {
211+
dk->share_meta(k);
212+
}
213+
if (dv && v) {
214+
dv->share_meta(v);
215+
}
216+
}
217+
201218
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
202219
const MetaTensor& softmax,
203220
const MetaTensor& loss_grad,

paddle/phi/infermeta/backward.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad,
168168
int dim2,
169169
MetaTensor* x_grad);
170170

171+
void FlashAttnGradInferMeta(const MetaTensor& q,
172+
const MetaTensor& k,
173+
const MetaTensor& v,
174+
MetaTensor* dq,
175+
MetaTensor* dk,
176+
MetaTensor* dv);
177+
171178
void GatherNdGradInferMeta(const MetaTensor& x,
172179
const MetaTensor& index,
173180
const MetaTensor& out_grad,

paddle/phi/infermeta/ternary.cc

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -267,23 +267,6 @@ void FlashAttnInferMeta(const MetaTensor& q,
267267
out->set_layout(q.layout());
268268
}
269269

270-
void FlashAttnGradInferMeta(const MetaTensor& q,
271-
const MetaTensor& k,
272-
const MetaTensor& v,
273-
MetaTensor* dq,
274-
MetaTensor* dk,
275-
MetaTensor* dv) {
276-
if (dq) {
277-
dq->share_meta(q);
278-
}
279-
if (dk && k) {
280-
dk->share_meta(k);
281-
}
282-
if (dv && v) {
283-
dv->share_meta(v);
284-
}
285-
}
286-
287270
void ArangeInferMeta(const MetaTensor& start,
288271
const MetaTensor& end,
289272
const MetaTensor& step,

paddle/phi/infermeta/ternary.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,6 @@ void FlashAttnInferMeta(const MetaTensor& q,
7171
MetaTensor* softmax,
7272
MetaTensor* seed_offset);
7373

74-
void FlashAttnGradInferMeta(const MetaTensor& q,
75-
const MetaTensor& k,
76-
const MetaTensor& v,
77-
MetaTensor* dq,
78-
MetaTensor* dk,
79-
MetaTensor* dv);
80-
8174
void InstanceNormInferMeta(const MetaTensor& x,
8275
const MetaTensor& scale,
8376
const MetaTensor& bias,

paddle/phi/kernels/flash_attn_grad_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

paddle/phi/kernels/flash_attn_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

paddle/phi/kernels/gpu/arange_kernel.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ void ArangeNullaryKernel(const Context& dev_ctx,
7272
Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data);
7373
}
7474

75+
template decltype(ArangeNullaryKernel<int64_t, phi::GPUContext>)
76+
ArangeNullaryKernel;
77+
template decltype(ArangeNullaryKernel<int, phi::GPUContext>)
78+
ArangeNullaryKernel;
7579
} // namespace phi
7680

7781
PD_REGISTER_KERNEL(
@@ -80,6 +84,3 @@ PD_REGISTER_KERNEL(
8084
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
8185
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
8286
}
83-
84-
PD_REGISTER_KERNEL(
85-
arange_nullary, GPU, ALL_LAYOUT, phi::ArangeNullaryKernel, int64_t, int) {}

0 commit comments

Comments
 (0)