My Project
QuantTypes.h
Go to the documentation of this file.
1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
10 #define MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "mlir/IR/Types.h"
18 #include "llvm/Support/MathExtras.h"
19 
20 namespace mlir {
21 namespace quant {
22 
23 class QuantizedIntegerType;
24 
25 namespace detail {
26 
27 struct QuantizedTypeStorage;
28 struct AnyQuantizedTypeStorage;
29 struct UniformQuantizedTypeStorage;
30 struct UniformQuantizedPerAxisTypeStorage;
31 
32 } // namespace detail
33 
34 namespace QuantizationTypes {
35 enum Kind {
36  Any = Type::FIRST_QUANTIZATION_TYPE,
40 };
41 } // namespace QuantizationTypes
42 
44 namespace QuantizationFlags {
45 enum FlagValue {
46  // Indicates that the storage type should be interpreted as a signed
47  // integer. The default is to interpret it as an unsigned value.
48  Signed = 1,
49 };
50 } // namespace QuantizationFlags
51 
60 class QuantizedType : public Type {
61 public:
63  using Type::Type;
64 
66  static constexpr unsigned MaxStorageBits = 32;
67 
68  static LogicalResult
69  verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
70  unsigned flags, Type storageType,
71  Type expressedType, int64_t storageTypeMin,
72  int64_t storageTypeMax);
73 
75  static bool classof(Type type) {
76  return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE &&
78  }
79 
82  static int64_t getDefaultMinimumForInteger(bool isSigned,
83  unsigned integralWidth) {
84  if (isSigned) {
85  return llvm::minIntN(integralWidth);
86  }
87  return 0;
88  }
89 
92  static int64_t getDefaultMaximumForInteger(bool isSigned,
93  unsigned integralWidth) {
94  if (isSigned) {
95  return llvm::maxIntN(integralWidth);
96  }
97  return llvm::maxUIntN(integralWidth);
98  }
99 
108  Type getExpressedType() const;
109 
112  unsigned getFlags() const;
113 
114  // Convenience helpers.
117  bool isSigned() const {
118  return (getFlags() & QuantizationFlags::Signed) ==
120  }
121 
124  Type getStorageType() const;
125 
127  int64_t getStorageTypeMin() const;
128 
130  int64_t getStorageTypeMax() const;
131 
134  unsigned getStorageTypeIntegralWidth() const;
135 
143  bool isCompatibleExpressedType(Type candidateExpressedType);
144 
151  static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
152 
159  Type castFromStorageType(Type candidateType);
160 
164  static Type castToStorageType(Type quantizedType);
165 
172  Type castFromExpressedType(Type candidateType);
173 
177  static Type castToExpressedType(Type quantizedType);
178 
185  Type castExpressedToStorageType(Type candidateType);
186 
187 private:
192  using Type::isBF16;
193  using Type::isF16;
194  using Type::isF32;
195  using Type::isF64;
196  using Type::isIndex;
197  using Type::isInteger;
198 };
199 
210  : public Type::TypeBase<AnyQuantizedType, QuantizedType,
211  detail::AnyQuantizedTypeStorage> {
212 public:
213  using Base::Base;
214 
216  static bool kindof(unsigned kind) { return kind == QuantizationTypes::Any; }
217 
220  static AnyQuantizedType get(unsigned flags, Type storageType,
221  Type expressedType, int64_t storageTypeMin,
222  int64_t storageTypeMax);
223 
226  static AnyQuantizedType getChecked(unsigned flags, Type storageType,
227  Type expressedType, int64_t storageTypeMin,
228  int64_t storageTypeMax, Location location);
229 
231  static LogicalResult
232  verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
233  unsigned flags, Type storageType,
234  Type expressedType, int64_t storageTypeMin,
235  int64_t storageTypeMax);
236 };
237 
271  : public Type::TypeBase<UniformQuantizedType, QuantizedType,
272  detail::UniformQuantizedTypeStorage> {
273 public:
274  using Base::Base;
275 
278  static UniformQuantizedType get(unsigned flags, Type storageType,
279  Type expressedType, double scale,
280  int64_t zeroPoint, int64_t storageTypeMin,
281  int64_t storageTypeMax);
282 
285  static UniformQuantizedType
286  getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
287  int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
288  Location location);
289 
291  static LogicalResult verifyConstructionInvariants(
292  Optional<Location> loc, MLIRContext *context, unsigned flags,
293  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
294  int64_t storageTypeMin, int64_t storageTypeMax);
295 
297  static bool kindof(unsigned kind) {
299  }
300 
303  double getScale() const;
304 
307  int64_t getZeroPoint() const;
308 
309  // Fixed point values are real numbers divided by a scale.
310  // Currently, only signed storage types are treated as fixed point.
311  // A fixed point value can be obtained from an affine value by subtracting
312  // the zeroPoint.
313  // In the future, this may be explicit versus implied by type and zeroPoint.
314  bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
315 };
316 
332  : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
333  detail::UniformQuantizedPerAxisTypeStorage> {
334 public:
335  using Base::Base;
336 
340  get(unsigned flags, Type storageType, Type expressedType,
341  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
342  int32_t quantizedDimension, int64_t storageTypeMin,
343  int64_t storageTypeMax);
344 
348  getChecked(unsigned flags, Type storageType, Type expressedType,
349  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
350  int32_t quantizedDimension, int64_t storageTypeMin,
351  int64_t storageTypeMax, Location location);
352 
354  static LogicalResult verifyConstructionInvariants(
355  Optional<Location> loc, MLIRContext *context, unsigned flags,
356  Type storageType, Type expressedType, ArrayRef<double> scales,
357  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
358  int64_t storageTypeMin, int64_t storageTypeMax);
359 
361  static bool kindof(unsigned kind) {
363  }
364 
369  ArrayRef<double> getScales() const;
370 
374  ArrayRef<int64_t> getZeroPoints() const;
375 
384  int32_t getQuantizedDimension() const;
385 
391  bool isFixedPoint() const {
392  if (!isSigned())
393  return false;
394  return llvm::all_of(getZeroPoints(),
395  [](int64_t zeroPoint) { return zeroPoint != 0; });
396  }
397 };
398 
399 } // namespace quant
400 } // namespace mlir
401 
402 #endif // MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
Definition: InferTypeOpInterface.cpp:20
FlagValue
Definition: QuantTypes.h:45
static bool kindof(unsigned kind)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.h:361
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Definition: QuantTypes.h:82
Definition: QuantTypes.h:270
Definition: QuantTypes.h:36
Definition: LLVM.h:40
Definition: Location.h:52
bool isIndex()
Definition: StandardTypes.cpp:30
static bool kindof(unsigned kind)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.h:297
Definition: LogicalResult.h:18
Definition: LLVM.h:37
unsigned getKind() const
Return the classification for this type.
Definition: Types.cpp:22
bool isFixedPoint() const
Definition: QuantTypes.h:391
Definition: QuantTypes.h:48
bool isF32()
Definition: StandardTypes.cpp:27
Definition: QuantTypes.h:209
bool isBF16()
Definition: StandardTypes.cpp:25
Type()
Definition: Types.h:111
Definition: QuantTypes.h:331
bool isSigned() const
Definition: QuantTypes.h:117
Definition: Types.h:84
Kind
Definition: QuantTypes.h:35
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.h:75
Definition: QuantTypes.h:60
Definition: StorageUniquerSupport.h:30
Definition: MLIRContext.h:34
bool isInteger(unsigned width)
Return true if this is an integer type with the specified width.
Definition: StandardTypes.cpp:33
bool isF16()
Definition: StandardTypes.cpp:26
bool isF64()
Definition: StandardTypes.cpp:28
bool isFixedPoint() const
Definition: QuantTypes.h:314
static bool kindof(unsigned kind)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.h:216
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Definition: QuantTypes.h:92