Reland "[RISCV] Refactor X60 scheduling model helper classes. NFC." (#152336)

This PR fixes the issue that caused an ub in PR #151472.

The issue was a shl call taking a negative shift amount (posDiff). The
result was never used, but tablegen would perform the calculation
anyway. The fix was to replace the shl call with just multiplications
with constants.

Original PR description:

This patch improves the helper classes in the SpacemiT-X60 vector
scheduling model and will be used in follow-up PRs:

There are now two functions to map LMUL to values:
* ConstValueUntilLMULThenDoubleBase: returns BaseValue for LMUL values
before startLMUL, Value for startLMUL, then doubles Value for each
subsequent LMUL. Useful for cases where fractional LMULs have constant
cycles, and integer LMULs double as they increase.
* GetLMULValue: takes an ordered list of LMUL cycles and LMUL and
returns the corresponding cycle. Useful for cases we can't easily cover
with ConstValueUntilLMULThenDoubleBase.

This PR also adds some useful simplified versions of
ConstValueUntilLMULThenDoubleBase, e.g.: ConstValueUntilLMULThenDouble
(when BaseValue == Value), or ConstOneUntilMF4ThenDouble (when cycles
start to double after MF2)
This commit is contained in:
Mikhail R. Gadelha 2025-08-07 16:27:46 +02:00 committed by GitHub
parent 82f5bd68d0
commit f3db0cb4d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,99 @@
//
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Helpers
// Maps LMUL string to corresponding value from the Values array
// LMUL values map to array indices as follows:
// MF8 -> Values[0], MF4 -> Values[1], MF2 -> Values[2], M1 -> Values[3],
// M2 -> Values[4], M4 -> Values[5], M8 -> Values[6]
// Shorter lists are allowed, e.g., widening instructions don't work on M8
class GetLMULValue<list<int> Values, string LMUL> {
defvar Index = !cond(
!eq(LMUL, "MF8"): 0,
!eq(LMUL, "MF4"): 1,
!eq(LMUL, "MF2"): 2,
!eq(LMUL, "M1"): 3,
!eq(LMUL, "M2"): 4,
!eq(LMUL, "M4"): 5,
!eq(LMUL, "M8"): 6,
);
assert !lt(Index, !size(Values)),
"Missing LMUL value for '" # LMUL # "'. " #
"Expected at least " # !add(Index, 1) # " elements, but got " #
!size(Values) # ".";
int c = Values[Index];
}
// Returns BaseValue for LMUL values before startLMUL, Value for startLMUL,
// then doubles Value for each subsequent LMUL
// Example: ConstValueUntilLMULThenDoubleBase<"M1", 2, 4, "M8"> returns:
// MF8->2, MF4->2, MF2->2, M1->4, M2->8, M4->16, M8->32
// This is useful for modeling scheduling parameters that scale with LMUL.
class ConstValueUntilLMULThenDoubleBase<string startLMUL, int BaseValue, int Value, string currentLMUL> {
assert !le(BaseValue, Value), "BaseValue must be less-equal to Value";
defvar startPos = GetLMULValue<[0, 1, 2, 3, 4, 5, 6], startLMUL>.c;
defvar currentPos = GetLMULValue<[0, 1, 2, 3, 4, 5, 6], currentLMUL>.c;
// Calculate the difference in positions
defvar posDiff = !sub(currentPos, startPos);
// Calculate Value * (2^posDiff)
int c = !cond(
!eq(posDiff, 0) : Value,
!eq(posDiff, 1) : !mul(Value, 2),
!eq(posDiff, 2) : !mul(Value, 4),
!eq(posDiff, 3) : !mul(Value, 8),
!eq(posDiff, 4) : !mul(Value, 16),
!eq(posDiff, 5) : !mul(Value, 32),
!eq(posDiff, 6) : !mul(Value, 64),
true : BaseValue
);
}
// Same as the previous function but BaseValue == Value
class ConstValueUntilLMULThenDouble<string startLMUL, int Value, string currentLMUL> {
int c = ConstValueUntilLMULThenDoubleBase<startLMUL, Value, Value, currentLMUL>.c;
}
// Returns MF8->1, MF4->1, MF2->2, M1->4, M2->8, M4->16, M8->32
class ConstOneUntilMF4ThenDouble<string mx> {
int c = ConstValueUntilLMULThenDouble<"MF4", 1, mx>.c;
}
// Returns MF8->1, MF4->1, MF2->1, M1->2, M2->4, M4->8, M8->16
class ConstOneUntilMF2ThenDouble<string mx> {
int c = ConstValueUntilLMULThenDouble<"MF2", 1, mx>.c;
}
// Returns MF8->1, MF4->1, MF2->1, M1->1, M2->2, M4->4, M8->8
class ConstOneUntilM1ThenDouble<string mx> {
int c = ConstValueUntilLMULThenDouble<"M1", 1, mx>.c;
}
//===----------------------------------------------------------------------===//
// Latency helper classes
// Used for: arithmetic (add/sub/min/max), saturating/averaging, FP add/sub/min/max
class Get4458Latency<string mx> {
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/4, /*M4=*/5, /*M8=*/8], mx>.c;
}
// Used for: widening operations (no M8)
class Get4588Latency<string mx> {
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/5, /*M4=*/8], mx>.c;
}
// Used for: mask-producing comparisons, carry ops with mask, FP comparisons
class Get461018Latency<string mx> {
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/6, /*M4=*/10, /*M8=*/18], mx>.c;
}
//===----------------------------------------------------------------------===//
class SMX60IsWorstCaseMX<string mx, list<string> MxList> {
string LLMUL = LargestLMUL<MxList>.r;
bit c = !eq(mx, LLMUL);
@ -27,64 +120,6 @@ class SMX60IsWorstCaseMXSEW<string mx, int sew, list<string> MxList, bit isF = 0
defvar SMX60VLEN = 256;
defvar SMX60DLEN = !div(SMX60VLEN, 2);
class Get1248Latency<string mx> {
int c = !cond(
!eq(mx, "M2") : 2,
!eq(mx, "M4") : 4,
!eq(mx, "M8") : 8,
true: 1
);
}
// Used for: logical opsz, shifts, sign ext, merge/move, FP sign/recip/convert, mask ops, slides
class Get4816Latency<string mx> {
int c = !cond(
!eq(mx, "M4") : 8,
!eq(mx, "M8") : 16,
true: 4
);
}
// Used for: arithmetic (add/sub/min/max), saturating/averaging, FP add/sub/min/max
class Get458Latency<string mx> {
int c = !cond(
!eq(mx, "M4") : 5,
!eq(mx, "M8") : 8,
true: 4
);
}
// Widening scaling pattern (4,4,4,4,5,8,8): plateaus at higher LMULs
// Used for: widening operations
class Get4588Latency<string mx> {
int c = !cond(
!eq(mx, "M2") : 5,
!eq(mx, "M4") : 8,
!eq(mx, "M8") : 8, // M8 not supported for most widening, fallback
true: 4
);
}
// Used for: mask-producing comparisons, carry ops with mask, FP comparisons
class Get461018Latency<string mx> {
int c = !cond(
!eq(mx, "M2") : 6,
!eq(mx, "M4") : 10,
!eq(mx, "M8") : 18,
true: 4
);
}
// Used for: e64 multiply pattern, complex ops
class Get781632Latency<string mx> {
int c = !cond(
!eq(mx, "M2") : 8,
!eq(mx, "M4") : 16,
!eq(mx, "M8") : 32,
true: 7
);
}
def SpacemitX60Model : SchedMachineModel {
let IssueWidth = 2; // dual-issue
let MicroOpBufferSize = 0; // in-order
@ -383,12 +418,13 @@ foreach LMul = [1, 2, 4, 8] in {
foreach mx = SchedMxList in {
defvar IsWorstCase = SMX60IsWorstCaseMX<mx, SchedMxList>.c;
let Latency = Get458Latency<mx>.c, ReleaseAtCycles = [4] in {
let Latency = Get4458Latency<mx>.c, ReleaseAtCycles = [4] in {
defm "" : LMULWriteResMX<"WriteVIMinMaxV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMinMaxX", [SMX60_VIEU], mx, IsWorstCase>;
}
let Latency = Get4816Latency<mx>.c, ReleaseAtCycles = [4] in {
defvar VIALULat = ConstValueUntilLMULThenDouble<"M2", 4, mx>.c;
let Latency = VIALULat, ReleaseAtCycles = [4] in {
// Pattern of vadd, vsub, vrsub: 4/4/5/8
// Pattern of vand, vor, vxor: 4/4/8/16
// They are grouped together, so we used the worst case 4/4/8/16
@ -425,7 +461,7 @@ foreach mx = SchedMxList in {
// Pattern of vmacc, vmadd, vmul, vmulh, etc.: e8/e16 = 4/4/5/8, e32 = 5,5,5,8,
// e64 = 7,8,16,32. We use the worst-case until we can split the SEW.
// TODO: change WriteVIMulV, etc to be defined with LMULSEWSchedWrites
let Latency = Get781632Latency<mx>.c, ReleaseAtCycles = [7] in {
let Latency = ConstValueUntilLMULThenDoubleBase<"M2", 7, 8, mx>.c, ReleaseAtCycles = [7] in {
defm "" : LMULWriteResMX<"WriteVIMulV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMulX", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMulAddV", [SMX60_VIEU], mx, IsWorstCase>;
@ -461,15 +497,8 @@ foreach mx = SchedMxList in {
foreach sew = SchedSEWSet<mx>.val in {
defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxList>.c;
// Slightly reduced for fractional LMULs
defvar Multiplier = !cond(
!eq(mx, "MF8") : 12,
!eq(mx, "MF4") : 12,
!eq(mx, "MF2") : 12,
true: 24
);
let Latency = !mul(Get1248Latency<mx>.c, Multiplier), ReleaseAtCycles = [12] in {
defvar VIDivLat = ConstValueUntilLMULThenDouble<"MF2", 12, mx>.c;
let Latency = VIDivLat, ReleaseAtCycles = [12] in {
defm "" : LMULSEWWriteResMXSEW<"WriteVIDivV", [SMX60_VIEU], mx, sew, IsWorstCase>;
defm "" : LMULSEWWriteResMXSEW<"WriteVIDivX", [SMX60_VIEU], mx, sew, IsWorstCase>;
}
@ -480,14 +509,8 @@ foreach mx = SchedMxList in {
foreach mx = SchedMxListW in {
defvar IsWorstCase = SMX60IsWorstCaseMX<mx, SchedMxListW>.c;
// Slightly increased for integer LMULs
defvar Multiplier = !cond(
!eq(mx, "M2") : 2,
!eq(mx, "M4") : 2,
true: 1
);
let Latency = !mul(Get4816Latency<mx>.c, Multiplier), ReleaseAtCycles = [4] in {
defvar VNarrowingLat = ConstValueUntilLMULThenDouble<"M1", 4, mx>.c;
let Latency = VNarrowingLat, ReleaseAtCycles = [4] in {
defm "" : LMULWriteResMX<"WriteVNShiftV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVNShiftX", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVNShiftI", [SMX60_VIEU], mx, IsWorstCase>;