Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 3 additions & 124 deletions paddle/fluid/platform/for_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,136 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"

#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace paddle {
namespace platform {

template <typename DeviceContext>
struct ForRange {
ForRange(const DeviceContext& dev_ctx, size_t limit);

template <typename Function>
void operator()(Function func) const;
};

// NOTE: After the pten kernel is migrated, it needs to be deleted.
template <>
struct ForRange<CPUDeviceContext> {
ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {}

template <typename Function>
void operator()(Function func) const {
for (size_t i = 0; i < limit_; ++i) {
func(i);
}
}

size_t limit_;
};

template <>
struct ForRange<phi::CPUContext> {
ForRange(const phi::CPUContext& dev_ctx, size_t limit) : limit_(limit) {}

template <typename Function>
void operator()(Function func) const {
for (size_t i = 0; i < limit_; ++i) {
func(i);
}
}

size_t limit_;
};

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
size_t idx = static_cast<size_t>(threadIdx.x);
func(idx);
}

template <typename Function>
__global__ static void ForRangeElemwiseOp(Function func, size_t limit) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < limit) {
func(idx);
}
}

// NOTE: After the pten kernel is migrated, it needs to be deleted.
template <>
struct ForRange<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(static_cast<size_t>(limit)) {}

template <typename Function>
inline void operator()(Function func) const {
#ifdef __HIPCC__
// HIP will throw core dump when threads > 256
constexpr int num_threads = 256;
#elif WITH_NV_JETSON
// JETSON_NANO will throw core dump when threads > 128
int num_thread = 256;
platform::ChangeThreadNum(dev_ctx_, &num_thread, 128);
const int num_threads = num_thread;
#else
constexpr int num_threads = 1024;
#endif
size_t block_size = limit_ <= num_threads ? limit_ : num_threads;
size_t grid_size = (limit_ + num_threads - 1) / num_threads;

if (grid_size == 1) {
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
func);
} else {
ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, limit_);
}
}

const CUDADeviceContext& dev_ctx_;
size_t limit_;
};

template <>
struct ForRange<phi::GPUContext> {
ForRange(const phi::GPUContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(static_cast<size_t>(limit)) {}

template <typename Function>
inline void operator()(Function func) const {
#ifdef __HIPCC__
// HIP will throw core dump when threads > 256
constexpr int num_threads = 256;
#elif WITH_NV_JETSON
// JETSON_NANO will throw core dump when threads > 128
int num_thread = 256;
platform::ChangeThreadNum(dev_ctx_, &num_thread, 128);
const int num_threads = num_thread;
#else
constexpr int num_threads = 1024;
#endif
size_t block_size = limit_ <= num_threads ? limit_ : num_threads;
size_t grid_size = (limit_ + num_threads - 1) / num_threads;

if (grid_size == 1) {
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
func);
} else {
ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, limit_);
}
}

const phi::GPUContext& dev_ctx_;
size_t limit_;
};

#endif
using ForRange = phi::funcs::ForRange<DeviceContext>;

} // namespace platform
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/abs_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
// limitations under the License.

#include "paddle/phi/kernels/abs_kernel.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {

Expand All @@ -29,7 +29,7 @@ void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
out, size_t(x.numel() * sizeof(phi::funcs::Real<T>)));
auto* out_data = out->data<phi::funcs::Real<T>>();

paddle::platform::ForRange<Context> for_range(ctx, numel);
phi::funcs::ForRange<Context> for_range(ctx, numel);
phi::funcs::AbsFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/funcs/diagonal.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

#include <algorithm>

