llvm-project/mlir/lib/CAPI/Dialect/SparseTensor.cpp
Aart Bik 836411b99f
[mlir][sparse] add lvlToDim field to sparse tensor encoding (#67194)
Note the new surface syntax allows for defining a dimToLvl and lvlToDim
map at once (where usually the latter can be inferred from the former,
but not always). This revision adds storage for the latter, together
with some intial boilerplate. The actual support (inference, validation,
printing, etc.) is still TBD of course.
2023-09-22 15:51:25 -07:00

88 lines
3.8 KiB
C++

//===- Tensor.cpp - C API for SparseTensor dialect ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/SparseTensor.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Support/LLVM.h"
using namespace llvm;
using namespace mlir::sparse_tensor;
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
mlir::sparse_tensor::SparseTensorDialect)
// Ensure the C-API enums are int-castable to C++ equivalents.
static_assert(
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
static_cast<int>(DimLevelType::Dense) &&
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
static_cast<int>(DimLevelType::Compressed) &&
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) ==
static_cast<int>(DimLevelType::CompressedNu) &&
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) ==
static_cast<int>(DimLevelType::CompressedNo) &&
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) ==
static_cast<int>(DimLevelType::CompressedNuNo) &&
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
static_cast<int>(DimLevelType::Singleton) &&
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) ==
static_cast<int>(DimLevelType::SingletonNu) &&
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) ==
static_cast<int>(DimLevelType::SingletonNo) &&
static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) ==
static_cast<int>(DimLevelType::SingletonNuNo),
"MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
return isa<SparseTensorEncodingAttr>(unwrap(attr));
}
MlirAttribute
mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
MlirSparseTensorDimLevelType const *lvlTypes,
MlirAffineMap dimToLvl, int posWidth,
int crdWidth) {
SmallVector<DimLevelType> cppLvlTypes;
cppLvlTypes.reserve(lvlRank);
for (intptr_t l = 0; l < lvlRank; ++l)
cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
mlir::AffineMap lvlToDim; // TODO: provide in API
return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
unwrap(dimToLvl), lvlToDim,
posWidth, crdWidth));
}
MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl());
}
MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) {
return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim());
}
intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
}
MlirSparseTensorDimLevelType
mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
return static_cast<MlirSparseTensorDimLevelType>(
cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
}
int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth();
}
int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
}