My Project
Dialect.h
Go to the documentation of this file.
1 //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the 'dialect' abstraction.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECT_H
14 #define MLIR_IR_DIALECT_H
15 
17 
18 namespace mlir {
19 class DialectAsmParser;
20 class DialectAsmPrinter;
21 class DialectInterface;
22 class OpBuilder;
23 class Type;
24 
26  std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
27 using DialectConstantFoldHook = std::function<LogicalResult(
30  std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
31 
39 class Dialect {
40 public:
41  virtual ~Dialect();
42 
45  static bool isValidNamespace(StringRef str);
46 
47  MLIRContext *getContext() const { return context; }
48 
49  StringRef getNamespace() const { return name; }
50 
54  bool allowsUnknownOperations() const { return unknownOpsAllowed; }
55 
59  bool allowsUnknownTypes() const { return unknownTypesAllowed; }
60 
61  //===--------------------------------------------------------------------===//
62  // Constant Hooks
63  //===--------------------------------------------------------------------===//
64 
73  [](Operation *op, ArrayRef<Attribute> operands,
74  SmallVectorImpl<Attribute> &results) { return failure(); };
75 
82  [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
83 
89  [](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
90  return Attribute();
91  };
92 
101  Type type, Location loc) {
102  return nullptr;
103  }
104 
105  //===--------------------------------------------------------------------===//
106  // Parsing Hooks
107  //===--------------------------------------------------------------------===//
108 
111  virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
112 
116  virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
117  llvm_unreachable("dialect has no registered attribute printing hook");
118  }
119 
121  virtual Type parseType(DialectAsmParser &parser) const;
122 
124  virtual void printType(Type, DialectAsmPrinter &) const {
125  llvm_unreachable("dialect has no registered type printing hook");
126  }
127 
128  //===--------------------------------------------------------------------===//
129  // Verification Hooks
130  //===--------------------------------------------------------------------===//
131 
137  unsigned regionIndex,
138  unsigned argIndex,
140 
146  unsigned regionIndex,
147  unsigned resultIndex,
149 
153  return success();
154  }
155 
156  //===--------------------------------------------------------------------===//
157  // Interfaces
158  //===--------------------------------------------------------------------===//
159 
163  auto it = registeredInterfaces.find(interfaceID);
164  return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
165  }
166  template <typename InterfaceT> const InterfaceT *getRegisteredInterface() {
167  return static_cast<const InterfaceT *>(
168  getRegisteredInterface(InterfaceT::getInterfaceID()));
169  }
170 
171 protected:
179  Dialect(StringRef name, MLIRContext *context);
180 
183  template <typename... Args> void addOperations() {
185  }
186 
187  // It would be nice to define this as variadic functions instead of a nested
188  // variadic type, but we can't do that: function template partial
189  // specialization is not allowed, and we can't define an overload set because
190  // we don't have any arguments of the types we are pushing around.
191  template <typename First, typename... Rest> class VariadicOperationAdder {
192  public:
193  static void addToSet(Dialect &dialect) {
194  dialect.addOperation(AbstractOperation::get<First>(dialect));
196  }
197  };
198 
199  template <typename First> class VariadicOperationAdder<First> {
200  public:
201  static void addToSet(Dialect &dialect) {
202  dialect.addOperation(AbstractOperation::get<First>(dialect));
203  }
204  };
205 
206  void addOperation(AbstractOperation opInfo);
207 
209  template <typename... Args> void addTypes() {
211  }
212 
214  template <typename... Args> void addAttributes() {
216  }
217 
218  // It would be nice to define this as variadic functions instead of a nested
219  // variadic type, but we can't do that: function template partial
220  // specialization is not allowed, and we can't define an overload set
221  // because we don't have any arguments of the types we are pushing around.
222  template <typename First, typename... Rest> struct VariadicSymbolAdder {
223  static void addToSet(Dialect &dialect) {
226  }
227  };
228 
229  template <typename First> struct VariadicSymbolAdder<First> {
230  static void addToSet(Dialect &dialect) {
231  dialect.addSymbol(First::getClassID());
232  }
233  };
234 
236  void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
237 
239  void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
240 
242  void addInterface(std::unique_ptr<DialectInterface> interface);
243 
245  template <typename T, typename T2, typename... Tys> void addInterfaces() {
246  addInterfaces<T>();
247  addInterfaces<T2, Tys...>();
248  }
249  template <typename T> void addInterfaces() {
250  addInterface(std::make_unique<T>(this));
251  }
252 
253 private:
254  // Register a symbol(e.g. type) with its given unique class identifier.
255  void addSymbol(const ClassID *const classID);
256 
257  Dialect(const Dialect &) = delete;
258  void operator=(Dialect &) = delete;
259 
262  void registerDialect(MLIRContext *context);
263 
265  StringRef name;
266 
268  MLIRContext *context;
269 
273  bool unknownOpsAllowed = false;
274 
278  bool unknownTypesAllowed = false;
279 
282 };
283 
284 using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
285 
289 
291 void registerAllDialects(MLIRContext *context);
292 
295 template <typename ConcreteDialect> void registerDialect() {
297  // Just allocate the dialect, the context takes ownership of it.
298  new ConcreteDialect(ctx);
299  });
300 }
301 
309 template <typename ConcreteDialect> struct DialectRegistration {
310  DialectRegistration() { registerDialect<ConcreteDialect>(); }
311 };
312 
313 } // namespace mlir
314 
315 #endif
Definition: InferTypeOpInterface.cpp:20
Definition: Attributes.h:1052
void addInterfaces()
Definition: Dialect.h:249
Definition: STLExtras.h:95
Definition: Operation.h:27
void allowUnknownOperations(bool allow=true)
Enable support for unregistered operations.
Definition: Dialect.h:236
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:96
std::function< void(MLIRContext *)> DialectAllocatorFunction
Definition: Dialect.h:284
Definition: Attributes.h:139
static void addToSet(Dialect &dialect)
Definition: Dialect.h:223
Definition: LLVM.h:48
std::function< LogicalResult(Operation *, ArrayRef< Attribute >, SmallVectorImpl< Attribute > &)> DialectConstantFoldHook
Definition: Dialect.h:28
std::function< Attribute(const OpaqueElementsAttr, ArrayRef< uint64_t >)> DialectExtractElementHook
Definition: Dialect.h:30
Dialect(StringRef name, MLIRContext *context)
Definition: Dialect.cpp:69
DialectConstantFoldHook constantFoldHook
Definition: Dialect.h:72
void addInterfaces()
Register a set of dialect interfaces with this dialect instance.
Definition: Dialect.h:245
virtual ~Dialect()
Definition: Dialect.cpp:75
void addOperations()
Definition: Dialect.h:183
Definition: LLVM.h:34
Definition: Location.h:52
std::pair< Identifier, Attribute > NamedAttribute
Definition: Attributes.h:264
static void addToSet(Dialect &dialect)
Definition: Dialect.h:230
LogicalResult success(bool isSuccess=true)
Definition: LogicalResult.h:25
Definition: LogicalResult.h:18
LogicalResult failure(bool isFailure=true)
Definition: LogicalResult.h:32
Definition: LLVM.h:37
const InterfaceT * getRegisteredInterface()
Definition: Dialect.h:166
static void addToSet(Dialect &dialect)
Definition: Dialect.h:193
Definition: Attributes.h:53
DialectRegistration()
Definition: Dialect.h:310
Definition: Dialect.h:39
Definition: OperationSupport.h:83
bool allowsUnknownTypes() const
Definition: Dialect.h:59
MLIRContext * getContext() const
Definition: Dialect.h:47
std::function< bool(const OpaqueElementsAttr, ElementsAttr &)> DialectConstantDecodeHook
Definition: Dialect.h:26
virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute)
Definition: Dialect.h:152
Definition: Types.h:84
DialectConstantDecodeHook decodeHook
Definition: Dialect.h:81
static void addToSet(Dialect &dialect)
Definition: Dialect.h:201
void addOperation(AbstractOperation opInfo)
Definition: MLIRContext.cpp:379
void addAttributes()
This method is used by derived classes to add their attributes to the set.
Definition: Dialect.h:214
StringRef getNamespace() const
Definition: Dialect.h:49
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:126
void allowUnknownTypes(bool allow=true)
Enable support for unregistered types.
Definition: Dialect.h:239
Definition: DialectImplementation.h:33
Definition: Dialect.h:222
void addTypes()
This method is used by derived classes to add their types to the set.
Definition: Dialect.h:209
Definition: MLIRContext.h:34
const DialectInterface * getRegisteredInterface(ClassID *interfaceID)
Definition: Dialect.h:162
void registerDialect()
Definition: Dialect.h:295
bool allowsUnknownOperations() const
Definition: Dialect.h:54
virtual void printType(Type, DialectAsmPrinter &) const
Print a type registered to this dialect.
Definition: Dialect.h:124
Definition: Dialect.h:309
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Definition: Dialect.cpp:81
virtual void printAttribute(Attribute, DialectAsmPrinter &) const
Definition: Dialect.h:116
DialectExtractElementHook extractElementHook
Definition: Dialect.h:88
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Definition: Dialect.h:100
void registerDialectAllocator(const DialectAllocatorFunction &function)
Definition: Dialect.cpp:39
Definition: Attributes.h:559
Definition: Builders.h:158
Definition: Dialect.h:191
Definition: DialectImplementation.h:100
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:104
This class represents an interface overridden for a single dialect.
Definition: DialectInterface.h:40
void registerAllDialects(MLIRContext *context)
Registers all dialects with the specified MLIRContext.
Definition: Dialect.cpp:57
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Definition: Dialect.cpp:90
static bool isValidNamespace(StringRef str)
Definition: Dialect.cpp:118