My Project
DialectHooks.h
Go to the documentation of this file.
1 //===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines abstraction and registration mechanism for dialect hooks.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECT_HOOKS_H
14 #define MLIR_IR_DIALECT_HOOKS_H
15 
16 #include "mlir/IR/Dialect.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 namespace mlir {
20 using DialectHooksSetter = std::function<void(MLIRContext *)>;
21 
30 class DialectHooks {
31 public:
32  // Returns hook to constant fold an operation.
34  // Returns hook to decode opaque constant tensor.
36  // Returns hook to extract an element of an opaque constant tensor.
38 };
39 
43 
50 template <typename ConcreteHooks> struct DialectHooksRegistration {
51  DialectHooksRegistration(StringRef dialectName) {
52  registerDialectHooksSetter([dialectName](MLIRContext *ctx) {
53  Dialect *dialect = ctx->getRegisteredDialect(dialectName);
54  if (!dialect) {
55  llvm::errs() << "error: cannot register hooks for unknown dialect '"
56  << dialectName << "'\n";
57  abort();
58  }
59  // Set hooks.
60  ConcreteHooks hooks;
61  if (auto h = hooks.getConstantFoldHook())
62  dialect->constantFoldHook = h;
63  if (auto h = hooks.getDecodeHook())
64  dialect->decodeHook = h;
65  if (auto h = hooks.getExtractElementHook())
66  dialect->extractElementHook = h;
67  });
68  }
69 };
70 
71 } // namespace mlir
72 
73 #endif
Definition: InferTypeOpInterface.cpp:20
Definition: DialectHooks.h:30
void registerDialectHooksSetter(const DialectHooksSetter &function)
Definition: Dialect.cpp:47
std::function< LogicalResult(Operation *, ArrayRef< Attribute >, SmallVectorImpl< Attribute > &)> DialectConstantFoldHook
Definition: Dialect.h:28
std::function< Attribute(const OpaqueElementsAttr, ArrayRef< uint64_t >)> DialectExtractElementHook
Definition: Dialect.h:30
DialectConstantFoldHook constantFoldHook
Definition: Dialect.h:72
Dialect * getRegisteredDialect(StringRef name)
Definition: MLIRContext.cpp:315
std::function< void(MLIRContext *)> DialectHooksSetter
Definition: DialectHooks.h:20
DialectConstantFoldHook getConstantFoldHook()
Definition: DialectHooks.h:33
Definition: Dialect.h:39
std::function< bool(const OpaqueElementsAttr, ElementsAttr &)> DialectConstantDecodeHook
Definition: Dialect.h:26
DialectConstantDecodeHook decodeHook
Definition: Dialect.h:81
DialectConstantDecodeHook getDecodeHook()
Definition: DialectHooks.h:35
Definition: DialectHooks.h:50
Definition: MLIRContext.h:34
DialectHooksRegistration(StringRef dialectName)
Definition: DialectHooks.h:51
DialectExtractElementHook extractElementHook
Definition: Dialect.h:88
DialectExtractElementHook getExtractElementHook()
Definition: DialectHooks.h:37