Skip to content

Commit 8e41c3b

Browse files
kunalspathakmichaelgsharp
authored andcommitted
Handle case where falseValue is contained (dotnet#101515)
* Handle case where falseValue is contained * Handle cases where Abs() is wrapped in conditional with AllBitsSet mask * Add a missing case for Abs() handling * jit format * Review comments * Review feedback * Another review feedback
1 parent 2ef654f commit 8e41c3b

File tree

3 files changed

+74
-3
lines changed

3 files changed

+74
-3
lines changed

src/coreclr/jit/gentree.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,7 @@ struct GenTree
17701770
inline bool IsVectorZero() const;
17711771
inline bool IsVectorCreate() const;
17721772
inline bool IsVectorAllBitsSet() const;
1773+
inline bool IsMaskAllBitsSet() const;
17731774
inline bool IsVectorConst();
17741775

17751776
inline uint64_t GetIntegralVectorConstElement(size_t index, var_types simdBaseType);
@@ -9238,6 +9239,32 @@ inline bool GenTree::IsVectorAllBitsSet() const
92389239
return false;
92399240
}
92409241

9242+
inline bool GenTree::IsMaskAllBitsSet() const
9243+
{
9244+
#ifdef TARGET_ARM64
9245+
static_assert_no_msg(AreContiguous(NI_Sve_CreateTrueMaskByte, NI_Sve_CreateTrueMaskDouble,
9246+
NI_Sve_CreateTrueMaskInt16, NI_Sve_CreateTrueMaskInt32,
9247+
NI_Sve_CreateTrueMaskInt64, NI_Sve_CreateTrueMaskSByte,
9248+
NI_Sve_CreateTrueMaskSingle, NI_Sve_CreateTrueMaskUInt16,
9249+
NI_Sve_CreateTrueMaskUInt32, NI_Sve_CreateTrueMaskUInt64));
9250+
9251+
if (OperIsHWIntrinsic())
9252+
{
9253+
NamedIntrinsic id = AsHWIntrinsic()->GetHWIntrinsicId();
9254+
if (id == NI_Sve_ConvertMaskToVector)
9255+
{
9256+
GenTree* op1 = AsHWIntrinsic()->Op(1);
9257+
assert(op1->OperIsHWIntrinsic());
9258+
id = op1->AsHWIntrinsic()->GetHWIntrinsicId();
9259+
}
9260+
return ((id == NI_Sve_CreateTrueMaskAll) ||
9261+
((id >= NI_Sve_CreateTrueMaskByte) && (id <= NI_Sve_CreateTrueMaskUInt64)));
9262+
}
9263+
9264+
#endif
9265+
return false;
9266+
}
9267+
92419268
//-------------------------------------------------------------------
92429269
// IsVectorConst: returns true if this node is a HWIntrinsic that represents a constant.
92439270
//

src/coreclr/jit/hwintrinsic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
16221622
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
16231623
if (intrinsic == NI_Sve_ConditionalSelect)
16241624
{
1625-
if (op1->IsVectorAllBitsSet())
1625+
if (op1->IsVectorAllBitsSet() || op1->IsMaskAllBitsSet())
16261626
{
16271627
return retNode->AsHWIntrinsic()->Op(2);
16281628
}

src/coreclr/jit/hwintrinsiccodegenarm64.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,8 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
406406
// Handle case where op2 is operation that needs embedded mask
407407
GenTree* op2 = intrin.op2;
408408
assert(intrin.id == NI_Sve_ConditionalSelect);
409-
assert(op2->isContained());
410409
assert(op2->OperIsHWIntrinsic());
410+
assert(op2->isContained());
411411

412412
// Get the registers and intrinsics that needs embedded mask
413413
const HWIntrinsic intrinEmbMask(op2->AsHWIntrinsic());
@@ -439,10 +439,54 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
439439
{
440440
case 1:
441441
assert(!instrIsRMW);
442+
442443
if (targetReg != falseReg)
443444
{
444-
GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, falseReg);
445+
// If targetReg is not the same as `falseReg` then need to move
446+
// the `falseReg` to `targetReg`.
447+
448+
if (intrin.op3->isContained())
449+
{
450+
assert(intrin.op3->IsVectorZero());
451+
if (intrin.op1->isContained())
452+
{
453+
// We already skip importing ConditionalSelect if op1 == trueAll, however
454+
// if we still see it here, it is because we wrapped the predicated instruction
455+
// inside ConditionalSelect.
456+
// As such, no need to move the `falseReg` to `targetReg`
457+
// because the predicated instruction will eventually set it.
458+
assert(intrin.op1->IsMaskAllBitsSet());
459+
}
460+
else
461+
{
462+
// If falseValue is zero, just zero out those lanes of targetReg using `movprfx`
463+
// and /Z
464+
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg, targetReg,
465+
opt);
466+
}
467+
}
468+
else if (targetReg == embMaskOp1Reg)
469+
{
470+
// target != falseValue, but we do not want to overwrite target with `embMaskOp1Reg`.
471+
// We will first do the predicate operation and then do conditionalSelect inactive
472+
// elements from falseValue
473+
474+
// We cannot use use `movprfx` here to move falseReg to targetReg because that will
475+
// overwrite the value of embMaskOp1Reg which is present in targetReg.
476+
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt);
477+
478+
GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg,
479+
falseReg, opt, INS_SCALABLE_OPTS_UNPREDICATED);
480+
break;
481+
}
482+
else
483+
{
484+
// At this point, target != embMaskOp1Reg != falseReg, so just go ahead
485+
// and move the falseReg unpredicated into targetReg.
486+
GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, falseReg);
487+
}
445488
}
489+
446490
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt);
447491
break;
448492

0 commit comments

Comments
 (0)