2525
2626void BuildProgram (pir::Builder &builder) { // NOLINT
2727 paddle::dialect::FullOp full_input_op1 =
28- builder.Build <paddle::dialect::FullOp>(std::vector< int64_t >{ 1 , 512 , 64 },
29- 1.5 );
28+ builder.Build <paddle::dialect::FullOp>(
29+ std::vector< int64_t >{ 1 , 512 , 64 }, 1.5 , phi::DataType::FLOAT16 );
3030 // linear 1
3131 paddle::dialect::FullOp full_weight_op1 =
32- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 , 64 }, 1.5 );
32+ builder.Build <paddle::dialect::FullOp>(
33+ std::vector<int64_t >{64 , 64 }, 1.5 , phi::DataType::FLOAT16);
3334 paddle::dialect::FullOp full_bias_op1 =
34- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 }, 1.0 );
35+ builder.Build <paddle::dialect::FullOp>(
36+ std::vector<int64_t >{64 }, 1.0 , phi::DataType::FLOAT16);
3537 paddle::dialect::MatmulOp matmul_op1 =
3638 builder.Build <paddle::dialect::MatmulOp>(full_input_op1.out (),
3739 full_weight_op1.out ());
3840 paddle::dialect::AddOp add_op1 = builder.Build <paddle::dialect::AddOp>(
3941 matmul_op1.out (), full_bias_op1.out ());
4042 // linear 2
4143 paddle::dialect::FullOp full_weight_op2 =
42- builder.Build <paddle::dialect::FullOp>(std::vector< int64_t >{ 64 , 128 },
43- 1.5 );
44+ builder.Build <paddle::dialect::FullOp>(
45+ std::vector< int64_t >{ 64 , 128 }, 1.5 , phi::DataType::FLOAT16 );
4446 paddle::dialect::FullOp full_bias_op2 =
45- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{128 }, 1.0 );
47+ builder.Build <paddle::dialect::FullOp>(
48+ std::vector<int64_t >{128 }, 1.0 , phi::DataType::FLOAT16);
4649 paddle::dialect::MatmulOp matmul_op2 =
4750 builder.Build <paddle::dialect::MatmulOp>(add_op1.out (),
4851 full_weight_op2.out ());
@@ -52,10 +55,11 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
5255 builder.Build <paddle::dialect::ReluOp>(add_op2.out ());
5356 // linear 3
5457 paddle::dialect::FullOp full_weight_op3 =
55- builder.Build <paddle::dialect::FullOp>(std::vector< int64_t >{ 128 , 64 },
56- 1.5 );
58+ builder.Build <paddle::dialect::FullOp>(
59+ std::vector< int64_t >{ 128 , 64 }, 1.5 , phi::DataType::FLOAT16 );
5760 paddle::dialect::FullOp full_bias_op3 =
58- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 }, 1.0 );
61+ builder.Build <paddle::dialect::FullOp>(
62+ std::vector<int64_t >{64 }, 1.0 , phi::DataType::FLOAT16);
5963 paddle::dialect::MatmulOp matmul_op3 =
6064 builder.Build <paddle::dialect::MatmulOp>(relu_op.out (),
6165 full_weight_op3.out ());
@@ -65,9 +69,11 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
6569 builder.Build <paddle::dialect::GeluOp>(add_op3.out ());
6670 // linear 4
6771 paddle::dialect::FullOp full_weight_op4 =
68- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 , 64 }, 1.5 );
72+ builder.Build <paddle::dialect::FullOp>(
73+ std::vector<int64_t >{64 , 64 }, 1.5 , phi::DataType::FLOAT16);
6974 paddle::dialect::FullOp full_bias_op4 =
70- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 }, 1.0 );
75+ builder.Build <paddle::dialect::FullOp>(
76+ std::vector<int64_t >{64 }, 1.0 , phi::DataType::FLOAT16);
7177 paddle::dialect::MatmulOp matmul_op4 =
7278 builder.Build <paddle::dialect::MatmulOp>(gelu_op1.out (),
7379 full_weight_op4.out ());
@@ -78,7 +84,7 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
7884
7985 // backward
8086 paddle::dialect::FullOp full_grad_op = builder.Build <paddle::dialect::FullOp>(
81- std::vector<int64_t >{1 , 512 , 64 }, 1.0 );
87+ std::vector<int64_t >{1 , 512 , 64 }, 1.0 , phi::DataType::FLOAT16 );
8288
8389 paddle::dialect::GeluGradOp gelu_op2_grad =
8490 builder.Build <paddle::dialect::GeluGradOp>(
0 commit comments