#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {
namespace funcs {
Expand Down Expand Up @@ -118,7 +118,7 @@ DenseTensor Diagonal(const DeviceContext& context,
#endif

// auto& dev_ctx = context.template device_context<DeviceContext>();
paddle::platform::ForRange<DeviceContext> for_range(context, diag.numel());
phi::funcs::ForRange<DeviceContext> for_range(context, diag.numel());
DiagonalFunctor<T> functor(
input_data, diag_arr, ret_arr, pos, dim_size, diag_data);
for_range(functor);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/funcs/elementwise_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"

#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
Expand Down Expand Up @@ -418,7 +418,7 @@ void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
DX_OP dx_op,
DY_OP dy_op) {
size_t N = static_cast<size_t>(phi::product(x_dim));
paddle::platform::ForRange<DeviceContext> for_range(dev_ctx, N);
phi::funcs::ForRange<DeviceContext> for_range(dev_ctx, N);
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP, Tout>{
x.data<T>(),
y.data<T>(),
Expand Down
129 changes: 129 additions & 0 deletions paddle/phi/kernels/funcs/for_range.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2016->2022

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix in next pr


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"

namespace phi {
namespace funcs {

template <typename Context>
struct ForRange {
ForRange(const Context& dev_ctx, size_t limit);

template <typename Function>
void operator()(Function func) const;
};

template <>
struct ForRange<phi::CPUContext> {
ForRange(const phi::CPUContext& dev_ctx, size_t limit) : limit_(limit) {}

template <typename Function>
void operator()(Function func) const {
for (size_t i = 0; i < limit_; ++i) {
func(i);
}
}

size_t limit_;
};

// NOTE: After the pten kernel is migrated, it needs to be deleted.
template <>
struct ForRange<paddle::platform::CPUDeviceContext> {
ForRange(const paddle::platform::CPUDeviceContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(limit) {}

template <typename Function>
void operator()(Function func) const {
phi::funcs::ForRange<phi::CPUContext> for_range(dev_ctx_, limit_);
for_range(func);
}

const paddle::platform::CPUDeviceContext& dev_ctx_;
size_t limit_;
};

#if defined(__NVCC__) || defined(__HIPCC__)

template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
size_t idx = static_cast<size_t>(threadIdx.x);
func(idx);
}

template <typename Function>
__global__ static void ForRangeElemwiseOp(Function func, size_t limit) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < limit) {
func(idx);
}
}

template <>
struct ForRange<phi::GPUContext> {
ForRange(const phi::GPUContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(limit) {}

template <typename Function>
inline void operator()(Function func) const {
#ifdef __HIPCC__
// HIP will throw core dump when threads > 256
constexpr int num_threads = 256;
#elif WITH_NV_JETSON
// JETSON_NANO will throw core dump when threads > 128
int num_thread = 256;
backends::gpu::ChangeThreadNum(dev_ctx_, &num_thread, 128);
const int num_threads = num_thread;
#else
constexpr int num_threads = 1024;
#endif
size_t block_size = limit_ <= num_threads ? limit_ : num_threads;
size_t grid_size = (limit_ + num_threads - 1) / num_threads;

if (grid_size == 1) {
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
func);
} else {
ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, limit_);
}
}

const phi::GPUContext& dev_ctx_;
size_t limit_;
};

// NOTE: After the pten kernel is migrated, it needs to be deleted.
template <>
struct ForRange<paddle::platform::CUDADeviceContext> {
ForRange(const paddle::platform::CUDADeviceContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(limit) {}

template <typename Function>
inline void operator()(Function func) const {
phi::funcs::ForRange<phi::GPUContext> for_range(dev_ctx_, limit_);
for_range(func);
}

const paddle::platform::CUDADeviceContext& dev_ctx_;
size_t limit_;
};

#endif

} // namespace funcs
} // namespace phi
4 changes: 2 additions & 2 deletions paddle/phi/kernels/gpu/poisson_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ limitations under the License. */
#include <hiprand_kernel.h>
#endif

#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/poisson_kernel.h"

namespace phi {
Expand Down Expand Up @@ -65,7 +65,7 @@ void PoissonKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;

paddle::platform::ForRange<Context> for_range(ctx, size);
phi::funcs::ForRange<Context> for_range(ctx, size);

PoissonCudaFunctor<T> functor(x_data, out_data, seed, offset);
for_range(functor);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/impl/abs_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

#pragma once

#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/abs_grad_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {

Expand Down Expand Up @@ -53,7 +53,7 @@ void AbsGradKernel(const Context& ctx,
ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
auto* dx_data = dx->data<T>();

paddle::platform::ForRange<Context> for_range(ctx, numel);
phi::funcs::ForRange<Context> for_range(ctx, numel);
phi::funcs::AbsGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
Expand All @@ -70,7 +70,7 @@ void AbsDoubleGradKernel(const Context& ctx,
ctx.template Alloc<T>(ddout, static_cast<size_t>(numel * sizeof(T)));
auto* ddout_data = ddout->data<T>();

paddle::platform::ForRange<Context> for_range(ctx, numel);
phi::funcs::ForRange<Context> for_range(ctx, numel);
phi::funcs::AbsGradGradFunctor<T> functor(
ddx_data, x_data, ddout_data, numel);
for_range(functor);
Expand Down
Loading