[mlir][SMT] add python bindings (#135674)

This PR adds "rich" python bindings to SMT dialect.
This commit is contained in:
Maksim Levental 2025-04-16 18:17:09 -04:00 committed by GitHub
parent 7623501c05
commit 697aa9995c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 378 additions and 112 deletions

View File

@ -26,82 +26,83 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt);
//===----------------------------------------------------------------------===//
/// Checks if the given type is any non-func SMT value type.
MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type);
/// Checks if the given type is any SMT value type.
MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type);
/// Checks if the given type is a smt::ArrayType.
MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type);
/// Creates an array type with the given domain and range types.
MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx,
MlirType domainType,
MlirType rangeType);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx,
MlirType domainType,
MlirType rangeType);
/// Checks if the given type is a smt::BitVectorType.
MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type);
/// Creates a smt::BitVectorType with the given width.
MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx,
int32_t width);
/// Checks if the given type is a smt::BoolType.
MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type);
/// Creates a smt::BoolType.
MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx);
/// Checks if the given type is a smt::IntType.
MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type);
/// Creates a smt::IntType.
MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx);
/// Checks if the given type is a smt::FuncType.
MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type);
/// Creates a smt::FuncType with the given domain and range types.
MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx,
size_t numberOfDomainTypes,
const MlirType *domainTypes,
MlirType rangeType);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx,
size_t numberOfDomainTypes,
const MlirType *domainTypes,
MlirType rangeType);
/// Checks if the given type is a smt::SortType.
MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type);
/// Creates a smt::SortType with the given identifier and sort parameters.
MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx,
MlirIdentifier identifier,
size_t numberOfSortParams,
const MlirType *sortParams);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx,
MlirIdentifier identifier,
size_t numberOfSortParams,
const MlirType *sortParams);
//===----------------------------------------------------------------------===//
// Attribute API.
//===----------------------------------------------------------------------===//
/// Checks if the given string is a valid smt::BVCmpPredicate.
MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx,
MlirStringRef str);
MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx,
MlirStringRef str);
/// Checks if the given string is a valid smt::IntPredicate.
MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx,
MlirStringRef str);
MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx,
MlirStringRef str);
/// Checks if the given attribute is a smt::SMTAttribute.
MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr);
/// Creates a smt::BitVectorAttr with the given value and width.
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx,
uint64_t value,
unsigned width);
MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx,
uint64_t value,
unsigned width);
/// Creates a smt::BVCmpPredicateAttr with the given string.
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx,
MlirStringRef str);
MLIR_CAPI_EXPORTED MlirAttribute
mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str);
/// Creates a smt::IntPredicateAttr with the given string.
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx,
MlirStringRef str);
MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx,
MlirStringRef str);
#ifdef __cplusplus
}

View File

@ -21,9 +21,13 @@ extern "C" {
/// Emits SMTLIB for the specified module using the provided callback and user
/// data
MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule,
MlirStringCallback,
void *userData);
MLIR_CAPI_EXPORTED MlirLogicalResult
mlirTranslateModuleToSMTLIB(MlirModule, MlirStringCallback, void *userData,
bool inlineSingleUseValues, bool indentLetBody);
MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB(
MlirOperation, MlirStringCallback, void *userData,
bool inlineSingleUseValues, bool indentLetBody);
#ifdef __cplusplus
}

View File

@ -0,0 +1,83 @@
//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "NanobindUtils.h"
#include "mlir-c/Dialect/SMT.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir-c/Target/ExportSMTLIB.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
void populateDialectSMTSubmodule(nanobind::module_ &m) {
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
.def_classmethod(
"get",
[](const nb::object &, MlirContext context) {
return mlirSMTTypeGetBool(context);
},
"cls"_a, "context"_a.none() = nb::none());
auto smtBitVectorType =
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
.def_classmethod(
"get",
[](const nb::object &, int32_t width, MlirContext context) {
return mlirSMTTypeGetBitVector(context, width);
},
"cls"_a, "width"_a, "context"_a.none() = nb::none());
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(module));
PyPrintAccumulator printAccum;
MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
module, printAccum.getCallback(), printAccum.getUserData(),
inlineSingleUseValues, indentLetBody);
if (mlirLogicalResultIsSuccess(result))
return printAccum.join();
throw nb::value_error(
("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage())
.c_str());
};
m.def(
"export_smtlib",
[&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
},
"module"_a, "inline_single_use_values"_a = false,
"indent_let_body"_a = false);
m.def(
"export_smtlib",
[&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
bool indentLetBody) {
return exportSMTLIB(mlirModuleGetOperation(module),
inlineSingleUseValues, indentLetBody);
},
"module"_a, "inline_single_use_values"_a = false,
"indent_let_body"_a = false);
}
NB_MODULE(_mlirDialectsSMT, m) {
m.doc() = "MLIR SMT Dialect";
populateDialectSMTSubmodule(m);
}

