My Project
SymbolTable.h
Go to the documentation of this file.
1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_SYMBOLTABLE_H
10 #define MLIR_IR_SYMBOLTABLE_H
11 
12 #include "mlir/IR/OpDefinition.h"
13 #include "llvm/ADT/StringMap.h"
14 
15 namespace mlir {
16 class Identifier;
17 class Operation;
18 
23 class SymbolTable {
24 public:
26  SymbolTable(Operation *symbolTableOp);
27 
30  Operation *lookup(StringRef name) const;
31  template <typename T> T lookup(StringRef name) const {
32  return dyn_cast_or_null<T>(lookup(name));
33  }
34 
36  void erase(Operation *symbol);
37 
41  void insert(Operation *symbol, Block::iterator insertPt = {});
42 
44  static StringRef getSymbolAttrName() { return "sym_name"; }
45 
47  Operation *getOp() const { return symbolTableOp; }
48 
49  //===--------------------------------------------------------------------===//
50  // Symbol Utilities
51  //===--------------------------------------------------------------------===//
52 
56  static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
57 
62  static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
63 
65  class SymbolUse {
66  public:
68  : owner(op), symbolRef(symbolRef) {}
69 
71  Operation *getUser() const { return owner; }
72 
74  SymbolRefAttr getSymbolRef() const { return symbolRef; }
75 
76  private:
78  Operation *owner;
79 
81  SymbolRefAttr symbolRef;
82  };
83 
85  class UseRange {
86  public:
87  UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
88 
89  using iterator = std::vector<SymbolUse>::const_iterator;
90  iterator begin() const { return uses.begin(); }
91  iterator end() const { return uses.end(); }
92 
93  private:
94  std::vector<SymbolUse> uses;
95  };
96 
105 
112  static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
113 
122  static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
123 
132  LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
133  StringRef newSymbol,
134  Operation *from);
135 
136 private:
137  Operation *symbolTableOp;
138 
140  llvm::StringMap<Operation *> symbolTable;
141 
143  unsigned uniquingCounter = 0;
144 };
145 
146 //===----------------------------------------------------------------------===//
147 // SymbolTable Trait Types
148 //===----------------------------------------------------------------------===//
149 
150 namespace OpTrait {
151 namespace impl {
154 } // namespace impl
155 
163 template <typename ConcreteType>
164 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
165 public:
167  return impl::verifySymbolTable(op);
168  }
169 
173  Operation *lookupSymbol(StringRef name) {
174  return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
175  }
176  template <typename T> T lookupSymbol(StringRef name) {
177  return dyn_cast_or_null<T>(lookupSymbol(name));
178  }
179 };
180 
184 template <typename ConcreteType>
185 class Symbol : public TraitBase<ConcreteType, Symbol> {
186 public:
188  return impl::verifySymbol(op);
189  }
190 
192  StringRef getName() {
193  return this->getOperation()
194  ->template getAttrOfType<StringAttr>(
196  .getValue();
197  }
198 
200  void setName(StringRef name) {
201  this->getOperation()->setAttr(
203  StringAttr::get(name, this->getOperation()->getContext()));
204  }
205 
210  return ::mlir::SymbolTable::getSymbolUses(getName(), from);
211  }
212 
217  return ::mlir::SymbolTable::symbolKnownUseEmpty(getName(), from);
218  }
219 
223  LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol,
224  Operation *from) {
225  return ::mlir::SymbolTable::replaceAllSymbolUses(getName(), newSymbol,
226  from);
227  }
228 };
229 
230 } // end namespace OpTrait
231 } // end namespace mlir
232 
233 #endif // MLIR_IR_SYMBOLTABLE_H
Definition: Attributes.h:456
Definition: InferTypeOpInterface.cpp:20
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:44
Operation * getUser() const
Return the operation user of this symbol reference.
Definition: SymbolTable.h:71
SymbolUse(Operation *op, SymbolRefAttr symbolRef)
Definition: SymbolTable.h:67
Definition: Operation.h:27
static LogicalResult verifyTrait(Operation *op)
Definition: SymbolTable.h:187
static Optional< UseRange > getSymbolUses(Operation *from)
Definition: SymbolTable.cpp:303
T lookup(StringRef name) const
Definition: SymbolTable.h:31
LogicalResult verifySymbol(Operation *op)
Definition: SymbolTable.cpp:172
void setName(StringRef name)
Set the name of this symbol.
Definition: SymbolTable.h:200
Definition: SymbolTable.h:185
static LogicalResult verifyTrait(Operation *op)
Definition: SymbolTable.h:166
Definition: LLVM.h:40
void erase(Operation *symbol)
Erase the given symbol from the table.
Definition: SymbolTable.cpp:53
This class implements a range of SymbolRef uses.
Definition: SymbolTable.h:85
Operation * lookupSymbol(StringRef name)
Definition: SymbolTable.h:173
void insert(Operation *symbol, Block::iterator insertPt={})
Definition: SymbolTable.cpp:69
Definition: LogicalResult.h:18
OpListType::iterator iterator
Definition: Block.h:107
Operation * getOp() const
Returns the associated operation.
Definition: SymbolTable.h:47
UseRange(std::vector< SymbolUse > &&uses)
Definition: SymbolTable.h:87
LogicalResult verifySymbolTable(Operation *op)
Definition: SymbolTable.cpp:142
Definition: SymbolTable.h:164
SymbolRefAttr getSymbolRef() const
Return the symbol reference that this use represents.
Definition: SymbolTable.h:74
static Operation * lookupNearestSymbolFrom(Operation *from, StringRef symbol)
Definition: SymbolTable.cpp:125
Definition: OpDefinition.h:386
static bool symbolKnownUseEmpty(StringRef symbol, Operation *from)
Definition: SymbolTable.cpp:340
Operation * lookup(StringRef name) const
Definition: SymbolTable.cpp:48
StringRef getName()
Returns the name of this symbol.
Definition: SymbolTable.h:192
static Operation * lookupSymbolIn(Operation *op, StringRef symbol)
Definition: SymbolTable.cpp:106
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
Definition: SymbolTable.cpp:25
Definition: SymbolTable.h:23
LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol, Operation *from)
Definition: SymbolTable.h:223
iterator begin() const
Definition: SymbolTable.h:90
iterator end() const
Definition: SymbolTable.h:91
static StringAttr get(StringRef bytes, MLIRContext *context)
Get an instance of a StringAttr with the given string.
Definition: Attributes.cpp:334
This class represents a specific symbol use.
Definition: SymbolTable.h:65
std::vector< SymbolUse >::const_iterator iterator
Definition: SymbolTable.h:89
bool symbolKnownUseEmpty(Operation *from)
Definition: SymbolTable.h:216
Optional<::mlir::SymbolTable::UseRange > getSymbolUses(Operation *from)
Definition: SymbolTable.h:209
static LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Operation *from)
Definition: SymbolTable.cpp:406
T lookupSymbol(StringRef name)
Definition: SymbolTable.h:176