My Project
ConstraintAnalysisGraph.h
Go to the documentation of this file.
1 //===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides graph-based data structures for representing anchors
10 // and constraints between them.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
15 #define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
16 
17 #include <utility>
18 #include <vector>
19 
20 #include "mlir/IR/Function.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Module.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/Types.h"
26 #include "llvm/ADT/DenseMap.h"
27 
28 namespace mlir {
29 namespace quantizer {
30 
31 class CAGNode;
32 class CAGSlice;
33 class TargetConfiguration;
34 
45 class CAGNode {
46 public:
47  enum class Kind {
49  Anchor,
53 
55  Constraint,
59  };
60 
61  // Vector and iterator over nodes.
63  using iterator = node_vector::iterator;
64  using const_iterator = node_vector::const_iterator;
65 
66  virtual ~CAGNode() = default;
67 
68  Kind getKind() const { return kind; }
69 
71  int getNodeId() const { return nodeId; }
72 
74  bool isDirty() const { return dirty; }
75  void markDirty() { dirty = true; }
76  void clearDirty() { dirty = false; }
77 
79  const_iterator begin() const { return outgoing.begin(); }
80  const_iterator end() const { return outgoing.end(); }
81  iterator begin() { return outgoing.begin(); }
82  iterator end() { return outgoing.end(); }
83 
85  const_iterator incoming_begin() const { return incoming.begin(); }
86  const_iterator incoming_end() const { return incoming.end(); }
87  iterator incoming_begin() { return incoming.begin(); }
88  iterator incoming_end() { return incoming.end(); }
89 
90  virtual void propagate(SolverContext &solverContext,
91  const TargetConfiguration &config) {}
92 
94  virtual void printLabel(raw_ostream &os) const;
95 
96  template <typename T> void findChildrenOfKind(SmallVectorImpl<T *> &found) {
97  for (CAGNode *child : *this) {
98  T *ofKind = dyn_cast<T>(child);
99  if (ofKind) {
100  found.push_back(ofKind);
101  }
102  }
103  }
104 
107  void replaceIncoming(CAGNode *otherNode);
108 
111  void addOutgoing(CAGNode *toNode);
112 
114  bool isOrphan() const { return incoming.empty() && outgoing.empty(); }
115 
116 protected:
117  CAGNode(Kind kind) : kind(kind) {}
118 
119 private:
120  Kind kind;
121  int nodeId = -1;
122  node_vector outgoing;
123  node_vector incoming;
124  bool dirty = false;
125 
126  friend class CAGSlice;
127 };
128 
132 class CAGAnchorNode : public CAGNode {
133 public:
134  enum class TypeTransformRule {
137  Direct,
138 
142  DirectStorage,
143 
147  ExpressedOnly,
148  };
149 
151  CAGUniformMetadata &getUniformMetadata() { return uniformMetadata; }
153  return uniformMetadata;
154  }
155 
156  virtual Operation *getOp() const = 0;
157  virtual Value getValue() const = 0;
158 
159  static bool classof(const CAGNode *n) {
160  return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor;
161  }
162 
163  void propagate(SolverContext &solverContext,
164  const TargetConfiguration &config) override;
165 
166  void printLabel(raw_ostream &os) const override;
167 
170  Type getTransformedType();
171 
172  TypeTransformRule getTypeTransformRule() const { return typeTransformRule; }
173 
174  void setTypeTransformRule(TypeTransformRule r) { typeTransformRule = r; }
175 
178  Type getOriginalType() const { return originalType; }
179 
180 protected:
181  CAGAnchorNode(Kind kind, Type originalType)
182  : CAGNode(kind), originalType(originalType) {}
183 
184 private:
185  CAGUniformMetadata uniformMetadata;
186  Type originalType;
187  TypeTransformRule typeTransformRule = TypeTransformRule::Direct;
188 };
189 
194 public:
195  CAGOperandAnchor(Operation *op, unsigned operandIdx);
196 
197  Operation *getOp() const final { return op; }
198  unsigned getOperandIdx() const { return operandIdx; }
199 
200  static bool classof(const CAGNode *n) {
201  return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor;
202  }
203 
204  Value getValue() const final { return op->getOperand(operandIdx); }
205 
206  void printLabel(raw_ostream &os) const override;
207 
208 private:
209  Operation *op;
210  unsigned operandIdx;
211 };
212 
217 public:
218  CAGResultAnchor(Operation *op, unsigned resultIdx);
219 
220  static bool classof(const CAGNode *n) {
221  return n->getKind() == Kind::Anchor || n->getKind() == Kind::ResultAnchor;
222  }
223 
224  Operation *getOp() const final { return resultValue->getDefiningOp(); }
225  Value getValue() const final { return resultValue; }
226 
227  void printLabel(raw_ostream &os) const override;
228 
229 private:
230  Value resultValue;
231 };
232 
234 class CAGConstraintNode : public CAGNode {
235 public:
236  CAGConstraintNode(Kind kind) : CAGNode(kind) {}
237 
238  static bool classof(const CAGNode *n) {
239  return n->getKind() >= Kind::Constraint &&
241  }
242 };
243 
245 class CAGSlice {
246 public:
247  CAGSlice(SolverContext &context);
248  ~CAGSlice();
249 
250  using node_vector = std::vector<CAGNode *>;
251  using iterator = node_vector::iterator;
252  using const_iterator = node_vector::const_iterator;
253 
254  iterator begin() { return allNodes.begin(); }
255  iterator end() { return allNodes.end(); }
256  const_iterator begin() const { return allNodes.begin(); }
257  const_iterator end() const { return allNodes.end(); }
258 
260  CAGOperandAnchor *getOperandAnchor(Operation *op, unsigned operandIdx);
261 
263  CAGResultAnchor *getResultAnchor(Operation *op, unsigned resultIdx);
264 
267  template <typename T, typename... Args>
268  T *addUniqueConstraint(ArrayRef<CAGAnchorNode *> anchors, Args... args) {
269  static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
270  "T must be a CAGConstraingNode");
271  T *constraintNode = addNode(std::make_unique<T>(args...));
272  for (auto *anchor : anchors)
273  anchor->addOutgoing(constraintNode);
274  return constraintNode;
275  }
276 
278  template <typename T, typename... Args>
280  ArrayRef<CAGAnchorNode *> toAnchors,
281  Args... args) {
282  static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
283  "T must be a CAGConstraingNode");
284  T *constraintNode = addNode(std::make_unique<T>(args...));
285  fromAnchor->addOutgoing(constraintNode);
286  for (auto *toAnchor : toAnchors) {
287  constraintNode->addOutgoing(toAnchor);
288  }
289  return constraintNode;
290  }
291 
292  template <typename T>
294  static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
295  "T must be a CAGConstraingNode");
296  SmallVector<T *, 8> cluster;
297  for (auto *anchor : anchors) {
298  anchor->findChildrenOfKind<T>(cluster);
299  }
300 
301  T *constraintNode;
302  if (cluster.empty()) {
303  // Create new.
304  constraintNode = addNode(std::make_unique<T>());
305  } else {
306  // Merge existing.
307  constraintNode = cluster[0];
308  for (size_t i = 1, e = cluster.size(); i < e; ++i) {
309  cluster[i]->replaceIncoming(constraintNode);
310  }
311  }
312  for (auto *anchor : anchors) {
313  anchor->addOutgoing(constraintNode);
314  }
315  return constraintNode;
316  }
317 
326  void enumerateImpliedConnections(
327  std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback);
328 
332  unsigned propagate(const TargetConfiguration &config);
333 
334 private:
338  template <typename T>
339  T *addNode(std::unique_ptr<T> node) {
340  node->nodeId = allNodes.size();
341  T *unownedNode = node.release();
342  allNodes.push_back(unownedNode);
343  return unownedNode;
344  }
345 
346  SolverContext &context;
347  std::vector<CAGNode *> allNodes;
350 };
351 
352 inline raw_ostream &operator<<(raw_ostream &os, const CAGNode &node) {
353  node.printLabel(os);
354  return os;
355 }
356 
357 } // namespace quantizer
358 } // namespace mlir
359 
360 #endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
Definition: ConstraintAnalysisGraph.h:193
Definition: InferTypeOpInterface.cpp:20
static bool classof(const CAGNode *n)
Definition: ConstraintAnalysisGraph.h:200
const_iterator begin() const
Iterator over this node&#39;s children (outgoing) nodes.
Definition: ConstraintAnalysisGraph.h:79
const CAGUniformMetadata & getUniformMetadata() const
Definition: ConstraintAnalysisGraph.h:152
Definition: ConstraintAnalysisGraph.h:45
int getNodeId() const
Unique id of the node within the slice.
Definition: ConstraintAnalysisGraph.h:71
Definition: Operation.h:27
const_iterator end() const
Definition: ConstraintAnalysisGraph.h:257
T * addUniqueConstraint(ArrayRef< CAGAnchorNode *> anchors, Args... args)
Definition: ConstraintAnalysisGraph.h:268
Definition: LLVM.h:48
CAGAnchorNode(Kind kind, Type originalType)
Definition: ConstraintAnalysisGraph.h:181
node_vector::iterator iterator
Definition: ConstraintAnalysisGraph.h:63
Value getOperand(unsigned idx)
Definition: Operation.h:207
virtual void printLabel(raw_ostream &os) const
Prints the node label, suitable for one-line display.
Definition: ConstraintAnalysisGraph.cpp:151
CAGUniformMetadata & getUniformMetadata()
Metadata for solving uniform quantization params.
Definition: ConstraintAnalysisGraph.h:151
friend class CAGSlice
Definition: ConstraintAnalysisGraph.h:126
virtual ~CAGNode()=default
Definition: LLVM.h:34
void addOutgoing(CAGNode *toNode)
Definition: ConstraintAnalysisGraph.cpp:32
const_iterator incoming_end() const
Definition: ConstraintAnalysisGraph.h:86
iterator begin()
Definition: ConstraintAnalysisGraph.h:254
static bool classof(const CAGNode *n)
Definition: ConstraintAnalysisGraph.h:159
A slice of a CAG (which may be the whole graph).
Definition: ConstraintAnalysisGraph.h:245
const_iterator incoming_begin() const
Iterator over this parents (incoming) nodes.
Definition: ConstraintAnalysisGraph.h:85
Kind getKind() const
Definition: ConstraintAnalysisGraph.h:68
Value getValue() const final
Definition: ConstraintAnalysisGraph.h:204
Definition: LLVM.h:37
node_vector::const_iterator const_iterator
Definition: ConstraintAnalysisGraph.h:252
iterator begin()
Definition: ConstraintAnalysisGraph.h:81
virtual void propagate(SolverContext &solverContext, const TargetConfiguration &config)
Definition: ConstraintAnalysisGraph.h:90
iterator end()
Definition: ConstraintAnalysisGraph.h:255
Definition: Metadata.h:67
Definition: Configuration.h:43
CAGNode(Kind kind)
Definition: ConstraintAnalysisGraph.h:117
iterator end()
Definition: ConstraintAnalysisGraph.h:82
raw_ostream & operator<<(raw_ostream &os, const CAGNode &node)
Definition: ConstraintAnalysisGraph.h:352
Type getOriginalType() const
Definition: ConstraintAnalysisGraph.h:178
TypeTransformRule
Definition: ConstraintAnalysisGraph.h:134
Definition: ConstraintAnalysisGraph.h:132
Definition: Metadata.h:28
iterator incoming_begin()
Definition: ConstraintAnalysisGraph.h:87
Definition: Types.h:84
bool isDirty() const
Whether the node is dirty, requiring one or more calls to propagate().
Definition: ConstraintAnalysisGraph.h:74
static bool classof(const CAGNode *n)
Definition: ConstraintAnalysisGraph.h:238
Definition: Value.h:38
Operation * getOp() const final
Definition: ConstraintAnalysisGraph.h:224
void clearDirty()
Definition: ConstraintAnalysisGraph.h:76
Definition: LLVM.h:35
static bool classof(const CAGNode *n)
Definition: ConstraintAnalysisGraph.h:220
const_iterator begin() const
Definition: ConstraintAnalysisGraph.h:256
const_iterator end() const
Definition: ConstraintAnalysisGraph.h:80
void markDirty()
Definition: ConstraintAnalysisGraph.h:75
void findChildrenOfKind(SmallVectorImpl< T *> &found)
Definition: ConstraintAnalysisGraph.h:96
TypeTransformRule getTypeTransformRule() const
Definition: ConstraintAnalysisGraph.h:172
Base class for constraint nodes.
Definition: ConstraintAnalysisGraph.h:234
iterator incoming_end()
Definition: ConstraintAnalysisGraph.h:88
std::vector< CAGNode * > node_vector
Definition: ConstraintAnalysisGraph.h:250
T * addUnidirectionalConstraint(CAGAnchorNode *fromAnchor, ArrayRef< CAGAnchorNode *> toAnchors, Args... args)
Adds a unidirectional constraint from a node to an array of target nodes.
Definition: ConstraintAnalysisGraph.h:279
Operation * getOp() const final
Definition: ConstraintAnalysisGraph.h:197
Kind
Definition: ConstraintAnalysisGraph.h:47
bool isOrphan() const
Whether this node is an orphan (has no incoming or outgoing connections).
Definition: ConstraintAnalysisGraph.h:114
unsigned getOperandIdx() const
Definition: ConstraintAnalysisGraph.h:198
void replaceIncoming(CAGNode *otherNode)
Definition: ConstraintAnalysisGraph.cpp:18
node_vector::const_iterator const_iterator
Definition: ConstraintAnalysisGraph.h:64
void setTypeTransformRule(TypeTransformRule r)
Definition: ConstraintAnalysisGraph.h:174
Value getValue() const final
Definition: ConstraintAnalysisGraph.h:225
T * addClusteredConstraint(ArrayRef< CAGAnchorNode *> anchors)
Definition: ConstraintAnalysisGraph.h:293
Definition: ConstraintAnalysisGraph.h:216
CAGConstraintNode(Kind kind)
Definition: ConstraintAnalysisGraph.h:236
node_vector::iterator iterator
Definition: ConstraintAnalysisGraph.h:251