301 lines
11 KiB
Python
301 lines
11 KiB
Python
"""Python2 and 3 test for the MLIR EDSC C API and Python bindings"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import unittest
|
|
|
|
import google_mlir.bindings.python.pybind as E
|
|
|
|
class EdscTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.module = E.MLIRModule()
|
|
self.boolType = self.module.make_scalar_type("i", 1)
|
|
self.i32Type = self.module.make_scalar_type("i", 32)
|
|
self.f32Type = self.module.make_scalar_type("f32")
|
|
self.indexType = self.module.make_index_type()
|
|
|
|
def testBindables(self):
|
|
with E.ContextManager():
|
|
i = E.Expr(E.Bindable(self.i32Type))
|
|
self.assertIn("$1", i.__str__())
|
|
|
|
def testOneExpr(self):
|
|
with E.ContextManager():
|
|
i, lb, ub = list(
|
|
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
|
|
expr = E.Mul(i, E.Add(lb, ub))
|
|
str = expr.__str__()
|
|
self.assertIn("($1 * ($2 + $3))", str)
|
|
|
|
def testOneLoop(self):
|
|
with E.ContextManager():
|
|
i, lb, ub, step = list(
|
|
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
|
loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
|
|
str = loop.__str__()
|
|
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
|
self.assertIn(" = ($2 + $3)", str)
|
|
|
|
def testTwoLoops(self):
|
|
with E.ContextManager():
|
|
i, lb, ub, step = list(
|
|
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
|
loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])])
|
|
str = loop.__str__()
|
|
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
|
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
|
self.assertIn("$5 = $1;", str)
|
|
|
|
def testNestedLoops(self):
|
|
with E.ContextManager():
|
|
i, lb, ub = list(
|
|
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
|
step = E.ConstantInteger(self.indexType, 42)
|
|
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
|
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
|
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
|
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
|
loop = E.For(ivs, lbs, ubs, steps, [
|
|
E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]),
|
|
])
|
|
str = loop.__str__()
|
|
self.assertIn("for($5 = $9 to $13 step $17) {", str)
|
|
self.assertIn("for($6 = $10 to $14 step $18) {", str)
|
|
self.assertIn("for($7 = $11 to $15 step $19) {", str)
|
|
self.assertIn("for($8 = $12 to $16 step $20) {", str)
|
|
self.assertIn("for($1 = $2 to $3 step 42) {", str)
|
|
self.assertIn("= (($3 * 42) + $2 * -1);", str)
|
|
|
|
def testMaxMinLoop(self):
|
|
with E.ContextManager():
|
|
i = E.Expr(E.Bindable(self.indexType))
|
|
step = E.Expr(E.Bindable(self.indexType))
|
|
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
|
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
|
loop = E.For(i, E.Max(lbs), E.Min(ubs), step, [])
|
|
s = str(loop)
|
|
self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)",
|
|
s)
|
|
|
|
def testIndexed(self):
|
|
with E.ContextManager():
|
|
i, j, k = list(
|
|
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
|
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
|
|
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
|
|
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
|
|
str = stmt.__str__()
|
|
self.assertIn(" = store(", str)
|
|
|
|
def testMatmul(self):
|
|
with E.ContextManager():
|
|
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
|
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
|
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
|
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
|
i, j, k = ivs[0], ivs[1], ivs[2]
|
|
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
|
|
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
|
|
loop = E.For(
|
|
ivs, lbs, ubs, steps,
|
|
[C.store([i, j],
|
|
C.load([i, j]) + A.load([i, k]) * B.load([k, j]))])
|
|
str = loop.__str__()
|
|
self.assertIn("for($1 = $4 to $7 step $10) {", str)
|
|
self.assertIn("for($2 = $5 to $8 step $11) {", str)
|
|
self.assertIn("for($3 = $6 to $9 step $12) {", str)
|
|
self.assertIn(" = store", str)
|
|
|
|
def testArithmetic(self):
|
|
with E.ContextManager():
|
|
i, j, k, l = list(
|
|
map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)]))
|
|
stmt = i % j + j * k - l / k
|
|
str = stmt.__str__()
|
|
self.assertIn("((($1 % $2) + ($2 * $3)) - ($4 / $3))", str)
|
|
|
|
def testBoolean(self):
|
|
with E.ContextManager():
|
|
i, j, k, l = list(
|
|
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(4)]))
|
|
stmt1 = (i < j) & (j >= k)
|
|
stmt2 = ~(stmt1 | (k == l))
|
|
str = stmt2.__str__()
|
|
# Note that "a | b" is currently implemented as ~(~a && ~b) and "~a" is
|
|
# currently implemented as "constant 1 - a", which leads to this
|
|
# expression.
|
|
self.assertIn(
|
|
"(1 - (1 - ((1 - (($1 < $2) && ($2 >= $3))) && (1 - ($3 == $4)))))",
|
|
str)
|
|
|
|
def testSelect(self):
|
|
with E.ContextManager():
|
|
i, j, k = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
|
|
stmt = E.Select(i > j, i, j)
|
|
str = stmt.__str__()
|
|
self.assertIn("select(($1 > $2), $1, $2)", str)
|
|
|
|
def testCall(self):
|
|
with E.ContextManager():
|
|
module = E.MLIRModule()
|
|
f32 = module.make_scalar_type("f32")
|
|
func, arg = [E.Expr(E.Bindable(f32)) for _ in range(2)]
|
|
code = func(arg, result=f32)
|
|
self.assertIn("@$1($2)", str(code))
|
|
|
|
def testBlock(self):
|
|
with E.ContextManager():
|
|
i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)]))
|
|
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
|
|
str = stmt.__str__()
|
|
self.assertIn("^bb:", str)
|
|
self.assertIn(" = ($1 + $2)", str)
|
|
self.assertIn(" = ($1 - $2)", str)
|
|
|
|
def testMLIRScalarTypes(self):
|
|
module = E.MLIRModule()
|
|
t = module.make_scalar_type("bf16")
|
|
self.assertIn("bf16", t.__str__())
|
|
t = module.make_scalar_type("f16")
|
|
self.assertIn("f16", t.__str__())
|
|
t = module.make_scalar_type("f32")
|
|
self.assertIn("f32", t.__str__())
|
|
t = module.make_scalar_type("f64")
|
|
self.assertIn("f64", t.__str__())
|
|
t = module.make_scalar_type("i", 1)
|
|
self.assertIn("i1", t.__str__())
|
|
t = module.make_scalar_type("i", 8)
|
|
self.assertIn("i8", t.__str__())
|
|
t = module.make_scalar_type("i", 32)
|
|
self.assertIn("i32", t.__str__())
|
|
t = module.make_scalar_type("i", 123)
|
|
self.assertIn("i123", t.__str__())
|
|
t = module.make_scalar_type("index")
|
|
self.assertIn("index", t.__str__())
|
|
|
|
def testMLIRFunctionCreation(self):
|
|
module = E.MLIRModule()
|
|
t = module.make_scalar_type("f32")
|
|
self.assertIn("f32", t.__str__())
|
|
m = module.make_memref_type(t, [3, 4, -1, 5])
|
|
self.assertIn("memref<3x4x?x5xf32>", m.__str__())
|
|
f = module.make_function("copy", [m, m], [])
|
|
self.assertIn(
|
|
"func @copy(%arg0: memref<3x4x?x5xf32>, %arg1: memref<3x4x?x5xf32>) {",
|
|
f.__str__())
|
|
|
|
f = module.make_function("sqrtf", [t], [t])
|
|
self.assertIn("func @sqrtf(%arg0: f32) -> f32", f.__str__())
|
|
|
|
def testMLIRConstantEmission(self):
|
|
module = E.MLIRModule()
|
|
f = module.make_function("constants", [], [])
|
|
with E.ContextManager():
|
|
emitter = E.MLIRFunctionEmitter(f)
|
|
emitter.bind_constant_bf16(1.23)
|
|
emitter.bind_constant_f16(1.23)
|
|
emitter.bind_constant_f32(1.23)
|
|
emitter.bind_constant_f64(1.23)
|
|
emitter.bind_constant_int(1, 1)
|
|
emitter.bind_constant_int(123, 8)
|
|
emitter.bind_constant_int(123, 16)
|
|
emitter.bind_constant_int(123, 32)
|
|
emitter.bind_constant_index(123)
|
|
emitter.bind_constant_function(f)
|
|
str = f.__str__()
|
|
self.assertIn("constant 1.230000e+00 : bf16", str)
|
|
self.assertIn("constant 1.230470e+00 : f16", str)
|
|
self.assertIn("constant 1.230000e+00 : f32", str)
|
|
self.assertIn("constant 1.230000e+00 : f64", str)
|
|
self.assertIn("constant 1 : i1", str)
|
|
self.assertIn("constant 123 : i8", str)
|
|
self.assertIn("constant 123 : i16", str)
|
|
self.assertIn("constant 123 : i32", str)
|
|
self.assertIn("constant 123 : index", str)
|
|
self.assertIn("constant @constants : () -> ()", str)
|
|
|
|
def testMLIRBuiltinEmission(self):
|
|
module = E.MLIRModule()
|
|
m = module.make_memref_type(self.f32Type, [10]) # f32 tensor
|
|
f = module.make_function("call_builtin", [m, m], [])
|
|
with E.ContextManager():
|
|
emitter = E.MLIRFunctionEmitter(f)
|
|
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
|
fn = module.declare_function("sqrtf", [self.f32Type], [self.f32Type])
|
|
fn = emitter.bind_constant_function(fn)
|
|
zero = emitter.bind_constant_index(0)
|
|
emitter.emit_inplace(E.Block([
|
|
output.store([zero], fn(input.load([zero]), result=self.f32Type))
|
|
]))
|
|
str = f.__str__()
|
|
self.assertIn("%f = constant @sqrtf : (f32) -> f32", str)
|
|
self.assertIn("call_indirect %f(%0) : (f32) -> f32", str)
|
|
|
|
def testMLIRBooleanEmission(self):
|
|
m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
|
|
f = self.module.make_function("mkbooltensor", [m, m], [])
|
|
with E.ContextManager():
|
|
emitter = E.MLIRFunctionEmitter(f)
|
|
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
|
i = E.Expr(E.Bindable(self.indexType))
|
|
j = E.Expr(E.Bindable(self.indexType))
|
|
k = E.Expr(E.Bindable(self.indexType))
|
|
idxs = [i, j, k]
|
|
zero = emitter.bind_constant_index(0)
|
|
one = emitter.bind_constant_index(1)
|
|
ten = emitter.bind_constant_index(10)
|
|
b1 = E.And(i < j, j < k)
|
|
b2 = E.Negate(b1)
|
|
b3 = E.Or(b2, k < j)
|
|
loop = E.Block([
|
|
E.For(idxs, [zero]*3, [ten]*3, [one]*3, [
|
|
output.store([i], E.And(input.load([i]), b3))
|
|
]),
|
|
E.Return()
|
|
])
|
|
emitter.emit_inplace(loop)
|
|
# str = f.__str__()
|
|
# print(str)
|
|
self.module.compile()
|
|
self.assertNotEqual(self.module.get_engine_address(), 0)
|
|
|
|
def testMLIREmission(self):
|
|
shape = [3, 4, 5]
|
|
m = self.module.make_memref_type(self.f32Type, shape)
|
|
f = self.module.make_function("copy", [m, m], [])
|
|
|
|
with E.ContextManager():
|
|
emitter = E.MLIRFunctionEmitter(f)
|
|
zero = emitter.bind_constant_index(0)
|
|
one = emitter.bind_constant_index(1)
|
|
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
|
M, N, O = emitter.bind_indexed_shape(input)
|
|
|
|
ivs = list(
|
|
map(E.Expr, [E.Bindable(self.indexType) for _ in range(len(shape))]))
|
|
lbs = [zero, zero, zero]
|
|
ubs = [M, N, O]
|
|
steps = [one, one, one]
|
|
|
|
# TODO(ntv): emitter.assertEqual(M, oM) etc
|
|
loop = E.Block([
|
|
E.For(ivs, lbs, ubs, steps, [output.store(ivs, input.load(ivs))]),
|
|
E.Return()
|
|
])
|
|
emitter.emit_inplace(loop)
|
|
|
|
# print(f) # uncomment to see the emitted IR
|
|
str = f.__str__()
|
|
self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""",
|
|
str)
|
|
|
|
self.module.compile()
|
|
self.assertNotEqual(self.module.get_engine_address(), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|