View File

@ -25,46 +25,49 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect)
// Type API.
//===----------------------------------------------------------------------===//
bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) {
bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) {
return isAnyNonFuncSMTValueType(unwrap(type));
}
bool smtTypeIsAnySMTValueType(MlirType type) {
bool mlirSMTTypeIsAnySMTValueType(MlirType type) {
return isAnySMTValueType(unwrap(type));
}
bool smtTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
bool mlirSMTTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType,
MlirType rangeType) {
MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType,
MlirType rangeType) {
return wrap(
ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType)));
}
bool smtTypeIsABitVector(MlirType type) {
bool mlirSMTTypeIsABitVector(MlirType type) {
return isa<BitVectorType>(unwrap(type));
}
MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) {
MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) {
return wrap(BitVectorType::get(unwrap(ctx), width));
}
bool smtTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
bool mlirSMTTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
MlirType smtTypeGetBool(MlirContext ctx) {
MlirType mlirSMTTypeGetBool(MlirContext ctx) {
return wrap(BoolType::get(unwrap(ctx)));
}
bool smtTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
bool mlirSMTTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
MlirType smtTypeGetInt(MlirContext ctx) {
MlirType mlirSMTTypeGetInt(MlirContext ctx) {
return wrap(IntType::get(unwrap(ctx)));
}
bool smtTypeIsASMTFunc(MlirType type) { return isa<SMTFuncType>(unwrap(type)); }
bool mlirSMTTypeIsASMTFunc(MlirType type) {
return isa<SMTFuncType>(unwrap(type));
}
MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
const MlirType *domainTypes, MlirType rangeType) {
MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
const MlirType *domainTypes,
MlirType rangeType) {
SmallVector<Type> domainTypesVec;
domainTypesVec.reserve(numberOfDomainTypes);
@ -74,10 +77,11 @@ MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType)));
}
bool smtTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
bool mlirSMTTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
size_t numberOfSortParams, const MlirType *sortParams) {
MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
size_t numberOfSortParams,
const MlirType *sortParams) {
SmallVector<Type> sortParamsVec;
sortParamsVec.reserve(numberOfSortParams);
@ -91,31 +95,31 @@ MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
// Attribute API.
//===----------------------------------------------------------------------===//
bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
return symbolizeBVCmpPredicate(unwrap(str)).has_value();
}
bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
return symbolizeIntPredicate(unwrap(str)).has_value();
}
bool smtAttrIsASMTAttribute(MlirAttribute attr) {
bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) {
return isa<BitVectorAttr, BVCmpPredicateAttr, IntPredicateAttr>(unwrap(attr));
}
MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value,
unsigned width) {
MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value,
unsigned width) {
return wrap(BitVectorAttr::get(unwrap(ctx), value, width));
}
MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
auto predicate = symbolizeBVCmpPredicate(unwrap(str));
assert(predicate.has_value() && "invalid predicate");
return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value()));
}
MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
auto predicate = symbolizeIntPredicate(unwrap(str));
assert(predicate.has_value() && "invalid predicate");

View File

