Skip to content

Commit 6310948

Browse files
authored
Add full, full_like, zeros, zeros_like, ones, ones_like (pytorch#1943)
1 parent 4c254c0 commit 6310948

27 files changed

+479
-58
lines changed

benchmarks/cpp/nvfuser/timm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ static void setup_vit_base_patch16_224_bcast5(Fusion* fusion, void* null) {
115115
auto t6 = set(t5);
116116
auto t7 = broadcast(t6, bcast_pattern0);
117117
auto t8 = add(t4, t7);
118-
auto t9 = randlike(t8);
118+
auto t9 = rand_like(t8);
119119
auto d34 =
120120
sub(IrBuilder::create<Double>(1.0), IrBuilder::create<Double>(0.0));
121121
auto t10 = lt(t9, d34);
@@ -289,7 +289,7 @@ static void setup_vit_base_patch16_224_norm_inner3(Fusion* fusion, void* null) {
289289
auto t10 = broadcast(t9, {false, false, false, true});
290290
auto t11 = reciprocal(t10);
291291
auto t12 = mul(t8, t11);
292-
auto t13 = randlike(t12);
292+
auto t13 = rand_like(t12);
293293
auto d79 = sub(IrBuilder::create<Double>(1), IrBuilder::create<Double>(0));
294294
auto t14 = lt(t13, d79);
295295
auto t15 = castOp(DataType::Float, t14);
@@ -367,7 +367,7 @@ static void setup_vit_base_patch16_224_bcast_outer6(
367367
auto t9 = add(IrBuilder::create<Double>(1), t8);
368368
auto t10 = mul(IrBuilder::create<Double>(0.5), t9);
369369
auto t11 = mul(t6, t10);
370-
auto t12 = randlike(t11);
370+
auto t12 = rand_like(t11);
371371
auto d66 = sub(IrBuilder::create<Double>(1), IrBuilder::create<Double>(0));
372372
auto t13 = lt(t12, d66);
373373
auto t14 = castOp(DataType::Float, t13);
@@ -456,7 +456,7 @@ static void setup_vit_base_patch16_224_bcast_inner6(
456456
auto t9 = add(IrBuilder::create<Double>(1), t8);
457457
auto t10 = mul(IrBuilder::create<Double>(0.5), t9);
458458
auto t11 = mul(t6, t10);
459-
auto t12 = randlike(t11);
459+
auto t12 = rand_like(t11);
460460
auto d66 = sub(IrBuilder::create<Double>(1), IrBuilder::create<Double>(0));
461461
auto t13 = lt(t12, d66);
462462
auto t14 = castOp(DataType::Float, t13);

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,79 @@ TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
449449
.contiguity(std::vector<bool>(n, true))
450450
.shape(shape)
451451
.build();
452-
IrBuilder::create<RNGOp>(RNGOpType::Uniform, out);
452+
IrBuilder::create<RNGOp>(RNGOpType::Uniform, out, dtype);
453453
return out;
454454
}
455455

456+
TensorView* rand_like(TensorView* v) {
457+
TORCH_CHECK(
458+
isFloatingPointType(v->dtype()),
459+
"input must have floating point type, but got ",
460+
v->dtype());
461+
std::vector<Val*> shape;
462+
shape.reserve(v->getMaybeRFactorDomain().size());
463+
for (auto id : v->getMaybeRFactorDomain()) {
464+
shape.emplace_back(id->getMaybeExpandedExtent());
465+
}
466+
return rand(shape, v->dtype());
467+
}
468+
469+
Val* rand_like(Val* v) {
470+
return rand_like(v->as<TensorView>());
471+
}
472+
473+
TensorView* full(
474+
const std::vector<Val*>& shape,
475+
Val* fill_value,
476+
DataType dtype) {
477+
auto n = shape.size();
478+
auto out = TensorViewBuilder()
479+
.ndims(n)
480+
.dtype(dtype)
481+
.contiguity(std::vector<bool>(n, true))
482+
.shape(shape)
483+
.build();
484+
IrBuilder::create<FullOp>(out, fill_value, dtype);
485+
return out;
486+
}
487+
488+
TensorView* full_like(TensorView* tv, Val* fill_value) {
489+
std::vector<Val*> shape;
490+
shape.reserve(tv->getMaybeRFactorDomain().size());
491+
for (auto id : tv->getMaybeRFactorDomain()) {
492+
shape.emplace_back(id->getMaybeExpandedExtent());
493+
}
494+
return full(shape, fill_value, tv->dtype());
495+
}
496+
497+
Val* full_like(Val* v, Val* fill_value) {
498+
return full_like(v->as<TensorView>(), fill_value);
499+
}
500+
501+
TensorView* zeros(const std::vector<Val*>& shape, DataType dtype) {
502+
return full(shape, FusionGuard::getCurFusion()->zeroVal(), dtype);
503+
}
504+
505+
TensorView* zeros_like(TensorView* tv) {
506+
return full_like(tv, FusionGuard::getCurFusion()->zeroVal());
507+
}
508+
509+
Val* zeros_like(Val* v) {
510+
return zeros_like(v->as<TensorView>());
511+
}
512+
513+
TensorView* ones(const std::vector<Val*>& shape, DataType dtype) {
514+
return full(shape, FusionGuard::getCurFusion()->oneVal(), dtype);
515+
}
516+
517+
TensorView* ones_like(TensorView* tv) {
518+
return full_like(tv, FusionGuard::getCurFusion()->oneVal());
519+
}
520+
521+
Val* ones_like(Val* v) {
522+
return ones_like(v->as<TensorView>());
523+
}
524+
456525
TensorView* arange(Val* end, DataType dtype) {
457526
return arange(FusionGuard::getCurFusion()->zeroVal(), end, dtype);
458527
}
@@ -480,7 +549,7 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
480549
.contiguity({true})
481550
.shape({size})
482551
.build();
483-
IrBuilder::create<ARangeOp>(out, start, end, step);
552+
IrBuilder::create<ARangeOp>(out, start, end, step, dtype);
484553
return out;
485554
}
486555

@@ -506,23 +575,6 @@ NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
506575
NVFUSER_DEFINE_UNARY_OP(print, Print)
507576
#undef NVFUSER_DEFINE_UNARY_OP
508577

509-
TensorView* randlike(TensorView* v) {
510-
TORCH_CHECK(
511-
isFloatingPointType(v->dtype()),
512-
"input must have floating point type, but got ",
513-
v->dtype());
514-
std::vector<Val*> shape;
515-
shape.reserve(v->getMaybeRFactorDomain().size());
516-
for (auto id : v->getMaybeRFactorDomain()) {
517-
shape.emplace_back(id->getMaybeExpandedExtent());
518-
}
519-
return rand(shape, v->dtype());
520-
}
521-
522-
Val* randlike(Val* v) {
523-
return randlike(v->as<TensorView>());
524-
}
525-
526578
Val* bitwise_not(Val* v) {
527579
TORCH_CHECK(
528580
isIntegralType(v->dtype()) || v->dtype() == DataType::Bool,

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,24 @@ TORCH_CUDA_CU_API WelfordResult Welford(
125125
TORCH_CUDA_CU_API TensorView* rand(
126126
const std::vector<Val*>& shape,
127127
DataType dtype);
128-
128+
TORCH_CUDA_CU_API Val* rand_like(Val*);
129+
TORCH_CUDA_CU_API TensorView* rand_like(TensorView*);
130+
TORCH_CUDA_CU_API TensorView* full(
131+
const std::vector<Val*>& shape,
132+
Val* fill_value,
133+
DataType dtype);
134+
TORCH_CUDA_CU_API TensorView* full_like(TensorView* tv, Val* fill_value);
135+
TORCH_CUDA_CU_API Val* full_like(Val* tv, Val* fill_value);
136+
TORCH_CUDA_CU_API TensorView* zeros(
137+
const std::vector<Val*>& shape,
138+
DataType dtype);
139+
TORCH_CUDA_CU_API TensorView* zeros_like(TensorView*);
140+
TORCH_CUDA_CU_API Val* zeros_like(Val*);
141+
TORCH_CUDA_CU_API TensorView* ones(
142+
const std::vector<Val*>& shape,
143+
DataType dtype);
144+
TORCH_CUDA_CU_API TensorView* ones_like(TensorView*);
145+
TORCH_CUDA_CU_API Val* ones_like(Val*);
129146
//! WARNING: giving invalid combinations of the start, end and step
130147
//! arguments can result in undefined behavior. Specifically, the
131148
//! signs of `end - start` and step must be the same.
@@ -204,9 +221,6 @@ TORCH_CUDA_CU_API TensorView* log2(TensorView*);
204221
// neg
205222
TORCH_CUDA_CU_API Val* neg(Val*);
206223
TORCH_CUDA_CU_API TensorView* neg(TensorView*);
207-
// randlike
208-
TORCH_CUDA_CU_API Val* randlike(Val*);
209-
TORCH_CUDA_CU_API TensorView* randlike(TensorView*);
210224
// real
211225
TORCH_CUDA_CU_API Val* real(Val*);
212226
TORCH_CUDA_CU_API TensorView* real(TensorView*);

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,14 @@ class CudaKernelGenerator : private OptOutConstDispatch {
560560
<< "&" << gen(ldst->in()) << ");\n";
561561
}
562562

563+
void handle(const FullOp* fop) final {
564+
indent() << gen(fop->output(0)) << " = (" << fop->dtype() << ")"
565+
<< gen(fop->getFillValue()) << ";\n";
566+
}
567+
563568
void handle(const ARangeOp* aop) final {
564569
auto index = genTensorIndex(aop->getLinearIndex()->as<kir::TensorIndex>());
565-
indent() << gen(aop->output(0)) << " = arange<" << aop->output(0)->dtype()
566-
<< ">";
570+
indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">";
567571
code_ << "(" << index << ", " << gen(aop->start()) << ", "
568572
<< gen(aop->step()) << ");\n";
569573
}
@@ -759,9 +763,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
759763
void handle(const RNGOp* rop) final {
760764
// TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an
761765
// innermost ID of size 4 (float) or size 2 (double)?
762-
auto out_tv = rop->output(0)->as<kir::TensorIndex>()->view();
763766
auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>());
764-
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
767+
int multiple = rop->dtype() == DataType::Double ? 2 : 4;
765768
indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index
766769
<< ";\n";
767770
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index"
@@ -780,8 +783,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
780783
indent() << "}\n";
781784
auto op_type = rop->getRNGOpType();
782785
indent() << gen(rop->output(0)) << " = " << op_type;
783-
if (needFloatSuffix(op_type) &&
784-
rop->output(0)->dtype() == DataType::Float) {
786+
if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) {
785787
code_ << "f";
786788
}
787789
code_ << "(rng_result, rng_component" << rop->name() << ");\n";

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ void Val::dispatch(T handler, Val* val) {
9595
template <typename T>
9696
void Expr::dispatch(T handler, Expr* expr) {
9797
switch (*(expr->getExprType())) {
98+
case ExprType::FullOp:
99+
ptr(handler)->handle(expr->as<FullOp>());
100+
return;
98101
case ExprType::ARangeOp:
99102
ptr(handler)->handle(expr->as<ARangeOp>());
100103
return;
@@ -281,6 +284,9 @@ void Val::constDispatch(T handler, const Val* val) {
281284
template <typename T>
282285
void Expr::constDispatch(T handler, const Expr* expr) {
283286
switch (*(expr->getExprType())) {
287+
case ExprType::FullOp:
288+
ptr(handler)->handle(expr->as<FullOp>());
289+
return;
284290
case ExprType::ARangeOp:
285291
ptr(handler)->handle(expr->as<ARangeOp>());
286292
return;
@@ -475,6 +481,9 @@ void Val::mutatorDispatch(T mutator, Val* val) {
475481
template <typename T>
476482
void Expr::mutatorDispatch(T mutator, Expr* expr) {
477483
switch (*(expr->getExprType())) {
484+
case ExprType::FullOp:
485+
ptr(mutator)->mutate(expr->as<FullOp>());
486+
return;
478487
case ExprType::ARangeOp:
479488
ptr(mutator)->mutate(expr->as<ARangeOp>());
480489
return;
@@ -734,6 +743,9 @@ void OptOutConstDispatch::handle(const kir::IntPair* stmt) {
734743
}
735744

736745
// Exprs
746+
void OptOutConstDispatch::handle(const FullOp* stmt) {
747+
unhandled(stmt);
748+
}
737749
void OptOutConstDispatch::handle(const ARangeOp* stmt) {
738750
unhandled(stmt);
739751
}
@@ -890,6 +902,9 @@ void OptOutDispatch::handle(kir::IntPair* stmt) {
890902
}
891903

892904
// Exprs
905+
void OptOutDispatch::handle(FullOp* stmt) {
906+
unhandled(stmt);
907+
}
893908
void OptOutDispatch::handle(ARangeOp* stmt) {
894909
unhandled(stmt);
895910
}

torch/csrc/jit/codegen/cuda/dispatch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class ComplexDouble;
6868
class NamedScalar;
6969

7070
// Exprs
71+
class FullOp;
7172
class ARangeOp;
7273
class UnaryOp;
7374
class BinaryOp;
@@ -144,6 +145,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
144145
virtual void handle(const kir::IntPair*);
145146

146147
// Exprs
148+
virtual void handle(const FullOp* stmt);
147149
virtual void handle(const ARangeOp* stmt);
148150
virtual void handle(const UnaryOp* stmt);
149151
virtual void handle(const BinaryOp* stmt);
@@ -211,6 +213,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
211213
virtual void handle(kir::IntPair*);
212214

213215
// Exprs
216+
virtual void handle(FullOp* stmt);
214217
virtual void handle(ARangeOp* stmt);
215218
virtual void handle(UnaryOp* stmt);
216219
virtual void handle(BinaryOp* stmt);
@@ -319,6 +322,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
319322
virtual void mutate(kir::IntPair*);
320323

321324
// Exprs
325+
virtual void mutate(FullOp*);
322326
virtual void mutate(ARangeOp*);
323327
virtual void mutate(UnaryOp*);
324328
virtual void mutate(BinaryOp*);

torch/csrc/jit/codegen/cuda/ir_builder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ IR_BUILDER_INSTANTIATE(ShiftOp)
6060
IR_BUILDER_INSTANTIATE(GatherOp)
6161
IR_BUILDER_INSTANTIATE(ViewAsScalar)
6262
IR_BUILDER_INSTANTIATE(ViewOp)
63+
IR_BUILDER_INSTANTIATE(FullOp)
6364
IR_BUILDER_INSTANTIATE(ARangeOp)
6465
IR_BUILDER_INSTANTIATE(UnaryOp)
6566
IR_BUILDER_INSTANTIATE(BinaryOp)

torch/csrc/jit/codegen/cuda/ir_cloner.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ void IrCloner::handle(const TensorView* tv) {
8888
clone_ = IrBuilder::clone(tv, this);
8989
}
9090

91+
void IrCloner::handle(const FullOp* op) {
92+
clone_ = IrBuilder::clone(op, this);
93+
}
94+
9195
void IrCloner::handle(const ARangeOp* op) {
9296
clone_ = IrBuilder::clone(op, this);
9397
}

torch/csrc/jit/codegen/cuda/ir_cloner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
6868
void handle(const ComplexDouble*) override;
6969
void handle(const NamedScalar*) override;
7070

71+
void handle(const FullOp*) override;
7172
void handle(const ARangeOp*) override;
7273
void handle(const UnaryOp*) override;
7374
void handle(const BinaryOp*) override;

torch/csrc/jit/codegen/cuda/ir_graphviz.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -407,15 +407,24 @@ void IrGraphGenerator::handle(const TensorView* tv) {
407407
tensor_views_.push_back(tv);
408408
}
409409

410-
void IrGraphGenerator::handle(const ARangeOp* uop) {
410+
void IrGraphGenerator::handle(const FullOp* fop) {
411411
// node
412-
printExpr(uop, "arange");
412+
printExpr(fop, "full");
413413

414414
// inputs & outputs
415-
addArc(uop->start(), uop);
416-
addArc(uop->end(), uop);
417-
addArc(uop->step(), uop);
418-
addArc(uop, uop->output(0));
415+
addArc(fop->getFillValue(), fop);
416+
addArc(fop, fop->output(0));
417+
}
418+
419+
void IrGraphGenerator::handle(const ARangeOp* aop) {
420+
// node
421+
printExpr(aop, "arange");
422+
423+
// inputs & outputs
424+
addArc(aop->start(), aop);
425+
addArc(aop->end(), aop);
426+
addArc(aop->step(), aop);
427+
addArc(aop, aop->output(0));
419428
}
420429

421430
void IrGraphGenerator::handle(const UnaryOp* uop) {

0 commit comments

Comments
 (0)