[mlir][NFC] update mlir create APIs (34/n) (#150660)

See https://github.com/llvm/llvm-project/pull/147168 for more info.
This commit is contained in:
Maksim Levental 2025-07-25 12:36:54 -05:00 committed by GitHub
parent b46527645d
commit 258daf5395
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 214 additions and 254 deletions

View File

@ -402,8 +402,8 @@ public:
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
// Actual cast (may change bitwidth) // Actual cast (may change bitwidth)
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(), auto cast =
castDestType, actualOp); emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
// Cast to the expected output type // Cast to the expected output type
auto result = adaptValueType(cast, rewriter, opReturnType); auto result = adaptValueType(cast, rewriter, opReturnType);
@ -507,8 +507,8 @@ public:
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
Value arithmeticResult = rewriter.template create<EmitCOp>( Value arithmeticResult =
op.getLoc(), arithmeticType, lhs, rhs); EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type); Value result = adaptValueType(arithmeticResult, rewriter, type);
@ -547,8 +547,8 @@ public:
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
Value arithmeticResult = rewriter.template create<EmitCOp>( Value arithmeticResult =
op.getLoc(), arithmeticType, lhs, rhs); EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type); Value result = adaptValueType(arithmeticResult, rewriter, type);
@ -748,8 +748,8 @@ public:
} }
Value fpCastOperand = adaptor.getIn(); Value fpCastOperand = adaptor.getIn();
if (actualOperandType != operandType) { if (actualOperandType != operandType) {
fpCastOperand = rewriter.template create<emitc::CastOp>( fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
castOp.getLoc(), actualOperandType, fpCastOperand); actualOperandType, fpCastOperand);
} }
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);

View File

@ -68,9 +68,8 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
scf::YieldOp::create(rewriter, loc, acc); scf::YieldOp::create(rewriter, loc, acc);
}; };
auto size = rewriter auto size = scf::ForOp::create(rewriter, loc, zero, rank, one,
.create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), ValueRange(one), loopBody)
loopBody)
.getResult(0); .getResult(0);
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},

View File

@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
return emitError(loc, "Cannot create unreachable terminator for '") return emitError(loc, "Cannot create unreachable terminator for '")
<< parentOp->getName() << "'"; << parentOp->getName() << "'";
return builder return func::ReturnOp::create(
.create<func::ReturnOp>( builder, loc,
loc, llvm::map_to_vector(funcOp.getResultTypes(), llvm::map_to_vector(
[&](Type type) { funcOp.getResultTypes(),
return getUndefValue(loc, builder, type); [&](Type type) { return getUndefValue(loc, builder, type); }))
}))
.getOperation(); .getOperation();
} }

View File

@ -559,8 +559,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
builder, loc, builder.getI32Type(), builder, loc, builder.getI32Type(),
builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
return builder return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
.create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue) clusterSizeValue)
.getResult(); .getResult();
} }

View File

