//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" using namespace mlir; static ParseResult parseAlternativesOpSelectedRegion( OpAsmParser &parser, IntegerAttr &selectedRegionAttr, std::optional &selectedRegionParam); static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, Operation *op, IntegerAttr selectedRegionAttr, Value selectedRegionParam); #define GET_OP_CLASSES #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc" #define DEBUG_TYPE "transform-tune" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") //===----------------------------------------------------------------------===// // KnobOp //===----------------------------------------------------------------------===// void transform::tune::KnobOp::getEffects( SmallVectorImpl &effects) { producesHandle(getOperation()->getOpResults(), effects); onlyReadsPayload(effects); } DiagnosedSilenceableFailure transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (getSelected()) { results.setParams(llvm::cast(getResult()), *getSelected()); return DiagnosedSilenceableFailure::success(); } return emitDefiniteFailure() << "non-deterministic choice " << getName() << " is only resolved through providing a `selected` attr"; } LogicalResult transform::tune::KnobOp::verify() { if (auto selected = getSelected()) { if (auto optionsArray = dyn_cast(getOptions())) { if (!llvm::is_contained(optionsArray, selected)) return emitOpError("provided `selected` attribute is not an element of " "`options` array of attributes"); } else LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected << " is an element of `options` attribute " << getOptions()); } return success(); } //===----------------------------------------------------------------------===// // AlternativesOp //===----------------------------------------------------------------------===// static ParseResult parseAlternativesOpSelectedRegion( OpAsmParser &parser, IntegerAttr &selectedRegionAttr, std::optional &selectedRegionParam) { size_t selectedRegionIdx; OptionalParseResult attrParseRes = parser.parseOptionalInteger(selectedRegionIdx); if (attrParseRes.has_value()) { if (failed(*attrParseRes)) return failure(); selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx); return success(); } OpAsmParser::UnresolvedOperand param; auto paramParseRes = parser.parseOptionalOperand(param); if (paramParseRes.has_value()) { if (failed(*paramParseRes)) return failure(); selectedRegionParam = param; return success(); } return parser.emitError(parser.getCurrentLocation()) << "expected either an integer attribute or a transform.param operand"; } static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, Operation *op, IntegerAttr selectedRegionAttr, Value selectedRegionParam) { if (selectedRegionAttr) printer << selectedRegionAttr.getValue(); if (selectedRegionParam) printer << selectedRegionParam; } OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( RegionSuccessor successor) { // No operands will be forwarded to the region(s). return getOperands().slice(0, 0); } void transform::tune::AlternativesOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { if (point.isParent()) if (auto selectedRegionIdx = getSelectedRegionAttr()) regions.emplace_back( &getAlternatives()[selectedRegionIdx->getSExtValue()]); else for (Region &alternative : getAlternatives()) regions.emplace_back(&alternative); else regions.push_back(RegionSuccessor::parent()); } ValueRange transform::tune::AlternativesOp::getSuccessorInputs(RegionSuccessor successor) { return successor.isParent() ? ValueRange(getOperation()->getResults()) : ValueRange(); } void transform::tune::AlternativesOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &bounds) { (void)operands; bounds.reserve(getNumRegions()); if (auto selectedRegionIdx = getSelectedRegionAttr()) { bounds.resize(getNumRegions(), InvocationBounds(0, 0)); bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1); } else { bounds.resize(getNumRegions(), InvocationBounds(0, 1)); } } void transform::tune::AlternativesOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getSelectedRegionParamMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); // TODO: should effects from regions be forwarded? } DiagnosedSilenceableFailure transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { std::optional selectedRegionIdx; if (auto selectedRegionAttr = getSelectedRegionAttr()) selectedRegionIdx = selectedRegionAttr->getSExtValue(); if (Value selectedRegionParam = getSelectedRegionParam()) { ArrayRef associatedAttrs = state.getParams(selectedRegionParam); IntegerAttr selectedRegionAttr; if (associatedAttrs.size() != 1 || !(selectedRegionAttr = dyn_cast(associatedAttrs[0]))) return emitDefiniteFailure() << "param should hold exactly one integer attribute, got: " << associatedAttrs[0]; selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue(); } if (!selectedRegionIdx) return emitDefiniteFailure() << "non-deterministic choice " << getName() << " is only resolved through providing a " "`selected_region` attr/param"; if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions()) return emitDefiniteFailure() << "'selected_region' attribute/param specifies region at index " << *selectedRegionIdx << " while op has only " << getNumRegions() << " regions"; Region &selectedRegion = getRegion(*selectedRegionIdx); auto scope = state.make_region_scope(selectedRegion); Block &block = selectedRegion.front(); // Apply the region's ops one by one. for (Operation &transform : block.without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform(cast(transform)); if (result.isDefiniteFailure()) return result; if (result.isSilenceableFailure()) { for (const auto &res : getResults()) results.set(res, {}); return result; } } // Forward the operation mapping for values yielded from the region to the // values produced by the alternatives op. transform::detail::forwardTerminatorOperands(&block, state, results); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::tune::AlternativesOp::verify() { for (auto *region : getRegions()) { auto yieldTerminator = llvm::dyn_cast_if_present(region->front().back()); if (!yieldTerminator) return emitOpError() << "expected '" << transform::YieldOp::getOperationName() << "' as terminator"; if (yieldTerminator->getNumOperands() != getNumResults()) return yieldTerminator.emitOpError() << "expected terminator to have as many operands as the parent op " "has results"; for (auto [i, operandType, resultType] : llvm::zip_equal( llvm::seq(0, yieldTerminator->getNumOperands()), yieldTerminator->getOperands().getType(), getResultTypes())) { if (operandType == resultType) continue; return yieldTerminator.emitOpError() << "the type of the terminator operand #" << i << " must match the type of the corresponding parent op result (" << operandType << " vs " << resultType << ")"; } } if (auto selectedRegionAttr = getSelectedRegionAttr()) { int64_t regionIdx = selectedRegionAttr->getSExtValue(); if (regionIdx < 0 || regionIdx >= getNumRegions()) return emitOpError() << "'selected_region' attribute specifies region at index " << regionIdx << " while op has only " << getNumRegions() << " regions"; } return success(); }