@ -19,9 +19,24 @@
using namespace mlir;
MlirLogicalResult mlirExportSMTLIB(MlirModule module,
MlirStringCallback callback,
void *userData) {
MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module,
MlirStringCallback callback,
void *userData,
bool inlineSingleUseValues,
bool indentLetBody) {
mlir::detail::CallbackOstream stream(callback, userData);
smt::SMTEmissionOptions options;
options.inlineSingleUseValues = inlineSingleUseValues;
options.indentLetBody = indentLetBody;
return wrap(smt::exportSMTLIB(unwrap(module), stream));
}
MlirLogicalResult mlirTranslateModuleToSMTLIB(MlirModule module,
MlirStringCallback callback,
void *userData,
bool inlineSingleUseValues,
bool indentLetBody) {
return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module),
callback, userData,
inlineSingleUseValues, indentLetBody);
}

View File

@ -403,6 +403,15 @@ declare_mlir_dialect_python_bindings(
"../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/SMTOps.td
GEN_ENUM_BINDINGS
SOURCES
dialects/smt.py
DIALECT_NAME smt)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@ -664,6 +673,21 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
MLIRCAPILinalg
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
MODULE_NAME _mlirDialectsSMT
ADD_TO_PARENT MLIRPythonSources.Dialects.smt
ROOT_DIR "${PYTHON_SOURCE_DIR}"
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectSMT.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPISMT
MLIRCAPIExportSMTLIB
)
declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
MODULE_NAME _mlirSparseTensorPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor

View File

@ -0,0 +1,14 @@
//===- SMTOps.td - Entry point for SMT bindings ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef BINDINGS_PYTHON_SMT_OPS
#define BINDINGS_PYTHON_SMT_OPS
include "mlir/Dialect/SMT/IR/SMT.td"
#endif // BINDINGS_PYTHON_SMT_OPS

View File

@ -0,0 +1,33 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._smt_ops_gen import *
from .._mlir_libs._mlirDialectsSMT import *
from ..extras.meta import region_op
def bool_t():
return BoolType.get()
def bv_t(width):
return BitVectorType.get(width)
def _solver(
inputs=None,
results=None,
loc=None,
ip=None,
):
if inputs is None:
inputs = []
if results is None:
results = []
return SolverOp(results, inputs, loc=loc, ip=ip)
solver = region_op(_solver, terminator=YieldOp)

View File

