diff --git a/mlir/include/mlir-c/Dialect/SMT.h b/mlir/include/mlir-c/Dialect/SMT.h index d076dccce1b0..0ad64746f148 100644 --- a/mlir/include/mlir-c/Dialect/SMT.h +++ b/mlir/include/mlir-c/Dialect/SMT.h @@ -26,82 +26,83 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt); //===----------------------------------------------------------------------===// /// Checks if the given type is any non-func SMT value type. -MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type); /// Checks if the given type is any SMT value type. -MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type); /// Checks if the given type is a smt::ArrayType. -MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type); /// Creates an array type with the given domain and range types. -MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx, - MlirType domainType, - MlirType rangeType); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx, + MlirType domainType, + MlirType rangeType); /// Checks if the given type is a smt::BitVectorType. -MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type); /// Creates a smt::BitVectorType with the given width. -MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx, + int32_t width); /// Checks if the given type is a smt::BoolType. -MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type); /// Creates a smt::BoolType. -MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx); /// Checks if the given type is a smt::IntType. -MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type); /// Creates a smt::IntType. -MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx); /// Checks if the given type is a smt::FuncType. -MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type); /// Creates a smt::FuncType with the given domain and range types. -MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx, - size_t numberOfDomainTypes, - const MlirType *domainTypes, - MlirType rangeType); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, + size_t numberOfDomainTypes, + const MlirType *domainTypes, + MlirType rangeType); /// Checks if the given type is a smt::SortType. -MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type); /// Creates a smt::SortType with the given identifier and sort parameters. -MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx, - MlirIdentifier identifier, - size_t numberOfSortParams, - const MlirType *sortParams); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx, + MlirIdentifier identifier, + size_t numberOfSortParams, + const MlirType *sortParams); //===----------------------------------------------------------------------===// // Attribute API. //===----------------------------------------------------------------------===// /// Checks if the given string is a valid smt::BVCmpPredicate. -MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx, - MlirStringRef str); +MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, + MlirStringRef str); /// Checks if the given string is a valid smt::IntPredicate. -MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx, - MlirStringRef str); +MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, + MlirStringRef str); /// Checks if the given attribute is a smt::SMTAttribute. -MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr); /// Creates a smt::BitVectorAttr with the given value and width. -MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx, - uint64_t value, - unsigned width); +MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, + uint64_t value, + unsigned width); /// Creates a smt::BVCmpPredicateAttr with the given string. -MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, - MlirStringRef str); +MLIR_CAPI_EXPORTED MlirAttribute +mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str); /// Creates a smt::IntPredicateAttr with the given string. -MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, - MlirStringRef str); +MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, + MlirStringRef str); #ifdef __cplusplus } diff --git a/mlir/include/mlir-c/Target/ExportSMTLIB.h b/mlir/include/mlir-c/Target/ExportSMTLIB.h index 31f411c4a89c..59beda54d289 100644 --- a/mlir/include/mlir-c/Target/ExportSMTLIB.h +++ b/mlir/include/mlir-c/Target/ExportSMTLIB.h @@ -21,9 +21,13 @@ extern "C" { /// Emits SMTLIB for the specified module using the provided callback and user /// data -MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule, - MlirStringCallback, - void *userData); +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirTranslateModuleToSMTLIB(MlirModule, MlirStringCallback, void *userData, + bool inlineSingleUseValues, bool indentLetBody); + +MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB( + MlirOperation, MlirStringCallback, void *userData, + bool inlineSingleUseValues, bool indentLetBody); #ifdef __cplusplus } diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp new file mode 100644 index 000000000000..4e7647729fb0 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -0,0 +1,83 @@ +//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "NanobindUtils.h" + +#include "mlir-c/Dialect/SMT.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir-c/Target/ExportSMTLIB.h" +#include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; + +using namespace nanobind::literals; + +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +void populateDialectSMTSubmodule(nanobind::module_ &m) { + + auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) + .def_classmethod( + "get", + [](const nb::object &, MlirContext context) { + return mlirSMTTypeGetBool(context); + }, + "cls"_a, "context"_a.none() = nb::none()); + auto smtBitVectorType = + mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector) + .def_classmethod( + "get", + [](const nb::object &, int32_t width, MlirContext context) { + return mlirSMTTypeGetBitVector(context, width); + }, + "cls"_a, "width"_a, "context"_a.none() = nb::none()); + + auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues, + bool indentLetBody) { + mlir::python::CollectDiagnosticsToStringScope scope( + mlirOperationGetContext(module)); + PyPrintAccumulator printAccum; + MlirLogicalResult result = mlirTranslateOperationToSMTLIB( + module, printAccum.getCallback(), printAccum.getUserData(), + inlineSingleUseValues, indentLetBody); + if (mlirLogicalResultIsSuccess(result)) + return printAccum.join(); + throw nb::value_error( + ("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage()) + .c_str()); + }; + + m.def( + "export_smtlib", + [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues, + bool indentLetBody) { + return exportSMTLIB(module, inlineSingleUseValues, indentLetBody); + }, + "module"_a, "inline_single_use_values"_a = false, + "indent_let_body"_a = false); + m.def( + "export_smtlib", + [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues, + bool indentLetBody) { + return exportSMTLIB(mlirModuleGetOperation(module), + inlineSingleUseValues, indentLetBody); + }, + "module"_a, "inline_single_use_values"_a = false, + "indent_let_body"_a = false); +} + +NB_MODULE(_mlirDialectsSMT, m) { + m.doc() = "MLIR SMT Dialect"; + + populateDialectSMTSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/SMT.cpp b/mlir/lib/CAPI/Dialect/SMT.cpp index 3a4620df8ccd..7e96bbb07153 100644 --- a/mlir/lib/CAPI/Dialect/SMT.cpp +++ b/mlir/lib/CAPI/Dialect/SMT.cpp @@ -25,46 +25,49 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect) // Type API. //===----------------------------------------------------------------------===// -bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) { +bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) { return isAnyNonFuncSMTValueType(unwrap(type)); } -bool smtTypeIsAnySMTValueType(MlirType type) { +bool mlirSMTTypeIsAnySMTValueType(MlirType type) { return isAnySMTValueType(unwrap(type)); } -bool smtTypeIsAArray(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsAArray(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType, - MlirType rangeType) { +MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType, + MlirType rangeType) { return wrap( ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType))); } -bool smtTypeIsABitVector(MlirType type) { +bool mlirSMTTypeIsABitVector(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) { +MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) { return wrap(BitVectorType::get(unwrap(ctx), width)); } -bool smtTypeIsABool(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsABool(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetBool(MlirContext ctx) { +MlirType mlirSMTTypeGetBool(MlirContext ctx) { return wrap(BoolType::get(unwrap(ctx))); } -bool smtTypeIsAInt(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsAInt(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetInt(MlirContext ctx) { +MlirType mlirSMTTypeGetInt(MlirContext ctx) { return wrap(IntType::get(unwrap(ctx))); } -bool smtTypeIsASMTFunc(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsASMTFunc(MlirType type) { + return isa(unwrap(type)); +} -MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, - const MlirType *domainTypes, MlirType rangeType) { +MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, + const MlirType *domainTypes, + MlirType rangeType) { SmallVector domainTypesVec; domainTypesVec.reserve(numberOfDomainTypes); @@ -74,10 +77,11 @@ MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType))); } -bool smtTypeIsASort(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsASort(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier, - size_t numberOfSortParams, const MlirType *sortParams) { +MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier, + size_t numberOfSortParams, + const MlirType *sortParams) { SmallVector sortParamsVec; sortParamsVec.reserve(numberOfSortParams); @@ -91,31 +95,31 @@ MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier, // Attribute API. //===----------------------------------------------------------------------===// -bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) { +bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) { return symbolizeBVCmpPredicate(unwrap(str)).has_value(); } -bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) { +bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) { return symbolizeIntPredicate(unwrap(str)).has_value(); } -bool smtAttrIsASMTAttribute(MlirAttribute attr) { +bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) { return isa(unwrap(attr)); } -MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value, - unsigned width) { +MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value, + unsigned width) { return wrap(BitVectorAttr::get(unwrap(ctx), value, width)); } -MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) { +MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) { auto predicate = symbolizeBVCmpPredicate(unwrap(str)); assert(predicate.has_value() && "invalid predicate"); return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value())); } -MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) { +MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) { auto predicate = symbolizeIntPredicate(unwrap(str)); assert(predicate.has_value() && "invalid predicate"); diff --git a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp index c9ac7ce704af..4326f967281e 100644 --- a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp +++ b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp @@ -19,9 +19,24 @@ using namespace mlir; -MlirLogicalResult mlirExportSMTLIB(MlirModule module, - MlirStringCallback callback, - void *userData) { +MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module, + MlirStringCallback callback, + void *userData, + bool inlineSingleUseValues, + bool indentLetBody) { mlir::detail::CallbackOstream stream(callback, userData); + smt::SMTEmissionOptions options; + options.inlineSingleUseValues = inlineSingleUseValues; + options.indentLetBody = indentLetBody; return wrap(smt::exportSMTLIB(unwrap(module), stream)); } + +MlirLogicalResult mlirTranslateModuleToSMTLIB(MlirModule module, + MlirStringCallback callback, + void *userData, + bool inlineSingleUseValues, + bool indentLetBody) { + return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module), + callback, userData, + inlineSingleUseValues, indentLetBody); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index fb115a5f4342..bbf6819608bb 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -403,6 +403,15 @@ declare_mlir_dialect_python_bindings( "../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" ) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SMTOps.td + GEN_ENUM_BINDINGS + SOURCES + dialects/smt.py + DIALECT_NAME smt) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -664,6 +673,21 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses MLIRCAPILinalg ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind + MODULE_NAME _mlirDialectsSMT + ADD_TO_PARENT MLIRPythonSources.Dialects.smt + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectSMT.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPISMT + MLIRCAPIExportSMTLIB +) + declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MODULE_NAME _mlirSparseTensorPasses ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor diff --git a/mlir/python/mlir/dialects/SMTOps.td b/mlir/python/mlir/dialects/SMTOps.td new file mode 100644 index 000000000000..e143f071eb65 --- /dev/null +++ b/mlir/python/mlir/dialects/SMTOps.td @@ -0,0 +1,14 @@ +//===- SMTOps.td - Entry point for SMT bindings ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BINDINGS_PYTHON_SMT_OPS +#define BINDINGS_PYTHON_SMT_OPS + +include "mlir/Dialect/SMT/IR/SMT.td" + +#endif // BINDINGS_PYTHON_SMT_OPS diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py new file mode 100644 index 000000000000..ae7a4c41cbc3 --- /dev/null +++ b/mlir/python/mlir/dialects/smt.py @@ -0,0 +1,33 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._smt_ops_gen import * + +from .._mlir_libs._mlirDialectsSMT import * +from ..extras.meta import region_op + + +def bool_t(): + return BoolType.get() + + +def bv_t(width): + return BitVectorType.get(width) + + +def _solver( + inputs=None, + results=None, + loc=None, + ip=None, +): + if inputs is None: + inputs = [] + if results is None: + results = [] + + return SolverOp(results, inputs, loc=loc, ip=ip) + + +solver = region_op(_solver, terminator=YieldOp) diff --git a/mlir/test/CAPI/smt.c b/mlir/test/CAPI/smt.c index 77815d4f7965..95a9b55e3209 100644 --- a/mlir/test/CAPI/smt.c +++ b/mlir/test/CAPI/smt.c @@ -34,7 +34,8 @@ void testExportSMTLIB(MlirContext ctx) { MlirModule module = mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(testSMT)); - MlirLogicalResult result = mlirExportSMTLIB(module, dumpCallback, NULL); + MlirLogicalResult result = + mlirTranslateModuleToSMTLIB(module, dumpCallback, NULL, false, false); (void)result; assert(mlirLogicalResultIsSuccess(result)); @@ -44,13 +45,13 @@ void testExportSMTLIB(MlirContext ctx) { } void testSMTType(MlirContext ctx) { - MlirType boolType = smtTypeGetBool(ctx); - MlirType intType = smtTypeGetInt(ctx); - MlirType arrayType = smtTypeGetArray(ctx, intType, boolType); - MlirType bvType = smtTypeGetBitVector(ctx, 32); + MlirType boolType = mlirSMTTypeGetBool(ctx); + MlirType intType = mlirSMTTypeGetInt(ctx); + MlirType arrayType = mlirSMTTypeGetArray(ctx, intType, boolType); + MlirType bvType = mlirSMTTypeGetBitVector(ctx, 32); MlirType funcType = - smtTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType); - MlirType sortType = smtTypeGetSort( + mlirSMTTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType); + MlirType sortType = mlirSMTTypeGetSort( ctx, mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sort")), 0, NULL); @@ -68,107 +69,107 @@ void testSMTType(MlirContext ctx) { mlirTypeDump(sortType); // CHECK: bool_is_any_non_func_smt_value_type - fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(boolType) + fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(boolType) ? "bool_is_any_non_func_smt_value_type\n" : "bool_is_func_smt_value_type\n"); // CHECK: int_is_any_non_func_smt_value_type - fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(intType) + fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(intType) ? "int_is_any_non_func_smt_value_type\n" : "int_is_func_smt_value_type\n"); // CHECK: array_is_any_non_func_smt_value_type - fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(arrayType) + fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(arrayType) ? "array_is_any_non_func_smt_value_type\n" : "array_is_func_smt_value_type\n"); // CHECK: bit_vector_is_any_non_func_smt_value_type - fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(bvType) + fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(bvType) ? "bit_vector_is_any_non_func_smt_value_type\n" : "bit_vector_is_func_smt_value_type\n"); // CHECK: sort_is_any_non_func_smt_value_type - fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(sortType) + fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(sortType) ? "sort_is_any_non_func_smt_value_type\n" : "sort_is_func_smt_value_type\n"); // CHECK: smt_func_is_func_smt_value_type - fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(funcType) + fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(funcType) ? "smt_func_is_any_non_func_smt_value_type\n" : "smt_func_is_func_smt_value_type\n"); // CHECK: bool_is_any_smt_value_type - fprintf(stderr, smtTypeIsAnySMTValueType(boolType) + fprintf(stderr, mlirSMTTypeIsAnySMTValueType(boolType) ? "bool_is_any_smt_value_type\n" : "bool_is_not_any_smt_value_type\n"); // CHECK: int_is_any_smt_value_type - fprintf(stderr, smtTypeIsAnySMTValueType(intType) + fprintf(stderr, mlirSMTTypeIsAnySMTValueType(intType) ? "int_is_any_smt_value_type\n" : "int_is_not_any_smt_value_type\n"); // CHECK: array_is_any_smt_value_type - fprintf(stderr, smtTypeIsAnySMTValueType(arrayType) + fprintf(stderr, mlirSMTTypeIsAnySMTValueType(arrayType) ? "array_is_any_smt_value_type\n" : "array_is_not_any_smt_value_type\n"); // CHECK: array_is_any_smt_value_type - fprintf(stderr, smtTypeIsAnySMTValueType(bvType) + fprintf(stderr, mlirSMTTypeIsAnySMTValueType(bvType) ? "array_is_any_smt_value_type\n" : "array_is_not_any_smt_value_type\n"); // CHECK: smt_func_is_any_smt_value_type - fprintf(stderr, smtTypeIsAnySMTValueType(funcType) + fprintf(stderr, mlirSMTTypeIsAnySMTValueType(funcType) ? "smt_func_is_any_smt_value_type\n" : "smt_func_is_not_any_smt_value_type\n"); // CHECK: sort_is_any_smt_value_type - fprintf(stderr, smtTypeIsAnySMTValueType(sortType) + fprintf(stderr, mlirSMTTypeIsAnySMTValueType(sortType) ? "sort_is_any_smt_value_type\n" : "sort_is_not_any_smt_value_type\n"); // CHECK: int_type_is_not_a_bool - fprintf(stderr, smtTypeIsABool(intType) ? "int_type_is_a_bool\n" - : "int_type_is_not_a_bool\n"); + fprintf(stderr, mlirSMTTypeIsABool(intType) ? "int_type_is_a_bool\n" + : "int_type_is_not_a_bool\n"); // CHECK: bool_type_is_not_a_int - fprintf(stderr, smtTypeIsAInt(boolType) ? "bool_type_is_a_int\n" - : "bool_type_is_not_a_int\n"); + fprintf(stderr, mlirSMTTypeIsAInt(boolType) ? "bool_type_is_a_int\n" + : "bool_type_is_not_a_int\n"); // CHECK: bv_type_is_not_a_array - fprintf(stderr, smtTypeIsAArray(bvType) ? "bv_type_is_a_array\n" - : "bv_type_is_not_a_array\n"); + fprintf(stderr, mlirSMTTypeIsAArray(bvType) ? "bv_type_is_a_array\n" + : "bv_type_is_not_a_array\n"); // CHECK: array_type_is_not_a_bit_vector - fprintf(stderr, smtTypeIsABitVector(arrayType) + fprintf(stderr, mlirSMTTypeIsABitVector(arrayType) ? "array_type_is_a_bit_vector\n" : "array_type_is_not_a_bit_vector\n"); // CHECK: sort_type_is_not_a_smt_func - fprintf(stderr, smtTypeIsASMTFunc(sortType) + fprintf(stderr, mlirSMTTypeIsASMTFunc(sortType) ? "sort_type_is_a_smt_func\n" : "sort_type_is_not_a_smt_func\n"); // CHECK: func_type_is_not_a_sort - fprintf(stderr, smtTypeIsASort(funcType) ? "func_type_is_a_sort\n" - : "func_type_is_not_a_sort\n"); + fprintf(stderr, mlirSMTTypeIsASort(funcType) ? "func_type_is_a_sort\n" + : "func_type_is_not_a_sort\n"); } void testSMTAttribute(MlirContext ctx) { // CHECK: slt_is_BVCmpPredicate - fprintf(stderr, - smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")) - ? "slt_is_BVCmpPredicate\n" - : "slt_is_not_BVCmpPredicate\n"); + fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate( + ctx, mlirStringRefCreateFromCString("slt")) + ? "slt_is_BVCmpPredicate\n" + : "slt_is_not_BVCmpPredicate\n"); // CHECK: lt_is_not_BVCmpPredicate - fprintf(stderr, - smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("lt")) - ? "lt_is_BVCmpPredicate\n" - : "lt_is_not_BVCmpPredicate\n"); + fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate( + ctx, mlirStringRefCreateFromCString("lt")) + ? "lt_is_BVCmpPredicate\n" + : "lt_is_not_BVCmpPredicate\n"); // CHECK: slt_is_not_IntPredicate - fprintf(stderr, - smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("slt")) - ? "slt_is_IntPredicate\n" - : "slt_is_not_IntPredicate\n"); + fprintf(stderr, mlirSMTAttrCheckIntPredicate( + ctx, mlirStringRefCreateFromCString("slt")) + ? "slt_is_IntPredicate\n" + : "slt_is_not_IntPredicate\n"); // CHECK: lt_is_IntPredicate - fprintf(stderr, - smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("lt")) - ? "lt_is_IntPredicate\n" - : "lt_is_not_IntPredicate\n"); + fprintf(stderr, mlirSMTAttrCheckIntPredicate( + ctx, mlirStringRefCreateFromCString("lt")) + ? "lt_is_IntPredicate\n" + : "lt_is_not_IntPredicate\n"); // CHECK: #smt.bv<5> : !smt.bv<32> - mlirAttributeDump(smtAttrGetBitVector(ctx, 5, 32)); + mlirAttributeDump(mlirSMTAttrGetBitVector(ctx, 5, 32)); // CHECK: 0 : i64 mlirAttributeDump( - smtAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt"))); + mlirSMTAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt"))); // CHECK: 0 : i64 mlirAttributeDump( - smtAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt"))); + mlirSMTAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt"))); } int main(void) { diff --git a/mlir/test/python/dialects/smt.py b/mlir/test/python/dialects/smt.py new file mode 100644 index 000000000000..6f0cd8835b65 --- /dev/null +++ b/mlir/test/python/dialects/smt.py @@ -0,0 +1,87 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.dialects import smt, arith +from mlir.ir import Context, Location, Module, InsertionPoint, F32Type + + +def run(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f(module) + print(module) + assert module.operation.verify() + + +# CHECK-LABEL: TEST: test_smoke +@run +def test_smoke(_module): + true = smt.constant(True) + false = smt.constant(False) + # CHECK: smt.constant true + # CHECK: smt.constant false + + +# CHECK-LABEL: TEST: test_types +@run +def test_types(_module): + bool_t = smt.bool_t() + bitvector_t = smt.bv_t(5) + # CHECK: !smt.bool + print(bool_t) + # CHECK: !smt.bv<5> + print(bitvector_t) + + +# CHECK-LABEL: TEST: test_solver_op +@run +def test_solver_op(_module): + @smt.solver + def foo1(): + true = smt.constant(True) + false = smt.constant(False) + + # CHECK: smt.solver() : () -> () { + # CHECK: %true = smt.constant true + # CHECK: %false = smt.constant false + # CHECK: } + + f32 = F32Type.get() + + @smt.solver(results=[f32]) + def foo2(): + return arith.ConstantOp(f32, 1.0) + + # CHECK: %{{.*}} = smt.solver() : () -> f32 { + # CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 + # CHECK: smt.yield %[[CST1]] : f32 + # CHECK: } + + two = arith.ConstantOp(f32, 2.0) + # CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 + print(two) + + @smt.solver(inputs=[two], results=[f32]) + def foo3(z: f32): + return z + + # CHECK: %{{.*}} = smt.solver(%[[CST2]]) : (f32) -> f32 { + # CHECK: ^bb0(%[[ARG0:.*]]: f32): + # CHECK: smt.yield %[[ARG0]] : f32 + # CHECK: } + + +# CHECK-LABEL: TEST: test_export_smtlib +@run +def test_export_smtlib(module): + @smt.solver + def foo1(): + true = smt.constant(True) + smt.assert_(true) + + query = smt.export_smtlib(module.operation) + # CHECK: ; solver scope 0 + # CHECK: (assert true) + # CHECK: (reset) + print(query)