Sandeep Dasgupta 81d7eef134
Sub-channel quantized type implementation (#120172)
This is an implementation for [RFC: Supporting Sub-Channel Quantization
in
MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694).

In order to make the review process easier, the PR has been divided into
the following commit labels:

1. **Add implementation for sub-channel type:** Includes the class
design for `UniformQuantizedSubChannelType`, printer/parser and bytecode
read/write support. The existing types (per-tensor and per-axis) are
unaltered.
2. **Add implementation for sub-channel type:** Lowering of
`quant.qcast` and `quant.dcast` operations to Linalg operations.
3. **Adding C/Python Apis:** We first define he C-APIs and build the
Python-APIs on top of those.
4. **Add pass to normalize generic ....:** This pass normalizes
sub-channel quantized types to per-tensor per-axis types, if possible.


A  design note:
- **Explicitly storing the `quantized_dimensions`, even when they can be
derived for ranked tensor.**
While it's possible to infer quantized dimensions from the static shape
of the scales (or zero-points) tensor for ranked
data tensors
([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3)
for background), there are cases where this can lead to ambiguity and
issues with round-tripping.

```
Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>>
```

The shape of the scales tensor is [1, 2], which might suggest that only
axis 1 is quantized. While this inference is technically correct, as the
block size for axis 0 is a degenerate case (equal to the dimension
size), it can cause problems with round-tripping. Therefore, even for
ranked tensors, we are explicitly storing the quantized dimensions.
Suggestions welcome!


PS: I understand that the upcoming holidays may impact your schedule, so
please take your time with the review. There's no rush.
2025-03-23 07:37:55 -05:00

274 lines
10 KiB
C++

//===- Quant.cpp - C Interface for Quant dialect --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
using namespace mlir;
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect)
//===---------------------------------------------------------------------===//
// QuantizedType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAQuantizedType(MlirType type) {
return isa<quant::QuantizedType>(unwrap(type));
}
unsigned mlirQuantizedTypeGetSignedFlag() {
return quant::QuantizationFlags::Signed;
}
int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,
unsigned integralWidth) {
return quant::QuantizedType::getDefaultMinimumForInteger(isSigned,
integralWidth);
}
int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
unsigned integralWidth) {
return quant::QuantizedType::getDefaultMaximumForInteger(isSigned,
integralWidth);
}
MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
return wrap(cast<quant::QuantizedType>(unwrap(type)).getExpressedType());
}
unsigned mlirQuantizedTypeGetFlags(MlirType type) {
return cast<quant::QuantizedType>(unwrap(type)).getFlags();
}
bool mlirQuantizedTypeIsSigned(MlirType type) {
return cast<quant::QuantizedType>(unwrap(type)).isSigned();
}
MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
return wrap(cast<quant::QuantizedType>(unwrap(type)).getStorageType());
}
int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMin();
}
int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMax();
}
unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeIntegralWidth();
}
bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
MlirType candidate) {
return cast<quant::QuantizedType>(unwrap(type))
.isCompatibleExpressedType(unwrap(candidate));
}
MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type)));
}
MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
MlirType candidate) {
return wrap(cast<quant::QuantizedType>(unwrap(type))
.castFromStorageType(unwrap(candidate)));
}
MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
return wrap(quant::QuantizedType::castToStorageType(
cast<quant::QuantizedType>(unwrap(type))));
}
MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
MlirType candidate) {
return wrap(cast<quant::QuantizedType>(unwrap(type))
.castFromExpressedType(unwrap(candidate)));
}
MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
return wrap(quant::QuantizedType::castToExpressedType(unwrap(type)));
}
MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
MlirType candidate) {
return wrap(cast<quant::QuantizedType>(unwrap(type))
.castExpressedToStorageType(unwrap(candidate)));
}
//===---------------------------------------------------------------------===//
// AnyQuantizedType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAAnyQuantizedType(MlirType type) {
return isa<quant::AnyQuantizedType>(unwrap(type));
}
MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
MlirType expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType),
unwrap(expressedType),
storageTypeMin, storageTypeMax));
}
//===---------------------------------------------------------------------===//
// UniformQuantizedType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAUniformQuantizedType(MlirType type) {
return isa<quant::UniformQuantizedType>(unwrap(type));
}
MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
MlirType expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax) {
return wrap(quant::UniformQuantizedType::get(
flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint,
storageTypeMin, storageTypeMax));
}
double mlirUniformQuantizedTypeGetScale(MlirType type) {
return cast<quant::UniformQuantizedType>(unwrap(type)).getScale();
}
int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
return cast<quant::UniformQuantizedType>(unwrap(type)).getZeroPoint();
}
bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
return cast<quant::UniformQuantizedType>(unwrap(type)).isFixedPoint();
}
//===---------------------------------------------------------------------===//
// UniformQuantizedPerAxisType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
return isa<quant::UniformQuantizedPerAxisType>(unwrap(type));
}
MlirType mlirUniformQuantizedPerAxisTypeGet(
unsigned flags, MlirType storageType, MlirType expressedType,
intptr_t nDims, double *scales, int64_t *zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
return wrap(quant::UniformQuantizedPerAxisType::get(
flags, unwrap(storageType), unwrap(expressedType),
llvm::ArrayRef(scales, nDims), llvm::ArrayRef(zeroPoints, nDims),
quantizedDimension, storageTypeMin, storageTypeMax));
}
intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getScales()
.size();
}
double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getScales()[pos];
}
int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
intptr_t pos) {
return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getZeroPoints()[pos];
}
int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getQuantizedDimension();
}
bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint();
}
//===---------------------------------------------------------------------===//
// UniformQuantizedSubChannelType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
return isa<quant::UniformQuantizedSubChannelType>(unwrap(type));
}
MlirType mlirUniformQuantizedSubChannelTypeGet(
unsigned flags, MlirType storageType, MlirType expressedType,
MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
int32_t *quantizedDimensions, int64_t *blockSizes, int64_t storageTypeMin,
int64_t storageTypeMax) {
auto scales = dyn_cast<mlir::DenseElementsAttr>(unwrap(scalesAttr));
auto zeroPoints = dyn_cast<mlir::DenseElementsAttr>(unwrap(zeroPointsAttr));
if (!scales || !zeroPoints) {
return {};
}
return wrap(quant::UniformQuantizedSubChannelType::get(
flags, unwrap(storageType), unwrap(expressedType), scales, zeroPoints,
llvm::ArrayRef<int32_t>(quantizedDimensions, nDims),
llvm::ArrayRef<int64_t>(blockSizes, nDims), storageTypeMin,
storageTypeMax));
}
intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) {
return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
.getBlockSizes()
.size();
}
int32_t mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
intptr_t pos) {
return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
.getQuantizedDimensions()[pos];
}
int64_t mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type,
intptr_t pos) {
return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
.getBlockSizes()[pos];
}
MlirAttribute mlirUniformQuantizedSubChannelTypeGetScales(MlirType type) {
return wrap(
cast<quant::UniformQuantizedSubChannelType>(unwrap(type)).getScales());
}
MlirAttribute mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type) {
return wrap(cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
.getZeroPoints());
}
//===---------------------------------------------------------------------===//
// CalibratedQuantizedType
//===---------------------------------------------------------------------===//
bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
return isa<quant::CalibratedQuantizedType>(unwrap(type));
}
MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
double max) {
return wrap(
quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max));
}
double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMin();
}
double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMax();
}