Skip to content

Commit f262d9c

Browse files
authored
Add support for uniform RNG (pytorch#1986)
1 parent eb1dad1 commit f262d9c

File tree

13 files changed

+149
-7
lines changed

13 files changed

+149
-7
lines changed

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,24 @@ TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
453453
return out;
454454
}
455455

456+
// TENSOR FACTORIES
457+
TensorView* uniform(
458+
const std::vector<Val*>& shape,
459+
Val* low,
460+
Val* high,
461+
DataType dtype) {
462+
auto n = shape.size();
463+
auto out = TensorViewBuilder()
464+
.ndims(n)
465+
.dtype(dtype)
466+
.contiguity(std::vector<bool>(n, true))
467+
.shape(shape)
468+
.build();
469+
IrBuilder::create<RNGOp>(
470+
RNGOpType::UniformRange, out, dtype, std::vector<Val*>{low, high});
471+
return out;
472+
}
473+
456474
TensorView* rand_like(TensorView* v) {
457475
TORCH_CHECK(
458476
isFloatingPointType(v->dtype()),

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,20 @@ TORCH_CUDA_CU_API WelfordResult Welford(
121121
// import IrBuilder just for this one interface.
122122
Int* init_N = nullptr);
123123

124-
// TENSOR FACTORIES
124+
// RNG OPERATIONS
125125
TORCH_CUDA_CU_API TensorView* rand(
126126
const std::vector<Val*>& shape,
127127
DataType dtype);
128128
TORCH_CUDA_CU_API Val* rand_like(Val*);
129129
TORCH_CUDA_CU_API TensorView* rand_like(TensorView*);
130+
131+
TORCH_CUDA_CU_API TensorView* uniform(
132+
const std::vector<Val*>& shape,
133+
Val* low,
134+
Val* high,
135+
DataType dtype);
136+
137+
// TENSOR FACTORIES
130138
TORCH_CUDA_CU_API TensorView* full(
131139
const std::vector<Val*>& shape,
132140
Val* fill_value,

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
794794
if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) {
795795
code_ << "f";
796796
}
797-
code_ << "(rng_result, rng_component" << rop->name() << ");\n";
797+
code_ << "(rng_result, rng_component" << rop->name();
798+
switch (op_type) {
799+
case RNGOpType::UniformRange: {
800+
auto parameters = rop->getParameters();
801+
TORCH_INTERNAL_ASSERT(parameters.size() == 2);
802+
code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]);
803+
break;
804+
}
805+
default:;
806+
}
807+
code_ << ");\n";
798808
}
799809

800810
std::string genBinaryOp(

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
233233
RNGOpType type,
234234
Val* out,
235235
DataType dtype,
236+
std::vector<Val*> parameters = {},
236237
int rng_offset = 0,
237238
Val* philox_index = nullptr);
238239

@@ -254,6 +255,14 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
254255
rng_offset_ = val;
255256
}
256257

258+
const std::vector<Val*>& getParameters() const {
259+
return parameters_;
260+
}
261+
262+
const std::vector<Val*>& getShape() const {
263+
return shape_;
264+
}
265+
257266
Val* getPhiloxIndex() const {
258267
return philox_index_;
259268
}
@@ -267,6 +276,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
267276
private:
268277
const RNGOpType rng_op_type_;
269278
const DataType dtype_;
279+
std::vector<Val*> parameters_;
280+
std::vector<Val*> shape_;
270281
int rng_offset_ = -1;
271282
// The index used to feed philox's subsequence and component
272283
Val* philox_index_ = nullptr;

