[MLIR][test] Fixup for checking for ml_dtypes (#123240)

In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]:
https://github.com/llvm/llvm-project/pull/123061#issuecomment-2596116023
This commit is contained in:
Konrad Kleine 2025-01-17 16:25:08 +01:00 committed by GitHub
parent 63b0ab8425
commit ba44d7ba1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -566,13 +566,15 @@ def testBF16Memref():
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
# test to-numpy utility # test to-numpy utility
# CHECK: [0.5] x = ranked_memref_to_numpy(arg2_memref_ptr[0])
npout = ranked_memref_to_numpy(arg2_memref_ptr[0]) assert len(x) == 1
log(npout) assert x[0] == 0.5
if HAS_ML_DTYPES: if HAS_ML_DTYPES:
run(testBF16Memref) run(testBF16Memref)
else:
log("TEST: testBF16Memref")
# Test f8E5M2 memrefs # Test f8E5M2 memrefs
@ -606,13 +608,15 @@ def testF8E5M2Memref():
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
# test to-numpy utility # test to-numpy utility
# CHECK: [0.5] x = ranked_memref_to_numpy(arg2_memref_ptr[0])
npout = ranked_memref_to_numpy(arg2_memref_ptr[0]) assert len(x) == 1
log(npout) assert x[0] == 0.5
if HAS_ML_DTYPES: if HAS_ML_DTYPES:
run(testF8E5M2Memref) run(testF8E5M2Memref)
else:
log("TEST: testF8E5M2Memref")
# Test addition of two 2d_memref # Test addition of two 2d_memref