@@ -449,10 +449,79 @@ TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
449449 .contiguity (std::vector<bool >(n, true ))
450450 .shape (shape)
451451 .build ();
452- IrBuilder::create<RNGOp>(RNGOpType::Uniform, out);
452+ IrBuilder::create<RNGOp>(RNGOpType::Uniform, out, dtype );
453453 return out;
454454}
455455
456+ TensorView* rand_like (TensorView* v) {
457+ TORCH_CHECK (
458+ isFloatingPointType (v->dtype ()),
459+ " input must have floating point type, but got " ,
460+ v->dtype ());
461+ std::vector<Val*> shape;
462+ shape.reserve (v->getMaybeRFactorDomain ().size ());
463+ for (auto id : v->getMaybeRFactorDomain ()) {
464+ shape.emplace_back (id->getMaybeExpandedExtent ());
465+ }
466+ return rand (shape, v->dtype ());
467+ }
468+
469+ Val* rand_like (Val* v) {
470+ return rand_like (v->as <TensorView>());
471+ }
472+
473+ TensorView* full (
474+ const std::vector<Val*>& shape,
475+ Val* fill_value,
476+ DataType dtype) {
477+ auto n = shape.size ();
478+ auto out = TensorViewBuilder ()
479+ .ndims (n)
480+ .dtype (dtype)
481+ .contiguity (std::vector<bool >(n, true ))
482+ .shape (shape)
483+ .build ();
484+ IrBuilder::create<FullOp>(out, fill_value, dtype);
485+ return out;
486+ }
487+
488+ TensorView* full_like (TensorView* tv, Val* fill_value) {
489+ std::vector<Val*> shape;
490+ shape.reserve (tv->getMaybeRFactorDomain ().size ());
491+ for (auto id : tv->getMaybeRFactorDomain ()) {
492+ shape.emplace_back (id->getMaybeExpandedExtent ());
493+ }
494+ return full (shape, fill_value, tv->dtype ());
495+ }
496+
497+ Val* full_like (Val* v, Val* fill_value) {
498+ return full_like (v->as <TensorView>(), fill_value);
499+ }
500+
501+ TensorView* zeros (const std::vector<Val*>& shape, DataType dtype) {
502+ return full (shape, FusionGuard::getCurFusion ()->zeroVal (), dtype);
503+ }
504+
505+ TensorView* zeros_like (TensorView* tv) {
506+ return full_like (tv, FusionGuard::getCurFusion ()->zeroVal ());
507+ }
508+
509+ Val* zeros_like (Val* v) {
510+ return zeros_like (v->as <TensorView>());
511+ }
512+
513+ TensorView* ones (const std::vector<Val*>& shape, DataType dtype) {
514+ return full (shape, FusionGuard::getCurFusion ()->oneVal (), dtype);
515+ }
516+
517+ TensorView* ones_like (TensorView* tv) {
518+ return full_like (tv, FusionGuard::getCurFusion ()->oneVal ());
519+ }
520+
521+ Val* ones_like (Val* v) {
522+ return ones_like (v->as <TensorView>());
523+ }
524+
456525TensorView* arange (Val* end, DataType dtype) {
457526 return arange (FusionGuard::getCurFusion ()->zeroVal (), end, dtype);
458527}
@@ -480,7 +549,7 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
480549 .contiguity ({true })
481550 .shape ({size})
482551 .build ();
483- IrBuilder::create<ARangeOp>(out, start, end, step);
552+ IrBuilder::create<ARangeOp>(out, start, end, step, dtype );
484553 return out;
485554}
486555
@@ -506,23 +575,6 @@ NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
506575NVFUSER_DEFINE_UNARY_OP (print, Print)
507576#undef NVFUSER_DEFINE_UNARY_OP
508577
509- TensorView* randlike (TensorView* v) {
510- TORCH_CHECK (
511- isFloatingPointType (v->dtype ()),
512- " input must have floating point type, but got " ,
513- v->dtype ());
514- std::vector<Val*> shape;
515- shape.reserve (v->getMaybeRFactorDomain ().size ());
516- for (auto id : v->getMaybeRFactorDomain ()) {
517- shape.emplace_back (id->getMaybeExpandedExtent ());
518- }
519- return rand (shape, v->dtype ());
520- }
521-
522- Val* randlike (Val* v) {
523- return randlike (v->as <TensorView>());
524- }
525-
526578Val* bitwise_not (Val* v) {
527579 TORCH_CHECK (
528580 isIntegralType (v->dtype ()) || v->dtype () == DataType::Bool,
0 commit comments