llvm-project/mlir/lib/Bindings/Python/DialectQuant.cpp
Sandeep Dasgupta 81d7eef134
Sub-channel quantized type implementation (#120172)
This is an implementation for [RFC: Supporting Sub-Channel Quantization
in
MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694).

In order to make the review process easier, the PR has been divided into
the following commit labels:

1. **Add implementation for sub-channel type:** Includes the class
design for `UniformQuantizedSubChannelType`, printer/parser and bytecode
read/write support. The existing types (per-tensor and per-axis) are
unaltered.
2. **Add implementation for sub-channel type:** Lowering of
`quant.qcast` and `quant.dcast` operations to Linalg operations.
3. **Adding C/Python Apis:** We first define he C-APIs and build the
Python-APIs on top of those.
4. **Add pass to normalize generic ....:** This pass normalizes
sub-channel quantized types to per-tensor per-axis types, if possible.


A  design note:
- **Explicitly storing the `quantized_dimensions`, even when they can be
derived for ranked tensor.**
While it's possible to infer quantized dimensions from the static shape
of the scales (or zero-points) tensor for ranked
data tensors
([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3)
for background), there are cases where this can lead to ambiguity and
issues with round-tripping.

```
Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>>
```

The shape of the scales tensor is [1, 2], which might suggest that only
axis 1 is quantized. While this inference is technically correct, as the
block size for axis 0 is a degenerate case (equal to the dimension
size), it can cause problems with round-tripping. Therefore, even for
ranked tensors, we are explicitly storing the quantized dimensions.
Suggestions welcome!


PS: I understand that the upcoming holidays may impact your schedule, so
please take your time with the review. There's no rush.
2025-03-23 07:37:55 -05:00

390 lines
17 KiB
C++

//===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <cstdint>
#include <vector>
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
using namespace mlir;
using namespace mlir::python::nanobind_adaptors;
static void populateDialectQuantSubmodule(const nb::module_ &m) {
//===-------------------------------------------------------------------===//
// QuantizedType
//===-------------------------------------------------------------------===//
auto quantizedType =
mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
quantizedType.def_staticmethod(
"default_minimum_for_integer",
[](bool isSigned, unsigned integralWidth) {
return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
integralWidth);
},
"Default minimum value for the integer with the specified signedness and "
"bit width.",
nb::arg("is_signed"), nb::arg("integral_width"));
quantizedType.def_staticmethod(
"default_maximum_for_integer",
[](bool isSigned, unsigned integralWidth) {
return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
integralWidth);
},
"Default maximum value for the integer with the specified signedness and "
"bit width.",
nb::arg("is_signed"), nb::arg("integral_width"));
quantizedType.def_property_readonly(
"expressed_type",
[](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
"Type expressed by this quantized type.");
quantizedType.def_property_readonly(
"flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
"Flags of this quantized type (named accessors should be preferred to "
"this)");
quantizedType.def_property_readonly(
"is_signed",
[](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
"Signedness of this quantized type.");
quantizedType.def_property_readonly(
"storage_type",
[](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
"Storage type backing this quantized type.");
quantizedType.def_property_readonly(
"storage_type_min",
[](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
"The minimum value held by the storage type of this quantized type.");
quantizedType.def_property_readonly(
"storage_type_max",
[](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
"The maximum value held by the storage type of this quantized type.");
quantizedType.def_property_readonly(
"storage_type_integral_width",
[](MlirType type) {
return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
},
"The bitwidth of the storage type of this quantized type.");
quantizedType.def(
"is_compatible_expressed_type",
[](MlirType type, MlirType candidate) {
return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
},
"Checks whether the candidate type can be expressed by this quantized "
"type.",
nb::arg("candidate"));
quantizedType.def_property_readonly(
"quantized_element_type",
[](MlirType type) {
return mlirQuantizedTypeGetQuantizedElementType(type);
},
"Element type of this quantized type expressed as quantized type.");
quantizedType.def(
"cast_from_storage_type",
[](MlirType type, MlirType candidate) {
MlirType castResult =
mlirQuantizedTypeCastFromStorageType(type, candidate);
if (!mlirTypeIsNull(castResult))
return castResult;
throw nb::type_error("Invalid cast.");
},
"Casts from a type based on the storage type of this quantized type to a "
"corresponding type based on the quantized type. Raises TypeError if the "
"cast is not valid.",
nb::arg("candidate"));
quantizedType.def_staticmethod(
"cast_to_storage_type",
[](MlirType type) {
MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
if (!mlirTypeIsNull(castResult))
return castResult;
throw nb::type_error("Invalid cast.");
},
"Casts from a type based on a quantized type to a corresponding type "
"based on the storage type of this quantized type. Raises TypeError if "
"the cast is not valid.",
nb::arg("type"));
quantizedType.def(
"cast_from_expressed_type",
[](MlirType type, MlirType candidate) {
MlirType castResult =
mlirQuantizedTypeCastFromExpressedType(type, candidate);
if (!mlirTypeIsNull(castResult))
return castResult;
throw nb::type_error("Invalid cast.");
},
"Casts from a type based on the expressed type of this quantized type to "
"a corresponding type based on the quantized type. Raises TypeError if "
"the cast is not valid.",
nb::arg("candidate"));
quantizedType.def_staticmethod(
"cast_to_expressed_type",
[](MlirType type) {
MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
if (!mlirTypeIsNull(castResult))
return castResult;
throw nb::type_error("Invalid cast.");
},
"Casts from a type based on a quantized type to a corresponding type "
"based on the expressed type of this quantized type. Raises TypeError if "
"the cast is not valid.",
nb::arg("type"));
quantizedType.def(
"cast_expressed_to_storage_type",
[](MlirType type, MlirType candidate) {
MlirType castResult =
mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
if (!mlirTypeIsNull(castResult))
return castResult;
throw nb::type_error("Invalid cast.");
},
"Casts from a type based on the expressed type of this quantized type to "
"a corresponding type based on the storage type. Raises TypeError if the "
"cast is not valid.",
nb::arg("candidate"));
quantizedType.get_class().attr("FLAG_SIGNED") =
mlirQuantizedTypeGetSignedFlag();
//===-------------------------------------------------------------------===//
// AnyQuantizedType
//===-------------------------------------------------------------------===//
auto anyQuantizedType =
mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
quantizedType.get_class());
anyQuantizedType.def_classmethod(
"get",
[](nb::object cls, unsigned flags, MlirType storageType,
MlirType expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
storageTypeMin, storageTypeMax));
},
"Gets an instance of AnyQuantizedType in the same context as the "
"provided storage type.",
nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
nb::arg("expressed_type"), nb::arg("storage_type_min"),
nb::arg("storage_type_max"));
//===-------------------------------------------------------------------===//
// UniformQuantizedType
//===-------------------------------------------------------------------===//
auto uniformQuantizedType = mlir_type_subclass(
m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
quantizedType.get_class());
uniformQuantizedType.def_classmethod(
"get",
[](nb::object cls, unsigned flags, MlirType storageType,
MlirType expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
return cls(mlirUniformQuantizedTypeGet(flags, storageType,
expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax));
},
"Gets an instance of UniformQuantizedType in the same context as the "
"provided storage type.",
nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"),
nb::arg("storage_type_min"), nb::arg("storage_type_max"));
uniformQuantizedType.def_property_readonly(
"scale",
[](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
"The scale designates the difference between the real values "
"corresponding to consecutive quantized values differing by 1.");
uniformQuantizedType.def_property_readonly(
"zero_point",
[](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
"The storage value corresponding to the real value 0 in the affine "
"equation.");
uniformQuantizedType.def_property_readonly(
"is_fixed_point",
[](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
"Fixed point values are real numbers divided by a scale.");
//===-------------------------------------------------------------------===//
// UniformQuantizedPerAxisType
//===-------------------------------------------------------------------===//
auto uniformQuantizedPerAxisType = mlir_type_subclass(
m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
quantizedType.get_class());
uniformQuantizedPerAxisType.def_classmethod(
"get",
[](nb::object cls, unsigned flags, MlirType storageType,
MlirType expressedType, std::vector<double> scales,
std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (scales.size() != zeroPoints.size())
throw nb::value_error(
"Mismatching number of scales and zero points.");
auto nDims = static_cast<intptr_t>(scales.size());
return cls(mlirUniformQuantizedPerAxisTypeGet(
flags, storageType, expressedType, nDims, scales.data(),
zeroPoints.data(), quantizedDimension, storageTypeMin,
storageTypeMax));
},
"Gets an instance of UniformQuantizedPerAxisType in the same context as "
"the provided storage type.",
nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
nb::arg("storage_type_max"));
uniformQuantizedPerAxisType.def_property_readonly(
"scales",
[](MlirType type) {
intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
std::vector<double> scales;
scales.reserve(nDim);
for (intptr_t i = 0; i < nDim; ++i) {
double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
scales.push_back(scale);
}
return scales;
},
"The scales designate the difference between the real values "
"corresponding to consecutive quantized values differing by 1. The ith "
"scale corresponds to the ith slice in the quantized_dimension.");
uniformQuantizedPerAxisType.def_property_readonly(
"zero_points",
[](MlirType type) {
intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
std::vector<int64_t> zeroPoints;
zeroPoints.reserve(nDim);
for (intptr_t i = 0; i < nDim; ++i) {
int64_t zeroPoint =
mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
zeroPoints.push_back(zeroPoint);
}
return zeroPoints;
},
"the storage values corresponding to the real value 0 in the affine "
"equation. The ith zero point corresponds to the ith slice in the "
"quantized_dimension.");
uniformQuantizedPerAxisType.def_property_readonly(
"quantized_dimension",
[](MlirType type) {
return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
},
"Specifies the dimension of the shape that the scales and zero points "
"correspond to.");
uniformQuantizedPerAxisType.def_property_readonly(
"is_fixed_point",
[](MlirType type) {
return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
},
"Fixed point values are real numbers divided by a scale.");
//===-------------------------------------------------------------------===//
// UniformQuantizedSubChannelType
//===-------------------------------------------------------------------===//
auto uniformQuantizedSubChannelType = mlir_type_subclass(
m, "UniformQuantizedSubChannelType",
mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
uniformQuantizedSubChannelType.def_classmethod(
"get",
[](nb::object cls, unsigned flags, MlirType storageType,
MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
std::vector<int32_t> quantizedDimensions,
std::vector<int64_t> blockSizes, int64_t storageTypeMin,
int64_t storageTypeMax) {
return cls(mlirUniformQuantizedSubChannelTypeGet(
flags, storageType, expressedType, scales, zeroPoints,
static_cast<intptr_t>(blockSizes.size()),
quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
storageTypeMax));
},
"Gets an instance of UniformQuantizedSubChannel in the same context as "
"the provided storage type.",
nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
nb::arg("storage_type_min"), nb::arg("storage_type_max"));
uniformQuantizedSubChannelType.def_property_readonly(
"quantized_dimensions",
[](MlirType type) {
intptr_t nDim =
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
std::vector<int32_t> quantizedDimensions;
quantizedDimensions.reserve(nDim);
for (intptr_t i = 0; i < nDim; ++i) {
quantizedDimensions.push_back(
mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
}
return quantizedDimensions;
},
"Gets the quantized dimensions. Each element in the returned list "
"represents an axis of the quantized data tensor that has a specified "
"block size. The order of elements corresponds to the order of block "
"sizes returned by 'block_sizes' method. It means that the data tensor "
"is quantized along the i-th dimension in the returned list using the "
"i-th block size from block_sizes method.");
uniformQuantizedSubChannelType.def_property_readonly(
"block_sizes",
[](MlirType type) {
intptr_t nDim =
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
std::vector<int64_t> blockSizes;
blockSizes.reserve(nDim);
for (intptr_t i = 0; i < nDim; ++i) {
blockSizes.push_back(
mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
}
return blockSizes;
},
"Gets the block sizes for the quantized dimensions. The i-th element in "
"the returned list corresponds to the block size for the i-th dimension "
"in the list returned by quantized_dimensions method.");
uniformQuantizedSubChannelType.def_property_readonly(
"scales",
[](MlirType type) -> MlirAttribute {
return mlirUniformQuantizedSubChannelTypeGetScales(type);
},
"The scales of the quantized type.");
uniformQuantizedSubChannelType.def_property_readonly(
"zero_points",
[](MlirType type) -> MlirAttribute {
return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
},
"The zero points of the quantized type.");
//===-------------------------------------------------------------------===//
// CalibratedQuantizedType
//===-------------------------------------------------------------------===//
auto calibratedQuantizedType = mlir_type_subclass(
m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
quantizedType.get_class());
calibratedQuantizedType.def_classmethod(
"get",
[](nb::object cls, MlirType expressedType, double min, double max) {
return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
},
"Gets an instance of CalibratedQuantizedType in the same context as the "
"provided expressed type.",
nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"),
nb::arg("max"));
calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
return mlirCalibratedQuantizedTypeGetMin(type);
});
calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
return mlirCalibratedQuantizedTypeGetMax(type);
});
}
NB_MODULE(_mlirDialectsQuant, m) {
m.doc() = "MLIR Quantization dialect";
populateDialectQuantSubmodule(m);
}