@@ -106,32 +106,33 @@ at::Tensor generate_uniform(int64_t size, at::ScalarType dtype) {
106106} // namespace
107107
108108TEST_F (NVFuserTest, FusionRNGValidateWithCURand_CUDA) {
109- for (int64_t size : {16 , 1024 , 10001 , 10002 , 10003 , 100000 , 10000001 }) {
110- for (auto dtype : {kFloat , kDouble }) {
111- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
112- auto fusion = fusion_ptr.get ();
113- FusionGuard fg (fusion);
109+ std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
110+ auto fusion = fusion_ptr.get ();
111+ FusionGuard fg (fusion);
114112
115- Int* size_val = IrBuilder::create<Int>();
116- fusion->addInput (size_val);
117- TensorView* tv0 = rand ({size_val}, aten_to_data_type (dtype));
118- fusion->addOutput (tv0);
113+ Int* size_val = IrBuilder::create<Int>();
114+ fusion->addInput (size_val);
115+ TensorView* tv0 = rand ({size_val}, DataType::Float);
116+ TensorView* tv1 = rand ({size_val}, DataType::Double);
117+ fusion->addOutput (tv0);
118+ fusion->addOutput (tv1);
119119
120- FusionExecutorCache fec (std::move (fusion_ptr));
120+ FusionExecutorCache fec (std::move (fusion_ptr));
121121
122- at::manual_seed ( 0 );
123- auto cg_outputs = fec. runFusionWithInputs ({size} );
124- auto out = cg_outputs[ 0 ] ;
122+ for ( int64_t size : { 16 , 1024 , 10001 , 10002 , 10003 , 100000 , 10000001 }) {
123+ at::manual_seed ( 0 );
124+ auto cg_outputs = fec. runFusionWithInputs ({size}) ;
125125
126- at::manual_seed (0 );
127- auto ref = generate_uniform (size, dtype);
126+ at::manual_seed (0 );
127+ auto ref0 = generate_uniform (size, kFloat );
128+ auto ref1 = generate_uniform (size, kDouble );
128129
129- testValidate (fec. fusion (), {out}, {size}, {ref}, __LINE__, __FILE__);
130- }
130+ testValidate (
131+ fec. fusion (), cg_outputs, {size}, {ref0, ref1}, __LINE__, __FILE__);
131132 }
132133}
133134
134- TEST_F (NVFuserTest, FusionRNGSimpleValidateWithCURand_CUDA ) {
135+ TEST_F (NVFuserTest, FusionRNGManualScheduleValidateWithCURand_CUDA ) {
135136 int64_t size = 128 ;
136137 auto dtype = kFloat ;
137138 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
0 commit comments