@ -272,14 +272,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary. // Allocate memory, copy, and free the source if necessary.
Value memory = Value memory =
toDynamic toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
? builder allocationSize)
.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) .getResult()
.getResult() : LLVM::AllocaOp::create(builder, loc, getPtrType(),
: LLVM::AllocaOp::create(builder, loc, getPtrType(), IntegerType::get(getContext(), 8),
IntegerType::get(getContext(), 8), allocationSize,
allocationSize, /*alignment=*/0);
/*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc); Value source = desc.memRefDescPtr(builder, loc);
LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic) if (!toDynamic)

View File

@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
if (!(ret = moduleOp.lookupSymbol<Op>(name))) { if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter); ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody()); rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
} }
return ret; return ret;
} }

View File

@ -575,8 +575,8 @@ private:
Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
getTypeConverter()->getIndexType(), getTypeConverter()->getIndexType(),
offsetPtr, idxPlusOne); offsetPtr, idxPlusOne);
return rewriter return LLVM::LoadOp::create(rewriter, loc,
.create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) getTypeConverter()->getIndexType(), sizePtr)
.getResult(); .getResult();
} }

View File

@ -1493,11 +1493,11 @@ public:
Value extended; Value extended;
if (op2TypeWidth < dstTypeWidth) { if (op2TypeWidth < dstTypeWidth) {
if (isUnsignedIntegerOrVector(op2Type)) { if (isUnsignedIntegerOrVector(op2Type)) {
extended = rewriter.template create<LLVM::ZExtOp>( extended =
loc, dstType, adaptor.getOperand2()); LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
} else { } else {
extended = rewriter.template create<LLVM::SExtOp>( extended =
loc, dstType, adaptor.getOperand2()); LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
} }
} else if (op2TypeWidth == dstTypeWidth) { } else if (op2TypeWidth == dstTypeWidth) {
extended = adaptor.getOperand2(); extended = adaptor.getOperand2();
@ -1505,8 +1505,8 @@ public:
return failure(); return failure();
} }
Value result = rewriter.template create<LLVMOp>( Value result =
loc, dstType, adaptor.getOperand1(), extended); LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
return success(); return success();
} }

View File

@ -177,9 +177,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
auto type = RankedTensorType::get({nSplits, 2}, i64); auto type = RankedTensorType::get({nSplits, 2}, i64);
Value resHaloSizes = Value resHaloSizes =
haloSizes.empty() haloSizes.empty()
? rewriter ? tensor::EmptyOp::create(rewriter, loc,
.create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, std::array<int64_t, 2>{0, 0}, i64)
i64)
.getResult() .getResult()
: tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
.getResult(); .getResult();
@ -306,13 +305,11 @@ public:
auto ctx = op.getContext(); auto ctx = op.getContext();
Value commWorld = Value commWorld =
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
auto rank = auto rank = mpi::CommRankOp::create(
rewriter rewriter, loc,
.create<mpi::CommRankOp>( TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
loc, commWorld)
TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, .getRank();
commWorld)
.getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank); rank);
return success(); return success();
@ -703,10 +700,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
// subviews need Index values // subviews need Index values
for (auto &sz : haloSizes) { for (auto &sz : haloSizes) {
if (auto value = dyn_cast<Value>(sz)) if (auto value = dyn_cast<Value>(sz))
sz = sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
rewriter value)
.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) .getResult();
.getResult();
} }
// most of the offset/size/stride data is the same for all dims // most of the offset/size/stride data is the same for all dims
@ -758,9 +754,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
// Get the linearized ids of the neighbors (down and up) for the // Get the linearized ids of the neighbors (down and up) for the
// given split // given split
auto tmp = rewriter auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
.create<NeighborsLinearIndicesOp>(loc, grid, myMultiIndex, myMultiIndex, splitAxes)
splitAxes)
.getResults(); .getResults();
// MPI operates on i32... // MPI operates on i32...
Value neighbourIDs[2] = { Value neighbourIDs[2] = {

View File

@ -569,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// to UIToFP. // to UIToFP.
if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) { if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
auto unrealizedCast = auto unrealizedCast =
rewriter UnrealizedConversionCastOp::create(
.create<UnrealizedConversionCastOp>( rewriter, loc,
loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
args[0])
.getResult(0); .getResult(0);
return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
unrealizedCast); unrealizedCast);
@ -868,14 +867,13 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
// Emit 'linalg.generic' op // Emit 'linalg.generic' op
auto resultTensor = auto resultTensor =
opBuilder linalg::GenericOp::create(
.create<linalg::GenericOp>( opBuilder, loc, outputTensor.getType(), operand, outputTensor,
loc, outputTensor.getType(), operand, outputTensor, affineMaps, affineMaps, getNParallelLoopsAttrs(rank),
getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { // Emit 'linalg.yield' op
// Emit 'linalg.yield' op linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); })
})
.getResult(0); .getResult(0);
// Cast to original operand type if necessary // Cast to original operand type if necessary
@ -1155,11 +1153,9 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
inputs.push_back(input); inputs.push_back(input);
// First fill the output buffer with the init value. // First fill the output buffer with the init value.
auto emptyTensor = auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
rewriter resultTy.getElementType(), dynDims)
.create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), .getResult();
dynDims)
.getResult();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr) if (!fillValueAttr)
@ -1167,10 +1163,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
op, "No initial value found for reduction operation"); op, "No initial value found for reduction operation");
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
auto filledTensor = rewriter auto filledTensor =
.create<linalg::FillOp>(loc, ValueRange{fillValue}, linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
ValueRange{emptyTensor}) ValueRange{emptyTensor})
.result(); .result();
outputs.push_back(filledTensor); outputs.push_back(filledTensor);
bool isNanIgnoreMode = false; bool isNanIgnoreMode = false;
@ -1186,14 +1182,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto trueAttr = rewriter.getBoolAttr(true); auto trueAttr = rewriter.getBoolAttr(true);
auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
auto emptyBoolTensor = auto emptyBoolTensor =
rewriter tensor::EmptyOp::create(rewriter, loc, reduceShape,
.create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(), trueValue.getType(), dynDims)
dynDims)
.getResult(); .getResult();
auto allResultsNaNTensor = auto allResultsNaNTensor =
rewriter linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
.create<linalg::FillOp>(loc, ValueRange{trueValue}, ValueRange{emptyBoolTensor})
ValueRange{emptyBoolTensor})
.result(); .result();
// Note that because the linalg::ReduceOp has two variadic arguments // Note that because the linalg::ReduceOp has two variadic arguments
// (inputs and outputs) and it has the SameVariadicOperandSize trait we // (inputs and outputs) and it has the SameVariadicOperandSize trait we
@ -1261,22 +1255,19 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false)); APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor = auto emptyNanTensor =
rewriter tensor::EmptyOp::create(rewriter, loc, reduceShape,
.create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), dynDims)
resultTy.getElementType(), dynDims)
.getResult(); .getResult();
auto nanFilledTensor = auto nanFilledTensor =
rewriter linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
.create<linalg::FillOp>(loc, ValueRange{nanValue}, ValueRange{emptyNanTensor})
ValueRange{emptyNanTensor})
.result(); .result();
// Create an empty tensor, non need to fill this since it will be // Create an empty tensor, non need to fill this since it will be
// overwritten by the select. // overwritten by the select.
auto finalEmptyTensor = auto finalEmptyTensor =
rewriter tensor::EmptyOp::create(rewriter, loc, reduceShape,
.create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), dynDims)
resultTy.getElementType(), dynDims)
.getResult(); .getResult();
// Do a selection between the tensors akin to: // Do a selection between the tensors akin to:
@ -1503,12 +1494,11 @@ public:
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.isUnsignedInteger()) { if (valueTy.isUnsignedInteger()) {
value = nestedBuilder value = UnrealizedConversionCastOp::create(
.create<UnrealizedConversionCastOp>( nestedBuilder, nestedLoc,
nestedLoc, nestedBuilder.getIntegerType(
nestedBuilder.getIntegerType( valueTy.getIntOrFloatBitWidth()),
valueTy.getIntOrFloatBitWidth()), value)
value)
.getResult(0); .getResult(0);
} }
if (valueTy.getIntOrFloatBitWidth() < 32) { if (valueTy.getIntOrFloatBitWidth() < 32) {
@ -1557,9 +1547,8 @@ public:
} }
if (outIntType.isUnsignedInteger()) { if (outIntType.isUnsignedInteger()) {
value = nestedBuilder value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
.create<UnrealizedConversionCastOp>(nestedLoc, outIntType, value)
outIntType, value)
.getResult(0); .getResult(0);
} }
linalg::YieldOp::create(nestedBuilder, loc, value); linalg::YieldOp::create(nestedBuilder, loc, value);
@ -2095,10 +2084,9 @@ public:
Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
// First fill the output buffer with the init value. // First fill the output buffer with the init value.
auto emptyTensor = rewriter auto emptyTensor = tensor::EmptyOp::create(
.create<tensor::EmptyOp>(loc, inputTy.getShape(), rewriter, loc, inputTy.getShape(),
inputTy.getElementType(), inputTy.getElementType(), ArrayRef<Value>({dynDims}))
ArrayRef<Value>({dynDims}))
.getResult(); .getResult();
SmallVector<AffineMap, 2> affineMaps = { SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.getMultiDimIdentityMap(resultTy.getRank())};
@ -2241,23 +2229,22 @@ public:
} }
// First fill the output buffer for the index. // First fill the output buffer for the index.
auto emptyTensorIdx = rewriter auto emptyTensorIdx =
.create<tensor::EmptyOp>(loc, resultTy.getShape(), tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
outElementTy, dynDims) outElementTy, dynDims)
.getResult(); .getResult();
auto fillValueIdx = arith::ConstantOp::create( auto fillValueIdx = arith::ConstantOp::create(
rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
auto filledTensorIdx = auto filledTensorIdx =
rewriter linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
.create<linalg::FillOp>(loc, ValueRange{fillValueIdx}, ValueRange{emptyTensorIdx})
ValueRange{emptyTensorIdx})
.result(); .result();
// Second fill the output buffer for the running max. // Second fill the output buffer for the running max.
auto emptyTensorMax = rewriter auto emptyTensorMax =
.create<tensor::EmptyOp>(loc, resultTy.getShape(), tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
inElementTy, dynDims) dynDims)
.getResult(); .getResult();
auto fillValueMaxAttr = auto fillValueMaxAttr =
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
@ -2268,9 +2255,8 @@ public:
auto fillValueMax = auto fillValueMax =
arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
auto filledTensorMax = auto filledTensorMax =
rewriter linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
.create<linalg::FillOp>(loc, ValueRange{fillValueMax}, ValueRange{emptyTensorMax})
ValueRange{emptyTensorMax})
.result(); .result();
// We need to reduce along the arg-max axis, with parallel operations along // We need to reduce along the arg-max axis, with parallel operations along
@ -2371,9 +2357,8 @@ public:
auto loc = op.getLoc(); auto loc = op.getLoc();
auto emptyTensor = auto emptyTensor =
rewriter tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
.create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy, resultElementTy, dynamicDims)
dynamicDims)
.getResult(); .getResult();
SmallVector<AffineMap, 2> affineMaps = { SmallVector<AffineMap, 2> affineMaps = {
@ -2448,10 +2433,10 @@ public:
} }
} }
auto emptyTensor = rewriter auto emptyTensor =
.create<tensor::EmptyOp>(loc, resultTy.getShape(), tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
resultElementTy, dynDims) resultElementTy, dynDims)
.getResult(); .getResult();
SmallVector<AffineMap, 2> affineMaps = { SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank()), rewriter.getMultiDimIdentityMap(resultTy.getRank()),
@ -2585,10 +2570,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
auto filledTensor = rewriter auto filledTensor =
.create<linalg::FillOp>(loc, ValueRange{fillValue}, linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
ValueRange{emptyTensor}) ValueRange{emptyTensor})
.result(); .result();
return filledTensor; return filledTensor;
} }

View File

@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
Value conv, Value result, Value conv, Value result,
ArrayRef<AffineMap> indexingMaps) { ArrayRef<AffineMap> indexingMaps) {
ShapedType resultTy = cast<ShapedType>(conv.getType()); ShapedType resultTy = cast<ShapedType>(conv.getType());
return rewriter return linalg::GenericOp::create(
.create<linalg::GenericOp>( rewriter, loc, resultTy, ValueRange({bias, conv}), result,
loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
getNParallelLoopsAttrs(resultTy.getRank()), [](OpBuilder &builder, Location loc, ValueRange args) {
[](OpBuilder &builder, Location loc, ValueRange args) { Value biasVal = args[0];
Value biasVal = args[0]; Type resType = args[1].getType();
Type resType = args[1].getType(); if (resType != biasVal.getType()) {
if (resType != biasVal.getType()) { biasVal =
biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal); arith::ExtSIOp::create(builder, loc, resType, biasVal);
} }
Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]); Value added =
linalg::YieldOp::create(builder, loc, added); arith::AddIOp::create(builder, loc, biasVal, args[1]);
}) linalg::YieldOp::create(builder, loc, added);
})
.getResult(0); .getResult(0);
} }
@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
// Build the broadcast-like operation as a linalg.generic. // Build the broadcast-like operation as a linalg.generic.
return rewriter return linalg::GenericOp::create(
.create<linalg::GenericOp>( rewriter, loc, resultTy, ValueRange({source}), result,
loc, resultTy, ValueRange({source}), result, indexingMaps, indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
getNParallelLoopsAttrs(resultTy.getRank()), [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
[&resultTy](OpBuilder &builder, Location loc, ValueRange args) { Value biasVal = args[0];
Value biasVal = args[0]; Type resType = args[1].getType();
Type resType = args[1].getType(); if (resType != biasVal.getType()) {
if (resType != biasVal.getType()) { biasVal =
biasVal = resultTy.getElementType().isFloat()
resultTy.getElementType().isFloat() ? arith::ExtFOp::create(builder, loc, resType, biasVal)
? arith::ExtFOp::create(builder, loc, resType, biasVal) .getResult()
.getResult() : arith::ExtSIOp::create(builder, loc, resType,
: arith::ExtSIOp::create(builder, loc, resType, biasVal) biasVal)
.getResult(); .getResult();
} }
linalg::YieldOp::create(builder, loc, biasVal); linalg::YieldOp::create(builder, loc, biasVal);
}) })
.getResult(0); .getResult(0);
} }
@ -397,21 +398,19 @@ public:
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
Value conv = Value conv = LinalgConvQOp::create(
rewriter rewriter, loc, resultTy,
.create<LinalgConvQOp>( ValueRange{input, weight, iZpVal, kZpVal},
loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, ValueRange{broadcastBias}, strideAttr, dilationAttr)
ValueRange{broadcastBias}, strideAttr, dilationAttr) ->getResult(0);
->getResult(0);
rewriter.replaceOp(op, conv); rewriter.replaceOp(op, conv);
return success(); return success();
} }
Value conv = rewriter Value conv = LinalgConvOp::create(
.create<LinalgConvOp>( rewriter, loc, accTy, ValueRange{input, weight},
loc, accTy, ValueRange{input, weight}, ValueRange{broadcastBias}, strideAttr, dilationAttr)
ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0); ->getResult(0);
// We may need to truncate back to the result type if the accumulator was // We may need to truncate back to the result type if the accumulator was
@ -529,9 +528,8 @@ public:
Value emptyTensor = tensor::EmptyOp::create( Value emptyTensor = tensor::EmptyOp::create(
rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims); rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr); Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
Value zeroTensor = rewriter Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{emptyTensor})
ValueRange{emptyTensor})
.result(); .result();
Value biasEmptyTensor = tensor::EmptyOp::create( Value biasEmptyTensor = tensor::EmptyOp::create(
@ -544,10 +542,9 @@ public:
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
if (hasNullZps) { if (hasNullZps) {
Value conv = rewriter Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
.create<linalg::DepthwiseConv2DNhwcHwcmOp>( rewriter, loc, linalgConvTy, ValueRange{input, weight},
loc, linalgConvTy, ValueRange{input, weight}, ValueRange{zeroTensor}, strideAttr, dilationAttr)
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0); .getResult(0);
// We may need to truncate back to the result type if the accumulator was // We may need to truncate back to the result type if the accumulator was
@ -565,22 +562,20 @@ public:
rewriter, loc, resultTy, conv, reassociationMap); rewriter, loc, resultTy, conv, reassociationMap);
Value result = Value result =
rewriter linalg::GenericOp::create(
.create<linalg::GenericOp>( rewriter, loc, resultTy, ValueRange({bias, convReshape}),
loc, resultTy, ValueRange({bias, convReshape}), biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
biasEmptyTensor, indexingMaps, [&](OpBuilder &nestedBuilder, Location nestedLoc,
getNParallelLoopsAttrs(resultRank), ValueRange args) {
[&](OpBuilder &nestedBuilder, Location nestedLoc, Value added;
ValueRange args) { if (llvm::isa<FloatType>(inputETy))
Value added; added = arith::AddFOp::create(nestedBuilder, loc, args[0],
if (llvm::isa<FloatType>(inputETy)) args[1]);
added = arith::AddFOp::create(nestedBuilder, loc, args[0], else
args[1]); added = arith::AddIOp::create(nestedBuilder, loc, args[0],
else args[1]);
added = arith::AddIOp::create(nestedBuilder, loc, args[0], linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
args[1]); })
linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
})
.getResult(0); .getResult(0);
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
} else { } else {
@ -588,12 +583,11 @@ public:
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal); IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
Value conv = Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
rewriter rewriter, loc, linalgConvTy,
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>( ValueRange{input, weight, iZpVal, kZpVal},
loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, ValueRange{zeroTensor}, strideAttr, dilationAttr)
ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0);
.getResult(0);
SmallVector<ReassociationExprs, 4> reassociationMap; SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = tensor::CollapseShapeOp::create( Value convReshape = tensor::CollapseShapeOp::create(
@ -639,9 +633,8 @@ public:
auto emptyTensor = auto emptyTensor =
tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
outputTy.getElementType(), filteredDims); outputTy.getElementType(), filteredDims);
Value zeroTensor = rewriter Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{emptyTensor})
ValueRange{emptyTensor})
.result(); .result();
FailureOr<int64_t> maybeAZp = op.getAZeroPoint(); FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
@ -910,20 +903,18 @@ public:
rewriter, loc, accTy.getShape(), accETy, dynamicDims); rewriter, loc, accTy.getShape(), accETy, dynamicDims);
Value filledEmptyTensor = Value filledEmptyTensor =
rewriter linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
.create<linalg::FillOp>(loc, ValueRange{initialValue}, ValueRange{poolEmptyTensor})
ValueRange{poolEmptyTensor})
.result(); .result();
Value fakeWindowDims = Value fakeWindowDims =
tensor::EmptyOp::create(rewriter, loc, kernel, accETy); tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
// Sum across the pooled region. // Sum across the pooled region.
Value poolingOp = rewriter Value poolingOp = linalg::PoolingNhwcSumOp::create(
.create<linalg::PoolingNhwcSumOp>( rewriter, loc, ArrayRef<Type>{accTy},
loc, ArrayRef<Type>{accTy}, ValueRange{paddedInput, fakeWindowDims},
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr, dilationAttr)
filledEmptyTensor, strideAttr, dilationAttr)
.getResult(0); .getResult(0);
// Normalize the summed value by the number of elements grouped in each // Normalize the summed value by the number of elements grouped in each
@ -1050,10 +1041,9 @@ public:
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
auto scaled = auto scaled =
rewriter tosa::ApplyScaleOp::create(
.create<tosa::ApplyScaleOp>( rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
loc, rewriter.getI32Type(), poolVal, multiplier, shift, shift, rewriter.getStringAttr("SINGLE_ROUND"))
rewriter.getStringAttr("SINGLE_ROUND"))
.getResult(); .getResult();
// If we have quantization information we need to apply output // If we have quantization information we need to apply output

View File

@ -482,14 +482,12 @@ struct CombineTransferReadOpTranspose final
permutationMap.compose(transferReadOp.getPermutationMap()); permutationMap.compose(transferReadOp.getPermutationMap());
auto loc = op.getLoc(); auto loc = op.getLoc();
Value result = Value result = vector::TransferReadOp::create(
rewriter rewriter, loc, resultType, transferReadOp.getBase(),
.create<vector::TransferReadOp>( transferReadOp.getIndices(), AffineMapAttr::get(newMap),
loc, resultType, transferReadOp.getBase(), transferReadOp.getPadding(), transferReadOp.getMask(),
transferReadOp.getIndices(), AffineMapAttr::get(newMap), transferReadOp.getInBoundsAttr())
transferReadOp.getPadding(), transferReadOp.getMask(), .getResult();
transferReadOp.getInBoundsAttr())
.getResult();
// Fuse through the integer extend op. // Fuse through the integer extend op.
if (extOp) { if (extOp) {

View File

@ -142,6 +142,7 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
// TODO: Implement the `convertInstruction` hooks in the // TODO: Implement the `convertInstruction` hooks in the
// `LLVMDialectLLVMIRImportInterface` and move the following include there. // `LLVMDialectLLVMIRImportInterface` and move the following include there.
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc" #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
return failure(); return failure();
} }
@ -1626,12 +1627,11 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
// Convert dso_local_equivalent. // Convert dso_local_equivalent.
if (auto *dsoLocalEquivalent = dyn_cast<llvm::DSOLocalEquivalent>(constant)) { if (auto *dsoLocalEquivalent = dyn_cast<llvm::DSOLocalEquivalent>(constant)) {
Type type = convertType(dsoLocalEquivalent->getType()); Type type = convertType(dsoLocalEquivalent->getType());
return builder return DSOLocalEquivalentOp::create(
.create<DSOLocalEquivalentOp>( builder, loc, type,
loc, type, FlatSymbolRefAttr::get(
FlatSymbolRefAttr::get( builder.getContext(),
builder.getContext(), dsoLocalEquivalent->getGlobalValue()->getName()))
dsoLocalEquivalent->getGlobalValue()->getName()))
.getResult(); .getResult();
} }
@ -1736,9 +1736,9 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName()); FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName());
auto blockTag = auto blockTag =
BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber()); BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber());
return builder return BlockAddressOp::create(
.create<BlockAddressOp>(loc, convertType(blockAddr->getType()), builder, loc, convertType(blockAddr->getType()),
BlockAddressAttr::get(context, fnSym, blockTag)) BlockAddressAttr::get(context, fnSym, blockTag))
.getRes(); .getRes();
} }
@ -2228,17 +2228,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (!resultTy) if (!resultTy)
return failure(); return failure();
ArrayAttr operandAttrs = convertAsmInlineOperandAttrs(*callInst); ArrayAttr operandAttrs = convertAsmInlineOperandAttrs(*callInst);
return builder return InlineAsmOp::create(
.create<InlineAsmOp>( builder, loc, resultTy, *operands,
loc, resultTy, *operands, builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getAsmString()), builder.getStringAttr(asmI->getConstraintString()),
builder.getStringAttr(asmI->getConstraintString()), asmI->hasSideEffects(), asmI->isAlignStack(),
asmI->hasSideEffects(), asmI->isAlignStack(), convertTailCallKindFromLLVM(callInst->getTailCallKind()),
convertTailCallKindFromLLVM(callInst->getTailCallKind()), AsmDialectAttr::get(
AsmDialectAttr::get( mlirModule.getContext(),
mlirModule.getContext(), convertAsmDialectFromLLVM(asmI->getDialect())),
convertAsmDialectFromLLVM(asmI->getDialect())), operandAttrs)
operandAttrs)
.getOperation(); .getOperation();
} }
bool isIncompatibleCall; bool isIncompatibleCall;

View File

@ -72,15 +72,14 @@ struct TestReshardingRewritePattern : OpRewritePattern<ShardOp> {
ShapedType sourceShardShape = ShapedType sourceShardShape =
shardShapedType(op.getResult().getType(), grid, op.getSharding()); shardShapedType(op.getResult().getType(), grid, op.getSharding());
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>( TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
builder UnrealizedConversionCastOp::create(builder, sourceShardShape,
.create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc()) op.getSrc())
->getResult(0)); ->getResult(0));
TypedValue<ShapedType> targetShard = TypedValue<ShapedType> targetShard =
reshard(builder, grid, op, targetShardOp, sourceShard); reshard(builder, grid, op, targetShardOp, sourceShard);
Value newTargetUnsharded = Value newTargetUnsharded =
builder UnrealizedConversionCastOp::create(
.create<UnrealizedConversionCastOp>( builder, targetShardOp.getResult().getType(), targetShard)
targetShardOp.getResult().getType(), targetShard)
->getResult(0); ->getResult(0);
rewriter.replaceAllUsesWith(targetShardOp.getResult(), rewriter.replaceAllUsesWith(targetShardOp.getResult(),
newTargetUnsharded); newTargetUnsharded);

View File

@ -1007,9 +1007,8 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
// This is a 1:N replacement. Insert a test.cast op. (That's what the // This is a 1:N replacement. Insert a test.cast op. (That's what the
// argument materialization used to do.) // argument materialization used to do.)
flattened.push_back( flattened.push_back(
rewriter TestCastOp::create(rewriter, op->getLoc(),
.create<TestCastOp>(op->getLoc(), op->getOperand(it.index()).getType(), range)
op->getOperand(it.index()).getType(), range)
.getResult()); .getResult());
} }
rewriter.replaceOpWithNewOp<TestValidOp>(op, TypeRange(), flattened, rewriter.replaceOpWithNewOp<TestValidOp>(op, TypeRange(), flattened,

View File

@ -569,10 +569,9 @@ static Value warpReduction(Location loc, OpBuilder &builder, Value input,
Value laneVal = vector::ReductionOp::create(builder, loc, kind, input); Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
// Parallel reduction using butterfly shuffles. // Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) { for (uint64_t i = 1; i < size; i <<= 1) {
Value shuffled = builder Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
.create<gpu::ShuffleOp>(loc, laneVal, i, /*width=*/size,
/*width=*/size, /*mode=*/gpu::ShuffleMode::XOR)
/*mode=*/gpu::ShuffleMode::XOR)
.getShuffleResult(); .getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
} }
@ -650,9 +649,8 @@ struct TestVectorDistribution
arith::IndexCastOp::create(builder, loc, i32Type, srcIdx); arith::IndexCastOp::create(builder, loc, i32Type, srcIdx);
Value warpSzI32 = arith::ConstantOp::create( Value warpSzI32 = arith::ConstantOp::create(
builder, loc, builder.getIntegerAttr(i32Type, warpSz)); builder, loc, builder.getIntegerAttr(i32Type, warpSz));
Value result = builder Value result = gpu::ShuffleOp::create(builder, loc, val, srcIdxI32,
.create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32, warpSzI32, gpu::ShuffleMode::IDX)
gpu::ShuffleMode::IDX)
.getResult(0); .getResult(0);
return result; return result;
}; };