@ -34,7 +34,8 @@ void testExportSMTLIB(MlirContext ctx) {
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(testSMT));
MlirLogicalResult result = mlirExportSMTLIB(module, dumpCallback, NULL);
MlirLogicalResult result =
mlirTranslateModuleToSMTLIB(module, dumpCallback, NULL, false, false);
(void)result;
assert(mlirLogicalResultIsSuccess(result));
@ -44,13 +45,13 @@ void testExportSMTLIB(MlirContext ctx) {
}
void testSMTType(MlirContext ctx) {
MlirType boolType = smtTypeGetBool(ctx);
MlirType intType = smtTypeGetInt(ctx);
MlirType arrayType = smtTypeGetArray(ctx, intType, boolType);
MlirType bvType = smtTypeGetBitVector(ctx, 32);
MlirType boolType = mlirSMTTypeGetBool(ctx);
MlirType intType = mlirSMTTypeGetInt(ctx);
MlirType arrayType = mlirSMTTypeGetArray(ctx, intType, boolType);
MlirType bvType = mlirSMTTypeGetBitVector(ctx, 32);
MlirType funcType =
smtTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType);
MlirType sortType = smtTypeGetSort(
mlirSMTTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType);
MlirType sortType = mlirSMTTypeGetSort(
ctx, mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sort")), 0,
NULL);
@ -68,107 +69,107 @@ void testSMTType(MlirContext ctx) {
mlirTypeDump(sortType);
// CHECK: bool_is_any_non_func_smt_value_type
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(boolType)
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(boolType)
? "bool_is_any_non_func_smt_value_type\n"
: "bool_is_func_smt_value_type\n");
// CHECK: int_is_any_non_func_smt_value_type
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(intType)
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(intType)
? "int_is_any_non_func_smt_value_type\n"
: "int_is_func_smt_value_type\n");
// CHECK: array_is_any_non_func_smt_value_type
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(arrayType)
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(arrayType)
? "array_is_any_non_func_smt_value_type\n"
: "array_is_func_smt_value_type\n");
// CHECK: bit_vector_is_any_non_func_smt_value_type
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(bvType)
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(bvType)
? "bit_vector_is_any_non_func_smt_value_type\n"
: "bit_vector_is_func_smt_value_type\n");
// CHECK: sort_is_any_non_func_smt_value_type
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(sortType)
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(sortType)
? "sort_is_any_non_func_smt_value_type\n"
: "sort_is_func_smt_value_type\n");
// CHECK: smt_func_is_func_smt_value_type
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(funcType)
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(funcType)
? "smt_func_is_any_non_func_smt_value_type\n"
: "smt_func_is_func_smt_value_type\n");
// CHECK: bool_is_any_smt_value_type
fprintf(stderr, smtTypeIsAnySMTValueType(boolType)
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(boolType)
? "bool_is_any_smt_value_type\n"
: "bool_is_not_any_smt_value_type\n");
// CHECK: int_is_any_smt_value_type
fprintf(stderr, smtTypeIsAnySMTValueType(intType)
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(intType)
? "int_is_any_smt_value_type\n"
: "int_is_not_any_smt_value_type\n");
// CHECK: array_is_any_smt_value_type
fprintf(stderr, smtTypeIsAnySMTValueType(arrayType)
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(arrayType)
? "array_is_any_smt_value_type\n"
: "array_is_not_any_smt_value_type\n");
// CHECK: array_is_any_smt_value_type
fprintf(stderr, smtTypeIsAnySMTValueType(bvType)
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(bvType)
? "array_is_any_smt_value_type\n"
: "array_is_not_any_smt_value_type\n");
// CHECK: smt_func_is_any_smt_value_type
fprintf(stderr, smtTypeIsAnySMTValueType(funcType)
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(funcType)
? "smt_func_is_any_smt_value_type\n"
: "smt_func_is_not_any_smt_value_type\n");
// CHECK: sort_is_any_smt_value_type
fprintf(stderr, smtTypeIsAnySMTValueType(sortType)
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(sortType)
? "sort_is_any_smt_value_type\n"
: "sort_is_not_any_smt_value_type\n");
// CHECK: int_type_is_not_a_bool
fprintf(stderr, smtTypeIsABool(intType) ? "int_type_is_a_bool\n"
: "int_type_is_not_a_bool\n");
fprintf(stderr, mlirSMTTypeIsABool(intType) ? "int_type_is_a_bool\n"
: "int_type_is_not_a_bool\n");
// CHECK: bool_type_is_not_a_int
fprintf(stderr, smtTypeIsAInt(boolType) ? "bool_type_is_a_int\n"
: "bool_type_is_not_a_int\n");
fprintf(stderr, mlirSMTTypeIsAInt(boolType) ? "bool_type_is_a_int\n"
: "bool_type_is_not_a_int\n");
// CHECK: bv_type_is_not_a_array
fprintf(stderr, smtTypeIsAArray(bvType) ? "bv_type_is_a_array\n"
: "bv_type_is_not_a_array\n");
fprintf(stderr, mlirSMTTypeIsAArray(bvType) ? "bv_type_is_a_array\n"
: "bv_type_is_not_a_array\n");
// CHECK: array_type_is_not_a_bit_vector
fprintf(stderr, smtTypeIsABitVector(arrayType)
fprintf(stderr, mlirSMTTypeIsABitVector(arrayType)
? "array_type_is_a_bit_vector\n"
: "array_type_is_not_a_bit_vector\n");
// CHECK: sort_type_is_not_a_smt_func
fprintf(stderr, smtTypeIsASMTFunc(sortType)
fprintf(stderr, mlirSMTTypeIsASMTFunc(sortType)
? "sort_type_is_a_smt_func\n"
: "sort_type_is_not_a_smt_func\n");
// CHECK: func_type_is_not_a_sort
fprintf(stderr, smtTypeIsASort(funcType) ? "func_type_is_a_sort\n"
: "func_type_is_not_a_sort\n");
fprintf(stderr, mlirSMTTypeIsASort(funcType) ? "func_type_is_a_sort\n"
: "func_type_is_not_a_sort\n");
}
void testSMTAttribute(MlirContext ctx) {
// CHECK: slt_is_BVCmpPredicate
fprintf(stderr,
smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt"))
? "slt_is_BVCmpPredicate\n"
: "slt_is_not_BVCmpPredicate\n");
fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate(
ctx, mlirStringRefCreateFromCString("slt"))
? "slt_is_BVCmpPredicate\n"
: "slt_is_not_BVCmpPredicate\n");
// CHECK: lt_is_not_BVCmpPredicate
fprintf(stderr,
smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("lt"))
? "lt_is_BVCmpPredicate\n"
: "lt_is_not_BVCmpPredicate\n");
fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate(
ctx, mlirStringRefCreateFromCString("lt"))
? "lt_is_BVCmpPredicate\n"
: "lt_is_not_BVCmpPredicate\n");
// CHECK: slt_is_not_IntPredicate
fprintf(stderr,
smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("slt"))
? "slt_is_IntPredicate\n"
: "slt_is_not_IntPredicate\n");
fprintf(stderr, mlirSMTAttrCheckIntPredicate(
ctx, mlirStringRefCreateFromCString("slt"))
? "slt_is_IntPredicate\n"
: "slt_is_not_IntPredicate\n");
// CHECK: lt_is_IntPredicate
fprintf(stderr,
smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("lt"))
? "lt_is_IntPredicate\n"
: "lt_is_not_IntPredicate\n");
fprintf(stderr, mlirSMTAttrCheckIntPredicate(
ctx, mlirStringRefCreateFromCString("lt"))
? "lt_is_IntPredicate\n"
: "lt_is_not_IntPredicate\n");
// CHECK: #smt.bv<5> : !smt.bv<32>
mlirAttributeDump(smtAttrGetBitVector(ctx, 5, 32));
mlirAttributeDump(mlirSMTAttrGetBitVector(ctx, 5, 32));
// CHECK: 0 : i64
mlirAttributeDump(
smtAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")));
mlirSMTAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")));
// CHECK: 0 : i64
mlirAttributeDump(
smtAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt")));
mlirSMTAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt")));
}
int main(void) {

View File

@ -0,0 +1,87 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.dialects import smt, arith
from mlir.ir import Context, Location, Module, InsertionPoint, F32Type
def run(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f(module)
print(module)
assert module.operation.verify()
# CHECK-LABEL: TEST: test_smoke
@run
def test_smoke(_module):
true = smt.constant(True)
false = smt.constant(False)
# CHECK: smt.constant true
# CHECK: smt.constant false
# CHECK-LABEL: TEST: test_types
@run
def test_types(_module):
bool_t = smt.bool_t()
bitvector_t = smt.bv_t(5)
# CHECK: !smt.bool
print(bool_t)
# CHECK: !smt.bv<5>
print(bitvector_t)
# CHECK-LABEL: TEST: test_solver_op
@run
def test_solver_op(_module):
@smt.solver
def foo1():
true = smt.constant(True)
false = smt.constant(False)
# CHECK: smt.solver() : () -> () {
# CHECK: %true = smt.constant true
# CHECK: %false = smt.constant false
# CHECK: }
f32 = F32Type.get()
@smt.solver(results=[f32])
def foo2():
return arith.ConstantOp(f32, 1.0)
# CHECK: %{{.*}} = smt.solver() : () -> f32 {
# CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
# CHECK: smt.yield %[[CST1]] : f32
# CHECK: }
two = arith.ConstantOp(f32, 2.0)
# CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
print(two)
@smt.solver(inputs=[two], results=[f32])
def foo3(z: f32):
return z
# CHECK: %{{.*}} = smt.solver(%[[CST2]]) : (f32) -> f32 {
# CHECK: ^bb0(%[[ARG0:.*]]: f32):
# CHECK: smt.yield %[[ARG0]] : f32
# CHECK: }
# CHECK-LABEL: TEST: test_export_smtlib
@run
def test_export_smtlib(module):
@smt.solver
def foo1():
true = smt.constant(True)
smt.assert_(true)
query = smt.export_smtlib(module.operation)
# CHECK: ; solver scope 0
# CHECK: (assert true)
# CHECK: (reset)
print(query)