torch/csrc/jit/codegen/cuda/ir_iostream.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,14 +481,19 @@ void IrPrinter::handle(const RNGOp* rop) {
481481

482482
os_ << rop->getRNGOpType() << "({";
483483
bool first = true;
484-
for (auto i : rop->inputs()) {
484+
for (auto i : rop->getShape()) {
485485
if (!first) {
486486
os_ << ", ";
487487
}
488488
handle(i);
489489
first = false;
490490
}
491-
os_ << "}, " << rop->dtype() << ")";
491+
os_ << "}";
492+
for (auto i : rop->getParameters()) {
493+
os_ << ", ";
494+
handle(i);
495+
}
496+
os_ << ", " << rop->dtype() << ")";
492497

493498
indent_size_--;
494499

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,25 +441,34 @@ RNGOp::RNGOp(
441441
RNGOpType type,
442442
Val* out,
443443
DataType dtype,
444+
std::vector<Val*> parameters,
444445
int rng_offset,
445446
Val* philox_index)
446447
: Expr(passkey, ExprType::RNGOp),
447448
rng_op_type_(type),
448449
dtype_(dtype),
450+
parameters_(std::move(parameters)),
449451
rng_offset_(rng_offset),
450452
philox_index_(philox_index) {
451453
if (out->isA<TensorView>()) {
452454
for (auto id : out->as<TensorView>()->getRootDomain()) {
453-
addInput(id->extent());
455+
shape_.emplace_back(id->extent());
454456
}
455457
}
458+
for (auto v : shape_) {
459+
addInput(v);
460+
}
461+
for (auto v : parameters_) {
462+
addInput(v);
463+
}
456464
addOutput(out);
457465
}
458466

459467
RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner)
460468
: Expr(src, ir_cloner),
461469
rng_op_type_(src->rng_op_type_),
462470
dtype_(src->dtype()),
471+
parameters_(ir_cloner->clone(src->parameters_)),
463472
rng_offset_(src->rng_offset_),
464473
philox_index_(ir_cloner->clone(src->philox_index_)) {}
465474

@@ -477,6 +486,14 @@ bool RNGOp::sameAs(const Statement* other) const {
477486
if (dtype_ != other_op->dtype_) {
478487
return false;
479488
}
489+
if (parameters_.size() != other_op->parameters_.size()) {
490+
return false;
491+
}
492+
for (auto i : c10::irange(parameters_.size())) {
493+
if (!parameters_[i]->sameAs(other_op->parameters_[i])) {
494+
return false;
495+
}
496+
}
480497
if (getRNGOffset() != other_op->getRNGOffset()) {
481498
return false;
482499
}

torch/csrc/jit/codegen/cuda/ir_utils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,18 @@ struct SubstituteInExpr : public OptInDispatch {
266266
}
267267

268268
void handle(RNGOp* rng_expr) final {
269+
std::vector<Val*> subsituted_params;
270+
for (auto v : rng_expr->getParameters()) {
271+
subsituted_params.emplace_back(reference_->sameAs(v) ? substitute_ : v);
272+
}
269273
auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_
270274
: rng_expr->output(0);
271275
expr_ = IrBuilder::create<RNGOp>(
272276
rng_expr->container(),
273277
rng_expr->getRNGOpType(),
274278
out,
275279
rng_expr->dtype(),
280+
subsituted_params,
276281
rng_expr->getRNGOffset(),
277282
rng_expr->getPhiloxIndex());
278283
}

torch/csrc/jit/codegen/cuda/lower_index.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ void IndexLowering::handle(const RNGOp* rop) {
109109
rop->getRNGOpType(),
110110
out,
111111
rop->dtype(),
112+
rop->getParameters(),
112113
rop->getRNGOffset(),
113114
philox_index);
114115

torch/csrc/jit/codegen/cuda/mutator.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,13 @@ void OptOutMutator::mutate(TernaryOp* top) {
214214

215215
void OptOutMutator::mutate(RNGOp* rop) {
216216
Val* out = maybeMutated(rop->output(0));
217+
auto& parameters = rop->getParameters();
218+
std::vector<Val*> mutated_parameters;
219+
for (auto v : parameters) {
220+
mutated_parameters.emplace_back(maybeMutated(v));
221+
}
217222

218-
if (out == rop->output(0)) {
223+
if (out == rop->output(0) && mutated_parameters == parameters) {
219224
return;
220225
}
221226

@@ -227,6 +232,7 @@ void OptOutMutator::mutate(RNGOp* rop) {
227232
rop_type,
228233
out,
229234
rop->dtype(),
235+
mutated_parameters,
230236
rop->getRNGOffset(),
231237
rop->getPhiloxIndex());
232238
}

torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,23 @@ __device__ double rng_uniform(const uint4& rng_result, int rng_component) {
6767
__device__ float rng_uniformf(const uint4& rng_result, int rng_component) {
6868
return uniformf((&rng_result.x)[rng_component]);
6969
}
70+
71+
__device__ double rng_uniform_range(
72+
const uint4& rng_result,
73+
int rng_component,
74+
double from,
75+
double to) {
76+
auto range = to - from;
77+
auto uniform01 = rng_uniform(rng_result, rng_component);
78+
return from + range * uniform01;
79+
}
80+
81+
__device__ float rng_uniform_rangef(
82+
const uint4& rng_result,
83+
int rng_component,
84+
float from,
85+
float to) {
86+
auto range = to - from;
87+
auto uniform01 = rng_uniformf(rng_result, rng_component);
88+
return from + range * uniform01;
89+
}

0 commit comments

Comments
 (0)