[mlir][OpDSL] Add default value to index attributes.

Index attributes had no default value, which means the attribute values had to be set on the operation. This revision adds a default parameter to `IndexAttrDef`. After the change, every index attribute has to define a default value. For example, we may define the following strides attribute:
```

```
When using the operation the default stride is used if the strides attribute is not set. The mechanism is implemented using `DefaultValuedAttr`.

Additionally, the revision uses the naming index attribute instead of attribute more consistently, which is a preparation for follow up revisions that will introduce function attributes.

Depends On D119125

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D119126
This commit is contained in:
gysit 2022-02-14 12:12:15 +00:00
parent 880e87580a
commit d50571ab07
13 changed files with 558 additions and 429 deletions

View File

@ -105,7 +105,7 @@ appear in the parameter list of the operation:
copy_and_scale(val, in_tensor, outs=[out_tensor])
```
## Attributes
## Index Attributes
Attributes are compile-time constant parameters only accessible in index
expressions. They can be used to parameterize the access pattern of a structured
@ -118,7 +118,7 @@ The following example demonstrates the use of attributes:
@linalg_structured_op
def strided_copy(I=TensorDef(T, S.IH, S.IW),
O=TensorDef(T, S.OH, S.OW, output=True),
strides=IndexAttrDef(S.SH, S.SW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1])):
"""Copy a subset of the input tensor elements to the output tensor"""
O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
```
@ -129,11 +129,12 @@ the symbols `S.SH` and `S.SW`, which are used to index the input tensor `I`.
When instantiating the operation, the attribute is set using a named argument:
```python
strided_copy(in_tensor, outs=[out_tensor], strides=[1,2])
strided_copy(in_tensor, outs=[out_tensor], strides=[1, 2])
```
The `strides` vector elements substitute the symbols `S.SH` and `S.SW` in the
index expressions of the operation instance.
index expressions of the operation instance. If no strides are provided the
`default` vector elements are used instead.
Attributes are currently limited to integer vectors and only accessible in index
expressions. An operation may have multiple attributes all of them placed at the
@ -157,8 +158,8 @@ def pooling_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U,
I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
```

View File

@ -135,7 +135,7 @@ class OperandKind(Enum):
InputTensor = 0
Scalar = 1
OutputTensor = 2
Attribute = 3
IndexAttr = 3
class OperandDef:
@ -147,16 +147,18 @@ class OperandDef:
def __init__(self,
kind: OperandKind,
type_var: TypeVar,
type_var: Optional[TypeVar] = None,
size_exprs: Optional[Sequence[AffineExprDef]] = None,
index_dims: Optional[Sequence[DimDef]] = None):
if not isinstance(type_var, TypeVar):
index_dims: Optional[Sequence[DimDef]] = None,
default_vals : Optional[Sequence[int]] = None):
if type_var and not isinstance(type_var, TypeVar):
raise ValueError(
f"OperandDef requires a TypeVar but got {repr(type_var)}")
self.owner = None # type: Optional["LinalgOpDef"]
self.type_var = type_var
self.size_exprs = size_exprs
self.index_dims = index_dims
self.default_vals = default_vals
self.kind = kind
self.name = None # type: Optional[str]
self.registered_index = -1 # type: int
@ -174,7 +176,7 @@ class OperandDef:
def __repr__(self):
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
f"index_dims={self.index_dims})")
f"index_dims={self.index_dims}, default_vals={self.default_vals})")
class TensorDef:
@ -202,7 +204,7 @@ class TensorDef:
f"got {index_dims}")
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
self.operand_def = OperandDef(
kind, type_var, size_exprs=shape, index_dims=index_dims)
kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
def __getitem__(self, dims) -> TensorUse:
assert self.operand_def.owner, "TensorDef is not attached to an op"
@ -246,7 +248,7 @@ class ScalarDef(TensorExpression):
"""
def __init__(self, type_var: TypeVar):
self.operand_def = OperandDef(OperandKind.Scalar, type_var)
self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var)
@property
def scalar_name(self) -> str:
@ -259,18 +261,25 @@ class ScalarDef(TensorExpression):
class IndexAttrDef:
"""Index Attribute definition.
"""Index attribute definition.
Index attributes provide a way to define and set symbols that can be used in
indexing expressions. Every attribute specifies a tuple of symbols that at
compile-time are replaced by integer values.
compile-time are replaced by integer values as well as their default values.
"""
def __init__(self, *sizes: SymbolDef):
def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
if any(not isinstance(size, SymbolDef) for size in sizes):
raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef but got "
f"{sizes}")
self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef "
f"but got {sizes}")
if any(not isinstance(default_val, int) for default_val in default):
raise ValueError(f"IndexAttrDef requires default values of type int "
f"but got {default}")
if len(sizes) != len(default):
raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
f"but got {len(default)}")
self.operand_def = OperandDef(
OperandKind.IndexAttr, size_exprs=sizes, default_vals=default)
class Comprehension:

View File

@ -45,10 +45,10 @@ class OperandDefConfig(YAMLObject):
def __init__(self,
operand_def: OperandDef,
shape_map: Optional[_ir.AffineMap] = None,
attribute_map: Optional[_ir.AffineMap] = None):
index_attr_map: Optional[_ir.AffineMap] = None):
self.operand_def = operand_def
self.shape_map = shape_map # type: Optional[_ir.AffineMap]
self.attribute_map = attribute_map # type: Optional[_ir.AffineMap]
self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
self.indexing_map = None # type: Optional[_ir.AffineMap]
@property
@ -61,24 +61,28 @@ class OperandDefConfig(YAMLObject):
@property
def usage(self) -> str:
if self.operand_def.kind == OperandKind.Attribute:
return "IndexAttribute"
if self.operand_def.kind == OperandKind.IndexAttr:
return "IndexAttr"
if self.operand_def.kind == OperandKind.OutputTensor:
return "OutputOperand"
return "InputOperand"
return "Output"
return "Input"
def to_yaml_custom_dict(self):
self_dict = dict(
name=self.name, usage=self.usage, type_var=self.type_var.name)
self_dict = dict(name=self.name, usage=self.usage)
if self.type_var:
self_dict["type_var"] = self.type_var.name
if self.shape_map:
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
if self.attribute_map:
self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map)
if self.index_attr_map:
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
if self.operand_def.default_vals:
self_dict["default_vals"] = self.operand_def.default_vals
return self_dict
def __repr__(self):
return (f"OperandDefConfig({self.operand_def}, "
f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, "
f"shape_map={self.shape_map}, "
f"index_attr_map={self.index_attr_map}, "
f"indexing_map={self.indexing_map})")
@ -162,7 +166,7 @@ class LinalgStructuredOpConfig(YAMLObject):
# Collect all attribute definitions.
collected_attr_defs = list()
for operand in registered_operands:
if operand.kind == OperandKind.Attribute:
if operand.kind == OperandKind.IndexAttr:
collected_attr_defs.append(operand)
# Collect all tensors with manual indexing annotation.
@ -210,9 +214,9 @@ class LinalgStructuredOpConfig(YAMLObject):
if operand_config.shape_map:
operand_config.shape_map = self._normalize_affine_map(
operand_config.shape_map, with_dims=False)
if operand_config.attribute_map:
operand_config.attribute_map = self._normalize_affine_map(
operand_config.attribute_map, with_dims=False)
if operand_config.index_attr_map:
operand_config.index_attr_map = self._normalize_affine_map(
operand_config.index_attr_map, with_dims=False)
# Now for each write use, propagate the indexing maps from the use to the
# tensor, ensuring that there are not conflicts.
@ -245,7 +249,7 @@ class LinalgStructuredOpConfig(YAMLObject):
# Check all registered tensor and scalar operands have an indexing map.
for operand in registered_operands:
if operand.kind == OperandKind.Attribute:
if operand.kind == OperandKind.IndexAttr:
continue
if not (operand in self.operands and self.operands[operand].indexing_map):
raise ValueError(f"Failed to compute an indexing map for operand "
@ -319,9 +323,9 @@ class LinalgStructuredOpConfig(YAMLObject):
assert local_state.local_dim_count == 0
affine_map = _ir.AffineMap.get(
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
if operand_def.kind == OperandKind.Attribute:
if operand_def.kind == OperandKind.IndexAttr:
self.operands[operand_def] = OperandDefConfig(
operand_def, attribute_map=affine_map)
operand_def, index_attr_map=affine_map)
else:
self.operands[operand_def] = OperandDefConfig(
operand_def, shape_map=affine_map)

View File

@ -39,15 +39,14 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs = op_config.ordered_operands
in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"]
out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"]
index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"]
# Verify outs is a sequence or a list of results.
if not isinstance(outs, (Sequence, OpResultList)):
raise ValueError(
f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}"
)
raise ValueError(f"Expected named argument outs to have type Sequence or "
f"OpResultLis but got {type(outs)}")
# Arity validation.
if len(ins) != len(in_arg_defs):
@ -60,18 +59,19 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
# Compute a replacement list for all attribute symbols.
expressions = [] # type: Sequence[AffineExpr]
replacements = [] # type: Sequence[AffineExpr]
for attr in attr_arg_defs:
if attr.name not in attrs:
raise ValueError(f"Expected named argument for the attribute {attr.name}")
attribute_values = attrs.get(attr.name)
if not all(isinstance(value, int) for value in attribute_values):
raise ValueError(f"Attribute {attr.name} needs to be of type "
f"Sequence[int] but got {type(attribute_values)}")
results = attr.attribute_map.results # type: AffineExprList
if len(attribute_values) != len(results):
raise ValueError(f"Attribute {attr.name} has length {len(results)} "
f"but got {len(attribute_values)} values")
for expr, value in zip(results, attribute_values):
for index_attr in index_attr_arg_defs:
index_attr_vals = index_attr.operand_def.default_vals
if index_attr.name in attrs:
index_attr_vals = attrs.get(index_attr.name)
assert index_attr_vals, "Index attribute has no value"
if not all(isinstance(value, int) for value in index_attr_vals):
raise ValueError(f"Attribute {index_attr.name} needs to be of type "
f"Sequence[int] but got {type(index_attr_vals)}")
results = index_attr.index_attr_map.results # type: AffineExprList
if len(index_attr_vals) != len(results):
raise ValueError(f"Attribute {index_attr.name} has length {len(results)} "
f"but got {len(index_attr_vals)} values")
for expr, value in zip(results, index_attr_vals):
expressions.append(expr)
replacements.append(AffineConstantExpr.get(value))
@ -116,22 +116,24 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
iterator_types_attr = ArrayAttr.get(
[StringAttr.get(s) for s in op_config.iterator_types])
# Compute a dictionary storing all index attributes.
index_attributes = {} # type: Dict[str, DenseElementAttr]
for attr in attr_arg_defs:
attribute_values = attrs.get(attr.name)
array = np.array(attribute_values, dtype=np.int64)
index_attributes[attr.name] = DenseElementsAttr.get(array)
# Compute the index attributes used when emitting a named structured op.
index_attrs = {} # type: Dict[str, DenseElementAttr]
for index_attr in index_attr_arg_defs:
index_attr_vals = attrs.get(index_attr.name)
# Only forward attributes set to a non-default value.
if index_attr_vals:
array = np.array(index_attr_vals, dtype=np.int64)
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
type_mapping, indexing_maps_attr, iterator_types_attr,
index_attributes, block_arg_types)
index_attrs, block_arg_types)
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
outs: ValueList, **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
# An operation that accesses only scalars and scalar/rank zero tensors is
@ -182,7 +184,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
op_class_name: str, *ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
# If we get here, there must exist a builtin class `op_class_name`.
@ -195,7 +197,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
# Set the index attributes used to compute the indexing maps.
named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
for name, value in index_attributes.items():
for name, value in index_attrs.items():
named_op.operation.attributes[name] = value
linalg.fill_builtin_region(named_op.operation)

View File

@ -224,8 +224,8 @@ def conv_1d_nwc_wcf(
I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OW, S.F, output=True),
strides=IndexAttrDef(S.SW),
dilations=IndexAttrDef(S.DW)):
strides=IndexAttrDef(S.SW, default=[1]),
dilations=IndexAttrDef(S.DW, default=[1])):
"""Performs 1-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -244,8 +244,8 @@ def conv_2d_nhwc_hwcf(
S.C),
K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs 2-D convolution.
Layout:
@ -270,8 +270,8 @@ def conv_2d_nhwc_hwcf_q(
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs 2-D convolution with zero point offsets.
Layout:
@ -297,8 +297,8 @@ def conv_2d_nchw_fchw(
S.OW * S.SW + S.KW * S.DW),
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs 2-D convolution.
Layout:
@ -321,8 +321,8 @@ def conv_3d_ndhwc_dhwcf(
S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
strides=IndexAttrDef(S.SD, S.SH, S.SW),
dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
"""Performs 3-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -341,8 +341,8 @@ def depthwise_conv_1d_nwc_wc(
I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KW, S.IC),
O=TensorDef(U, S.N, S.OW, S.IC, output=True),
strides=IndexAttrDef(S.SW),
dilations=IndexAttrDef(S.DW)):
strides=IndexAttrDef(S.SW, default=[1]),
dilations=IndexAttrDef(S.DW, default=[1])):
"""Performs depth-wise 1-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -362,8 +362,8 @@ def depthwise_conv_2d_nhwc_hwc(
S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -385,8 +385,8 @@ def depthwise_conv_2d_nhwc_hwc_q(
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -407,8 +407,8 @@ def depthwise_conv_2d_nhwc_hwcm(
S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -429,8 +429,8 @@ def depthwise_conv_2d_nhwc_hwcm_q(
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -451,8 +451,8 @@ def pooling_nhwc_sum(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs sum pooling.
Numeric casting is performed on the input operand, promoting it to the same
@ -470,8 +470,8 @@ def pooling_nhwc_max(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs max pooling.
Numeric casting is performed on the input operand, promoting it to the same
@ -490,8 +490,8 @@ def pooling_nhwc_max_unsigned(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs unsigned max pooling.
Numeric casting is performed on the input operand, promoting it to the same
@ -510,8 +510,8 @@ def pooling_nchw_max(
S.OW * S.SW + S.KW * S.DW),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs max pooling.
Numeric casting is performed on the input operand, promoting it to the same
@ -531,8 +531,8 @@ def pooling_nhwc_min(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs min pooling.
Numeric casting is performed on the input operand, promoting it to the same
@ -551,8 +551,8 @@ def pooling_nhwc_min_unsigned(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs unsigned min pooling.
Numeric casting is performed on the input operand, promoting it to the same
@ -571,8 +571,8 @@ def pooling_ndhwc_sum(
S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SD, S.SH, S.SW),
dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
"""Performs 3D sum pooling.
Numeric casting is performed on the input operand, promoting it to the same
@ -591,8 +591,8 @@ def pooling_ndhwc_max(
S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SD, S.SH, S.SW),
dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
"""Performs 3D max pooling.
Numeric casting is performed on the input operand, promoting it to the same
@ -612,8 +612,8 @@ def pooling_ndhwc_min(
S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SD, S.SH, S.SW),
dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
"""Performs 3D min pooling.
Numeric casting is performed on the input operand, promoting it to the same

View File

@ -97,19 +97,12 @@ func @depthwise_conv_2d_nhwc_hwcm_memref_dilated(%input: memref<2x8x9x2xf32>, %f
// -----
func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{missing indexing map required attribute 'strides'}}
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
}
// -----
func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{missing indexing map required attribute 'dilations'}}
linalg.depthwise_conv_2d_nhwc_hwc {strides = dense<1> : vector<2xi64>}
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_default_attributes
func @depthwise_conv_2d_input_nhwc_filter_default_attributes(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// CHECK: linalg.depthwise_conv_2d_nhwc_hwc
// CHECK-NOT: strides =
// CHECK-NOT: dilations =
linalg.depthwise_conv_2d_nhwc_hwc
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@ -118,7 +111,7 @@ func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113
// -----
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}}
// expected-error @+1 {{incorrect element type for index attribute 'strides'}}
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
@ -128,7 +121,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
// -----
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}}
// expected-error @+1 {{incorrect shape for index attribute 'strides'}}
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
@ -566,7 +559,7 @@ func @conv_interface_wrong_input_indexing_map(
%arg0 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// expected-error @+1 {{unexpected input index map for convolutions}}
%0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
%1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32
%2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32
"linalg.yield"(%2) : (f32) -> ()
@ -583,7 +576,7 @@ func @conv_interface_wrong_num_operands(
%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// expected-error @+1 {{expected output/filter indexing maps to be projected permutations}}
%0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
%1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32
%2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32
"linalg.yield"(%2) : (f32) -> ()

View File

@ -21,7 +21,7 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
usage: Output
type_var: T
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
indexing_maps: !LinalgIndexingMapsConfig
@ -95,7 +95,7 @@ structured_op: !LinalgStructuredOpConfig
# @linalg_structured_op
# def test2(I=TensorDef(T, S.M, S.N),
# O=TensorDef(T, S.M, S.N, output=True),
# strides=IndexAttrDef(S.SM, S.SN)):
# strides=IndexAttrDef(S.SM, S.SN, default=[1, 2])):
# """Title.
# Detailed description.
@ -114,19 +114,21 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
usage: InputOperand
usage: Input
type_var: T
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
usage: Output
type_var: T
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
usage: IndexAttr
index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
default_vals:
- 1
- 2
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>
@ -145,7 +147,8 @@ structured_op: !LinalgStructuredOpConfig
# ODS: let arguments =
# ODS-NEXT: Variadic<AnyType>:$inputs,
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
# ODS-NEXT: RankedI64ElementsAttr<[2]>:$strides
# ODS-NEXT: DefaultValuedAttr<RankedI64ElementsAttr<[2]>
# ODS-SAME: "{ static_cast<int64_t>(1), static_cast<int64_t>(2) }">:$strides
# ODS: "Attribute":$strides
# ODS: $_state.addAttribute("strides", strides);
@ -169,8 +172,8 @@ structured_op: !LinalgStructuredOpConfig
# IMPL: Test2Op::hasDynamicIndexingMaps() { return true; }
# IMPL: Test2Op::verifyIndexingMapRequiredAttributes()
# IMPL: auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
# IMPL: "missing indexing map required attribute 'strides'"
# IMPL: "incorrect element type for index attribute 'strides'"
# IMPL: "incorrect shape for index attribute 'strides'"
# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block)
# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 &&
@ -197,11 +200,11 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: value
usage: InputOperand
usage: Input
type_var: T1
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
usage: Output
type_var: U
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig

View File

@ -7,15 +7,15 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK-LABEL: matmul
# CHECK: args:
# CHECK: name: A
# CHECK: usage: InputOperand
# CHECK: usage: Input
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
# CHECK: name: B
# CHECK: usage: InputOperand
# CHECK: usage: Input
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
# CHECK: name: C
# CHECK: usage: OutputOperand
# CHECK: usage: Output
# CHECK: type_var: U
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
@linalg_structured_op
@ -30,7 +30,7 @@ def matmul(
# CHECK-LABEL: fill
# CHECK: args:
# CHECK: name: value
# CHECK: usage: InputOperand
# CHECK: usage: Input
# CHECK-NOT: shape_map:
# CHECK: type_var: T
@linalg_structured_op
@ -42,20 +42,22 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
# CHECK-LABEL: strided_copy
# CHECK: args:
# CHECK: name: I
# CHECK: usage: InputOperand
# CHECK: usage: Input
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
# CHECK: name: O
# CHECK: usage: OutputOperand
# CHECK: usage: Output
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
# CHECK: name: strides
# CHECK: usage: IndexAttribute
# CHECK: type_var: I64
# CHECK: attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
# CHECK: usage: IndexAttr
# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
# CHECK: default_vals:
# CHECK: - 1
# CHECK: - 2
@linalg_structured_op
def strided_copy(
I=TensorDef(T, S.IH, S.IW),
O=TensorDef(T, S.OH, S.OW, output=True),
strides=IndexAttrDef(S.SH, S.SW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])):
O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]

View File

@ -16,8 +16,8 @@ def conv_poly(
I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
K=TensorDef(T2, S.KH, S.KW, S.C),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@ -51,8 +51,9 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((2, 2, 1), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_conv(input, filter, init_result):
# Use default dilations and set non-default strides.
return conv_poly(
input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
input, filter, outs=[init_result], strides=[2, 4])
print(module)

View File

@ -16,8 +16,8 @@ def pooling_max_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
TypeFn.cast(
@ -29,8 +29,8 @@ def pooling_max_unsigned_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
@ -42,8 +42,8 @@ def pooling_min_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
TypeFn.cast(
@ -55,8 +55,8 @@ def pooling_min_unsigned_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(

View File

@ -132,7 +132,7 @@ func @main() -> i32 attributes {llvm.emit_c_interface} {
%c2 = arith.constant 2 : index
memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64>
memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64>
call @pooling_on_buffers(%input, %shape, %output) :
(memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
@ -421,9 +421,13 @@ def test_min_pooling_builtin():
@builtin.FuncOp.from_py_func(
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
MemRefType.get((1, 2, 4, 1), i32))
# Set the strides and use the default dilations.
def pooling_on_buffers(input, shape, output):
linalg.pooling_nhwc_min(
input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
input,
shape,
outs=[output],
strides=[2, 4])
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
@ -451,13 +455,13 @@ def test_min_pooling_generic():
@builtin.FuncOp.from_py_func(
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
MemRefType.get((1, 2, 4, 1), i32))
# Set the strides and use the default dilations.
def pooling_on_buffers(input, shape, output):
linalg.pooling_nhwc_min(
input,
shape,
outs=[output],
strides=[2, 4],
dilations=[1, 2],
emit_generic=True)
execution_engine = ExecutionEngine(transform(module, pooling_boiler))

View File

@ -61,14 +61,15 @@ struct SerializedAffineMap {
AffineMap affineMap() { return affineMapAttr.getValue(); }
};
enum class LinalgOperandDefUsage { input, output, attribute };
enum class LinalgOperandDefUsage { Input, Output, IndexAttr };
struct LinalgOperandDef {
std::string name;
LinalgOperandDefUsage usage;
std::string typeVar;
Optional<std::string> typeVar;
Optional<SerializedAffineMap> shapeMap;
Optional<SerializedAffineMap> attributeMap;
Optional<SerializedAffineMap> indexAttrMap;
Optional<SmallVector<int64_t>> defaultVals;
};
enum class LinalgIteratorTypeDef {
@ -175,18 +176,21 @@ struct MappingTraits<LinalgStructuredOpConfig> {
/// the argument. Only tensor arguments have a `shape_map`. Each shape must
/// be normalized over the same list of symbols and have no dimension
/// inputs.
/// - `attribute_map`: An optional AffineMap from all op symbols to the
/// attribute symbols. During op creation these symbols are replaced by the
/// corresponding `name` attribute values. Only attribute arguments have
/// an `attribute_map`.
/// - `index_attr_map`: An optional AffineMap from all op symbols to the
/// index attribute symbols. During op creation these symbols are replaced
/// by the corresponding `name` index attribue values. Only index attribute
/// arguments have an `index_attr_map`.
/// - `default_vals`: An optional default initialization for index attribute
/// arguments.
template <>
struct MappingTraits<LinalgOperandDef> {
static void mapping(IO &io, LinalgOperandDef &info) {
io.mapRequired("name", info.name);
io.mapRequired("usage", info.usage);
io.mapRequired("type_var", info.typeVar);
io.mapOptional("type_var", info.typeVar);
io.mapOptional("shape_map", info.shapeMap);
io.mapOptional("attribute_map", info.attributeMap);
io.mapOptional("index_attr_map", info.indexAttrMap);
io.mapOptional("default_vals", info.defaultVals);
}
};
@ -194,9 +198,9 @@ struct MappingTraits<LinalgOperandDef> {
template <>
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
io.enumCase(value, "IndexAttribute", LinalgOperandDefUsage::attribute);
io.enumCase(value, "Input", LinalgOperandDefUsage::Input);
io.enumCase(value, "Output", LinalgOperandDefUsage::Output);
io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr);
}
};
@ -395,7 +399,10 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
// Search all argument types.
for (const auto &it : llvm::enumerate(args)) {
if (it.value().typeVar == typeVar)
if (it.value().usage != LinalgOperandDefUsage::Input &&
it.value().usage != LinalgOperandDefUsage::Output)
continue;
if (it.value().typeVar.getValue() == typeVar)
return llvm::formatv("block.getArgument({0}).getType()", it.index())
.str();
}
@ -674,20 +681,32 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
// Assemble the attribute specific logic required for the op definition.
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return arg.usage == LinalgOperandDefUsage::attribute;
return arg.usage == LinalgOperandDefUsage::IndexAttr;
})) {
SmallVector<std::string> attrDefs;
SmallVector<std::string> attrParams;
SmallVector<std::string> attrStmts;
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
if (arg.usage != LinalgOperandDefUsage::attribute)
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
continue;
assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
static const char defFmt[] = "RankedI64ElementsAttr<[{0}]>:${1}";
assert(arg.indexAttrMap.hasValue());
assert(arg.defaultVals.hasValue());
size_t size = arg.indexAttrMap->affineMap().getNumResults();
assert(arg.defaultVals.getValue().size() == size);
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
static const char paramFmt[] = "\"Attribute\":${0}";
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
attrDefs.push_back(llvm::formatv(
defFmt, arg.attributeMap->affineMap().getNumResults(), arg.name));
std::string defaultVals;
llvm::raw_string_ostream ss(defaultVals);
ss << "{ ";
llvm::interleave(
arg.defaultVals.getValue(), ss,
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
", ");
ss << " }";
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
ss.str(), arg.name));
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
}
@ -725,7 +744,7 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
// Compute the number of scalar and tensor arguments.
int64_t numOfArgs =
llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return arg.usage != LinalgOperandDefUsage::attribute;
return arg.usage != LinalgOperandDefUsage::IndexAttr;
});
// An operation that accesses only scalars and scalar/rank zero tensors is
@ -796,11 +815,11 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
)FMT";
// Update all symbol bindings mapped to an attribute.
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
if (arg.usage != LinalgOperandDefUsage::attribute)
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
continue;
assert(arg.attributeMap.hasValue());
assert(arg.indexAttrMap.hasValue());
for (auto &en :
llvm::enumerate(arg.attributeMap->affineMap().getResults())) {
llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
symbolBindings[symbol.getPosition()] =
llvm::formatv(structuredOpAccessAttrFormat, arg.name,
@ -889,31 +908,26 @@ std::string {0}::getLibraryCallName() {{
// hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return arg.usage == LinalgOperandDefUsage::attribute;
return arg.usage == LinalgOperandDefUsage::IndexAttr;
})) {
std::vector<std::string> attrVerifications;
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
if (arg.usage != LinalgOperandDefUsage::attribute)
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
continue;
assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
assert(arg.indexAttrMap.hasValue());
// Verify index attribute. Paramters:
// {0}: Attribute name
// {1}: Attribute size
static const char attrFmt[] = R"FMT(
if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
if (!attr.getType().getElementType().isInteger(64))
return op->emitError(
"incorrect element type for indexing map required attribute '{0}'");
return op->emitError("incorrect element type for index attribute '{0}'");
if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
return op->emitError(
"incorrect shape for indexing map required attribute '{0}'");
} else {
return op->emitError(
"missing indexing map required attribute '{0}'");
return op->emitError("incorrect shape for index attribute '{0}'");
}
)FMT";
attrVerifications.push_back(llvm::formatv(
attrFmt, arg.name, arg.attributeMap->affineMap().getNumResults()));
attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults()));
}
// Generates the verifyIndexingMapRequiredAttributes method. Parameters:
@ -953,7 +967,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
int localCounter = 0;
SmallVector<std::string> stmts;
for (LinalgOperandDef &arg : args) {
if (arg.usage != LinalgOperandDefUsage::output)
if (arg.usage != LinalgOperandDefUsage::Output)
continue;
// Find the assignment that correlates with the argument.