9 #ifndef MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ 10 #define MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/APInt.h" 17 #include "llvm/ADT/APSInt.h" 63 uniformType.getScale(),
64 static_cast<double>(uniformType.getZeroPoint()),
65 static_cast<double>(uniformType.getStorageTypeMin()),
66 static_cast<double>(uniformType.getStorageTypeMax()),
67 uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
68 assert(uniformType.getExpressedType().isa<
FloatType>());
69 assert(uniformType.getStorageType().isa<
IntegerType>());
73 double clampMin,
double clampMax,
74 uint32_t storageBitWidth,
bool isSigned)
75 : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
76 clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
77 clampMinDouble(clampMin), clampMaxDouble(clampMax),
78 storageBitWidth(storageBitWidth), isSigned(isSigned),
79 roundMode(APFloat::rmNearestTiesToAway) {}
82 APFloat clampMin, APFloat clampMax,
83 uint32_t storageBitWidth,
bool isSigned)
84 : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
85 clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
86 clampMinDouble(clampMin.convertToDouble()),
87 clampMaxDouble(clampMax.convertToDouble()),
88 storageBitWidth(storageBitWidth), isSigned(isSigned),
89 roundMode(APFloat::rmNearestTiesToAway) {}
96 if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
97 storageBitWidth == 8 &&
98 roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
99 return quantizeF32ToInt8(expressedValue);
103 expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
106 APFloat scaled = (expressedValue / scale);
107 scaled.roundToIntegral(roundMode);
108 scaled.add(zeroPoint, roundMode);
109 APFloat fixedpoint = llvm::minimum(scaled, clampMax);
110 fixedpoint = llvm::maximum(fixedpoint, clampMin);
112 llvm::APSInt result(storageBitWidth, !isSigned);
113 fixedpoint.convertToInteger(result, roundMode, &lossy);
115 return std::move(result);
119 APInt qValue = quantizeFloatToInt(expressedValue);
120 return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
128 virtual APInt quantizeF32ToInt8(APFloat expressedValue)
const {
129 assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
130 assert(storageBitWidth == 8);
131 assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);
133 const float realValue = expressedValue.convertToFloat();
135 const double scaled = realValue / scaleDouble + zeroPointDouble;
137 const double scaledRounded = std::round(scaled);
138 const double clamped =
139 std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble);
141 uint64_t signlessResult;
143 int64_t clampedInt =
static_cast<int8_t
>(clamped);
144 memcpy(&signlessResult, &clampedInt,
sizeof(clampedInt));
146 signlessResult =
static_cast<uint8_t
>(clamped);
148 return APInt(storageBitWidth, signlessResult);
155 const APFloat zeroPoint;
156 const APFloat clampMin;
157 const APFloat clampMax;
159 const double scaleDouble;
160 const double zeroPointDouble;
161 const double clampMinDouble;
162 const double clampMaxDouble;
164 const uint32_t storageBitWidth;
166 const llvm::APFloat::roundingMode roundMode;
177 : scales(uniformType.getScales()),
178 zeroPoints(uniformType.getZeroPoints()),
179 clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
180 clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
181 storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
182 isSigned(uniformType.isSigned()),
183 quantizationDim(uniformType.getQuantizedDimension()) {
184 assert(uniformType.getExpressedType().isa<
FloatType>());
185 assert(uniformType.getStorageType().isa<
IntegerType>());
186 assert(scales.size() == zeroPoints.size());
202 storageBitWidth, isSigned);
208 const APFloat clampMin;
209 const APFloat clampMax;
210 const uint32_t storageBitWidth;
212 int32_t quantizationDim;
218 #endif // MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ Definition: InferTypeOpInterface.cpp:20
Definition: Attributes.h:976
Integer types can have arbitrary bitwidth up to a large fixed limit.
Definition: StandardTypes.h:82
Definition: StandardTypes.h:113
const Type expressedType
Definition: UniformSupport.h:49
Definition: Attributes.h:660
Definition: UniformSupport.h:32
Definition: Attributes.h:53
static const ExpressedToQuantizedConverter forInputType(Type inputType)
Creates a converter for the given input type.
Definition: UniformSupport.cpp:21
Definition: QuantTypes.h:60
Type convert(QuantizedType elementalType) const
Definition: UniformSupport.cpp:44
const Type inputType
Definition: UniformSupport.h:45
Definition: Attributes.h:559