Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions paddle/operators/roi_pool_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/roi_pool_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

static constexpr int kROISize = 5;

class ROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ROIPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ROIs"),
"Input(ROIs) of ROIPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ROIPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Argmax"),
"Output(Argmax) of ROIPoolOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");

PADDLE_ENFORCE(input_dims.size() == 4,
"The format of input tensor is NCHW.");
PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs to check rois_dims[1] == kROISize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

PADDLE_ENFORCE(rois_dims[1] == kROISize,
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");

int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");

PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled output height must greater than 0");
PADDLE_ENFORCE_GT(pooled_width, 0,
"The pooled output width must greater than 0");
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0");

auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] = input_dims[1];
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;

ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("Argmax", out_dims);
}

protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};

class ROIPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}

protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};

class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ROIPoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor), "
"the input of ROIPoolOp. "
"The format of input tensor is NCHW. Where N is batch size, "
"C is the number of input channels, "
"H is the height of the feature, and "
"W is the width of the feature.");
AddInput("ROIs",
"(Tensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …]. "
"Where batch_id is the id of the data, "
"(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates.");
AddOutput("Out",
"(Tensor), "
"The output of ROIPoolOp is a 4-D tensor with shape "
"(num_rois, channels, pooled_h, pooled_w).");
AddOutput("Argmax",
"(Tensor), "
"Argmaxes corresponding to indices in X used "
"for gradient computation. Only output "
"if arg “is_test” is false.").AsIntermediate();
AddAttr<float>("spatial_scale",
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling.")
.SetDefault(1.0);
AddAttr<int>("pooled_height",
"(int, default 1), "
"The pooled output height.")
.SetDefault(1);
AddAttr<int>("pooled_width",
"(int, default 1), "
"The pooled output width.")
.SetDefault(1);
AddComment(R"DOC(
ROIPool operator

ROI Pooling for Faster-RCNN. The link below is a further introduction:
https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
roi_pool,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
roi_pool_grad,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUPlace, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
232 changes: 232 additions & 0 deletions paddle/operators/roi_pool_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/roi_pool_op.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static constexpr int kROISize = 5;

static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}

template <typename T>
__global__ void GPUROIPoolForward(
const int nthreads, const T* input_data, const int64_t* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
T* output_data, int64_t* argmax_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
int roi_end_h = round(offset_input_rois[4] * spatial_scale);

int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);

T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
int maxidx = -1;
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_data_index = h * width + w;
if (offset_input_data[input_data_index] > maxval) {
maxval = offset_input_data[input_data_index];
maxidx = input_data_index;
}
}
}
output_data[index] = maxval;
if (argmax_data) {
argmax_data[index] = maxidx;
}
}
}

template <typename T>
__global__ void GPUROIPoolBackward(
const int nthreads,
const int64_t* input_rois,
const T* output_grad,
const int64_t* argmax_data,
const int num_rois,
const float spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
T* offset_input_grad = input_grad + input_offset;
const int64_t* offset_argmax_data = argmax_data + output_offset;

int argmax = offset_argmax_data[ph * pooled_width + pw];
if (argmax != -1) {
platform::CudaAtomicAdd(offset_input_grad + argmax,
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
}
}
}


template <typename Place, typename T>
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<Tensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
auto* argmax = ctx.Output<Tensor>("Argmax");

auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");

auto in_dims = in->dims();
auto in_stride = framework::stride(in_dims);
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];

size_t rois_num = rois->dims()[0];
if (rois_num== 0) return;

int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;

GPUROIPoolForward<T>
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size,
in->data<T>(),
rois->data<int64_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
}
};

template <typename Place, typename T>
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<Tensor>("ROIs");
auto* argmax = ctx.Input<Tensor>("Argmax");

auto* out_grad =
ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<Tensor>(framework::GradVarName("X"));

auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");

size_t rois_num = rois->dims()[0];
int channels = in->dims()[1];
int height = in->dims()[2];
int width = in->dims()[3];

if (x_grad) {
x_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, there is no need to set zero here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checked, bp needs to set zero.


int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;

if (output_grad_size > 0) {
GPUROIPoolBackward<T>
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size,
rois->data<int64_t>(),
out_grad->data<T>(),
argmax->data<int64_t>(),
rois_num,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
x_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
roi_pool,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
roi_pool_grad,
ops::GPUROIPoolGradOpKernel<paddle::platform::GPUPlace, float>,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);
Loading