My Project
DialectInterface.h
Go to the documentation of this file.
1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_DIALECTINTERFACE_H
10 #define MLIR_IR_DIALECTINTERFACE_H
11 
12 #include "mlir/Support/STLExtras.h"
13 #include "llvm/ADT/DenseSet.h"
14 
15 namespace mlir {
16 class Dialect;
17 class MLIRContext;
18 class Operation;
19 
20 //===----------------------------------------------------------------------===//
21 // DialectInterface
22 //===----------------------------------------------------------------------===//
23 namespace detail {
26 template <typename ConcreteType, typename BaseT>
27 class DialectInterfaceBase : public BaseT {
28 public:
30 
32  static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
33 
34 protected:
35  DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
36 };
37 } // end namespace detail
38 
41 public:
42  virtual ~DialectInterface();
43 
46  template <typename ConcreteType>
48 
50  Dialect *getDialect() const { return dialect; }
51 
53  ClassID *getID() const { return interfaceID; }
54 
55 protected:
57  : dialect(dialect), interfaceID(id) {}
58 
59 private:
61  Dialect *dialect;
62 
64  ClassID *interfaceID;
65 };
66 
67 //===----------------------------------------------------------------------===//
68 // DialectInterfaceCollection
69 //===----------------------------------------------------------------------===//
70 
71 namespace detail {
76  struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
78 
79  static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
80  static unsigned getHashValue(const DialectInterface *key) {
81  return getHashValue(key->getDialect());
82  }
83 
84  static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
85  if (rhs == getEmptyKey() || rhs == getTombstoneKey())
86  return false;
87  return lhs == rhs->getDialect();
88  }
89  };
90 
93  using InterfaceVectorT = std::vector<const DialectInterface *>;
94 
95 public:
98 
99 protected:
102  const DialectInterface *getInterfaceFor(Operation *op) const;
103 
105  const DialectInterface *getInterfaceFor(Dialect *dialect) const {
106  auto it = interfaces.find_as(dialect);
107  return it == interfaces.end() ? nullptr : *it;
108  }
109 
112  template <typename InterfaceT>
113  class iterator : public llvm::mapped_iterator<
114  InterfaceVectorT::const_iterator,
115  const InterfaceT &(*)(const DialectInterface *)> {
116  static const InterfaceT &remapIt(const DialectInterface *interface) {
117  return *static_cast<const InterfaceT *>(interface);
118  }
119 
120  iterator(InterfaceVectorT::const_iterator it)
121  : llvm::mapped_iterator<
122  InterfaceVectorT::const_iterator,
123  const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {}
124 
127  };
128 
130  template <typename InterfaceT> iterator<InterfaceT> interface_begin() const {
131  return iterator<InterfaceT>(orderedInterfaces.begin());
132  }
133  template <typename InterfaceT> iterator<InterfaceT> interface_end() const {
134  return iterator<InterfaceT>(orderedInterfaces.end());
135  }
136 
137 private:
139  InterfaceSetT interfaces;
142  // NOTE: SetVector does not provide find access, so it can't be used here.
143  InterfaceVectorT orderedInterfaces;
144 };
145 } // namespace detail
146 
149 template <typename InterfaceType>
152 public:
154 
157  : detail::DialectInterfaceCollectionBase(
158  ctx, InterfaceType::getInterfaceID()) {}
159 
162  template <typename Object>
163  const InterfaceType *getInterfaceFor(Object *obj) const {
164  return static_cast<const InterfaceType *>(
166  }
167 
169  using iterator =
171  iterator begin() const { return interface_begin<InterfaceType>(); }
172  iterator end() const { return interface_end<InterfaceType>(); }
173 
174 private:
177 };
178 
179 } // namespace mlir
180 
181 #endif
Definition: InferTypeOpInterface.cpp:20
iterator end() const
Definition: DialectInterface.h:172
Definition: STLExtras.h:95
Definition: Operation.h:27
Definition: DialectInterface.h:27
static ClassID * getInterfaceID()
Get a unique id for the derived interface type.
Definition: DialectInterface.h:32
Definition: LLVM.h:46
Definition: LLVM.h:45
iterator< InterfaceT > interface_end() const
Definition: DialectInterface.h:133
Dialect * getDialect() const
Return the dialect that this interface represents.
Definition: DialectInterface.h:50
Definition: DialectInterface.h:150
const InterfaceType * getInterfaceFor(Object *obj) const
Definition: DialectInterface.h:163
const DialectInterface * getInterfaceFor(Dialect *dialect) const
Get the interface for the given dialect.
Definition: DialectInterface.h:105
Definition: Dialect.h:39
DialectInterfaceBase(Dialect *dialect)
Definition: DialectInterface.h:35
Definition: DialectInterface.h:74
inline ::llvm::hash_code hash_value(AffineExpr arg)
Make AffineExpr hashable.
Definition: AffineExpr.h:201
iterator< InterfaceT > interface_begin() const
Iterator access to the held interfaces.
Definition: DialectInterface.h:130
DialectInterfaceCollection(MLIRContext *ctx)
Collect the registered dialect interfaces within the provided context.
Definition: DialectInterface.h:156
Definition: MLIRContext.h:34
ClassID * getID() const
Return the derived interface id.
Definition: DialectInterface.h:53
const DialectInterface * getInterfaceFor(Operation *op) const
Definition: Dialect.cpp:154
DialectInterface(Dialect *dialect, ClassID *id)
Definition: DialectInterface.h:56
iterator begin() const
Definition: DialectInterface.h:171
This class represents an interface overridden for a single dialect.
Definition: DialectInterface.h:40