[mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This patch updates all remaining uses of the deprecated functionality in mlir/. This was done with clang-tidy as described below and further modifications to GPUBase.td and OpenMPOpsInterfaces.td. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D151542
This commit is contained in:
parent
7c52520c8d
commit
68f58812e3
@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
|||||||
|
|
||||||
// If the type is a function type, it contains the input and result types of
|
// If the type is a function type, it contains the input and result types of
|
||||||
// this operation.
|
// this operation.
|
||||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||||
result.operands))
|
result.operands))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
|||||||
mlir::LogicalResult ConstantOp::verify() {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return emitOpError("return type must match the one of the attached value "
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
|||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult TransposeOp::verify() {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
|||||||
|
|
||||||
// If the type is a function type, it contains the input and result types of
|
// If the type is a function type, it contains the input and result types of
|
||||||
// this operation.
|
// this operation.
|
||||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||||
result.operands))
|
result.operands))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
|||||||
mlir::LogicalResult ConstantOp::verify() {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return emitOpError("return type must match the one of the attached value "
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
|||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult TransposeOp::verify() {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
|||||||
|
|
||||||
// If the type is a function type, it contains the input and result types of
|
// If the type is a function type, it contains the input and result types of
|
||||||
// this operation.
|
// this operation.
|
||||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||||
result.operands))
|
result.operands))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
|||||||
mlir::LogicalResult ConstantOp::verify() {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return emitOpError("return type must match the one of the attached value "
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|||||||
if (inputs.size() != 1 || outputs.size() != 1)
|
if (inputs.size() != 1 || outputs.size() != 1)
|
||||||
return false;
|
return false;
|
||||||
// The inputs must be Tensors with the same element type.
|
// The inputs must be Tensors with the same element type.
|
||||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
|
||||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
|
||||||
if (!input || !output || input.getElementType() != output.getElementType())
|
if (!input || !output || input.getElementType() != output.getElementType())
|
||||||
return false;
|
return false;
|
||||||
// The shape is required to match if both types are ranked.
|
// The shape is required to match if both types are ranked.
|
||||||
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
|||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TransposeOp::inferShapes() {
|
void TransposeOp::inferShapes() {
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult TransposeOp::verify() {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ struct ShapeInferencePass
|
|||||||
/// operands inferred.
|
/// operands inferred.
|
||||||
static bool allOperandsInferred(Operation *op) {
|
static bool allOperandsInferred(Operation *op) {
|
||||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||||
return operandType.isa<RankedTensorType>();
|
return llvm::isa<RankedTensorType>(operandType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ struct ShapeInferencePass
|
|||||||
/// shaped result.
|
/// shaped result.
|
||||||
static bool returnsDynamicShape(Operation *op) {
|
static bool returnsDynamicShape(Operation *op) {
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||||
return !resultType.isa<RankedTensorType>();
|
return !llvm::isa<RankedTensorType>(resultType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
|||||||
|
|
||||||
// If the type is a function type, it contains the input and result types of
|
// If the type is a function type, it contains the input and result types of
|
||||||
// this operation.
|
// this operation.
|
||||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||||
result.operands))
|
result.operands))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
|||||||
mlir::LogicalResult ConstantOp::verify() {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return emitOpError("return type must match the one of the attached value "
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|||||||
if (inputs.size() != 1 || outputs.size() != 1)
|
if (inputs.size() != 1 || outputs.size() != 1)
|
||||||
return false;
|
return false;
|
||||||
// The inputs must be Tensors with the same element type.
|
// The inputs must be Tensors with the same element type.
|
||||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
|
||||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
|
||||||
if (!input || !output || input.getElementType() != output.getElementType())
|
if (!input || !output || input.getElementType() != output.getElementType())
|
||||||
return false;
|
return false;
|
||||||
// The shape is required to match if both types are ranked.
|
// The shape is required to match if both types are ranked.
|
||||||
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
|||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TransposeOp::inferShapes() {
|
void TransposeOp::inferShapes() {
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult TransposeOp::verify() {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
|
|||||||
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter,
|
||||||
LoopIterationFn processIteration) {
|
LoopIterationFn processIteration) {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||||||
|
|
||||||
// When lowering the constant operation, we allocate and assign the constant
|
// When lowering the constant operation, we allocate and assign the constant
|
||||||
// values to a corresponding memref allocation.
|
// values to a corresponding memref allocation.
|
||||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||||
|
|
||||||
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||||||
target.addIllegalDialect<toy::ToyDialect>();
|
target.addIllegalDialect<toy::ToyDialect>();
|
||||||
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
||||||
return llvm::none_of(op->getOperandTypes(),
|
return llvm::none_of(op->getOperandTypes(),
|
||||||
[](Type type) { return type.isa<TensorType>(); });
|
[](Type type) { return llvm::isa<TensorType>(type); });
|
||||||
});
|
});
|
||||||
|
|
||||||
// Now that the conversion target has been defined, we just need to provide
|
// Now that the conversion target has been defined, we just need to provide
|
||||||
|
@ -94,7 +94,7 @@ struct ShapeInferencePass
|
|||||||
/// operands inferred.
|
/// operands inferred.
|
||||||
static bool allOperandsInferred(Operation *op) {
|
static bool allOperandsInferred(Operation *op) {
|
||||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||||
return operandType.isa<RankedTensorType>();
|
return llvm::isa<RankedTensorType>(operandType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ struct ShapeInferencePass
|
|||||||
/// shaped result.
|
/// shaped result.
|
||||||
static bool returnsDynamicShape(Operation *op) {
|
static bool returnsDynamicShape(Operation *op) {
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||||
return !resultType.isa<RankedTensorType>();
|
return !llvm::isa<RankedTensorType>(resultType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
|||||||
|
|
||||||
// If the type is a function type, it contains the input and result types of
|
// If the type is a function type, it contains the input and result types of
|
||||||
// this operation.
|
// this operation.
|
||||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||||
result.operands))
|
result.operands))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
|||||||
mlir::LogicalResult ConstantOp::verify() {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return emitOpError("return type must match the one of the attached value "
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|||||||
if (inputs.size() != 1 || outputs.size() != 1)
|
if (inputs.size() != 1 || outputs.size() != 1)
|
||||||
return false;
|
return false;
|
||||||
// The inputs must be Tensors with the same element type.
|
// The inputs must be Tensors with the same element type.
|
||||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
|
||||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
|
||||||
if (!input || !output || input.getElementType() != output.getElementType())
|
if (!input || !output || input.getElementType() != output.getElementType())
|
||||||
return false;
|
return false;
|
||||||
// The shape is required to match if both types are ranked.
|
// The shape is required to match if both types are ranked.
|
||||||
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
|||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TransposeOp::inferShapes() {
|
void TransposeOp::inferShapes() {
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult TransposeOp::verify() {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
|
|||||||
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter,
|
||||||
LoopIterationFn processIteration) {
|
LoopIterationFn processIteration) {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||||||
|
|
||||||
// When lowering the constant operation, we allocate and assign the constant
|
// When lowering the constant operation, we allocate and assign the constant
|
||||||
// values to a corresponding memref allocation.
|
// values to a corresponding memref allocation.
|
||||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||||
|
|
||||||
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||||||
target.addIllegalDialect<toy::ToyDialect>();
|
target.addIllegalDialect<toy::ToyDialect>();
|
||||||
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
||||||
return llvm::none_of(op->getOperandTypes(),
|
return llvm::none_of(op->getOperandTypes(),
|
||||||
[](Type type) { return type.isa<TensorType>(); });
|
[](Type type) { return llvm::isa<TensorType>(type); });
|
||||||
});
|
});
|
||||||
|
|
||||||
// Now that the conversion target has been defined, we just need to provide
|
// Now that the conversion target has been defined, we just need to provide
|
||||||
|
@ -61,7 +61,7 @@ public:
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
|
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ struct ShapeInferencePass
|
|||||||
/// operands inferred.
|
/// operands inferred.
|
||||||
static bool allOperandsInferred(Operation *op) {
|
static bool allOperandsInferred(Operation *op) {
|
||||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||||
return operandType.isa<RankedTensorType>();
|
return llvm::isa<RankedTensorType>(operandType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ struct ShapeInferencePass
|
|||||||
/// shaped result.
|
/// shaped result.
|
||||||
static bool returnsDynamicShape(Operation *op) {
|
static bool returnsDynamicShape(Operation *op) {
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||||
return !resultType.isa<RankedTensorType>();
|
return !llvm::isa<RankedTensorType>(resultType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -101,7 +101,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
|||||||
|
|
||||||
// If the type is a function type, it contains the input and result types of
|
// If the type is a function type, it contains the input and result types of
|
||||||
// this operation.
|
// this operation.
|
||||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||||
result.operands))
|
result.operands))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@ -179,9 +179,9 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
|||||||
static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
||||||
mlir::Attribute opaqueValue,
|
mlir::Attribute opaqueValue,
|
||||||
mlir::Operation *op) {
|
mlir::Operation *op) {
|
||||||
if (type.isa<mlir::TensorType>()) {
|
if (llvm::isa<mlir::TensorType>(type)) {
|
||||||
// Check that the value is an elements attribute.
|
// Check that the value is an elements attribute.
|
||||||
auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
|
auto attrValue = llvm::dyn_cast<mlir::DenseFPElementsAttr>(opaqueValue);
|
||||||
if (!attrValue)
|
if (!attrValue)
|
||||||
return op->emitError("constant of TensorType must be initialized by "
|
return op->emitError("constant of TensorType must be initialized by "
|
||||||
"a DenseFPElementsAttr, got ")
|
"a DenseFPElementsAttr, got ")
|
||||||
@ -189,13 +189,13 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
|||||||
|
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = type.dyn_cast<mlir::RankedTensorType>();
|
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(type);
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the
|
// Check that the rank of the attribute type matches the rank of the
|
||||||
// constant result type.
|
// constant result type.
|
||||||
auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
|
auto attrType = llvm::cast<mlir::RankedTensorType>(attrValue.getType());
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return op->emitOpError("return type must match the one of the attached "
|
return op->emitOpError("return type must match the one of the attached "
|
||||||
"value attribute: ")
|
"value attribute: ")
|
||||||
@ -213,11 +213,11 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
|||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
auto resultType = type.cast<StructType>();
|
auto resultType = llvm::cast<StructType>(type);
|
||||||
llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
|
llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
|
||||||
|
|
||||||
// Verify that the initializer is an Array.
|
// Verify that the initializer is an Array.
|
||||||
auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
|
auto attrValue = llvm::dyn_cast<ArrayAttr>(opaqueValue);
|
||||||
if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
|
if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
|
||||||
return op->emitError("constant of StructType must be initialized by an "
|
return op->emitError("constant of StructType must be initialized by an "
|
||||||
"ArrayAttr with the same number of elements, got ")
|
"ArrayAttr with the same number of elements, got ")
|
||||||
@ -283,8 +283,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|||||||
if (inputs.size() != 1 || outputs.size() != 1)
|
if (inputs.size() != 1 || outputs.size() != 1)
|
||||||
return false;
|
return false;
|
||||||
// The inputs must be Tensors with the same element type.
|
// The inputs must be Tensors with the same element type.
|
||||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
|
||||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
|
||||||
if (!input || !output || input.getElementType() != output.getElementType())
|
if (!input || !output || input.getElementType() != output.getElementType())
|
||||||
return false;
|
return false;
|
||||||
// The shape is required to match if both types are ranked.
|
// The shape is required to match if both types are ranked.
|
||||||
@ -426,8 +426,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
|||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
@ -442,7 +442,7 @@ mlir::LogicalResult ReturnOp::verify() {
|
|||||||
void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
|
void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
|
||||||
mlir::Value input, size_t index) {
|
mlir::Value input, size_t index) {
|
||||||
// Extract the result type from the input type.
|
// Extract the result type from the input type.
|
||||||
StructType structTy = input.getType().cast<StructType>();
|
StructType structTy = llvm::cast<StructType>(input.getType());
|
||||||
assert(index < structTy.getNumElementTypes());
|
assert(index < structTy.getNumElementTypes());
|
||||||
mlir::Type resultType = structTy.getElementTypes()[index];
|
mlir::Type resultType = structTy.getElementTypes()[index];
|
||||||
|
|
||||||
@ -451,7 +451,7 @@ void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult StructAccessOp::verify() {
|
mlir::LogicalResult StructAccessOp::verify() {
|
||||||
StructType structTy = getInput().getType().cast<StructType>();
|
StructType structTy = llvm::cast<StructType>(getInput().getType());
|
||||||
size_t indexValue = getIndex();
|
size_t indexValue = getIndex();
|
||||||
if (indexValue >= structTy.getNumElementTypes())
|
if (indexValue >= structTy.getNumElementTypes())
|
||||||
return emitOpError()
|
return emitOpError()
|
||||||
@ -474,14 +474,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TransposeOp::inferShapes() {
|
void TransposeOp::inferShapes() {
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult TransposeOp::verify() {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
@ -598,7 +598,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
// Check that the type is either a TensorType or another StructType.
|
// Check that the type is either a TensorType or another StructType.
|
||||||
if (!elementType.isa<mlir::TensorType, StructType>()) {
|
if (!llvm::isa<mlir::TensorType, StructType>(elementType)) {
|
||||||
parser.emitError(typeLoc, "element type for a struct must either "
|
parser.emitError(typeLoc, "element type for a struct must either "
|
||||||
"be a TensorType or a StructType, got: ")
|
"be a TensorType or a StructType, got: ")
|
||||||
<< elementType;
|
<< elementType;
|
||||||
@ -619,7 +619,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
|
|||||||
void ToyDialect::printType(mlir::Type type,
|
void ToyDialect::printType(mlir::Type type,
|
||||||
mlir::DialectAsmPrinter &printer) const {
|
mlir::DialectAsmPrinter &printer) const {
|
||||||
// Currently the only toy type is a struct type.
|
// Currently the only toy type is a struct type.
|
||||||
StructType structType = type.cast<StructType>();
|
StructType structType = llvm::cast<StructType>(type);
|
||||||
|
|
||||||
// Print the struct type according to the parser format.
|
// Print the struct type according to the parser format.
|
||||||
printer << "struct<";
|
printer << "struct<";
|
||||||
@ -653,9 +653,9 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
|
|||||||
mlir::Attribute value,
|
mlir::Attribute value,
|
||||||
mlir::Type type,
|
mlir::Type type,
|
||||||
mlir::Location loc) {
|
mlir::Location loc) {
|
||||||
if (type.isa<StructType>())
|
if (llvm::isa<StructType>(type))
|
||||||
return builder.create<StructConstantOp>(loc, type,
|
return builder.create<StructConstantOp>(loc, type,
|
||||||
value.cast<mlir::ArrayAttr>());
|
llvm::cast<mlir::ArrayAttr>(value));
|
||||||
return builder.create<ConstantOp>(loc, type,
|
return builder.create<ConstantOp>(loc, type,
|
||||||
value.cast<mlir::DenseElementsAttr>());
|
llvm::cast<mlir::DenseElementsAttr>(value));
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
|
|||||||
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter,
|
||||||
LoopIterationFn processIteration) {
|
LoopIterationFn processIteration) {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||||||
|
|
||||||
// When lowering the constant operation, we allocate and assign the constant
|
// When lowering the constant operation, we allocate and assign the constant
|
||||||
// values to a corresponding memref allocation.
|
// values to a corresponding memref allocation.
|
||||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||||
|
|
||||||
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||||||
target.addIllegalDialect<toy::ToyDialect>();
|
target.addIllegalDialect<toy::ToyDialect>();
|
||||||
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
||||||
return llvm::none_of(op->getOperandTypes(),
|
return llvm::none_of(op->getOperandTypes(),
|
||||||
[](Type type) { return type.isa<TensorType>(); });
|
[](Type type) { return llvm::isa<TensorType>(type); });
|
||||||
});
|
});
|
||||||
|
|
||||||
// Now that the conversion target has been defined, we just need to provide
|
// Now that the conversion target has been defined, we just need to provide
|
||||||
|
@ -61,7 +61,7 @@ public:
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
|
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ struct ShapeInferencePass
|
|||||||
/// operands inferred.
|
/// operands inferred.
|
||||||
static bool allOperandsInferred(Operation *op) {
|
static bool allOperandsInferred(Operation *op) {
|
||||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||||
return operandType.isa<RankedTensorType>();
|
return llvm::isa<RankedTensorType>(operandType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ struct ShapeInferencePass
|
|||||||
/// shaped result.
|
/// shaped result.
|
||||||
static bool returnsDynamicShape(Operation *op) {
|
static bool returnsDynamicShape(Operation *op) {
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||||
return !resultType.isa<RankedTensorType>();
|
return !llvm::isa<RankedTensorType>(resultType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -31,7 +31,8 @@ OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
|
|||||||
|
|
||||||
/// Fold simple struct access operations that access into a constant.
|
/// Fold simple struct access operations that access into a constant.
|
||||||
OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
|
||||||
auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
|
auto structAttr =
|
||||||
|
llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
|
||||||
if (!structAttr)
|
if (!structAttr)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
@ -62,19 +62,19 @@ class FileLineColLocBreakpointManager
|
|||||||
public:
|
public:
|
||||||
Breakpoint *match(const Action &action) const override {
|
Breakpoint *match(const Action &action) const override {
|
||||||
for (const IRUnit &unit : action.getContextIRUnits()) {
|
for (const IRUnit &unit : action.getContextIRUnits()) {
|
||||||
if (auto *op = unit.dyn_cast<Operation *>()) {
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(unit)) {
|
||||||
if (auto match = matchFromLocation(op->getLoc()))
|
if (auto match = matchFromLocation(op->getLoc()))
|
||||||
return *match;
|
return *match;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto *block = unit.dyn_cast<Block *>()) {
|
if (auto *block = llvm::dyn_cast_if_present<Block *>(unit)) {
|
||||||
for (auto &op : block->getOperations()) {
|
for (auto &op : block->getOperations()) {
|
||||||
if (auto match = matchFromLocation(op.getLoc()))
|
if (auto match = matchFromLocation(op.getLoc()))
|
||||||
return *match;
|
return *match;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (Region *region = unit.dyn_cast<Region *>()) {
|
if (Region *region = llvm::dyn_cast_if_present<Region *>(unit)) {
|
||||||
if (auto match = matchFromLocation(region->getLoc()))
|
if (auto match = matchFromLocation(region->getLoc()))
|
||||||
return *match;
|
return *match;
|
||||||
continue;
|
continue;
|
||||||
|
@ -112,25 +112,25 @@ class MMAMatrixOf<list<Type> allowedTypes> :
|
|||||||
// Types for all sparse handles.
|
// Types for all sparse handles.
|
||||||
def GPU_SparseEnvHandle :
|
def GPU_SparseEnvHandle :
|
||||||
DialectType<GPU_Dialect,
|
DialectType<GPU_Dialect,
|
||||||
CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">,
|
CPred<"llvm::isa<::mlir::gpu::SparseEnvHandleType>($_self)">,
|
||||||
"sparse environment handle type">,
|
"sparse environment handle type">,
|
||||||
BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
|
BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
|
||||||
|
|
||||||
def GPU_SparseDnVecHandle :
|
def GPU_SparseDnVecHandle :
|
||||||
DialectType<GPU_Dialect,
|
DialectType<GPU_Dialect,
|
||||||
CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">,
|
CPred<"llvm::isa<::mlir::gpu::SparseDnVecHandleType>($_self)">,
|
||||||
"dense vector handle type">,
|
"dense vector handle type">,
|
||||||
BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
|
BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
|
||||||
|
|
||||||
def GPU_SparseDnMatHandle :
|
def GPU_SparseDnMatHandle :
|
||||||
DialectType<GPU_Dialect,
|
DialectType<GPU_Dialect,
|
||||||
CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">,
|
CPred<"llvm::isa<::mlir::gpu::SparseDnMatHandleType>($_self)">,
|
||||||
"dense matrix handle type">,
|
"dense matrix handle type">,
|
||||||
BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
|
BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
|
||||||
|
|
||||||
def GPU_SparseSpMatHandle :
|
def GPU_SparseSpMatHandle :
|
||||||
DialectType<GPU_Dialect,
|
DialectType<GPU_Dialect,
|
||||||
CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">,
|
CPred<"llvm::isa<::mlir::gpu::SparseSpMatHandleType>($_self)">,
|
||||||
"sparse matrix handle type">,
|
"sparse matrix handle type">,
|
||||||
BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
|
BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
|
|||||||
/*methodName=*/"getDeclareTargetDeviceType",
|
/*methodName=*/"getDeclareTargetDeviceType",
|
||||||
(ins), [{}], [{
|
(ins), [{}], [{
|
||||||
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
|
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
|
||||||
if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
|
if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
|
||||||
return dAttr.getDeviceType().getValue();
|
return dAttr.getDeviceType().getValue();
|
||||||
return {};
|
return {};
|
||||||
}]>,
|
}]>,
|
||||||
@ -108,7 +108,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
|
|||||||
/*methodName=*/"getDeclareTargetCaptureClause",
|
/*methodName=*/"getDeclareTargetCaptureClause",
|
||||||
(ins), [{}], [{
|
(ins), [{}], [{
|
||||||
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
|
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
|
||||||
if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
|
if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
|
||||||
return dAttr.getCaptureClause().getValue();
|
return dAttr.getCaptureClause().getValue();
|
||||||
return {};
|
return {};
|
||||||
}]>
|
}]>
|
||||||
|
@ -115,7 +115,7 @@ public:
|
|||||||
static bool classof(Type type);
|
static bool classof(Type type);
|
||||||
|
|
||||||
/// Allow implicit conversion to ShapedType.
|
/// Allow implicit conversion to ShapedType.
|
||||||
operator ShapedType() const { return cast<ShapedType>(); }
|
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -169,7 +169,7 @@ public:
|
|||||||
unsigned getMemorySpaceAsInt() const;
|
unsigned getMemorySpaceAsInt() const;
|
||||||
|
|
||||||
/// Allow implicit conversion to ShapedType.
|
/// Allow implicit conversion to ShapedType.
|
||||||
operator ShapedType() const { return cast<ShapedType>(); }
|
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -217,13 +217,15 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool isEmptyKey(mlir::TypeRange range) {
|
static bool isEmptyKey(mlir::TypeRange range) {
|
||||||
if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
|
if (const auto *type =
|
||||||
|
llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
|
||||||
return type == getEmptyKeyPointer();
|
return type == getEmptyKeyPointer();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isTombstoneKey(mlir::TypeRange range) {
|
static bool isTombstoneKey(mlir::TypeRange range) {
|
||||||
if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
|
if (const auto *type =
|
||||||
|
llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
|
||||||
return type == getTombstoneKeyPointer();
|
return type == getTombstoneKeyPointer();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -163,12 +163,12 @@ public:
|
|||||||
|
|
||||||
/// Return the value the effect is applied on, or nullptr if there isn't a
|
/// Return the value the effect is applied on, or nullptr if there isn't a
|
||||||
/// known value being affected.
|
/// known value being affected.
|
||||||
Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
|
Value getValue() const { return value ? llvm::dyn_cast_if_present<Value>(value) : Value(); }
|
||||||
|
|
||||||
/// Return the symbol reference the effect is applied on, or nullptr if there
|
/// Return the symbol reference the effect is applied on, or nullptr if there
|
||||||
/// isn't a known smbol being affected.
|
/// isn't a known smbol being affected.
|
||||||
SymbolRefAttr getSymbolRef() const {
|
SymbolRefAttr getSymbolRef() const {
|
||||||
return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
|
return value ? llvm::dyn_cast_if_present<SymbolRefAttr>(value) : SymbolRefAttr();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the resource that the effect applies to.
|
/// Return the resource that the effect applies to.
|
||||||
|
@ -254,7 +254,7 @@ struct NestedAnalysisMap {
|
|||||||
/// Returns the parent analysis map for this analysis map, or null if this is
|
/// Returns the parent analysis map for this analysis map, or null if this is
|
||||||
/// the top-level map.
|
/// the top-level map.
|
||||||
const NestedAnalysisMap *getParent() const {
|
const NestedAnalysisMap *getParent() const {
|
||||||
return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
|
return llvm::dyn_cast_if_present<NestedAnalysisMap *>(parentOrInstrumentor);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a pass instrumentation object for the current operation. This
|
/// Returns a pass instrumentation object for the current operation. This
|
||||||
|
@ -89,7 +89,7 @@ void SparseConstantPropagation::visitOperation(
|
|||||||
|
|
||||||
// Merge in the result of the fold, either a constant or a value.
|
// Merge in the result of the fold, either a constant or a value.
|
||||||
OpFoldResult foldResult = std::get<1>(it);
|
OpFoldResult foldResult = std::get<1>(it);
|
||||||
if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
|
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
|
LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
|
||||||
propagateIfChanged(lattice,
|
propagateIfChanged(lattice,
|
||||||
lattice->join(ConstantValue(attr, op->getDialect())));
|
lattice->join(ConstantValue(attr, op->getDialect())));
|
||||||
|
@ -31,7 +31,7 @@ void Executable::print(raw_ostream &os) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Executable::onUpdate(DataFlowSolver *solver) const {
|
void Executable::onUpdate(DataFlowSolver *solver) const {
|
||||||
if (auto *block = point.dyn_cast<Block *>()) {
|
if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
|
||||||
// Re-invoke the analyses on the block itself.
|
// Re-invoke the analyses on the block itself.
|
||||||
for (DataFlowAnalysis *analysis : subscribers)
|
for (DataFlowAnalysis *analysis : subscribers)
|
||||||
solver->enqueue({block, analysis});
|
solver->enqueue({block, analysis});
|
||||||
@ -39,7 +39,7 @@ void Executable::onUpdate(DataFlowSolver *solver) const {
|
|||||||
for (DataFlowAnalysis *analysis : subscribers)
|
for (DataFlowAnalysis *analysis : subscribers)
|
||||||
for (Operation &op : *block)
|
for (Operation &op : *block)
|
||||||
solver->enqueue({&op, analysis});
|
solver->enqueue({&op, analysis});
|
||||||
} else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
|
} else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
|
||||||
// Re-invoke the analysis on the successor block.
|
// Re-invoke the analysis on the successor block.
|
||||||
if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
|
if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
|
||||||
for (DataFlowAnalysis *analysis : subscribers)
|
for (DataFlowAnalysis *analysis : subscribers)
|
||||||
@ -219,7 +219,7 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
|
|||||||
LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
|
LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
|
||||||
if (point.is<Block *>())
|
if (point.is<Block *>())
|
||||||
return success();
|
return success();
|
||||||
auto *op = point.dyn_cast<Operation *>();
|
auto *op = llvm::dyn_cast_if_present<Operation *>(point);
|
||||||
if (!op)
|
if (!op)
|
||||||
return emitError(point.getLoc(), "unknown program point kind");
|
return emitError(point.getLoc(), "unknown program point kind");
|
||||||
|
|
||||||
|
@ -33,9 +33,9 @@ LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
|
LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
|
||||||
if (auto *op = point.dyn_cast<Operation *>())
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
|
||||||
processOperation(op);
|
processOperation(op);
|
||||||
else if (auto *block = point.dyn_cast<Block *>())
|
else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
|
||||||
visitBlock(block);
|
visitBlock(block);
|
||||||
else
|
else
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -181,7 +181,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
|
|||||||
if (auto bound =
|
if (auto bound =
|
||||||
dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
|
dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
|
||||||
return bound.getValue();
|
return bound.getValue();
|
||||||
} else if (auto value = loopBound->dyn_cast<Value>()) {
|
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
|
||||||
const IntegerValueRangeLattice *lattice =
|
const IntegerValueRangeLattice *lattice =
|
||||||
getLatticeElementFor(op, value);
|
getLatticeElementFor(op, value);
|
||||||
if (lattice != nullptr)
|
if (lattice != nullptr)
|
||||||
|
@ -66,9 +66,9 @@ AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
|
LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
|
||||||
if (Operation *op = point.dyn_cast<Operation *>())
|
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
|
||||||
visitOperation(op);
|
visitOperation(op);
|
||||||
else if (Block *block = point.dyn_cast<Block *>())
|
else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
|
||||||
visitBlock(block);
|
visitBlock(block);
|
||||||
else
|
else
|
||||||
return failure();
|
return failure();
|
||||||
@ -238,7 +238,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
|
|||||||
|
|
||||||
unsigned firstIndex = 0;
|
unsigned firstIndex = 0;
|
||||||
if (inputs.size() != lattices.size()) {
|
if (inputs.size() != lattices.size()) {
|
||||||
if (point.dyn_cast<Operation *>()) {
|
if (llvm::dyn_cast_if_present<Operation *>(point)) {
|
||||||
if (!inputs.empty())
|
if (!inputs.empty())
|
||||||
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
|
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
|
||||||
visitNonControlFlowArgumentsImpl(
|
visitNonControlFlowArgumentsImpl(
|
||||||
@ -316,9 +316,9 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
|
|||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
|
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
|
||||||
if (Operation *op = point.dyn_cast<Operation *>())
|
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
|
||||||
visitOperation(op);
|
visitOperation(op);
|
||||||
else if (point.dyn_cast<Block *>())
|
else if (llvm::dyn_cast_if_present<Block *>(point))
|
||||||
// For backward dataflow, we don't have to do any work for the blocks
|
// For backward dataflow, we don't have to do any work for the blocks
|
||||||
// themselves. CFG edges between blocks are processed by the BranchOp
|
// themselves. CFG edges between blocks are processed by the BranchOp
|
||||||
// logic in `visitOperation`, and entry blocks for functions are tied
|
// logic in `visitOperation`, and entry blocks for functions are tied
|
||||||
|
@ -39,21 +39,21 @@ void ProgramPoint::print(raw_ostream &os) const {
|
|||||||
os << "<NULL POINT>";
|
os << "<NULL POINT>";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
|
if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
|
||||||
return programPoint->print(os);
|
return programPoint->print(os);
|
||||||
if (auto *op = dyn_cast<Operation *>())
|
if (auto *op = llvm::dyn_cast<Operation *>(*this))
|
||||||
return op->print(os);
|
return op->print(os);
|
||||||
if (auto value = dyn_cast<Value>())
|
if (auto value = llvm::dyn_cast<Value>(*this))
|
||||||
return value.print(os);
|
return value.print(os);
|
||||||
return get<Block *>()->print(os);
|
return get<Block *>()->print(os);
|
||||||
}
|
}
|
||||||
|
|
||||||
Location ProgramPoint::getLoc() const {
|
Location ProgramPoint::getLoc() const {
|
||||||
if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
|
if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
|
||||||
return programPoint->getLoc();
|
return programPoint->getLoc();
|
||||||
if (auto *op = dyn_cast<Operation *>())
|
if (auto *op = llvm::dyn_cast<Operation *>(*this))
|
||||||
return op->getLoc();
|
return op->getLoc();
|
||||||
if (auto value = dyn_cast<Value>())
|
if (auto value = llvm::dyn_cast<Value>(*this))
|
||||||
return value.getLoc();
|
return value.getLoc();
|
||||||
return get<Block *>()->getParent()->getLoc();
|
return get<Block *>()->getParent()->getLoc();
|
||||||
}
|
}
|
||||||
|
@ -2060,7 +2060,7 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
|
|||||||
if (parseToken(Token::r_paren, "expected ')' in location"))
|
if (parseToken(Token::r_paren, "expected ')' in location"))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (auto *op = opOrArgument.dyn_cast<Operation *>())
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
|
||||||
op->setLoc(directLoc);
|
op->setLoc(directLoc);
|
||||||
else
|
else
|
||||||
opOrArgument.get<BlockArgument>().setLoc(directLoc);
|
opOrArgument.get<BlockArgument>().setLoc(directLoc);
|
||||||
|
@ -47,7 +47,7 @@ SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
|
|||||||
DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
|
DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
|
||||||
DictionaryAttr attributeDict;
|
DictionaryAttr attributeDict;
|
||||||
if (!mlirAttributeIsNull(attributes))
|
if (!mlirAttributeIsNull(attributes))
|
||||||
attributeDict = unwrap(attributes).cast<DictionaryAttr>();
|
attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
|
||||||
return attributeDict;
|
return attributeDict;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1190,9 +1190,9 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
|
|||||||
// TODO: safer and more flexible to store data type in actual op instead?
|
// TODO: safer and more flexible to store data type in actual op instead?
|
||||||
static Type getSpMatElemType(Value spMat) {
|
static Type getSpMatElemType(Value spMat) {
|
||||||
if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
|
if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
|
||||||
return op.getValues().getType().cast<MemRefType>().getElementType();
|
return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||||
if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
|
if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
|
||||||
return op.getValues().getType().cast<MemRefType>().getElementType();
|
return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||||
llvm_unreachable("cannot find spmat def");
|
llvm_unreachable("cannot find spmat def");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1235,7 +1235,7 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
|
|||||||
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
|
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
|
||||||
if (!getTypeConverter()->useOpaquePointers())
|
if (!getTypeConverter()->useOpaquePointers())
|
||||||
pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
|
pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
|
||||||
Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
|
Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
|
||||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||||
dType.getIntOrFloatBitWidth());
|
dType.getIntOrFloatBitWidth());
|
||||||
auto handle =
|
auto handle =
|
||||||
@ -1271,7 +1271,7 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
|
|||||||
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
|
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
|
||||||
if (!getTypeConverter()->useOpaquePointers())
|
if (!getTypeConverter()->useOpaquePointers())
|
||||||
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
|
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
|
||||||
Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
|
Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
|
||||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||||
dType.getIntOrFloatBitWidth());
|
dType.getIntOrFloatBitWidth());
|
||||||
auto handle =
|
auto handle =
|
||||||
@ -1315,8 +1315,8 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
|
|||||||
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
|
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
|
||||||
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
|
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
|
||||||
}
|
}
|
||||||
Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
|
Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
|
||||||
Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
|
Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||||
auto iw = rewriter.create<LLVM::ConstantOp>(
|
auto iw = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
|
loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
|
||||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||||
@ -1350,9 +1350,9 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
|
|||||||
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
|
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
|
||||||
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
|
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
|
||||||
}
|
}
|
||||||
Type pType = op.getRowPos().getType().cast<MemRefType>().getElementType();
|
Type pType = llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
|
||||||
Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
|
Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
|
||||||
Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
|
Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||||
auto pw = rewriter.create<LLVM::ConstantOp>(
|
auto pw = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
|
loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
|
||||||
auto iw = rewriter.create<LLVM::ConstantOp>(
|
auto iw = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
@ -405,7 +405,7 @@ LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
|
|||||||
return failure();
|
return failure();
|
||||||
if (!(*converted)) // Conversion to default is 0.
|
if (!(*converted)) // Conversion to default is 0.
|
||||||
return 0;
|
return 0;
|
||||||
if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
|
if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
|
||||||
return explicitSpace.getInt();
|
return explicitSpace.getInt();
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -671,7 +671,7 @@ struct GlobalMemrefOpLowering
|
|||||||
|
|
||||||
Attribute initialValue = nullptr;
|
Attribute initialValue = nullptr;
|
||||||
if (!global.isExternal() && !global.isUninitialized()) {
|
if (!global.isExternal() && !global.isUninitialized()) {
|
||||||
auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
|
auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
|
||||||
initialValue = elementsAttr;
|
initialValue = elementsAttr;
|
||||||
|
|
||||||
// For scalar memrefs, the global variable created is of the element type,
|
// For scalar memrefs, the global variable created is of the element type,
|
||||||
|
@ -412,10 +412,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
|||||||
auto *ans = cast<TypeAnswer>(answer);
|
auto *ans = cast<TypeAnswer>(answer);
|
||||||
if (isa<pdl::RangeType>(val.getType()))
|
if (isa<pdl::RangeType>(val.getType()))
|
||||||
builder.create<pdl_interp::CheckTypesOp>(
|
builder.create<pdl_interp::CheckTypesOp>(
|
||||||
loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
|
loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
|
||||||
else
|
else
|
||||||
builder.create<pdl_interp::CheckTypeOp>(
|
builder.create<pdl_interp::CheckTypeOp>(
|
||||||
loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
|
loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Predicates::AttributeQuestion: {
|
case Predicates::AttributeQuestion: {
|
||||||
|
@ -300,7 +300,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
|||||||
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
|
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
|
||||||
|
|
||||||
// tosa::ErfOp
|
// tosa::ErfOp
|
||||||
if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
|
if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
|
||||||
return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
|
return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
|
||||||
|
|
||||||
// tosa::GreaterOp
|
// tosa::GreaterOp
|
||||||
@ -1885,7 +1885,7 @@ public:
|
|||||||
|
|
||||||
auto addDynamicDimension = [&](Value source, int64_t dim) {
|
auto addDynamicDimension = [&](Value source, int64_t dim) {
|
||||||
auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
|
auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
|
||||||
if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
|
if (auto dimValue = llvm::dyn_cast_if_present<Value>(dynamicDim.value()))
|
||||||
results.push_back(dimValue);
|
results.push_back(dimValue);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -121,11 +121,11 @@ void mlirDebuggerCursorSelectParentIRUnit() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
IRUnit *unit = &state.cursor;
|
IRUnit *unit = &state.cursor;
|
||||||
if (auto *op = unit->dyn_cast<Operation *>()) {
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
|
||||||
state.cursor = op->getBlock();
|
state.cursor = op->getBlock();
|
||||||
} else if (auto *region = unit->dyn_cast<Region *>()) {
|
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
|
||||||
state.cursor = region->getParentOp();
|
state.cursor = region->getParentOp();
|
||||||
} else if (auto *block = unit->dyn_cast<Block *>()) {
|
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
|
||||||
state.cursor = block->getParent();
|
state.cursor = block->getParent();
|
||||||
} else {
|
} else {
|
||||||
llvm::outs() << "Current cursor is not a valid IRUnit";
|
llvm::outs() << "Current cursor is not a valid IRUnit";
|
||||||
@ -142,14 +142,14 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
IRUnit *unit = &state.cursor;
|
IRUnit *unit = &state.cursor;
|
||||||
if (auto *op = unit->dyn_cast<Operation *>()) {
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
|
||||||
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
|
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
|
||||||
llvm::outs() << "Index invalid, op has " << op->getNumRegions()
|
llvm::outs() << "Index invalid, op has " << op->getNumRegions()
|
||||||
<< " but got " << index << "\n";
|
<< " but got " << index << "\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
state.cursor = &op->getRegion(index);
|
state.cursor = &op->getRegion(index);
|
||||||
} else if (auto *region = unit->dyn_cast<Region *>()) {
|
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
|
||||||
auto block = region->begin();
|
auto block = region->begin();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while (block != region->end() && count != index) {
|
while (block != region->end() && count != index) {
|
||||||
@ -163,7 +163,7 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
state.cursor = &*block;
|
state.cursor = &*block;
|
||||||
} else if (auto *block = unit->dyn_cast<Block *>()) {
|
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
|
||||||
auto op = block->begin();
|
auto op = block->begin();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while (op != block->end() && count != index) {
|
while (op != block->end() && count != index) {
|
||||||
@ -192,14 +192,14 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
IRUnit *unit = &state.cursor;
|
IRUnit *unit = &state.cursor;
|
||||||
if (auto *op = unit->dyn_cast<Operation *>()) {
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
|
||||||
Operation *previous = op->getPrevNode();
|
Operation *previous = op->getPrevNode();
|
||||||
if (!previous) {
|
if (!previous) {
|
||||||
llvm::outs() << "No previous operation in the current block\n";
|
llvm::outs() << "No previous operation in the current block\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
state.cursor = previous;
|
state.cursor = previous;
|
||||||
} else if (auto *region = unit->dyn_cast<Region *>()) {
|
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
|
||||||
llvm::outs() << "Has region\n";
|
llvm::outs() << "Has region\n";
|
||||||
Operation *parent = region->getParentOp();
|
Operation *parent = region->getParentOp();
|
||||||
if (!parent) {
|
if (!parent) {
|
||||||
@ -212,7 +212,7 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
|
|||||||
}
|
}
|
||||||
state.cursor =
|
state.cursor =
|
||||||
®ion->getParentOp()->getRegion(region->getRegionNumber() - 1);
|
®ion->getParentOp()->getRegion(region->getRegionNumber() - 1);
|
||||||
} else if (auto *block = unit->dyn_cast<Block *>()) {
|
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
|
||||||
Block *previous = block->getPrevNode();
|
Block *previous = block->getPrevNode();
|
||||||
if (!previous) {
|
if (!previous) {
|
||||||
llvm::outs() << "No previous block in the current region\n";
|
llvm::outs() << "No previous block in the current region\n";
|
||||||
@ -234,14 +234,14 @@ void mlirDebuggerCursorSelectNextIRUnit() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
IRUnit *unit = &state.cursor;
|
IRUnit *unit = &state.cursor;
|
||||||
if (auto *op = unit->dyn_cast<Operation *>()) {
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
|
||||||
Operation *next = op->getNextNode();
|
Operation *next = op->getNextNode();
|
||||||
if (!next) {
|
if (!next) {
|
||||||
llvm::outs() << "No next operation in the current block\n";
|
llvm::outs() << "No next operation in the current block\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
state.cursor = next;
|
state.cursor = next;
|
||||||
} else if (auto *region = unit->dyn_cast<Region *>()) {
|
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
|
||||||
Operation *parent = region->getParentOp();
|
Operation *parent = region->getParentOp();
|
||||||
if (!parent) {
|
if (!parent) {
|
||||||
llvm::outs() << "No parent operation for the current region\n";
|
llvm::outs() << "No parent operation for the current region\n";
|
||||||
@ -253,7 +253,7 @@ void mlirDebuggerCursorSelectNextIRUnit() {
|
|||||||
}
|
}
|
||||||
state.cursor =
|
state.cursor =
|
||||||
®ion->getParentOp()->getRegion(region->getRegionNumber() + 1);
|
®ion->getParentOp()->getRegion(region->getRegionNumber() + 1);
|
||||||
} else if (auto *block = unit->dyn_cast<Block *>()) {
|
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
|
||||||
Block *next = block->getNextNode();
|
Block *next = block->getNextNode();
|
||||||
if (!next) {
|
if (!next) {
|
||||||
llvm::outs() << "No next block in the current region\n";
|
llvm::outs() << "No next block in the current region\n";
|
||||||
|
@ -1212,7 +1212,7 @@ static void materializeConstants(OpBuilder &b, Location loc,
|
|||||||
actualValues.reserve(values.size());
|
actualValues.reserve(values.size());
|
||||||
auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
|
auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
|
||||||
for (OpFoldResult ofr : values) {
|
for (OpFoldResult ofr : values) {
|
||||||
if (auto value = ofr.dyn_cast<Value>()) {
|
if (auto value = llvm::dyn_cast_if_present<Value>(ofr)) {
|
||||||
actualValues.push_back(value);
|
actualValues.push_back(value);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -4599,7 +4599,7 @@ void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
|
|||||||
if (staticDim.has_value())
|
if (staticDim.has_value())
|
||||||
return builder.create<arith::ConstantIndexOp>(result.location,
|
return builder.create<arith::ConstantIndexOp>(result.location,
|
||||||
*staticDim);
|
*staticDim);
|
||||||
return ofr.dyn_cast<Value>();
|
return llvm::dyn_cast_if_present<Value>(ofr);
|
||||||
});
|
});
|
||||||
result.addOperands(basisValues);
|
result.addOperands(basisValues);
|
||||||
}
|
}
|
||||||
|
@ -808,7 +808,7 @@ OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
|
|||||||
if (matchPattern(getRhs(), m_Zero()))
|
if (matchPattern(getRhs(), m_Zero()))
|
||||||
return getLhs();
|
return getLhs();
|
||||||
/// or(x, <all ones>) -> <all ones>
|
/// or(x, <all ones>) -> <all ones>
|
||||||
if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
|
if (auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
|
||||||
if (rhsAttr.getValue().isAllOnes())
|
if (rhsAttr.getValue().isAllOnes())
|
||||||
return rhsAttr;
|
return rhsAttr;
|
||||||
|
|
||||||
@ -1249,7 +1249,7 @@ LogicalResult arith::ExtSIOp::verify() {
|
|||||||
|
|
||||||
/// Always fold extension of FP constants.
|
/// Always fold extension of FP constants.
|
||||||
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
|
||||||
auto constOperand = adaptor.getIn().dyn_cast_or_null<FloatAttr>();
|
auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
|
||||||
if (!constOperand)
|
if (!constOperand)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
@ -1702,7 +1702,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
|
|||||||
|
|
||||||
// We are moving constants to the right side; So if lhs is constant rhs is
|
// We are moving constants to the right side; So if lhs is constant rhs is
|
||||||
// guaranteed to be a constant.
|
// guaranteed to be a constant.
|
||||||
if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
|
if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
|
||||||
return constFoldBinaryOp<IntegerAttr>(
|
return constFoldBinaryOp<IntegerAttr>(
|
||||||
adaptor.getOperands(), getI1SameShape(lhs.getType()),
|
adaptor.getOperands(), getI1SameShape(lhs.getType()),
|
||||||
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
|
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
|
||||||
@ -1772,8 +1772,8 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
|
||||||
auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
|
auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
|
||||||
auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
|
auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
|
||||||
|
|
||||||
// If one operand is NaN, making them both NaN does not change the result.
|
// If one operand is NaN, making them both NaN does not change the result.
|
||||||
if (lhs && lhs.getValue().isNaN())
|
if (lhs && lhs.getValue().isNaN())
|
||||||
@ -2193,11 +2193,11 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
|
|||||||
// Constant-fold constant operands over non-splat constant condition.
|
// Constant-fold constant operands over non-splat constant condition.
|
||||||
// select %cst_vec, %cst0, %cst1 => %cst2
|
// select %cst_vec, %cst0, %cst1 => %cst2
|
||||||
if (auto cond =
|
if (auto cond =
|
||||||
adaptor.getCondition().dyn_cast_or_null<DenseElementsAttr>()) {
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
|
||||||
if (auto lhs =
|
if (auto lhs =
|
||||||
adaptor.getTrueValue().dyn_cast_or_null<DenseElementsAttr>()) {
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
|
||||||
if (auto rhs =
|
if (auto rhs =
|
||||||
adaptor.getFalseValue().dyn_cast_or_null<DenseElementsAttr>()) {
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
|
||||||
SmallVector<Attribute> results;
|
SmallVector<Attribute> results;
|
||||||
results.reserve(static_cast<size_t>(cond.getNumElements()));
|
results.reserve(static_cast<size_t>(cond.getNumElements()));
|
||||||
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
|
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
|
||||||
|
@ -184,7 +184,7 @@ struct SelectOpInterface
|
|||||||
|
|
||||||
// If the buffers have different types, they differ only in their layout
|
// If the buffers have different types, they differ only in their layout
|
||||||
// map.
|
// map.
|
||||||
auto memrefType = trueType->cast<MemRefType>();
|
auto memrefType = llvm::cast<MemRefType>(*trueType);
|
||||||
return getMemRefTypeWithFullyDynamicLayout(
|
return getMemRefTypeWithFullyDynamicLayout(
|
||||||
RankedTensorType::get(memrefType.getShape(),
|
RankedTensorType::get(memrefType.getShape(),
|
||||||
memrefType.getElementType()),
|
memrefType.getElementType()),
|
||||||
|
@ -33,8 +33,8 @@ LogicalResult mlir::foldDynamicIndexList(Builder &b,
|
|||||||
if (ofr.is<Attribute>())
|
if (ofr.is<Attribute>())
|
||||||
continue;
|
continue;
|
||||||
// Newly static, move from Value to constant.
|
// Newly static, move from Value to constant.
|
||||||
if (auto cstOp =
|
if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
|
||||||
ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) {
|
.getDefiningOp<arith::ConstantIndexOp>()) {
|
||||||
ofr = b.getIndexAttr(cstOp.value());
|
ofr = b.getIndexAttr(cstOp.value());
|
||||||
valuesChanged = true;
|
valuesChanged = true;
|
||||||
}
|
}
|
||||||
@ -56,9 +56,9 @@ llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
|
|||||||
|
|
||||||
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
||||||
OpFoldResult ofr) {
|
OpFoldResult ofr) {
|
||||||
if (auto value = ofr.dyn_cast<Value>())
|
if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
|
||||||
return value;
|
return value;
|
||||||
auto attr = dyn_cast<IntegerAttr>(ofr.dyn_cast<Attribute>());
|
auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
|
||||||
assert(attr && "expect the op fold result casts to an integer attribute");
|
assert(attr && "expect the op fold result casts to an integer attribute");
|
||||||
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
|
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
|
|||||||
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
|
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
|
||||||
}
|
}
|
||||||
FailureOr<Value> alloc = options.createAlloc(
|
FailureOr<Value> alloc = options.createAlloc(
|
||||||
rewriter, loc, allocType->cast<MemRefType>(), dynamicDims);
|
rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
|
||||||
if (failed(alloc))
|
if (failed(alloc))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -59,7 +59,8 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
/// Return the func::FuncOp called by `callOp`.
|
/// Return the func::FuncOp called by `callOp`.
|
||||||
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
SymbolRefAttr sym =
|
||||||
|
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
|
||||||
if (!sym)
|
if (!sym)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return dyn_cast_or_null<func::FuncOp>(
|
return dyn_cast_or_null<func::FuncOp>(
|
||||||
|
@ -80,7 +80,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
|
|||||||
|
|
||||||
/// Return the FuncOp called by `callOp`.
|
/// Return the FuncOp called by `callOp`.
|
||||||
static FuncOp getCalledFunction(CallOpInterface callOp) {
|
static FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
|
||||||
if (!sym)
|
if (!sym)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return dyn_cast_or_null<FuncOp>(
|
return dyn_cast_or_null<FuncOp>(
|
||||||
|
@ -995,7 +995,7 @@ static void annotateOpsWithAliasSets(Operation *op,
|
|||||||
op->walk([&](Operation *op) {
|
op->walk([&](Operation *op) {
|
||||||
SmallVector<Attribute> aliasSets;
|
SmallVector<Attribute> aliasSets;
|
||||||
for (OpResult opResult : op->getOpResults()) {
|
for (OpResult opResult : op->getOpResults()) {
|
||||||
if (opResult.getType().isa<TensorType>()) {
|
if (llvm::isa<TensorType>(opResult.getType())) {
|
||||||
SmallVector<Attribute> aliases;
|
SmallVector<Attribute> aliases;
|
||||||
state.applyOnAliases(opResult, [&](Value alias) {
|
state.applyOnAliases(opResult, [&](Value alias) {
|
||||||
std::string buffer;
|
std::string buffer;
|
||||||
|
@ -238,7 +238,7 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
|
|||||||
|
|
||||||
/// Return the func::FuncOp called by `callOp`.
|
/// Return the func::FuncOp called by `callOp`.
|
||||||
static func::FuncOp getCalledFunction(func::CallOp callOp) {
|
static func::FuncOp getCalledFunction(func::CallOp callOp) {
|
||||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
|
||||||
if (!sym)
|
if (!sym)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return dyn_cast_or_null<func::FuncOp>(
|
return dyn_cast_or_null<func::FuncOp>(
|
||||||
|
@ -90,7 +90,8 @@ OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
|
||||||
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
|
ArrayAttr arrayAttr =
|
||||||
|
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
|
||||||
if (arrayAttr && arrayAttr.size() == 2)
|
if (arrayAttr && arrayAttr.size() == 2)
|
||||||
return arrayAttr[1];
|
return arrayAttr[1];
|
||||||
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
|
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
|
||||||
@ -103,7 +104,8 @@ OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
|
||||||
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
|
ArrayAttr arrayAttr =
|
||||||
|
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
|
||||||
if (arrayAttr && arrayAttr.size() == 2)
|
if (arrayAttr && arrayAttr.size() == 2)
|
||||||
return arrayAttr[0];
|
return arrayAttr[0];
|
||||||
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
|
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
|
||||||
|
@ -94,7 +94,7 @@ DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
|
|||||||
|
|
||||||
void DataLayoutEntryAttr::print(AsmPrinter &os) const {
|
void DataLayoutEntryAttr::print(AsmPrinter &os) const {
|
||||||
os << DataLayoutEntryAttr::kAttrKeyword << "<";
|
os << DataLayoutEntryAttr::kAttrKeyword << "<";
|
||||||
if (auto type = getKey().dyn_cast<Type>())
|
if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
|
||||||
os << type;
|
os << type;
|
||||||
else
|
else
|
||||||
os << "\"" << getKey().get<StringAttr>().strref() << "\"";
|
os << "\"" << getKey().get<StringAttr>().strref() << "\"";
|
||||||
@ -151,7 +151,7 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
|||||||
DenseSet<Type> types;
|
DenseSet<Type> types;
|
||||||
DenseSet<StringAttr> ids;
|
DenseSet<StringAttr> ids;
|
||||||
for (DataLayoutEntryInterface entry : entries) {
|
for (DataLayoutEntryInterface entry : entries) {
|
||||||
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
|
||||||
if (!types.insert(type).second)
|
if (!types.insert(type).second)
|
||||||
return emitError() << "repeated layout entry key: " << type;
|
return emitError() << "repeated layout entry key: " << type;
|
||||||
} else {
|
} else {
|
||||||
|
@ -493,7 +493,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
|
|||||||
// error. All other canonicalization is done in the fold method.
|
// error. All other canonicalization is done in the fold method.
|
||||||
bool requiresConst = !rawConstantIndices.empty() &&
|
bool requiresConst = !rawConstantIndices.empty() &&
|
||||||
currType.isa_and_nonnull<LLVMStructType>();
|
currType.isa_and_nonnull<LLVMStructType>();
|
||||||
if (Value val = iter.dyn_cast<Value>()) {
|
if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
|
||||||
APInt intC;
|
APInt intC;
|
||||||
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
|
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
|
||||||
intC.isSignedIntN(kGEPConstantBitWidth)) {
|
intC.isSignedIntN(kGEPConstantBitWidth)) {
|
||||||
@ -598,7 +598,7 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
|
|||||||
llvm::interleaveComma(
|
llvm::interleaveComma(
|
||||||
GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
|
GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
|
||||||
[&](PointerUnion<IntegerAttr, Value> cst) {
|
[&](PointerUnion<IntegerAttr, Value> cst) {
|
||||||
if (Value val = cst.dyn_cast<Value>())
|
if (Value val = llvm::dyn_cast_if_present<Value>(cst))
|
||||||
printer.printOperand(val);
|
printer.printOperand(val);
|
||||||
else
|
else
|
||||||
printer << cst.get<IntegerAttr>().getInt();
|
printer << cst.get<IntegerAttr>().getInt();
|
||||||
@ -2495,7 +2495,7 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
|
|||||||
!integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
|
!integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
|
||||||
|
|
||||||
PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
|
PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
|
||||||
if (Value val = existing.dyn_cast<Value>())
|
if (Value val = llvm::dyn_cast_if_present<Value>(existing))
|
||||||
gepArgs.emplace_back(val);
|
gepArgs.emplace_back(val);
|
||||||
else
|
else
|
||||||
gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
|
gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
|
||||||
|
@ -261,7 +261,7 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
|
|||||||
|
|
||||||
static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
|
static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
|
||||||
return llvm::all_of(gepOp.getIndices(), [](auto index) {
|
return llvm::all_of(gepOp.getIndices(), [](auto index) {
|
||||||
auto indexAttr = index.template dyn_cast<IntegerAttr>();
|
auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
|
||||||
return indexAttr && indexAttr.getValue() == 0;
|
return indexAttr && indexAttr.getValue() == 0;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -289,7 +289,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
|
|||||||
// Ensures all indices are static and fetches them.
|
// Ensures all indices are static and fetches them.
|
||||||
SmallVector<IntegerAttr> indices;
|
SmallVector<IntegerAttr> indices;
|
||||||
for (auto index : gep.getIndices()) {
|
for (auto index : gep.getIndices()) {
|
||||||
IntegerAttr indexInt = index.dyn_cast<IntegerAttr>();
|
IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
|
||||||
if (!indexInt)
|
if (!indexInt)
|
||||||
return {};
|
return {};
|
||||||
indices.push_back(indexInt);
|
indices.push_back(indexInt);
|
||||||
@ -310,7 +310,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
|
|||||||
for (IntegerAttr index : llvm::drop_begin(indices)) {
|
for (IntegerAttr index : llvm::drop_begin(indices)) {
|
||||||
// Ensure the structure of the type being indexed can be reasoned about.
|
// Ensure the structure of the type being indexed can be reasoned about.
|
||||||
// This includes rejecting any potential typed pointer.
|
// This includes rejecting any potential typed pointer.
|
||||||
auto destructurable = selectedType.dyn_cast<DestructurableTypeInterface>();
|
auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
|
||||||
if (!destructurable)
|
if (!destructurable)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
@ -343,7 +343,7 @@ LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
|
|||||||
bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
|
bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
|
||||||
SmallPtrSetImpl<Attribute> &usedIndices,
|
SmallPtrSetImpl<Attribute> &usedIndices,
|
||||||
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
|
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
|
||||||
auto basePtrType = getBase().getType().dyn_cast<LLVM::LLVMPointerType>();
|
auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
|
||||||
if (!basePtrType)
|
if (!basePtrType)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@ -359,7 +359,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
|
|||||||
return false;
|
return false;
|
||||||
auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
|
auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
|
||||||
assert(slot.elementPtrs.contains(firstLevelIndex));
|
assert(slot.elementPtrs.contains(firstLevelIndex));
|
||||||
if (!slot.elementPtrs.at(firstLevelIndex).isa<LLVM::LLVMPointerType>())
|
if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
|
||||||
return false;
|
return false;
|
||||||
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
|
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
|
||||||
usedIndices.insert(firstLevelIndex);
|
usedIndices.insert(firstLevelIndex);
|
||||||
@ -369,7 +369,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
|
|||||||
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
|
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
|
||||||
DenseMap<Attribute, MemorySlot> &subslots,
|
DenseMap<Attribute, MemorySlot> &subslots,
|
||||||
RewriterBase &rewriter) {
|
RewriterBase &rewriter) {
|
||||||
IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast<IntegerAttr>();
|
IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
|
||||||
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
|
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
|
||||||
|
|
||||||
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
|
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
|
||||||
@ -414,7 +414,7 @@ LLVM::LLVMStructType::getSubelementIndexMap() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
|
Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
|
||||||
auto indexAttr = index.dyn_cast<IntegerAttr>();
|
auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
|
||||||
if (!indexAttr || !indexAttr.getType().isInteger(32))
|
if (!indexAttr || !indexAttr.getType().isInteger(32))
|
||||||
return {};
|
return {};
|
||||||
int32_t indexInt = indexAttr.getInt();
|
int32_t indexInt = indexAttr.getInt();
|
||||||
@ -439,7 +439,7 @@ LLVM::LLVMArrayType::getSubelementIndexMap() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
|
Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
|
||||||
auto indexAttr = index.dyn_cast<IntegerAttr>();
|
auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
|
||||||
if (!indexAttr || !indexAttr.getType().isInteger(32))
|
if (!indexAttr || !indexAttr.getType().isInteger(32))
|
||||||
return {};
|
return {};
|
||||||
int32_t indexInt = indexAttr.getInt();
|
int32_t indexInt = indexAttr.getInt();
|
||||||
|
@ -354,7 +354,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
|
|||||||
auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
|
auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
|
||||||
const auto *it =
|
const auto *it =
|
||||||
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
||||||
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
|
||||||
return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
|
return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
|
||||||
newType.getAddressSpace();
|
newType.getAddressSpace();
|
||||||
}
|
}
|
||||||
@ -362,7 +362,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
|
|||||||
});
|
});
|
||||||
if (it == oldLayout.end()) {
|
if (it == oldLayout.end()) {
|
||||||
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
||||||
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
|
||||||
return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
|
return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -2368,7 +2368,7 @@ transform::TileOp::apply(TransformResults &transformResults,
|
|||||||
sizes.reserve(tileSizes.size());
|
sizes.reserve(tileSizes.size());
|
||||||
unsigned dynamicIdx = 0;
|
unsigned dynamicIdx = 0;
|
||||||
for (OpFoldResult ofr : getMixedSizes()) {
|
for (OpFoldResult ofr : getMixedSizes()) {
|
||||||
if (auto attr = ofr.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
||||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||||
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
||||||
continue;
|
continue;
|
||||||
@ -2794,7 +2794,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
|
|||||||
sizes.reserve(tileSizes.size());
|
sizes.reserve(tileSizes.size());
|
||||||
unsigned dynamicIdx = 0;
|
unsigned dynamicIdx = 0;
|
||||||
for (OpFoldResult ofr : getMixedSizes()) {
|
for (OpFoldResult ofr : getMixedSizes()) {
|
||||||
if (auto attr = ofr.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
||||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||||
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
||||||
} else {
|
} else {
|
||||||
|
@ -1447,7 +1447,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
|
|||||||
cast<LinalgOp>(genericOp.getOperation())
|
cast<LinalgOp>(genericOp.getOperation())
|
||||||
.createLoopRanges(rewriter, genericOp.getLoc());
|
.createLoopRanges(rewriter, genericOp.getLoc());
|
||||||
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
|
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
|
||||||
if (auto attr = ofr.dyn_cast<Attribute>())
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
|
||||||
return cast<IntegerAttr>(attr).getInt() == value;
|
return cast<IntegerAttr>(attr).getInt() == value;
|
||||||
llvm::APInt actual;
|
llvm::APInt actual;
|
||||||
return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
|
return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
|
||||||
|
@ -229,7 +229,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
|
|||||||
// to look for the bound.
|
// to look for the bound.
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
|
LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
|
||||||
Value size;
|
Value size;
|
||||||
if (auto attr = rangeValue.size.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
|
||||||
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
|
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
|
||||||
} else {
|
} else {
|
||||||
Value materializedSize =
|
Value materializedSize =
|
||||||
|
@ -92,7 +92,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
|
|||||||
rewriter, op.getLoc(), d0 + d1 - d2,
|
rewriter, op.getLoc(), d0 + d1 - d2,
|
||||||
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
|
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
|
||||||
minSplitPoint});
|
minSplitPoint});
|
||||||
if (auto attr = remainingSize.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
|
||||||
if (cast<IntegerAttr>(attr).getValue().isZero())
|
if (cast<IntegerAttr>(attr).getValue().isZero())
|
||||||
return {op, TilingInterface()};
|
return {op, TilingInterface()};
|
||||||
}
|
}
|
||||||
|
@ -48,7 +48,7 @@ using namespace mlir::scf;
|
|||||||
static bool isZero(OpFoldResult v) {
|
static bool isZero(OpFoldResult v) {
|
||||||
if (!v)
|
if (!v)
|
||||||
return false;
|
return false;
|
||||||
if (auto attr = v.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
|
||||||
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
|
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
return intAttr && intAttr.getValue().isZero();
|
return intAttr && intAttr.getValue().isZero();
|
||||||
}
|
}
|
||||||
@ -104,7 +104,7 @@ void mlir::linalg::transformIndexOps(
|
|||||||
/// checked at runtime.
|
/// checked at runtime.
|
||||||
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
|
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
|
||||||
OpFoldResult value) {
|
OpFoldResult value) {
|
||||||
if (auto attr = value.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
|
||||||
assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
|
assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
|
||||||
"expected strictly positive tile size and divisor");
|
"expected strictly positive tile size and divisor");
|
||||||
return;
|
return;
|
||||||
|
@ -1135,7 +1135,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
|
|||||||
PatternRewriter &rewriter) const {
|
PatternRewriter &rewriter) const {
|
||||||
// Given an OpFoldResult, return an index-typed value.
|
// Given an OpFoldResult, return an index-typed value.
|
||||||
auto getIdxValue = [&](OpFoldResult ofr) {
|
auto getIdxValue = [&](OpFoldResult ofr) {
|
||||||
if (auto val = ofr.dyn_cast<Value>())
|
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
|
||||||
return val;
|
return val;
|
||||||
return rewriter
|
return rewriter
|
||||||
.create<arith::ConstantIndexOp>(
|
.create<arith::ConstantIndexOp>(
|
||||||
|
@ -1646,7 +1646,7 @@ static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
|
|||||||
ArrayRef<OpFoldResult> ofrs) {
|
ArrayRef<OpFoldResult> ofrs) {
|
||||||
SmallVector<Value> result;
|
SmallVector<Value> result;
|
||||||
for (auto o : ofrs) {
|
for (auto o : ofrs) {
|
||||||
if (auto val = o.template dyn_cast<Value>()) {
|
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
|
||||||
result.push_back(val);
|
result.push_back(val);
|
||||||
} else {
|
} else {
|
||||||
result.push_back(rewriter.create<arith::ConstantIndexOp>(
|
result.push_back(rewriter.create<arith::ConstantIndexOp>(
|
||||||
@ -1954,8 +1954,8 @@ struct PadOpVectorizationWithTransferWritePattern
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Other cases: Take a deeper look at defining ops of values.
|
// Other cases: Take a deeper look at defining ops of values.
|
||||||
auto v1 = size1.dyn_cast<Value>();
|
auto v1 = llvm::dyn_cast_if_present<Value>(size1);
|
||||||
auto v2 = size2.dyn_cast<Value>();
|
auto v2 = llvm::dyn_cast_if_present<Value>(size2);
|
||||||
if (!v1 || !v2)
|
if (!v1 || !v2)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
@ -970,7 +970,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
|
|||||||
auto dim = it.index();
|
auto dim = it.index();
|
||||||
auto size = it.value();
|
auto size = it.value();
|
||||||
curr.push_back(dim);
|
curr.push_back(dim);
|
||||||
auto attr = size.dyn_cast<Attribute>();
|
auto attr = llvm::dyn_cast_if_present<Attribute>(size);
|
||||||
if (attr && cast<IntegerAttr>(attr).getInt() == 1)
|
if (attr && cast<IntegerAttr>(attr).getInt() == 1)
|
||||||
continue;
|
continue;
|
||||||
reassociation.emplace_back(ReassociationIndices{});
|
reassociation.emplace_back(ReassociationIndices{});
|
||||||
|
@ -64,7 +64,7 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static bool isSupportedElementType(Type type) {
|
static bool isSupportedElementType(Type type) {
|
||||||
return type.isa<MemRefType>() ||
|
return llvm::isa<MemRefType>(type) ||
|
||||||
OpBuilder(type.getContext()).getZeroAttr(type);
|
OpBuilder(type.getContext()).getZeroAttr(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
|
|||||||
SmallVector<DestructurableMemorySlot>
|
SmallVector<DestructurableMemorySlot>
|
||||||
memref::AllocaOp::getDestructurableSlots() {
|
memref::AllocaOp::getDestructurableSlots() {
|
||||||
MemRefType memrefType = getType();
|
MemRefType memrefType = getType();
|
||||||
auto destructurable = memrefType.dyn_cast<DestructurableTypeInterface>();
|
auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
|
||||||
if (!destructurable)
|
if (!destructurable)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
@ -134,7 +134,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
|
|||||||
|
|
||||||
DenseMap<Attribute, MemorySlot> slotMap;
|
DenseMap<Attribute, MemorySlot> slotMap;
|
||||||
|
|
||||||
auto memrefType = getType().cast<DestructurableTypeInterface>();
|
auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
|
||||||
for (Attribute usedIndex : usedIndices) {
|
for (Attribute usedIndex : usedIndices) {
|
||||||
Type elemType = memrefType.getTypeAtIndex(usedIndex);
|
Type elemType = memrefType.getTypeAtIndex(usedIndex);
|
||||||
MemRefType elemPtr = MemRefType::get({}, elemType);
|
MemRefType elemPtr = MemRefType::get({}, elemType);
|
||||||
@ -281,7 +281,7 @@ struct MemRefDestructurableTypeExternalModel
|
|||||||
MemRefDestructurableTypeExternalModel, MemRefType> {
|
MemRefDestructurableTypeExternalModel, MemRefType> {
|
||||||
std::optional<DenseMap<Attribute, Type>>
|
std::optional<DenseMap<Attribute, Type>>
|
||||||
getSubelementIndexMap(Type type) const {
|
getSubelementIndexMap(Type type) const {
|
||||||
auto memrefType = type.cast<MemRefType>();
|
auto memrefType = llvm::cast<MemRefType>(type);
|
||||||
constexpr int64_t maxMemrefSizeForDestructuring = 16;
|
constexpr int64_t maxMemrefSizeForDestructuring = 16;
|
||||||
if (!memrefType.hasStaticShape() ||
|
if (!memrefType.hasStaticShape() ||
|
||||||
memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
|
memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
|
||||||
@ -298,15 +298,15 @@ struct MemRefDestructurableTypeExternalModel
|
|||||||
}
|
}
|
||||||
|
|
||||||
Type getTypeAtIndex(Type type, Attribute index) const {
|
Type getTypeAtIndex(Type type, Attribute index) const {
|
||||||
auto memrefType = type.cast<MemRefType>();
|
auto memrefType = llvm::cast<MemRefType>(type);
|
||||||
auto coordArrAttr = index.dyn_cast<ArrayAttr>();
|
auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
|
||||||
if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
|
if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
Type indexType = IndexType::get(memrefType.getContext());
|
Type indexType = IndexType::get(memrefType.getContext());
|
||||||
for (const auto &[coordAttr, dimSize] :
|
for (const auto &[coordAttr, dimSize] :
|
||||||
llvm::zip(coordArrAttr, memrefType.getShape())) {
|
llvm::zip(coordArrAttr, memrefType.getShape())) {
|
||||||
auto coord = coordAttr.dyn_cast<IntegerAttr>();
|
auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
|
||||||
if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
|
if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
|
||||||
coord.getInt() >= dimSize)
|
coord.getInt() >= dimSize)
|
||||||
return {};
|
return {};
|
||||||
|
@ -970,7 +970,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
|
|||||||
return unusedDims;
|
return unusedDims;
|
||||||
|
|
||||||
for (const auto &dim : llvm::enumerate(sizes))
|
for (const auto &dim : llvm::enumerate(sizes))
|
||||||
if (auto attr = dim.value().dyn_cast<Attribute>())
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
|
||||||
if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
|
if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
|
||||||
unusedDims.set(dim.index());
|
unusedDims.set(dim.index());
|
||||||
|
|
||||||
@ -1042,7 +1042,7 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
|
|||||||
|
|
||||||
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
|
||||||
// All forms of folding require a known index.
|
// All forms of folding require a known index.
|
||||||
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
|
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
|
||||||
if (!index)
|
if (!index)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
|||||||
// Because we only support input strides of 1, the output stride is also
|
// Because we only support input strides of 1, the output stride is also
|
||||||
// always 1.
|
// always 1.
|
||||||
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
|
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
|
||||||
Attribute attr = valueOrAttr.dyn_cast<Attribute>();
|
Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
|
||||||
return attr && cast<IntegerAttr>(attr).getInt() == 1;
|
return attr && cast<IntegerAttr>(attr).getInt() == 1;
|
||||||
})) {
|
})) {
|
||||||
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
|
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
|
||||||
@ -86,8 +86,9 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sizes.push_back(opSize);
|
sizes.push_back(opSize);
|
||||||
Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
|
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
|
||||||
sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
|
sourceOffsetAttr =
|
||||||
|
llvm::dyn_cast_if_present<Attribute>(sourceOffset);
|
||||||
|
|
||||||
if (opOffsetAttr && sourceOffsetAttr) {
|
if (opOffsetAttr && sourceOffsetAttr) {
|
||||||
// If both offsets are static we can simply calculate the combined
|
// If both offsets are static we can simply calculate the combined
|
||||||
@ -101,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
|||||||
AffineExpr expr = rewriter.getAffineConstantExpr(0);
|
AffineExpr expr = rewriter.getAffineConstantExpr(0);
|
||||||
SmallVector<Value> affineApplyOperands;
|
SmallVector<Value> affineApplyOperands;
|
||||||
for (auto valueOrAttr : {opOffset, sourceOffset}) {
|
for (auto valueOrAttr : {opOffset, sourceOffset}) {
|
||||||
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
|
||||||
expr = expr + cast<IntegerAttr>(attr).getInt();
|
expr = expr + cast<IntegerAttr>(attr).getInt();
|
||||||
} else {
|
} else {
|
||||||
expr =
|
expr =
|
||||||
|
@ -520,7 +520,7 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
|
|||||||
<< operandName << " operand appears more than once";
|
<< operandName << " operand appears more than once";
|
||||||
|
|
||||||
mlir::Type varType = operand.getType();
|
mlir::Type varType = operand.getType();
|
||||||
auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
|
auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
|
||||||
auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
|
auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
|
||||||
if (!decl)
|
if (!decl)
|
||||||
return op->emitOpError()
|
return op->emitOpError()
|
||||||
|
@ -802,10 +802,10 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands,
|
|||||||
for (const auto &mapTypeOp : *map_types) {
|
for (const auto &mapTypeOp : *map_types) {
|
||||||
int64_t mapTypeBits = 0x00;
|
int64_t mapTypeBits = 0x00;
|
||||||
|
|
||||||
if (!mapTypeOp.isa<mlir::IntegerAttr>())
|
if (!llvm::isa<mlir::IntegerAttr>(mapTypeOp))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
|
mapTypeBits = llvm::cast<mlir::IntegerAttr>(mapTypeOp).getInt();
|
||||||
|
|
||||||
bool to =
|
bool to =
|
||||||
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
|
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
|
||||||
|
@ -381,7 +381,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
|
|||||||
// map.
|
// map.
|
||||||
auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
|
auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
auto iterRanked = initArgBufferType->cast<MemRefType>();
|
auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
|
||||||
assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
|
assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
|
||||||
"expected same shape");
|
"expected same shape");
|
||||||
assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
|
assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
|
||||||
@ -802,7 +802,7 @@ struct WhileOpInterface
|
|||||||
if (!isa<TensorType>(bbArg.getType()))
|
if (!isa<TensorType>(bbArg.getType()))
|
||||||
return bbArg.getType();
|
return bbArg.getType();
|
||||||
// TODO: error handling
|
// TODO: error handling
|
||||||
return bufferization::getBufferType(bbArg, options)->cast<Type>();
|
return llvm::cast<Type>(*bufferization::getBufferType(bbArg, options));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Construct a new scf.while op with memref instead of tensor values.
|
// Construct a new scf.while op with memref instead of tensor values.
|
||||||
|
@ -88,10 +88,10 @@ LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr,
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
unsigned dimIv = cstr.appendDimVar(iv);
|
unsigned dimIv = cstr.appendDimVar(iv);
|
||||||
auto lbv = lb.dyn_cast<Value>();
|
auto lbv = llvm::dyn_cast_if_present<Value>(lb);
|
||||||
unsigned symLb =
|
unsigned symLb =
|
||||||
lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1);
|
lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1);
|
||||||
auto ubv = ub.dyn_cast<Value>();
|
auto ubv = llvm::dyn_cast_if_present<Value>(ub);
|
||||||
unsigned symUb =
|
unsigned symUb =
|
||||||
ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1);
|
ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1);
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
|
|||||||
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
|
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
|
||||||
if (getIndices().size() == 1 &&
|
if (getIndices().size() == 1 &&
|
||||||
constructOp.getConstituents().size() == type.getNumElements()) {
|
constructOp.getConstituents().size() == type.getNumElements()) {
|
||||||
auto i = getIndices().begin()->cast<IntegerAttr>();
|
auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
|
||||||
return constructOp.getConstituents()[i.getValue().getSExtValue()];
|
return constructOp.getConstituents()[i.getValue().getSExtValue()];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1562,8 +1562,8 @@ LogicalResult spirv::BitcastOp::verify() {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult spirv::ConvertPtrToUOp::verify() {
|
LogicalResult spirv::ConvertPtrToUOp::verify() {
|
||||||
auto operandType = getPointer().getType().cast<spirv::PointerType>();
|
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
|
||||||
auto resultType = getResult().getType().cast<spirv::ScalarType>();
|
auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
|
||||||
if (!resultType || !resultType.isSignlessInteger())
|
if (!resultType || !resultType.isSignlessInteger())
|
||||||
return emitError("result must be a scalar type of unsigned integer");
|
return emitError("result must be a scalar type of unsigned integer");
|
||||||
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
|
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
|
||||||
@ -1583,8 +1583,8 @@ LogicalResult spirv::ConvertPtrToUOp::verify() {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult spirv::ConvertUToPtrOp::verify() {
|
LogicalResult spirv::ConvertUToPtrOp::verify() {
|
||||||
auto operandType = getOperand().getType().cast<spirv::ScalarType>();
|
auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
|
||||||
auto resultType = getResult().getType().cast<spirv::PointerType>();
|
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
|
||||||
if (!operandType || !operandType.isSignlessInteger())
|
if (!operandType || !operandType.isSignlessInteger())
|
||||||
return emitError("result must be a scalar type of unsigned integer");
|
return emitError("result must be a scalar type of unsigned integer");
|
||||||
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
|
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
|
||||||
|
@ -125,23 +125,23 @@ Type CompositeType::getElementType(unsigned index) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned CompositeType::getNumElements() const {
|
unsigned CompositeType::getNumElements() const {
|
||||||
if (auto arrayType = dyn_cast<ArrayType>())
|
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
|
||||||
return arrayType.getNumElements();
|
return arrayType.getNumElements();
|
||||||
if (auto matrixType = dyn_cast<MatrixType>())
|
if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
|
||||||
return matrixType.getNumColumns();
|
return matrixType.getNumColumns();
|
||||||
if (auto structType = dyn_cast<StructType>())
|
if (auto structType = llvm::dyn_cast<StructType>(*this))
|
||||||
return structType.getNumElements();
|
return structType.getNumElements();
|
||||||
if (auto vectorType = dyn_cast<VectorType>())
|
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
|
||||||
return vectorType.getNumElements();
|
return vectorType.getNumElements();
|
||||||
if (isa<CooperativeMatrixNVType>()) {
|
if (llvm::isa<CooperativeMatrixNVType>(*this)) {
|
||||||
llvm_unreachable(
|
llvm_unreachable(
|
||||||
"invalid to query number of elements of spirv::CooperativeMatrix type");
|
"invalid to query number of elements of spirv::CooperativeMatrix type");
|
||||||
}
|
}
|
||||||
if (isa<JointMatrixINTELType>()) {
|
if (llvm::isa<JointMatrixINTELType>(*this)) {
|
||||||
llvm_unreachable(
|
llvm_unreachable(
|
||||||
"invalid to query number of elements of spirv::JointMatrix type");
|
"invalid to query number of elements of spirv::JointMatrix type");
|
||||||
}
|
}
|
||||||
if (isa<RuntimeArrayType>()) {
|
if (llvm::isa<RuntimeArrayType>(*this)) {
|
||||||
llvm_unreachable(
|
llvm_unreachable(
|
||||||
"invalid to query number of elements of spirv::RuntimeArray type");
|
"invalid to query number of elements of spirv::RuntimeArray type");
|
||||||
}
|
}
|
||||||
@ -149,8 +149,8 @@ unsigned CompositeType::getNumElements() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool CompositeType::hasCompileTimeKnownNumElements() const {
|
bool CompositeType::hasCompileTimeKnownNumElements() const {
|
||||||
return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
|
return !llvm::isa<CooperativeMatrixNVType, JointMatrixINTELType,
|
||||||
RuntimeArrayType>();
|
RuntimeArrayType>(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompositeType::getExtensions(
|
void CompositeType::getExtensions(
|
||||||
@ -188,11 +188,11 @@ void CompositeType::getCapabilities(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::optional<int64_t> CompositeType::getSizeInBytes() {
|
std::optional<int64_t> CompositeType::getSizeInBytes() {
|
||||||
if (auto arrayType = dyn_cast<ArrayType>())
|
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
|
||||||
return arrayType.getSizeInBytes();
|
return arrayType.getSizeInBytes();
|
||||||
if (auto structType = dyn_cast<StructType>())
|
if (auto structType = llvm::dyn_cast<StructType>(*this))
|
||||||
return structType.getSizeInBytes();
|
return structType.getSizeInBytes();
|
||||||
if (auto vectorType = dyn_cast<VectorType>()) {
|
if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
|
||||||
std::optional<int64_t> elementSize =
|
std::optional<int64_t> elementSize =
|
||||||
llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
|
llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
|
||||||
if (!elementSize)
|
if (!elementSize)
|
||||||
@ -680,7 +680,7 @@ void ScalarType::getCapabilities(
|
|||||||
capabilities.push_back(ref); \
|
capabilities.push_back(ref); \
|
||||||
} break
|
} break
|
||||||
|
|
||||||
if (auto intType = dyn_cast<IntegerType>()) {
|
if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
|
||||||
switch (bitwidth) {
|
switch (bitwidth) {
|
||||||
WIDTH_CASE(Int, 8);
|
WIDTH_CASE(Int, 8);
|
||||||
WIDTH_CASE(Int, 16);
|
WIDTH_CASE(Int, 16);
|
||||||
@ -692,7 +692,7 @@ void ScalarType::getCapabilities(
|
|||||||
llvm_unreachable("invalid bitwidth to getCapabilities");
|
llvm_unreachable("invalid bitwidth to getCapabilities");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
assert(isa<FloatType>());
|
assert(llvm::isa<FloatType>(*this));
|
||||||
switch (bitwidth) {
|
switch (bitwidth) {
|
||||||
WIDTH_CASE(Float, 16);
|
WIDTH_CASE(Float, 16);
|
||||||
WIDTH_CASE(Float, 64);
|
WIDTH_CASE(Float, 64);
|
||||||
@ -735,22 +735,22 @@ bool SPIRVType::classof(Type type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool SPIRVType::isScalarOrVector() {
|
bool SPIRVType::isScalarOrVector() {
|
||||||
return isIntOrFloat() || isa<VectorType>();
|
return isIntOrFloat() || llvm::isa<VectorType>(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||||
std::optional<StorageClass> storage) {
|
std::optional<StorageClass> storage) {
|
||||||
if (auto scalarType = dyn_cast<ScalarType>()) {
|
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
|
||||||
scalarType.getExtensions(extensions, storage);
|
scalarType.getExtensions(extensions, storage);
|
||||||
} else if (auto compositeType = dyn_cast<CompositeType>()) {
|
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
|
||||||
compositeType.getExtensions(extensions, storage);
|
compositeType.getExtensions(extensions, storage);
|
||||||
} else if (auto imageType = dyn_cast<ImageType>()) {
|
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
|
||||||
imageType.getExtensions(extensions, storage);
|
imageType.getExtensions(extensions, storage);
|
||||||
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
|
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
|
||||||
sampledImageType.getExtensions(extensions, storage);
|
sampledImageType.getExtensions(extensions, storage);
|
||||||
} else if (auto matrixType = dyn_cast<MatrixType>()) {
|
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
|
||||||
matrixType.getExtensions(extensions, storage);
|
matrixType.getExtensions(extensions, storage);
|
||||||
} else if (auto ptrType = dyn_cast<PointerType>()) {
|
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
|
||||||
ptrType.getExtensions(extensions, storage);
|
ptrType.getExtensions(extensions, storage);
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("invalid SPIR-V Type to getExtensions");
|
llvm_unreachable("invalid SPIR-V Type to getExtensions");
|
||||||
@ -760,17 +760,17 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
|||||||
void SPIRVType::getCapabilities(
|
void SPIRVType::getCapabilities(
|
||||||
SPIRVType::CapabilityArrayRefVector &capabilities,
|
SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||||
std::optional<StorageClass> storage) {
|
std::optional<StorageClass> storage) {
|
||||||
if (auto scalarType = dyn_cast<ScalarType>()) {
|
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
|
||||||
scalarType.getCapabilities(capabilities, storage);
|
scalarType.getCapabilities(capabilities, storage);
|
||||||
} else if (auto compositeType = dyn_cast<CompositeType>()) {
|
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
|
||||||
compositeType.getCapabilities(capabilities, storage);
|
compositeType.getCapabilities(capabilities, storage);
|
||||||
} else if (auto imageType = dyn_cast<ImageType>()) {
|
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
|
||||||
imageType.getCapabilities(capabilities, storage);
|
imageType.getCapabilities(capabilities, storage);
|
||||||
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
|
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
|
||||||
sampledImageType.getCapabilities(capabilities, storage);
|
sampledImageType.getCapabilities(capabilities, storage);
|
||||||
} else if (auto matrixType = dyn_cast<MatrixType>()) {
|
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
|
||||||
matrixType.getCapabilities(capabilities, storage);
|
matrixType.getCapabilities(capabilities, storage);
|
||||||
} else if (auto ptrType = dyn_cast<PointerType>()) {
|
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
|
||||||
ptrType.getCapabilities(capabilities, storage);
|
ptrType.getCapabilities(capabilities, storage);
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
|
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
|
||||||
@ -778,9 +778,9 @@ void SPIRVType::getCapabilities(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::optional<int64_t> SPIRVType::getSizeInBytes() {
|
std::optional<int64_t> SPIRVType::getSizeInBytes() {
|
||||||
if (auto scalarType = dyn_cast<ScalarType>())
|
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
|
||||||
return scalarType.getSizeInBytes();
|
return scalarType.getSizeInBytes();
|
||||||
if (auto compositeType = dyn_cast<CompositeType>())
|
if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
|
||||||
return compositeType.getSizeInBytes();
|
return compositeType.getSizeInBytes();
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
@ -856,9 +856,9 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
|
|||||||
if (!adaptor.getLhs() || !adaptor.getRhs())
|
if (!adaptor.getLhs() || !adaptor.getRhs())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto lhsShape = llvm::to_vector<6>(
|
auto lhsShape = llvm::to_vector<6>(
|
||||||
adaptor.getLhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
|
llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
|
||||||
auto rhsShape = llvm::to_vector<6>(
|
auto rhsShape = llvm::to_vector<6>(
|
||||||
adaptor.getRhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
|
llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
|
||||||
SmallVector<int64_t, 6> resultShape;
|
SmallVector<int64_t, 6> resultShape;
|
||||||
resultShape.append(lhsShape.begin(), lhsShape.end());
|
resultShape.append(lhsShape.begin(), lhsShape.end());
|
||||||
resultShape.append(rhsShape.begin(), rhsShape.end());
|
resultShape.append(rhsShape.begin(), rhsShape.end());
|
||||||
@ -989,7 +989,7 @@ OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
|
|||||||
if (!operand)
|
if (!operand)
|
||||||
return false;
|
return false;
|
||||||
extents.push_back(llvm::to_vector<6>(
|
extents.push_back(llvm::to_vector<6>(
|
||||||
operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
|
llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
|
||||||
}
|
}
|
||||||
return OpTrait::util::staticallyKnownBroadcastable(extents);
|
return OpTrait::util::staticallyKnownBroadcastable(extents);
|
||||||
}())
|
}())
|
||||||
@ -1132,10 +1132,10 @@ LogicalResult mlir::shape::DimOp::verify() {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
||||||
auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
|
auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
|
||||||
if (!lhs)
|
if (!lhs)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
|
auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
|
||||||
if (!rhs)
|
if (!rhs)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
@ -1346,7 +1346,7 @@ std::optional<int64_t> GetExtentOp::getConstantDim() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
|
||||||
auto elements = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
|
auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
|
||||||
if (!elements)
|
if (!elements)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
std::optional<int64_t> dim = getConstantDim();
|
std::optional<int64_t> dim = getConstantDim();
|
||||||
@ -1490,7 +1490,7 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
|
||||||
auto shape = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
|
auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
|
||||||
if (!shape)
|
if (!shape)
|
||||||
return {};
|
return {};
|
||||||
int64_t rank = shape.getNumElements();
|
int64_t rank = shape.getNumElements();
|
||||||
@ -1671,10 +1671,10 @@ bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
||||||
auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
|
auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
|
||||||
if (!lhs)
|
if (!lhs)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
|
auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
|
||||||
if (!rhs)
|
if (!rhs)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
APInt folded = lhs.getValue() * rhs.getValue();
|
APInt folded = lhs.getValue() * rhs.getValue();
|
||||||
@ -1864,9 +1864,9 @@ LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
|
|||||||
if (!adaptor.getOperand() || !adaptor.getIndex())
|
if (!adaptor.getOperand() || !adaptor.getIndex())
|
||||||
return failure();
|
return failure();
|
||||||
auto shapeVec = llvm::to_vector<6>(
|
auto shapeVec = llvm::to_vector<6>(
|
||||||
adaptor.getOperand().cast<DenseIntElementsAttr>().getValues<int64_t>());
|
llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
|
||||||
auto shape = llvm::ArrayRef(shapeVec);
|
auto shape = llvm::ArrayRef(shapeVec);
|
||||||
auto splitPoint = adaptor.getIndex().cast<IntegerAttr>().getInt();
|
auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
|
||||||
// Verify that the split point is in the correct range.
|
// Verify that the split point is in the correct range.
|
||||||
// TODO: Constant fold to an "error".
|
// TODO: Constant fold to an "error".
|
||||||
int64_t rank = shape.size();
|
int64_t rank = shape.size();
|
||||||
@ -1889,7 +1889,7 @@ OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
|
|||||||
return OpFoldResult();
|
return OpFoldResult();
|
||||||
Builder builder(getContext());
|
Builder builder(getContext());
|
||||||
auto shape = llvm::to_vector<6>(
|
auto shape = llvm::to_vector<6>(
|
||||||
adaptor.getInput().cast<DenseIntElementsAttr>().getValues<int64_t>());
|
llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
|
||||||
auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
|
auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
|
||||||
builder.getIndexType());
|
builder.getIndexType());
|
||||||
return DenseIntElementsAttr::get(type, shape);
|
return DenseIntElementsAttr::get(type, shape);
|
||||||
|
@ -815,7 +815,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
|
|||||||
Level cooStartLvl = getCOOStart(stt.getEncoding());
|
Level cooStartLvl = getCOOStart(stt.getEncoding());
|
||||||
if (cooStartLvl < stt.getLvlRank()) {
|
if (cooStartLvl < stt.getLvlRank()) {
|
||||||
// We only supports trailing COO for now, must be the last input.
|
// We only supports trailing COO for now, must be the last input.
|
||||||
auto cooTp = lvlTps.back().cast<ShapedType>();
|
auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
|
||||||
// The coordinates should be in shape of <? x rank>
|
// The coordinates should be in shape of <? x rank>
|
||||||
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
|
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
|
||||||
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
|
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
|
||||||
@ -844,7 +844,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
|
|||||||
inputTp = lvlTps[idx++];
|
inputTp = lvlTps[idx++];
|
||||||
}
|
}
|
||||||
// The input element type and expected element type should match.
|
// The input element type and expected element type should match.
|
||||||
Type inpElemTp = inputTp.cast<TensorType>().getElementType();
|
Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
|
||||||
Type expElemTp = getFieldElemType(stt, fKind);
|
Type expElemTp = getFieldElemType(stt, fKind);
|
||||||
if (inpElemTp != expElemTp) {
|
if (inpElemTp != expElemTp) {
|
||||||
misMatch = true;
|
misMatch = true;
|
||||||
|
@ -188,7 +188,7 @@ static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
|
|||||||
/// Generates a memref from tensor operation.
|
/// Generates a memref from tensor operation.
|
||||||
static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
|
static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
|
||||||
Value tensor) {
|
Value tensor) {
|
||||||
auto tensorType = tensor.getType().cast<ShapedType>();
|
auto tensorType = llvm::cast<ShapedType>(tensor.getType());
|
||||||
auto memrefType =
|
auto memrefType =
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
|
return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
|
||||||
|
@ -414,7 +414,7 @@ public:
|
|||||||
/// TODO: better unord/not-unique; also generalize, optimize, specialize!
|
/// TODO: better unord/not-unique; also generalize, optimize, specialize!
|
||||||
SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
|
SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
|
||||||
OpBuilder &builder, Location loc) {
|
OpBuilder &builder, Location loc) {
|
||||||
const SparseTensorType stt(rtp.cast<RankedTensorType>());
|
const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
|
||||||
const Level lvlRank = stt.getLvlRank();
|
const Level lvlRank = stt.getLvlRank();
|
||||||
// Extract fields and coordinates from args.
|
// Extract fields and coordinates from args.
|
||||||
SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
|
SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
|
||||||
@ -466,7 +466,7 @@ public:
|
|||||||
// The mangled name of the function has this format:
|
// The mangled name of the function has this format:
|
||||||
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
|
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
|
||||||
constexpr const char kInsertFuncNamePrefix[] = "_insert_";
|
constexpr const char kInsertFuncNamePrefix[] = "_insert_";
|
||||||
const SparseTensorType stt(rtp.cast<RankedTensorType>());
|
const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
|
||||||
|
|
||||||
SmallString<32> nameBuffer;
|
SmallString<32> nameBuffer;
|
||||||
llvm::raw_svector_ostream nameOstream(nameBuffer);
|
llvm::raw_svector_ostream nameOstream(nameBuffer);
|
||||||
@ -541,14 +541,14 @@ static void genEndInsert(OpBuilder &builder, Location loc,
|
|||||||
|
|
||||||
static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
|
static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
|
||||||
Value tensor) {
|
Value tensor) {
|
||||||
auto tTp = tensor.getType().cast<TensorType>();
|
auto tTp = llvm::cast<TensorType>(tensor.getType());
|
||||||
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
|
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
|
||||||
return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
|
return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
|
Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
|
||||||
auto elemTp = mem.getType().cast<MemRefType>().getElementType();
|
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
|
||||||
return builder
|
return builder
|
||||||
.create<memref::SubViewOp>(
|
.create<memref::SubViewOp>(
|
||||||
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
|
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
|
||||||
|
@ -180,7 +180,7 @@ struct ReifyPadOp
|
|||||||
AffineExpr expr = b.getAffineDimExpr(0);
|
AffineExpr expr = b.getAffineDimExpr(0);
|
||||||
unsigned numSymbols = 0;
|
unsigned numSymbols = 0;
|
||||||
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
|
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
|
||||||
if (Value v = valueOrAttr.dyn_cast<Value>()) {
|
if (Value v = llvm::dyn_cast_if_present<Value>(valueOrAttr)) {
|
||||||
expr = expr + b.getAffineSymbolExpr(numSymbols++);
|
expr = expr + b.getAffineSymbolExpr(numSymbols++);
|
||||||
mapOperands.push_back(v);
|
mapOperands.push_back(v);
|
||||||
return;
|
return;
|
||||||
|
@ -501,7 +501,7 @@ Speculation::Speculatability DimOp::getSpeculatability() {
|
|||||||
|
|
||||||
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
|
||||||
// All forms of folding require a known index.
|
// All forms of folding require a known index.
|
||||||
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
|
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
|
||||||
if (!index)
|
if (!index)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
@ -764,7 +764,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
|
|||||||
OpFoldResult currDim = std::get<1>(it);
|
OpFoldResult currDim = std::get<1>(it);
|
||||||
// Case 1: The empty tensor dim is static. Check that the tensor cast
|
// Case 1: The empty tensor dim is static. Check that the tensor cast
|
||||||
// result dim matches.
|
// result dim matches.
|
||||||
if (auto attr = currDim.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
|
||||||
if (ShapedType::isDynamic(newDim) ||
|
if (ShapedType::isDynamic(newDim) ||
|
||||||
newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
|
newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
|
||||||
// Something is off, the cast result shape cannot be more dynamic
|
// Something is off, the cast result shape cannot be more dynamic
|
||||||
@ -2106,7 +2106,7 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
|
||||||
if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
|
if (auto splat = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
|
||||||
auto resultType = llvm::cast<ShapedType>(getResult().getType());
|
auto resultType = llvm::cast<ShapedType>(getResult().getType());
|
||||||
if (resultType.hasStaticShape())
|
if (resultType.hasStaticShape())
|
||||||
return splat.resizeSplat(resultType);
|
return splat.resizeSplat(resultType);
|
||||||
@ -3558,7 +3558,7 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
|
|||||||
SmallVector<int64_t> result;
|
SmallVector<int64_t> result;
|
||||||
for (auto o : ofrs) {
|
for (auto o : ofrs) {
|
||||||
// Have to do this first, as getConstantIntValue special-cases constants.
|
// Have to do this first, as getConstantIntValue special-cases constants.
|
||||||
if (o.dyn_cast<Value>())
|
if (llvm::dyn_cast_if_present<Value>(o))
|
||||||
result.push_back(ShapedType::kDynamic);
|
result.push_back(ShapedType::kDynamic);
|
||||||
else
|
else
|
||||||
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
|
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
|
||||||
|
@ -76,7 +76,7 @@ struct CastOpInterface
|
|||||||
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
|
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
|
||||||
return MemRefType::get(
|
return MemRefType::get(
|
||||||
rankedResultType.getShape(), rankedResultType.getElementType(),
|
rankedResultType.getShape(), rankedResultType.getElementType(),
|
||||||
maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
|
llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
@ -139,7 +139,7 @@ struct CollapseShapeOpInterface
|
|||||||
collapseShapeOp.getSrc(), options, fixedTypes);
|
collapseShapeOp.getSrc(), options, fixedTypes);
|
||||||
if (failed(maybeSrcBufferType))
|
if (failed(maybeSrcBufferType))
|
||||||
return failure();
|
return failure();
|
||||||
auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
|
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
|
||||||
bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
|
bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
|
||||||
srcBufferType, collapseShapeOp.getReassociationIndices());
|
srcBufferType, collapseShapeOp.getReassociationIndices());
|
||||||
|
|
||||||
@ -303,7 +303,7 @@ struct ExpandShapeOpInterface
|
|||||||
expandShapeOp.getSrc(), options, fixedTypes);
|
expandShapeOp.getSrc(), options, fixedTypes);
|
||||||
if (failed(maybeSrcBufferType))
|
if (failed(maybeSrcBufferType))
|
||||||
return failure();
|
return failure();
|
||||||
auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
|
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
|
||||||
auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
|
auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
|
||||||
srcBufferType, expandShapeOp.getResultType().getShape(),
|
srcBufferType, expandShapeOp.getResultType().getShape(),
|
||||||
expandShapeOp.getReassociationIndices());
|
expandShapeOp.getReassociationIndices());
|
||||||
@ -369,7 +369,7 @@ struct ExtractSliceOpInterface
|
|||||||
if (failed(resultMemrefType))
|
if (failed(resultMemrefType))
|
||||||
return failure();
|
return failure();
|
||||||
Value subView = rewriter.create<memref::SubViewOp>(
|
Value subView = rewriter.create<memref::SubViewOp>(
|
||||||
loc, resultMemrefType->cast<MemRefType>(), *srcMemref, mixedOffsets,
|
loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
|
||||||
mixedSizes, mixedStrides);
|
mixedSizes, mixedStrides);
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, subView);
|
replaceOpWithBufferizedValues(rewriter, op, subView);
|
||||||
@ -389,7 +389,7 @@ struct ExtractSliceOpInterface
|
|||||||
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
||||||
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
||||||
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
|
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
|
||||||
extractSliceOp.getType().getShape(), srcMemrefType->cast<MemRefType>(),
|
extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
|
||||||
mixedOffsets, mixedSizes, mixedStrides));
|
mixedOffsets, mixedSizes, mixedStrides));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -548,8 +548,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
|
|||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto resultETy = resultTy.getElementType();
|
auto resultETy = resultTy.getElementType();
|
||||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||||
|
|
||||||
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
||||||
return getInput1();
|
return getInput1();
|
||||||
@ -573,8 +573,8 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
|||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto resultETy = resultTy.getElementType();
|
auto resultETy = resultTy.getElementType();
|
||||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||||
if (lhsAttr && lhsAttr.isSplat()) {
|
if (lhsAttr && lhsAttr.isSplat()) {
|
||||||
if (llvm::isa<IntegerType>(resultETy) &&
|
if (llvm::isa<IntegerType>(resultETy) &&
|
||||||
lhsAttr.getSplatValue<APInt>().isZero())
|
lhsAttr.getSplatValue<APInt>().isZero())
|
||||||
@ -642,8 +642,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
|||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto resultETy = resultTy.getElementType();
|
auto resultETy = resultTy.getElementType();
|
||||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||||
|
|
||||||
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
|
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
|
||||||
if (rhsTy == resultTy) {
|
if (rhsTy == resultTy) {
|
||||||
@ -670,8 +670,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
|
|||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto resultETy = resultTy.getElementType();
|
auto resultETy = resultTy.getElementType();
|
||||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||||
|
|
||||||
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
||||||
return getInput1();
|
return getInput1();
|
||||||
@ -713,8 +713,8 @@ struct APIntFoldGreaterEqual {
|
|||||||
|
|
||||||
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||||
|
|
||||||
if (!lhsAttr || !rhsAttr)
|
if (!lhsAttr || !rhsAttr)
|
||||||
return {};
|
return {};
|
||||||
@ -725,8 +725,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
|||||||
|
|
||||||
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||||
|
|
||||||
if (!lhsAttr || !rhsAttr)
|
if (!lhsAttr || !rhsAttr)
|
||||||
return {};
|
return {};
|
||||||
@ -738,8 +738,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
|||||||
|
|
||||||
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||||
Value lhs = getInput1();
|
Value lhs = getInput1();
|
||||||
Value rhs = getInput2();
|
Value rhs = getInput2();
|
||||||
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
|
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
|
||||||
@ -763,7 +763,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
|
|||||||
if (getInput().getType() == getType())
|
if (getInput().getType() == getType())
|
||||||
return getInput();
|
return getInput();
|
||||||
|
|
||||||
auto operand = adaptor.getInput().dyn_cast_or_null<ElementsAttr>();
|
auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
|
||||||
if (!operand)
|
if (!operand)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
@ -852,7 +852,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
|||||||
if (inputTy == outputTy)
|
if (inputTy == outputTy)
|
||||||
return getInput1();
|
return getInput1();
|
||||||
|
|
||||||
auto operand = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||||
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
|
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
|
||||||
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
||||||
}
|
}
|
||||||
@ -863,7 +863,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
|||||||
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
|
||||||
// If the pad is all zeros we can fold this operation away.
|
// If the pad is all zeros we can fold this operation away.
|
||||||
if (adaptor.getPadding()) {
|
if (adaptor.getPadding()) {
|
||||||
auto densePad = adaptor.getPadding().cast<DenseElementsAttr>();
|
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
|
||||||
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
|
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
|
||||||
return getInput1();
|
return getInput1();
|
||||||
}
|
}
|
||||||
@ -907,7 +907,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
|
|||||||
auto operand = getInput();
|
auto operand = getInput();
|
||||||
auto operandTy = llvm::cast<ShapedType>(operand.getType());
|
auto operandTy = llvm::cast<ShapedType>(operand.getType());
|
||||||
auto axis = getAxis();
|
auto axis = getAxis();
|
||||||
auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
|
auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
|
||||||
if (operandAttr)
|
if (operandAttr)
|
||||||
return operandAttr;
|
return operandAttr;
|
||||||
|
|
||||||
@ -936,7 +936,7 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
|
|||||||
!outputTy.getElementType().isIntOrIndexOrFloat())
|
!outputTy.getElementType().isIntOrIndexOrFloat())
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto operand = adaptor.getInput().cast<ElementsAttr>();
|
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
|
||||||
if (operand.isSplat() && outputTy.hasStaticShape()) {
|
if (operand.isSplat() && outputTy.hasStaticShape()) {
|
||||||
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
||||||
}
|
}
|
||||||
@ -955,7 +955,7 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
|
|||||||
if (getOnTrue() == getOnFalse())
|
if (getOnTrue() == getOnFalse())
|
||||||
return getOnTrue();
|
return getOnTrue();
|
||||||
|
|
||||||
auto predicate = adaptor.getPred().dyn_cast_or_null<DenseIntElementsAttr>();
|
auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
|
||||||
if (!predicate)
|
if (!predicate)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
@ -977,7 +977,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
|
|||||||
auto resultTy = llvm::cast<ShapedType>(getType());
|
auto resultTy = llvm::cast<ShapedType>(getType());
|
||||||
|
|
||||||
// Transposing splat values just means reshaping.
|
// Transposing splat values just means reshaping.
|
||||||
if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
|
if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
||||||
if (input.isSplat() && resultTy.hasStaticShape() &&
|
if (input.isSplat() && resultTy.hasStaticShape() &&
|
||||||
inputTy.getElementType() == resultTy.getElementType())
|
inputTy.getElementType() == resultTy.getElementType())
|
||||||
return input.reshape(resultTy);
|
return input.reshape(resultTy);
|
||||||
|
@ -63,9 +63,9 @@ LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
|
|||||||
// Verify the rank agrees with the output type if the output type is ranked.
|
// Verify the rank agrees with the output type if the output type is ranked.
|
||||||
if (outputType) {
|
if (outputType) {
|
||||||
if (outputType.getRank() !=
|
if (outputType.getRank() !=
|
||||||
input1_copy.getType().cast<RankedTensorType>().getRank() ||
|
llvm::cast<RankedTensorType>(input1_copy.getType()).getRank() ||
|
||||||
outputType.getRank() !=
|
outputType.getRank() !=
|
||||||
input2_copy.getType().cast<RankedTensorType>().getRank())
|
llvm::cast<RankedTensorType>(input2_copy.getType()).getRank())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
loc, "the reshaped type doesn't agrees with the ranked output type");
|
loc, "the reshaped type doesn't agrees with the ranked output type");
|
||||||
}
|
}
|
||||||
|
@ -103,8 +103,8 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
|
|||||||
|
|
||||||
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
|
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
|
||||||
Value &input1, Value &input2) {
|
Value &input1, Value &input2) {
|
||||||
auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
|
auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
|
||||||
auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
|
auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
|
||||||
|
|
||||||
if (!input1Ty || !input2Ty) {
|
if (!input1Ty || !input2Ty) {
|
||||||
return failure();
|
return failure();
|
||||||
@ -126,9 +126,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
|
|||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<int64_t> higherRankShape =
|
ArrayRef<int64_t> higherRankShape =
|
||||||
higherTensorValue.getType().cast<RankedTensorType>().getShape();
|
llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
|
||||||
ArrayRef<int64_t> lowerRankShape =
|
ArrayRef<int64_t> lowerRankShape =
|
||||||
lowerTensorValue.getType().cast<RankedTensorType>().getShape();
|
llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
|
||||||
|
|
||||||
SmallVector<int64_t, 4> reshapeOutputShape;
|
SmallVector<int64_t, 4> reshapeOutputShape;
|
||||||
|
|
||||||
@ -136,7 +136,8 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
|
|||||||
.failed())
|
.failed())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
|
auto reshapeInputType =
|
||||||
|
llvm::cast<RankedTensorType>(lowerTensorValue.getType());
|
||||||
auto reshapeOutputType = RankedTensorType::get(
|
auto reshapeOutputType = RankedTensorType::get(
|
||||||
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
|
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
|
|||||||
SmallVector<Operation *> operations;
|
SmallVector<Operation *> operations;
|
||||||
operations.reserve(values.size());
|
operations.reserve(values.size());
|
||||||
for (transform::MappedValue value : values) {
|
for (transform::MappedValue value : values) {
|
||||||
if (auto *op = value.dyn_cast<Operation *>()) {
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
|
||||||
operations.push_back(op);
|
operations.push_back(op);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -135,7 +135,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
|
|||||||
SmallVector<Value> payloadValues;
|
SmallVector<Value> payloadValues;
|
||||||
payloadValues.reserve(values.size());
|
payloadValues.reserve(values.size());
|
||||||
for (transform::MappedValue value : values) {
|
for (transform::MappedValue value : values) {
|
||||||
if (auto v = value.dyn_cast<Value>()) {
|
if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
|
||||||
payloadValues.push_back(v);
|
payloadValues.push_back(v);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -152,7 +152,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
|
|||||||
SmallVector<transform::Param> parameters;
|
SmallVector<transform::Param> parameters;
|
||||||
parameters.reserve(values.size());
|
parameters.reserve(values.size());
|
||||||
for (transform::MappedValue value : values) {
|
for (transform::MappedValue value : values) {
|
||||||
if (auto attr = value.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
|
||||||
parameters.push_back(attr);
|
parameters.push_back(attr);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,7 @@ namespace mlir {
|
|||||||
bool isZeroIndex(OpFoldResult v) {
|
bool isZeroIndex(OpFoldResult v) {
|
||||||
if (!v)
|
if (!v)
|
||||||
return false;
|
return false;
|
||||||
if (auto attr = v.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
|
||||||
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
|
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
return intAttr && intAttr.getValue().isZero();
|
return intAttr && intAttr.getValue().isZero();
|
||||||
}
|
}
|
||||||
@ -51,7 +51,7 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
|
|||||||
void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
||||||
SmallVectorImpl<Value> &dynamicVec,
|
SmallVectorImpl<Value> &dynamicVec,
|
||||||
SmallVectorImpl<int64_t> &staticVec) {
|
SmallVectorImpl<int64_t> &staticVec) {
|
||||||
auto v = ofr.dyn_cast<Value>();
|
auto v = llvm::dyn_cast_if_present<Value>(ofr);
|
||||||
if (!v) {
|
if (!v) {
|
||||||
APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
|
APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
|
||||||
staticVec.push_back(apInt.getSExtValue());
|
staticVec.push_back(apInt.getSExtValue());
|
||||||
@ -116,14 +116,14 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
|
|||||||
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
||||||
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
||||||
// Case 1: Check for Constant integer.
|
// Case 1: Check for Constant integer.
|
||||||
if (auto val = ofr.dyn_cast<Value>()) {
|
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
|
||||||
APSInt intVal;
|
APSInt intVal;
|
||||||
if (matchPattern(val, m_ConstantInt(&intVal)))
|
if (matchPattern(val, m_ConstantInt(&intVal)))
|
||||||
return intVal.getSExtValue();
|
return intVal.getSExtValue();
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
// Case 2: Check for IntegerAttr.
|
// Case 2: Check for IntegerAttr.
|
||||||
Attribute attr = ofr.dyn_cast<Attribute>();
|
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
|
||||||
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
|
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
|
||||||
return intAttr.getValue().getSExtValue();
|
return intAttr.getValue().getSExtValue();
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
@ -143,7 +143,8 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
|
|||||||
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
|
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
|
||||||
if (cst1 && cst2 && *cst1 == *cst2)
|
if (cst1 && cst2 && *cst1 == *cst2)
|
||||||
return true;
|
return true;
|
||||||
auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
|
auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
|
||||||
|
v2 = llvm::dyn_cast_if_present<Value>(ofr2);
|
||||||
return v1 && v1 == v2;
|
return v1 && v1 == v2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1154,7 +1154,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
|
|||||||
OpaqueProperties properties, RegionRange,
|
OpaqueProperties properties, RegionRange,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
ExtractOp::Adaptor op(operands, attributes, properties);
|
ExtractOp::Adaptor op(operands, attributes, properties);
|
||||||
auto vectorType = op.getVector().getType().cast<VectorType>();
|
auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
|
||||||
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
|
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
|
||||||
inferredReturnTypes.push_back(vectorType.getElementType());
|
inferredReturnTypes.push_back(vectorType.getElementType());
|
||||||
} else {
|
} else {
|
||||||
@ -2003,9 +2003,9 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
|
|||||||
if (!adaptor.getSource())
|
if (!adaptor.getSource())
|
||||||
return {};
|
return {};
|
||||||
auto vectorType = getResultVectorType();
|
auto vectorType = getResultVectorType();
|
||||||
if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
|
if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
|
||||||
return DenseElementsAttr::get(vectorType, adaptor.getSource());
|
return DenseElementsAttr::get(vectorType, adaptor.getSource());
|
||||||
if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
|
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
|
||||||
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
|
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
@ -2090,7 +2090,7 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
|
|||||||
OpaqueProperties properties, RegionRange,
|
OpaqueProperties properties, RegionRange,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
ShuffleOp::Adaptor op(operands, attributes, properties);
|
ShuffleOp::Adaptor op(operands, attributes, properties);
|
||||||
auto v1Type = op.getV1().getType().cast<VectorType>();
|
auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
|
||||||
auto v1Rank = v1Type.getRank();
|
auto v1Rank = v1Type.getRank();
|
||||||
// Construct resulting type: leading dimension matches mask
|
// Construct resulting type: leading dimension matches mask
|
||||||
// length, all trailing dimensions match the operands.
|
// length, all trailing dimensions match the operands.
|
||||||
@ -4951,7 +4951,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
|
|||||||
|
|
||||||
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
|
||||||
// Eliminate splat constant transpose ops.
|
// Eliminate splat constant transpose ops.
|
||||||
if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
|
if (auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
|
||||||
if (attr.isSplat())
|
if (attr.isSplat())
|
||||||
return attr.reshape(getResultVectorType());
|
return attr.reshape(getResultVectorType());
|
||||||
|
|
||||||
|
@ -3642,7 +3642,7 @@ void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
|
|||||||
if (auto *op = getDefiningOp())
|
if (auto *op = getDefiningOp())
|
||||||
return op->print(os, flags);
|
return op->print(os, flags);
|
||||||
// TODO: Improve BlockArgument print'ing.
|
// TODO: Improve BlockArgument print'ing.
|
||||||
BlockArgument arg = this->cast<BlockArgument>();
|
BlockArgument arg = llvm::cast<BlockArgument>(*this);
|
||||||
os << "<block argument> of type '" << arg.getType()
|
os << "<block argument> of type '" << arg.getType()
|
||||||
<< "' at index: " << arg.getArgNumber();
|
<< "' at index: " << arg.getArgNumber();
|
||||||
}
|
}
|
||||||
@ -3656,7 +3656,7 @@ void Value::print(raw_ostream &os, AsmState &state) {
|
|||||||
return op->print(os, state);
|
return op->print(os, state);
|
||||||
|
|
||||||
// TODO: Improve BlockArgument print'ing.
|
// TODO: Improve BlockArgument print'ing.
|
||||||
BlockArgument arg = this->cast<BlockArgument>();
|
BlockArgument arg = llvm::cast<BlockArgument>(*this);
|
||||||
os << "<block argument> of type '" << arg.getType()
|
os << "<block argument> of type '" << arg.getType()
|
||||||
<< "' at index: " << arg.getArgNumber();
|
<< "' at index: " << arg.getArgNumber();
|
||||||
}
|
}
|
||||||
@ -3693,10 +3693,10 @@ static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
|
|||||||
|
|
||||||
void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
|
void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
|
||||||
Operation *op;
|
Operation *op;
|
||||||
if (auto result = dyn_cast<OpResult>()) {
|
if (auto result = llvm::dyn_cast<OpResult>(*this)) {
|
||||||
op = result.getOwner();
|
op = result.getOwner();
|
||||||
} else {
|
} else {
|
||||||
op = cast<BlockArgument>().getOwner()->getParentOp();
|
op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
|
||||||
if (!op) {
|
if (!op) {
|
||||||
os << "<<UNKNOWN SSA VALUE>>";
|
os << "<<UNKNOWN SSA VALUE>>";
|
||||||
return;
|
return;
|
||||||
|
@ -347,14 +347,14 @@ BlockRange::BlockRange(SuccessorRange successors)
|
|||||||
|
|
||||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||||
BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptrdiff_t index) {
|
BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptrdiff_t index) {
|
||||||
if (auto *operand = object.dyn_cast<BlockOperand *>())
|
if (auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
|
||||||
return {operand + index};
|
return {operand + index};
|
||||||
return {object.dyn_cast<Block *const *>() + index};
|
return {llvm::dyn_cast_if_present<Block *const *>(object) + index};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||||
Block *BlockRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
|
Block *BlockRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
|
||||||
if (const auto *operand = object.dyn_cast<BlockOperand *>())
|
if (const auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
|
||||||
return operand[index].get();
|
return operand[index].get();
|
||||||
return object.dyn_cast<Block *const *>()[index];
|
return llvm::dyn_cast_if_present<Block *const *>(object)[index];
|
||||||
}
|
}
|
||||||
|
@ -483,7 +483,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
|
|||||||
Type expectedType = std::get<1>(it);
|
Type expectedType = std::get<1>(it);
|
||||||
|
|
||||||
// Normal values get pushed back directly.
|
// Normal values get pushed back directly.
|
||||||
if (auto value = std::get<0>(it).dyn_cast<Value>()) {
|
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
|
||||||
if (value.getType() != expectedType)
|
if (value.getType() != expectedType)
|
||||||
return cleanupFailure();
|
return cleanupFailure();
|
||||||
|
|
||||||
|
@ -1247,12 +1247,12 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
|
|||||||
DenseElementsAttr
|
DenseElementsAttr
|
||||||
DenseElementsAttr::mapValues(Type newElementType,
|
DenseElementsAttr::mapValues(Type newElementType,
|
||||||
function_ref<APInt(const APInt &)> mapping) const {
|
function_ref<APInt(const APInt &)> mapping) const {
|
||||||
return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
|
return llvm::cast<DenseIntElementsAttr>(*this).mapValues(newElementType, mapping);
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseElementsAttr DenseElementsAttr::mapValues(
|
DenseElementsAttr DenseElementsAttr::mapValues(
|
||||||
Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
|
Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
|
||||||
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
|
return llvm::cast<DenseFPElementsAttr>(*this).mapValues(newElementType, mapping);
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapedType DenseElementsAttr::getType() const {
|
ShapedType DenseElementsAttr::getType() const {
|
||||||
|
@ -88,45 +88,45 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
unsigned FloatType::getWidth() {
|
unsigned FloatType::getWidth() {
|
||||||
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
|
Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
|
||||||
return 8;
|
return 8;
|
||||||
if (isa<Float16Type, BFloat16Type>())
|
if (llvm::isa<Float16Type, BFloat16Type>(*this))
|
||||||
return 16;
|
return 16;
|
||||||
if (isa<Float32Type>())
|
if (llvm::isa<Float32Type>(*this))
|
||||||
return 32;
|
return 32;
|
||||||
if (isa<Float64Type>())
|
if (llvm::isa<Float64Type>(*this))
|
||||||
return 64;
|
return 64;
|
||||||
if (isa<Float80Type>())
|
if (llvm::isa<Float80Type>(*this))
|
||||||
return 80;
|
return 80;
|
||||||
if (isa<Float128Type>())
|
if (llvm::isa<Float128Type>(*this))
|
||||||
return 128;
|
return 128;
|
||||||
llvm_unreachable("unexpected float type");
|
llvm_unreachable("unexpected float type");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the floating semantics for the given type.
|
/// Returns the floating semantics for the given type.
|
||||||
const llvm::fltSemantics &FloatType::getFloatSemantics() {
|
const llvm::fltSemantics &FloatType::getFloatSemantics() {
|
||||||
if (isa<Float8E5M2Type>())
|
if (llvm::isa<Float8E5M2Type>(*this))
|
||||||
return APFloat::Float8E5M2();
|
return APFloat::Float8E5M2();
|
||||||
if (isa<Float8E4M3FNType>())
|
if (llvm::isa<Float8E4M3FNType>(*this))
|
||||||
return APFloat::Float8E4M3FN();
|
return APFloat::Float8E4M3FN();
|
||||||
if (isa<Float8E5M2FNUZType>())
|
if (llvm::isa<Float8E5M2FNUZType>(*this))
|
||||||
return APFloat::Float8E5M2FNUZ();
|
return APFloat::Float8E5M2FNUZ();
|
||||||
if (isa<Float8E4M3FNUZType>())
|
if (llvm::isa<Float8E4M3FNUZType>(*this))
|
||||||
return APFloat::Float8E4M3FNUZ();
|
return APFloat::Float8E4M3FNUZ();
|
||||||
if (isa<Float8E4M3B11FNUZType>())
|
if (llvm::isa<Float8E4M3B11FNUZType>(*this))
|
||||||
return APFloat::Float8E4M3B11FNUZ();
|
return APFloat::Float8E4M3B11FNUZ();
|
||||||
if (isa<BFloat16Type>())
|
if (llvm::isa<BFloat16Type>(*this))
|
||||||
return APFloat::BFloat();
|
return APFloat::BFloat();
|
||||||
if (isa<Float16Type>())
|
if (llvm::isa<Float16Type>(*this))
|
||||||
return APFloat::IEEEhalf();
|
return APFloat::IEEEhalf();
|
||||||
if (isa<Float32Type>())
|
if (llvm::isa<Float32Type>(*this))
|
||||||
return APFloat::IEEEsingle();
|
return APFloat::IEEEsingle();
|
||||||
if (isa<Float64Type>())
|
if (llvm::isa<Float64Type>(*this))
|
||||||
return APFloat::IEEEdouble();
|
return APFloat::IEEEdouble();
|
||||||
if (isa<Float80Type>())
|
if (llvm::isa<Float80Type>(*this))
|
||||||
return APFloat::x87DoubleExtended();
|
return APFloat::x87DoubleExtended();
|
||||||
if (isa<Float128Type>())
|
if (llvm::isa<Float128Type>(*this))
|
||||||
return APFloat::IEEEquad();
|
return APFloat::IEEEquad();
|
||||||
llvm_unreachable("non-floating point type used");
|
llvm_unreachable("non-floating point type used");
|
||||||
}
|
}
|
||||||
@ -269,21 +269,21 @@ Type TensorType::getElementType() const {
|
|||||||
[](auto type) { return type.getElementType(); });
|
[](auto type) { return type.getElementType(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
|
bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
|
||||||
|
|
||||||
ArrayRef<int64_t> TensorType::getShape() const {
|
ArrayRef<int64_t> TensorType::getShape() const {
|
||||||
return cast<RankedTensorType>().getShape();
|
return llvm::cast<RankedTensorType>(*this).getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
||||||
Type elementType) const {
|
Type elementType) const {
|
||||||
if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
|
if (auto unrankedTy = llvm::dyn_cast<UnrankedTensorType>(*this)) {
|
||||||
if (shape)
|
if (shape)
|
||||||
return RankedTensorType::get(*shape, elementType);
|
return RankedTensorType::get(*shape, elementType);
|
||||||
return UnrankedTensorType::get(elementType);
|
return UnrankedTensorType::get(elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rankedTy = cast<RankedTensorType>();
|
auto rankedTy = llvm::cast<RankedTensorType>(*this);
|
||||||
if (!shape)
|
if (!shape)
|
||||||
return RankedTensorType::get(rankedTy.getShape(), elementType,
|
return RankedTensorType::get(rankedTy.getShape(), elementType,
|
||||||
rankedTy.getEncoding());
|
rankedTy.getEncoding());
|
||||||
@ -356,15 +356,15 @@ Type BaseMemRefType::getElementType() const {
|
|||||||
[](auto type) { return type.getElementType(); });
|
[](auto type) { return type.getElementType(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
|
bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
|
||||||
|
|
||||||
ArrayRef<int64_t> BaseMemRefType::getShape() const {
|
ArrayRef<int64_t> BaseMemRefType::getShape() const {
|
||||||
return cast<MemRefType>().getShape();
|
return llvm::cast<MemRefType>(*this).getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
||||||
Type elementType) const {
|
Type elementType) const {
|
||||||
if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
|
if (auto unrankedTy = llvm::dyn_cast<UnrankedMemRefType>(*this)) {
|
||||||
if (!shape)
|
if (!shape)
|
||||||
return UnrankedMemRefType::get(elementType, getMemorySpace());
|
return UnrankedMemRefType::get(elementType, getMemorySpace());
|
||||||
MemRefType::Builder builder(*shape, elementType);
|
MemRefType::Builder builder(*shape, elementType);
|
||||||
@ -372,7 +372,7 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
|||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
MemRefType::Builder builder(cast<MemRefType>());
|
MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
|
||||||
if (shape)
|
if (shape)
|
||||||
builder.setShape(*shape);
|
builder.setShape(*shape);
|
||||||
builder.setElementType(elementType);
|
builder.setElementType(elementType);
|
||||||
@ -389,15 +389,15 @@ MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Attribute BaseMemRefType::getMemorySpace() const {
|
Attribute BaseMemRefType::getMemorySpace() const {
|
||||||
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
|
if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
|
||||||
return rankedMemRefTy.getMemorySpace();
|
return rankedMemRefTy.getMemorySpace();
|
||||||
return cast<UnrankedMemRefType>().getMemorySpace();
|
return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned BaseMemRefType::getMemorySpaceAsInt() const {
|
unsigned BaseMemRefType::getMemorySpaceAsInt() const {
|
||||||
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
|
if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
|
||||||
return rankedMemRefTy.getMemorySpaceAsInt();
|
return rankedMemRefTy.getMemorySpaceAsInt();
|
||||||
return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
|
return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -626,17 +626,17 @@ ValueRange::ValueRange(ResultRange values)
|
|||||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||||
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
|
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
|
||||||
ptrdiff_t index) {
|
ptrdiff_t index) {
|
||||||
if (const auto *value = owner.dyn_cast<const Value *>())
|
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
|
||||||
return {value + index};
|
return {value + index};
|
||||||
if (auto *operand = owner.dyn_cast<OpOperand *>())
|
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||||
return {operand + index};
|
return {operand + index};
|
||||||
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
|
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
|
||||||
}
|
}
|
||||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||||
Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
|
Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
|
||||||
if (const auto *value = owner.dyn_cast<const Value *>())
|
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
|
||||||
return value[index];
|
return value[index];
|
||||||
if (auto *operand = owner.dyn_cast<OpOperand *>())
|
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||||
return operand[index].get();
|
return operand[index].get();
|
||||||
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
|
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
|
||||||
}
|
}
|
||||||
|
@ -267,18 +267,18 @@ RegionRange::RegionRange(ArrayRef<Region *> regions)
|
|||||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||||
RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
|
RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
|
||||||
ptrdiff_t index) {
|
ptrdiff_t index) {
|
||||||
if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
|
if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
|
||||||
return region + index;
|
return region + index;
|
||||||
if (auto **region = owner.dyn_cast<Region **>())
|
if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
|
||||||
return region + index;
|
return region + index;
|
||||||
return &owner.get<Region *>()[index];
|
return &owner.get<Region *>()[index];
|
||||||
}
|
}
|
||||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||||
Region *RegionRange::dereference_iterator(const OwnerT &owner,
|
Region *RegionRange::dereference_iterator(const OwnerT &owner,
|
||||||
ptrdiff_t index) {
|
ptrdiff_t index) {
|
||||||
if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
|
if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
|
||||||
return region[index].get();
|
return region[index].get();
|
||||||
if (auto **region = owner.dyn_cast<Region **>())
|
if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
|
||||||
return region[index];
|
return region[index];
|
||||||
return &owner.get<Region *>()[index];
|
return &owner.get<Region *>()[index];
|
||||||
}
|
}
|
||||||
|
@ -551,7 +551,7 @@ struct SymbolScope {
|
|||||||
typename llvm::function_traits<CallbackT>::result_t,
|
typename llvm::function_traits<CallbackT>::result_t,
|
||||||
void>::value> * = nullptr>
|
void>::value> * = nullptr>
|
||||||
std::optional<WalkResult> walk(CallbackT cback) {
|
std::optional<WalkResult> walk(CallbackT cback) {
|
||||||
if (Region *region = limit.dyn_cast<Region *>())
|
if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
|
||||||
return walkSymbolUses(*region, cback);
|
return walkSymbolUses(*region, cback);
|
||||||
return walkSymbolUses(limit.get<Operation *>(), cback);
|
return walkSymbolUses(limit.get<Operation *>(), cback);
|
||||||
}
|
}
|
||||||
@ -571,7 +571,7 @@ struct SymbolScope {
|
|||||||
/// traversing into any nested symbol tables.
|
/// traversing into any nested symbol tables.
|
||||||
template <typename CallbackT>
|
template <typename CallbackT>
|
||||||
std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
|
std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
|
||||||
if (Region *region = limit.dyn_cast<Region *>())
|
if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
|
||||||
return ::walkSymbolTable(*region, cback);
|
return ::walkSymbolTable(*region, cback);
|
||||||
return ::walkSymbolTable(limit.get<Operation *>(), cback);
|
return ::walkSymbolTable(limit.get<Operation *>(), cback);
|
||||||
}
|
}
|
||||||
|
@ -27,9 +27,9 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
|
|||||||
if (count == 0)
|
if (count == 0)
|
||||||
return;
|
return;
|
||||||
ValueRange::OwnerT owner = values.begin().getBase();
|
ValueRange::OwnerT owner = values.begin().getBase();
|
||||||
if (auto *result = owner.dyn_cast<detail::OpResultImpl *>())
|
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(owner))
|
||||||
this->base = result;
|
this->base = result;
|
||||||
else if (auto *operand = owner.dyn_cast<OpOperand *>())
|
else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||||
this->base = operand;
|
this->base = operand;
|
||||||
else
|
else
|
||||||
this->base = owner.get<const Value *>();
|
this->base = owner.get<const Value *>();
|
||||||
@ -37,22 +37,22 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
|
|||||||
|
|
||||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||||
TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
|
TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
|
||||||
if (const auto *value = object.dyn_cast<const Value *>())
|
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
|
||||||
return {value + index};
|
return {value + index};
|
||||||
if (auto *operand = object.dyn_cast<OpOperand *>())
|
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
|
||||||
return {operand + index};
|
return {operand + index};
|
||||||
if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
|
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
|
||||||
return {result->getNextResultAtOffset(index)};
|
return {result->getNextResultAtOffset(index)};
|
||||||
return {object.dyn_cast<const Type *>() + index};
|
return {llvm::dyn_cast_if_present<const Type *>(object) + index};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||||
Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
|
Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
|
||||||
if (const auto *value = object.dyn_cast<const Value *>())
|
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
|
||||||
return (value + index)->getType();
|
return (value + index)->getType();
|
||||||
if (auto *operand = object.dyn_cast<OpOperand *>())
|
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
|
||||||
return (operand + index)->get().getType();
|
return (operand + index)->get().getType();
|
||||||
if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
|
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
|
||||||
return result->getNextResultAtOffset(index)->getType();
|
return result->getNextResultAtOffset(index)->getType();
|
||||||
return object.dyn_cast<const Type *>()[index];
|
return llvm::dyn_cast_if_present<const Type *>(object)[index];
|
||||||
}
|
}
|
||||||
|
@ -34,84 +34,94 @@ Type AbstractType::replaceImmediateSubElements(Type type,
|
|||||||
|
|
||||||
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
||||||
|
|
||||||
bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
|
bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
|
||||||
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
|
bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
|
||||||
bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
|
bool Type::isFloat8E5M2FNUZ() const {
|
||||||
bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
|
return llvm::isa<Float8E5M2FNUZType>(*this);
|
||||||
bool Type::isFloat8E4M3B11FNUZ() const { return isa<Float8E4M3B11FNUZType>(); }
|
}
|
||||||
bool Type::isBF16() const { return isa<BFloat16Type>(); }
|
bool Type::isFloat8E4M3FNUZ() const {
|
||||||
bool Type::isF16() const { return isa<Float16Type>(); }
|
return llvm::isa<Float8E4M3FNUZType>(*this);
|
||||||
bool Type::isF32() const { return isa<Float32Type>(); }
|
}
|
||||||
bool Type::isF64() const { return isa<Float64Type>(); }
|
bool Type::isFloat8E4M3B11FNUZ() const {
|
||||||
bool Type::isF80() const { return isa<Float80Type>(); }
|
return llvm::isa<Float8E4M3B11FNUZType>(*this);
|
||||||
bool Type::isF128() const { return isa<Float128Type>(); }
|
}
|
||||||
|
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
|
||||||
|
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
|
||||||
|
bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
|
||||||
|
bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
|
||||||
|
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
|
||||||
|
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
|
||||||
|
|
||||||
bool Type::isIndex() const { return isa<IndexType>(); }
|
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
|
||||||
|
|
||||||
/// Return true if this is an integer type with the specified width.
|
/// Return true if this is an integer type with the specified width.
|
||||||
bool Type::isInteger(unsigned width) const {
|
bool Type::isInteger(unsigned width) const {
|
||||||
if (auto intTy = dyn_cast<IntegerType>())
|
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||||
return intTy.getWidth() == width;
|
return intTy.getWidth() == width;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isSignlessInteger() const {
|
bool Type::isSignlessInteger() const {
|
||||||
if (auto intTy = dyn_cast<IntegerType>())
|
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||||
return intTy.isSignless();
|
return intTy.isSignless();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isSignlessInteger(unsigned width) const {
|
bool Type::isSignlessInteger(unsigned width) const {
|
||||||
if (auto intTy = dyn_cast<IntegerType>())
|
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||||
return intTy.isSignless() && intTy.getWidth() == width;
|
return intTy.isSignless() && intTy.getWidth() == width;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isSignedInteger() const {
|
bool Type::isSignedInteger() const {
|
||||||
if (auto intTy = dyn_cast<IntegerType>())
|
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||||
return intTy.isSigned();
|
return intTy.isSigned();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isSignedInteger(unsigned width) const {
|
bool Type::isSignedInteger(unsigned width) const {
|
||||||
if (auto intTy = dyn_cast<IntegerType>())
|
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||||
return intTy.isSigned() && intTy.getWidth() == width;
|
return intTy.isSigned() && intTy.getWidth() == width;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isUnsignedInteger() const {
|
bool Type::isUnsignedInteger() const {
|
||||||
if (auto intTy = dyn_cast<IntegerType>())
|
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||||
return intTy.isUnsigned();
|
return intTy.isUnsigned();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isUnsignedInteger(unsigned width) const {
|
bool Type::isUnsignedInteger(unsigned width) const {
|
||||||
if (auto intTy = dyn_cast<IntegerType>())
|
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||||
return intTy.isUnsigned() && intTy.getWidth() == width;
|
return intTy.isUnsigned() && intTy.getWidth() == width;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isSignlessIntOrIndex() const {
|
bool Type::isSignlessIntOrIndex() const {
|
||||||
return isSignlessInteger() || isa<IndexType>();
|
return isSignlessInteger() || llvm::isa<IndexType>(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isSignlessIntOrIndexOrFloat() const {
|
bool Type::isSignlessIntOrIndexOrFloat() const {
|
||||||
return isSignlessInteger() || isa<IndexType, FloatType>();
|
return isSignlessInteger() || llvm::isa<IndexType, FloatType>(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isSignlessIntOrFloat() const {
|
bool Type::isSignlessIntOrFloat() const {
|
||||||
return isSignlessInteger() || isa<FloatType>();
|
return isSignlessInteger() || llvm::isa<FloatType>(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::isIntOrIndex() const { return isa<IntegerType>() || isIndex(); }
|
bool Type::isIntOrIndex() const {
|
||||||
|
return llvm::isa<IntegerType>(*this) || isIndex();
|
||||||
|
}
|
||||||
|
|
||||||
bool Type::isIntOrFloat() const { return isa<IntegerType, FloatType>(); }
|
bool Type::isIntOrFloat() const {
|
||||||
|
return llvm::isa<IntegerType, FloatType>(*this);
|
||||||
|
}
|
||||||
|
|
||||||
bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); }
|
bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); }
|
||||||
|
|
||||||
unsigned Type::getIntOrFloatBitWidth() const {
|
unsigned Type::getIntOrFloatBitWidth() const {
|
||||||
assert(isIntOrFloat() && "only integers and floats have a bitwidth");
|
assert(isIntOrFloat() && "only integers and floats have a bitwidth");
|
||||||
if (auto intType = dyn_cast<IntegerType>())
|
if (auto intType = llvm::dyn_cast<IntegerType>(*this))
|
||||||
return intType.getWidth();
|
return intType.getWidth();
|
||||||
return cast<FloatType>().getWidth();
|
return llvm::cast<FloatType>(*this).getWidth();
|
||||||
}
|
}
|
||||||
|
@ -48,11 +48,11 @@ static void printBlock(llvm::raw_ostream &os, Block *block,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
|
void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
|
||||||
if (auto *op = this->dyn_cast<Operation *>())
|
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*this))
|
||||||
return printOp(os, op, flags);
|
return printOp(os, op, flags);
|
||||||
if (auto *region = this->dyn_cast<Region *>())
|
if (auto *region = llvm::dyn_cast_if_present<Region *>(*this))
|
||||||
return printRegion(os, region, flags);
|
return printRegion(os, region, flags);
|
||||||
if (auto *block = this->dyn_cast<Block *>())
|
if (auto *block = llvm::dyn_cast_if_present<Block *>(*this))
|
||||||
return printBlock(os, block, flags);
|
return printBlock(os, block, flags);
|
||||||
llvm_unreachable("unknown IRUnit");
|
llvm_unreachable("unknown IRUnit");
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,7 @@ using namespace mlir::detail;
|
|||||||
/// If this value is the result of an Operation, return the operation that
|
/// If this value is the result of an Operation, return the operation that
|
||||||
/// defines it.
|
/// defines it.
|
||||||
Operation *Value::getDefiningOp() const {
|
Operation *Value::getDefiningOp() const {
|
||||||
if (auto result = dyn_cast<OpResult>())
|
if (auto result = llvm::dyn_cast<OpResult>(*this))
|
||||||
return result.getOwner();
|
return result.getOwner();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -27,28 +27,28 @@ Location Value::getLoc() const {
|
|||||||
if (auto *op = getDefiningOp())
|
if (auto *op = getDefiningOp())
|
||||||
return op->getLoc();
|
return op->getLoc();
|
||||||
|
|
||||||
return cast<BlockArgument>().getLoc();
|
return llvm::cast<BlockArgument>(*this).getLoc();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Value::setLoc(Location loc) {
|
void Value::setLoc(Location loc) {
|
||||||
if (auto *op = getDefiningOp())
|
if (auto *op = getDefiningOp())
|
||||||
return op->setLoc(loc);
|
return op->setLoc(loc);
|
||||||
|
|
||||||
return cast<BlockArgument>().setLoc(loc);
|
return llvm::cast<BlockArgument>(*this).setLoc(loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the Region in which this Value is defined.
|
/// Return the Region in which this Value is defined.
|
||||||
Region *Value::getParentRegion() {
|
Region *Value::getParentRegion() {
|
||||||
if (auto *op = getDefiningOp())
|
if (auto *op = getDefiningOp())
|
||||||
return op->getParentRegion();
|
return op->getParentRegion();
|
||||||
return cast<BlockArgument>().getOwner()->getParent();
|
return llvm::cast<BlockArgument>(*this).getOwner()->getParent();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the Block in which this Value is defined.
|
/// Return the Block in which this Value is defined.
|
||||||
Block *Value::getParentBlock() {
|
Block *Value::getParentBlock() {
|
||||||
if (Operation *op = getDefiningOp())
|
if (Operation *op = getDefiningOp())
|
||||||
return op->getBlock();
|
return op->getBlock();
|
||||||
return cast<BlockArgument>().getOwner();
|
return llvm::cast<BlockArgument>(*this).getOwner();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -241,7 +241,7 @@ mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
|
|||||||
TypeID typeID) {
|
TypeID typeID) {
|
||||||
return llvm::to_vector<4>(llvm::make_filter_range(
|
return llvm::to_vector<4>(llvm::make_filter_range(
|
||||||
entries, [typeID](DataLayoutEntryInterface entry) {
|
entries, [typeID](DataLayoutEntryInterface entry) {
|
||||||
auto type = entry.getKey().dyn_cast<Type>();
|
auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
|
||||||
return type && type.getTypeID() == typeID;
|
return type && type.getTypeID() == typeID;
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
@ -521,7 +521,7 @@ void DataLayoutSpecInterface::bucketEntriesByType(
|
|||||||
DenseMap<TypeID, DataLayoutEntryList> &types,
|
DenseMap<TypeID, DataLayoutEntryList> &types,
|
||||||
DenseMap<StringAttr, DataLayoutEntryInterface> &ids) {
|
DenseMap<StringAttr, DataLayoutEntryInterface> &ids) {
|
||||||
for (DataLayoutEntryInterface entry : getEntries()) {
|
for (DataLayoutEntryInterface entry : getEntries()) {
|
||||||
if (auto type = entry.getKey().dyn_cast<Type>())
|
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()))
|
||||||
types[type.getTypeID()].push_back(entry);
|
types[type.getTypeID()].push_back(entry);
|
||||||
else
|
else
|
||||||
ids[entry.getKey().get<StringAttr>()] = entry;
|
ids[entry.getKey().get<StringAttr>()] = entry;
|
||||||
|
@ -68,7 +68,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
|
|||||||
bool ShapeAdaptor::hasRank() const {
|
bool ShapeAdaptor::hasRank() const {
|
||||||
if (val.isNull())
|
if (val.isNull())
|
||||||
return false;
|
return false;
|
||||||
if (auto t = val.dyn_cast<Type>())
|
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||||
return cast<ShapedType>(t).hasRank();
|
return cast<ShapedType>(t).hasRank();
|
||||||
if (val.is<Attribute>())
|
if (val.is<Attribute>())
|
||||||
return true;
|
return true;
|
||||||
@ -78,7 +78,7 @@ bool ShapeAdaptor::hasRank() const {
|
|||||||
Type ShapeAdaptor::getElementType() const {
|
Type ShapeAdaptor::getElementType() const {
|
||||||
if (val.isNull())
|
if (val.isNull())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (auto t = val.dyn_cast<Type>())
|
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||||
return cast<ShapedType>(t).getElementType();
|
return cast<ShapedType>(t).getElementType();
|
||||||
if (val.is<Attribute>())
|
if (val.is<Attribute>())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -87,10 +87,10 @@ Type ShapeAdaptor::getElementType() const {
|
|||||||
|
|
||||||
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
|
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
|
||||||
assert(hasRank());
|
assert(hasRank());
|
||||||
if (auto t = val.dyn_cast<Type>()) {
|
if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
|
||||||
ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
|
ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
|
||||||
res.assign(vals.begin(), vals.end());
|
res.assign(vals.begin(), vals.end());
|
||||||
} else if (auto attr = val.dyn_cast<Attribute>()) {
|
} else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
|
||||||
auto dattr = cast<DenseIntElementsAttr>(attr);
|
auto dattr = cast<DenseIntElementsAttr>(attr);
|
||||||
res.clear();
|
res.clear();
|
||||||
res.reserve(dattr.size());
|
res.reserve(dattr.size());
|
||||||
@ -110,9 +110,9 @@ void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
|
|||||||
|
|
||||||
int64_t ShapeAdaptor::getDimSize(int index) const {
|
int64_t ShapeAdaptor::getDimSize(int index) const {
|
||||||
assert(hasRank());
|
assert(hasRank());
|
||||||
if (auto t = val.dyn_cast<Type>())
|
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||||
return cast<ShapedType>(t).getDimSize(index);
|
return cast<ShapedType>(t).getDimSize(index);
|
||||||
if (auto attr = val.dyn_cast<Attribute>())
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
|
||||||
return cast<DenseIntElementsAttr>(attr)
|
return cast<DenseIntElementsAttr>(attr)
|
||||||
.getValues<APInt>()[index]
|
.getValues<APInt>()[index]
|
||||||
.getSExtValue();
|
.getSExtValue();
|
||||||
@ -122,9 +122,9 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
|
|||||||
|
|
||||||
int64_t ShapeAdaptor::getRank() const {
|
int64_t ShapeAdaptor::getRank() const {
|
||||||
assert(hasRank());
|
assert(hasRank());
|
||||||
if (auto t = val.dyn_cast<Type>())
|
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||||
return cast<ShapedType>(t).getRank();
|
return cast<ShapedType>(t).getRank();
|
||||||
if (auto attr = val.dyn_cast<Attribute>())
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
|
||||||
return cast<DenseIntElementsAttr>(attr).size();
|
return cast<DenseIntElementsAttr>(attr).size();
|
||||||
return val.get<ShapedTypeComponents *>()->getDims().size();
|
return val.get<ShapedTypeComponents *>()->getDims().size();
|
||||||
}
|
}
|
||||||
@ -133,9 +133,9 @@ bool ShapeAdaptor::hasStaticShape() const {
|
|||||||
if (!hasRank())
|
if (!hasRank())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (auto t = val.dyn_cast<Type>())
|
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||||
return cast<ShapedType>(t).hasStaticShape();
|
return cast<ShapedType>(t).hasStaticShape();
|
||||||
if (auto attr = val.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
|
||||||
auto dattr = cast<DenseIntElementsAttr>(attr);
|
auto dattr = cast<DenseIntElementsAttr>(attr);
|
||||||
for (auto index : dattr.getValues<APInt>())
|
for (auto index : dattr.getValues<APInt>())
|
||||||
if (ShapedType::isDynamic(index.getSExtValue()))
|
if (ShapedType::isDynamic(index.getSExtValue()))
|
||||||
@ -149,10 +149,10 @@ bool ShapeAdaptor::hasStaticShape() const {
|
|||||||
int64_t ShapeAdaptor::getNumElements() const {
|
int64_t ShapeAdaptor::getNumElements() const {
|
||||||
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
|
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
|
||||||
|
|
||||||
if (auto t = val.dyn_cast<Type>())
|
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||||
return cast<ShapedType>(t).getNumElements();
|
return cast<ShapedType>(t).getNumElements();
|
||||||
|
|
||||||
if (auto attr = val.dyn_cast<Attribute>()) {
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
|
||||||
auto dattr = cast<DenseIntElementsAttr>(attr);
|
auto dattr = cast<DenseIntElementsAttr>(attr);
|
||||||
int64_t num = 1;
|
int64_t num = 1;
|
||||||
for (auto index : dattr.getValues<APInt>()) {
|
for (auto index : dattr.getValues<APInt>()) {
|
||||||
|
@ -26,14 +26,14 @@ namespace mlir {
|
|||||||
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
||||||
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
||||||
// Case 1: Check for Constant integer.
|
// Case 1: Check for Constant integer.
|
||||||
if (auto val = ofr.dyn_cast<Value>()) {
|
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
|
||||||
APSInt intVal;
|
APSInt intVal;
|
||||||
if (matchPattern(val, m_ConstantInt(&intVal)))
|
if (matchPattern(val, m_ConstantInt(&intVal)))
|
||||||
return intVal.getSExtValue();
|
return intVal.getSExtValue();
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
// Case 2: Check for IntegerAttr.
|
// Case 2: Check for IntegerAttr.
|
||||||
Attribute attr = ofr.dyn_cast<Attribute>();
|
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
|
||||||
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
|
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
|
||||||
return intAttr.getValue().getSExtValue();
|
return intAttr.getValue().getSExtValue();
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
@ -99,7 +99,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
|
|||||||
}
|
}
|
||||||
|
|
||||||
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
|
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
|
||||||
if (Value value = ofr.dyn_cast<Value>())
|
if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
|
||||||
return getExpr(value, /*dim=*/std::nullopt);
|
return getExpr(value, /*dim=*/std::nullopt);
|
||||||
auto constInt = getConstantIntValue(ofr);
|
auto constInt = getConstantIntValue(ofr);
|
||||||
assert(constInt.has_value() && "expected Integer constant");
|
assert(constInt.has_value() && "expected Integer constant");
|
||||||
|
@ -26,7 +26,8 @@ struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
|
|||||||
const Pass &getPass() const { return pass; }
|
const Pass &getPass() const { return pass; }
|
||||||
Operation *getOp() const {
|
Operation *getOp() const {
|
||||||
ArrayRef<IRUnit> irUnits = getContextIRUnits();
|
ArrayRef<IRUnit> irUnits = getContextIRUnits();
|
||||||
return irUnits.empty() ? nullptr : irUnits[0].dyn_cast<Operation *>();
|
return irUnits.empty() ? nullptr
|
||||||
|
: llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -384,7 +384,7 @@ void Operator::populateTypeInferenceInfo(
|
|||||||
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
|
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
|
||||||
// Check for a non-variable length operand to use as the type anchor.
|
// Check for a non-variable length operand to use as the type anchor.
|
||||||
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
|
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
|
||||||
NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
|
NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
|
||||||
return operand && !operand->isVariableLength();
|
return operand && !operand->isVariableLength();
|
||||||
});
|
});
|
||||||
if (operandI == arguments.end())
|
if (operandI == arguments.end())
|
||||||
@ -824,7 +824,7 @@ StringRef Operator::getAssemblyFormat() const {
|
|||||||
void Operator::print(llvm::raw_ostream &os) const {
|
void Operator::print(llvm::raw_ostream &os) const {
|
||||||
os << "op '" << getOperationName() << "'\n";
|
os << "op '" << getOperationName() << "'\n";
|
||||||
for (Argument arg : arguments) {
|
for (Argument arg : arguments) {
|
||||||
if (auto *attr = arg.dyn_cast<NamedAttribute *>())
|
if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
|
||||||
os << "[attribute] " << attr->name << '\n';
|
os << "[attribute] " << attr->name << '\n';
|
||||||
else
|
else
|
||||||
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
|
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
|
||||||
|
@ -131,7 +131,7 @@ convertBranchWeights(std::optional<ElementsAttr> weights,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
SmallVector<uint32_t> weightValues;
|
SmallVector<uint32_t> weightValues;
|
||||||
weightValues.reserve(weights->size());
|
weightValues.reserve(weights->size());
|
||||||
for (APInt weight : weights->cast<DenseIntElementsAttr>())
|
for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
|
||||||
weightValues.push_back(weight.getLimitedValue());
|
weightValues.push_back(weight.getLimitedValue());
|
||||||
return llvm::MDBuilder(moduleTranslation.getLLVMContext())
|
return llvm::MDBuilder(moduleTranslation.getLLVMContext())
|
||||||
.createBranchWeights(weightValues);
|
.createBranchWeights(weightValues);
|
||||||
@ -330,7 +330,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
|
|||||||
auto *ty = llvm::cast<llvm::IntegerType>(
|
auto *ty = llvm::cast<llvm::IntegerType>(
|
||||||
moduleTranslation.convertType(switchOp.getValue().getType()));
|
moduleTranslation.convertType(switchOp.getValue().getType()));
|
||||||
for (auto i :
|
for (auto i :
|
||||||
llvm::zip(switchOp.getCaseValues()->cast<DenseIntElementsAttr>(),
|
llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
|
||||||
switchOp.getCaseDestinations()))
|
switchOp.getCaseDestinations()))
|
||||||
switchInst->addCase(
|
switchInst->addCase(
|
||||||
llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
|
llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user