Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions paddle/fluid/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,6 @@ namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");

std::normal_distribution<T> dist(mean, std);
auto shape = GetShape(context);
tensor->Resize(shape);
int64_t size = tensor->numel();
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);

for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
}; // namespace operators

template <typename T>
class CPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
Expand Down Expand Up @@ -194,8 +173,6 @@ Used to initialize tensors with gaussian random generator.
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>,
ops::CPUGaussianRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
ops::CPUGaussianRandomBatchSizeLikeKernel<float>,
ops::CPUGaussianRandomBatchSizeLikeKernel<double>);
Expand Down
52 changes: 0 additions & 52 deletions paddle/fluid/operators/gaussian_random_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,53 +52,6 @@ struct GaussianGenerator {
}
};

template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
auto shape = GetShape(context);
tensor->Resize(shape);

auto& dev_cxt =
context.template device_context<platform::CUDADeviceContext>();
T* data = tensor->mutable_data<T>(dev_cxt.GetPlace());

int64_t size = tensor->numel();

int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename details::MPTypeTrait<T>::Type;
distribution::normal_distribution<MT> dist;
distribution::normal_transform<MT> trans(mean, std);
distribution::distribution_and_transform<T>(dev_cxt, tensor, dist,
trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
} else {
auto func = GaussianGenerator<T>(mean, std, seed);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
}
};

template <typename T>
class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -136,11 +89,6 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(
gaussian_random,
paddle::operators::GPUGaussianRandomKernel<paddle::platform::float16>,
paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(
gaussian_random_batch_size_like,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<
Expand Down
53 changes: 53 additions & 0 deletions paddle/phi/kernels/cpu/gaussian_random_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/gaussian_random_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/fluid/framework/generator.h"

namespace phi {

template <typename T, typename Context>
void GaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
auto tensor = out;

std::normal_distribution<T> dist(mean, std);

tensor->Resize(phi::make_ddim(shape.GetData()));
int64_t size = tensor->numel();
T* data = dev_ctx.template Alloc<T>(tensor);
auto engine = paddle::framework::GetCPURandomEngine(seed);

for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}

} // namespace phi

PD_REGISTER_KERNEL(gaussian_random,
CPU,
ALL_LAYOUT,
phi::GaussianRandomKernel,
float,
double) {}
32 changes: 32 additions & 0 deletions paddle/phi/kernels/gaussian_random_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

template <typename T, typename Context>
void GaussianRandomKernel(const Context& ctx,
const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out);

} // namespace phi
111 changes: 111 additions & 0 deletions paddle/phi/kernels/gpu/gaussian_random_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/gaussian_random_kernel.h"

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"

#include "paddle/fluid/framework/generator.h"

DECLARE_bool(use_curand);

namespace phi {

template <typename T>
struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
unsigned int offset_ = 0;

__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}

__host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset)
: mean_(mean), std_(std), seed_(seed), offset_(offset) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(mean_, std_);
unsigned int new_n = n + offset_;
rng.discard(new_n);
MT out = dist(rng);
return static_cast<T>(out);
}
};

template <typename T, typename Context>
void GaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
auto tensor = out;

bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}

tensor->Resize(phi::make_ddim(shape.GetData()));

T* data = dev_ctx.template Alloc<T>(tensor);

int64_t size = tensor->numel();

int device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);

using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
funcs::normal_distribution<MT> dist;
funcs::normal_transform<MT> trans(mean, std);
funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
GaussianGenerator<MT>(mean, std, seed_offset.first, gen_offset);
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func);
}
} else {
auto func = GaussianGenerator<MT>(mean, std, seed);
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func);
}
}

} // namespace phi

PD_REGISTER_KERNEL(gaussian_random,
GPU,
ALL_LAYOUT,
phi::GaussianRandomKernel,
phi::dtype::float16,
float,
double) {}
45 changes: 45 additions & 0 deletions paddle/phi/ops/compat/gaussian_random_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/compat/op_utils.h"

namespace phi {

KernelSignature GaussianRandomOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.InputSize("ShapeTensorList") > 0) {
return KernelSignature("gaussian_random",
{},
{"ShapeTensorList", "mean", "std", "seed", "dtype"},
{"Out"});
}

const auto& shape = paddle::any_cast<std::vector<int64_t>>(ctx.Attr("shape"));
if (ctx.HasInput("ShapeTensor") && shape.empty()) {
return KernelSignature("gaussian_random",
{},
{"ShapeTensor", "mean", "std", "seed", "dtype"},
{"Out"});
}

return KernelSignature("gaussian_random",
{},
{"shape", "mean", "std", "seed", "dtype"},
{"Out"});
}

} // namespace phi

PD_REGISTER_ARG_MAPPING_FN(gaussian_random,
phi::GaussianRandomOpArgumentMapping);