16
16
#include " llvm/Support/Debug.h"
17
17
18
18
#include " src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.hpp"
19
+ #include " src/Accelerators/NNPA/Support/NNPALimit.hpp"
19
20
#include " src/Dialect/ONNX/ONNXOps.hpp"
20
21
#include " src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
21
22
@@ -121,6 +122,7 @@ void estimateTimeForMatMulOp(Operation *op, Value a, Value b, bool aTransposed,
121
122
assert (aType && aType.hasRank () && " expected shaped type with A rank" );
122
123
int64_t aRank = aType.getRank ();
123
124
llvm::ArrayRef<int64_t > aShape = aType.getShape ();
125
+ // a => matrix A; B => the Batch dims (aka all but the last 2 dims).
124
126
bool aBDynamic;
125
127
int64_t aB = summarizeHigherDims (aShape, aRank - 2 , aBDynamic);
126
128
int64_t aNIndex = aTransposed ? aRank - 1 : aRank - 2 ;
@@ -132,6 +134,7 @@ void estimateTimeForMatMulOp(Operation *op, Value a, Value b, bool aTransposed,
132
134
assert (bType && bType.hasRank () && " expected shaped type with B rank" );
133
135
int64_t bRank = bType.getRank ();
134
136
llvm::ArrayRef<int64_t > bShape = bType.getShape ();
137
+ // b => matrix B; B => the Batch dims (aka all but the last 2 dims).
135
138
bool bBDynamic;
136
139
int64_t bB = summarizeHigherDims (bShape, bRank - 2 , bBDynamic);
137
140
int64_t bMIndex = bTransposed ? bRank - 1 : bRank - 2 ;
@@ -312,6 +315,15 @@ void estimateTimeForOp<ONNXExpOp>(ONNXExpOp op, const DimAnalysis *dimAnalysis,
312
315
cpuEstimatedTime, nnpaEstimatedTime);
313
316
}
314
317
318
+ template <>
319
+ void estimateTimeForOp<ONNXGeluOp>(ONNXGeluOp op,
320
+ const DimAnalysis *dimAnalysis, double &cpuEstimatedTime,
321
+ double &nnpaEstimatedTime) {
322
+ estimateTimeForElementwiseOp (op.getOperation (), op.getOperand (), dimAnalysis,
323
+ estimatedTimeForCPU_Gelu_3ds, estimatedTimeForNNPA_Gelu_3ds,
324
+ cpuEstimatedTime, nnpaEstimatedTime);
325
+ }
326
+
315
327
template <>
316
328
void estimateTimeForOp<ONNXLogOp>(ONNXLogOp op, const DimAnalysis *dimAnalysis,
317
329
double &cpuEstimatedTime, double &nnpaEstimatedTime) {
@@ -401,15 +413,33 @@ double estimateTimeForStickOp(Value oper) {
401
413
int64_t e4 , e3 , e2 , e1 ;
402
414
std::string msg;
403
415
processDim (oper, e4 , e3 , e2 , e1 , msg);
404
- return estimatedTimeForNNPA_Stick_3ds (e4 * e3 , e2 , e1 );
416
+ // March 14, no NNPA support.
417
+ if (isLessEqualNNPALevel (NNPALevel::M14))
418
+ return arch14_estimatedTimeForCPU_Stick_3ds (e4 * e3 , e2 , e1 );
419
+ // Else returns minimum between CPU and NNPA
420
+ if (isLessEqualNNPALevel (NNPALevel::M15)) {
421
+ double cpuTime = arch15_estimatedTimeForCPU_Stick_3ds (e4 * e3 , e2 , e1 );
422
+ double nnpaTime = arch15_estimatedTimeForNNPA_Stick_3ds (e4 * e3 , e2 , e1 );
423
+ return cpuTime < nnpaTime ? cpuTime : nnpaTime;
424
+ }
425
+ llvm_unreachable (" add new NNPA architecture model here" );
405
426
}
406
427
407
428
double estimateTimeForUnstickOp (Value oper) {
408
429
// Process dim (collapse and handle dynamic sizes).
409
430
int64_t e4 , e3 , e2 , e1 ;
410
431
std::string msg;
411
432
processDim (oper, e4 , e3 , e2 , e1 , msg);
412
- return estimatedTimeForNNPA_Unstick_3ds (e4 * e3 , e2 , e1 );
433
+ // March 14, no NNPA support.
434
+ if (isLessEqualNNPALevel (NNPALevel::M14))
435
+ return arch14_estimatedTimeForCPU_Unstick_3ds (e4 * e3 , e2 , e1 );
436
+ // Else returns minimum between CPU and NNPA
437
+ if (isLessEqualNNPALevel (NNPALevel::M15)) {
438
+ double cpuTime = arch15_estimatedTimeForCPU_Unstick_3ds (e4 * e3 , e2 , e1 );
439
+ double nnpaTime = arch15_estimatedTimeForNNPA_Unstick_3ds (e4 * e3 , e2 , e1 );
440
+ return cpuTime < nnpaTime ? cpuTime : nnpaTime;
441
+ }
442
+ llvm_unreachable (" add new NNPA architecture model here" );
413
443
}
414
444
415
445
bool estimateTimeForOpWithModel (Operation *op, const DimAnalysis *dimAnalysis,
@@ -432,6 +462,8 @@ bool estimateTimeForOpWithModel(Operation *op, const DimAnalysis *dimAnalysis,
432
462
// Unary elementwise NNPA candidate ops.
433
463
else if (auto expOp = mlir::dyn_cast<ONNXExpOp>(op))
434
464
estimateTimeForOp (expOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
465
+ else if (auto geluOp = mlir::dyn_cast<ONNXGeluOp>(op))
466
+ estimateTimeForOp (geluOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
435
467
else if (auto logOp = mlir::dyn_cast<ONNXLogOp>(op))
436
468
estimateTimeForOp (logOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
437
469
else if (auto reluOp = mlir::dyn_cast<ONNXReluOp>(op))
0 commit comments