[MLIR][test] Fixup for checking for ml_dtypes (#123240)
In order to optionally run some checks that depend on the `ml_dtypes` python module we have to remove the `CHECK` lines for those tests or they will be required and missed in the test output. I've changed to use asserts as recommended in [1]. [1]: https://github.com/llvm/llvm-project/pull/123061#issuecomment-2596116023
This commit is contained in:
parent
63b0ab8425
commit
ba44d7ba1f
@ -566,13 +566,15 @@ def testBF16Memref():
|
|||||||
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
||||||
|
|
||||||
# test to-numpy utility
|
# test to-numpy utility
|
||||||
# CHECK: [0.5]
|
x = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
||||||
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
assert len(x) == 1
|
||||||
log(npout)
|
assert x[0] == 0.5
|
||||||
|
|
||||||
|
|
||||||
if HAS_ML_DTYPES:
|
if HAS_ML_DTYPES:
|
||||||
run(testBF16Memref)
|
run(testBF16Memref)
|
||||||
|
else:
|
||||||
|
log("TEST: testBF16Memref")
|
||||||
|
|
||||||
|
|
||||||
# Test f8E5M2 memrefs
|
# Test f8E5M2 memrefs
|
||||||
@ -606,13 +608,15 @@ def testF8E5M2Memref():
|
|||||||
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
||||||
|
|
||||||
# test to-numpy utility
|
# test to-numpy utility
|
||||||
# CHECK: [0.5]
|
x = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
||||||
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
assert len(x) == 1
|
||||||
log(npout)
|
assert x[0] == 0.5
|
||||||
|
|
||||||
|
|
||||||
if HAS_ML_DTYPES:
|
if HAS_ML_DTYPES:
|
||||||
run(testF8E5M2Memref)
|
run(testF8E5M2Memref)
|
||||||
|
else:
|
||||||
|
log("TEST: testF8E5M2Memref")
|
||||||
|
|
||||||
|
|
||||||
# Test addition of two 2d_memref
|
# Test addition of two 2d_memref
|
||||||
|
Loading…
x
Reference in New Issue
Block a user