Skip to content

Commit a51ad9d

Browse files
houj04co63oc
authored andcommitted
[XPU] add multinomial op (PaddlePaddle#65413)
1 parent 77b1a0f commit a51ad9d

File tree

5 files changed

+190
-1
lines changed

5 files changed

+190
-1
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ if(NOT DEFINED XPU_XRE_BASE_VERSION)
3030
set(XPU_XRE_BASE_VERSION "4.32.0.1")
3131
endif()
3232
if(NOT DEFINED XPU_XHPC_BASE_DATE)
33-
set(XPU_XHPC_BASE_DATE "20240601")
33+
set(XPU_XHPC_BASE_DATE "20240621")
3434
endif()
3535
set(XPU_XCCL_BASE_VERSION "1.2.1.2")
3636
if(NOT DEFINED XPU_XFT_BASE_VERSION)

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ XPUOpMap& get_kl3_ops() {
684684
{"multi_encoder_xpu",
685685
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
686686
{"multiclass_nms3", XPUKernelSet({phi::DataType::FLOAT32})},
687+
{"multinomial", XPUKernelSet({phi::DataType::FLOAT32})},
687688
{"nearest_interp_v2",
688689
XPUKernelSet({phi::DataType::FLOAT32,
689690
phi::DataType::FLOAT16,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/multinomial_kernel.h"
16+
17+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void MultinomialKernel(const Context& dev_ctx,
24+
const DenseTensor& x,
25+
const Scalar& num_samples,
26+
bool replacement,
27+
DenseTensor* out) {
28+
auto int_num_samples = num_samples.to<int64_t>();
29+
auto* in_data = x.data<T>();
30+
int64_t* out_data = dev_ctx.template Alloc<int64_t>(out);
31+
auto in_dims = x.dims();
32+
int64_t dim_size = in_dims.size();
33+
const int64_t num_categories = in_dims[dim_size - 1];
34+
const int64_t num_distributions = dim_size > 1 ? in_dims[dim_size - 2] : 1;
35+
int64_t seed = dev_ctx.GetGenerator()->Random64();
36+
37+
// int multinomial(Context* ctx, const T* x, TID* y, int64_t num_samples,
38+
// int64_t num_categories, int64_t num_distributions, bool replacement,
39+
// int64_t seed);
40+
int r = xpu::multinomial<T, int64_t>(dev_ctx.x_context(),
41+
in_data,
42+
out_data,
43+
int_num_samples,
44+
num_categories,
45+
num_distributions,
46+
replacement,
47+
seed);
48+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "multinomial");
49+
}
50+
51+
} // namespace phi
52+
53+
PD_REGISTER_KERNEL(
54+
multinomial, XPU, ALL_LAYOUT, phi::MultinomialKernel, float) {
55+
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
56+
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import unittest
17+
18+
import numpy as np
19+
from get_test_cover_info import (
20+
XPUOpTestWrapper,
21+
create_test_class,
22+
get_xpu_op_support_types,
23+
)
24+
from op_test_xpu import XPUOpTest
25+
26+
import paddle
27+
28+
paddle.enable_static()
29+
30+
31+
def sample_output_one_dimension(out, dim):
32+
# count numbers of different categories
33+
sample_prob = np.zeros(dim).astype("float32")
34+
sample_index_prob = np.unique(out, return_counts=True)
35+
sample_prob[sample_index_prob[0]] = sample_index_prob[1]
36+
sample_prob /= sample_prob.sum()
37+
return sample_prob
38+
39+
40+
def sample_output_two_dimension(out, shape):
41+
num_dist = shape[0]
42+
out_list = np.split(out, num_dist, axis=0)
43+
sample_prob = np.zeros(shape).astype("float32")
44+
for i in range(num_dist):
45+
sample_index_prob = np.unique(out_list[i], return_counts=True)
46+
sample_prob[i][sample_index_prob[0]] = sample_index_prob[1]
47+
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
48+
return sample_prob
49+
50+
51+
class XPUTestMultinomialOp(XPUOpTestWrapper):
52+
def __init__(self):
53+
self.op_name = 'multinomial'
54+
self.use_dynamic_create_class = False
55+
56+
class TestMultinomialOp(XPUOpTest):
57+
def setUp(self):
58+
self.dtype = self.in_type
59+
self.place = paddle.XPUPlace(0)
60+
paddle.enable_static()
61+
self.op_type = "multinomial"
62+
self.python_api = paddle.multinomial
63+
self.init_data()
64+
self.inputs = {"X": self.input_np}
65+
66+
def init_data(self):
67+
# input probability is a vector, and replacement is True
68+
self.input_np = np.random.rand(4).astype(self.dtype)
69+
self.outputs = {"Out": np.zeros(100000).astype("int64")}
70+
self.attrs = {"num_samples": 100000, "replacement": True}
71+
72+
def test_check_output(self):
73+
self.check_output_with_place_customized(
74+
self.verify_output, self.place
75+
)
76+
77+
def sample_output(self, out):
78+
return sample_output_one_dimension(out, 4)
79+
80+
def verify_output(self, outs):
81+
# normalize the input to get the probability
82+
prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True)
83+
sample_prob = self.sample_output(np.array(outs[0]))
84+
np.testing.assert_allclose(
85+
sample_prob,
86+
prob,
87+
rtol=0,
88+
atol=0.01,
89+
err_msg='sample_prob: '
90+
+ str(sample_prob)
91+
+ '\nprob: '
92+
+ str(prob),
93+
)
94+
95+
class TestMultinomialOp2(TestMultinomialOp):
96+
def init_data(self):
97+
# input probability is a matrix
98+
self.input_np = np.random.rand(3, 4).astype(self.dtype)
99+
self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")}
100+
self.attrs = {"num_samples": 100000, "replacement": True}
101+
102+
def sample_output(self, out):
103+
return sample_output_two_dimension(out, [3, 4])
104+
105+
class TestMultinomialOp3(TestMultinomialOp):
106+
def init_data(self):
107+
# replacement is False. number of samples must be less than number of categories.
108+
self.input_np = np.random.rand(1000).astype(self.dtype)
109+
self.outputs = {"Out": np.zeros(100).astype("int64")}
110+
self.attrs = {"num_samples": 100, "replacement": False}
111+
112+
def verify_output(self, outs):
113+
out = np.array(outs[0])
114+
unique_out = np.unique(out)
115+
self.assertEqual(
116+
len(unique_out),
117+
100,
118+
"replacement is False. categories can't be sampled repeatedly",
119+
)
120+
121+
122+
support_types = get_xpu_op_support_types('multinomial')
123+
for stype in support_types:
124+
create_test_class(globals(), XPUTestMultinomialOp, stype)
125+
126+
if __name__ == "__main__":
127+
unittest.main()

tools/xpu/pack_paddle_depence.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ function xhpc_prepare() {
7272
cp -r ${XHPC_DIR_NAME}/xdnn/so/libxpuapi.so xpu/lib
7373

7474
check_files ${XHPC_DIR_NAME}/xfa/include/flash_api.h ${XHPC_DIR_NAME}/xfa/so/libxpu_flash_attention.so
75+
76+
# remove '#include "xpu/refactor/core/quant.h"' in flash_api.h
77+
# TODO(houj04): remove this hack when compile issue is resolved in XHPC
78+
sed -i '8d' ${XHPC_DIR_NAME}/xfa/include/flash_api.h
79+
7580
cp -r ${XHPC_DIR_NAME}/xfa/include/* xpu/include/xhpc/xfa
7681
cp -r ${XHPC_DIR_NAME}/xfa/so/libxpu_flash_attention.so xpu/lib/
7782
}

0 commit comments

Comments
 (0)