My Project
UniformSupport.h
Go to the documentation of this file.
1 //===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
10 #define MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
11 
13 #include "mlir/IR/StandardTypes.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/APSInt.h"
18 
19 namespace mlir {
20 namespace quant {
21 
35 
38  Type convert(QuantizedType elementalType) const;
39 
41  explicit operator bool() const { return (bool)expressedType; }
42 
45  const Type inputType;
46 
50 };
51 
60 public:
63  uniformType.getScale(),
64  static_cast<double>(uniformType.getZeroPoint()),
65  static_cast<double>(uniformType.getStorageTypeMin()),
66  static_cast<double>(uniformType.getStorageTypeMax()),
67  uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
68  assert(uniformType.getExpressedType().isa<FloatType>());
69  assert(uniformType.getStorageType().isa<IntegerType>());
70  }
71 
72  UniformQuantizedValueConverter(double scale, double zeroPoint,
73  double clampMin, double clampMax,
74  uint32_t storageBitWidth, bool isSigned)
75  : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
76  clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
77  clampMinDouble(clampMin), clampMaxDouble(clampMax),
78  storageBitWidth(storageBitWidth), isSigned(isSigned),
79  roundMode(APFloat::rmNearestTiesToAway) {}
80 
81  UniformQuantizedValueConverter(double scale, double zeroPoint,
82  APFloat clampMin, APFloat clampMax,
83  uint32_t storageBitWidth, bool isSigned)
84  : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
85  clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
86  clampMinDouble(clampMin.convertToDouble()),
87  clampMaxDouble(clampMax.convertToDouble()),
88  storageBitWidth(storageBitWidth), isSigned(isSigned),
89  roundMode(APFloat::rmNearestTiesToAway) {}
90 
91  virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
92  // This function is a performance critical code path in quantization
93  // since it runs for each single float parameter value.
94 
95  // Specialize f32->u8/i8 case to optimize performance.
96  if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
97  storageBitWidth == 8 &&
98  roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
99  return quantizeF32ToInt8(expressedValue);
100  }
101 
102  bool lossy;
103  expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
104  // fixedpoint = clamp(clampMin, clampMax, (
105  // roundHalfToEven(expressed / scale) + zeroPoint))
106  APFloat scaled = (expressedValue / scale);
107  scaled.roundToIntegral(roundMode);
108  scaled.add(zeroPoint, roundMode);
109  APFloat fixedpoint = llvm::minimum(scaled, clampMax);
110  fixedpoint = llvm::maximum(fixedpoint, clampMin);
111 
112  llvm::APSInt result(storageBitWidth, !isSigned);
113  fixedpoint.convertToInteger(result, roundMode, &lossy);
114 
115  return std::move(result);
116  }
117 
118  int64_t quantizeFloatToInt64(APFloat expressedValue) const {
119  APInt qValue = quantizeFloatToInt(expressedValue);
120  return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
121  }
122 
124 
125 private:
126  // An optimized implementation to quantize f32 to i8/u8 with C++ native
127  // arithmetic.
128  virtual APInt quantizeF32ToInt8(APFloat expressedValue) const {
129  assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
130  assert(storageBitWidth == 8);
131  assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);
132 
133  const float realValue = expressedValue.convertToFloat();
134 
135  const double scaled = realValue / scaleDouble + zeroPointDouble;
136  // Round to nearest integer with halfway cases rounded away from zero.
137  const double scaledRounded = std::round(scaled);
138  const double clamped =
139  std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble);
140 
141  uint64_t signlessResult;
142  if (isSigned) {
143  int64_t clampedInt = static_cast<int8_t>(clamped);
144  memcpy(&signlessResult, &clampedInt, sizeof(clampedInt));
145  } else {
146  signlessResult = static_cast<uint8_t>(clamped);
147  }
148  return APInt(storageBitWidth, signlessResult);
149  }
150 
151  // Keep both APFloat and double versions of the quantization parameters
152  // around since they will be used in generic and specialized arithmetic,
153  // respectively.
154  const APFloat scale;
155  const APFloat zeroPoint;
156  const APFloat clampMin;
157  const APFloat clampMax;
158 
159  const double scaleDouble;
160  const double zeroPointDouble;
161  const double clampMinDouble;
162  const double clampMaxDouble;
163 
164  const uint32_t storageBitWidth;
165  const bool isSigned;
166  const llvm::APFloat::roundingMode roundMode;
167 };
168 
174 public:
176  UniformQuantizedPerAxisType uniformType)
177  : scales(uniformType.getScales()),
178  zeroPoints(uniformType.getZeroPoints()),
179  clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
180  clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
181  storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
182  isSigned(uniformType.isSigned()),
183  quantizationDim(uniformType.getQuantizedDimension()) {
184  assert(uniformType.getExpressedType().isa<FloatType>());
185  assert(uniformType.getStorageType().isa<IntegerType>());
186  assert(scales.size() == zeroPoints.size());
187  }
188 
191  ElementsAttr convert(Attribute realValue);
192 
193 private:
196 
199  UniformQuantizedValueConverter getPerChunkConverter(int index) const {
200  UniformQuantizedValueConverter converter(scales[index], zeroPoints[index],
201  clampMin, clampMax,
202  storageBitWidth, isSigned);
203  return converter;
204  }
205 
206  const ArrayRef<double> scales;
207  const ArrayRef<int64_t> zeroPoints;
208  const APFloat clampMin;
209  const APFloat clampMax;
210  const uint32_t storageBitWidth;
211  const bool isSigned;
212  int32_t quantizationDim;
213 };
214 
215 } // namespace quant
216 } // namespace mlir
217 
218 #endif // MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
Definition: InferTypeOpInterface.cpp:20
Definition: Attributes.h:976
virtual ~UniformQuantizedValueConverter()
Definition: UniformSupport.h:123
Integer types can have arbitrary bitwidth up to a large fixed limit.
Definition: StandardTypes.h:82
Definition: QuantTypes.h:270
Definition: UniformSupport.h:59
Definition: StandardTypes.h:113
UniformQuantizedValueConverter(double scale, double zeroPoint, APFloat clampMin, APFloat clampMax, uint32_t storageBitWidth, bool isSigned)
Definition: UniformSupport.h:81
int64_t quantizeFloatToInt64(APFloat expressedValue) const
Definition: UniformSupport.h:118
Definition: LLVM.h:37
const Type expressedType
Definition: UniformSupport.h:49
Definition: Attributes.h:660
UniformQuantizedValueConverter(UniformQuantizedType uniformType)
Definition: UniformSupport.h:61
Definition: UniformSupport.h:32
Definition: Attributes.h:53
Definition: QuantTypes.h:331
UniformQuantizedPerAxisValueConverter(UniformQuantizedPerAxisType uniformType)
Definition: UniformSupport.h:175
Definition: Types.h:84
UniformQuantizedValueConverter(double scale, double zeroPoint, double clampMin, double clampMax, uint32_t storageBitWidth, bool isSigned)
Definition: UniformSupport.h:72
static const ExpressedToQuantizedConverter forInputType(Type inputType)
Creates a converter for the given input type.
Definition: UniformSupport.cpp:21
Definition: QuantTypes.h:60
Type convert(QuantizedType elementalType) const
Definition: UniformSupport.cpp:44
virtual APInt quantizeFloatToInt(APFloat expressedValue) const
Definition: UniformSupport.h:91
const Type inputType
Definition: UniformSupport.h:45
Definition: Attributes.h:559