[mlir] Fix new clang-tidy warning llvm-type-switch-case-types. NFC. (#178487)

Pre-commiting this before landing the new check in
https://github.com/llvm/llvm-project/pull/177892
This commit is contained in:
Jakub Kuderski 2026-01-28 14:13:47 -05:00 committed by GitHub
parent 37e93811a6
commit 59e44799bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
48 changed files with 267 additions and 277 deletions

View File

@ -424,11 +424,11 @@ static Value getOriginalVectorValue(Value value) {
Value current = value;
while (Operation *definingOp = current.getDefiningOp()) {
bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
.Case<vector::ShapeCastOp>([&current](auto op) {
.Case([&current](vector::ShapeCastOp op) {
current = op.getSource();
return true;
})
.Case<vector::BroadcastOp>([&current](auto op) {
.Case([&current](vector::BroadcastOp op) {
current = op.getSource();
return false;
})

View File

@ -267,10 +267,10 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
static std::optional<StringRef> getTypeMangling(Type type) {
return TypeSwitch<Type, std::optional<StringRef>>(type)
.Case<Float16Type>([](auto) { return "Dhj"; })
.Case<Float32Type>([](auto) { return "fj"; })
.Case<Float64Type>([](auto) { return "dj"; })
.Case<IntegerType>([](auto intTy) -> std::optional<StringRef> {
.Case([](Float16Type) { return "Dhj"; })
.Case([](Float32Type) { return "fj"; })
.Case([](Float64Type) { return "dj"; })
.Case([](IntegerType intTy) -> std::optional<StringRef> {
switch (intTy.getWidth()) {
case 8:
return "cj";

View File

@ -812,7 +812,7 @@ void ConvertMathToFuncsPass::generateOpImplementations() {
module.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
.Case([&](math::CountLeadingZerosOp op) {
if (!convertCtlz || !isConvertible(op))
return;
Type resultType = getElementTypeOrSelf(op.getResult().getType());
@ -824,7 +824,7 @@ void ConvertMathToFuncsPass::generateOpImplementations() {
if (entry.second)
entry.first->second = createCtlzFunc(&module, resultType);
})
.Case<math::IPowIOp>([&](math::IPowIOp op) {
.Case([&](math::IPowIOp op) {
if (!isConvertible(op))
return;
@ -837,7 +837,7 @@ void ConvertMathToFuncsPass::generateOpImplementations() {
if (entry.second)
entry.first->second = createElementIPowIFunc(&module, resultType);
})
.Case<math::FPowIOp>([&](math::FPowIOp op) {
.Case([&](math::FPowIOp op) {
if (!isFPowIConvertible(op))
return;

View File

@ -350,7 +350,7 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
.Case([&](pdl::AttributeOp attrOp) {
getAttributePredicates(attrOp, predList, builder, inputs);
})
.Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
.Case([&](pdl::ApplyNativeConstraintOp constraintOp) {
getConstraintPredicates(constraintOp, predList, builder, inputs);
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
@ -471,7 +471,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
// We intentionally do not traverse attributes and types, because those
// are expensive to join on.
TypeSwitch<Operation *>(entry.value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
.Case([&](pdl::OperationOp operationOp) {
OperandRange operands = operationOp.getOperandValues();
// Special case when we pass all the operands in one range.
// For those, the index is empty.
@ -544,7 +544,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
Position *&pos, unsigned rootID) {
Value value = opIndex.parent;
TypeSwitch<Operation *>(value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
.Case([&](pdl::OperationOp operationOp) {
LDBG() << " * Value: " << value;
// Get users and iterate over them.
@ -583,7 +583,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
// Update the position
pos = opPos;
})
.Case<pdl::ResultOp>([&](auto resultOp) {
.Case([&](pdl::ResultOp resultOp) {
// Traverse up an individual result.
auto *opPos = dyn_cast<OperationPosition>(pos);
assert(opPos && "operations and results must be interleaved");
@ -592,7 +592,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
// Insert the result position in case we have not visited it yet.
valueToPosition.try_emplace(value, pos);
})
.Case<pdl::ResultsOp>([&](auto resultOp) {
.Case([&](pdl::ResultsOp resultOp) {
// Traverse up a group of results.
auto *opPos = dyn_cast<OperationPosition>(pos);
assert(opPos && "operations and results must be interleaved");

View File

@ -1099,10 +1099,10 @@ namespace {
StringRef getTypeMangling(Type type, bool isSigned) {
return llvm::TypeSwitch<Type, StringRef>(type)
.Case<Float16Type>([](auto) { return "Dh"; })
.Case<Float32Type>([](auto) { return "f"; })
.Case<Float64Type>([](auto) { return "d"; })
.Case<IntegerType>([isSigned](IntegerType intTy) {
.Case([](Float16Type) { return "Dh"; })
.Case([](Float32Type) { return "f"; })
.Case([](Float64Type) { return "d"; })
.Case([isSigned](IntegerType intTy) {
switch (intTy.getWidth()) {
case 1:
return "b";

View File

@ -42,10 +42,8 @@ static bool isZeroConstant(Value val) {
return false;
return TypeSwitch<Attribute, bool>(constant.getValue())
.Case<FloatAttr>(
[](auto floatAttr) { return floatAttr.getValue().isZero(); })
.Case<IntegerAttr>(
[](auto intAttr) { return intAttr.getValue().isZero(); })
.Case([](FloatAttr floatAttr) { return floatAttr.getValue().isZero(); })
.Case([](IntegerAttr intAttr) { return intAttr.getValue().isZero(); })
.Default(false);
}

View File

@ -37,7 +37,7 @@ static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
return failure();
}
return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
.Case([&](memref::SubViewOp subviewOp) {
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, loc, subviewOp.getMixedOffsets(),
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
@ -45,19 +45,18 @@ static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
memrefBase = subviewOp.getSource();
return success();
})
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
.Case([&](memref::ExpandShapeOp expandShapeOp) {
mlir::memref::resolveSourceIndicesExpandShape(
loc, rewriter, expandShapeOp, indices, resolvedIndices, false);
memrefBase = expandShapeOp.getViewSource();
return success();
})
.Case<memref::CollapseShapeOp>(
[&](memref::CollapseShapeOp collapseShapeOp) {
mlir::memref::resolveSourceIndicesCollapseShape(
loc, rewriter, collapseShapeOp, indices, resolvedIndices);
memrefBase = collapseShapeOp.getViewSource();
return success();
})
.Case([&](memref::CollapseShapeOp collapseShapeOp) {
mlir::memref::resolveSourceIndicesCollapseShape(
loc, rewriter, collapseShapeOp, indices, resolvedIndices);
memrefBase = collapseShapeOp.getViewSource();
return success();
})
.Default([&](Operation *op) {
return rewriter.notifyMatchFailure(
op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "

View File

@ -155,17 +155,17 @@ public:
arm_sme::CombiningKind kind = op.getKind();
if (kind == arm_sme::CombiningKind::Add) {
TypeSwitch<Operation *>(extOp)
.Case<arith::ExtFOp>([&](auto) {
.Case([&](arith::ExtFOp) {
rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>(
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
})
.Case<arith::ExtSIOp>([&](auto) {
.Case([&](arith::ExtSIOp) {
rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>(
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
})
.Case<arith::ExtUIOp>([&](auto) {
.Case([&](arith::ExtUIOp) {
rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>(
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
@ -173,17 +173,17 @@ public:
.DefaultUnreachable("unexpected extend op!");
} else if (kind == arm_sme::CombiningKind::Sub) {
TypeSwitch<Operation *>(extOp)
.Case<arith::ExtFOp>([&](auto) {
.Case([&](arith::ExtFOp) {
rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>(
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
})
.Case<arith::ExtSIOp>([&](auto) {
.Case([&](arith::ExtSIOp) {
rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>(
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
})
.Case<arith::ExtUIOp>([&](auto) {
.Case([&](arith::ExtUIOp) {
rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>(
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());

View File

@ -417,11 +417,11 @@ static void forEachPredecessorTileValue(BlockArgument blockArg,
unsigned argNumber = blockArg.getArgNumber();
for (Block *pred : block->getPredecessors()) {
TypeSwitch<Operation *>(pred->getTerminator())
.Case<cf::BranchOp>([&](auto branch) {
.Case([&](cf::BranchOp branch) {
Value predecessorOperand = branch.getDestOperands()[argNumber];
callback(predecessorOperand);
})
.Case<cf::CondBranchOp>([&](auto condBranch) {
.Case([&](cf::CondBranchOp condBranch) {
if (condBranch.getFalseDest() == block) {
Value predecessorOperand =
condBranch.getFalseDestOperands()[argNumber];

View File

@ -368,7 +368,7 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
.Case<SparseSpGEMMOpHandleType>([&](Type) {
os << getSparseHandleKeyword(SparseHandleKind::SpGEMMOp);
})
.Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
.Case([&](MMAMatrixType fragTy) {
os << "mma_matrix<";
auto shape = fragTy.getShape();
for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)

View File

@ -243,7 +243,7 @@ private:
SmallVector<Value, 1> tokens;
tokens.reserve(asyncTokens.size());
TypeSwitch<Operation *>(op)
.Case<async::AwaitOp>([&](auto awaitOp) {
.Case([&](async::AwaitOp awaitOp) {
// Add async.await ops to wait for the !gpu.async.tokens.
builder.setInsertionPointAfter(op);
for (auto asyncToken : asyncTokens)
@ -252,7 +252,7 @@ private:
// Set `it` after the inserted async.await ops.
it = builder.getInsertionPoint();
})
.Case<async::ExecuteOp>([&](auto executeOp) {
.Case([&](async::ExecuteOp executeOp) {
// Set `it` to the beginning of the region and add asyncTokens to the
// async.execute operands.
it = executeOp.getBody()->begin();

View File

@ -322,7 +322,7 @@ static Value getBase(Value v) {
v = op.getSource();
return true;
})
.Case<memref::TransposeOp>([&](auto op) {
.Case([&](memref::TransposeOp op) {
v = op.getIn();
return true;
})

View File

@ -143,12 +143,9 @@ LogicalResult OperationOp::verifyRegions() {
for (Operation &op : getBody().getOps()) {
TypeSwitch<Operation *>(&op)
.Case<OperandsOp>(
[&](OperandsOp op) { insertNames("operands", op.getNames()); })
.Case<ResultsOp>(
[&](ResultsOp op) { insertNames("results", op.getNames()); })
.Case<RegionsOp>(
[&](RegionsOp op) { insertNames("regions", op.getNames()); });
.Case([&](OperandsOp op) { insertNames("operands", op.getNames()); })
.Case([&](ResultsOp op) { insertNames("results", op.getNames()); })
.Case([&](RegionsOp op) { insertNames("regions", op.getNames()); });
}
// Verify that no two operand, result or region share the same name.

View File

@ -779,7 +779,7 @@ verifyStructIndices(Type baseGEPType, unsigned indexPos,
return success();
return TypeSwitch<Type, LogicalResult>(baseGEPType)
.Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
.Case([&](LLVMStructType structType) -> LogicalResult {
auto attr = dyn_cast<IntegerAttr>(indices[indexPos]);
if (!attr)
return emitOpError() << "expected index " << indexPos
@ -3253,13 +3253,13 @@ LogicalResult LLVMFuncOp::verify() {
return WalkResult::advance();
};
return TypeSwitch<Operation *, WalkResult>(op)
.Case<LandingpadOp>([&](auto landingpad) {
.Case([&](LandingpadOp landingpad) {
constexpr StringLiteral errorMessage =
"'llvm.landingpad' should have a consistent result type "
"inside a function";
return checkType(landingpad.getType(), errorMessage);
})
.Case<ResumeOp>([&](auto resume) {
.Case([&](ResumeOp resume) {
constexpr StringLiteral errorMessage =
"'llvm.resume' should have a consistent input type inside a "
"function";

View File

@ -763,26 +763,24 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
bool result =
llvm::TypeSwitch<Type, bool>(type)
.Case<LLVMStructType>([&](auto structType) {
.Case([&](LLVMStructType structType) {
return llvm::all_of(structType.getBody(), isCompatible);
})
.Case<LLVMFunctionType>([&](auto funcType) {
.Case([&](LLVMFunctionType funcType) {
return isCompatible(funcType.getReturnType()) &&
llvm::all_of(funcType.getParams(), isCompatible);
})
.Case<IntegerType>([](auto intType) { return intType.isSignless(); })
.Case<VectorType>([&](auto vecType) {
.Case([](IntegerType intType) { return intType.isSignless(); })
.Case([&](VectorType vecType) {
return vecType.getRank() == 1 &&
isCompatible(vecType.getElementType());
})
.Case<LLVMPointerType>([&](auto pointerType) { return true; })
.Case<LLVMTargetExtType>([&](auto extType) {
.Case([&](LLVMPointerType pointerType) { return true; })
.Case([&](LLVMTargetExtType extType) {
return llvm::all_of(extType.getTypeParams(), isCompatible);
})
// clang-format off
.Case<
LLVMArrayType
>([&](auto containerType) {
.Case([&](LLVMArrayType containerType) {
return isCompatible(containerType.getElementType());
})
.Case<
@ -895,12 +893,12 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
.Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
.Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
.Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
.Case<IntegerType>([](IntegerType intTy) {
.Case([](IntegerType intTy) {
return llvm::TypeSize::getFixed(intTy.getWidth());
})
.Case<LLVMPPCFP128Type>(
[](Type) { return llvm::TypeSize::getFixed(128); })
.Case<VectorType>([](VectorType t) {
.Case([](VectorType t) {
assert(isCompatibleVectorType(t) &&
"unexpected incompatible with LLVM vector type");
llvm::TypeSize elementSize =

View File

@ -4117,10 +4117,10 @@ ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
bool hasRelu) {
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
.Case([&](mlir::Float6E2M3FNType) {
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
})
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
.Case([&](mlir::Float6E3M2FNType) {
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
})
.Default([](mlir::Type) {
@ -4145,13 +4145,13 @@ ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
.Case([&](mlir::Float8E4M3FNType) {
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
.Case([&](mlir::Float8E5M2Type) {
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
})
.Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
.Case([&](mlir::Float8E8M0FNUType) {
if (hasRoundingModeRZ)
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
else if (hasRoundingModeRP)
@ -4172,10 +4172,10 @@ ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
bool hasRelu) {
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
.Case([&](mlir::Float8E4M3FNType) {
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
.Case([&](mlir::Float8E5M2Type) {
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
})
.Default([](mlir::Type) {
@ -4210,11 +4210,11 @@ NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
.Case([&](Float8E4M3FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
})
.Case<Float8E5M2Type>([&](Float8E5M2Type type) {
.Case([&](Float8E5M2Type type) {
return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
})
@ -4250,11 +4250,11 @@ NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
.Case([&](Float6E2M3FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
})
.Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
.Case([&](Float6E3M2FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
})
@ -4278,7 +4278,7 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
.Case([&](Float4E2M1FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
})
@ -4483,11 +4483,11 @@ llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
bool hasRelu = getRelu();
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
.Case([&](mlir::Float8E4M3FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
.Case([&](mlir::Float8E5M2Type) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
})
@ -4502,11 +4502,11 @@ llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
bool hasRelu = getRelu();
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
.Case([&](mlir::Float6E2M3FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
})
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
.Case([&](mlir::Float6E3M2FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
})
@ -4521,7 +4521,7 @@ llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
bool hasRelu = getRelu();
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
.Case([&](mlir::Float4E2M1FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
})

View File

@ -63,10 +63,10 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
return getAsOpFoldResult(
TypeSwitch<Type, Value>(v.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
.Case([&](RankedTensorType t) -> Value {
return tensor::DimOp::create(builder, loc, v, dim);
})
.Case<MemRefType>([&](MemRefType t) -> Value {
.Case([&](MemRefType t) -> Value {
return memref::DimOp::create(builder, loc, v, dim);
}));
}
@ -78,11 +78,11 @@ static Operation *getSlice(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
return TypeSwitch<Type, Operation *>(source.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
.Case([&](RankedTensorType t) -> Operation * {
return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
strides);
})
.Case<MemRefType>([&](MemRefType type) -> Operation * {
.Case([&](MemRefType type) -> Operation * {
return memref::SubViewOp::create(b, loc, source, offsets, sizes,
strides);
})

View File

@ -840,7 +840,7 @@ static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
ExpansionInfo &expansionInfo) {
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
.Case<TransposeOp>([&](TransposeOp transposeOp) {
.Case([&](TransposeOp transposeOp) {
return createExpandedTransposeOp(rewriter, transposeOp,
expandedOpOperands[0], outputs[0],
expansionInfo);

View File

@ -76,13 +76,13 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
SmallVector<NamedAttribute> preservedAttrs;
Operation *newConv =
TypeSwitch<Operation *, Operation *>(operation)
.Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
.Case([&](DepthwiseConv2DNhwcHwcmOp op) {
preservedAttrs = getPrunedAttributeList(op);
return DepthwiseConv2DNhwcHwcOp::create(
rewriter, loc, newInitTy, ValueRange{input, collapsedKernel},
ValueRange{collapsedInit}, stride, dilation);
})
.Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
.Case([&](DepthwiseConv2DNhwcHwcmQOp op) {
preservedAttrs = getPrunedAttributeList(op);
return DepthwiseConv2DNhwcHwcQOp::create(
rewriter, loc, newInitTy,

View File

@ -56,7 +56,7 @@ using namespace mlir::linalg;
SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter,
Operation *op) {
return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
.Case<scf::ForOp>([&](scf::ForOp forOp) {
.Case([&](scf::ForOp forOp) {
scf::ForOp partialIteration;
if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
partialIteration)))

View File

@ -644,19 +644,19 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
.Case<arith::AddIOp, arith::AddFOp>(
[&](auto op) { return CombiningKind::ADD; })
.Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
.Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
.Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
.Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
.Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
.Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
.Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
.Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
.Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
.Case([&](arith::AndIOp op) { return CombiningKind::AND; })
.Case([&](arith::MaxSIOp op) { return CombiningKind::MAXSI; })
.Case([&](arith::MaxUIOp op) { return CombiningKind::MAXUI; })
.Case([&](arith::MaximumFOp op) { return CombiningKind::MAXIMUMF; })
.Case([&](arith::MaxNumFOp op) { return CombiningKind::MAXNUMF; })
.Case([&](arith::MinSIOp op) { return CombiningKind::MINSI; })
.Case([&](arith::MinUIOp op) { return CombiningKind::MINUI; })
.Case([&](arith::MinimumFOp op) { return CombiningKind::MINIMUMF; })
.Case([&](arith::MinNumFOp op) { return CombiningKind::MINNUMF; })
.Case<arith::MulIOp, arith::MulFOp>(
[&](auto op) { return CombiningKind::MUL; })
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
.Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
.Case([&](arith::OrIOp op) { return CombiningKind::OR; })
.Case([&](arith::XOrIOp op) { return CombiningKind::XOR; })
.Default(std::nullopt);
}
@ -2684,21 +2684,21 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
return failure();
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
.Case([&](linalg::LinalgOp linalgOp) {
return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
vectorizeNDExtract,
flatten1DDepthwiseConv);
})
.Case<tensor::PadOp>([&](auto padOp) {
.Case([&](tensor::PadOp padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
.Case<linalg::PackOp>([&](auto packOp) {
.Case([&](linalg::PackOp packOp) {
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
})
.Case<linalg::UnPackOp>([&](auto unpackOp) {
.Case([&](linalg::UnPackOp unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
.Case([&](tensor::InsertSliceOp sliceOp) {
return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
})
.Default(failure());
@ -2755,7 +2755,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
SmallVector<Value> results;
auto vectorizeResult =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
.Case([&](linalg::LinalgOp linalgOp) {
// Check for both named as well as generic convolution ops.
if (isaConvolutionOpInterface(linalgOp)) {
FailureOr<Operation *> convOr = vectorizeConvolution(
@ -2789,20 +2789,20 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
// notified and we will end up with read-after-free issues!
return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
})
.Case<tensor::PadOp>([&](auto padOp) {
.Case([&](tensor::PadOp padOp) {
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
results);
})
.Case<linalg::PackOp>([&](auto packOp) {
.Case([&](linalg::PackOp packOp) {
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
results);
})
.Case<linalg::UnPackOp>([&](auto unpackOp) {
.Case([&](linalg::UnPackOp unpackOp) {
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes,
inputScalableVecDims, results);
})
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
.Case([&](tensor::InsertSliceOp sliceOp) {
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
results);
})

View File

@ -124,67 +124,67 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
Value offset) {
Location loc = op->getLoc();
llvm::TypeSwitch<Operation *>(op.getOperation())
.template Case<memref::AllocOp>([&](auto oper) {
.Case([&](memref::AllocOp oper) {
auto newAlloc = memref::AllocOp::create(
rewriter, loc, cast<MemRefType>(flatMemref.getType()),
oper.getAlignmentAttr());
castAllocResult(oper, newAlloc, loc, rewriter);
})
.template Case<memref::AllocaOp>([&](auto oper) {
.Case([&](memref::AllocaOp oper) {
auto newAlloca = memref::AllocaOp::create(
rewriter, loc, cast<MemRefType>(flatMemref.getType()),
oper.getAlignmentAttr());
castAllocResult(oper, newAlloca, loc, rewriter);
})
.template Case<memref::LoadOp>([&](auto op) {
.Case([&](memref::LoadOp op) {
auto newLoad =
memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
})
.template Case<memref::StoreOp>([&](auto op) {
.Case([&](memref::StoreOp op) {
auto newStore =
memref::StoreOp::create(rewriter, loc, op->getOperands().front(),
flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
})
.template Case<vector::LoadOp>([&](auto op) {
.Case([&](vector::LoadOp op) {
auto newLoad =
vector::LoadOp::create(rewriter, loc, op->getResultTypes(),
flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
})
.template Case<vector::StoreOp>([&](auto op) {
.Case([&](vector::StoreOp op) {
auto newStore =
vector::StoreOp::create(rewriter, loc, op->getOperands().front(),
flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
})
.template Case<vector::MaskedLoadOp>([&](auto op) {
.Case([&](vector::MaskedLoadOp op) {
auto newMaskedLoad = vector::MaskedLoadOp::create(
rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
op.getMask(), op.getPassThru());
newMaskedLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedLoad.getResult());
})
.template Case<vector::MaskedStoreOp>([&](auto op) {
.Case([&](vector::MaskedStoreOp op) {
auto newMaskedStore = vector::MaskedStoreOp::create(
rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(),
op.getValueToStore());
newMaskedStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedStore);
})
.template Case<vector::TransferReadOp>([&](auto op) {
.Case([&](vector::TransferReadOp op) {
auto newTransferRead = vector::TransferReadOp::create(
rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
op.getPadding());
rewriter.replaceOp(op, newTransferRead.getResult());
})
.template Case<vector::TransferWriteOp>([&](auto op) {
.Case([&](vector::TransferWriteOp op) {
auto newTransferWrite = vector::TransferWriteOp::create(
rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
rewriter.replaceOp(op, newTransferWrite);

View File

@ -359,7 +359,7 @@ public:
collectGlobalsFromDeviceRegion(
accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
})
.Case<FunctionOpInterface>([&](auto func) {
.Case([&](FunctionOpInterface func) {
if ((acc::isAccRoutine(func) ||
acc::isSpecializedAccRoutine(func)) &&
!func.isExternal())
@ -367,13 +367,13 @@ public:
globalsToAccDeclare, accSupport,
symTab);
})
.Case<acc::GlobalVariableOpInterface>([&](auto globalVarOp) {
.Case([&](acc::GlobalVariableOpInterface globalVarOp) {
if (globalVarOp->getAttr(acc::getDeclareAttrName()))
if (Region *initRegion = globalVarOp.getInitRegion())
collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare,
accSupport, symTab);
})
.Case<acc::PrivateRecipeOp>([&](auto privateRecipe) {
.Case([&](acc::PrivateRecipeOp privateRecipe) {
if (hasRelevantRecipeUse(privateRecipe, mod)) {
collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(),
globalsToAccDeclare, accSupport,
@ -383,7 +383,7 @@ public:
symTab);
}
})
.Case<acc::FirstprivateRecipeOp>([&](auto firstprivateRecipe) {
.Case([&](acc::FirstprivateRecipeOp firstprivateRecipe) {
if (hasRelevantRecipeUse(firstprivateRecipe, mod)) {
collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(),
globalsToAccDeclare, accSupport,
@ -396,7 +396,7 @@ public:
symTab);
}
})
.Case<acc::ReductionRecipeOp>([&](auto reductionRecipe) {
.Case([&](acc::ReductionRecipeOp reductionRecipe) {
if (hasRelevantRecipeUse(reductionRecipe, mod)) {
collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(),
globalsToAccDeclare, accSupport,

View File

@ -70,7 +70,7 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
// Traverse the operands / parent.
TypeSwitch<Operation *>(op)
.Case<OperationOp>([&visited](auto operation) {
.Case([&visited](OperationOp operation) {
for (Value operand : operation.getOperandValues())
visit(operand.getDefiningOp(), visited);
})

View File

@ -48,9 +48,8 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
using tensor::InsertSliceOp;
value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
.template Case<InsertSliceOp>(
[&](InsertSliceOp op) { return op.getDest(); })
.template Case<ForOp>([&](ForOp forOp) {
.Case([&](InsertSliceOp op) { return op.getDest(); })
.Case([&](ForOp forOp) {
return isShapePreserving(forOp, opResult.getResultNumber())
? forOp.getInitArgs()[opResult.getResultNumber()]
: Value();

View File

@ -341,11 +341,11 @@ LogicalResult spirv::CompositeConstructOp::verify() {
// 3. Arrays (1 constituent for each array element)
// 4. Vectors (1 constituent (sub-)element for each vector element)
auto coopElementType =
llvm::TypeSwitch<Type, Type>(getType())
.Case<spirv::CooperativeMatrixType>(
[](auto coopType) { return coopType.getElementType(); })
.Default(nullptr);
auto coopElementType = llvm::TypeSwitch<Type, Type>(getType())
.Case([](spirv::CooperativeMatrixType coopType) {
return coopType.getElementType();
})
.Default(nullptr);
// Case 1. -- matrices.
if (coopElementType) {

View File

@ -50,10 +50,10 @@ public:
[this](auto concreteType) { addConcrete(concreteType); })
.Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
[this](auto concreteType) { add(concreteType.getElementType()); })
.Case<SampledImageType>([this](SampledImageType concreteType) {
.Case([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
.Case<StructType>([this](StructType concreteType) {
.Case([this](StructType concreteType) {
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
@ -97,13 +97,13 @@ public:
.Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
[this](auto concreteType) { addConcrete(concreteType); })
.Case<ArrayType>([this](ArrayType concreteType) {
.Case([this](ArrayType concreteType) {
add(concreteType.getElementType());
})
.Case<SampledImageType>([this](SampledImageType concreteType) {
.Case([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
.Case<StructType>([this](StructType concreteType) {
.Case([this](StructType concreteType) {
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
@ -195,9 +195,8 @@ Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
.Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
TensorArmType>([](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
.Case([](MatrixType type) { return type.getColumnType(); })
.Case([index](StructType type) { return type.getElementType(index); })
.DefaultUnreachable("Invalid composite type");
}
@ -205,7 +204,7 @@ unsigned CompositeType::getNumElements() const {
return TypeSwitch<SPIRVType, unsigned>(*this)
.Case<ArrayType, StructType, TensorArmType, VectorType>(
[](auto type) { return type.getNumElements(); })
.Case<MatrixType>([](MatrixType type) { return type.getNumColumns(); })
.Case([](MatrixType type) { return type.getNumColumns(); })
.DefaultUnreachable("Invalid type for number of elements query");
}
@ -704,7 +703,7 @@ void SPIRVType::getCapabilities(
std::optional<int64_t> SPIRVType::getSizeInBytes() {
return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this)
.Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
.Case([](ScalarType type) -> std::optional<int64_t> {
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with
// boolean type. If they are stored (in conjunction with OpVariable),
@ -717,7 +716,7 @@ std::optional<int64_t> SPIRVType::getSizeInBytes() {
return std::nullopt;
return bitWidth / 8;
})
.Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
.Case([](ArrayType type) -> std::optional<int64_t> {
// Since array type may have an explicit stride declaration (in bytes),
// we also include it in the calculation.
auto elementType = cast<SPIRVType>(type.getElementType());

View File

@ -628,9 +628,9 @@ static spirv::Dim convertRank(int64_t rank) {
static spirv::ImageFormat getImageFormat(Type elementType) {
return TypeSwitch<Type, spirv::ImageFormat>(elementType)
.Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
.Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
.Case<IntegerType>([](IntegerType intType) {
.Case([](Float16Type) { return spirv::ImageFormat::R16f; })
.Case([](Float32Type) { return spirv::ImageFormat::R32f; })
.Case([](IntegerType intType) {
auto const isSigned = intType.isSigned() || intType.isSignless();
#define BIT_WIDTH_CASE(BIT_WIDTH) \
case BIT_WIDTH: \

View File

@ -2658,13 +2658,13 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
int64_t numPayloads =
llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
.Case<TransformHandleTypeInterface>([&](auto x) {
.Case([&](TransformHandleTypeInterface x) {
return llvm::range_size(state.getPayloadOps(getHandle()));
})
.Case<TransformValueHandleTypeInterface>([&](auto x) {
.Case([&](TransformValueHandleTypeInterface x) {
return llvm::range_size(state.getPayloadValues(getHandle()));
})
.Case<TransformParamTypeInterface>([&](auto x) {
.Case([&](TransformParamTypeInterface x) {
return llvm::range_size(state.getParams(getHandle()));
})
.DefaultUnreachable("unknown transform dialect type interface");

View File

@ -113,8 +113,9 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
std::optional<Operation *> newMask =
TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
.Case<vector::CreateMaskOp>(
[&](auto createMaskOp) -> std::optional<Operation *> {
.Case(
[&](vector::CreateMaskOp createMaskOp)
-> std::optional<Operation *> {
OperandRange maskOperands = createMaskOp.getOperands();
// The `vector.create_mask` op creates a mask arrangement
// without any zeros at the front. Also, because
@ -134,8 +135,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
newMaskOperands);
})
.Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
-> std::optional<Operation *> {
.Case([&](vector::ConstantMaskOp constantMaskOp)
-> std::optional<Operation *> {
// Take the shape of mask, compress its trailing dimension:
SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
int64_t &maskIndex = maskDimSizes.back();
@ -144,8 +145,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
maskDimSizes);
})
.Case<arith::ConstantOp>([&](auto constantOp)
-> std::optional<Operation *> {
.Case([&](arith::ConstantOp constantOp)
-> std::optional<Operation *> {
// TODO: Support multiple dimensions.
if (maskShape.size() != 1)
return std::nullopt;

View File

@ -881,7 +881,7 @@ static bool isLinearizable(Operation *op) {
// As type legalization is done with vector.shape_cast, shape_cast
// itself cannot be linearized (will create new shape_casts to linearize
// ad infinitum).
.Case<vector::ShapeCastOp>([&](auto) { return false; })
.Case([&](vector::ShapeCastOp) { return false; })
// The operations
// - vector.extract_strided_slice
// - vector.extract
@ -891,18 +891,16 @@ static bool isLinearizable(Operation *op) {
// vector.shuffle only supports fixed size vectors, so it is impossible to
// use this approach to linearize these ops if they operate on scalable
// vectors.
.Case<vector::ExtractStridedSliceOp>(
[&](vector::ExtractStridedSliceOp extractOp) {
return !extractOp.getType().isScalable();
})
.Case<vector::InsertStridedSliceOp>(
[&](vector::InsertStridedSliceOp insertOp) {
return !insertOp.getType().isScalable();
})
.Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
.Case([&](vector::ExtractStridedSliceOp extractOp) {
return !extractOp.getType().isScalable();
})
.Case([&](vector::InsertStridedSliceOp insertOp) {
return !insertOp.getType().isScalable();
})
.Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
.Case([&](vector::InsertOp insertOp) {
return !insertOp.getType().isScalable();
})
.Case([&](vector::ExtractOp extractOp) {
return !extractOp.getSourceVectorType().isScalable();
})
.Default([&](auto) { return true; });

View File

@ -300,11 +300,12 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
RewriterBase &rewriter) {
auto loc = xfer->getLoc();
Value base = TypeSwitch<Operation *, Value>(xfer)
.Case<vector::TransferReadOp>(
[&](auto readOp) { return readOp.getBase(); })
.Case<vector::TransferWriteOp>(
[&](auto writeOp) { return writeOp.getOperand(1); });
Value base =
TypeSwitch<Operation *, Value>(xfer)
.Case([&](vector::TransferReadOp readOp) { return readOp.getBase(); })
.Case([&](vector::TransferWriteOp writeOp) {
return writeOp.getOperand(1);
});
SmallVector<OpFoldResult> mixedSourceDims =
hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base)

View File

@ -460,43 +460,45 @@ LogicalResult LayoutInfoPropagation::visitOperation(
Operation *op, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
TypeSwitch<Operation *>(op)
.Case<xegpu::DpasOp>(
[&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
.Case<xegpu::StoreNdOp>(
[&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
.Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
.Case(
[&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
.Case([&](xegpu::StoreNdOp storeNdOp) {
visitStoreNdOp(storeNdOp, operands, results);
})
.Case([&](xegpu::StoreScatterOp storeScatterOp) {
visitStoreScatterOp(storeScatterOp, operands, results);
})
.Case<xegpu::LoadNdOp>(
[&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
.Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
.Case([&](xegpu::LoadNdOp loadNdOp) {
visitLoadNdOp(loadNdOp, operands, results);
})
.Case([&](xegpu::LoadGatherOp loadGatherOp) {
visitLoadGatherOp(loadGatherOp, operands, results);
})
.Case<xegpu::CreateDescOp>([&](auto createDescOp) {
.Case([&](xegpu::CreateDescOp createDescOp) {
visitCreateDescOp(createDescOp, operands, results);
})
.Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
.Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
})
.Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
.Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
visitPrefetchNdOp(prefetchNdOp, operands, results);
})
.Case<vector::TransposeOp>([&](auto transposeOp) {
.Case([&](vector::TransposeOp transposeOp) {
visitTransposeOp(transposeOp, operands, results);
})
.Case<vector::BitCastOp>([&](auto bitcastOp) {
.Case([&](vector::BitCastOp bitcastOp) {
visitVectorBitcastOp(bitcastOp, operands, results);
})
.Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
.Case([&](vector::MultiDimReductionOp reductionOp) {
visitVectorMultiReductionOp(reductionOp, operands, results);
})
.Case<vector::BroadcastOp>([&](auto broadcastOp) {
.Case([&](vector::BroadcastOp broadcastOp) {
visitVectorBroadCastOp(broadcastOp, operands, results);
})
.Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
.Case([&](vector::ShapeCastOp shapeCastOp) {
visitShapeCastOp(shapeCastOp, operands, results);
})
.Case<xegpu::StoreMatrixOp>([&](auto storeMatrixOp) {
.Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
visitStoreMatrixOp(storeMatrixOp, operands, results);
})
// All other ops.
@ -1641,16 +1643,14 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
LogicalResult r = success();
TypeSwitch<Operation *>(&op)
.Case<mlir::RegionBranchTerminatorOpInterface>(
[&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
r = updateControlFlowOps(builder, branchTermOp,
getXeGPULayoutForValue);
})
.Case<mlir::FunctionOpInterface>(
[&](mlir::FunctionOpInterface funcOp) {
r = updateFunctionOpInterface(builder, funcOp,
getXeGPULayoutForValue);
})
.Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
r = updateControlFlowOps(builder, branchTermOp,
getXeGPULayoutForValue);
})
.Case([&](mlir::FunctionOpInterface funcOp) {
r = updateFunctionOpInterface(builder, funcOp,
getXeGPULayoutForValue);
})
.Default([&](Operation *op) {
r = updateOp(builder, op, getXeGPULayoutForValue);
});

View File

@ -2157,16 +2157,16 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
return;
TypeSwitch<LocationAttr>(loc)
.Case<OpaqueLoc>([&](OpaqueLoc loc) {
.Case([&](OpaqueLoc loc) {
printLocationInternal(loc.getFallbackLocation(), pretty);
})
.Case<UnknownLoc>([&](UnknownLoc loc) {
.Case([&](UnknownLoc loc) {
if (pretty)
os << "[unknown]";
else
os << "unknown";
})
.Case<FileLineColRange>([&](FileLineColRange loc) {
.Case([&](FileLineColRange loc) {
if (pretty)
os << loc.getFilename().getValue();
else
@ -2184,7 +2184,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to "
<< loc.getEndLine() << ':' << loc.getEndColumn();
})
.Case<NameLoc>([&](NameLoc loc) {
.Case([&](NameLoc loc) {
printEscapedString(loc.getName());
// Print the child if it isn't unknown.
@ -2195,7 +2195,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
os << ')';
}
})
.Case<CallSiteLoc>([&](CallSiteLoc loc) {
.Case([&](CallSiteLoc loc) {
Location caller = loc.getCaller();
Location callee = loc.getCallee();
if (!pretty)
@ -2218,7 +2218,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
if (!pretty)
os << ")";
})
.Case<FusedLoc>([&](FusedLoc loc) {
.Case([&](FusedLoc loc) {
if (!pretty)
os << "fused";
if (Attribute metadata = loc.getMetadata()) {
@ -2744,7 +2744,7 @@ void AsmPrinter::Impl::printType(Type type) {
void AsmPrinter::Impl::printTypeImpl(Type type) {
TypeSwitch<Type>(type)
.Case<OpaqueType>([&](OpaqueType opaqueTy) {
.Case([&](OpaqueType opaqueTy) {
printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
opaqueTy.getTypeData());
})
@ -2767,14 +2767,14 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Case<Float64Type>([&](Type) { os << "f64"; })
.Case<Float80Type>([&](Type) { os << "f80"; })
.Case<Float128Type>([&](Type) { os << "f128"; })
.Case<IntegerType>([&](IntegerType integerTy) {
.Case([&](IntegerType integerTy) {
if (integerTy.isSigned())
os << 's';
else if (integerTy.isUnsigned())
os << 'u';
os << 'i' << integerTy.getWidth();
})
.Case<FunctionType>([&](FunctionType funcTy) {
.Case([&](FunctionType funcTy) {
os << '(';
interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
os << ") -> ";
@ -2787,7 +2787,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << ')';
}
})
.Case<VectorType>([&](VectorType vectorTy) {
.Case([&](VectorType vectorTy) {
auto scalableDims = vectorTy.getScalableDims();
os << "vector<";
auto vShape = vectorTy.getShape();
@ -2804,7 +2804,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
printType(vectorTy.getElementType());
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
.Case([&](RankedTensorType tensorTy) {
os << "tensor<";
printDimensionList(tensorTy.getShape());
if (!tensorTy.getShape().empty())
@ -2817,12 +2817,12 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
}
os << '>';
})
.Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
.Case([&](UnrankedTensorType tensorTy) {
os << "tensor<*x";
printType(tensorTy.getElementType());
os << '>';
})
.Case<MemRefType>([&](MemRefType memrefTy) {
.Case([&](MemRefType memrefTy) {
os << "memref<";
printDimensionList(memrefTy.getShape());
if (!memrefTy.getShape().empty())
@ -2840,7 +2840,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
}
os << '>';
})
.Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
.Case([&](UnrankedMemRefType memrefTy) {
os << "memref<*x";
printType(memrefTy.getElementType());
// Only print the memory space if it is the non-default one.
@ -2850,19 +2850,19 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
}
os << '>';
})
.Case<ComplexType>([&](ComplexType complexTy) {
.Case([&](ComplexType complexTy) {
os << "complex<";
printType(complexTy.getElementType());
os << '>';
})
.Case<TupleType>([&](TupleType tupleTy) {
.Case([&](TupleType tupleTy) {
os << "tuple<";
interleaveComma(tupleTy.getTypes(),
[&](Type type) { printType(type); });
os << '>';
})
.Case<NoneType>([&](Type) { os << "none"; })
.Case<GraphType>([&](GraphType graphTy) {
.Case([&](GraphType graphTy) {
os << '(';
interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
os << ") -> ";

View File

@ -389,7 +389,7 @@ collectParentLayouts(Operation *leaf,
for (Operation *parent = leaf->getParentOp(); parent != nullptr;
parent = parent->getParentOp()) {
llvm::TypeSwitch<Operation *>(parent)
.Case<ModuleOp>([&](ModuleOp op) {
.Case([&](ModuleOp op) {
// Skip top-level module op unless it has a layout. Top-level module
// without layout is most likely the one implicitly added by the
// parser and it doesn't have location. Top-level null specification
@ -401,7 +401,7 @@ collectParentLayouts(Operation *leaf,
if (opLocations)
opLocations->push_back(op.getLoc());
})
.Case<DataLayoutOpInterface>([&](DataLayoutOpInterface op) {
.Case([&](DataLayoutOpInterface op) {
specs.push_back(op.getDataLayoutSpec());
if (opLocations)
opLocations->push_back(op.getLoc());

View File

@ -404,7 +404,7 @@ struct ByteCodeWriter {
[](Type) { return PDLValue::Kind::Attribute; })
.Case<pdl::OperationType>(
[](Type) { return PDLValue::Kind::Operation; })
.Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
.Case([](pdl::RangeType rangeTy) {
if (isa<pdl::TypeType>(rangeTy.getElementType()))
return PDLValue::Kind::TypeRange;
return PDLValue::Kind::ValueRange;

View File

@ -852,7 +852,7 @@ bool Operator::hasAssemblyFormat() const {
StringRef Operator::getAssemblyFormat() const {
return TypeSwitch<const Init *, StringRef>(def.getValueInit("assemblyFormat"))
.Case<StringInit>([&](auto *init) { return init->getValue(); });
.Case([&](const StringInit *init) { return init->getValue(); });
}
void Operator::print(llvm::raw_ostream &os) const {

View File

@ -52,7 +52,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
return std::nullopt;
return TypeSwitch<const llvm::Init *, std::optional<StringRef>>(
builderCall->getValue())
.Case<llvm::StringInit>([&](auto *init) {
.Case([&](const llvm::StringInit *init) {
StringRef value = init->getValue();
return value.empty() ? std::optional<StringRef>() : value;
})

View File

@ -70,19 +70,19 @@ static inline LogicalResult interleaveCommaWithError(const Container &c,
/// imply higher precedence.
static FailureOr<int> getOperatorPrecedence(Operation *operation) {
return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation)
.Case<emitc::AddressOfOp>([&](auto op) { return 15; })
.Case<emitc::AddOp>([&](auto op) { return 12; })
.Case<emitc::ApplyOp>([&](auto op) { return 15; })
.Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
.Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
.Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
.Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
.Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
.Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
.Case<emitc::CallOp>([&](auto op) { return 16; })
.Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
.Case<emitc::CastOp>([&](auto op) { return 15; })
.Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
.Case([&](emitc::AddressOfOp op) { return 15; })
.Case([&](emitc::AddOp op) { return 12; })
.Case([&](emitc::ApplyOp op) { return 15; })
.Case([&](emitc::BitwiseAndOp op) { return 7; })
.Case([&](emitc::BitwiseLeftShiftOp op) { return 11; })
.Case([&](emitc::BitwiseNotOp op) { return 15; })
.Case([&](emitc::BitwiseOrOp op) { return 5; })
.Case([&](emitc::BitwiseRightShiftOp op) { return 11; })
.Case([&](emitc::BitwiseXorOp op) { return 6; })
.Case([&](emitc::CallOp op) { return 16; })
.Case([&](emitc::CallOpaqueOp op) { return 16; })
.Case([&](emitc::CastOp op) { return 15; })
.Case([&](emitc::CmpOp op) -> FailureOr<int> {
switch (op.getPredicate()) {
case emitc::CmpPredicate::eq:
case emitc::CmpPredicate::ne:
@ -97,18 +97,18 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
}
return op->emitError("unsupported cmp predicate");
})
.Case<emitc::ConditionalOp>([&](auto op) { return 2; })
.Case<emitc::ConstantOp>([&](auto op) { return 17; })
.Case<emitc::DivOp>([&](auto op) { return 13; })
.Case<emitc::LoadOp>([&](auto op) { return 16; })
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
.Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
.Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
.Case<emitc::MulOp>([&](auto op) { return 13; })
.Case<emitc::RemOp>([&](auto op) { return 13; })
.Case<emitc::SubOp>([&](auto op) { return 12; })
.Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
.Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
.Case([&](emitc::ConditionalOp op) { return 2; })
.Case([&](emitc::ConstantOp op) { return 17; })
.Case([&](emitc::DivOp op) { return 13; })
.Case([&](emitc::LoadOp op) { return 16; })
.Case([&](emitc::LogicalAndOp op) { return 4; })
.Case([&](emitc::LogicalNotOp op) { return 15; })
.Case([&](emitc::LogicalOrOp op) { return 3; })
.Case([&](emitc::MulOp op) { return 13; })
.Case([&](emitc::RemOp op) { return 13; })
.Case([&](emitc::SubOp op) { return 12; })
.Case([&](emitc::UnaryMinusOp op) { return 15; })
.Case([&](emitc::UnaryPlusOp op) { return 15; })
.Default([](auto op) { return op->emitError("unsupported operation"); });
}
@ -1835,7 +1835,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
LogicalResult status =
llvm::TypeSwitch<Operation *, LogicalResult>(&op)
// Builtin ops.
.Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
.Case([&](ModuleOp op) { return printOperation(*this, op); })
// CF ops.
.Case<cf::BranchOp, cf::CondBranchOp>(
[&](auto op) { return printOperation(*this, op); })

View File

@ -647,11 +647,10 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
dialect.walk([&](mlir::Operation *op) {
res =
llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case<irdl::DialectOp>(([](irdl::DialectOp) { return success(); }))
.Case<irdl::OperationOp>(
([](irdl::OperationOp) { return success(); }))
.Case<irdl::TypeOp>(([](irdl::TypeOp) { return success(); }))
.Case<irdl::OperandsOp>(([](irdl::OperandsOp op) -> LogicalResult {
.Case(([](irdl::DialectOp) { return success(); }))
.Case(([](irdl::OperationOp) { return success(); }))
.Case(([](irdl::TypeOp) { return success(); }))
.Case(([](irdl::OperandsOp op) -> LogicalResult {
if (llvm::all_of(
op.getVariadicity(), [](irdl::VariadicityAttr attr) {
return attr.getValue() == irdl::Variadicity::single;
@ -660,7 +659,7 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
return op.emitError("IRDL C++ translation does not yet support "
"variadic operations");
}))
.Case<irdl::ResultsOp>(([](irdl::ResultsOp op) -> LogicalResult {
.Case(([](irdl::ResultsOp op) -> LogicalResult {
if (llvm::all_of(
op.getVariadicity(), [](irdl::VariadicityAttr attr) {
return attr.getValue() == irdl::Variadicity::single;
@ -669,9 +668,9 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
return op.emitError(
"IRDL C++ translation does not yet support variadic results");
}))
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
.Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
.Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
.Case(([](irdl::AnyOp) { return success(); }))
.Case(([](irdl::RegionOp) { return success(); }))
.Case(([](irdl::RegionsOp) { return success(); }))
.Default([](mlir::Operation *op) -> LogicalResult {
return op->emitError("IRDL C++ translation does not yet support "
"translation of ")

View File

@ -27,10 +27,10 @@ void registerIRDLToCppTranslation() {
"irdl-to-cpp", "translate IRDL dialect definitions to C++ definitions",
[](Operation *op, raw_ostream &output) {
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<irdl::DialectOp>([&](irdl::DialectOp dialectOp) {
.Case([&](irdl::DialectOp dialectOp) {
return irdl::translateIRDLDialectToCpp(dialectOp, output);
})
.Case<ModuleOp>([&](ModuleOp moduleOp) {
.Case([&](ModuleOp moduleOp) {
for (Operation &op : moduleOp.getBody()->getOperations())
if (auto dialectOp = llvm::dyn_cast<irdl::DialectOp>(op))
if (failed(

View File

@ -284,7 +284,7 @@ DebugTranslation::translateRecursive(DIRecursiveTypeAttrInterface attr) {
llvm::DINode *result =
TypeSwitch<DIRecursiveTypeAttrInterface, llvm::DINode *>(attr)
.Case<DICompositeTypeAttr>([&](auto attr) {
.Case([&](DICompositeTypeAttr attr) {
auto temporary = translateTemporaryImpl(attr);
setRecursivePlaceholder(temporary.get());
// Must call `translateImpl` directly instead of `translate` to
@ -293,7 +293,7 @@ DebugTranslation::translateRecursive(DIRecursiveTypeAttrInterface attr) {
temporary->replaceAllUsesWith(concrete);
return concrete;
})
.Case<DISubprogramAttr>([&](auto attr) {
.Case([&](DISubprogramAttr attr) {
auto temporary = translateTemporaryImpl(attr);
setRecursivePlaceholder(temporary.get());
// Must call `translateImpl` directly instead of `translate` to

View File

@ -333,16 +333,16 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
for (auto flagAttr : flags.getAsRange<ModuleFlagAttr>()) {
llvm::Metadata *valueMetadata =
llvm::TypeSwitch<Attribute, llvm::Metadata *>(flagAttr.getValue())
.Case<StringAttr>([&](auto strAttr) {
.Case([&](StringAttr strAttr) {
return llvm::MDString::get(builder.getContext(),
strAttr.getValue());
})
.Case<IntegerAttr>([&](auto intAttr) {
.Case([&](IntegerAttr intAttr) {
return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
llvm::Type::getInt32Ty(builder.getContext()),
intAttr.getInt()));
})
.Case<ArrayAttr>([&](auto arrayAttr) {
.Case([&](ArrayAttr arrayAttr) {
return convertModuleFlagValue(flagAttr.getKey().getValue(),
arrayAttr, builder,
moduleTranslation);

View File

@ -448,7 +448,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
.Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
[&](auto op) { checkDepend(op, result); })
.Case<omp::TargetUpdateOp>([&](auto op) { checkDepend(op, result); })
.Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
.Case([&](omp::TargetOp op) {
checkAllocate(op, result);
checkBare(op, result);

View File

@ -420,18 +420,18 @@ public:
return translateTypeOffsetOp(typeOffsetOp, builder,
moduleTranslation);
})
.Case<GatherOp>([&](GatherOp gatherOp) {
.Case([&](GatherOp gatherOp) {
return translateGatherOp(gatherOp, builder, moduleTranslation);
})
.Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
.Case([&](MaskedLoadOp maskedLoadOp) {
return translateMaskedLoadOp(maskedLoadOp, builder,
moduleTranslation);
})
.Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
.Case([&](MaskedStoreOp maskedStoreOp) {
return translateMaskedStoreOp(maskedStoreOp, builder,
moduleTranslation);
})
.Case<ScatterOp>([&](ScatterOp scatterOp) {
.Case([&](ScatterOp scatterOp) {
return translateScatterOp(scatterOp, builder, moduleTranslation);
})
.Default([&](Operation *op) {

View File

@ -1741,20 +1741,20 @@ static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder,
ModuleTranslation &moduleTranslation,
Location loc) {
return llvm::TypeSwitch<Attribute, LogicalResult>(namedAttr.getValue())
.Case<TypeAttr>([&](auto typeAttr) {
.Case([&](TypeAttr typeAttr) {
attrBuilder.addTypeAttr(
llvmKind, moduleTranslation.convertType(typeAttr.getValue()));
return success();
})
.Case<IntegerAttr>([&](auto intAttr) {
.Case([&](IntegerAttr intAttr) {
attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
return success();
})
.Case<UnitAttr>([&](auto) {
.Case([&](UnitAttr) {
attrBuilder.addAttribute(llvmKind);
return success();
})
.Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
.Case([&](LLVM::ConstantRangeAttr rangeAttr) {
attrBuilder.addConstantRangeAttr(
llvmKind,
llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));

View File

@ -25,7 +25,7 @@ static void addOperands(Operation *op, SetVector<Value> &operandSet) {
if (!op)
return;
TypeSwitch<Operation *, void>(op)
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
.Case([&](linalg::LinalgOp linalgOp) {
SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
operandSet.insert_range(inputOperands);
})

View File

@ -81,14 +81,15 @@ private:
/// `CustomDirective` with a single parameter argument or `RefDirective`.
static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) {
return TypeSwitch<FormatElement *, ParameterElement *>(el)
.Case<CustomDirective>([&](auto custom) {
.Case([&](CustomDirective *custom) {
FailureOr<ParameterElement *> maybeParam =
custom->template getFrontAs<ParameterElement>();
return *maybeParam;
})
.Case<ParameterElement>([&](auto param) { return param; })
.Case<RefDirective>(
[&](auto ref) { return cast<ParameterElement>(ref->getArg()); })
.Case([&](ParameterElement *param) { return param; })
.Case([&](RefDirective *ref) {
return cast<ParameterElement>(ref->getArg());
})
.DefaultUnreachable("unexpected struct element type");
}