My Project
Configuration.h
Go to the documentation of this file.
1 //===- Configuration.h - Configuration object base classes ------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The quantizer is relatively agnostic to source and target dialects, with
10 // the specific represented by configuration policy objects derived from
11 // classes in this file.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
16 #define MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
17 
18 #include <functional>
19 
21 #include "mlir/IR/Identifier.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/StringSet.h"
28 
29 namespace mlir {
30 class Operation;
31 
32 namespace quantizer {
33 
34 class CAGSlice;
35 
44 public:
45  static constexpr size_t MaxSchemeIndex = 31;
46  using OpHandlerFn = std::function<void(Operation *op, CAGSlice &cag)>;
47 
49  virtual ~TargetConfiguration() = default;
50 
52  unsigned addCandidateType(quant::AnyQuantizedType quantizedType,
54  unsigned ordinal = candidateTypes.size();
55  assert(allCandidateTypesMask.size() == ordinal);
56  CandidateQuantizedType ct{ordinal, quantizedType, scheme};
57  candidateTypes.push_back(ct);
58  allCandidateTypesMask.push_back(true);
59  return ordinal;
60  }
61 
63  const CandidateQuantizedType &getCandidateType(unsigned index) const {
64  assert(index < candidateTypes.size());
65  return candidateTypes[index];
66  }
67 
69  return candidateTypes;
70  }
71 
73  llvm::SmallBitVector getAllCandidateTypesMask() const {
74  return allCandidateTypesMask;
75  }
76 
78  llvm::SmallBitVector
80  llvm::SmallBitVector disabled(allCandidateTypesMask);
81  for (unsigned ordinal : exceptOrdinals) {
82  disabled.reset(ordinal);
83  }
84  return disabled;
85  }
86 
88  template <typename OpTy>
90  addOpHandlerByName(OpTy::getOperationName(), fn);
91  }
92 
96  template <typename OpTy>
98  addRequireStatsOpByName(OpTy::getOperationName());
99  }
100 
102  bool isRequireStatsOp(Operation *op) const;
103 
108  template <typename OpTy>
110  addValueIdentityOpByName(OpTy::getOperationName());
111  }
112 
114  void handleOp(Operation *op, CAGSlice &cag) const;
115 
117  virtual void finalizeAnchors(CAGSlice &cag) const {}
118 
120  virtual bool isHandledType(Type t) const = 0;
121 
122 protected:
123  virtual void addValueIdentityOpByName(StringRef opName) = 0;
124  void addOpHandlerByName(StringRef name, OpHandlerFn fn);
125 
126 private:
127  void addRequireStatsOpByName(StringRef opName);
128 
130  std::vector<CandidateQuantizedType> candidateTypes;
131 
132  // A SmallBoolVector with bits set for all known candidate types.
133  llvm::SmallBitVector allCandidateTypesMask;
134 
136  llvm::StringMap<OpHandlerFn> opHandlers;
137 
140  llvm::StringSet<> requireStatsOpNames;
141 };
142 
143 } // namespace quantizer
144 } // namespace mlir
145 
146 #endif // MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
Definition: InferTypeOpInterface.cpp:20
ArrayRef< CandidateQuantizedType > getCandidateTypes() const
Definition: Configuration.h:68
Definition: Operation.h:27
TargetConfiguration(SolverContext &context)
Definition: Configuration.cpp:20
void addValueIdentityOp()
Definition: Configuration.h:109
virtual void finalizeAnchors(CAGSlice &cag) const
Finalizes the CAG after all anchors have been added.
Definition: Configuration.h:117
static constexpr size_t MaxSchemeIndex
Definition: Configuration.h:45
std::function< void(Operation *op, CAGSlice &cag)> OpHandlerFn
Definition: Configuration.h:46
llvm::SmallBitVector getCandidateTypeDisabledExceptMask(ArrayRef< unsigned > exceptOrdinals) const
Gets a mask with every candidate type except those in the given mask.
Definition: Configuration.h:79
A slice of a CAG (which may be the whole graph).
Definition: ConstraintAnalysisGraph.h:245
llvm::SmallBitVector getAllCandidateTypesMask() const
Gets a mask of all enabled candidate types by ordinal.
Definition: Configuration.h:73
Definition: LLVM.h:37
Definition: Configuration.h:43
Definition: QuantTypes.h:209
void handleOp(Operation *op, CAGSlice &cag) const
Handles the operation if a handler is defined for it.
Definition: Configuration.cpp:35
void addOpHandler(OpHandlerFn fn)
Adds an op handler.
Definition: Configuration.h:89
const CandidateQuantizedType & getCandidateType(unsigned index) const
Gets a prototype scheme by index.
Definition: Configuration.h:63
Definition: Metadata.h:28
void addOpHandlerByName(StringRef name, OpHandlerFn fn)
Definition: Configuration.cpp:22
Definition: Types.h:84
unsigned addCandidateType(quant::AnyQuantizedType quantizedType, CandidateQuantizedType::Scheme scheme)
Adds a candidate type, returning its ordinal.
Definition: Configuration.h:52
virtual void addValueIdentityOpByName(StringRef opName)=0
bool isRequireStatsOp(Operation *op) const
Returns whether opName is a RequireStatsOp.
Definition: Configuration.cpp:30
virtual bool isHandledType(Type t) const =0
Whether an operand or result type is subject to analysis by this config.
void addRequireStatsOp()
Definition: Configuration.h:97
Candidate for a quantized type conversion.
Definition: Metadata.h:47