My Project
CallGraph.h
Go to the documentation of this file.
1 //===- CallGraph.h - CallGraph analysis for MLIR ----------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains an analysis for computing the multi-level callgraph from a
10 // given top-level operation. This nodes within this callgraph are defined by
11 // the `CallOpInterface` and `CallableOpInterface` operation interfaces defined
12 // in CallInterface.td.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_ANALYSIS_CALLGRAPH_H
17 #define MLIR_ANALYSIS_CALLGRAPH_H
18 
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/GraphTraits.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/ADT/PointerIntPair.h"
23 #include "llvm/ADT/SetVector.h"
24 
25 namespace mlir {
26 struct CallInterfaceCallable;
27 class Operation;
28 class Region;
29 
30 //===----------------------------------------------------------------------===//
31 // CallGraphNode
32 //===----------------------------------------------------------------------===//
33 
39 public:
41  class Edge {
42  enum class Kind {
43  // An 'Abstract' edge represents an opaque, non-operation, reference
44  // between this node and the target. Edges of this type are only valid
45  // from the external node, as there is no valid connection to an operation
46  // in the module.
47  Abstract,
48 
49  // A 'Call' edge represents a direct reference to the target node via a
50  // call-like operation within the callable region of this node.
51  Call,
52 
53  // A 'Child' edge is used when the region of target node is defined inside
54  // of the callable region of this node. This means that the region of this
55  // node is an ancestor of the region for the target node. As such, this
56  // edge cannot be used on the 'external' node.
57  Child,
58  };
59 
60  public:
62  bool isAbstract() const { return targetAndKind.getInt() == Kind::Abstract; }
63 
65  bool isCall() const { return targetAndKind.getInt() == Kind::Call; }
66 
68  bool isChild() const { return targetAndKind.getInt() == Kind::Child; }
69 
71  CallGraphNode *getTarget() const { return targetAndKind.getPointer(); }
72 
73  bool operator==(const Edge &edge) const {
74  return targetAndKind == edge.targetAndKind;
75  }
76 
77  private:
78  Edge(CallGraphNode *node, Kind kind) : targetAndKind(node, kind) {}
79  explicit Edge(llvm::PointerIntPair<CallGraphNode *, 2, Kind> targetAndKind)
80  : targetAndKind(targetAndKind) {}
81 
83  llvm::PointerIntPair<CallGraphNode *, 2, Kind> targetAndKind;
84 
85  // Provide access to the constructor and Kind.
86  friend class CallGraphNode;
87  };
88 
90  bool isExternal() const;
91 
94  Region *getCallableRegion() const;
95 
99  void addAbstractEdge(CallGraphNode *node);
100 
102  void addCallEdge(CallGraphNode *node);
103 
105  void addChildEdge(CallGraphNode *child);
106 
109  iterator begin() const { return edges.begin(); }
110  iterator end() const { return edges.end(); }
111 
113  bool hasChildren() const;
114 
115 private:
117  struct EdgeKeyInfo {
118  using BaseInfo =
120 
121  static Edge getEmptyKey() { return Edge(BaseInfo::getEmptyKey()); }
122  static Edge getTombstoneKey() { return Edge(BaseInfo::getTombstoneKey()); }
123  static unsigned getHashValue(const Edge &edge) {
124  return BaseInfo::getHashValue(edge.targetAndKind);
125  }
126  static bool isEqual(const Edge &lhs, const Edge &rhs) { return lhs == rhs; }
127  };
128 
129  CallGraphNode(Region *callableRegion) : callableRegion(callableRegion) {}
130 
132  void addEdge(CallGraphNode *node, Edge::Kind kind);
133 
137  Region *callableRegion;
138 
140  llvm::SetVector<Edge, SmallVector<Edge, 4>,
141  llvm::SmallDenseSet<Edge, 4, EdgeKeyInfo>>
142  edges;
143 
144  // Provide access to private methods.
145  friend class CallGraph;
146 };
147 
148 //===----------------------------------------------------------------------===//
149 // CallGraph
150 //===----------------------------------------------------------------------===//
151 
152 class CallGraph {
153  using NodeMapT = llvm::MapVector<Region *, std::unique_ptr<CallGraphNode>>;
154 
157  class NodeIterator final
158  : public llvm::mapped_iterator<
159  NodeMapT::const_iterator,
160  CallGraphNode *(*)(const NodeMapT::value_type &)> {
161  static CallGraphNode *unwrap(const NodeMapT::value_type &value) {
162  return value.second.get();
163  }
164 
165  public:
167  NodeIterator(NodeMapT::const_iterator it)
168  : llvm::mapped_iterator<
169  NodeMapT::const_iterator,
170  CallGraphNode *(*)(const NodeMapT::value_type &)>(it, &unwrap) {}
171  };
172 
173 public:
174  CallGraph(Operation *op);
175 
179  CallGraphNode *getOrAddNode(Region *region, CallGraphNode *parentNode);
180 
183  CallGraphNode *lookupNode(Region *region) const;
184 
187  return const_cast<CallGraphNode *>(&externalNode);
188  }
189 
194  CallGraphNode *resolveCallable(CallInterfaceCallable callable,
195  Operation *from = nullptr) const;
196 
198  using iterator = NodeIterator;
199  iterator begin() const { return nodes.begin(); }
200  iterator end() const { return nodes.end(); }
201 
203  void dump() const;
204  void print(raw_ostream &os) const;
205 
206 private:
208  NodeMapT nodes;
209 
211  CallGraphNode externalNode;
212 };
213 
214 } // end namespace mlir
215 
216 namespace llvm {
217 // Provide graph traits for traversing call graphs using standard graph
218 // traversals.
219 template <> struct GraphTraits<const mlir::CallGraphNode *> {
221  static NodeRef getEntryNode(NodeRef node) { return node; }
222 
224  return edge.getTarget();
225  }
226 
227  // ChildIteratorType/begin/end - Allow iteration over all nodes in the graph.
228  using ChildIteratorType =
229  mapped_iterator<mlir::CallGraphNode::iterator, decltype(&unwrap)>;
231  return {node->begin(), &unwrap};
232  }
234  return {node->end(), &unwrap};
235  }
236 };
237 
238 template <>
239 struct GraphTraits<const mlir::CallGraph *>
240  : public GraphTraits<const mlir::CallGraphNode *> {
242  static NodeRef getEntryNode(const mlir::CallGraph *cg) {
243  return cg->getExternalNode();
244  }
245 
246  // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
248  static nodes_iterator nodes_begin(mlir::CallGraph *cg) { return cg->begin(); }
249  static nodes_iterator nodes_end(mlir::CallGraph *cg) { return cg->end(); }
250 };
251 } // end namespace llvm
252 
253 #endif // MLIR_ANALYSIS_CALLGRAPH_H
Definition: InferTypeOpInterface.cpp:20
void addChildEdge(CallGraphNode *child)
Adds a reference edge to the given child node.
Definition: CallGraph.cpp:56
Definition: Region.h:23
iterator begin() const
Definition: CallGraph.h:199
bool isChild() const
Returns if this edge represents a Child edge.
Definition: CallGraph.h:68
Definition: PassRegistry.cpp:413
Definition: Operation.h:27
CallGraphNode * getExternalNode() const
Return the callgraph node representing the indirect-external callee.
Definition: CallGraph.h:186
Definition: LLVM.h:45
iterator begin() const
Definition: CallGraph.h:109
NodeIterator iterator
An iterator over the nodes of the graph.
Definition: CallGraph.h:198
static ChildIteratorType child_begin(NodeRef node)
Definition: CallGraph.h:230
Definition: LLVM.h:34
bool isAbstract() const
Returns if this edge represents an Abstract edge.
Definition: CallGraph.h:62
static ChildIteratorType child_end(NodeRef node)
Definition: CallGraph.h:233
mapped_iterator< mlir::CallGraphNode::iterator, decltype(&unwrap)> ChildIteratorType
Definition: CallGraph.h:229
static nodes_iterator nodes_begin(mlir::CallGraph *cg)
Definition: CallGraph.h:248
CallGraphNode * getTarget() const
Returns the target node for this edge.
Definition: CallGraph.h:71
static NodeRef getEntryNode(const mlir::CallGraph *cg)
The entry node into the graph is the external node.
Definition: CallGraph.h:242
bool operator==(const Edge &edge) const
Definition: CallGraph.h:73
static NodeRef unwrap(const mlir::CallGraphNode::Edge &edge)
Definition: CallGraph.h:223
Definition: CallInterfaces.h:24
static NodeRef getEntryNode(NodeRef node)
Definition: CallGraph.h:221
void addCallEdge(CallGraphNode *node)
Add an outgoing call edge from this node.
Definition: CallGraph.cpp:51
bool isCall() const
Returns if this edge represents a Call edge.
Definition: CallGraph.h:65
friend class CallGraphNode
Definition: CallGraph.h:86
Definition: CallGraph.h:38
bool isExternal() const
Returns if this node is the external node.
Definition: CallGraph.cpp:34
void print(OpAsmPrinter &p, AffineIfOp op)
Definition: AffineOps.cpp:1671
iterator end() const
Definition: CallGraph.h:200
SmallVectorImpl< Edge >::const_iterator iterator
Iterator over the outgoing edges of this node.
Definition: CallGraph.h:108
void addAbstractEdge(CallGraphNode *node)
Definition: CallGraph.cpp:45
static nodes_iterator nodes_end(mlir::CallGraph *cg)
Definition: CallGraph.h:249
This class represents a directed edge between two nodes in the callgraph.
Definition: CallGraph.h:41
friend class CallGraph
Definition: CallGraph.h:145
Region * getCallableRegion() const
Definition: CallGraph.cpp:38
mlir::CallGraph::iterator nodes_iterator
Definition: CallGraph.h:247
bool hasChildren() const
Returns true if this node has any child edges.
Definition: CallGraph.cpp:61
Definition: CallGraph.h:152
iterator end() const
Definition: CallGraph.h:110