Skip to content

Conversation

kuke
Copy link
Contributor

@kuke kuke commented Jan 17, 2018

Resolve #7430

__global__ void LabelErasedIdx(const T* in_dat, const int in_len,
const T* tokens, const int tokens_len,
int* num_erased) {
__global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why in_len use int64_t while tokens_len is size_t?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They have different data type.

for (int i = 0; i < tokens_len; ++i) {
for (size_t i = 0; i < tokens_len; ++i) {
if (in_dat[index] == tokens[i]) {
erased = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a break here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());

// Count number of elements to be erased
thrust::device_vector<size_t> num_erased(in_len + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set num_erased[0]=0 here to avoid checking if index==0 in every threads,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

template <typename T>
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure that Vector in LoD must be thrust::host_vector in .cu file. Is it necessary converting device_vector to std::vector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

template <typename T, typename Vector>
thrust::device_vector<T> set_device_vector(Vector& vector) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can have a try like this:

device_vector<int> D(vector.begin(), vector.end());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works

int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
thrust::device_vector<size_t> dev_in_lod =
set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thrust::device_vector<size_t> dev_in_lod(lod0.begin(), lod0.end());

This should work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto tokens_len = tokens.size();
auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel();
auto in_dat = in->data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, We should registry an int64_t kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

class TestSequenceEraseOpEmpty(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test for int64_t input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

@kuke kuke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thx

__global__ void LabelErasedIdx(const T* in_dat, const int in_len,
const T* tokens, const int tokens_len,
int* num_erased) {
__global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They have different data type.

for (int i = 0; i < tokens_len; ++i) {
for (size_t i = 0; i < tokens_len; ++i) {
if (in_dat[index] == tokens[i]) {
erased = 1;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

template <typename T, typename Vector>
thrust::device_vector<T> set_device_vector(Vector& vector) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works

}

template <typename T>
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto tokens_len = tokens.size();
auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel();
auto in_dat = in->data<T>();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());

// Count number of elements to be erased
thrust::device_vector<size_t> num_erased(in_len + 1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
thrust::device_vector<size_t> dev_in_lod =
set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

class TestSequenceEraseOpEmpty(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@wanghaoshuang wanghaoshuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuke kuke merged commit a1c281f into PaddlePaddle:develop Jan 19, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants