[MLIR][Python] Use ir.Value directly instead of _SubClassValueT (#82341)
_SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen. For example def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT: ... here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any.
This commit is contained in:
parent
6d160a49c2
commit
6ce5159945
@ -10,4 +10,4 @@ class _Globals:
|
||||
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
|
||||
|
||||
def register_dialect(dialect_class: type) -> object: ...
|
||||
def register_operation(dialect_class: type) -> object: ...
|
||||
def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ...
|
||||
|
@ -8,7 +8,6 @@ from typing import (
|
||||
Sequence as _Sequence,
|
||||
Tuple as _Tuple,
|
||||
Type as _Type,
|
||||
TypeVar as _TypeVar,
|
||||
Union as _Union,
|
||||
)
|
||||
|
||||
@ -143,12 +142,6 @@ def get_op_result_or_op_results(
|
||||
else op
|
||||
)
|
||||
|
||||
|
||||
# This is the standard way to indicate subclass/inheritance relationship
|
||||
# see the typing.Type doc string.
|
||||
_U = _TypeVar("_U", bound=_cext.ir.Value)
|
||||
SubClassValueT = _Type[_U]
|
||||
|
||||
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
|
||||
ResultValueT = _Union[ResultValueTypeTuple]
|
||||
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
|
||||
|
@ -12,7 +12,6 @@ try:
|
||||
get_default_loc_context as _get_default_loc_context,
|
||||
_cext as _ods_cext,
|
||||
get_op_result_or_op_results as _get_op_result_or_op_results,
|
||||
SubClassValueT as _SubClassValueT,
|
||||
)
|
||||
|
||||
from typing import Any, List, Union
|
||||
@ -81,5 +80,5 @@ class ConstantOp(ConstantOp):
|
||||
|
||||
def constant(
|
||||
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
|
||||
) -> _SubClassValueT:
|
||||
) -> Value:
|
||||
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
|
||||
|
@ -7,7 +7,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
// CHECK: @_ods_cext.register_dialect
|
||||
// CHECK: class _Dialect(_ods_ir.Dialect):
|
||||
// CHECK: DIALECT_NAMESPACE = "test"
|
||||
// CHECK: pass
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "test";
|
||||
let cppNamespace = "Test";
|
||||
|
@ -3,7 +3,6 @@
|
||||
import gc
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import func
|
||||
from mlir.dialects._ods_common import SubClassValueT
|
||||
|
||||
|
||||
def run(f):
|
||||
@ -270,7 +269,7 @@ def testValueCasters():
|
||||
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
|
||||
|
||||
@register_value_caster(IntegerType.static_typeid)
|
||||
def cast_int(v) -> SubClassValueT:
|
||||
def cast_int(v) -> Value:
|
||||
print("in caster", v.__class__.__name__)
|
||||
if isinstance(v, OpResult):
|
||||
return NOPResult(v)
|
||||
|
@ -31,7 +31,6 @@ constexpr const char *fileHeader = R"Py(
|
||||
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ._ods_common import (
|
||||
SubClassValueT as _SubClassValueT,
|
||||
equally_sized_accessor as _ods_equally_sized_accessor,
|
||||
get_default_loc_context as _ods_get_default_loc_context,
|
||||
get_op_result_or_op_results as _get_op_result_or_op_results,
|
||||
@ -52,8 +51,6 @@ constexpr const char *dialectClassTemplate = R"Py(
|
||||
@_ods_cext.register_dialect
|
||||
class _Dialect(_ods_ir.Dialect):
|
||||
DIALECT_NAMESPACE = "{0}"
|
||||
pass
|
||||
|
||||
)Py";
|
||||
|
||||
constexpr const char *dialectExtensionTemplate = R"Py(
|
||||
@ -1007,14 +1004,13 @@ static void emitValueBuilder(const Operator &op,
|
||||
});
|
||||
std::string nameWithoutDialect =
|
||||
op.getOperationName().substr(op.getOperationName().find('.') + 1);
|
||||
os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
|
||||
op.getCppClassName(),
|
||||
llvm::join(valueBuilderParams, ", "),
|
||||
llvm::join(opBuilderArgs, ", "),
|
||||
(op.getNumResults() > 1
|
||||
? "_Sequence[_SubClassValueT]"
|
||||
: (op.getNumResults() > 0 ? "_SubClassValueT"
|
||||
: "_ods_ir.Operation")));
|
||||
os << llvm::formatv(
|
||||
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
|
||||
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
|
||||
llvm::join(opBuilderArgs, ", "),
|
||||
(op.getNumResults() > 1
|
||||
? "_Sequence[_ods_ir.Value]"
|
||||
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
|
||||
}
|
||||
|
||||
/// Emits bindings for a specific Op to the given output stream.
|
||||
|
Loading…
x
Reference in New Issue
Block a user