My Project
TypeDetail.h
Go to the documentation of this file.
1 //===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef TYPE_DETAIL_H_
10 #define TYPE_DETAIL_H_
11 
12 #include "mlir/IR/StandardTypes.h"
13 #include "mlir/IR/TypeSupport.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Hashing.h"
17 #include "llvm/ADT/bit.h"
18 
19 namespace mlir {
20 namespace quant {
21 namespace detail {
22 
25  int64_t storageTypeMin, int64_t storageTypeMax)
26  : flags(flags), storageType(storageType), expressedType(expressedType),
27  storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
28 
30  unsigned flags;
31 
32  // Integral type for the storage point representation.
34 
35  // Floating point type that the quantized type approximates.
37 
38  // The minimum value storageType can take.
39  int64_t storageTypeMin;
40 
41  // The maximum value storageType can take.
42  int64_t storageTypeMax;
43 };
44 
46  struct KeyTy {
48  int64_t storageTypeMin, int64_t storageTypeMax)
49  : flags(flags), storageType(storageType), expressedType(expressedType),
50  storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
51  unsigned flags;
54  int64_t storageTypeMin;
55  int64_t storageTypeMax;
56 
57  // Check for equality of two structures that share KeyTy data members
58  // (by name).
59  template <typename T, typename U>
60  static bool genericIsEqual(const T &lhs, const U &rhs) {
61  return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
62  lhs.expressedType == rhs.expressedType &&
63  lhs.storageTypeMin == rhs.storageTypeMin &&
64  lhs.storageTypeMax == rhs.storageTypeMax;
65  }
66 
67  bool operator==(const KeyTy &other) const {
68  return genericIsEqual(*this, other);
69  }
70 
71  unsigned getHashValue() const {
72  return llvm::hash_combine(flags, storageType, expressedType,
73  storageTypeMin, storageTypeMax);
74  }
75  };
76 
79  key.storageTypeMin, key.storageTypeMax) {}
80 
81  bool operator==(const KeyTy &key) const {
82  return KeyTy::genericIsEqual(*this, key);
83  }
84 
87  const KeyTy &key) {
88  return new (allocator.allocate<AnyQuantizedTypeStorage>())
90  }
91 
92  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
93 };
94 
96  struct KeyTy {
97  KeyTy(unsigned flags, Type storageType, Type expressedType, double scale,
98  int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
99  : flags(flags), storageType(storageType), expressedType(expressedType),
100  scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin),
101  storageTypeMax(storageTypeMax) {}
103  unsigned flags;
104 
105  // Integral type for the storage point representation.
107 
108  // Floating point type that the quantized type approximates.
110 
111  double scale;
112  int64_t zeroPoint;
113  int64_t storageTypeMin;
114  int64_t storageTypeMax;
115 
116  // Check for equality of two structures that share KeyTy data members
117  // (by name).
118  template <typename T, typename U>
119  static bool genericIsEqual(const T &lhs, const U &rhs) {
120  return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
121  lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale &&
122  lhs.zeroPoint == rhs.zeroPoint &&
123  lhs.storageTypeMin == rhs.storageTypeMin &&
124  lhs.storageTypeMax == rhs.storageTypeMax;
125  }
126 
127  bool operator==(const KeyTy &other) const {
128  return genericIsEqual(*this, other);
129  }
130 
131  unsigned getHashValue() const {
132  int64_t scaleBits = llvm::bit_cast<int64_t>(scale);
133  return llvm::hash_combine(flags, storageType, expressedType, scaleBits,
134  zeroPoint, storageTypeMin, storageTypeMax);
135  }
136  };
137 
140  key.storageTypeMin, key.storageTypeMax),
141  scale(key.scale), zeroPoint(key.zeroPoint) {}
142 
143  bool operator==(const KeyTy &key) const {
144  return KeyTy::genericIsEqual(*this, key);
145  }
146 
149  const KeyTy &key) {
150  return new (allocator.allocate<UniformQuantizedTypeStorage>())
152  }
153 
154  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
155 
156  double scale;
157  int64_t zeroPoint;
158 };
159 
161  struct KeyTy {
163  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
164  int32_t quantizedDimension, int64_t storageTypeMin,
165  int64_t storageTypeMax)
166  : flags(flags), storageType(storageType), expressedType(expressedType),
167  scales(scales), zeroPoints(zeroPoints),
168  quantizedDimension(quantizedDimension),
169  storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
171  unsigned flags;
172 
173  // Integral type for the storage point representation.
175 
176  // Floating point type that the quantized type approximates.
178 
182  int64_t storageTypeMin;
183  int64_t storageTypeMax;
184 
185  ArrayRef<double> getScales() const { return scales; }
186 
187  ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; }
188 
189  // Check for equality of two structures that share KeyTy data members
190  // (by name).
191  template <typename T, typename U>
192  static bool genericIsEqual(const T &lhs, const U &rhs) {
193  return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
194  lhs.expressedType == rhs.expressedType &&
195  lhs.getScales() == rhs.getScales() &&
196  lhs.getZeroPoints() == rhs.getZeroPoints() &&
197  lhs.quantizedDimension == rhs.quantizedDimension &&
198  lhs.storageTypeMin == rhs.storageTypeMin &&
199  lhs.storageTypeMax == rhs.storageTypeMax;
200  }
201 
202  bool operator==(const KeyTy &other) const {
203  return genericIsEqual(*this, other);
204  }
205 
206  unsigned getHashValue() const {
207  int64_t *scalesCast = llvm::bit_cast<int64_t *>(scales.data());
208  ArrayRef<int64_t> scalesBits(scalesCast, scales.size());
209  return llvm::hash_combine(
210  flags, storageType, expressedType,
211  llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()),
212  llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()),
213  storageTypeMin, storageTypeMax);
214  }
215  };
216 
217  // We pass scales and zeroPoints in directly rather than relying on KeyTy
218  // because we have to create new reallocated versions in `construct` below.
220  ArrayRef<int64_t> zeroPoints)
222  key.storageTypeMin, key.storageTypeMax),
223  scaleElements(scales.data()), zeroPointElements(zeroPoints.data()),
224  quantParamsSize(scales.size()),
225  quantizedDimension(key.quantizedDimension) {}
226 
227  bool operator==(const KeyTy &key) const {
228  return KeyTy::genericIsEqual(*this, key);
229  }
230 
233  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
234  ArrayRef<double> scales = allocator.copyInto(key.scales);
235  ArrayRef<int64_t> zeroPoints = allocator.copyInto(key.zeroPoints);
236  return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>())
237  UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints);
238  }
239 
240  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
241 
243  return ArrayRef<double>(scaleElements, quantParamsSize);
244  }
245 
247  return ArrayRef<int64_t>(zeroPointElements, quantParamsSize);
248  }
249 
250  const double *scaleElements;
251  const int64_t *zeroPointElements;
252  unsigned quantParamsSize;
254 };
255 
256 } // namespace detail
257 } // namespace quant
258 } // namespace mlir
259 
260 #endif // TYPE_DETAIL_H_
Definition: InferTypeOpInterface.cpp:20
UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints)
Definition: TypeDetail.h:219
int64_t storageTypeMax
Definition: TypeDetail.h:114
int64_t zeroPoint
Definition: TypeDetail.h:157
double scale
Definition: TypeDetail.h:156
int64_t storageTypeMin
Definition: TypeDetail.h:113
KeyTy(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: TypeDetail.h:162
int32_t quantizedDimension
Definition: TypeDetail.h:253
Base storage class appearing in a Type.
Definition: TypeSupport.h:33
static bool genericIsEqual(const T &lhs, const U &rhs)
Definition: TypeDetail.h:192
Definition: StorageUniquer.h:89
const double * scaleElements
Definition: TypeDetail.h:250
ArrayRef< int64_t > getZeroPoints() const
Definition: TypeDetail.h:246
static UniformQuantizedPerAxisTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Construction.
Definition: TypeDetail.h:233
static unsigned hashKey(const KeyTy &key)
Definition: TypeDetail.h:92
int64_t storageTypeMax
Definition: TypeDetail.h:42
ArrayRef< int64_t > zeroPoints
Definition: TypeDetail.h:180
static AnyQuantizedTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Construction.
Definition: TypeDetail.h:86
bool operator==(const KeyTy &other) const
Definition: TypeDetail.h:202
unsigned getHashValue() const
Definition: TypeDetail.h:206
Type storageType
Definition: TypeDetail.h:33
unsigned flags
Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
Definition: TypeDetail.h:103
Type storageType
Definition: TypeDetail.h:52
static UniformQuantizedTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Construction.
Definition: TypeDetail.h:148
unsigned flags
Definition: TypeDetail.h:51
T * allocate()
Allocate an instance of the provided type.
Definition: StorageUniquer.h:109
bool operator==(const KeyTy &key) const
Definition: TypeDetail.h:143
ArrayRef< double > scales
Definition: TypeDetail.h:179
Definition: LLVM.h:37
ArrayRef< double > getScales() const
Definition: TypeDetail.h:185
KeyTy(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: TypeDetail.h:97
unsigned quantParamsSize
Definition: TypeDetail.h:252
unsigned getHashValue() const
Definition: TypeDetail.h:131
UniformQuantizedTypeStorage(const KeyTy &key)
Definition: TypeDetail.h:138
KeyTy(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: TypeDetail.h:47
QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: TypeDetail.h:24
Definition: Types.h:84
bool operator==(const KeyTy &key) const
Definition: TypeDetail.h:227
AnyQuantizedTypeStorage(const KeyTy &key)
Definition: TypeDetail.h:77
unsigned flags
Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
Definition: TypeDetail.h:30
unsigned flags
Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
Definition: TypeDetail.h:171
Type expressedType
Definition: TypeDetail.h:36
int64_t zeroPoint
Definition: TypeDetail.h:112
unsigned getHashValue() const
Definition: TypeDetail.h:71
static bool genericIsEqual(const T &lhs, const U &rhs)
Definition: TypeDetail.h:60
bool operator==(const KeyTy &other) const
Definition: TypeDetail.h:67
ArrayRef< T > copyInto(ArrayRef< T > elements)
Definition: StorageUniquer.h:93
Type expressedType
Definition: TypeDetail.h:53
static unsigned hashKey(const KeyTy &key)
Definition: TypeDetail.h:154
int64_t storageTypeMax
Definition: TypeDetail.h:55
int64_t storageTypeMin
Definition: TypeDetail.h:54
int64_t storageTypeMin
Definition: TypeDetail.h:39
ArrayRef< double > getScales() const
Definition: TypeDetail.h:242
bool operator==(const KeyTy &other) const
Definition: TypeDetail.h:127
const int64_t * zeroPointElements
Definition: TypeDetail.h:251
static unsigned hashKey(const KeyTy &key)
Definition: TypeDetail.h:240
bool operator==(const KeyTy &key) const
Definition: TypeDetail.h:81
static bool genericIsEqual(const T &lhs, const U &rhs)
Definition: TypeDetail.h:119
ArrayRef< int64_t > getZeroPoints() const
Definition: TypeDetail.h:187