diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md index bf639b81dd64..5ba7adc2463a 100644 --- a/mlir/docs/TargetLLVMIR.md +++ b/mlir/docs/TargetLLVMIR.md @@ -16,6 +16,14 @@ are expected to closely match the corresponding LLVM IR instructions and intrinsics. This minimizes the dependency on LLVM IR libraries in MLIR as well as reduces the churn in case of changes. +Note that many different dialects can be lowered to LLVM but are provided as +different sets of patterns and have different passes available to mlir-opt. +However, this is primarily useful for testing and prototyping, and using the +collection of patterns together is highly recommended. One place this is +important and visible is the ControlFlow dialect's branching operations which +will fail to apply if their types mismatch with the blocks they jump to in the +parent op. + SPIR-V to LLVM dialect conversion has a [dedicated document](SPIRVToLLVMDialectConversion.md). diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index cc97ef73d7bf..89012704541d 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringRef.h" #include using namespace mlir; @@ -71,34 +72,108 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { } }; -// Base class for LLVM IR lowering terminator operations with successors. -template -struct OneToOneLLVMTerminatorLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Base = OneToOneLLVMTerminatorLowering; +/// The cf->LLVM lowerings for branching ops require that the blocks they jump +/// to first have updated types which should be handled by a pattern operating +/// on the parent op. +static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter, + ValueRange operands, + ValueRange blockArgs, Location loc, + llvm::StringRef messagePrefix) { + for (const auto &idxAndTypes : + llvm::enumerate(llvm::zip(blockArgs, operands))) { + int64_t i = idxAndTypes.index(); + Value argValue = + rewriter.getRemappedValue(std::get<0>(idxAndTypes.value())); + Type operandType = std::get<1>(idxAndTypes.value()).getType(); + // In the case of an invalid jump, the block argument will have been + // remapped to an UnrealizedConversionCast. In the case of a valid jump, + // there might still be a no-op conversion cast with both types being equal. + // Consider both of these details to see if the jump would be invalid. + if (auto op = dyn_cast_or_null( + argValue.getDefiningOp())) { + if (op.getOperandTypes().front() != operandType) { + return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) { + diag << messagePrefix; + diag << "mismatched types from operand # " << i << " "; + diag << operandType; + diag << " not compatible with destination block argument type "; + diag << argValue.getType(); + diag << " which should be converted with the parent op."; + }); + } + } + } + return success(); +} + +/// Ensure that all block types were updated and then create an LLVM::BrOp +struct BranchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), - op->getSuccessors(), op->getAttrs()); + if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(), + op.getSuccessor()->getArguments(), + op.getLoc(), + /*messagePrefix=*/""))) + return failure(); + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); return success(); } }; -// FIXME: this should be tablegen'ed as well. -struct BranchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; +/// Ensure that all block types were updated and then create an LLVM::CondBrOp +struct CondBranchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, + typename cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(), + op.getFalseDest()->getArguments(), + op.getLoc(), "in false case branch "))) + return failure(); + if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(), + op.getTrueDest()->getArguments(), + op.getLoc(), "in true case branch "))) + return failure(); + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); + return success(); + } }; -struct CondBranchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; -}; -struct SwitchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; + +/// Ensure that all block types were updated and then create an LLVM::SwitchOp +struct SwitchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(), + op.getDefaultDestination()->getArguments(), + op.getLoc(), "in switch default case "))) + return failure(); + + for (const auto &i : llvm::enumerate( + llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) { + if (failed(verifyMatchingValues( + rewriter, std::get<0>(i.value()), + std::get<1>(i.value())->getArguments(), op.getLoc(), + "in switch case " + std::to_string(i.index()) + " "))) { + return failure(); + } + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); + return success(); + } }; } // namespace diff --git a/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir new file mode 100644 index 000000000000..a2afa233a26e --- /dev/null +++ b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -convert-cf-to-llvm | FileCheck %s + +func.func @name(%flag: i32, %pred: i1){ + // Test cf.br lowering failure with type mismatch + // CHECK: cf.br + %c0 = arith.constant 0 : index + cf.br ^bb1(%c0 : index) + + // Test cf.cond_br lowering failure with type mismatch in false_dest + // CHECK: cf.cond_br + ^bb1(%0: index): // 2 preds: ^bb0, ^bb2 + %c1 = arith.constant 1 : i1 + %c2 = arith.constant 1 : index + cf.cond_br %pred, ^bb2(%c1: i1), ^bb3(%c2: index) + + // Test cf.cond_br lowering failure with type mismatch in true_dest + // CHECK: cf.cond_br + ^bb2(%1: i1): + %c3 = arith.constant 1 : i1 + %c4 = arith.constant 1 : index + cf.cond_br %pred, ^bb3(%c4: index), ^bb2(%c3: i1) + + // Test cf.switch lowering failure with type mismatch in default case + // CHECK: cf.switch + ^bb3(%2: index): // pred: ^bb1 + %c5 = arith.constant 1 : i1 + %c6 = arith.constant 1 : index + cf.switch %flag : i32, [ + default: ^bb1(%c6 : index), + 42: ^bb4(%c5 : i1) + ] + + // Test cf.switch lowering failure with type mismatch in non-default case + // CHECK: cf.switch + ^bb4(%3: i1): // pred: ^bb1 + %c7 = arith.constant 1 : i1 + %c8 = arith.constant 1 : index + cf.switch %flag : i32, [ + default: ^bb2(%c7 : i1), + 41: ^bb1(%c8 : index) + ] + }