My Project
UniformKernelUtils.h
Go to the documentation of this file.
1 //===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
10 #define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
11 
15 #include "mlir/IR/Operation.h"
16 
17 #include <cmath>
18 
19 namespace mlir {
20 namespace fxpmath {
21 namespace detail {
22 
26 }
27 
29  ArrayRef<unsigned> checkWidths) {
30  unsigned w = t.getStorageType().getIntOrFloatBitWidth();
31  for (unsigned checkWidth : checkWidths) {
32  if (w == checkWidth)
33  return true;
34  }
35  return false;
36 }
37 
40 template <typename F> bool integralLog2(F x, int &log2Result) {
41  const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
42  const F xLog2Rounded = std::round(xLog2);
43  const F xLog2Frac = xLog2 - xLog2Rounded;
44  log2Result = static_cast<int>(xLog2Rounded);
45  // Allow small comparison slop below the level that would make a difference
46  // for 2^16 levels.
47  return std::abs(xLog2Frac) < 1e-6;
48 }
49 
55  : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
56  lhsType(getUniformElementType(lhs->getType())),
57  rhsType(getUniformElementType(rhs->getType())),
58  resultType(getUniformElementType(*op->result_type_begin())),
59  lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
60  rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
62  quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
63  }
64 
66  bool isValid() const {
67  return lhsType && rhsType && resultType && lhsStorageType &&
69  }
70 
73 
75  bool isSameStorageType() const {
76  return lhsType.getStorageType() == rhsType.getStorageType() &&
77  lhsType.getStorageType() == resultType.getStorageType();
78  }
79 
82  bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
83  int &resultLog2Scale) const {
86  return false;
87  }
88 
89  if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
90  !integralLog2(rhsType.getScale(), rhsLog2Scale) ||
91  !integralLog2(resultType.getScale(), resultLog2Scale)) {
92  return false;
93  }
94 
95  return true;
96  }
97 
99  // and any explicit clamp provided as attributes.
100  std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
101  int64_t typeMin = resultType.getStorageTypeMin();
102  int64_t typeMax = resultType.getStorageTypeMax();
103 
104  if (clampMin || clampMax) {
106  if (clampMin) {
107  typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
108  }
109  if (clampMax) {
110  typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
111  }
112  }
113 
114  // The quantized, integral ops expect clamps as 32bit ints.
115  return {
116  IntegerAttr::get(ty, typeMin),
117  IntegerAttr::get(ty, typeMax),
118  };
119  }
120 
126 
127  // Element UniformQuantizedType for operands/result.
131 
132  // Full storage-based types.
136 };
137 
141  QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
142  assert(realMultiplier < 1.0);
143  assert(realMultiplier > 0.0);
144 
145  const double q = std::frexp(realMultiplier, &exponent);
146  auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
147  assert(qFixed <= (1ll << 31));
148  if (qFixed == (1ll << 31)) {
149  qFixed /= 2;
150  ++exponent;
151  }
152  assert(qFixed <= std::numeric_limits<int32_t>::max());
153  multiplier = static_cast<int32_t>(qFixed);
154  }
155 
156  int32_t multiplier;
157  int exponent;
158 };
159 
161 inline Type castElementType(Type t, Type newElementType) {
162  if (auto st = t.dyn_cast<ShapedType>()) {
163  switch (st.getKind()) {
165  return VectorType::get(st.getShape(), newElementType);
167  return RankedTensorType::get(st.getShape(), newElementType);
169  return UnrankedTensorType::get(newElementType);
171  return MemRefType::get(st.getShape(), newElementType,
172  st.cast<MemRefType>().getAffineMaps());
173  }
174  }
175  assert(t.isIntOrFloat());
176  return newElementType;
177 }
178 
181 inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
182  if (auto st = t.dyn_cast<ShapedType>()) {
183  assert(st.getElementType().isa<IntegerType>());
184  return DenseElementsAttr::get(st,
185  IntegerAttr::get(st.getElementType(), value));
186  }
187 
188  auto integerType = t.cast<IntegerType>();
189  assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
190  return IntegerAttr::get(integerType, value);
191 }
192 
195 inline APFloat convertFloatToType(FloatType ft, APFloat value) {
196  bool losesInfo;
197  auto status = value.convert(ft.getFloatSemantics(),
198  APFloat::rmNearestTiesToEven, &losesInfo);
199  (void)status; // unused in opt mode
200  assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
201  "could not convert to float const");
202  return value;
203 }
204 
208  if (auto st = t.dyn_cast<ShapedType>()) {
209  FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
210  assert(floatElementType &&
211  "float broadcast element type must be float like");
212  APFloat apValue = convertFloatToType(floatElementType, value);
213  return DenseElementsAttr::get(st,
214  FloatAttr::get(st.getElementType(), apValue));
215  } else {
216  auto floatType = t.dyn_cast<FloatType>();
217  assert(floatType && "float broadcast must be of float type");
218  APFloat apValue = convertFloatToType(floatType, value);
219  return FloatAttr::get(floatType, apValue);
220  }
221 }
222 
223 } // namespace detail
224 } // namespace fxpmath
225 } // namespace mlir
226 
227 #endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
Definition: InferTypeOpInterface.cpp:20
Type lhsStorageType
Definition: UniformKernelUtils.h:133
unsigned getIntOrFloatBitWidth()
Definition: StandardTypes.cpp:103
quant::UniformQuantizedType rhsType
Definition: UniformKernelUtils.h:129
bool isValid() const
Returns whether this info is valid (all types defined, etc).
Definition: UniformKernelUtils.h:66
bool isSameStorageType() const
Returns whether the storage type of all operands is identical.
Definition: UniformKernelUtils.h:75
bool integralLog2(F x, int &log2Result)
Definition: UniformKernelUtils.h:40
Type getStorageType() const
Definition: QuantTypes.cpp:58
Definition: Operation.h:27
APFloat convertFloatToType(FloatType ft, APFloat value)
Definition: UniformKernelUtils.h:195
Integer types can have arbitrary bitwidth up to a large fixed limit.
Definition: StandardTypes.h:82
QuantizedMultiplierSmallerThanOneExp(double realMultiplier)
Definition: UniformKernelUtils.h:141
Definition: StandardTypes.h:59
static UnrankedTensorType get(Type elementType)
Definition: StandardTypes.cpp:279
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Definition: QuantTypes.cpp:89
Definition: QuantTypes.h:270
Definition: UniformKernelUtils.h:52
Definition: UniformSupport.h:59
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Definition: Attributes.cpp:572
static MemRefType get(ArrayRef< int64_t > shape, Type elementType, ArrayRef< AffineMap > affineMapComposition={}, unsigned memorySpace=0)
Definition: StandardTypes.cpp:303
Definition: StandardTypes.h:113
std::pair< IntegerAttr, IntegerAttr > getClampMinMax(IntegerType ty) const
Gets the result integer clamp range given the result quantized type.
Definition: UniformKernelUtils.h:100
bool isIntOrFloat()
Return true of this is an integer or a float type.
Definition: StandardTypes.cpp:45
Definition: LLVM.h:40
int64_t quantizeFloatToInt64(APFloat expressedValue) const
Definition: UniformSupport.h:118
bool hasStorageBitWidth(quant::QuantizedType t, ArrayRef< unsigned > checkWidths)
Definition: UniformKernelUtils.h:28
Definition: StandardTypes.h:178
Attribute broadcastScalarConstIntValue(Type t, int64_t value)
Definition: UniformKernelUtils.h:181
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
Definition: StandardTypes.cpp:86
Definition: StandardTypes.h:56
Operation * op
Definition: UniformKernelUtils.h:121
Type castElementType(Type t, Type newElementType)
Casts an integer or floating point based shaped type to a new element type.
Definition: UniformKernelUtils.h:161
Definition: LLVM.h:37
Value rhs
Definition: UniformKernelUtils.h:123
Attribute broadcastScalarConstFloatValue(Type t, APFloat value)
Definition: UniformKernelUtils.h:207
U dyn_cast_or_null() const
Definition: Types.h:261
Definition: StandardTypes.h:58
U dyn_cast() const
Definition: Types.h:258
Definition: Attributes.h:53
int32_t multiplier
Definition: UniformKernelUtils.h:156
static IntegerAttr get(Type type, int64_t value)
Definition: Attributes.cpp:271
Type getQuantizedResultType() const
Gets the final quantized result type of the result.
Definition: UniformKernelUtils.h:72
Definition: StandardTypes.h:390
int exponent
Definition: UniformKernelUtils.h:157
Type resultStorageType
Definition: UniformKernelUtils.h:135
double getScale() const
Definition: QuantTypes.cpp:287
Definition: Types.h:84
Definition: Value.h:38
static RankedTensorType get(ArrayRef< int64_t > shape, Type elementType)
Definition: StandardTypes.cpp:248
static VectorType get(ArrayRef< int64_t > shape, Type elementType)
Definition: StandardTypes.cpp:199
Definition: StandardTypes.h:57
Value lhs
Definition: UniformKernelUtils.h:122
result_type_iterator result_type_begin()
Definition: Operation.h:262
Definition: QuantTypes.h:60
Optional< APFloat > clampMax
Definition: UniformKernelUtils.h:125
UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs, Optional< APFloat > clampMin, Optional< APFloat > clampMax)
Definition: UniformKernelUtils.h:53
static FloatAttr get(Type type, double value)
Definition: Attributes.cpp:175
Optional< APFloat > clampMin
Definition: UniformKernelUtils.h:124
Type rhsStorageType
Definition: UniformKernelUtils.h:134
bool isa() const
Definition: Types.h:254
bool isFixedPoint() const
Definition: QuantTypes.h:314
bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale, int &resultLog2Scale) const
Definition: UniformKernelUtils.h:82
quant::UniformQuantizedType getUniformElementType(Type t)
Definition: UniformKernelUtils.h:23
quant::UniformQuantizedType resultType
Definition: UniformKernelUtils.h:130
U cast() const
Definition: Types.h:264
quant::UniformQuantizedType lhsType
Definition: UniformKernelUtils.h:128