Krzysztof Drewniak 31aa7f34e0
[mlir][Affine] Let affine.[de]linearize_index omit outer bounds (#116103)
The affine.delinearize_index and affine.linearize_index operations, as
currently defined, require providing a length N basis to [de]linearize N
values. The first value in this basis is never used during lowering and
is unused during lowering. (Note that, even though it isn't used during
lowering it can still be used to, for example, remove length-1 outputs
from a delinearize).

This dead value makes sense in the original context of these operations,
which is linearizing or de-linearizing indexes to memref<>s, vector<>s,
and other shaped types, where that outer bound is avaliable and may be
useful for analysis.

However, other usecases exist where the outer bound is not known. For
example:

    %thread_id_x = gpu.thread_id x : index
%0:3 = affine.delinearize_index %thread_id_x into (4, 16) : index,index,
index

In this code, we don't know the upper bound of the thread ID, but we do
want to construct the ?x4x16 grid of delinearized values in order to
further partition the GPU threads.

In order to support such usecases, we broaden the definition of
affine.delinearize_index and affine.linearize_index to make the outer
bound optional.

In the case of affine.delinearize_index, where the number of results is
a function of the size of the passed-in basis, we augment all existing
builders with a `hasOuterBound` argument, which, for backwards
compatibilty and to preserve the natural usage of the op, defaults to
`true`. If this flag is true, the op returns one result per basis
element, if it is false, it returns one extra result in position 0.

We also update existing canonicalization patterns (and move one of them
into the folder) to handle these cases. Note that disagreements about
the outer bound now no longer prevent delinearize/linearize
cancelations.
2024-11-18 15:41:54 -06:00

336 lines
13 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import func
from mlir.dialects import arith
from mlir.dialects import memref
from mlir.dialects import affine
import mlir.extras.types as T
def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f
# CHECK-LABEL: TEST: testAffineStoreOp
@constructAndPrintInModule
def testAffineStoreOp():
f32 = F32Type.get()
index_type = IndexType.get()
memref_type_out = MemRefType.get([12, 12], f32)
# CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
@func.FuncOp.from_py_func(index_type)
def affine_store_test(arg0):
# CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
mem = memref.AllocOp(memref_type_out, [], []).result
d0 = AffineDimExpr.get(0)
s0 = AffineSymbolExpr.get(0)
map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
# CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
a1 = arith.ConstantOp(f32, 2.1)
# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
return mem
# CHECK-LABEL: TEST: testAffineDelinearizeInfer
@constructAndPrintInModule
def testAffineDelinearizeInfer():
# CHECK: %[[C1:.*]] = arith.constant 1 : index
c1 = arith.ConstantOp(T.index(), 1)
# CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (2, 3) : index, index
two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [2, 3])
# CHECK-LABEL: TEST: testAffineLoadOp
@constructAndPrintInModule
def testAffineLoadOp():
f32 = F32Type.get()
index_type = IndexType.get()
memref_type_in = MemRefType.get([10, 10], f32)
# CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 {
@func.FuncOp.from_py_func(memref_type_in, index_type)
def affine_load_test(I, arg0):
d0 = AffineDimExpr.get(0)
s0 = AffineSymbolExpr.get(0)
map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
# CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32>
a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map)
return a1
# CHECK-LABEL: TEST: testAffineForOp
@constructAndPrintInModule
def testAffineForOp():
f32 = F32Type.get()
index_type = IndexType.get()
memref_type = MemRefType.get([1024], f32)
# CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)>
# CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)>
# CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) {
@func.FuncOp.from_py_func(memref_type)
def affine_for_op_test(buffer):
# CHECK: %[[C1:.*]] = arith.constant 1 : index
c1 = arith.ConstantOp(index_type, 1)
# CHECK: %[[C2:.*]] = arith.constant 2 : index
c2 = arith.ConstantOp(index_type, 2)
# CHECK: %[[C3:.*]] = arith.constant 3 : index
c3 = arith.ConstantOp(index_type, 3)
# CHECK: %[[C9:.*]] = arith.constant 9 : index
c9 = arith.ConstantOp(index_type, 9)
# CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32
ac0 = AffineConstantExpr.get(0)
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
s0 = AffineSymbolExpr.get(0)
lb = AffineMap.get(1, 1, [ac0, d0 + s0])
ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1])
sum_0 = arith.ConstantOp(f32, 0.0)
# CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) {
sum = affine.AffineForOp(
lb,
ub,
2,
iter_args=[sum_0],
lower_bound_operands=[c2, c3],
upper_bound_operands=[c9, c1],
)
with InsertionPoint(sum.body):
# CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
tmp = memref.LoadOp(buffer, [sum.induction_variable])
sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
affine.AffineYieldOp([sum_next])
# CHECK-LABEL: TEST: testAffineForOpErrors
@constructAndPrintInModule
def testAffineForOpErrors():
c1 = arith.ConstantOp(T.index(), 1)
c2 = arith.ConstantOp(T.index(), 2)
c3 = arith.ConstantOp(T.index(), 3)
d0 = AffineDimExpr.get(0)
try:
affine.AffineForOp(
c1,
c2,
1,
lower_bound_operands=[c3],
upper_bound_operands=[],
)
except ValueError as e:
assert (
e.args[0]
== "Either a concrete lower bound or an AffineMap in combination with lower bound operands, but not both, is supported."
)
try:
affine.AffineForOp(
AffineMap.get_constant(1),
c2,
1,
lower_bound_operands=[c3, c3],
upper_bound_operands=[],
)
except ValueError as e:
assert (
e.args[0]
== "Wrong number of lower bound operands passed to AffineForOp; Expected 0, got 2."
)
try:
two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [1, 1])
affine.AffineForOp(
two_indices,
c2,
1,
lower_bound_operands=[],
upper_bound_operands=[],
)
except ValueError as e:
assert e.args[0] == "Only a single concrete value is supported for lower bound."
try:
affine.AffineForOp(
1.0,
c2,
1,
lower_bound_operands=[],
upper_bound_operands=[],
)
except ValueError as e:
assert e.args[0] == "lower bound must be int | ResultValueT | AffineMap."
@constructAndPrintInModule
def testForSugar():
memref_t = T.memref(10, T.index())
range = affine.for_
# CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> (d0)>
# CHECK-LABEL: func.func @range_loop_1(
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
# CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to #[[$ATTR_2]](%[[VAL_1]]) {
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
def range_loop_1(lb, ub, memref_v):
for i in range(lb, ub, step=1):
add = arith.addi(i, i)
memref.store(add, memref_v, [i])
affine.yield_([])
# CHECK-LABEL: func.func @range_loop_2(
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
# CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to 10 {
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
def range_loop_2(lb, ub, memref_v):
for i in range(lb, 10, step=1):
add = arith.addi(i, i)
memref.store(add, memref_v, [i])
affine.yield_([])
# CHECK-LABEL: func.func @range_loop_3(
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
# CHECK: affine.for %[[VAL_3:.*]] = 0 to #[[$ATTR_2]](%[[VAL_1]]) {
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
def range_loop_3(lb, ub, memref_v):
for i in range(0, ub, step=1):
add = arith.addi(i, i)
memref.store(add, memref_v, [i])
affine.yield_([])
# CHECK-LABEL: func.func @range_loop_4(
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
def range_loop_4(lb, ub, memref_v):
for i in range(0, 10, step=1):
add = arith.addi(i, i)
memref.store(add, memref_v, [i])
affine.yield_([])
# CHECK-LABEL: func.func @range_loop_8(
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
# CHECK: %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
# CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
# CHECK: memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
# CHECK: affine.yield %[[VAL_5]] : memref<10xindex>
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
def range_loop_8(lb, ub, memref_v):
for i, it in range(0, 10, iter_args=[memref_v]):
add = arith.addi(i, i)
memref.store(add, it, [i])
affine.yield_([it])
# CHECK-LABEL: TEST: testAffineIfWithoutElse
@constructAndPrintInModule
def testAffineIfWithoutElse():
index = IndexType.get()
i32 = IntegerType.get_signless(32)
d0 = AffineDimExpr.get(0)
# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
cond = IntegerSet.get(1, 0, [d0 - 5], [False])
# CHECK-LABEL: func.func @simple_affine_if(
# CHECK-SAME: %[[VAL_0:.*]]: index) {
# CHECK: affine.if #[[$SET0]](%[[VAL_0]]) {
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
# CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(index)
def simple_affine_if(cond_operands):
if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
with InsertionPoint(if_op.then_block):
one = arith.ConstantOp(i32, 1)
add = arith.AddIOp(one, one)
affine.AffineYieldOp([])
return
# CHECK-LABEL: TEST: testAffineIfWithElse
@constructAndPrintInModule
def testAffineIfWithElse():
index = IndexType.get()
i32 = IntegerType.get_signless(32)
d0 = AffineDimExpr.get(0)
# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
cond = IntegerSet.get(1, 0, [d0 - 5], [False])
# CHECK-LABEL: func.func @simple_affine_if_else(
# CHECK-SAME: %[[VAL_0:.*]]: index) {
# CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
# CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32
# CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32
# CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
# CHECK: } else {
# CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32
# CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32
# CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
# CHECK: }
# CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(index)
def simple_affine_if_else(cond_operands):
if_op = affine.AffineIfOp(
cond, [i32, i32], cond_operands=[cond_operands], has_else=True
)
with InsertionPoint(if_op.then_block):
x_true = arith.ConstantOp(i32, 0)
y_true = arith.ConstantOp(i32, 1)
affine.AffineYieldOp([x_true, y_true])
with InsertionPoint(if_op.else_block):
x_false = arith.ConstantOp(i32, 2)
y_false = arith.ConstantOp(i32, 3)
affine.AffineYieldOp([x_false, y_false])
add = arith.AddIOp(if_op.results[0], if_op.results[1])
return