[mlir][emitc] Restrict integer and float types (#85788)
Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports. The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.
This commit is contained in:
parent
972f65a83f
commit
647d75d3a8
@ -30,6 +30,10 @@
|
||||
namespace mlir {
|
||||
namespace emitc {
|
||||
void buildTerminatedBody(OpBuilder &builder, Location loc);
|
||||
/// Determines whether \p type is a valid integer type in EmitC.
|
||||
bool isSupportedIntegerType(mlir::Type type);
|
||||
/// Determines whether \p type is a valid floating-point type in EmitC.
|
||||
bool isSupportedFloatType(mlir::Type type);
|
||||
} // namespace emitc
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@ -51,8 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
|
||||
def CExpression : NativeOpTrait<"emitc::CExpression">;
|
||||
|
||||
// Types only used in binary arithmetic operations.
|
||||
def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
|
||||
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
|
||||
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index, EmitC_OpaqueType]>;
|
||||
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;
|
||||
|
||||
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
|
||||
let summary = "Addition operation";
|
||||
|
||||
@ -22,6 +22,12 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
|
||||
// EmitC type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
|
||||
"integer type supported by EmitC">;
|
||||
|
||||
def EmitCFloatType : Type<CPred<"emitc::isSupportedFloatType($_self)">,
|
||||
"floating-point type supported by EmitC">;
|
||||
|
||||
class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
|
||||
: TypeDef<EmitC_Dialect, name, traits> {
|
||||
let mnemonic = typeMnemonic;
|
||||
|
||||
@ -54,6 +54,35 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
|
||||
builder.create<emitc::YieldOp>(loc);
|
||||
}
|
||||
|
||||
bool mlir::emitc::isSupportedIntegerType(Type type) {
|
||||
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
case 8:
|
||||
case 16:
|
||||
case 32:
|
||||
case 64:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool mlir::emitc::isSupportedFloatType(Type type) {
|
||||
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
|
||||
switch (floatType.getWidth()) {
|
||||
case 32:
|
||||
case 64:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Check that the type of the initial value is compatible with the operations
|
||||
/// result type.
|
||||
static LogicalResult verifyInitializationAttribute(Operation *op,
|
||||
|
||||
@ -170,7 +170,7 @@ func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr<f32>) {
|
||||
// -----
|
||||
|
||||
func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
// expected-error @+1 {{'emitc.div' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
|
||||
// expected-error @+1 {{'emitc.div' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
|
||||
%1 = "emitc.div" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
return
|
||||
}
|
||||
@ -178,7 +178,7 @@ func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
// -----
|
||||
|
||||
func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
|
||||
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
|
||||
%1 = "emitc.mul" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
return
|
||||
}
|
||||
@ -186,7 +186,7 @@ func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
// -----
|
||||
|
||||
func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'tensor<i32>'}}
|
||||
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
|
||||
%1 = "emitc.rem" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
return
|
||||
}
|
||||
@ -194,7 +194,7 @@ func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
// -----
|
||||
|
||||
func.func @rem_float(%arg0: f32, %arg1: f32) {
|
||||
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'f32'}}
|
||||
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer type supported by EmitC or index or EmitC opaque type, but got 'f32'}}
|
||||
%1 = "emitc.rem" (%arg0, %arg1) : (f32, f32) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
@ -81,3 +81,19 @@ func.func @illegal_array_with_tensor_element_type(
|
||||
%arg0: !emitc.array<4xtensor<4xi32>>
|
||||
) {
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @illegal_integer_type(%arg0: i11, %arg1: i11) -> i11 {
|
||||
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'i11'}}
|
||||
%mul = "emitc.mul" (%arg0, %arg1) : (i11, i11) -> i11
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @illegal_float_type(%arg0: f80, %arg1: f80) {
|
||||
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'f80'}}
|
||||
%mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80
|
||||
return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user