[mlir][NFC] update mlir
create APIs (34/n) (#150660)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
This commit is contained in:
parent
b46527645d
commit
258daf5395
@ -402,8 +402,8 @@ public:
|
|||||||
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
|
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
|
||||||
|
|
||||||
// Actual cast (may change bitwidth)
|
// Actual cast (may change bitwidth)
|
||||||
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
|
auto cast =
|
||||||
castDestType, actualOp);
|
emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
|
||||||
|
|
||||||
// Cast to the expected output type
|
// Cast to the expected output type
|
||||||
auto result = adaptValueType(cast, rewriter, opReturnType);
|
auto result = adaptValueType(cast, rewriter, opReturnType);
|
||||||
@ -507,8 +507,8 @@ public:
|
|||||||
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
|
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
|
||||||
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
|
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
|
||||||
|
|
||||||
Value arithmeticResult = rewriter.template create<EmitCOp>(
|
Value arithmeticResult =
|
||||||
op.getLoc(), arithmeticType, lhs, rhs);
|
EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
|
||||||
|
|
||||||
Value result = adaptValueType(arithmeticResult, rewriter, type);
|
Value result = adaptValueType(arithmeticResult, rewriter, type);
|
||||||
|
|
||||||
@ -547,8 +547,8 @@ public:
|
|||||||
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
|
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
|
||||||
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
|
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
|
||||||
|
|
||||||
Value arithmeticResult = rewriter.template create<EmitCOp>(
|
Value arithmeticResult =
|
||||||
op.getLoc(), arithmeticType, lhs, rhs);
|
EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
|
||||||
|
|
||||||
Value result = adaptValueType(arithmeticResult, rewriter, type);
|
Value result = adaptValueType(arithmeticResult, rewriter, type);
|
||||||
|
|
||||||
@ -748,8 +748,8 @@ public:
|
|||||||
}
|
}
|
||||||
Value fpCastOperand = adaptor.getIn();
|
Value fpCastOperand = adaptor.getIn();
|
||||||
if (actualOperandType != operandType) {
|
if (actualOperandType != operandType) {
|
||||||
fpCastOperand = rewriter.template create<emitc::CastOp>(
|
fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
|
||||||
castOp.getLoc(), actualOperandType, fpCastOperand);
|
actualOperandType, fpCastOperand);
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
|
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
|
||||||
|
|
||||||
|
@ -68,9 +68,8 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
|
|||||||
|
|
||||||
scf::YieldOp::create(rewriter, loc, acc);
|
scf::YieldOp::create(rewriter, loc, acc);
|
||||||
};
|
};
|
||||||
auto size = rewriter
|
auto size = scf::ForOp::create(rewriter, loc, zero, rank, one,
|
||||||
.create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
|
ValueRange(one), loopBody)
|
||||||
loopBody)
|
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
|
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
|
||||||
|
@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
|
|||||||
return emitError(loc, "Cannot create unreachable terminator for '")
|
return emitError(loc, "Cannot create unreachable terminator for '")
|
||||||
<< parentOp->getName() << "'";
|
<< parentOp->getName() << "'";
|
||||||
|
|
||||||
return builder
|
return func::ReturnOp::create(
|
||||||
.create<func::ReturnOp>(
|
builder, loc,
|
||||||
loc, llvm::map_to_vector(funcOp.getResultTypes(),
|
llvm::map_to_vector(
|
||||||
[&](Type type) {
|
funcOp.getResultTypes(),
|
||||||
return getUndefValue(loc, builder, type);
|
[&](Type type) { return getUndefValue(loc, builder, type); }))
|
||||||
}))
|
|
||||||
.getOperation();
|
.getOperation();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -559,8 +559,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
|
|||||||
builder, loc, builder.getI32Type(),
|
builder, loc, builder.getI32Type(),
|
||||||
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
|
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
|
||||||
|
|
||||||
return builder
|
return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
|
||||||
.create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
|
clusterSizeValue)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -272,9 +272,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
|
|||||||
|
|
||||||
// Allocate memory, copy, and free the source if necessary.
|
// Allocate memory, copy, and free the source if necessary.
|
||||||
Value memory =
|
Value memory =
|
||||||
toDynamic
|
toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
|
||||||
? builder
|
allocationSize)
|
||||||
.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
|
|
||||||
.getResult()
|
.getResult()
|
||||||
: LLVM::AllocaOp::create(builder, loc, getPtrType(),
|
: LLVM::AllocaOp::create(builder, loc, getPtrType(),
|
||||||
IntegerType::get(getContext(), 8),
|
IntegerType::get(getContext(), 8),
|
||||||
|
@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
|
|||||||
if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
|
if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
|
||||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||||
ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
|
ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -575,8 +575,8 @@ private:
|
|||||||
Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
|
Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
|
||||||
getTypeConverter()->getIndexType(),
|
getTypeConverter()->getIndexType(),
|
||||||
offsetPtr, idxPlusOne);
|
offsetPtr, idxPlusOne);
|
||||||
return rewriter
|
return LLVM::LoadOp::create(rewriter, loc,
|
||||||
.create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
|
getTypeConverter()->getIndexType(), sizePtr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1493,11 +1493,11 @@ public:
|
|||||||
Value extended;
|
Value extended;
|
||||||
if (op2TypeWidth < dstTypeWidth) {
|
if (op2TypeWidth < dstTypeWidth) {
|
||||||
if (isUnsignedIntegerOrVector(op2Type)) {
|
if (isUnsignedIntegerOrVector(op2Type)) {
|
||||||
extended = rewriter.template create<LLVM::ZExtOp>(
|
extended =
|
||||||
loc, dstType, adaptor.getOperand2());
|
LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
|
||||||
} else {
|
} else {
|
||||||
extended = rewriter.template create<LLVM::SExtOp>(
|
extended =
|
||||||
loc, dstType, adaptor.getOperand2());
|
LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
|
||||||
}
|
}
|
||||||
} else if (op2TypeWidth == dstTypeWidth) {
|
} else if (op2TypeWidth == dstTypeWidth) {
|
||||||
extended = adaptor.getOperand2();
|
extended = adaptor.getOperand2();
|
||||||
@ -1505,8 +1505,8 @@ public:
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value result = rewriter.template create<LLVMOp>(
|
Value result =
|
||||||
loc, dstType, adaptor.getOperand1(), extended);
|
LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -177,9 +177,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
|
|||||||
auto type = RankedTensorType::get({nSplits, 2}, i64);
|
auto type = RankedTensorType::get({nSplits, 2}, i64);
|
||||||
Value resHaloSizes =
|
Value resHaloSizes =
|
||||||
haloSizes.empty()
|
haloSizes.empty()
|
||||||
? rewriter
|
? tensor::EmptyOp::create(rewriter, loc,
|
||||||
.create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
|
std::array<int64_t, 2>{0, 0}, i64)
|
||||||
i64)
|
|
||||||
.getResult()
|
.getResult()
|
||||||
: tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
|
: tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
|
||||||
.getResult();
|
.getResult();
|
||||||
@ -306,10 +305,8 @@ public:
|
|||||||
auto ctx = op.getContext();
|
auto ctx = op.getContext();
|
||||||
Value commWorld =
|
Value commWorld =
|
||||||
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
|
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
|
||||||
auto rank =
|
auto rank = mpi::CommRankOp::create(
|
||||||
rewriter
|
rewriter, loc,
|
||||||
.create<mpi::CommRankOp>(
|
|
||||||
loc,
|
|
||||||
TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
|
TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
|
||||||
commWorld)
|
commWorld)
|
||||||
.getRank();
|
.getRank();
|
||||||
@ -703,9 +700,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
|
|||||||
// subviews need Index values
|
// subviews need Index values
|
||||||
for (auto &sz : haloSizes) {
|
for (auto &sz : haloSizes) {
|
||||||
if (auto value = dyn_cast<Value>(sz))
|
if (auto value = dyn_cast<Value>(sz))
|
||||||
sz =
|
sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
|
||||||
rewriter
|
value)
|
||||||
.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
|
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -758,9 +754,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
|
|||||||
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
|
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
|
||||||
// Get the linearized ids of the neighbors (down and up) for the
|
// Get the linearized ids of the neighbors (down and up) for the
|
||||||
// given split
|
// given split
|
||||||
auto tmp = rewriter
|
auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
|
||||||
.create<NeighborsLinearIndicesOp>(loc, grid, myMultiIndex,
|
myMultiIndex, splitAxes)
|
||||||
splitAxes)
|
|
||||||
.getResults();
|
.getResults();
|
||||||
// MPI operates on i32...
|
// MPI operates on i32...
|
||||||
Value neighbourIDs[2] = {
|
Value neighbourIDs[2] = {
|
||||||
|
@ -569,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
|
|||||||
// to UIToFP.
|
// to UIToFP.
|
||||||
if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
|
if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
|
||||||
auto unrealizedCast =
|
auto unrealizedCast =
|
||||||
rewriter
|
UnrealizedConversionCastOp::create(
|
||||||
.create<UnrealizedConversionCastOp>(
|
rewriter, loc,
|
||||||
loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
|
rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
|
||||||
args[0])
|
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
|
return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
|
||||||
unrealizedCast);
|
unrealizedCast);
|
||||||
@ -868,10 +867,9 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
|
|||||||
|
|
||||||
// Emit 'linalg.generic' op
|
// Emit 'linalg.generic' op
|
||||||
auto resultTensor =
|
auto resultTensor =
|
||||||
opBuilder
|
linalg::GenericOp::create(
|
||||||
.create<linalg::GenericOp>(
|
opBuilder, loc, outputTensor.getType(), operand, outputTensor,
|
||||||
loc, outputTensor.getType(), operand, outputTensor, affineMaps,
|
affineMaps, getNParallelLoopsAttrs(rank),
|
||||||
getNParallelLoopsAttrs(rank),
|
|
||||||
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
|
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
|
||||||
// Emit 'linalg.yield' op
|
// Emit 'linalg.yield' op
|
||||||
linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
|
linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
|
||||||
@ -1155,10 +1153,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
|
|||||||
inputs.push_back(input);
|
inputs.push_back(input);
|
||||||
|
|
||||||
// First fill the output buffer with the init value.
|
// First fill the output buffer with the init value.
|
||||||
auto emptyTensor =
|
auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
|
||||||
rewriter
|
resultTy.getElementType(), dynDims)
|
||||||
.create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
|
|
||||||
dynDims)
|
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
|
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
|
||||||
@ -1167,8 +1163,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
|
|||||||
op, "No initial value found for reduction operation");
|
op, "No initial value found for reduction operation");
|
||||||
|
|
||||||
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
|
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
|
||||||
auto filledTensor = rewriter
|
auto filledTensor =
|
||||||
.create<linalg::FillOp>(loc, ValueRange{fillValue},
|
linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
|
||||||
ValueRange{emptyTensor})
|
ValueRange{emptyTensor})
|
||||||
.result();
|
.result();
|
||||||
outputs.push_back(filledTensor);
|
outputs.push_back(filledTensor);
|
||||||
@ -1186,13 +1182,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
|
|||||||
auto trueAttr = rewriter.getBoolAttr(true);
|
auto trueAttr = rewriter.getBoolAttr(true);
|
||||||
auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
|
auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
|
||||||
auto emptyBoolTensor =
|
auto emptyBoolTensor =
|
||||||
rewriter
|
tensor::EmptyOp::create(rewriter, loc, reduceShape,
|
||||||
.create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
|
trueValue.getType(), dynDims)
|
||||||
dynDims)
|
|
||||||
.getResult();
|
.getResult();
|
||||||
auto allResultsNaNTensor =
|
auto allResultsNaNTensor =
|
||||||
rewriter
|
linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
|
||||||
.create<linalg::FillOp>(loc, ValueRange{trueValue},
|
|
||||||
ValueRange{emptyBoolTensor})
|
ValueRange{emptyBoolTensor})
|
||||||
.result();
|
.result();
|
||||||
// Note that because the linalg::ReduceOp has two variadic arguments
|
// Note that because the linalg::ReduceOp has two variadic arguments
|
||||||
@ -1261,21 +1255,18 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
|
|||||||
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
|
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
|
||||||
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
|
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
|
||||||
auto emptyNanTensor =
|
auto emptyNanTensor =
|
||||||
rewriter
|
tensor::EmptyOp::create(rewriter, loc, reduceShape,
|
||||||
.create<tensor::EmptyOp>(loc, reduceShape,
|
|
||||||
resultTy.getElementType(), dynDims)
|
resultTy.getElementType(), dynDims)
|
||||||
.getResult();
|
.getResult();
|
||||||
auto nanFilledTensor =
|
auto nanFilledTensor =
|
||||||
rewriter
|
linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
|
||||||
.create<linalg::FillOp>(loc, ValueRange{nanValue},
|
|
||||||
ValueRange{emptyNanTensor})
|
ValueRange{emptyNanTensor})
|
||||||
.result();
|
.result();
|
||||||
|
|
||||||
// Create an empty tensor, non need to fill this since it will be
|
// Create an empty tensor, non need to fill this since it will be
|
||||||
// overwritten by the select.
|
// overwritten by the select.
|
||||||
auto finalEmptyTensor =
|
auto finalEmptyTensor =
|
||||||
rewriter
|
tensor::EmptyOp::create(rewriter, loc, reduceShape,
|
||||||
.create<tensor::EmptyOp>(loc, reduceShape,
|
|
||||||
resultTy.getElementType(), dynDims)
|
resultTy.getElementType(), dynDims)
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
@ -1503,9 +1494,8 @@ public:
|
|||||||
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
|
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
|
||||||
|
|
||||||
if (valueTy.isUnsignedInteger()) {
|
if (valueTy.isUnsignedInteger()) {
|
||||||
value = nestedBuilder
|
value = UnrealizedConversionCastOp::create(
|
||||||
.create<UnrealizedConversionCastOp>(
|
nestedBuilder, nestedLoc,
|
||||||
nestedLoc,
|
|
||||||
nestedBuilder.getIntegerType(
|
nestedBuilder.getIntegerType(
|
||||||
valueTy.getIntOrFloatBitWidth()),
|
valueTy.getIntOrFloatBitWidth()),
|
||||||
value)
|
value)
|
||||||
@ -1557,8 +1547,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (outIntType.isUnsignedInteger()) {
|
if (outIntType.isUnsignedInteger()) {
|
||||||
value = nestedBuilder
|
value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
|
||||||
.create<UnrealizedConversionCastOp>(nestedLoc,
|
|
||||||
outIntType, value)
|
outIntType, value)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
}
|
}
|
||||||
@ -2095,10 +2084,9 @@ public:
|
|||||||
Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
|
Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
|
||||||
|
|
||||||
// First fill the output buffer with the init value.
|
// First fill the output buffer with the init value.
|
||||||
auto emptyTensor = rewriter
|
auto emptyTensor = tensor::EmptyOp::create(
|
||||||
.create<tensor::EmptyOp>(loc, inputTy.getShape(),
|
rewriter, loc, inputTy.getShape(),
|
||||||
inputTy.getElementType(),
|
inputTy.getElementType(), ArrayRef<Value>({dynDims}))
|
||||||
ArrayRef<Value>({dynDims}))
|
|
||||||
.getResult();
|
.getResult();
|
||||||
SmallVector<AffineMap, 2> affineMaps = {
|
SmallVector<AffineMap, 2> affineMaps = {
|
||||||
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
|
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
|
||||||
@ -2241,22 +2229,21 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First fill the output buffer for the index.
|
// First fill the output buffer for the index.
|
||||||
auto emptyTensorIdx = rewriter
|
auto emptyTensorIdx =
|
||||||
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
|
tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
|
||||||
outElementTy, dynDims)
|
outElementTy, dynDims)
|
||||||
.getResult();
|
.getResult();
|
||||||
auto fillValueIdx = arith::ConstantOp::create(
|
auto fillValueIdx = arith::ConstantOp::create(
|
||||||
rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
|
rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
|
||||||
auto filledTensorIdx =
|
auto filledTensorIdx =
|
||||||
rewriter
|
linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
|
||||||
.create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
|
|
||||||
ValueRange{emptyTensorIdx})
|
ValueRange{emptyTensorIdx})
|
||||||
.result();
|
.result();
|
||||||
|
|
||||||
// Second fill the output buffer for the running max.
|
// Second fill the output buffer for the running max.
|
||||||
auto emptyTensorMax = rewriter
|
auto emptyTensorMax =
|
||||||
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
|
tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
|
||||||
inElementTy, dynDims)
|
dynDims)
|
||||||
.getResult();
|
.getResult();
|
||||||
auto fillValueMaxAttr =
|
auto fillValueMaxAttr =
|
||||||
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
|
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
|
||||||
@ -2268,8 +2255,7 @@ public:
|
|||||||
auto fillValueMax =
|
auto fillValueMax =
|
||||||
arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
|
arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
|
||||||
auto filledTensorMax =
|
auto filledTensorMax =
|
||||||
rewriter
|
linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
|
||||||
.create<linalg::FillOp>(loc, ValueRange{fillValueMax},
|
|
||||||
ValueRange{emptyTensorMax})
|
ValueRange{emptyTensorMax})
|
||||||
.result();
|
.result();
|
||||||
|
|
||||||
@ -2371,9 +2357,8 @@ public:
|
|||||||
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto emptyTensor =
|
auto emptyTensor =
|
||||||
rewriter
|
tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
|
||||||
.create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
|
resultElementTy, dynamicDims)
|
||||||
dynamicDims)
|
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
SmallVector<AffineMap, 2> affineMaps = {
|
SmallVector<AffineMap, 2> affineMaps = {
|
||||||
@ -2448,8 +2433,8 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto emptyTensor = rewriter
|
auto emptyTensor =
|
||||||
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
|
tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
|
||||||
resultElementTy, dynDims)
|
resultElementTy, dynDims)
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
@ -2585,8 +2570,8 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
|
|||||||
tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
|
tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
|
||||||
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
|
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
|
||||||
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
|
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
|
||||||
auto filledTensor = rewriter
|
auto filledTensor =
|
||||||
.create<linalg::FillOp>(loc, ValueRange{fillValue},
|
linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
|
||||||
ValueRange{emptyTensor})
|
ValueRange{emptyTensor})
|
||||||
.result();
|
.result();
|
||||||
return filledTensor;
|
return filledTensor;
|
||||||
|
@ -64,17 +64,18 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
|
|||||||
Value conv, Value result,
|
Value conv, Value result,
|
||||||
ArrayRef<AffineMap> indexingMaps) {
|
ArrayRef<AffineMap> indexingMaps) {
|
||||||
ShapedType resultTy = cast<ShapedType>(conv.getType());
|
ShapedType resultTy = cast<ShapedType>(conv.getType());
|
||||||
return rewriter
|
return linalg::GenericOp::create(
|
||||||
.create<linalg::GenericOp>(
|
rewriter, loc, resultTy, ValueRange({bias, conv}), result,
|
||||||
loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
|
indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
|
||||||
getNParallelLoopsAttrs(resultTy.getRank()),
|
|
||||||
[](OpBuilder &builder, Location loc, ValueRange args) {
|
[](OpBuilder &builder, Location loc, ValueRange args) {
|
||||||
Value biasVal = args[0];
|
Value biasVal = args[0];
|
||||||
Type resType = args[1].getType();
|
Type resType = args[1].getType();
|
||||||
if (resType != biasVal.getType()) {
|
if (resType != biasVal.getType()) {
|
||||||
biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal);
|
biasVal =
|
||||||
|
arith::ExtSIOp::create(builder, loc, resType, biasVal);
|
||||||
}
|
}
|
||||||
Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]);
|
Value added =
|
||||||
|
arith::AddIOp::create(builder, loc, biasVal, args[1]);
|
||||||
linalg::YieldOp::create(builder, loc, added);
|
linalg::YieldOp::create(builder, loc, added);
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
@ -124,10 +125,9 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
|
|||||||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
|
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
|
||||||
|
|
||||||
// Build the broadcast-like operation as a linalg.generic.
|
// Build the broadcast-like operation as a linalg.generic.
|
||||||
return rewriter
|
return linalg::GenericOp::create(
|
||||||
.create<linalg::GenericOp>(
|
rewriter, loc, resultTy, ValueRange({source}), result,
|
||||||
loc, resultTy, ValueRange({source}), result, indexingMaps,
|
indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
|
||||||
getNParallelLoopsAttrs(resultTy.getRank()),
|
|
||||||
[&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
|
[&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
|
||||||
Value biasVal = args[0];
|
Value biasVal = args[0];
|
||||||
Type resType = args[1].getType();
|
Type resType = args[1].getType();
|
||||||
@ -136,7 +136,8 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
|
|||||||
resultTy.getElementType().isFloat()
|
resultTy.getElementType().isFloat()
|
||||||
? arith::ExtFOp::create(builder, loc, resType, biasVal)
|
? arith::ExtFOp::create(builder, loc, resType, biasVal)
|
||||||
.getResult()
|
.getResult()
|
||||||
: arith::ExtSIOp::create(builder, loc, resType, biasVal)
|
: arith::ExtSIOp::create(builder, loc, resType,
|
||||||
|
biasVal)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
linalg::YieldOp::create(builder, loc, biasVal);
|
linalg::YieldOp::create(builder, loc, biasVal);
|
||||||
@ -397,10 +398,9 @@ public:
|
|||||||
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
|
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
|
||||||
auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
|
auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
|
||||||
|
|
||||||
Value conv =
|
Value conv = LinalgConvQOp::create(
|
||||||
rewriter
|
rewriter, loc, resultTy,
|
||||||
.create<LinalgConvQOp>(
|
ValueRange{input, weight, iZpVal, kZpVal},
|
||||||
loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
|
|
||||||
ValueRange{broadcastBias}, strideAttr, dilationAttr)
|
ValueRange{broadcastBias}, strideAttr, dilationAttr)
|
||||||
->getResult(0);
|
->getResult(0);
|
||||||
|
|
||||||
@ -408,9 +408,8 @@ public:
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value conv = rewriter
|
Value conv = LinalgConvOp::create(
|
||||||
.create<LinalgConvOp>(
|
rewriter, loc, accTy, ValueRange{input, weight},
|
||||||
loc, accTy, ValueRange{input, weight},
|
|
||||||
ValueRange{broadcastBias}, strideAttr, dilationAttr)
|
ValueRange{broadcastBias}, strideAttr, dilationAttr)
|
||||||
->getResult(0);
|
->getResult(0);
|
||||||
|
|
||||||
@ -529,8 +528,7 @@ public:
|
|||||||
Value emptyTensor = tensor::EmptyOp::create(
|
Value emptyTensor = tensor::EmptyOp::create(
|
||||||
rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
|
rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
|
||||||
Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
|
Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
|
||||||
Value zeroTensor = rewriter
|
Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
|
||||||
.create<linalg::FillOp>(loc, ValueRange{zero},
|
|
||||||
ValueRange{emptyTensor})
|
ValueRange{emptyTensor})
|
||||||
.result();
|
.result();
|
||||||
|
|
||||||
@ -544,9 +542,8 @@ public:
|
|||||||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
|
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
|
||||||
|
|
||||||
if (hasNullZps) {
|
if (hasNullZps) {
|
||||||
Value conv = rewriter
|
Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
|
||||||
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
|
rewriter, loc, linalgConvTy, ValueRange{input, weight},
|
||||||
loc, linalgConvTy, ValueRange{input, weight},
|
|
||||||
ValueRange{zeroTensor}, strideAttr, dilationAttr)
|
ValueRange{zeroTensor}, strideAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
@ -565,11 +562,9 @@ public:
|
|||||||
rewriter, loc, resultTy, conv, reassociationMap);
|
rewriter, loc, resultTy, conv, reassociationMap);
|
||||||
|
|
||||||
Value result =
|
Value result =
|
||||||
rewriter
|
linalg::GenericOp::create(
|
||||||
.create<linalg::GenericOp>(
|
rewriter, loc, resultTy, ValueRange({bias, convReshape}),
|
||||||
loc, resultTy, ValueRange({bias, convReshape}),
|
biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
|
||||||
biasEmptyTensor, indexingMaps,
|
|
||||||
getNParallelLoopsAttrs(resultRank),
|
|
||||||
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||||
ValueRange args) {
|
ValueRange args) {
|
||||||
Value added;
|
Value added;
|
||||||
@ -588,10 +583,9 @@ public:
|
|||||||
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
|
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
|
||||||
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
|
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
|
||||||
auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
|
auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
|
||||||
Value conv =
|
Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
|
||||||
rewriter
|
rewriter, loc, linalgConvTy,
|
||||||
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
|
ValueRange{input, weight, iZpVal, kZpVal},
|
||||||
loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
|
|
||||||
ValueRange{zeroTensor}, strideAttr, dilationAttr)
|
ValueRange{zeroTensor}, strideAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
SmallVector<ReassociationExprs, 4> reassociationMap;
|
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||||
@ -639,8 +633,7 @@ public:
|
|||||||
auto emptyTensor =
|
auto emptyTensor =
|
||||||
tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
|
tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
|
||||||
outputTy.getElementType(), filteredDims);
|
outputTy.getElementType(), filteredDims);
|
||||||
Value zeroTensor = rewriter
|
Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
|
||||||
.create<linalg::FillOp>(loc, ValueRange{zero},
|
|
||||||
ValueRange{emptyTensor})
|
ValueRange{emptyTensor})
|
||||||
.result();
|
.result();
|
||||||
|
|
||||||
@ -910,8 +903,7 @@ public:
|
|||||||
rewriter, loc, accTy.getShape(), accETy, dynamicDims);
|
rewriter, loc, accTy.getShape(), accETy, dynamicDims);
|
||||||
|
|
||||||
Value filledEmptyTensor =
|
Value filledEmptyTensor =
|
||||||
rewriter
|
linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
|
||||||
.create<linalg::FillOp>(loc, ValueRange{initialValue},
|
|
||||||
ValueRange{poolEmptyTensor})
|
ValueRange{poolEmptyTensor})
|
||||||
.result();
|
.result();
|
||||||
|
|
||||||
@ -919,9 +911,8 @@ public:
|
|||||||
tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
|
tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
|
||||||
|
|
||||||
// Sum across the pooled region.
|
// Sum across the pooled region.
|
||||||
Value poolingOp = rewriter
|
Value poolingOp = linalg::PoolingNhwcSumOp::create(
|
||||||
.create<linalg::PoolingNhwcSumOp>(
|
rewriter, loc, ArrayRef<Type>{accTy},
|
||||||
loc, ArrayRef<Type>{accTy},
|
|
||||||
ValueRange{paddedInput, fakeWindowDims},
|
ValueRange{paddedInput, fakeWindowDims},
|
||||||
filledEmptyTensor, strideAttr, dilationAttr)
|
filledEmptyTensor, strideAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
@ -1050,10 +1041,9 @@ public:
|
|||||||
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
|
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
|
||||||
|
|
||||||
auto scaled =
|
auto scaled =
|
||||||
rewriter
|
tosa::ApplyScaleOp::create(
|
||||||
.create<tosa::ApplyScaleOp>(
|
rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
|
||||||
loc, rewriter.getI32Type(), poolVal, multiplier, shift,
|
shift, rewriter.getStringAttr("SINGLE_ROUND"))
|
||||||
rewriter.getStringAttr("SINGLE_ROUND"))
|
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
// If we have quantization information we need to apply output
|
// If we have quantization information we need to apply output
|
||||||
|
@ -482,10 +482,8 @@ struct CombineTransferReadOpTranspose final
|
|||||||
permutationMap.compose(transferReadOp.getPermutationMap());
|
permutationMap.compose(transferReadOp.getPermutationMap());
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value result =
|
Value result = vector::TransferReadOp::create(
|
||||||
rewriter
|
rewriter, loc, resultType, transferReadOp.getBase(),
|
||||||
.create<vector::TransferReadOp>(
|
|
||||||
loc, resultType, transferReadOp.getBase(),
|
|
||||||
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
|
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
|
||||||
transferReadOp.getPadding(), transferReadOp.getMask(),
|
transferReadOp.getPadding(), transferReadOp.getMask(),
|
||||||
transferReadOp.getInBoundsAttr())
|
transferReadOp.getInBoundsAttr())
|
||||||
|
@ -142,6 +142,7 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
|
|||||||
// TODO: Implement the `convertInstruction` hooks in the
|
// TODO: Implement the `convertInstruction` hooks in the
|
||||||
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
|
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
|
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1626,9 +1627,8 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
|
|||||||
// Convert dso_local_equivalent.
|
// Convert dso_local_equivalent.
|
||||||
if (auto *dsoLocalEquivalent = dyn_cast<llvm::DSOLocalEquivalent>(constant)) {
|
if (auto *dsoLocalEquivalent = dyn_cast<llvm::DSOLocalEquivalent>(constant)) {
|
||||||
Type type = convertType(dsoLocalEquivalent->getType());
|
Type type = convertType(dsoLocalEquivalent->getType());
|
||||||
return builder
|
return DSOLocalEquivalentOp::create(
|
||||||
.create<DSOLocalEquivalentOp>(
|
builder, loc, type,
|
||||||
loc, type,
|
|
||||||
FlatSymbolRefAttr::get(
|
FlatSymbolRefAttr::get(
|
||||||
builder.getContext(),
|
builder.getContext(),
|
||||||
dsoLocalEquivalent->getGlobalValue()->getName()))
|
dsoLocalEquivalent->getGlobalValue()->getName()))
|
||||||
@ -1736,8 +1736,8 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
|
|||||||
FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName());
|
FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName());
|
||||||
auto blockTag =
|
auto blockTag =
|
||||||
BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber());
|
BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber());
|
||||||
return builder
|
return BlockAddressOp::create(
|
||||||
.create<BlockAddressOp>(loc, convertType(blockAddr->getType()),
|
builder, loc, convertType(blockAddr->getType()),
|
||||||
BlockAddressAttr::get(context, fnSym, blockTag))
|
BlockAddressAttr::get(context, fnSym, blockTag))
|
||||||
.getRes();
|
.getRes();
|
||||||
}
|
}
|
||||||
@ -2228,9 +2228,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
|
|||||||
if (!resultTy)
|
if (!resultTy)
|
||||||
return failure();
|
return failure();
|
||||||
ArrayAttr operandAttrs = convertAsmInlineOperandAttrs(*callInst);
|
ArrayAttr operandAttrs = convertAsmInlineOperandAttrs(*callInst);
|
||||||
return builder
|
return InlineAsmOp::create(
|
||||||
.create<InlineAsmOp>(
|
builder, loc, resultTy, *operands,
|
||||||
loc, resultTy, *operands,
|
|
||||||
builder.getStringAttr(asmI->getAsmString()),
|
builder.getStringAttr(asmI->getAsmString()),
|
||||||
builder.getStringAttr(asmI->getConstraintString()),
|
builder.getStringAttr(asmI->getConstraintString()),
|
||||||
asmI->hasSideEffects(), asmI->isAlignStack(),
|
asmI->hasSideEffects(), asmI->isAlignStack(),
|
||||||
|
@ -72,15 +72,14 @@ struct TestReshardingRewritePattern : OpRewritePattern<ShardOp> {
|
|||||||
ShapedType sourceShardShape =
|
ShapedType sourceShardShape =
|
||||||
shardShapedType(op.getResult().getType(), grid, op.getSharding());
|
shardShapedType(op.getResult().getType(), grid, op.getSharding());
|
||||||
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
|
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
|
||||||
builder
|
UnrealizedConversionCastOp::create(builder, sourceShardShape,
|
||||||
.create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
|
op.getSrc())
|
||||||
->getResult(0));
|
->getResult(0));
|
||||||
TypedValue<ShapedType> targetShard =
|
TypedValue<ShapedType> targetShard =
|
||||||
reshard(builder, grid, op, targetShardOp, sourceShard);
|
reshard(builder, grid, op, targetShardOp, sourceShard);
|
||||||
Value newTargetUnsharded =
|
Value newTargetUnsharded =
|
||||||
builder
|
UnrealizedConversionCastOp::create(
|
||||||
.create<UnrealizedConversionCastOp>(
|
builder, targetShardOp.getResult().getType(), targetShard)
|
||||||
targetShardOp.getResult().getType(), targetShard)
|
|
||||||
->getResult(0);
|
->getResult(0);
|
||||||
rewriter.replaceAllUsesWith(targetShardOp.getResult(),
|
rewriter.replaceAllUsesWith(targetShardOp.getResult(),
|
||||||
newTargetUnsharded);
|
newTargetUnsharded);
|
||||||
|
@ -1007,8 +1007,7 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
|
|||||||
// This is a 1:N replacement. Insert a test.cast op. (That's what the
|
// This is a 1:N replacement. Insert a test.cast op. (That's what the
|
||||||
// argument materialization used to do.)
|
// argument materialization used to do.)
|
||||||
flattened.push_back(
|
flattened.push_back(
|
||||||
rewriter
|
TestCastOp::create(rewriter, op->getLoc(),
|
||||||
.create<TestCastOp>(op->getLoc(),
|
|
||||||
op->getOperand(it.index()).getType(), range)
|
op->getOperand(it.index()).getType(), range)
|
||||||
.getResult());
|
.getResult());
|
||||||
}
|
}
|
||||||
|
@ -569,8 +569,7 @@ static Value warpReduction(Location loc, OpBuilder &builder, Value input,
|
|||||||
Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
|
Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
|
||||||
// Parallel reduction using butterfly shuffles.
|
// Parallel reduction using butterfly shuffles.
|
||||||
for (uint64_t i = 1; i < size; i <<= 1) {
|
for (uint64_t i = 1; i < size; i <<= 1) {
|
||||||
Value shuffled = builder
|
Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
|
||||||
.create<gpu::ShuffleOp>(loc, laneVal, i,
|
|
||||||
/*width=*/size,
|
/*width=*/size,
|
||||||
/*mode=*/gpu::ShuffleMode::XOR)
|
/*mode=*/gpu::ShuffleMode::XOR)
|
||||||
.getShuffleResult();
|
.getShuffleResult();
|
||||||
@ -650,9 +649,8 @@ struct TestVectorDistribution
|
|||||||
arith::IndexCastOp::create(builder, loc, i32Type, srcIdx);
|
arith::IndexCastOp::create(builder, loc, i32Type, srcIdx);
|
||||||
Value warpSzI32 = arith::ConstantOp::create(
|
Value warpSzI32 = arith::ConstantOp::create(
|
||||||
builder, loc, builder.getIntegerAttr(i32Type, warpSz));
|
builder, loc, builder.getIntegerAttr(i32Type, warpSz));
|
||||||
Value result = builder
|
Value result = gpu::ShuffleOp::create(builder, loc, val, srcIdxI32,
|
||||||
.create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
|
warpSzI32, gpu::ShuffleMode::IDX)
|
||||||
gpu::ShuffleMode::IDX)
|
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user