[mlir][OpDSL] Add default value to index attributes.
Index attributes had no default value, which means the attribute values had to be set on the operation. This revision adds a default parameter to `IndexAttrDef`. After the change, every index attribute has to define a default value. For example, we may define the following strides attribute: ``` ``` When using the operation the default stride is used if the strides attribute is not set. The mechanism is implemented using `DefaultValuedAttr`. Additionally, the revision uses the naming index attribute instead of attribute more consistently, which is a preparation for follow up revisions that will introduce function attributes. Depends On D119125 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D119126
This commit is contained in:
parent
880e87580a
commit
d50571ab07
@ -105,7 +105,7 @@ appear in the parameter list of the operation:
|
||||
copy_and_scale(val, in_tensor, outs=[out_tensor])
|
||||
```
|
||||
|
||||
## Attributes
|
||||
## Index Attributes
|
||||
|
||||
Attributes are compile-time constant parameters only accessible in index
|
||||
expressions. They can be used to parameterize the access pattern of a structured
|
||||
@ -118,7 +118,7 @@ The following example demonstrates the use of attributes:
|
||||
@linalg_structured_op
|
||||
def strided_copy(I=TensorDef(T, S.IH, S.IW),
|
||||
O=TensorDef(T, S.OH, S.OW, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1])):
|
||||
"""Copy a subset of the input tensor elements to the output tensor"""
|
||||
O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
|
||||
```
|
||||
@ -129,11 +129,12 @@ the symbols `S.SH` and `S.SW`, which are used to index the input tensor `I`.
|
||||
When instantiating the operation, the attribute is set using a named argument:
|
||||
|
||||
```python
|
||||
strided_copy(in_tensor, outs=[out_tensor], strides=[1,2])
|
||||
strided_copy(in_tensor, outs=[out_tensor], strides=[1, 2])
|
||||
```
|
||||
|
||||
The `strides` vector elements substitute the symbols `S.SH` and `S.SW` in the
|
||||
index expressions of the operation instance.
|
||||
index expressions of the operation instance. If no strides are provided the
|
||||
`default` vector elements are used instead.
|
||||
|
||||
Attributes are currently limited to integer vectors and only accessible in index
|
||||
expressions. An operation may have multiple attributes all of them placed at the
|
||||
@ -157,8 +158,8 @@ def pooling_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U,
|
||||
I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
||||
```
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -135,7 +135,7 @@ class OperandKind(Enum):
|
||||
InputTensor = 0
|
||||
Scalar = 1
|
||||
OutputTensor = 2
|
||||
Attribute = 3
|
||||
IndexAttr = 3
|
||||
|
||||
|
||||
class OperandDef:
|
||||
@ -147,16 +147,18 @@ class OperandDef:
|
||||
|
||||
def __init__(self,
|
||||
kind: OperandKind,
|
||||
type_var: TypeVar,
|
||||
type_var: Optional[TypeVar] = None,
|
||||
size_exprs: Optional[Sequence[AffineExprDef]] = None,
|
||||
index_dims: Optional[Sequence[DimDef]] = None):
|
||||
if not isinstance(type_var, TypeVar):
|
||||
index_dims: Optional[Sequence[DimDef]] = None,
|
||||
default_vals : Optional[Sequence[int]] = None):
|
||||
if type_var and not isinstance(type_var, TypeVar):
|
||||
raise ValueError(
|
||||
f"OperandDef requires a TypeVar but got {repr(type_var)}")
|
||||
self.owner = None # type: Optional["LinalgOpDef"]
|
||||
self.type_var = type_var
|
||||
self.size_exprs = size_exprs
|
||||
self.index_dims = index_dims
|
||||
self.default_vals = default_vals
|
||||
self.kind = kind
|
||||
self.name = None # type: Optional[str]
|
||||
self.registered_index = -1 # type: int
|
||||
@ -174,7 +176,7 @@ class OperandDef:
|
||||
def __repr__(self):
|
||||
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
|
||||
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
|
||||
f"index_dims={self.index_dims})")
|
||||
f"index_dims={self.index_dims}, default_vals={self.default_vals})")
|
||||
|
||||
|
||||
class TensorDef:
|
||||
@ -202,7 +204,7 @@ class TensorDef:
|
||||
f"got {index_dims}")
|
||||
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
|
||||
self.operand_def = OperandDef(
|
||||
kind, type_var, size_exprs=shape, index_dims=index_dims)
|
||||
kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
|
||||
|
||||
def __getitem__(self, dims) -> TensorUse:
|
||||
assert self.operand_def.owner, "TensorDef is not attached to an op"
|
||||
@ -246,7 +248,7 @@ class ScalarDef(TensorExpression):
|
||||
"""
|
||||
|
||||
def __init__(self, type_var: TypeVar):
|
||||
self.operand_def = OperandDef(OperandKind.Scalar, type_var)
|
||||
self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var)
|
||||
|
||||
@property
|
||||
def scalar_name(self) -> str:
|
||||
@ -259,18 +261,25 @@ class ScalarDef(TensorExpression):
|
||||
|
||||
|
||||
class IndexAttrDef:
|
||||
"""Index Attribute definition.
|
||||
"""Index attribute definition.
|
||||
|
||||
Index attributes provide a way to define and set symbols that can be used in
|
||||
indexing expressions. Every attribute specifies a tuple of symbols that at
|
||||
compile-time are replaced by integer values.
|
||||
compile-time are replaced by integer values as well as their default values.
|
||||
"""
|
||||
|
||||
def __init__(self, *sizes: SymbolDef):
|
||||
def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
|
||||
if any(not isinstance(size, SymbolDef) for size in sizes):
|
||||
raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef but got "
|
||||
f"{sizes}")
|
||||
self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
|
||||
raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef "
|
||||
f"but got {sizes}")
|
||||
if any(not isinstance(default_val, int) for default_val in default):
|
||||
raise ValueError(f"IndexAttrDef requires default values of type int "
|
||||
f"but got {default}")
|
||||
if len(sizes) != len(default):
|
||||
raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
|
||||
f"but got {len(default)}")
|
||||
self.operand_def = OperandDef(
|
||||
OperandKind.IndexAttr, size_exprs=sizes, default_vals=default)
|
||||
|
||||
|
||||
class Comprehension:
|
||||
|
@ -45,10 +45,10 @@ class OperandDefConfig(YAMLObject):
|
||||
def __init__(self,
|
||||
operand_def: OperandDef,
|
||||
shape_map: Optional[_ir.AffineMap] = None,
|
||||
attribute_map: Optional[_ir.AffineMap] = None):
|
||||
index_attr_map: Optional[_ir.AffineMap] = None):
|
||||
self.operand_def = operand_def
|
||||
self.shape_map = shape_map # type: Optional[_ir.AffineMap]
|
||||
self.attribute_map = attribute_map # type: Optional[_ir.AffineMap]
|
||||
self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
|
||||
self.indexing_map = None # type: Optional[_ir.AffineMap]
|
||||
|
||||
@property
|
||||
@ -61,24 +61,28 @@ class OperandDefConfig(YAMLObject):
|
||||
|
||||
@property
|
||||
def usage(self) -> str:
|
||||
if self.operand_def.kind == OperandKind.Attribute:
|
||||
return "IndexAttribute"
|
||||
if self.operand_def.kind == OperandKind.IndexAttr:
|
||||
return "IndexAttr"
|
||||
if self.operand_def.kind == OperandKind.OutputTensor:
|
||||
return "OutputOperand"
|
||||
return "InputOperand"
|
||||
return "Output"
|
||||
return "Input"
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
self_dict = dict(
|
||||
name=self.name, usage=self.usage, type_var=self.type_var.name)
|
||||
self_dict = dict(name=self.name, usage=self.usage)
|
||||
if self.type_var:
|
||||
self_dict["type_var"] = self.type_var.name
|
||||
if self.shape_map:
|
||||
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
|
||||
if self.attribute_map:
|
||||
self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map)
|
||||
if self.index_attr_map:
|
||||
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
|
||||
if self.operand_def.default_vals:
|
||||
self_dict["default_vals"] = self.operand_def.default_vals
|
||||
return self_dict
|
||||
|
||||
def __repr__(self):
|
||||
return (f"OperandDefConfig({self.operand_def}, "
|
||||
f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, "
|
||||
f"shape_map={self.shape_map}, "
|
||||
f"index_attr_map={self.index_attr_map}, "
|
||||
f"indexing_map={self.indexing_map})")
|
||||
|
||||
|
||||
@ -162,7 +166,7 @@ class LinalgStructuredOpConfig(YAMLObject):
|
||||
# Collect all attribute definitions.
|
||||
collected_attr_defs = list()
|
||||
for operand in registered_operands:
|
||||
if operand.kind == OperandKind.Attribute:
|
||||
if operand.kind == OperandKind.IndexAttr:
|
||||
collected_attr_defs.append(operand)
|
||||
|
||||
# Collect all tensors with manual indexing annotation.
|
||||
@ -210,9 +214,9 @@ class LinalgStructuredOpConfig(YAMLObject):
|
||||
if operand_config.shape_map:
|
||||
operand_config.shape_map = self._normalize_affine_map(
|
||||
operand_config.shape_map, with_dims=False)
|
||||
if operand_config.attribute_map:
|
||||
operand_config.attribute_map = self._normalize_affine_map(
|
||||
operand_config.attribute_map, with_dims=False)
|
||||
if operand_config.index_attr_map:
|
||||
operand_config.index_attr_map = self._normalize_affine_map(
|
||||
operand_config.index_attr_map, with_dims=False)
|
||||
|
||||
# Now for each write use, propagate the indexing maps from the use to the
|
||||
# tensor, ensuring that there are not conflicts.
|
||||
@ -245,7 +249,7 @@ class LinalgStructuredOpConfig(YAMLObject):
|
||||
|
||||
# Check all registered tensor and scalar operands have an indexing map.
|
||||
for operand in registered_operands:
|
||||
if operand.kind == OperandKind.Attribute:
|
||||
if operand.kind == OperandKind.IndexAttr:
|
||||
continue
|
||||
if not (operand in self.operands and self.operands[operand].indexing_map):
|
||||
raise ValueError(f"Failed to compute an indexing map for operand "
|
||||
@ -319,9 +323,9 @@ class LinalgStructuredOpConfig(YAMLObject):
|
||||
assert local_state.local_dim_count == 0
|
||||
affine_map = _ir.AffineMap.get(
|
||||
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
|
||||
if operand_def.kind == OperandKind.Attribute:
|
||||
if operand_def.kind == OperandKind.IndexAttr:
|
||||
self.operands[operand_def] = OperandDefConfig(
|
||||
operand_def, attribute_map=affine_map)
|
||||
operand_def, index_attr_map=affine_map)
|
||||
else:
|
||||
self.operands[operand_def] = OperandDefConfig(
|
||||
operand_def, shape_map=affine_map)
|
||||
|
@ -39,15 +39,14 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||
*ins: Value, outs: ValueList,
|
||||
**attrs: Sequence[int]):
|
||||
all_arg_defs = op_config.ordered_operands
|
||||
in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
|
||||
out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
|
||||
attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
|
||||
in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"]
|
||||
out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"]
|
||||
index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"]
|
||||
|
||||
# Verify outs is a sequence or a list of results.
|
||||
if not isinstance(outs, (Sequence, OpResultList)):
|
||||
raise ValueError(
|
||||
f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}"
|
||||
)
|
||||
raise ValueError(f"Expected named argument outs to have type Sequence or "
|
||||
f"OpResultLis but got {type(outs)}")
|
||||
|
||||
# Arity validation.
|
||||
if len(ins) != len(in_arg_defs):
|
||||
@ -60,18 +59,19 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||
# Compute a replacement list for all attribute symbols.
|
||||
expressions = [] # type: Sequence[AffineExpr]
|
||||
replacements = [] # type: Sequence[AffineExpr]
|
||||
for attr in attr_arg_defs:
|
||||
if attr.name not in attrs:
|
||||
raise ValueError(f"Expected named argument for the attribute {attr.name}")
|
||||
attribute_values = attrs.get(attr.name)
|
||||
if not all(isinstance(value, int) for value in attribute_values):
|
||||
raise ValueError(f"Attribute {attr.name} needs to be of type "
|
||||
f"Sequence[int] but got {type(attribute_values)}")
|
||||
results = attr.attribute_map.results # type: AffineExprList
|
||||
if len(attribute_values) != len(results):
|
||||
raise ValueError(f"Attribute {attr.name} has length {len(results)} "
|
||||
f"but got {len(attribute_values)} values")
|
||||
for expr, value in zip(results, attribute_values):
|
||||
for index_attr in index_attr_arg_defs:
|
||||
index_attr_vals = index_attr.operand_def.default_vals
|
||||
if index_attr.name in attrs:
|
||||
index_attr_vals = attrs.get(index_attr.name)
|
||||
assert index_attr_vals, "Index attribute has no value"
|
||||
if not all(isinstance(value, int) for value in index_attr_vals):
|
||||
raise ValueError(f"Attribute {index_attr.name} needs to be of type "
|
||||
f"Sequence[int] but got {type(index_attr_vals)}")
|
||||
results = index_attr.index_attr_map.results # type: AffineExprList
|
||||
if len(index_attr_vals) != len(results):
|
||||
raise ValueError(f"Attribute {index_attr.name} has length {len(results)} "
|
||||
f"but got {len(index_attr_vals)} values")
|
||||
for expr, value in zip(results, index_attr_vals):
|
||||
expressions.append(expr)
|
||||
replacements.append(AffineConstantExpr.get(value))
|
||||
|
||||
@ -116,22 +116,24 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||
iterator_types_attr = ArrayAttr.get(
|
||||
[StringAttr.get(s) for s in op_config.iterator_types])
|
||||
|
||||
# Compute a dictionary storing all index attributes.
|
||||
index_attributes = {} # type: Dict[str, DenseElementAttr]
|
||||
for attr in attr_arg_defs:
|
||||
attribute_values = attrs.get(attr.name)
|
||||
array = np.array(attribute_values, dtype=np.int64)
|
||||
index_attributes[attr.name] = DenseElementsAttr.get(array)
|
||||
# Compute the index attributes used when emitting a named structured op.
|
||||
index_attrs = {} # type: Dict[str, DenseElementAttr]
|
||||
for index_attr in index_attr_arg_defs:
|
||||
index_attr_vals = attrs.get(index_attr.name)
|
||||
# Only forward attributes set to a non-default value.
|
||||
if index_attr_vals:
|
||||
array = np.array(index_attr_vals, dtype=np.int64)
|
||||
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
|
||||
|
||||
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr,
|
||||
index_attributes, block_arg_types)
|
||||
index_attrs, block_arg_types)
|
||||
|
||||
|
||||
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
||||
outs: ValueList, **attrs: Sequence[int]):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||
|
||||
# An operation that accesses only scalars and scalar/rank zero tensors is
|
||||
@ -182,7 +184,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
||||
op_class_name: str, *ins: Value, outs: ValueList,
|
||||
**attrs: Sequence[int]):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||
|
||||
# If we get here, there must exist a builtin class `op_class_name`.
|
||||
@ -195,7 +197,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
||||
|
||||
# Set the index attributes used to compute the indexing maps.
|
||||
named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
|
||||
for name, value in index_attributes.items():
|
||||
for name, value in index_attrs.items():
|
||||
named_op.operation.attributes[name] = value
|
||||
|
||||
linalg.fill_builtin_region(named_op.operation)
|
||||
|
@ -224,8 +224,8 @@ def conv_1d_nwc_wcf(
|
||||
I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
|
||||
K=TensorDef(T2, S.KW, S.C, S.F),
|
||||
O=TensorDef(U, S.N, S.OW, S.F, output=True),
|
||||
strides=IndexAttrDef(S.SW),
|
||||
dilations=IndexAttrDef(S.DW)):
|
||||
strides=IndexAttrDef(S.SW, default=[1]),
|
||||
dilations=IndexAttrDef(S.DW, default=[1])):
|
||||
"""Performs 1-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
@ -244,8 +244,8 @@ def conv_2d_nhwc_hwcf(
|
||||
S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs 2-D convolution.
|
||||
|
||||
Layout:
|
||||
@ -270,8 +270,8 @@ def conv_2d_nhwc_hwcf_q(
|
||||
IZp=ScalarDef(I32),
|
||||
KZp=ScalarDef(I32),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs 2-D convolution with zero point offsets.
|
||||
|
||||
Layout:
|
||||
@ -297,8 +297,8 @@ def conv_2d_nchw_fchw(
|
||||
S.OW * S.SW + S.KW * S.DW),
|
||||
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
|
||||
O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs 2-D convolution.
|
||||
|
||||
Layout:
|
||||
@ -321,8 +321,8 @@ def conv_3d_ndhwc_dhwcf(
|
||||
S.OW * S.SW + S.KW * S.DW, S.C),
|
||||
K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
|
||||
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
|
||||
strides=IndexAttrDef(S.SD, S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
|
||||
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
|
||||
"""Performs 3-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
@ -341,8 +341,8 @@ def depthwise_conv_1d_nwc_wc(
|
||||
I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
|
||||
K=TensorDef(T2, S.KW, S.IC),
|
||||
O=TensorDef(U, S.N, S.OW, S.IC, output=True),
|
||||
strides=IndexAttrDef(S.SW),
|
||||
dilations=IndexAttrDef(S.DW)):
|
||||
strides=IndexAttrDef(S.SW, default=[1]),
|
||||
dilations=IndexAttrDef(S.DW, default=[1])):
|
||||
"""Performs depth-wise 1-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
@ -362,8 +362,8 @@ def depthwise_conv_2d_nhwc_hwc(
|
||||
S.IC),
|
||||
K=TensorDef(T2, S.KH, S.KW, S.IC),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs depth-wise 2-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
@ -385,8 +385,8 @@ def depthwise_conv_2d_nhwc_hwc_q(
|
||||
IZp=ScalarDef(I32),
|
||||
KZp=ScalarDef(I32),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs depth-wise 2-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
@ -407,8 +407,8 @@ def depthwise_conv_2d_nhwc_hwcm(
|
||||
S.IC),
|
||||
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs depth-wise 2-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
@ -429,8 +429,8 @@ def depthwise_conv_2d_nhwc_hwcm_q(
|
||||
IZp=ScalarDef(I32),
|
||||
KZp=ScalarDef(I32),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs depth-wise 2-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
@ -451,8 +451,8 @@ def pooling_nhwc_sum(
|
||||
S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs sum pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
@ -470,8 +470,8 @@ def pooling_nhwc_max(
|
||||
S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs max pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
@ -490,8 +490,8 @@ def pooling_nhwc_max_unsigned(
|
||||
S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs unsigned max pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
@ -510,8 +510,8 @@ def pooling_nchw_max(
|
||||
S.OW * S.SW + S.KW * S.DW),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs max pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
@ -531,8 +531,8 @@ def pooling_nhwc_min(
|
||||
S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs min pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
@ -551,8 +551,8 @@ def pooling_nhwc_min_unsigned(
|
||||
S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
"""Performs unsigned min pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
@ -571,8 +571,8 @@ def pooling_ndhwc_sum(
|
||||
S.OW * S.SW + S.KW * S.DW, S.C),
|
||||
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SD, S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
|
||||
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
|
||||
"""Performs 3D sum pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
@ -591,8 +591,8 @@ def pooling_ndhwc_max(
|
||||
S.OW * S.SW + S.KW * S.DW, S.C),
|
||||
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SD, S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
|
||||
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
|
||||
"""Performs 3D max pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
@ -612,8 +612,8 @@ def pooling_ndhwc_min(
|
||||
S.OW * S.SW + S.KW * S.DW, S.C),
|
||||
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SD, S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
|
||||
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
|
||||
"""Performs 3D min pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
|
@ -97,19 +97,12 @@ func @depthwise_conv_2d_nhwc_hwcm_memref_dilated(%input: memref<2x8x9x2xf32>, %f
|
||||
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// expected-error @+1 {{missing indexing map required attribute 'strides'}}
|
||||
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>}
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// expected-error @+1 {{missing indexing map required attribute 'dilations'}}
|
||||
linalg.depthwise_conv_2d_nhwc_hwc {strides = dense<1> : vector<2xi64>}
|
||||
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_default_attributes
|
||||
func @depthwise_conv_2d_input_nhwc_filter_default_attributes(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// CHECK: linalg.depthwise_conv_2d_nhwc_hwc
|
||||
// CHECK-NOT: strides =
|
||||
// CHECK-NOT: dilations =
|
||||
linalg.depthwise_conv_2d_nhwc_hwc
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
return
|
||||
@ -118,7 +111,7 @@ func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}}
|
||||
// expected-error @+1 {{incorrect element type for index attribute 'strides'}}
|
||||
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
@ -128,7 +121,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}}
|
||||
// expected-error @+1 {{incorrect shape for index attribute 'strides'}}
|
||||
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
@ -566,7 +559,7 @@ func @conv_interface_wrong_input_indexing_map(
|
||||
%arg0 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// expected-error @+1 {{unexpected input index map for convolutions}}
|
||||
%0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({
|
||||
^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
|
||||
^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
|
||||
%1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32
|
||||
%2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32
|
||||
"linalg.yield"(%2) : (f32) -> ()
|
||||
@ -583,7 +576,7 @@ func @conv_interface_wrong_num_operands(
|
||||
%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// expected-error @+1 {{expected output/filter indexing maps to be projected permutations}}
|
||||
%0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({
|
||||
^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
|
||||
^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
|
||||
%1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32
|
||||
%2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32
|
||||
"linalg.yield"(%2) : (f32) -> ()
|
||||
|
@ -21,7 +21,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: OutputOperand
|
||||
usage: Output
|
||||
type_var: T
|
||||
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
@ -95,7 +95,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# @linalg_structured_op
|
||||
# def test2(I=TensorDef(T, S.M, S.N),
|
||||
# O=TensorDef(T, S.M, S.N, output=True),
|
||||
# strides=IndexAttrDef(S.SM, S.SN)):
|
||||
# strides=IndexAttrDef(S.SM, S.SN, default=[1, 2])):
|
||||
# """Title.
|
||||
|
||||
# Detailed description.
|
||||
@ -114,19 +114,21 @@ structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
usage: InputOperand
|
||||
usage: Input
|
||||
type_var: T
|
||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: OutputOperand
|
||||
usage: Output
|
||||
type_var: T
|
||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: strides
|
||||
usage: IndexAttribute
|
||||
type_var: I64
|
||||
attribute_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
|
||||
usage: IndexAttr
|
||||
index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
|
||||
default_vals:
|
||||
- 1
|
||||
- 2
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>
|
||||
@ -145,7 +147,8 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# ODS: let arguments =
|
||||
# ODS-NEXT: Variadic<AnyType>:$inputs,
|
||||
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
|
||||
# ODS-NEXT: RankedI64ElementsAttr<[2]>:$strides
|
||||
# ODS-NEXT: DefaultValuedAttr<RankedI64ElementsAttr<[2]>
|
||||
# ODS-SAME: "{ static_cast<int64_t>(1), static_cast<int64_t>(2) }">:$strides
|
||||
|
||||
# ODS: "Attribute":$strides
|
||||
# ODS: $_state.addAttribute("strides", strides);
|
||||
@ -169,8 +172,8 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# IMPL: Test2Op::hasDynamicIndexingMaps() { return true; }
|
||||
# IMPL: Test2Op::verifyIndexingMapRequiredAttributes()
|
||||
# IMPL: auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
|
||||
# IMPL: "missing indexing map required attribute 'strides'"
|
||||
|
||||
# IMPL: "incorrect element type for index attribute 'strides'"
|
||||
# IMPL: "incorrect shape for index attribute 'strides'"
|
||||
# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block)
|
||||
# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 &&
|
||||
|
||||
@ -197,11 +200,11 @@ structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: value
|
||||
usage: InputOperand
|
||||
usage: Input
|
||||
type_var: T1
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: OutputOperand
|
||||
usage: Output
|
||||
type_var: U
|
||||
shape_map: affine_map<() -> ()>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
|
@ -7,15 +7,15 @@ from mlir.dialects.linalg.opdsl.lang import *
|
||||
# CHECK-LABEL: matmul
|
||||
# CHECK: args:
|
||||
# CHECK: name: A
|
||||
# CHECK: usage: InputOperand
|
||||
# CHECK: usage: Input
|
||||
# CHECK: type_var: T
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
|
||||
# CHECK: name: B
|
||||
# CHECK: usage: InputOperand
|
||||
# CHECK: usage: Input
|
||||
# CHECK: type_var: T
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
|
||||
# CHECK: name: C
|
||||
# CHECK: usage: OutputOperand
|
||||
# CHECK: usage: Output
|
||||
# CHECK: type_var: U
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
||||
@linalg_structured_op
|
||||
@ -30,7 +30,7 @@ def matmul(
|
||||
# CHECK-LABEL: fill
|
||||
# CHECK: args:
|
||||
# CHECK: name: value
|
||||
# CHECK: usage: InputOperand
|
||||
# CHECK: usage: Input
|
||||
# CHECK-NOT: shape_map:
|
||||
# CHECK: type_var: T
|
||||
@linalg_structured_op
|
||||
@ -42,20 +42,22 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
|
||||
# CHECK-LABEL: strided_copy
|
||||
# CHECK: args:
|
||||
# CHECK: name: I
|
||||
# CHECK: usage: InputOperand
|
||||
# CHECK: usage: Input
|
||||
# CHECK: type_var: T
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
|
||||
# CHECK: name: O
|
||||
# CHECK: usage: OutputOperand
|
||||
# CHECK: usage: Output
|
||||
# CHECK: type_var: T
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
|
||||
# CHECK: name: strides
|
||||
# CHECK: usage: IndexAttribute
|
||||
# CHECK: type_var: I64
|
||||
# CHECK: attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
|
||||
# CHECK: usage: IndexAttr
|
||||
# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
|
||||
# CHECK: default_vals:
|
||||
# CHECK: - 1
|
||||
# CHECK: - 2
|
||||
@linalg_structured_op
|
||||
def strided_copy(
|
||||
I=TensorDef(T, S.IH, S.IW),
|
||||
O=TensorDef(T, S.OH, S.OW, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])):
|
||||
O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
|
||||
|
@ -16,8 +16,8 @@ def conv_poly(
|
||||
I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, S.C),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
@ -51,8 +51,9 @@ with Context() as ctx, Location.unknown():
|
||||
RankedTensorType.get((2, 2, 1), f32),
|
||||
RankedTensorType.get((1, 2, 4, 1), i32))
|
||||
def test_f32i32_conv(input, filter, init_result):
|
||||
# Use default dilations and set non-default strides.
|
||||
return conv_poly(
|
||||
input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
input, filter, outs=[init_result], strides=[2, 4])
|
||||
|
||||
|
||||
print(module)
|
||||
|
@ -16,8 +16,8 @@ def pooling_max_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
@ -29,8 +29,8 @@ def pooling_max_unsigned_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
|
||||
TypeFn.cast_unsigned(
|
||||
@ -42,8 +42,8 @@ def pooling_min_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
@ -55,8 +55,8 @@ def pooling_min_unsigned_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW),
|
||||
dilations=IndexAttrDef(S.DH, S.DW)):
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
|
||||
TypeFn.cast_unsigned(
|
||||
|
@ -132,7 +132,7 @@ func @main() -> i32 attributes {llvm.emit_c_interface} {
|
||||
%c2 = arith.constant 2 : index
|
||||
memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
|
||||
memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
|
||||
memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64>
|
||||
memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64>
|
||||
|
||||
call @pooling_on_buffers(%input, %shape, %output) :
|
||||
(memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
|
||||
@ -421,9 +421,13 @@ def test_min_pooling_builtin():
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
|
||||
MemRefType.get((1, 2, 4, 1), i32))
|
||||
# Set the strides and use the default dilations.
|
||||
def pooling_on_buffers(input, shape, output):
|
||||
linalg.pooling_nhwc_min(
|
||||
input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
|
||||
input,
|
||||
shape,
|
||||
outs=[output],
|
||||
strides=[2, 4])
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||
|
||||
@ -451,13 +455,13 @@ def test_min_pooling_generic():
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
|
||||
MemRefType.get((1, 2, 4, 1), i32))
|
||||
# Set the strides and use the default dilations.
|
||||
def pooling_on_buffers(input, shape, output):
|
||||
linalg.pooling_nhwc_min(
|
||||
input,
|
||||
shape,
|
||||
outs=[output],
|
||||
strides=[2, 4],
|
||||
dilations=[1, 2],
|
||||
emit_generic=True)
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||
|
@ -61,14 +61,15 @@ struct SerializedAffineMap {
|
||||
AffineMap affineMap() { return affineMapAttr.getValue(); }
|
||||
};
|
||||
|
||||
enum class LinalgOperandDefUsage { input, output, attribute };
|
||||
enum class LinalgOperandDefUsage { Input, Output, IndexAttr };
|
||||
|
||||
struct LinalgOperandDef {
|
||||
std::string name;
|
||||
LinalgOperandDefUsage usage;
|
||||
std::string typeVar;
|
||||
Optional<std::string> typeVar;
|
||||
Optional<SerializedAffineMap> shapeMap;
|
||||
Optional<SerializedAffineMap> attributeMap;
|
||||
Optional<SerializedAffineMap> indexAttrMap;
|
||||
Optional<SmallVector<int64_t>> defaultVals;
|
||||
};
|
||||
|
||||
enum class LinalgIteratorTypeDef {
|
||||
@ -175,18 +176,21 @@ struct MappingTraits<LinalgStructuredOpConfig> {
|
||||
/// the argument. Only tensor arguments have a `shape_map`. Each shape must
|
||||
/// be normalized over the same list of symbols and have no dimension
|
||||
/// inputs.
|
||||
/// - `attribute_map`: An optional AffineMap from all op symbols to the
|
||||
/// attribute symbols. During op creation these symbols are replaced by the
|
||||
/// corresponding `name` attribute values. Only attribute arguments have
|
||||
/// an `attribute_map`.
|
||||
/// - `index_attr_map`: An optional AffineMap from all op symbols to the
|
||||
/// index attribute symbols. During op creation these symbols are replaced
|
||||
/// by the corresponding `name` index attribue values. Only index attribute
|
||||
/// arguments have an `index_attr_map`.
|
||||
/// - `default_vals`: An optional default initialization for index attribute
|
||||
/// arguments.
|
||||
template <>
|
||||
struct MappingTraits<LinalgOperandDef> {
|
||||
static void mapping(IO &io, LinalgOperandDef &info) {
|
||||
io.mapRequired("name", info.name);
|
||||
io.mapRequired("usage", info.usage);
|
||||
io.mapRequired("type_var", info.typeVar);
|
||||
io.mapOptional("type_var", info.typeVar);
|
||||
io.mapOptional("shape_map", info.shapeMap);
|
||||
io.mapOptional("attribute_map", info.attributeMap);
|
||||
io.mapOptional("index_attr_map", info.indexAttrMap);
|
||||
io.mapOptional("default_vals", info.defaultVals);
|
||||
}
|
||||
};
|
||||
|
||||
@ -194,9 +198,9 @@ struct MappingTraits<LinalgOperandDef> {
|
||||
template <>
|
||||
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
||||
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
|
||||
io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
|
||||
io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
|
||||
io.enumCase(value, "IndexAttribute", LinalgOperandDefUsage::attribute);
|
||||
io.enumCase(value, "Input", LinalgOperandDefUsage::Input);
|
||||
io.enumCase(value, "Output", LinalgOperandDefUsage::Output);
|
||||
io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr);
|
||||
}
|
||||
};
|
||||
|
||||
@ -395,7 +399,10 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
|
||||
|
||||
// Search all argument types.
|
||||
for (const auto &it : llvm::enumerate(args)) {
|
||||
if (it.value().typeVar == typeVar)
|
||||
if (it.value().usage != LinalgOperandDefUsage::Input &&
|
||||
it.value().usage != LinalgOperandDefUsage::Output)
|
||||
continue;
|
||||
if (it.value().typeVar.getValue() == typeVar)
|
||||
return llvm::formatv("block.getArgument({0}).getType()", it.index())
|
||||
.str();
|
||||
}
|
||||
@ -674,20 +681,32 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
||||
|
||||
// Assemble the attribute specific logic required for the op definition.
|
||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return arg.usage == LinalgOperandDefUsage::attribute;
|
||||
return arg.usage == LinalgOperandDefUsage::IndexAttr;
|
||||
})) {
|
||||
SmallVector<std::string> attrDefs;
|
||||
SmallVector<std::string> attrParams;
|
||||
SmallVector<std::string> attrStmts;
|
||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::attribute)
|
||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
||||
continue;
|
||||
assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
|
||||
static const char defFmt[] = "RankedI64ElementsAttr<[{0}]>:${1}";
|
||||
assert(arg.indexAttrMap.hasValue());
|
||||
assert(arg.defaultVals.hasValue());
|
||||
size_t size = arg.indexAttrMap->affineMap().getNumResults();
|
||||
assert(arg.defaultVals.getValue().size() == size);
|
||||
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
|
||||
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
|
||||
static const char paramFmt[] = "\"Attribute\":${0}";
|
||||
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
|
||||
attrDefs.push_back(llvm::formatv(
|
||||
defFmt, arg.attributeMap->affineMap().getNumResults(), arg.name));
|
||||
std::string defaultVals;
|
||||
llvm::raw_string_ostream ss(defaultVals);
|
||||
ss << "{ ";
|
||||
llvm::interleave(
|
||||
arg.defaultVals.getValue(), ss,
|
||||
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
|
||||
", ");
|
||||
ss << " }";
|
||||
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
|
||||
ss.str(), arg.name));
|
||||
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
||||
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
||||
}
|
||||
@ -725,7 +744,7 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
|
||||
// Compute the number of scalar and tensor arguments.
|
||||
int64_t numOfArgs =
|
||||
llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return arg.usage != LinalgOperandDefUsage::attribute;
|
||||
return arg.usage != LinalgOperandDefUsage::IndexAttr;
|
||||
});
|
||||
|
||||
// An operation that accesses only scalars and scalar/rank zero tensors is
|
||||
@ -796,11 +815,11 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
|
||||
)FMT";
|
||||
// Update all symbol bindings mapped to an attribute.
|
||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::attribute)
|
||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
||||
continue;
|
||||
assert(arg.attributeMap.hasValue());
|
||||
assert(arg.indexAttrMap.hasValue());
|
||||
for (auto &en :
|
||||
llvm::enumerate(arg.attributeMap->affineMap().getResults())) {
|
||||
llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
|
||||
if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
|
||||
symbolBindings[symbol.getPosition()] =
|
||||
llvm::formatv(structuredOpAccessAttrFormat, arg.name,
|
||||
@ -889,31 +908,26 @@ std::string {0}::getLibraryCallName() {{
|
||||
|
||||
// hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
|
||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return arg.usage == LinalgOperandDefUsage::attribute;
|
||||
return arg.usage == LinalgOperandDefUsage::IndexAttr;
|
||||
})) {
|
||||
std::vector<std::string> attrVerifications;
|
||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::attribute)
|
||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
||||
continue;
|
||||
assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
|
||||
assert(arg.indexAttrMap.hasValue());
|
||||
// Verify index attribute. Paramters:
|
||||
// {0}: Attribute name
|
||||
// {1}: Attribute size
|
||||
static const char attrFmt[] = R"FMT(
|
||||
if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
|
||||
if (!attr.getType().getElementType().isInteger(64))
|
||||
return op->emitError(
|
||||
"incorrect element type for indexing map required attribute '{0}'");
|
||||
return op->emitError("incorrect element type for index attribute '{0}'");
|
||||
if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
|
||||
return op->emitError(
|
||||
"incorrect shape for indexing map required attribute '{0}'");
|
||||
} else {
|
||||
return op->emitError(
|
||||
"missing indexing map required attribute '{0}'");
|
||||
return op->emitError("incorrect shape for index attribute '{0}'");
|
||||
}
|
||||
)FMT";
|
||||
attrVerifications.push_back(llvm::formatv(
|
||||
attrFmt, arg.name, arg.attributeMap->affineMap().getNumResults()));
|
||||
attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults()));
|
||||
}
|
||||
|
||||
// Generates the verifyIndexingMapRequiredAttributes method. Parameters:
|
||||
@ -953,7 +967,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
|
||||
int localCounter = 0;
|
||||
SmallVector<std::string> stmts;
|
||||
for (LinalgOperandDef &arg : args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::output)
|
||||
if (arg.usage != LinalgOperandDefUsage::Output)
|
||||
continue;
|
||||
|
||||
// Find the assignment that correlates with the argument.
|
||||
|
Loading…
x
Reference in New Issue
Block a user