diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index cfdcd9cc2d86..3372faf4b16c 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -424,11 +424,11 @@ static Value getOriginalVectorValue(Value value) { Value current = value; while (Operation *definingOp = current.getDefiningOp()) { bool skipOp = llvm::TypeSwitch(definingOp) - .Case([¤t](auto op) { + .Case([¤t](vector::ShapeCastOp op) { current = op.getSource(); return true; }) - .Case([¤t](auto op) { + .Case([¤t](vector::BroadcastOp op) { current = op.getSource(); return false; }) diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index a0eed19f01a8..8c92b7e2b718 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -267,10 +267,10 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { static std::optional getTypeMangling(Type type) { return TypeSwitch>(type) - .Case([](auto) { return "Dhj"; }) - .Case([](auto) { return "fj"; }) - .Case([](auto) { return "dj"; }) - .Case([](auto intTy) -> std::optional { + .Case([](Float16Type) { return "Dhj"; }) + .Case([](Float32Type) { return "fj"; }) + .Case([](Float64Type) { return "dj"; }) + .Case([](IntegerType intTy) -> std::optional { switch (intTy.getWidth()) { case 8: return "cj"; diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index cde23403ad9f..2774adb071c9 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -812,7 +812,7 @@ void ConvertMathToFuncsPass::generateOpImplementations() { module.walk([&](Operation *op) { TypeSwitch(op) - .Case([&](math::CountLeadingZerosOp op) { + .Case([&](math::CountLeadingZerosOp op) { if (!convertCtlz || !isConvertible(op)) return; Type resultType = getElementTypeOrSelf(op.getResult().getType()); @@ -824,7 +824,7 @@ void ConvertMathToFuncsPass::generateOpImplementations() { if (entry.second) entry.first->second = createCtlzFunc(&module, resultType); }) - .Case([&](math::IPowIOp op) { + .Case([&](math::IPowIOp op) { if (!isConvertible(op)) return; @@ -837,7 +837,7 @@ void ConvertMathToFuncsPass::generateOpImplementations() { if (entry.second) entry.first->second = createElementIPowIFunc(&module, resultType); }) - .Case([&](math::FPowIOp op) { + .Case([&](math::FPowIOp op) { if (!isFPowIConvertible(op)) return; diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 39d4815dc73b..4490c326d864 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -350,7 +350,7 @@ static void getNonTreePredicates(pdl::PatternOp pattern, .Case([&](pdl::AttributeOp attrOp) { getAttributePredicates(attrOp, predList, builder, inputs); }) - .Case([&](auto constraintOp) { + .Case([&](pdl::ApplyNativeConstraintOp constraintOp) { getConstraintPredicates(constraintOp, predList, builder, inputs); }) .Case([&](auto resultOp) { @@ -471,7 +471,7 @@ static void buildCostGraph(ArrayRef roots, RootOrderingGraph &graph, // We intentionally do not traverse attributes and types, because those // are expensive to join on. TypeSwitch(entry.value.getDefiningOp()) - .Case([&](auto operationOp) { + .Case([&](pdl::OperationOp operationOp) { OperandRange operands = operationOp.getOperandValues(); // Special case when we pass all the operands in one range. // For those, the index is empty. @@ -544,7 +544,7 @@ static void visitUpward(std::vector &predList, Position *&pos, unsigned rootID) { Value value = opIndex.parent; TypeSwitch(value.getDefiningOp()) - .Case([&](auto operationOp) { + .Case([&](pdl::OperationOp operationOp) { LDBG() << " * Value: " << value; // Get users and iterate over them. @@ -583,7 +583,7 @@ static void visitUpward(std::vector &predList, // Update the position pos = opPos; }) - .Case([&](auto resultOp) { + .Case([&](pdl::ResultOp resultOp) { // Traverse up an individual result. auto *opPos = dyn_cast(pos); assert(opPos && "operations and results must be interleaved"); @@ -592,7 +592,7 @@ static void visitUpward(std::vector &predList, // Insert the result position in case we have not visited it yet. valueToPosition.try_emplace(value, pos); }) - .Case([&](auto resultOp) { + .Case([&](pdl::ResultsOp resultOp) { // Traverse up a group of results. auto *opPos = dyn_cast(pos); assert(opPos && "operations and results must be interleaved"); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 02b61bd98936..d9144d0c5e22 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1099,10 +1099,10 @@ namespace { StringRef getTypeMangling(Type type, bool isSigned) { return llvm::TypeSwitch(type) - .Case([](auto) { return "Dh"; }) - .Case([](auto) { return "f"; }) - .Case([](auto) { return "d"; }) - .Case([isSigned](IntegerType intTy) { + .Case([](Float16Type) { return "Dh"; }) + .Case([](Float32Type) { return "f"; }) + .Case([](Float64Type) { return "d"; }) + .Case([isSigned](IntegerType intTy) { switch (intTy.getWidth()) { case 1: return "b"; diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 7d7f0a23848a..c81bb4b455b9 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -42,10 +42,8 @@ static bool isZeroConstant(Value val) { return false; return TypeSwitch(constant.getValue()) - .Case( - [](auto floatAttr) { return floatAttr.getValue().isZero(); }) - .Case( - [](auto intAttr) { return intAttr.getValue().isZero(); }) + .Case([](FloatAttr floatAttr) { return floatAttr.getValue().isZero(); }) + .Case([](IntegerAttr intAttr) { return intAttr.getValue().isZero(); }) .Default(false); } diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp index fb2b096df9c3..865182647d12 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp @@ -37,7 +37,7 @@ static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, return failure(); } return llvm::TypeSwitch(defOp) - .Case([&](memref::SubViewOp subviewOp) { + .Case([&](memref::SubViewOp subviewOp) { mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, loc, subviewOp.getMixedOffsets(), subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices, @@ -45,19 +45,18 @@ static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, memrefBase = subviewOp.getSource(); return success(); }) - .Case([&](memref::ExpandShapeOp expandShapeOp) { + .Case([&](memref::ExpandShapeOp expandShapeOp) { mlir::memref::resolveSourceIndicesExpandShape( loc, rewriter, expandShapeOp, indices, resolvedIndices, false); memrefBase = expandShapeOp.getViewSource(); return success(); }) - .Case( - [&](memref::CollapseShapeOp collapseShapeOp) { - mlir::memref::resolveSourceIndicesCollapseShape( - loc, rewriter, collapseShapeOp, indices, resolvedIndices); - memrefBase = collapseShapeOp.getViewSource(); - return success(); - }) + .Case([&](memref::CollapseShapeOp collapseShapeOp) { + mlir::memref::resolveSourceIndicesCollapseShape( + loc, rewriter, collapseShapeOp, indices, resolvedIndices); + memrefBase = collapseShapeOp.getViewSource(); + return success(); + }) .Default([&](Operation *op) { return rewriter.notifyMatchFailure( op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or " diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index cb7c3d711efd..f1c36edac609 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -155,17 +155,17 @@ public: arm_sme::CombiningKind kind = op.getKind(); if (kind == arm_sme::CombiningKind::Add) { TypeSwitch(extOp) - .Case([&](auto) { + .Case([&](arith::ExtFOp) { rewriter.replaceOpWithNewOp( op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); }) - .Case([&](auto) { + .Case([&](arith::ExtSIOp) { rewriter.replaceOpWithNewOp( op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); }) - .Case([&](auto) { + .Case([&](arith::ExtUIOp) { rewriter.replaceOpWithNewOp( op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); @@ -173,17 +173,17 @@ public: .DefaultUnreachable("unexpected extend op!"); } else if (kind == arm_sme::CombiningKind::Sub) { TypeSwitch(extOp) - .Case([&](auto) { + .Case([&](arith::ExtFOp) { rewriter.replaceOpWithNewOp( op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); }) - .Case([&](auto) { + .Case([&](arith::ExtSIOp) { rewriter.replaceOpWithNewOp( op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); }) - .Case([&](auto) { + .Case([&](arith::ExtUIOp) { rewriter.replaceOpWithNewOp( op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index a651710ec315..4e67a2c84842 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -417,11 +417,11 @@ static void forEachPredecessorTileValue(BlockArgument blockArg, unsigned argNumber = blockArg.getArgNumber(); for (Block *pred : block->getPredecessors()) { TypeSwitch(pred->getTerminator()) - .Case([&](auto branch) { + .Case([&](cf::BranchOp branch) { Value predecessorOperand = branch.getDestOperands()[argNumber]; callback(predecessorOperand); }) - .Case([&](auto condBranch) { + .Case([&](cf::CondBranchOp condBranch) { if (condBranch.getFalseDest() == block) { Value predecessorOperand = condBranch.getFalseDestOperands()[argNumber]; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index c029a49f2625..8bd508b364e3 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -368,7 +368,7 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { .Case([&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpGEMMOp); }) - .Case([&](MMAMatrixType fragTy) { + .Case([&](MMAMatrixType fragTy) { os << "mma_matrix<"; auto shape = fragTy.getShape(); for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index 70d2e113ea33..e2e63abe0a11 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -243,7 +243,7 @@ private: SmallVector tokens; tokens.reserve(asyncTokens.size()); TypeSwitch(op) - .Case([&](auto awaitOp) { + .Case([&](async::AwaitOp awaitOp) { // Add async.await ops to wait for the !gpu.async.tokens. builder.setInsertionPointAfter(op); for (auto asyncToken : asyncTokens) @@ -252,7 +252,7 @@ private: // Set `it` after the inserted async.await ops. it = builder.getInsertionPoint(); }) - .Case([&](auto executeOp) { + .Case([&](async::ExecuteOp executeOp) { // Set `it` to the beginning of the region and add asyncTokens to the // async.execute operands. it = executeOp.getBody()->begin(); diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index 0758984ab451..ba4703550a3d 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -322,7 +322,7 @@ static Value getBase(Value v) { v = op.getSource(); return true; }) - .Case([&](auto op) { + .Case([&](memref::TransposeOp op) { v = op.getIn(); return true; }) diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp index 197a5907fc4c..d5d2ca33d8a5 100644 --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -143,12 +143,9 @@ LogicalResult OperationOp::verifyRegions() { for (Operation &op : getBody().getOps()) { TypeSwitch(&op) - .Case( - [&](OperandsOp op) { insertNames("operands", op.getNames()); }) - .Case( - [&](ResultsOp op) { insertNames("results", op.getNames()); }) - .Case( - [&](RegionsOp op) { insertNames("regions", op.getNames()); }); + .Case([&](OperandsOp op) { insertNames("operands", op.getNames()); }) + .Case([&](ResultsOp op) { insertNames("results", op.getNames()); }) + .Case([&](RegionsOp op) { insertNames("regions", op.getNames()); }); } // Verify that no two operand, result or region share the same name. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 27d6355c9e22..41c6b43473aa 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -779,7 +779,7 @@ verifyStructIndices(Type baseGEPType, unsigned indexPos, return success(); return TypeSwitch(baseGEPType) - .Case([&](LLVMStructType structType) -> LogicalResult { + .Case([&](LLVMStructType structType) -> LogicalResult { auto attr = dyn_cast(indices[indexPos]); if (!attr) return emitOpError() << "expected index " << indexPos @@ -3253,13 +3253,13 @@ LogicalResult LLVMFuncOp::verify() { return WalkResult::advance(); }; return TypeSwitch(op) - .Case([&](auto landingpad) { + .Case([&](LandingpadOp landingpad) { constexpr StringLiteral errorMessage = "'llvm.landingpad' should have a consistent result type " "inside a function"; return checkType(landingpad.getType(), errorMessage); }) - .Case([&](auto resume) { + .Case([&](ResumeOp resume) { constexpr StringLiteral errorMessage = "'llvm.resume' should have a consistent input type inside a " "function"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 5dc4fa2b2d82..e24615c8d304 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -763,26 +763,24 @@ static bool isCompatibleImpl(Type type, DenseSet &compatibleTypes) { bool result = llvm::TypeSwitch(type) - .Case([&](auto structType) { + .Case([&](LLVMStructType structType) { return llvm::all_of(structType.getBody(), isCompatible); }) - .Case([&](auto funcType) { + .Case([&](LLVMFunctionType funcType) { return isCompatible(funcType.getReturnType()) && llvm::all_of(funcType.getParams(), isCompatible); }) - .Case([](auto intType) { return intType.isSignless(); }) - .Case([&](auto vecType) { + .Case([](IntegerType intType) { return intType.isSignless(); }) + .Case([&](VectorType vecType) { return vecType.getRank() == 1 && isCompatible(vecType.getElementType()); }) - .Case([&](auto pointerType) { return true; }) - .Case([&](auto extType) { + .Case([&](LLVMPointerType pointerType) { return true; }) + .Case([&](LLVMTargetExtType extType) { return llvm::all_of(extType.getTypeParams(), isCompatible); }) // clang-format off - .Case< - LLVMArrayType - >([&](auto containerType) { + .Case([&](LLVMArrayType containerType) { return isCompatible(containerType.getElementType()); }) .Case< @@ -895,12 +893,12 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { .Case([](Type) { return llvm::TypeSize::getFixed(64); }) .Case([](Type) { return llvm::TypeSize::getFixed(80); }) .Case([](Type) { return llvm::TypeSize::getFixed(128); }) - .Case([](IntegerType intTy) { + .Case([](IntegerType intTy) { return llvm::TypeSize::getFixed(intTy.getWidth()); }) .Case( [](Type) { return llvm::TypeSize::getFixed(128); }) - .Case([](VectorType t) { + .Case([](VectorType t) { assert(isCompatibleVectorType(t) && "unexpected incompatible with LLVM vector type"); llvm::TypeSize elementSize = diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 76ec8b8b7cfd..bd40c5951c5a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -4117,10 +4117,10 @@ ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, bool hasRelu) { return llvm::TypeSwitch(dstTy) - .Case([&](mlir::Float6E2M3FNType) { + .Case([&](mlir::Float6E2M3FNType) { return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); }) - .Case([&](mlir::Float6E3M2FNType) { + .Case([&](mlir::Float6E3M2FNType) { return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); }) .Default([](mlir::Type) { @@ -4145,13 +4145,13 @@ ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP); return llvm::TypeSwitch(dstTy) - .Case([&](mlir::Float8E4M3FNType) { + .Case([&](mlir::Float8E4M3FNType) { return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); }) - .Case([&](mlir::Float8E5M2Type) { + .Case([&](mlir::Float8E5M2Type) { return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); }) - .Case([&](mlir::Float8E8M0FNUType) { + .Case([&](mlir::Float8E8M0FNUType) { if (hasRoundingModeRZ) return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); else if (hasRoundingModeRP) @@ -4172,10 +4172,10 @@ ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, bool hasRelu) { return llvm::TypeSwitch(dstTy) - .Case([&](mlir::Float8E4M3FNType) { + .Case([&](mlir::Float8E4M3FNType) { return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); }) - .Case([&](mlir::Float8E5M2Type) { + .Case([&](mlir::Float8E5M2Type) { return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); }) .Default([](mlir::Type) { @@ -4210,11 +4210,11 @@ NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( llvm::Intrinsic::ID intId = llvm::TypeSwitch(curOp.getSrcType()) - .Case([&](Float8E4M3FNType type) { + .Case([&](Float8E4M3FNType type) { return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn; }) - .Case([&](Float8E5M2Type type) { + .Case([&](Float8E5M2Type type) { return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn; }) @@ -4250,11 +4250,11 @@ NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs( llvm::Intrinsic::ID intId = llvm::TypeSwitch(curOp.getSrcType()) - .Case([&](Float6E2M3FNType type) { + .Case([&](Float6E2M3FNType type) { return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn; }) - .Case([&](Float6E3M2FNType type) { + .Case([&](Float6E3M2FNType type) { return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn; }) @@ -4278,7 +4278,7 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs( llvm::Intrinsic::ID intId = llvm::TypeSwitch(curOp.getSrcType()) - .Case([&](Float4E2M1FNType type) { + .Case([&](Float4E2M1FNType type) { return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn; }) @@ -4483,11 +4483,11 @@ llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { bool hasRelu = getRelu(); return llvm::TypeSwitch(dstTy) - .Case([&](mlir::Float8E4M3FNType) { + .Case([&](mlir::Float8E4M3FNType) { return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite; }) - .Case([&](mlir::Float8E5M2Type) { + .Case([&](mlir::Float8E5M2Type) { return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite; }) @@ -4502,11 +4502,11 @@ llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() { bool hasRelu = getRelu(); return llvm::TypeSwitch(dstTy) - .Case([&](mlir::Float6E2M3FNType) { + .Case([&](mlir::Float6E2M3FNType) { return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite; }) - .Case([&](mlir::Float6E3M2FNType) { + .Case([&](mlir::Float6E3M2FNType) { return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite; }) @@ -4521,7 +4521,7 @@ llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() { bool hasRelu = getRelu(); return llvm::TypeSwitch(dstTy) - .Case([&](mlir::Float4E2M1FNType) { + .Case([&](mlir::Float4E2M1FNType) { return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite; }) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 0f0e308bba78..fab543be4cf3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -63,10 +63,10 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, return getAsOpFoldResult( TypeSwitch(v.getType()) - .Case([&](RankedTensorType t) -> Value { + .Case([&](RankedTensorType t) -> Value { return tensor::DimOp::create(builder, loc, v, dim); }) - .Case([&](MemRefType t) -> Value { + .Case([&](MemRefType t) -> Value { return memref::DimOp::create(builder, loc, v, dim); })); } @@ -78,11 +78,11 @@ static Operation *getSlice(OpBuilder &b, Location loc, Value source, ArrayRef sizes, ArrayRef strides) { return TypeSwitch(source.getType()) - .Case([&](RankedTensorType t) -> Operation * { + .Case([&](RankedTensorType t) -> Operation * { return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes, strides); }) - .Case([&](MemRefType type) -> Operation * { + .Case([&](MemRefType type) -> Operation * { return memref::SubViewOp::create(b, loc, source, offsets, sizes, strides); }) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 72acd02d0d13..2384986d49c5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -840,7 +840,7 @@ static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, ExpansionInfo &expansionInfo) { return TypeSwitch(linalgOp.getOperation()) - .Case([&](TransposeOp transposeOp) { + .Case([&](TransposeOp transposeOp) { return createExpandedTransposeOp(rewriter, transposeOp, expandedOpOperands[0], outputs[0], expansionInfo); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp index 6becc1f29afb..c55a0bf7ef9a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp @@ -76,13 +76,13 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, SmallVector preservedAttrs; Operation *newConv = TypeSwitch(operation) - .Case([&](auto op) { + .Case([&](DepthwiseConv2DNhwcHwcmOp op) { preservedAttrs = getPrunedAttributeList(op); return DepthwiseConv2DNhwcHwcOp::create( rewriter, loc, newInitTy, ValueRange{input, collapsedKernel}, ValueRange{collapsedInit}, stride, dilation); }) - .Case([&](auto op) { + .Case([&](DepthwiseConv2DNhwcHwcmQOp op) { preservedAttrs = getPrunedAttributeList(op); return DepthwiseConv2DNhwcHwcQOp::create( rewriter, loc, newInitTy, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 48ebd1644bbe..44c37d29e87d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -56,7 +56,7 @@ using namespace mlir::linalg; SmallVector mlir::linalg::peelLoop(RewriterBase &rewriter, Operation *op) { return llvm::TypeSwitch>(op) - .Case([&](scf::ForOp forOp) { + .Case([&](scf::ForOp forOp) { scf::ForOp partialIteration; if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration))) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a2e38c9f572e..74ebcee47d20 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -644,19 +644,19 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) { return llvm::TypeSwitch>(combinerOp) .Case( [&](auto op) { return CombiningKind::ADD; }) - .Case([&](auto op) { return CombiningKind::AND; }) - .Case([&](auto op) { return CombiningKind::MAXSI; }) - .Case([&](auto op) { return CombiningKind::MAXUI; }) - .Case([&](auto op) { return CombiningKind::MAXIMUMF; }) - .Case([&](auto op) { return CombiningKind::MAXNUMF; }) - .Case([&](auto op) { return CombiningKind::MINSI; }) - .Case([&](auto op) { return CombiningKind::MINUI; }) - .Case([&](auto op) { return CombiningKind::MINIMUMF; }) - .Case([&](auto op) { return CombiningKind::MINNUMF; }) + .Case([&](arith::AndIOp op) { return CombiningKind::AND; }) + .Case([&](arith::MaxSIOp op) { return CombiningKind::MAXSI; }) + .Case([&](arith::MaxUIOp op) { return CombiningKind::MAXUI; }) + .Case([&](arith::MaximumFOp op) { return CombiningKind::MAXIMUMF; }) + .Case([&](arith::MaxNumFOp op) { return CombiningKind::MAXNUMF; }) + .Case([&](arith::MinSIOp op) { return CombiningKind::MINSI; }) + .Case([&](arith::MinUIOp op) { return CombiningKind::MINUI; }) + .Case([&](arith::MinimumFOp op) { return CombiningKind::MINIMUMF; }) + .Case([&](arith::MinNumFOp op) { return CombiningKind::MINNUMF; }) .Case( [&](auto op) { return CombiningKind::MUL; }) - .Case([&](auto op) { return CombiningKind::OR; }) - .Case([&](auto op) { return CombiningKind::XOR; }) + .Case([&](arith::OrIOp op) { return CombiningKind::OR; }) + .Case([&](arith::XOrIOp op) { return CombiningKind::XOR; }) .Default(std::nullopt); } @@ -2684,21 +2684,21 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition( return failure(); return TypeSwitch(op) - .Case([&](auto linalgOp) { + .Case([&](linalg::LinalgOp linalgOp) { return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, vectorizeNDExtract, flatten1DDepthwiseConv); }) - .Case([&](auto padOp) { + .Case([&](tensor::PadOp padOp) { return vectorizePadOpPrecondition(padOp, inputVectorSizes); }) - .Case([&](auto packOp) { + .Case([&](linalg::PackOp packOp) { return vectorizePackOpPrecondition(packOp, inputVectorSizes); }) - .Case([&](auto unpackOp) { + .Case([&](linalg::UnPackOp unpackOp) { return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); }) - .Case([&](auto sliceOp) { + .Case([&](tensor::InsertSliceOp sliceOp) { return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes); }) .Default(failure()); @@ -2755,7 +2755,7 @@ FailureOr mlir::linalg::vectorize( SmallVector results; auto vectorizeResult = TypeSwitch(op) - .Case([&](auto linalgOp) { + .Case([&](linalg::LinalgOp linalgOp) { // Check for both named as well as generic convolution ops. if (isaConvolutionOpInterface(linalgOp)) { FailureOr convOr = vectorizeConvolution( @@ -2789,20 +2789,20 @@ FailureOr mlir::linalg::vectorize( // notified and we will end up with read-after-free issues! return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results); }) - .Case([&](auto padOp) { + .Case([&](tensor::PadOp padOp) { return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, results); }) - .Case([&](auto packOp) { + .Case([&](linalg::PackOp packOp) { return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results); }) - .Case([&](auto unpackOp) { + .Case([&](linalg::UnPackOp unpackOp) { return vectorizeAsTensorUnpackOp(rewriter, unpackOp, inputVectorSizes, inputScalableVecDims, results); }) - .Case([&](auto sliceOp) { + .Case([&](tensor::InsertSliceOp sliceOp) { return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes, results); }) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index e6850890bf8f..32244728ff33 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -124,67 +124,67 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, Value offset) { Location loc = op->getLoc(); llvm::TypeSwitch(op.getOperation()) - .template Case([&](auto oper) { + .Case([&](memref::AllocOp oper) { auto newAlloc = memref::AllocOp::create( rewriter, loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloc, loc, rewriter); }) - .template Case([&](auto oper) { + .Case([&](memref::AllocaOp oper) { auto newAlloca = memref::AllocaOp::create( rewriter, loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloca, loc, rewriter); }) - .template Case([&](auto op) { + .Case([&](memref::LoadOp op) { auto newLoad = memref::LoadOp::create(rewriter, loc, op->getResultTypes(), flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) - .template Case([&](auto op) { + .Case([&](memref::StoreOp op) { auto newStore = memref::StoreOp::create(rewriter, loc, op->getOperands().front(), flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) - .template Case([&](auto op) { + .Case([&](vector::LoadOp op) { auto newLoad = vector::LoadOp::create(rewriter, loc, op->getResultTypes(), flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) - .template Case([&](auto op) { + .Case([&](vector::StoreOp op) { auto newStore = vector::StoreOp::create(rewriter, loc, op->getOperands().front(), flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) - .template Case([&](auto op) { + .Case([&](vector::MaskedLoadOp op) { auto newMaskedLoad = vector::MaskedLoadOp::create( rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), op.getPassThru()); newMaskedLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedLoad.getResult()); }) - .template Case([&](auto op) { + .Case([&](vector::MaskedStoreOp op) { auto newMaskedStore = vector::MaskedStoreOp::create( rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(), op.getValueToStore()); newMaskedStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedStore); }) - .template Case([&](auto op) { + .Case([&](vector::TransferReadOp op) { auto newTransferRead = vector::TransferReadOp::create( rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); rewriter.replaceOp(op, newTransferRead.getResult()); }) - .template Case([&](auto op) { + .Case([&](vector::TransferWriteOp op) { auto newTransferWrite = vector::TransferWriteOp::create( rewriter, loc, op.getVector(), flatMemref, ValueRange{offset}); rewriter.replaceOp(op, newTransferWrite); diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp index 8cab2234ec37..e99a27a7bb89 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp @@ -359,7 +359,7 @@ public: collectGlobalsFromDeviceRegion( accOp.getRegion(), globalsToAccDeclare, accSupport, symTab); }) - .Case([&](auto func) { + .Case([&](FunctionOpInterface func) { if ((acc::isAccRoutine(func) || acc::isSpecializedAccRoutine(func)) && !func.isExternal()) @@ -367,13 +367,13 @@ public: globalsToAccDeclare, accSupport, symTab); }) - .Case([&](auto globalVarOp) { + .Case([&](acc::GlobalVariableOpInterface globalVarOp) { if (globalVarOp->getAttr(acc::getDeclareAttrName())) if (Region *initRegion = globalVarOp.getInitRegion()) collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare, accSupport, symTab); }) - .Case([&](auto privateRecipe) { + .Case([&](acc::PrivateRecipeOp privateRecipe) { if (hasRelevantRecipeUse(privateRecipe, mod)) { collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(), globalsToAccDeclare, accSupport, @@ -383,7 +383,7 @@ public: symTab); } }) - .Case([&](auto firstprivateRecipe) { + .Case([&](acc::FirstprivateRecipeOp firstprivateRecipe) { if (hasRelevantRecipeUse(firstprivateRecipe, mod)) { collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(), globalsToAccDeclare, accSupport, @@ -396,7 +396,7 @@ public: symTab); } }) - .Case([&](auto reductionRecipe) { + .Case([&](acc::ReductionRecipeOp reductionRecipe) { if (hasRelevantRecipeUse(reductionRecipe, mod)) { collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(), globalsToAccDeclare, accSupport, diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index 8af93335ca96..a28b365cdb4f 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -70,7 +70,7 @@ static void visit(Operation *op, DenseSet &visited) { // Traverse the operands / parent. TypeSwitch(op) - .Case([&visited](auto operation) { + .Case([&visited](OperationOp operation) { for (Value operand : operation.getOperandValues()) visit(operand.getDefiningOp(), visited); }) diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index d380c46f7fbe..6eabf6becf5d 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -48,9 +48,8 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) { using tensor::InsertSliceOp; value = llvm::TypeSwitch(opResult.getOwner()) - .template Case( - [&](InsertSliceOp op) { return op.getDest(); }) - .template Case([&](ForOp forOp) { + .Case([&](InsertSliceOp op) { return op.getDest(); }) + .Case([&](ForOp forOp) { return isShapePreserving(forOp, opResult.getResultNumber()) ? forOp.getInitArgs()[opResult.getResultNumber()] : Value(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 1962538d804a..34e06bf52f70 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -341,11 +341,11 @@ LogicalResult spirv::CompositeConstructOp::verify() { // 3. Arrays (1 constituent for each array element) // 4. Vectors (1 constituent (sub-)element for each vector element) - auto coopElementType = - llvm::TypeSwitch(getType()) - .Case( - [](auto coopType) { return coopType.getElementType(); }) - .Default(nullptr); + auto coopElementType = llvm::TypeSwitch(getType()) + .Case([](spirv::CooperativeMatrixType coopType) { + return coopType.getElementType(); + }) + .Default(nullptr); // Case 1. -- matrices. if (coopElementType) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 53a48abe5ad0..342a47cdefbf 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -50,10 +50,10 @@ public: [this](auto concreteType) { addConcrete(concreteType); }) .Case( [this](auto concreteType) { add(concreteType.getElementType()); }) - .Case([this](SampledImageType concreteType) { + .Case([this](SampledImageType concreteType) { add(concreteType.getImageType()); }) - .Case([this](StructType concreteType) { + .Case([this](StructType concreteType) { for (Type elementType : concreteType.getElementTypes()) add(elementType); }) @@ -97,13 +97,13 @@ public: .Case( [this](auto concreteType) { addConcrete(concreteType); }) - .Case([this](ArrayType concreteType) { + .Case([this](ArrayType concreteType) { add(concreteType.getElementType()); }) - .Case([this](SampledImageType concreteType) { + .Case([this](SampledImageType concreteType) { add(concreteType.getImageType()); }) - .Case([this](StructType concreteType) { + .Case([this](StructType concreteType) { for (Type elementType : concreteType.getElementTypes()) add(elementType); }) @@ -195,9 +195,8 @@ Type CompositeType::getElementType(unsigned index) const { return TypeSwitch(*this) .Case([](auto type) { return type.getElementType(); }) - .Case([](MatrixType type) { return type.getColumnType(); }) - .Case( - [index](StructType type) { return type.getElementType(index); }) + .Case([](MatrixType type) { return type.getColumnType(); }) + .Case([index](StructType type) { return type.getElementType(index); }) .DefaultUnreachable("Invalid composite type"); } @@ -205,7 +204,7 @@ unsigned CompositeType::getNumElements() const { return TypeSwitch(*this) .Case( [](auto type) { return type.getNumElements(); }) - .Case([](MatrixType type) { return type.getNumColumns(); }) + .Case([](MatrixType type) { return type.getNumColumns(); }) .DefaultUnreachable("Invalid type for number of elements query"); } @@ -704,7 +703,7 @@ void SPIRVType::getCapabilities( std::optional SPIRVType::getSizeInBytes() { return TypeSwitch>(*this) - .Case([](ScalarType type) -> std::optional { + .Case([](ScalarType type) -> std::optional { // According to the SPIR-V spec: // "There is no physical size or bit pattern defined for values with // boolean type. If they are stored (in conjunction with OpVariable), @@ -717,7 +716,7 @@ std::optional SPIRVType::getSizeInBytes() { return std::nullopt; return bitWidth / 8; }) - .Case([](ArrayType type) -> std::optional { + .Case([](ArrayType type) -> std::optional { // Since array type may have an explicit stride declaration (in bytes), // we also include it in the calculation. auto elementType = cast(type.getElementType()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 816226749463..973c16e62bb1 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -628,9 +628,9 @@ static spirv::Dim convertRank(int64_t rank) { static spirv::ImageFormat getImageFormat(Type elementType) { return TypeSwitch(elementType) - .Case([](Float16Type) { return spirv::ImageFormat::R16f; }) - .Case([](Float32Type) { return spirv::ImageFormat::R32f; }) - .Case([](IntegerType intType) { + .Case([](Float16Type) { return spirv::ImageFormat::R16f; }) + .Case([](Float32Type) { return spirv::ImageFormat::R32f; }) + .Case([](IntegerType intType) { auto const isSigned = intType.isSigned() || intType.isSignless(); #define BIT_WIDTH_CASE(BIT_WIDTH) \ case BIT_WIDTH: \ diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 31b004abb015..ab6cdcd848c4 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2658,13 +2658,13 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, transform::TransformState &state) { int64_t numPayloads = llvm::TypeSwitch(getHandle().getType()) - .Case([&](auto x) { + .Case([&](TransformHandleTypeInterface x) { return llvm::range_size(state.getPayloadOps(getHandle())); }) - .Case([&](auto x) { + .Case([&](TransformValueHandleTypeInterface x) { return llvm::range_size(state.getPayloadValues(getHandle())); }) - .Case([&](auto x) { + .Case([&](TransformParamTypeInterface x) { return llvm::range_size(state.getParams(getHandle())); }) .DefaultUnreachable("unknown transform dialect type interface"); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 3a3231d51336..ccb3c01669f1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -113,8 +113,9 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type()); std::optional newMask = TypeSwitch>(maskOp) - .Case( - [&](auto createMaskOp) -> std::optional { + .Case( + [&](vector::CreateMaskOp createMaskOp) + -> std::optional { OperandRange maskOperands = createMaskOp.getOperands(); // The `vector.create_mask` op creates a mask arrangement // without any zeros at the front. Also, because @@ -134,8 +135,8 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, return vector::CreateMaskOp::create(rewriter, loc, newMaskType, newMaskOperands); }) - .Case([&](auto constantMaskOp) - -> std::optional { + .Case([&](vector::ConstantMaskOp constantMaskOp) + -> std::optional { // Take the shape of mask, compress its trailing dimension: SmallVector maskDimSizes(constantMaskOp.getMaskDimSizes()); int64_t &maskIndex = maskDimSizes.back(); @@ -144,8 +145,8 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, return vector::ConstantMaskOp::create(rewriter, loc, newMaskType, maskDimSizes); }) - .Case([&](auto constantOp) - -> std::optional { + .Case([&](arith::ConstantOp constantOp) + -> std::optional { // TODO: Support multiple dimensions. if (maskShape.size() != 1) return std::nullopt; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index ea93085849e0..c17d3862c0ea 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -881,7 +881,7 @@ static bool isLinearizable(Operation *op) { // As type legalization is done with vector.shape_cast, shape_cast // itself cannot be linearized (will create new shape_casts to linearize // ad infinitum). - .Case([&](auto) { return false; }) + .Case([&](vector::ShapeCastOp) { return false; }) // The operations // - vector.extract_strided_slice // - vector.extract @@ -891,18 +891,16 @@ static bool isLinearizable(Operation *op) { // vector.shuffle only supports fixed size vectors, so it is impossible to // use this approach to linearize these ops if they operate on scalable // vectors. - .Case( - [&](vector::ExtractStridedSliceOp extractOp) { - return !extractOp.getType().isScalable(); - }) - .Case( - [&](vector::InsertStridedSliceOp insertOp) { - return !insertOp.getType().isScalable(); - }) - .Case([&](vector::InsertOp insertOp) { + .Case([&](vector::ExtractStridedSliceOp extractOp) { + return !extractOp.getType().isScalable(); + }) + .Case([&](vector::InsertStridedSliceOp insertOp) { return !insertOp.getType().isScalable(); }) - .Case([&](vector::ExtractOp extractOp) { + .Case([&](vector::InsertOp insertOp) { + return !insertOp.getType().isScalable(); + }) + .Case([&](vector::ExtractOp extractOp) { return !extractOp.getSourceVectorType().isScalable(); }) .Default([&](auto) { return true; }); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index c307fb441e3a..e123f9e21bbe 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -300,11 +300,12 @@ SmallVector vector::getMixedSizesXfer(bool hasTensorSemantics, RewriterBase &rewriter) { auto loc = xfer->getLoc(); - Value base = TypeSwitch(xfer) - .Case( - [&](auto readOp) { return readOp.getBase(); }) - .Case( - [&](auto writeOp) { return writeOp.getOperand(1); }); + Value base = + TypeSwitch(xfer) + .Case([&](vector::TransferReadOp readOp) { return readOp.getBase(); }) + .Case([&](vector::TransferWriteOp writeOp) { + return writeOp.getOperand(1); + }); SmallVector mixedSourceDims = hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index b46f6c7e751a..9a88310ccd3c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -460,43 +460,45 @@ LogicalResult LayoutInfoPropagation::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { TypeSwitch(op) - .Case( - [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); }) - .Case( - [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); }) - .Case([&](auto storeScatterOp) { + .Case( + [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); }) + .Case([&](xegpu::StoreNdOp storeNdOp) { + visitStoreNdOp(storeNdOp, operands, results); + }) + .Case([&](xegpu::StoreScatterOp storeScatterOp) { visitStoreScatterOp(storeScatterOp, operands, results); }) - .Case( - [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); }) - .Case([&](auto loadGatherOp) { + .Case([&](xegpu::LoadNdOp loadNdOp) { + visitLoadNdOp(loadNdOp, operands, results); + }) + .Case([&](xegpu::LoadGatherOp loadGatherOp) { visitLoadGatherOp(loadGatherOp, operands, results); }) - .Case([&](auto createDescOp) { + .Case([&](xegpu::CreateDescOp createDescOp) { visitCreateDescOp(createDescOp, operands, results); }) - .Case([&](auto updateNdOffsetOp) { + .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) { visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results); }) - .Case([&](auto prefetchNdOp) { + .Case([&](xegpu::PrefetchNdOp prefetchNdOp) { visitPrefetchNdOp(prefetchNdOp, operands, results); }) - .Case([&](auto transposeOp) { + .Case([&](vector::TransposeOp transposeOp) { visitTransposeOp(transposeOp, operands, results); }) - .Case([&](auto bitcastOp) { + .Case([&](vector::BitCastOp bitcastOp) { visitVectorBitcastOp(bitcastOp, operands, results); }) - .Case([&](auto reductionOp) { + .Case([&](vector::MultiDimReductionOp reductionOp) { visitVectorMultiReductionOp(reductionOp, operands, results); }) - .Case([&](auto broadcastOp) { + .Case([&](vector::BroadcastOp broadcastOp) { visitVectorBroadCastOp(broadcastOp, operands, results); }) - .Case([&](auto shapeCastOp) { + .Case([&](vector::ShapeCastOp shapeCastOp) { visitShapeCastOp(shapeCastOp, operands, results); }) - .Case([&](auto storeMatrixOp) { + .Case([&](xegpu::StoreMatrixOp storeMatrixOp) { visitStoreMatrixOp(storeMatrixOp, operands, results); }) // All other ops. @@ -1641,16 +1643,14 @@ void XeGPUPropagateLayoutPass::runOnOperation() { for (mlir::Operation &op : llvm::reverse(block->getOperations())) { LogicalResult r = success(); TypeSwitch(&op) - .Case( - [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) { - r = updateControlFlowOps(builder, branchTermOp, - getXeGPULayoutForValue); - }) - .Case( - [&](mlir::FunctionOpInterface funcOp) { - r = updateFunctionOpInterface(builder, funcOp, - getXeGPULayoutForValue); - }) + .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) { + r = updateControlFlowOps(builder, branchTermOp, + getXeGPULayoutForValue); + }) + .Case([&](mlir::FunctionOpInterface funcOp) { + r = updateFunctionOpInterface(builder, funcOp, + getXeGPULayoutForValue); + }) .Default([&](Operation *op) { r = updateOp(builder, op, getXeGPULayoutForValue); }); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index e2f36ff37883..81455699421c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2157,16 +2157,16 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, return; TypeSwitch(loc) - .Case([&](OpaqueLoc loc) { + .Case([&](OpaqueLoc loc) { printLocationInternal(loc.getFallbackLocation(), pretty); }) - .Case([&](UnknownLoc loc) { + .Case([&](UnknownLoc loc) { if (pretty) os << "[unknown]"; else os << "unknown"; }) - .Case([&](FileLineColRange loc) { + .Case([&](FileLineColRange loc) { if (pretty) os << loc.getFilename().getValue(); else @@ -2184,7 +2184,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to " << loc.getEndLine() << ':' << loc.getEndColumn(); }) - .Case([&](NameLoc loc) { + .Case([&](NameLoc loc) { printEscapedString(loc.getName()); // Print the child if it isn't unknown. @@ -2195,7 +2195,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, os << ')'; } }) - .Case([&](CallSiteLoc loc) { + .Case([&](CallSiteLoc loc) { Location caller = loc.getCaller(); Location callee = loc.getCallee(); if (!pretty) @@ -2218,7 +2218,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, if (!pretty) os << ")"; }) - .Case([&](FusedLoc loc) { + .Case([&](FusedLoc loc) { if (!pretty) os << "fused"; if (Attribute metadata = loc.getMetadata()) { @@ -2744,7 +2744,7 @@ void AsmPrinter::Impl::printType(Type type) { void AsmPrinter::Impl::printTypeImpl(Type type) { TypeSwitch(type) - .Case([&](OpaqueType opaqueTy) { + .Case([&](OpaqueType opaqueTy) { printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), opaqueTy.getTypeData()); }) @@ -2767,14 +2767,14 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { .Case([&](Type) { os << "f64"; }) .Case([&](Type) { os << "f80"; }) .Case([&](Type) { os << "f128"; }) - .Case([&](IntegerType integerTy) { + .Case([&](IntegerType integerTy) { if (integerTy.isSigned()) os << 's'; else if (integerTy.isUnsigned()) os << 'u'; os << 'i' << integerTy.getWidth(); }) - .Case([&](FunctionType funcTy) { + .Case([&](FunctionType funcTy) { os << '('; interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); os << ") -> "; @@ -2787,7 +2787,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { os << ')'; } }) - .Case([&](VectorType vectorTy) { + .Case([&](VectorType vectorTy) { auto scalableDims = vectorTy.getScalableDims(); os << "vector<"; auto vShape = vectorTy.getShape(); @@ -2804,7 +2804,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { printType(vectorTy.getElementType()); os << '>'; }) - .Case([&](RankedTensorType tensorTy) { + .Case([&](RankedTensorType tensorTy) { os << "tensor<"; printDimensionList(tensorTy.getShape()); if (!tensorTy.getShape().empty()) @@ -2817,12 +2817,12 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { } os << '>'; }) - .Case([&](UnrankedTensorType tensorTy) { + .Case([&](UnrankedTensorType tensorTy) { os << "tensor<*x"; printType(tensorTy.getElementType()); os << '>'; }) - .Case([&](MemRefType memrefTy) { + .Case([&](MemRefType memrefTy) { os << "memref<"; printDimensionList(memrefTy.getShape()); if (!memrefTy.getShape().empty()) @@ -2840,7 +2840,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { } os << '>'; }) - .Case([&](UnrankedMemRefType memrefTy) { + .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); // Only print the memory space if it is the non-default one. @@ -2850,19 +2850,19 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { } os << '>'; }) - .Case([&](ComplexType complexTy) { + .Case([&](ComplexType complexTy) { os << "complex<"; printType(complexTy.getElementType()); os << '>'; }) - .Case([&](TupleType tupleTy) { + .Case([&](TupleType tupleTy) { os << "tuple<"; interleaveComma(tupleTy.getTypes(), [&](Type type) { printType(type); }); os << '>'; }) .Case([&](Type) { os << "none"; }) - .Case([&](GraphType graphTy) { + .Case([&](GraphType graphTy) { os << '('; interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); }); os << ") -> "; diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index 782384999c70..2b73001bb55e 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -389,7 +389,7 @@ collectParentLayouts(Operation *leaf, for (Operation *parent = leaf->getParentOp(); parent != nullptr; parent = parent->getParentOp()) { llvm::TypeSwitch(parent) - .Case([&](ModuleOp op) { + .Case([&](ModuleOp op) { // Skip top-level module op unless it has a layout. Top-level module // without layout is most likely the one implicitly added by the // parser and it doesn't have location. Top-level null specification @@ -401,7 +401,7 @@ collectParentLayouts(Operation *leaf, if (opLocations) opLocations->push_back(op.getLoc()); }) - .Case([&](DataLayoutOpInterface op) { + .Case([&](DataLayoutOpInterface op) { specs.push_back(op.getDataLayoutSpec()); if (opLocations) opLocations->push_back(op.getLoc()); diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 159aa5468603..ede7d8a4006f 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -404,7 +404,7 @@ struct ByteCodeWriter { [](Type) { return PDLValue::Kind::Attribute; }) .Case( [](Type) { return PDLValue::Kind::Operation; }) - .Case([](pdl::RangeType rangeTy) { + .Case([](pdl::RangeType rangeTy) { if (isa(rangeTy.getElementType())) return PDLValue::Kind::TypeRange; return PDLValue::Kind::ValueRange; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 926ffd0e363a..82dfbcbfa4d4 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -852,7 +852,7 @@ bool Operator::hasAssemblyFormat() const { StringRef Operator::getAssemblyFormat() const { return TypeSwitch(def.getValueInit("assemblyFormat")) - .Case([&](auto *init) { return init->getValue(); }); + .Case([&](const StringInit *init) { return init->getValue(); }); } void Operator::print(llvm::raw_ostream &os) const { diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index 0f1bf83d1987..39d6e64c91d0 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -52,7 +52,7 @@ std::optional TypeConstraint::getBuilderCall() const { return std::nullopt; return TypeSwitch>( builderCall->getValue()) - .Case([&](auto *init) { + .Case([&](const llvm::StringInit *init) { StringRef value = init->getValue(); return value.empty() ? std::optional() : value; }) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index c05ec2e411d5..fd96395f7723 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -70,19 +70,19 @@ static inline LogicalResult interleaveCommaWithError(const Container &c, /// imply higher precedence. static FailureOr getOperatorPrecedence(Operation *operation) { return llvm::TypeSwitch>(operation) - .Case([&](auto op) { return 15; }) - .Case([&](auto op) { return 12; }) - .Case([&](auto op) { return 15; }) - .Case([&](auto op) { return 7; }) - .Case([&](auto op) { return 11; }) - .Case([&](auto op) { return 15; }) - .Case([&](auto op) { return 5; }) - .Case([&](auto op) { return 11; }) - .Case([&](auto op) { return 6; }) - .Case([&](auto op) { return 16; }) - .Case([&](auto op) { return 16; }) - .Case([&](auto op) { return 15; }) - .Case([&](auto op) -> FailureOr { + .Case([&](emitc::AddressOfOp op) { return 15; }) + .Case([&](emitc::AddOp op) { return 12; }) + .Case([&](emitc::ApplyOp op) { return 15; }) + .Case([&](emitc::BitwiseAndOp op) { return 7; }) + .Case([&](emitc::BitwiseLeftShiftOp op) { return 11; }) + .Case([&](emitc::BitwiseNotOp op) { return 15; }) + .Case([&](emitc::BitwiseOrOp op) { return 5; }) + .Case([&](emitc::BitwiseRightShiftOp op) { return 11; }) + .Case([&](emitc::BitwiseXorOp op) { return 6; }) + .Case([&](emitc::CallOp op) { return 16; }) + .Case([&](emitc::CallOpaqueOp op) { return 16; }) + .Case([&](emitc::CastOp op) { return 15; }) + .Case([&](emitc::CmpOp op) -> FailureOr { switch (op.getPredicate()) { case emitc::CmpPredicate::eq: case emitc::CmpPredicate::ne: @@ -97,18 +97,18 @@ static FailureOr getOperatorPrecedence(Operation *operation) { } return op->emitError("unsupported cmp predicate"); }) - .Case([&](auto op) { return 2; }) - .Case([&](auto op) { return 17; }) - .Case([&](auto op) { return 13; }) - .Case([&](auto op) { return 16; }) - .Case([&](auto op) { return 4; }) - .Case([&](auto op) { return 15; }) - .Case([&](auto op) { return 3; }) - .Case([&](auto op) { return 13; }) - .Case([&](auto op) { return 13; }) - .Case([&](auto op) { return 12; }) - .Case([&](auto op) { return 15; }) - .Case([&](auto op) { return 15; }) + .Case([&](emitc::ConditionalOp op) { return 2; }) + .Case([&](emitc::ConstantOp op) { return 17; }) + .Case([&](emitc::DivOp op) { return 13; }) + .Case([&](emitc::LoadOp op) { return 16; }) + .Case([&](emitc::LogicalAndOp op) { return 4; }) + .Case([&](emitc::LogicalNotOp op) { return 15; }) + .Case([&](emitc::LogicalOrOp op) { return 3; }) + .Case([&](emitc::MulOp op) { return 13; }) + .Case([&](emitc::RemOp op) { return 13; }) + .Case([&](emitc::SubOp op) { return 12; }) + .Case([&](emitc::UnaryMinusOp op) { return 15; }) + .Case([&](emitc::UnaryPlusOp op) { return 15; }) .Default([](auto op) { return op->emitError("unsupported operation"); }); } @@ -1835,7 +1835,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { LogicalResult status = llvm::TypeSwitch(&op) // Builtin ops. - .Case([&](auto op) { return printOperation(*this, op); }) + .Case([&](ModuleOp op) { return printOperation(*this, op); }) // CF ops. .Case( [&](auto op) { return printOperation(*this, op); }) diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp index 253f3172aff9..046c7dd0fccf 100644 --- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp +++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp @@ -647,11 +647,10 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) { dialect.walk([&](mlir::Operation *op) { res = llvm::TypeSwitch(op) - .Case(([](irdl::DialectOp) { return success(); })) - .Case( - ([](irdl::OperationOp) { return success(); })) - .Case(([](irdl::TypeOp) { return success(); })) - .Case(([](irdl::OperandsOp op) -> LogicalResult { + .Case(([](irdl::DialectOp) { return success(); })) + .Case(([](irdl::OperationOp) { return success(); })) + .Case(([](irdl::TypeOp) { return success(); })) + .Case(([](irdl::OperandsOp op) -> LogicalResult { if (llvm::all_of( op.getVariadicity(), [](irdl::VariadicityAttr attr) { return attr.getValue() == irdl::Variadicity::single; @@ -660,7 +659,7 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) { return op.emitError("IRDL C++ translation does not yet support " "variadic operations"); })) - .Case(([](irdl::ResultsOp op) -> LogicalResult { + .Case(([](irdl::ResultsOp op) -> LogicalResult { if (llvm::all_of( op.getVariadicity(), [](irdl::VariadicityAttr attr) { return attr.getValue() == irdl::Variadicity::single; @@ -669,9 +668,9 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) { return op.emitError( "IRDL C++ translation does not yet support variadic results"); })) - .Case(([](irdl::AnyOp) { return success(); })) - .Case(([](irdl::RegionOp) { return success(); })) - .Case(([](irdl::RegionsOp) { return success(); })) + .Case(([](irdl::AnyOp) { return success(); })) + .Case(([](irdl::RegionOp) { return success(); })) + .Case(([](irdl::RegionsOp) { return success(); })) .Default([](mlir::Operation *op) -> LogicalResult { return op->emitError("IRDL C++ translation does not yet support " "translation of ") diff --git a/mlir/lib/Target/IRDLToCpp/TranslationRegistration.cpp b/mlir/lib/Target/IRDLToCpp/TranslationRegistration.cpp index 2a991662738b..55563e348086 100644 --- a/mlir/lib/Target/IRDLToCpp/TranslationRegistration.cpp +++ b/mlir/lib/Target/IRDLToCpp/TranslationRegistration.cpp @@ -27,10 +27,10 @@ void registerIRDLToCppTranslation() { "irdl-to-cpp", "translate IRDL dialect definitions to C++ definitions", [](Operation *op, raw_ostream &output) { return TypeSwitch(op) - .Case([&](irdl::DialectOp dialectOp) { + .Case([&](irdl::DialectOp dialectOp) { return irdl::translateIRDLDialectToCpp(dialectOp, output); }) - .Case([&](ModuleOp moduleOp) { + .Case([&](ModuleOp moduleOp) { for (Operation &op : moduleOp.getBody()->getOperations()) if (auto dialectOp = llvm::dyn_cast(op)) if (failed( diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp index 64fcb5862912..bad6ce1e63cb 100644 --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -284,7 +284,7 @@ DebugTranslation::translateRecursive(DIRecursiveTypeAttrInterface attr) { llvm::DINode *result = TypeSwitch(attr) - .Case([&](auto attr) { + .Case([&](DICompositeTypeAttr attr) { auto temporary = translateTemporaryImpl(attr); setRecursivePlaceholder(temporary.get()); // Must call `translateImpl` directly instead of `translate` to @@ -293,7 +293,7 @@ DebugTranslation::translateRecursive(DIRecursiveTypeAttrInterface attr) { temporary->replaceAllUsesWith(concrete); return concrete; }) - .Case([&](auto attr) { + .Case([&](DISubprogramAttr attr) { auto temporary = translateTemporaryImpl(attr); setRecursivePlaceholder(temporary.get()); // Must call `translateImpl` directly instead of `translate` to diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 31636abfff27..eb3a48a213de 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -333,16 +333,16 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, for (auto flagAttr : flags.getAsRange()) { llvm::Metadata *valueMetadata = llvm::TypeSwitch(flagAttr.getValue()) - .Case([&](auto strAttr) { + .Case([&](StringAttr strAttr) { return llvm::MDString::get(builder.getContext(), strAttr.getValue()); }) - .Case([&](auto intAttr) { + .Case([&](IntegerAttr intAttr) { return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( llvm::Type::getInt32Ty(builder.getContext()), intAttr.getInt())); }) - .Case([&](auto arrayAttr) { + .Case([&](ArrayAttr arrayAttr) { return convertModuleFlagValue(flagAttr.getKey().getValue(), arrayAttr, builder, moduleTranslation); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index ccefe1715604..672e87790456 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -448,7 +448,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); }) .Case( [&](auto op) { checkDepend(op, result); }) - .Case([&](auto op) { checkDepend(op, result); }) + .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); }) .Case([&](omp::TargetOp op) { checkAllocate(op, result); checkBare(op, result); diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp index 8d6fffcca45f..ca3301aa509a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -420,18 +420,18 @@ public: return translateTypeOffsetOp(typeOffsetOp, builder, moduleTranslation); }) - .Case([&](GatherOp gatherOp) { + .Case([&](GatherOp gatherOp) { return translateGatherOp(gatherOp, builder, moduleTranslation); }) - .Case([&](MaskedLoadOp maskedLoadOp) { + .Case([&](MaskedLoadOp maskedLoadOp) { return translateMaskedLoadOp(maskedLoadOp, builder, moduleTranslation); }) - .Case([&](MaskedStoreOp maskedStoreOp) { + .Case([&](MaskedStoreOp maskedStoreOp) { return translateMaskedStoreOp(maskedStoreOp, builder, moduleTranslation); }) - .Case([&](ScatterOp scatterOp) { + .Case([&](ScatterOp scatterOp) { return translateScatterOp(scatterOp, builder, moduleTranslation); }) .Default([&](Operation *op) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 79d65ffe4e55..9ddb0c1f1322 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1741,20 +1741,20 @@ static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, ModuleTranslation &moduleTranslation, Location loc) { return llvm::TypeSwitch(namedAttr.getValue()) - .Case([&](auto typeAttr) { + .Case([&](TypeAttr typeAttr) { attrBuilder.addTypeAttr( llvmKind, moduleTranslation.convertType(typeAttr.getValue())); return success(); }) - .Case([&](auto intAttr) { + .Case([&](IntegerAttr intAttr) { attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); return success(); }) - .Case([&](auto) { + .Case([&](UnitAttr) { attrBuilder.addAttribute(llvmKind); return success(); }) - .Case([&](auto rangeAttr) { + .Case([&](LLVM::ConstantRangeAttr rangeAttr) { attrBuilder.addConstantRangeAttr( llvmKind, llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper())); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 81f4c881bacc..ee0b19006fc6 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -25,7 +25,7 @@ static void addOperands(Operation *op, SetVector &operandSet) { if (!op) return; TypeSwitch(op) - .Case([&](linalg::LinalgOp linalgOp) { + .Case([&](linalg::LinalgOp linalgOp) { SmallVector inputOperands = linalgOp.getDpsInputs(); operandSet.insert_range(inputOperands); }) diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index d3c0f68d8efa..348026eb99b0 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -81,14 +81,15 @@ private: /// `CustomDirective` with a single parameter argument or `RefDirective`. static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) { return TypeSwitch(el) - .Case([&](auto custom) { + .Case([&](CustomDirective *custom) { FailureOr maybeParam = custom->template getFrontAs(); return *maybeParam; }) - .Case([&](auto param) { return param; }) - .Case( - [&](auto ref) { return cast(ref->getArg()); }) + .Case([&](ParameterElement *param) { return param; }) + .Case([&](RefDirective *ref) { + return cast(ref->getArg()); + }) .DefaultUnreachable("unexpected struct element type"); }