Skip to content

Commit d94a108

Browse files
committed
support kldiv_loss/kldiv_loss_grad for kunlun
*test=kunlun
1 parent ac2a94c commit d94a108

File tree

4 files changed

+250
-0
lines changed

4 files changed

+250
-0
lines changed

paddle/fluid/platform/device/xpu/xpu2_op_list.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,9 @@ XPUOpMap& get_kl2_ops() {
306306
{"huber_loss_grad",
307307
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
308308
{"huber_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
309+
{"kldiv_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
310+
{"kldiv_loss_grad",
311+
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
309312
{"iou_similarity",
310313
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
311314
{"index_select",
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
16+
#include "paddle/phi/core/enforce.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/softmax_kernel.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void KLDivLossGradKernel(const Context& dev_ctx,
24+
const DenseTensor& x,
25+
const DenseTensor& label,
26+
const DenseTensor& d_out,
27+
const std::string& reduction,
28+
DenseTensor* d_x) {
29+
using XPUType = typename XPUTypeTrait<T>::Type;
30+
dev_ctx.template Alloc<T>(d_x);
31+
if (d_x->numel() == 0) {
32+
return;
33+
}
34+
35+
int r = XPU_SUCCESS;
36+
r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
37+
reinterpret_cast<const XPUType*>(label.data<T>()),
38+
reinterpret_cast<const XPUType*>(d_out.data<T>()),
39+
reinterpret_cast<XPUType*>(d_x->data<T>()),
40+
d_x->numel());
41+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");
42+
if ("none" != reduction) {
43+
PADDLE_THROW(phi::errors::Unavailable(
44+
"Not supported reduction [%s] in kldiv_loss_grad", reduction));
45+
}
46+
}
47+
48+
} // namespace phi
49+
50+
PD_REGISTER_KERNEL(
51+
kldiv_loss_grad, XPU, ALL_LAYOUT, phi::KLDivLossGradKernel, float) {}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
16+
#include "paddle/phi/core/enforce.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/softmax_kernel.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void KLDivLossKernel(const Context& dev_ctx,
24+
const DenseTensor& x,
25+
const DenseTensor& label,
26+
const std::string& reduction,
27+
DenseTensor* out) {
28+
using XPUType = typename XPUTypeTrait<T>::Type;
29+
dev_ctx.template Alloc<T>(out);
30+
if (out->numel() == 0) {
31+
return;
32+
}
33+
34+
int r = XPU_SUCCESS;
35+
r = xpu::kldiv_loss(dev_ctx.x_context(),
36+
reinterpret_cast<const XPUType*>(x.data<T>()),
37+
reinterpret_cast<const XPUType*>(label.data<T>()),
38+
reinterpret_cast<XPUType*>(out->data<T>()),
39+
out->numel());
40+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss");
41+
if ("none" != reduction) {
42+
PADDLE_THROW(phi::errors::Unavailable(
43+
"Not supported reduction [%s] in kldiv_loss", reduction));
44+
}
45+
}
46+
47+
} // namespace phi
48+
49+
PD_REGISTER_KERNEL(kldiv_loss, XPU, ALL_LAYOUT, phi::KLDivLossKernel, float) {}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import sys
15+
16+
sys.path.append("..")
17+
import paddle
18+
import unittest
19+
import numpy as np
20+
from paddle.nn.functional import kl_div
21+
from op_test_xpu import XPUOpTest
22+
from xpu.get_test_cover_info import (
23+
create_test_class,
24+
get_xpu_op_support_types,
25+
XPUOpTestWrapper,
26+
)
27+
28+
paddle.enable_static()
29+
30+
31+
def kldiv_loss(x, target, reduction):
32+
output = target * (np.log(target) - x)
33+
loss = np.where(target >= 0, output, np.zeros_like(x))
34+
35+
if reduction == "batchmean":
36+
if len(x.shape) > 0:
37+
return loss.sum() / x.shape[0]
38+
else:
39+
return loss.sum()
40+
if reduction == "mean":
41+
return loss.mean()
42+
if reduction == "sum":
43+
return loss.sum()
44+
45+
return loss
46+
47+
48+
class XPUTestKLDivLossOp(XPUOpTestWrapper):
49+
def __init__(self):
50+
self.op_name = 'kldiv_loss'
51+
self.use_dynamic_create_class = False
52+
53+
class TestKLDivLossOp(XPUOpTest):
54+
def setUp(self):
55+
self.initTestCase()
56+
self.op_type = 'kldiv_loss'
57+
self.dtype = np.float32
58+
self.__class__.use_xpu = True
59+
self.python_api = kl_div
60+
x = np.random.uniform(-10, 10, self.x_shape).astype('float32')
61+
target = np.random.uniform(-10, 10, self.x_shape).astype('float32')
62+
63+
self.attrs = {"reduction": self.reduction}
64+
65+
self.inputs = {
66+
'X': x,
67+
'Target': target,
68+
}
69+
loss = kldiv_loss(x, target, self.reduction)
70+
self.outputs = {'Loss': loss.astype('float32')}
71+
72+
def test_check_output(self):
73+
self.check_output(check_eager=True)
74+
75+
def test_check_grad(self):
76+
self.check_grad_with_place(
77+
paddle.XPUPlace(0),
78+
['X'],
79+
'Loss',
80+
no_grad_set=set(["Target"]),
81+
check_eager=True,
82+
)
83+
84+
def initTestCase(self):
85+
self.x_shape = (4, 5, 5)
86+
self.reduction = 'none'
87+
88+
class TestKLDivLossOp2(TestKLDivLossOp):
89+
def initTestCase(self):
90+
self.x_shape = (3, 2, 7, 7)
91+
self.reduction = 'none'
92+
93+
class TestKLDivLossOp3(TestKLDivLossOp):
94+
def initTestCase(self):
95+
self.x_shape = (2, 3, 5, 7, 9)
96+
self.reduction = 'none'
97+
98+
class TestKLDivLossOp4(TestKLDivLossOp):
99+
def initTestCase(self):
100+
self.x_shape = (5, 20)
101+
self.reduction = 'none'
102+
103+
class TestKLDivLossDygraph(unittest.TestCase):
104+
def run_kl_loss(self, reduction, shape=(5, 20)):
105+
x = np.random.uniform(-10, 10, shape).astype('float32')
106+
target = np.random.uniform(-10, 10, shape).astype('float32')
107+
gt_loss = kldiv_loss(x, target, reduction)
108+
109+
with paddle.fluid.dygraph.guard():
110+
kldiv_criterion = paddle.nn.KLDivLoss(reduction)
111+
pred_loss = kldiv_criterion(
112+
paddle.to_tensor(x), paddle.to_tensor(target)
113+
)
114+
np.testing.assert_allclose(
115+
pred_loss.numpy(), gt_loss, rtol=1e-05
116+
)
117+
118+
def test_kl_loss_none(self):
119+
self.run_kl_loss('none')
120+
121+
def test_kl_loss_static_api(self):
122+
input = paddle.fluid.data(name='input', shape=[5, 20])
123+
label = paddle.fluid.data(name='label', shape=[5, 20])
124+
125+
paddle.nn.functional.kl_div(input, label)
126+
127+
class TestKLDivLossTypePromotion(unittest.TestCase):
128+
def test_kl_div_promotion(self):
129+
130+
with paddle.fluid.dygraph.guard():
131+
x1 = paddle.rand([5, 20], dtype='float32')
132+
target1 = paddle.rand([5, 20], dtype='float32')
133+
134+
kldiv_criterion = paddle.nn.KLDivLoss()
135+
pred_loss1 = kldiv_criterion(x1, target1)
136+
137+
x2 = paddle.rand([5, 20], dtype='float32')
138+
target2 = paddle.rand([5, 20], dtype='float32')
139+
pred_loss2 = paddle.nn.functional.kl_div(x2, target2)
140+
141+
142+
support_types = get_xpu_op_support_types('kldiv_loss')
143+
for stype in support_types:
144+
create_test_class(globals(), XPUTestKLDivLossOp, stype)
145+
146+
if __name__ == "__main__":
147+
unittest.main()

0 commit comments

Comments
 (0)