[mlir][SMT] add python bindings (#135674)
This PR adds "rich" python bindings to SMT dialect.
This commit is contained in:
parent
7623501c05
commit
697aa9995c
@ -26,82 +26,83 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt);
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks if the given type is any non-func SMT value type.
|
||||
MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type);
|
||||
|
||||
/// Checks if the given type is any SMT value type.
|
||||
MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type);
|
||||
|
||||
/// Checks if the given type is a smt::ArrayType.
|
||||
MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type);
|
||||
|
||||
/// Creates an array type with the given domain and range types.
|
||||
MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx,
|
||||
MlirType domainType,
|
||||
MlirType rangeType);
|
||||
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx,
|
||||
MlirType domainType,
|
||||
MlirType rangeType);
|
||||
|
||||
/// Checks if the given type is a smt::BitVectorType.
|
||||
MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type);
|
||||
|
||||
/// Creates a smt::BitVectorType with the given width.
|
||||
MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width);
|
||||
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx,
|
||||
int32_t width);
|
||||
|
||||
/// Checks if the given type is a smt::BoolType.
|
||||
MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type);
|
||||
|
||||
/// Creates a smt::BoolType.
|
||||
MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx);
|
||||
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx);
|
||||
|
||||
/// Checks if the given type is a smt::IntType.
|
||||
MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type);
|
||||
|
||||
/// Creates a smt::IntType.
|
||||
MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx);
|
||||
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx);
|
||||
|
||||
/// Checks if the given type is a smt::FuncType.
|
||||
MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type);
|
||||
|
||||
/// Creates a smt::FuncType with the given domain and range types.
|
||||
MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx,
|
||||
size_t numberOfDomainTypes,
|
||||
const MlirType *domainTypes,
|
||||
MlirType rangeType);
|
||||
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx,
|
||||
size_t numberOfDomainTypes,
|
||||
const MlirType *domainTypes,
|
||||
MlirType rangeType);
|
||||
|
||||
/// Checks if the given type is a smt::SortType.
|
||||
MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type);
|
||||
|
||||
/// Creates a smt::SortType with the given identifier and sort parameters.
|
||||
MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx,
|
||||
MlirIdentifier identifier,
|
||||
size_t numberOfSortParams,
|
||||
const MlirType *sortParams);
|
||||
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx,
|
||||
MlirIdentifier identifier,
|
||||
size_t numberOfSortParams,
|
||||
const MlirType *sortParams);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attribute API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks if the given string is a valid smt::BVCmpPredicate.
|
||||
MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx,
|
||||
MlirStringRef str);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx,
|
||||
MlirStringRef str);
|
||||
|
||||
/// Checks if the given string is a valid smt::IntPredicate.
|
||||
MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx,
|
||||
MlirStringRef str);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx,
|
||||
MlirStringRef str);
|
||||
|
||||
/// Checks if the given attribute is a smt::SMTAttribute.
|
||||
MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr);
|
||||
|
||||
/// Creates a smt::BitVectorAttr with the given value and width.
|
||||
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx,
|
||||
uint64_t value,
|
||||
unsigned width);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx,
|
||||
uint64_t value,
|
||||
unsigned width);
|
||||
|
||||
/// Creates a smt::BVCmpPredicateAttr with the given string.
|
||||
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx,
|
||||
MlirStringRef str);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str);
|
||||
|
||||
/// Creates a smt::IntPredicateAttr with the given string.
|
||||
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx,
|
||||
MlirStringRef str);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx,
|
||||
MlirStringRef str);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -21,9 +21,13 @@ extern "C" {
|
||||
|
||||
/// Emits SMTLIB for the specified module using the provided callback and user
|
||||
/// data
|
||||
MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule,
|
||||
MlirStringCallback,
|
||||
void *userData);
|
||||
MLIR_CAPI_EXPORTED MlirLogicalResult
|
||||
mlirTranslateModuleToSMTLIB(MlirModule, MlirStringCallback, void *userData,
|
||||
bool inlineSingleUseValues, bool indentLetBody);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB(
|
||||
MlirOperation, MlirStringCallback, void *userData,
|
||||
bool inlineSingleUseValues, bool indentLetBody);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
83
mlir/lib/Bindings/Python/DialectSMT.cpp
Normal file
83
mlir/lib/Bindings/Python/DialectSMT.cpp
Normal file
@ -0,0 +1,83 @@
|
||||
//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "NanobindUtils.h"
|
||||
|
||||
#include "mlir-c/Dialect/SMT.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir-c/Target/ExportSMTLIB.h"
|
||||
#include "mlir/Bindings/Python/Diagnostics.h"
|
||||
#include "mlir/Bindings/Python/Nanobind.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
using namespace nanobind::literals;
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
using namespace mlir::python::nanobind_adaptors;
|
||||
|
||||
void populateDialectSMTSubmodule(nanobind::module_ &m) {
|
||||
|
||||
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](const nb::object &, MlirContext context) {
|
||||
return mlirSMTTypeGetBool(context);
|
||||
},
|
||||
"cls"_a, "context"_a.none() = nb::none());
|
||||
auto smtBitVectorType =
|
||||
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](const nb::object &, int32_t width, MlirContext context) {
|
||||
return mlirSMTTypeGetBitVector(context, width);
|
||||
},
|
||||
"cls"_a, "width"_a, "context"_a.none() = nb::none());
|
||||
|
||||
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
|
||||
bool indentLetBody) {
|
||||
mlir::python::CollectDiagnosticsToStringScope scope(
|
||||
mlirOperationGetContext(module));
|
||||
PyPrintAccumulator printAccum;
|
||||
MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
|
||||
module, printAccum.getCallback(), printAccum.getUserData(),
|
||||
inlineSingleUseValues, indentLetBody);
|
||||
if (mlirLogicalResultIsSuccess(result))
|
||||
return printAccum.join();
|
||||
throw nb::value_error(
|
||||
("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage())
|
||||
.c_str());
|
||||
};
|
||||
|
||||
m.def(
|
||||
"export_smtlib",
|
||||
[&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
|
||||
bool indentLetBody) {
|
||||
return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
|
||||
},
|
||||
"module"_a, "inline_single_use_values"_a = false,
|
||||
"indent_let_body"_a = false);
|
||||
m.def(
|
||||
"export_smtlib",
|
||||
[&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
|
||||
bool indentLetBody) {
|
||||
return exportSMTLIB(mlirModuleGetOperation(module),
|
||||
inlineSingleUseValues, indentLetBody);
|
||||
},
|
||||
"module"_a, "inline_single_use_values"_a = false,
|
||||
"indent_let_body"_a = false);
|
||||
}
|
||||
|
||||
NB_MODULE(_mlirDialectsSMT, m) {
|
||||
m.doc() = "MLIR SMT Dialect";
|
||||
|
||||
populateDialectSMTSubmodule(m);
|
||||
}
|
@ -25,46 +25,49 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect)
|
||||
// Type API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) {
|
||||
bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) {
|
||||
return isAnyNonFuncSMTValueType(unwrap(type));
|
||||
}
|
||||
|
||||
bool smtTypeIsAnySMTValueType(MlirType type) {
|
||||
bool mlirSMTTypeIsAnySMTValueType(MlirType type) {
|
||||
return isAnySMTValueType(unwrap(type));
|
||||
}
|
||||
|
||||
bool smtTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
|
||||
bool mlirSMTTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
|
||||
|
||||
MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType,
|
||||
MlirType rangeType) {
|
||||
MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType,
|
||||
MlirType rangeType) {
|
||||
return wrap(
|
||||
ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType)));
|
||||
}
|
||||
|
||||
bool smtTypeIsABitVector(MlirType type) {
|
||||
bool mlirSMTTypeIsABitVector(MlirType type) {
|
||||
return isa<BitVectorType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) {
|
||||
MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) {
|
||||
return wrap(BitVectorType::get(unwrap(ctx), width));
|
||||
}
|
||||
|
||||
bool smtTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
|
||||
bool mlirSMTTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
|
||||
|
||||
MlirType smtTypeGetBool(MlirContext ctx) {
|
||||
MlirType mlirSMTTypeGetBool(MlirContext ctx) {
|
||||
return wrap(BoolType::get(unwrap(ctx)));
|
||||
}
|
||||
|
||||
bool smtTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
|
||||
bool mlirSMTTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
|
||||
|
||||
MlirType smtTypeGetInt(MlirContext ctx) {
|
||||
MlirType mlirSMTTypeGetInt(MlirContext ctx) {
|
||||
return wrap(IntType::get(unwrap(ctx)));
|
||||
}
|
||||
|
||||
bool smtTypeIsASMTFunc(MlirType type) { return isa<SMTFuncType>(unwrap(type)); }
|
||||
bool mlirSMTTypeIsASMTFunc(MlirType type) {
|
||||
return isa<SMTFuncType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
|
||||
const MlirType *domainTypes, MlirType rangeType) {
|
||||
MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
|
||||
const MlirType *domainTypes,
|
||||
MlirType rangeType) {
|
||||
SmallVector<Type> domainTypesVec;
|
||||
domainTypesVec.reserve(numberOfDomainTypes);
|
||||
|
||||
@ -74,10 +77,11 @@ MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
|
||||
return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType)));
|
||||
}
|
||||
|
||||
bool smtTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
|
||||
bool mlirSMTTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
|
||||
|
||||
MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
|
||||
size_t numberOfSortParams, const MlirType *sortParams) {
|
||||
MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
|
||||
size_t numberOfSortParams,
|
||||
const MlirType *sortParams) {
|
||||
SmallVector<Type> sortParamsVec;
|
||||
sortParamsVec.reserve(numberOfSortParams);
|
||||
|
||||
@ -91,31 +95,31 @@ MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
|
||||
// Attribute API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
|
||||
bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
|
||||
return symbolizeBVCmpPredicate(unwrap(str)).has_value();
|
||||
}
|
||||
|
||||
bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
|
||||
bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
|
||||
return symbolizeIntPredicate(unwrap(str)).has_value();
|
||||
}
|
||||
|
||||
bool smtAttrIsASMTAttribute(MlirAttribute attr) {
|
||||
bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) {
|
||||
return isa<BitVectorAttr, BVCmpPredicateAttr, IntPredicateAttr>(unwrap(attr));
|
||||
}
|
||||
|
||||
MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value,
|
||||
unsigned width) {
|
||||
MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value,
|
||||
unsigned width) {
|
||||
return wrap(BitVectorAttr::get(unwrap(ctx), value, width));
|
||||
}
|
||||
|
||||
MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
|
||||
MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
|
||||
auto predicate = symbolizeBVCmpPredicate(unwrap(str));
|
||||
assert(predicate.has_value() && "invalid predicate");
|
||||
|
||||
return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value()));
|
||||
}
|
||||
|
||||
MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
|
||||
MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
|
||||
auto predicate = symbolizeIntPredicate(unwrap(str));
|
||||
assert(predicate.has_value() && "invalid predicate");
|
||||
|
||||
|
@ -19,9 +19,24 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
MlirLogicalResult mlirExportSMTLIB(MlirModule module,
|
||||
MlirStringCallback callback,
|
||||
void *userData) {
|
||||
MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module,
|
||||
MlirStringCallback callback,
|
||||
void *userData,
|
||||
bool inlineSingleUseValues,
|
||||
bool indentLetBody) {
|
||||
mlir::detail::CallbackOstream stream(callback, userData);
|
||||
smt::SMTEmissionOptions options;
|
||||
options.inlineSingleUseValues = inlineSingleUseValues;
|
||||
options.indentLetBody = indentLetBody;
|
||||
return wrap(smt::exportSMTLIB(unwrap(module), stream));
|
||||
}
|
||||
|
||||
MlirLogicalResult mlirTranslateModuleToSMTLIB(MlirModule module,
|
||||
MlirStringCallback callback,
|
||||
void *userData,
|
||||
bool inlineSingleUseValues,
|
||||
bool indentLetBody) {
|
||||
return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module),
|
||||
callback, userData,
|
||||
inlineSingleUseValues, indentLetBody);
|
||||
}
|
||||
|
@ -403,6 +403,15 @@ declare_mlir_dialect_python_bindings(
|
||||
"../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
|
||||
)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/SMTOps.td
|
||||
GEN_ENUM_BINDINGS
|
||||
SOURCES
|
||||
dialects/smt.py
|
||||
DIALECT_NAME smt)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
@ -664,6 +673,21 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
|
||||
MLIRCAPILinalg
|
||||
)
|
||||
|
||||
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
|
||||
MODULE_NAME _mlirDialectsSMT
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.smt
|
||||
ROOT_DIR "${PYTHON_SOURCE_DIR}"
|
||||
PYTHON_BINDINGS_LIBRARY nanobind
|
||||
SOURCES
|
||||
DialectSMT.cpp
|
||||
PRIVATE_LINK_LIBS
|
||||
LLVMSupport
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
MLIRCAPIIR
|
||||
MLIRCAPISMT
|
||||
MLIRCAPIExportSMTLIB
|
||||
)
|
||||
|
||||
declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
|
||||
MODULE_NAME _mlirSparseTensorPasses
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
|
||||
|
14
mlir/python/mlir/dialects/SMTOps.td
Normal file
14
mlir/python/mlir/dialects/SMTOps.td
Normal file
@ -0,0 +1,14 @@
|
||||
//===- SMTOps.td - Entry point for SMT bindings ------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef BINDINGS_PYTHON_SMT_OPS
|
||||
#define BINDINGS_PYTHON_SMT_OPS
|
||||
|
||||
include "mlir/Dialect/SMT/IR/SMT.td"
|
||||
|
||||
#endif // BINDINGS_PYTHON_SMT_OPS
|
33
mlir/python/mlir/dialects/smt.py
Normal file
33
mlir/python/mlir/dialects/smt.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._smt_ops_gen import *
|
||||
|
||||
from .._mlir_libs._mlirDialectsSMT import *
|
||||
from ..extras.meta import region_op
|
||||
|
||||
|
||||
def bool_t():
|
||||
return BoolType.get()
|
||||
|
||||
|
||||
def bv_t(width):
|
||||
return BitVectorType.get(width)
|
||||
|
||||
|
||||
def _solver(
|
||||
inputs=None,
|
||||
results=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
if results is None:
|
||||
results = []
|
||||
|
||||
return SolverOp(results, inputs, loc=loc, ip=ip)
|
||||
|
||||
|
||||
solver = region_op(_solver, terminator=YieldOp)
|
@ -34,7 +34,8 @@ void testExportSMTLIB(MlirContext ctx) {
|
||||
MlirModule module =
|
||||
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(testSMT));
|
||||
|
||||
MlirLogicalResult result = mlirExportSMTLIB(module, dumpCallback, NULL);
|
||||
MlirLogicalResult result =
|
||||
mlirTranslateModuleToSMTLIB(module, dumpCallback, NULL, false, false);
|
||||
(void)result;
|
||||
assert(mlirLogicalResultIsSuccess(result));
|
||||
|
||||
@ -44,13 +45,13 @@ void testExportSMTLIB(MlirContext ctx) {
|
||||
}
|
||||
|
||||
void testSMTType(MlirContext ctx) {
|
||||
MlirType boolType = smtTypeGetBool(ctx);
|
||||
MlirType intType = smtTypeGetInt(ctx);
|
||||
MlirType arrayType = smtTypeGetArray(ctx, intType, boolType);
|
||||
MlirType bvType = smtTypeGetBitVector(ctx, 32);
|
||||
MlirType boolType = mlirSMTTypeGetBool(ctx);
|
||||
MlirType intType = mlirSMTTypeGetInt(ctx);
|
||||
MlirType arrayType = mlirSMTTypeGetArray(ctx, intType, boolType);
|
||||
MlirType bvType = mlirSMTTypeGetBitVector(ctx, 32);
|
||||
MlirType funcType =
|
||||
smtTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType);
|
||||
MlirType sortType = smtTypeGetSort(
|
||||
mlirSMTTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType);
|
||||
MlirType sortType = mlirSMTTypeGetSort(
|
||||
ctx, mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sort")), 0,
|
||||
NULL);
|
||||
|
||||
@ -68,107 +69,107 @@ void testSMTType(MlirContext ctx) {
|
||||
mlirTypeDump(sortType);
|
||||
|
||||
// CHECK: bool_is_any_non_func_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(boolType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(boolType)
|
||||
? "bool_is_any_non_func_smt_value_type\n"
|
||||
: "bool_is_func_smt_value_type\n");
|
||||
// CHECK: int_is_any_non_func_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(intType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(intType)
|
||||
? "int_is_any_non_func_smt_value_type\n"
|
||||
: "int_is_func_smt_value_type\n");
|
||||
// CHECK: array_is_any_non_func_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(arrayType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(arrayType)
|
||||
? "array_is_any_non_func_smt_value_type\n"
|
||||
: "array_is_func_smt_value_type\n");
|
||||
// CHECK: bit_vector_is_any_non_func_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(bvType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(bvType)
|
||||
? "bit_vector_is_any_non_func_smt_value_type\n"
|
||||
: "bit_vector_is_func_smt_value_type\n");
|
||||
// CHECK: sort_is_any_non_func_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(sortType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(sortType)
|
||||
? "sort_is_any_non_func_smt_value_type\n"
|
||||
: "sort_is_func_smt_value_type\n");
|
||||
// CHECK: smt_func_is_func_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(funcType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(funcType)
|
||||
? "smt_func_is_any_non_func_smt_value_type\n"
|
||||
: "smt_func_is_func_smt_value_type\n");
|
||||
|
||||
// CHECK: bool_is_any_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnySMTValueType(boolType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(boolType)
|
||||
? "bool_is_any_smt_value_type\n"
|
||||
: "bool_is_not_any_smt_value_type\n");
|
||||
// CHECK: int_is_any_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnySMTValueType(intType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(intType)
|
||||
? "int_is_any_smt_value_type\n"
|
||||
: "int_is_not_any_smt_value_type\n");
|
||||
// CHECK: array_is_any_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnySMTValueType(arrayType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(arrayType)
|
||||
? "array_is_any_smt_value_type\n"
|
||||
: "array_is_not_any_smt_value_type\n");
|
||||
// CHECK: array_is_any_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnySMTValueType(bvType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(bvType)
|
||||
? "array_is_any_smt_value_type\n"
|
||||
: "array_is_not_any_smt_value_type\n");
|
||||
// CHECK: smt_func_is_any_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnySMTValueType(funcType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(funcType)
|
||||
? "smt_func_is_any_smt_value_type\n"
|
||||
: "smt_func_is_not_any_smt_value_type\n");
|
||||
// CHECK: sort_is_any_smt_value_type
|
||||
fprintf(stderr, smtTypeIsAnySMTValueType(sortType)
|
||||
fprintf(stderr, mlirSMTTypeIsAnySMTValueType(sortType)
|
||||
? "sort_is_any_smt_value_type\n"
|
||||
: "sort_is_not_any_smt_value_type\n");
|
||||
|
||||
// CHECK: int_type_is_not_a_bool
|
||||
fprintf(stderr, smtTypeIsABool(intType) ? "int_type_is_a_bool\n"
|
||||
: "int_type_is_not_a_bool\n");
|
||||
fprintf(stderr, mlirSMTTypeIsABool(intType) ? "int_type_is_a_bool\n"
|
||||
: "int_type_is_not_a_bool\n");
|
||||
// CHECK: bool_type_is_not_a_int
|
||||
fprintf(stderr, smtTypeIsAInt(boolType) ? "bool_type_is_a_int\n"
|
||||
: "bool_type_is_not_a_int\n");
|
||||
fprintf(stderr, mlirSMTTypeIsAInt(boolType) ? "bool_type_is_a_int\n"
|
||||
: "bool_type_is_not_a_int\n");
|
||||
// CHECK: bv_type_is_not_a_array
|
||||
fprintf(stderr, smtTypeIsAArray(bvType) ? "bv_type_is_a_array\n"
|
||||
: "bv_type_is_not_a_array\n");
|
||||
fprintf(stderr, mlirSMTTypeIsAArray(bvType) ? "bv_type_is_a_array\n"
|
||||
: "bv_type_is_not_a_array\n");
|
||||
// CHECK: array_type_is_not_a_bit_vector
|
||||
fprintf(stderr, smtTypeIsABitVector(arrayType)
|
||||
fprintf(stderr, mlirSMTTypeIsABitVector(arrayType)
|
||||
? "array_type_is_a_bit_vector\n"
|
||||
: "array_type_is_not_a_bit_vector\n");
|
||||
// CHECK: sort_type_is_not_a_smt_func
|
||||
fprintf(stderr, smtTypeIsASMTFunc(sortType)
|
||||
fprintf(stderr, mlirSMTTypeIsASMTFunc(sortType)
|
||||
? "sort_type_is_a_smt_func\n"
|
||||
: "sort_type_is_not_a_smt_func\n");
|
||||
// CHECK: func_type_is_not_a_sort
|
||||
fprintf(stderr, smtTypeIsASort(funcType) ? "func_type_is_a_sort\n"
|
||||
: "func_type_is_not_a_sort\n");
|
||||
fprintf(stderr, mlirSMTTypeIsASort(funcType) ? "func_type_is_a_sort\n"
|
||||
: "func_type_is_not_a_sort\n");
|
||||
}
|
||||
|
||||
void testSMTAttribute(MlirContext ctx) {
|
||||
// CHECK: slt_is_BVCmpPredicate
|
||||
fprintf(stderr,
|
||||
smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt"))
|
||||
? "slt_is_BVCmpPredicate\n"
|
||||
: "slt_is_not_BVCmpPredicate\n");
|
||||
fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate(
|
||||
ctx, mlirStringRefCreateFromCString("slt"))
|
||||
? "slt_is_BVCmpPredicate\n"
|
||||
: "slt_is_not_BVCmpPredicate\n");
|
||||
// CHECK: lt_is_not_BVCmpPredicate
|
||||
fprintf(stderr,
|
||||
smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("lt"))
|
||||
? "lt_is_BVCmpPredicate\n"
|
||||
: "lt_is_not_BVCmpPredicate\n");
|
||||
fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate(
|
||||
ctx, mlirStringRefCreateFromCString("lt"))
|
||||
? "lt_is_BVCmpPredicate\n"
|
||||
: "lt_is_not_BVCmpPredicate\n");
|
||||
// CHECK: slt_is_not_IntPredicate
|
||||
fprintf(stderr,
|
||||
smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("slt"))
|
||||
? "slt_is_IntPredicate\n"
|
||||
: "slt_is_not_IntPredicate\n");
|
||||
fprintf(stderr, mlirSMTAttrCheckIntPredicate(
|
||||
ctx, mlirStringRefCreateFromCString("slt"))
|
||||
? "slt_is_IntPredicate\n"
|
||||
: "slt_is_not_IntPredicate\n");
|
||||
// CHECK: lt_is_IntPredicate
|
||||
fprintf(stderr,
|
||||
smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("lt"))
|
||||
? "lt_is_IntPredicate\n"
|
||||
: "lt_is_not_IntPredicate\n");
|
||||
fprintf(stderr, mlirSMTAttrCheckIntPredicate(
|
||||
ctx, mlirStringRefCreateFromCString("lt"))
|
||||
? "lt_is_IntPredicate\n"
|
||||
: "lt_is_not_IntPredicate\n");
|
||||
|
||||
// CHECK: #smt.bv<5> : !smt.bv<32>
|
||||
mlirAttributeDump(smtAttrGetBitVector(ctx, 5, 32));
|
||||
mlirAttributeDump(mlirSMTAttrGetBitVector(ctx, 5, 32));
|
||||
// CHECK: 0 : i64
|
||||
mlirAttributeDump(
|
||||
smtAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")));
|
||||
mlirSMTAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")));
|
||||
// CHECK: 0 : i64
|
||||
mlirAttributeDump(
|
||||
smtAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt")));
|
||||
mlirSMTAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt")));
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
|
87
mlir/test/python/dialects/smt.py
Normal file
87
mlir/test/python/dialects/smt.py
Normal file
@ -0,0 +1,87 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.dialects import smt, arith
|
||||
from mlir.ir import Context, Location, Module, InsertionPoint, F32Type
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
f(module)
|
||||
print(module)
|
||||
assert module.operation.verify()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: test_smoke
|
||||
@run
|
||||
def test_smoke(_module):
|
||||
true = smt.constant(True)
|
||||
false = smt.constant(False)
|
||||
# CHECK: smt.constant true
|
||||
# CHECK: smt.constant false
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: test_types
|
||||
@run
|
||||
def test_types(_module):
|
||||
bool_t = smt.bool_t()
|
||||
bitvector_t = smt.bv_t(5)
|
||||
# CHECK: !smt.bool
|
||||
print(bool_t)
|
||||
# CHECK: !smt.bv<5>
|
||||
print(bitvector_t)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: test_solver_op
|
||||
@run
|
||||
def test_solver_op(_module):
|
||||
@smt.solver
|
||||
def foo1():
|
||||
true = smt.constant(True)
|
||||
false = smt.constant(False)
|
||||
|
||||
# CHECK: smt.solver() : () -> () {
|
||||
# CHECK: %true = smt.constant true
|
||||
# CHECK: %false = smt.constant false
|
||||
# CHECK: }
|
||||
|
||||
f32 = F32Type.get()
|
||||
|
||||
@smt.solver(results=[f32])
|
||||
def foo2():
|
||||
return arith.ConstantOp(f32, 1.0)
|
||||
|
||||
# CHECK: %{{.*}} = smt.solver() : () -> f32 {
|
||||
# CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
|
||||
# CHECK: smt.yield %[[CST1]] : f32
|
||||
# CHECK: }
|
||||
|
||||
two = arith.ConstantOp(f32, 2.0)
|
||||
# CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
|
||||
print(two)
|
||||
|
||||
@smt.solver(inputs=[two], results=[f32])
|
||||
def foo3(z: f32):
|
||||
return z
|
||||
|
||||
# CHECK: %{{.*}} = smt.solver(%[[CST2]]) : (f32) -> f32 {
|
||||
# CHECK: ^bb0(%[[ARG0:.*]]: f32):
|
||||
# CHECK: smt.yield %[[ARG0]] : f32
|
||||
# CHECK: }
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: test_export_smtlib
|
||||
@run
|
||||
def test_export_smtlib(module):
|
||||
@smt.solver
|
||||
def foo1():
|
||||
true = smt.constant(True)
|
||||
smt.assert_(true)
|
||||
|
||||
query = smt.export_smtlib(module.operation)
|
||||
# CHECK: ; solver scope 0
|
||||
# CHECK: (assert true)
|
||||
# CHECK: (reset)
|
||||
print(query)
|
Loading…
x
Reference in New Issue
Block a user