14 #ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H 15 #define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H 26 #include "llvm/ADT/DenseMap.h" 33 class TargetConfiguration;
94 virtual void printLabel(raw_ostream &os)
const;
98 T *ofKind = dyn_cast<T>(child);
100 found.push_back(ofKind);
114 bool isOrphan()
const {
return incoming.empty() && outgoing.empty(); }
153 return uniformMetadata;
157 virtual Value getValue()
const = 0;
166 void printLabel(raw_ostream &os)
const override;
170 Type getTransformedType();
182 :
CAGNode(kind), originalType(originalType) {}
206 void printLabel(raw_ostream &os)
const override;
227 void printLabel(raw_ostream &os)
const override;
267 template <
typename T,
typename... Args>
269 static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
270 "T must be a CAGConstraingNode");
271 T *constraintNode = addNode(std::make_unique<T>(args...));
272 for (
auto *anchor : anchors)
273 anchor->addOutgoing(constraintNode);
274 return constraintNode;
278 template <
typename T,
typename... Args>
282 static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
283 "T must be a CAGConstraingNode");
284 T *constraintNode = addNode(std::make_unique<T>(args...));
286 for (
auto *toAnchor : toAnchors) {
287 constraintNode->addOutgoing(toAnchor);
289 return constraintNode;
292 template <
typename T>
294 static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
295 "T must be a CAGConstraingNode");
297 for (
auto *anchor : anchors) {
298 anchor->findChildrenOfKind<T>(cluster);
302 if (cluster.empty()) {
304 constraintNode = addNode(std::make_unique<T>());
307 constraintNode = cluster[0];
308 for (
size_t i = 1, e = cluster.size(); i < e; ++i) {
309 cluster[i]->replaceIncoming(constraintNode);
312 for (
auto *anchor : anchors) {
313 anchor->addOutgoing(constraintNode);
315 return constraintNode;
326 void enumerateImpliedConnections(
338 template <
typename T>
339 T *addNode(std::unique_ptr<T> node) {
340 node->nodeId = allNodes.size();
341 T *unownedNode = node.release();
342 allNodes.push_back(unownedNode);
347 std::vector<CAGNode *> allNodes;
360 #endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H Definition: ConstraintAnalysisGraph.h:193
Definition: InferTypeOpInterface.cpp:20
static bool classof(const CAGNode *n)
Definition: ConstraintAnalysisGraph.h:200
const_iterator begin() const
Iterator over this node's children (outgoing) nodes.
Definition: ConstraintAnalysisGraph.h:79
const CAGUniformMetadata & getUniformMetadata() const
Definition: ConstraintAnalysisGraph.h:152
Definition: ConstraintAnalysisGraph.h:45
int getNodeId() const
Unique id of the node within the slice.
Definition: ConstraintAnalysisGraph.h:71
Definition: Operation.h:27
const_iterator end() const
Definition: ConstraintAnalysisGraph.h:257
T * addUniqueConstraint(ArrayRef< CAGAnchorNode *> anchors, Args... args)
Definition: ConstraintAnalysisGraph.h:268
CAGAnchorNode(Kind kind, Type originalType)
Definition: ConstraintAnalysisGraph.h:181
node_vector::iterator iterator
Definition: ConstraintAnalysisGraph.h:63
Value getOperand(unsigned idx)
Definition: Operation.h:207
virtual void printLabel(raw_ostream &os) const
Prints the node label, suitable for one-line display.
Definition: ConstraintAnalysisGraph.cpp:151
CAGUniformMetadata & getUniformMetadata()
Metadata for solving uniform quantization params.
Definition: ConstraintAnalysisGraph.h:151
friend class CAGSlice
Definition: ConstraintAnalysisGraph.h:126
virtual ~CAGNode()=default
void addOutgoing(CAGNode *toNode)
Definition: ConstraintAnalysisGraph.cpp:32
const_iterator incoming_end() const
Definition: ConstraintAnalysisGraph.h:86
iterator begin()
Definition: ConstraintAnalysisGraph.h:254
static bool classof(const CAGNode *n)
Definition: ConstraintAnalysisGraph.h:159
A slice of a CAG (which may be the whole graph).
Definition: ConstraintAnalysisGraph.h:245
const_iterator incoming_begin() const
Iterator over this parents (incoming) nodes.
Definition: ConstraintAnalysisGraph.h:85
Kind getKind() const
Definition: ConstraintAnalysisGraph.h:68
Value getValue() const final
Definition: ConstraintAnalysisGraph.h:204
node_vector::const_iterator const_iterator
Definition: ConstraintAnalysisGraph.h:252
iterator begin()
Definition: ConstraintAnalysisGraph.h:81
virtual void propagate(SolverContext &solverContext, const TargetConfiguration &config)
Definition: ConstraintAnalysisGraph.h:90
iterator end()
Definition: ConstraintAnalysisGraph.h:255
Definition: Configuration.h:43
CAGNode(Kind kind)
Definition: ConstraintAnalysisGraph.h:117
iterator end()
Definition: ConstraintAnalysisGraph.h:82
raw_ostream & operator<<(raw_ostream &os, const CAGNode &node)
Definition: ConstraintAnalysisGraph.h:352
Type getOriginalType() const
Definition: ConstraintAnalysisGraph.h:178
TypeTransformRule
Definition: ConstraintAnalysisGraph.h:134
Definition: ConstraintAnalysisGraph.h:132
Definition: Metadata.h:28
iterator incoming_begin()
Definition: ConstraintAnalysisGraph.h:87
bool isDirty() const
Whether the node is dirty, requiring one or more calls to propagate().
Definition: ConstraintAnalysisGraph.h:74
static bool classof(const CAGNode *n)
Definition: ConstraintAnalysisGraph.h:238
Operation * getOp() const final
Definition: ConstraintAnalysisGraph.h:224
void clearDirty()
Definition: ConstraintAnalysisGraph.h:76
static bool classof(const CAGNode *n)
Definition: ConstraintAnalysisGraph.h:220
const_iterator begin() const
Definition: ConstraintAnalysisGraph.h:256
const_iterator end() const
Definition: ConstraintAnalysisGraph.h:80
void markDirty()
Definition: ConstraintAnalysisGraph.h:75
void findChildrenOfKind(SmallVectorImpl< T *> &found)
Definition: ConstraintAnalysisGraph.h:96
TypeTransformRule getTypeTransformRule() const
Definition: ConstraintAnalysisGraph.h:172
Base class for constraint nodes.
Definition: ConstraintAnalysisGraph.h:234
iterator incoming_end()
Definition: ConstraintAnalysisGraph.h:88
std::vector< CAGNode * > node_vector
Definition: ConstraintAnalysisGraph.h:250
T * addUnidirectionalConstraint(CAGAnchorNode *fromAnchor, ArrayRef< CAGAnchorNode *> toAnchors, Args... args)
Adds a unidirectional constraint from a node to an array of target nodes.
Definition: ConstraintAnalysisGraph.h:279
Operation * getOp() const final
Definition: ConstraintAnalysisGraph.h:197
Kind
Definition: ConstraintAnalysisGraph.h:47
bool isOrphan() const
Whether this node is an orphan (has no incoming or outgoing connections).
Definition: ConstraintAnalysisGraph.h:114
unsigned getOperandIdx() const
Definition: ConstraintAnalysisGraph.h:198
void replaceIncoming(CAGNode *otherNode)
Definition: ConstraintAnalysisGraph.cpp:18
node_vector::const_iterator const_iterator
Definition: ConstraintAnalysisGraph.h:64
void setTypeTransformRule(TypeTransformRule r)
Definition: ConstraintAnalysisGraph.h:174
Value getValue() const final
Definition: ConstraintAnalysisGraph.h:225
T * addClusteredConstraint(ArrayRef< CAGAnchorNode *> anchors)
Definition: ConstraintAnalysisGraph.h:293
Definition: ConstraintAnalysisGraph.h:216
CAGConstraintNode(Kind kind)
Definition: ConstraintAnalysisGraph.h:236
node_vector::iterator iterator
Definition: ConstraintAnalysisGraph.h:251