Brian Patton d52e631359 Add a test example of calling a builtin function.
PiperOrigin-RevId: 235149430
2019-03-29 16:37:46 -07:00

301 lines
11 KiB
Python

"""Python2 and 3 test for the MLIR EDSC C API and Python bindings"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import google_mlir.bindings.python.pybind as E
class EdscTest(unittest.TestCase):
def setUp(self):
self.module = E.MLIRModule()
self.boolType = self.module.make_scalar_type("i", 1)
self.i32Type = self.module.make_scalar_type("i", 32)
self.f32Type = self.module.make_scalar_type("f32")
self.indexType = self.module.make_index_type()
def testBindables(self):
with E.ContextManager():
i = E.Expr(E.Bindable(self.i32Type))
self.assertIn("$1", i.__str__())
def testOneExpr(self):
with E.ContextManager():
i, lb, ub = list(
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
expr = E.Mul(i, E.Add(lb, ub))
str = expr.__str__()
self.assertIn("($1 * ($2 + $3))", str)
def testOneLoop(self):
with E.ContextManager():
i, lb, ub, step = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
str = loop.__str__()
self.assertIn("for($1 = $2 to $3 step $4) {", str)
self.assertIn(" = ($2 + $3)", str)
def testTwoLoops(self):
with E.ContextManager():
i, lb, ub, step = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])])
str = loop.__str__()
self.assertIn("for($1 = $2 to $3 step $4) {", str)
self.assertIn("for($1 = $2 to $3 step $4) {", str)
self.assertIn("$5 = $1;", str)
def testNestedLoops(self):
with E.ContextManager():
i, lb, ub = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
step = E.ConstantInteger(self.indexType, 42)
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(ivs, lbs, ubs, steps, [
E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]),
])
str = loop.__str__()
self.assertIn("for($5 = $9 to $13 step $17) {", str)
self.assertIn("for($6 = $10 to $14 step $18) {", str)
self.assertIn("for($7 = $11 to $15 step $19) {", str)
self.assertIn("for($8 = $12 to $16 step $20) {", str)
self.assertIn("for($1 = $2 to $3 step 42) {", str)
self.assertIn("= (($3 * 42) + $2 * -1);", str)
def testMaxMinLoop(self):
with E.ContextManager():
i = E.Expr(E.Bindable(self.indexType))
step = E.Expr(E.Bindable(self.indexType))
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
loop = E.For(i, E.Max(lbs), E.Min(ubs), step, [])
s = str(loop)
self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)",
s)
def testIndexed(self):
with E.ContextManager():
i, j, k = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
str = stmt.__str__()
self.assertIn(" = store(", str)
def testMatmul(self):
with E.ContextManager():
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
i, j, k = ivs[0], ivs[1], ivs[2]
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
loop = E.For(
ivs, lbs, ubs, steps,
[C.store([i, j],
C.load([i, j]) + A.load([i, k]) * B.load([k, j]))])
str = loop.__str__()
self.assertIn("for($1 = $4 to $7 step $10) {", str)
self.assertIn("for($2 = $5 to $8 step $11) {", str)
self.assertIn("for($3 = $6 to $9 step $12) {", str)
self.assertIn(" = store", str)
def testArithmetic(self):
with E.ContextManager():
i, j, k, l = list(
map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)]))
stmt = i % j + j * k - l / k
str = stmt.__str__()
self.assertIn("((($1 % $2) + ($2 * $3)) - ($4 / $3))", str)
def testBoolean(self):
with E.ContextManager():
i, j, k, l = list(
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(4)]))
stmt1 = (i < j) & (j >= k)
stmt2 = ~(stmt1 | (k == l))
str = stmt2.__str__()
# Note that "a | b" is currently implemented as ~(~a && ~b) and "~a" is
# currently implemented as "constant 1 - a", which leads to this
# expression.
self.assertIn(
"(1 - (1 - ((1 - (($1 < $2) && ($2 >= $3))) && (1 - ($3 == $4)))))",
str)
def testSelect(self):
with E.ContextManager():
i, j, k = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
stmt = E.Select(i > j, i, j)
str = stmt.__str__()
self.assertIn("select(($1 > $2), $1, $2)", str)
def testCall(self):
with E.ContextManager():
module = E.MLIRModule()
f32 = module.make_scalar_type("f32")
func, arg = [E.Expr(E.Bindable(f32)) for _ in range(2)]
code = func(arg, result=f32)
self.assertIn("@$1($2)", str(code))
def testBlock(self):
with E.ContextManager():
i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)]))
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
str = stmt.__str__()
self.assertIn("^bb:", str)
self.assertIn(" = ($1 + $2)", str)
self.assertIn(" = ($1 - $2)", str)
def testMLIRScalarTypes(self):
module = E.MLIRModule()
t = module.make_scalar_type("bf16")
self.assertIn("bf16", t.__str__())
t = module.make_scalar_type("f16")
self.assertIn("f16", t.__str__())
t = module.make_scalar_type("f32")
self.assertIn("f32", t.__str__())
t = module.make_scalar_type("f64")
self.assertIn("f64", t.__str__())
t = module.make_scalar_type("i", 1)
self.assertIn("i1", t.__str__())
t = module.make_scalar_type("i", 8)
self.assertIn("i8", t.__str__())
t = module.make_scalar_type("i", 32)
self.assertIn("i32", t.__str__())
t = module.make_scalar_type("i", 123)
self.assertIn("i123", t.__str__())
t = module.make_scalar_type("index")
self.assertIn("index", t.__str__())
def testMLIRFunctionCreation(self):
module = E.MLIRModule()
t = module.make_scalar_type("f32")
self.assertIn("f32", t.__str__())
m = module.make_memref_type(t, [3, 4, -1, 5])
self.assertIn("memref<3x4x?x5xf32>", m.__str__())
f = module.make_function("copy", [m, m], [])
self.assertIn(
"func @copy(%arg0: memref<3x4x?x5xf32>, %arg1: memref<3x4x?x5xf32>) {",
f.__str__())
f = module.make_function("sqrtf", [t], [t])
self.assertIn("func @sqrtf(%arg0: f32) -> f32", f.__str__())
def testMLIRConstantEmission(self):
module = E.MLIRModule()
f = module.make_function("constants", [], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
emitter.bind_constant_bf16(1.23)
emitter.bind_constant_f16(1.23)
emitter.bind_constant_f32(1.23)
emitter.bind_constant_f64(1.23)
emitter.bind_constant_int(1, 1)
emitter.bind_constant_int(123, 8)
emitter.bind_constant_int(123, 16)
emitter.bind_constant_int(123, 32)
emitter.bind_constant_index(123)
emitter.bind_constant_function(f)
str = f.__str__()
self.assertIn("constant 1.230000e+00 : bf16", str)
self.assertIn("constant 1.230470e+00 : f16", str)
self.assertIn("constant 1.230000e+00 : f32", str)
self.assertIn("constant 1.230000e+00 : f64", str)
self.assertIn("constant 1 : i1", str)
self.assertIn("constant 123 : i8", str)
self.assertIn("constant 123 : i16", str)
self.assertIn("constant 123 : i32", str)
self.assertIn("constant 123 : index", str)
self.assertIn("constant @constants : () -> ()", str)
def testMLIRBuiltinEmission(self):
module = E.MLIRModule()
m = module.make_memref_type(self.f32Type, [10]) # f32 tensor
f = module.make_function("call_builtin", [m, m], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
fn = module.declare_function("sqrtf", [self.f32Type], [self.f32Type])
fn = emitter.bind_constant_function(fn)
zero = emitter.bind_constant_index(0)
emitter.emit_inplace(E.Block([
output.store([zero], fn(input.load([zero]), result=self.f32Type))
]))
str = f.__str__()
self.assertIn("%f = constant @sqrtf : (f32) -> f32", str)
self.assertIn("call_indirect %f(%0) : (f32) -> f32", str)
def testMLIRBooleanEmission(self):
m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
f = self.module.make_function("mkbooltensor", [m, m], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
i = E.Expr(E.Bindable(self.indexType))
j = E.Expr(E.Bindable(self.indexType))
k = E.Expr(E.Bindable(self.indexType))
idxs = [i, j, k]
zero = emitter.bind_constant_index(0)
one = emitter.bind_constant_index(1)
ten = emitter.bind_constant_index(10)
b1 = E.And(i < j, j < k)
b2 = E.Negate(b1)
b3 = E.Or(b2, k < j)
loop = E.Block([
E.For(idxs, [zero]*3, [ten]*3, [one]*3, [
output.store([i], E.And(input.load([i]), b3))
]),
E.Return()
])
emitter.emit_inplace(loop)
# str = f.__str__()
# print(str)
self.module.compile()
self.assertNotEqual(self.module.get_engine_address(), 0)
def testMLIREmission(self):
shape = [3, 4, 5]
m = self.module.make_memref_type(self.f32Type, shape)
f = self.module.make_function("copy", [m, m], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
zero = emitter.bind_constant_index(0)
one = emitter.bind_constant_index(1)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
M, N, O = emitter.bind_indexed_shape(input)
ivs = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(len(shape))]))
lbs = [zero, zero, zero]
ubs = [M, N, O]
steps = [one, one, one]
# TODO(ntv): emitter.assertEqual(M, oM) etc
loop = E.Block([
E.For(ivs, lbs, ubs, steps, [output.store(ivs, input.load(ivs))]),
E.Return()
])
emitter.emit_inplace(loop)
# print(f) # uncomment to see the emitted IR
str = f.__str__()
self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""",
str)
self.module.compile()
self.assertNotEqual(self.module.get_engine_address(), 0)
if __name__ == "__main__":
unittest.main()