Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions paddle/phi/kernels/funcs/gather_scatter_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ limitations under the License. */
namespace phi {
namespace funcs {

#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t( \
func, double) Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
Instantiate_Template_Function_index_t(func, \
phi::dtype::bfloat16) \
Instantiate_Template_Function_index_t(func, unsigned char)

#define Instantiate_Template_Function_index_t(func, tensor_t) \
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/index_add_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,6 @@ PD_REGISTER_KERNEL(index_add_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/index_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,6 @@ PD_REGISTER_KERNEL(index_add,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(put_along_axis_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/put_along_axis_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ PD_REGISTER_KERNEL(put_along_axis,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ PD_REGISTER_KERNEL(take_along_axis_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/take_along_axis_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ PD_REGISTER_KERNEL(take_along_axis,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
67 changes: 65 additions & 2 deletions python/paddle/fluid/tests/unittests/test_index_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.fluid import Program
from paddle.fluid import Program, core


def compute_index_add_ref(
Expand Down Expand Up @@ -99,6 +99,69 @@ def test_check_grad_normal(self):
self.check_grad(['X', 'AddValue'], 'Out')


class TestIndexAddFP16Op(TestIndexAddOp):
def init_dtype_type(self):
self.axis = 0
self.x_type = np.float16
self.index_type = np.int64
self.x_shape = (101, 3)
self.index_size = 3
self.add_value_shape = (3, 3)
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestIndexAddBF16Op(OpTest):
def setUp(self):
self.python_api = raw_index_add
self.op_type = "index_add"
self.init_dtype_type()
index_np = np.random.randint(
low=0, high=self.x_shape[self.axis], size=self.index_size
)
x_np = np.random.random(self.x_shape).astype(self.x_type)
add_value_np = np.random.random(self.add_value_shape).astype(
self.x_type
)

self.inputs = {
'X': convert_float_to_uint16(x_np),
'Index': index_np,
'AddValue': convert_float_to_uint16(add_value_np),
}
self.attrs = {'axis': self.axis}
out = compute_index_add_ref(
self.axis,
self.x_shape,
x_np,
self.add_value_shape,
add_value_np,
self.index_size,
index_np,
)
self.outputs = {'Out': convert_float_to_uint16(out)}
self.place = core.CUDAPlace(0)

def init_dtype_type(self):
self.axis = 0
self.x_type = np.float32
self.index_type = np.int64
self.x_shape = (101, 3)
self.index_size = 3
self.add_value_shape = (3, 3)
self.dtype = np.uint16

def test_check_output(self):
self.check_output_with_place(self.place)

def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X', 'AddValue'], 'Out')


class TestIndexAddAPI(unittest.TestCase):
def setUp(self):
self.setType()
Expand Down
46 changes: 45 additions & 1 deletion python/paddle/fluid/tests/unittests/test_index_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle import fluid
from paddle.fluid import core


class TestIndexSampleOp(OpTest):
Expand Down Expand Up @@ -121,6 +122,49 @@ def config(self):
self.index_type = "int64"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestIndexSampleBF16Op(OpTest):
def setUp(self):
self.op_type = "index_sample"
self.python_api = paddle.index_sample
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
indexnp = np.random.randint(
low=0, high=self.x_shape[1], size=self.index_shape
).astype(self.index_type)
self.inputs = {'X': xnp, 'Index': indexnp}
index_array = []
for i in range(self.index_shape[0]):
for j in indexnp[i]:
index_array.append(xnp[i, j])
index_array = np.array(index_array).astype(self.x_type)
out = np.reshape(index_array, self.index_shape)
self.outputs = {'Out': out}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place)

def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')

def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float32"
self.dtype = np.uint16
self.index_shape = (10, 10)
self.index_type = "int32"


class TestIndexSampleShape(unittest.TestCase):
def test_shape(self):
paddle.enable_static()
Expand Down
77 changes: 71 additions & 6 deletions python/paddle/fluid/tests/unittests/test_put_along_axis_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.framework import core
Expand All @@ -28,19 +28,18 @@ class TestPutAlongAxisOp(OpTest):
def setUp(self):
self.init_data()
self.reduce_op = "assign"
self.dtype = 'float64'
self.op_type = "put_along_axis"
self.python_api = paddle.tensor.put_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
# numpy put_along_axis is an inplace opearion.
# numpy put_along_axis is an inplace operation.
self.xnp_result = copy.deepcopy(self.xnp)
np.put_along_axis(self.xnp_result, self.index, self.value, self.axis)
self.target = self.xnp_result
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.braodcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.braodcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.braodcast_shape)
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
Expand All @@ -56,6 +55,7 @@ def test_check_grad(self):
self.check_grad(["Input", "Value"], "Result")

def init_data(self):
self.dtype = 'float64'
self.x_type = "float64"
self.x_shape = (10, 10, 10)
self.value_type = "float64"
Expand All @@ -66,6 +66,71 @@ def init_data(self):
self.axis_type = "int64"


class TestPutAlongAxisFP16Op(TestPutAlongAxisOp):
def init_data(self):
self.dtype = np.float16
self.x_type = "float16"
self.x_shape = (10, 10, 10)
self.value_type = "float16"
self.value = np.array([99]).astype(self.value_type)
self.index_type = "int32"
self.index = np.array([[[0]]]).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestPutAlongAxisBF16Op(OpTest):
def setUp(self):
self.init_data()
self.reduce_op = "assign"
self.op_type = "put_along_axis"
self.python_api = paddle.tensor.put_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
# numpy put_along_axis is an inplace operation.
self.xnp_result = copy.deepcopy(self.xnp)
np.put_along_axis(self.xnp_result, self.index, self.value, self.axis)
self.target = self.xnp_result
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
'Value': self.value_broadcast,
}
self.attrs = {'Axis': self.axis, 'Reduce': self.reduce_op}
self.outputs = {'Result': self.target}

self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input'])
self.inputs['Value'] = convert_float_to_uint16(self.inputs['Value'])
self.outputs['Result'] = convert_float_to_uint16(self.outputs['Result'])
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place)

def test_check_grad(self):
self.check_grad_with_place(self.place, ["Input", "Value"], "Result")

def init_data(self):
self.dtype = np.uint16
self.x_type = "float32"
self.x_shape = (10, 10, 10)
self.value_type = "float32"
self.value = np.array([99]).astype(self.value_type)
self.index_type = "int32"
self.index = np.array([[[0]]]).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"


class TestPutAlongAxisAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
Expand Down
Loading