Add support of param type for transform.structured.tile_using_forall (#72097)
Make transform.structured.tile_using_forall be able to take param type tile sizes. Examples: ``` %tile_sizes = transform.param.constant 16 : i64 -> !transform.param<i64> transform.structured.tile_using_forall %matmul tile_sizes [%tile_sizes : !transform.param<i64>, 32] ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) ``` ``` %c10 = transform.param.constant 10 : i64 -> !transform.any_param %c20 = transform.param.constant 20 : i64 -> !transform.any_param %tile_sizes = transform.merge_handles %c10, %c20 : !transform.any_param transform.structured.tile_using_forall %matmul tile_sizes *(%tile_sizes : !transform.any_param) ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) ```
This commit is contained in:
parent
64a849a52e
commit
d439f3640b
@ -21,10 +21,10 @@ include "mlir/IR/RegionKindInterface.td"
|
||||
|
||||
// This is roughly similar to OpFoldResult assuming the handle produces a single
|
||||
// value in the payload IR.
|
||||
def TransformParamTypeOrAnyHandle : Type<
|
||||
def TransformAnyParamTypeOrAnyHandle : Type<
|
||||
Or<[TransformHandleTypeInterface.predicate,
|
||||
Transform_ParamType.predicate]>,
|
||||
"transform 'param' type or any handle type">;
|
||||
TransformParamTypeInterface.predicate]>,
|
||||
"transform any param type or any handle type">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Apply...PatternsOp
|
||||
@ -691,9 +691,9 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
|
||||
I64Attr:$dimension,
|
||||
I64Attr:$target_size,
|
||||
DefaultValuedAttr<I64Attr, "1">:$divisor);
|
||||
let results = (outs TransformParamTypeOrAnyHandle:$low_size,
|
||||
TransformParamTypeOrAnyHandle:$high_size,
|
||||
TransformParamTypeOrAnyHandle:$split_point);
|
||||
let results = (outs TransformAnyParamTypeOrAnyHandle:$low_size,
|
||||
TransformAnyParamTypeOrAnyHandle:$high_size,
|
||||
TransformAnyParamTypeOrAnyHandle:$split_point);
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:` custom<MultitileSizesTypes>("
|
||||
@ -1408,7 +1408,7 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
I64Attr:$dimension,
|
||||
Optional<TransformParamTypeOrAnyHandle>:$dynamic_split_point,
|
||||
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
|
||||
I64Attr:$static_split_point);
|
||||
let results = (outs TransformHandleTypeInterface:$first,
|
||||
TransformHandleTypeInterface:$second);
|
||||
@ -1857,7 +1857,7 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$dynamic_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
|
||||
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
|
||||
@ -1968,10 +1968,10 @@ def TileUsingForallOp :
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
Variadic<TransformHandleTypeInterface>:$num_threads,
|
||||
Variadic<TransformHandleTypeInterface>:$tile_sizes,
|
||||
Optional<TransformHandleTypeInterface>:$packed_num_threads,
|
||||
Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$num_threads,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$tile_sizes,
|
||||
Optional<TransformAnyParamTypeOrAnyHandle>:$packed_num_threads,
|
||||
Optional<TransformAnyParamTypeOrAnyHandle>:$packed_tile_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
|
||||
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
|
||||
|
@ -86,8 +86,9 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
|
||||
return cast<LinalgOp>(result->getOperation());
|
||||
}
|
||||
|
||||
/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
|
||||
/// to exactly one op with one index result, return that value.
|
||||
/// Assuming that `ofr` is an index attr or a param of index type
|
||||
/// or a transform dialect handle mapped to exactly one op
|
||||
/// with one index result, return that value.
|
||||
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
|
||||
transform::TransformState &state, TransformOpInterface transformOp,
|
||||
SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
|
||||
@ -98,12 +99,23 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
|
||||
result.push_back(ofr);
|
||||
continue;
|
||||
}
|
||||
auto payloadOps = state.getPayloadOps(ofr.get<Value>());
|
||||
|
||||
Value transformValue = ofr.get<Value>();
|
||||
if (isa<TransformParamTypeInterface>(transformValue.getType())) {
|
||||
ArrayRef<Attribute> params = state.getParams(transformValue);
|
||||
if (params.size() != 1)
|
||||
return transformOp.emitDefiniteFailure()
|
||||
<< "requires exactly one parameter associated";
|
||||
result.push_back(params[0]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto payloadOps = state.getPayloadOps(transformValue);
|
||||
if (!llvm::hasSingleElement(payloadOps)) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
transformOp.emitSilenceableError()
|
||||
<< "handle must be mapped to exactly one payload op";
|
||||
diag.attachNote(ofr.get<Value>().getLoc())
|
||||
diag.attachNote(transformValue.getLoc())
|
||||
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
|
||||
return diag;
|
||||
}
|
||||
@ -123,14 +135,27 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
// Given a list of OpFoldResults that are either index attrs or op
|
||||
// handles, return a list of OpFoldResults where all op handles are
|
||||
// replaced with the first (and only) OpResult of that payload op. (There
|
||||
// must be exactly one mapped payload op and it must have exactly one
|
||||
// index result.)
|
||||
// Given a list of params that are index attrs or a list of OpFoldResults
|
||||
// that are either index attrs or op handles, return a list of OpFoldResults
|
||||
// of index attrs or a list of OpFoldResults where all op handles are
|
||||
// replaced with the first (and only) OpResult of that payload op.
|
||||
// (There must be exactly one parameter associated with the AnyParamType or
|
||||
// one mapped payload op which must have exactly one index result.)
|
||||
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
|
||||
transform::TransformState &state, TransformOpInterface transformOp,
|
||||
SmallVector<OpFoldResult> &result, Value packedHandle) {
|
||||
if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
|
||||
ArrayRef<Attribute> params = state.getParams(packedHandle);
|
||||
for (auto param : params) {
|
||||
if (!isa<IntegerAttr>(param))
|
||||
return transformOp.emitDefiniteFailure()
|
||||
<< "expected the parameter to be associated with an integer "
|
||||
"attribute";
|
||||
result.push_back(param);
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
for (Operation *op : state.getPayloadOps(packedHandle)) {
|
||||
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// Offset per thread:
|
||||
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
|
||||
@ -451,3 +451,138 @@ module attributes {transform.with_named_sequence} {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
|
||||
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
|
||||
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
|
||||
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
|
||||
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
|
||||
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
|
||||
|
||||
// CHECK-LABEL: matmul_tile_size_dynamic(
|
||||
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK: %[[c1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[c0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
|
||||
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] :
|
||||
// CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
|
||||
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
|
||||
// CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
|
||||
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
|
||||
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
|
||||
// CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
|
||||
// CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
|
||||
// CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
|
||||
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
|
||||
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
|
||||
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: scf.forall.in_parallel
|
||||
// CHECK-NEXT: tensor.parallel_insert_slice
|
||||
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%sz = transform.param.constant 10 : i64 -> !transform.param<i64>
|
||||
%1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param<i64>, 20]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%c10 = transform.param.constant 10 : i64 -> !transform.param<i64>
|
||||
%c20 = transform.param.constant 20 : i64 -> !transform.param<i64>
|
||||
%sz = transform.merge_handles %c10, %c20 : !transform.param<i64>
|
||||
// expected-error @below {{requires exactly one parameter associated}}
|
||||
%1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param<i64>, 20]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
|
||||
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
|
||||
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
|
||||
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
|
||||
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
|
||||
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
|
||||
|
||||
// CHECK-LABEL: matmul_tile_size_dynamic(
|
||||
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK: %[[c1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[c0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
|
||||
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] :
|
||||
// CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
|
||||
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
|
||||
// CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
|
||||
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
|
||||
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
|
||||
// CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
|
||||
// CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
|
||||
// CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
|
||||
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
|
||||
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
|
||||
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: scf.forall.in_parallel
|
||||
// CHECK-NEXT: tensor.parallel_insert_slice
|
||||
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%c10 = transform.param.constant 10 : i64 -> !transform.any_param
|
||||
%c20 = transform.param.constant 20 : i64 -> !transform.any_param
|
||||
%sz = transform.merge_handles %c10, %c20 : !transform.any_param
|
||||
%1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%sz = transform.param.constant "[10 : i64, 20 : i64]" -> !transform.any_param
|
||||
// expected-error @below {{expected the parameter to be associated with an integer attribute}}
|
||||
%1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user