My Project
STLExtras.h
Go to the documentation of this file.
1 //===- STLExtras.h - STL-like extensions that are used by MLIR --*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains stuff that should be arguably sunk down to the LLVM
10 // Support/STLExtras.h file over time.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_SUPPORT_STLEXTRAS_H
15 #define MLIR_SUPPORT_STLEXTRAS_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 namespace mlir {
21 
22 namespace detail {
23 template <typename RangeT>
24 using ValueOfRange = typename std::remove_reference<decltype(
25  *std::begin(std::declval<RangeT &>()))>::type;
26 } // end namespace detail
27 
38 template <typename ForwardIterator, typename UnaryFunctor,
39  typename NullaryFunctor,
40  typename = typename std::enable_if<
41  !std::is_constructible<StringRef, UnaryFunctor>::value &&
42  !std::is_constructible<StringRef, NullaryFunctor>::value>::type>
43 inline void interleave(ForwardIterator begin, ForwardIterator end,
44  UnaryFunctor each_fn, NullaryFunctor between_fn) {
45  if (begin == end)
46  return;
47  each_fn(*begin);
48  ++begin;
49  for (; begin != end; ++begin) {
50  between_fn();
51  each_fn(*begin);
52  }
53 }
54 
55 template <typename Container, typename UnaryFunctor, typename NullaryFunctor,
56  typename = typename std::enable_if<
57  !std::is_constructible<StringRef, UnaryFunctor>::value &&
58  !std::is_constructible<StringRef, NullaryFunctor>::value>::type>
59 inline void interleave(const Container &c, UnaryFunctor each_fn,
60  NullaryFunctor between_fn) {
61  interleave(c.begin(), c.end(), each_fn, between_fn);
62 }
63 
65 template <typename Container, typename UnaryFunctor, typename raw_ostream,
67 inline void interleave(const Container &c, raw_ostream &os,
68  UnaryFunctor each_fn, const StringRef &separator) {
69  interleave(c.begin(), c.end(), each_fn, [&] { os << separator; });
70 }
71 template <typename Container, typename raw_ostream,
73 inline void interleave(const Container &c, raw_ostream &os,
74  const StringRef &separator) {
75  interleave(
76  c, os, [&](const T &a) { os << a; }, separator);
77 }
78 
79 template <typename Container, typename UnaryFunctor, typename raw_ostream,
81 inline void interleaveComma(const Container &c, raw_ostream &os,
82  UnaryFunctor each_fn) {
83  interleave(c, os, each_fn, ", ");
84 }
85 template <typename Container, typename raw_ostream,
87 inline void interleaveComma(const Container &c, raw_ostream &os) {
88  interleaveComma(c, os, [&](const T &a) { os << a; });
89 }
90 
95 struct alignas(8) ClassID {
96  template <typename T> static ClassID *getID() {
97  static ClassID id;
98  return &id;
99  }
100  template <template <typename T> class Trait> static ClassID *getID() {
101  static ClassID id;
102  return &id;
103  }
104 };
105 
113 namespace detail {
114 template <typename...> using void_t = void;
115 template <class, template <class...> class Op, class... Args> struct detector {
116  using value_t = std::false_type;
117 };
118 template <template <class...> class Op, class... Args>
119 struct detector<void_t<Op<Args...>>, Op, Args...> {
120  using value_t = std::true_type;
121 };
122 } // end namespace detail
123 
124 template <template <class...> class Op, class... Args>
125 using is_detected = typename detail::detector<void, Op, Args...>::value_t;
126 
128 namespace detail {
129 template <typename Callable, typename... Args>
130 using is_invocable =
131  decltype(std::declval<Callable &>()(std::declval<Args>()...));
132 } // namespace detail
133 
134 template <typename Callable, typename... Args>
135 using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
136 
137 //===----------------------------------------------------------------------===//
138 // Extra additions to <iterator>
139 //===----------------------------------------------------------------------===//
140 
143 template <typename DerivedT, typename BaseT, typename T,
144  typename PointerT = T *, typename ReferenceT = T &>
146  : public llvm::iterator_facade_base<DerivedT,
147  std::random_access_iterator_tag, T,
148  std::ptrdiff_t, PointerT, ReferenceT> {
149 public:
150  ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const {
151  assert(base == rhs.base && "incompatible iterators");
152  return index - rhs.index;
153  }
154  bool operator==(const indexed_accessor_iterator &rhs) const {
155  return base == rhs.base && index == rhs.index;
156  }
157  bool operator<(const indexed_accessor_iterator &rhs) const {
158  assert(base == rhs.base && "incompatible iterators");
159  return index < rhs.index;
160  }
161 
162  DerivedT &operator+=(ptrdiff_t offset) {
163  this->index += offset;
164  return static_cast<DerivedT &>(*this);
165  }
166  DerivedT &operator-=(ptrdiff_t offset) {
167  this->index -= offset;
168  return static_cast<DerivedT &>(*this);
169  }
170 
172  ptrdiff_t getIndex() const { return index; }
173 
175  const BaseT &getBase() const { return base; }
176 
177 protected:
178  indexed_accessor_iterator(BaseT base, ptrdiff_t index)
179  : base(base), index(index) {}
180  BaseT base;
181  ptrdiff_t index;
182 };
183 
184 namespace detail {
194 template <typename DerivedT, typename BaseT, typename T,
195  typename PointerT = T *, typename ReferenceT = T &>
197 public:
198  using RangeBaseT =
200 
202  class iterator : public indexed_accessor_iterator<iterator, BaseT, T,
203  PointerT, ReferenceT> {
204  public:
205  // Index into this iterator, invoking a static method on the derived type.
206  ReferenceT operator*() const {
207  return DerivedT::dereference_iterator(this->getBase(), this->getIndex());
208  }
209 
210  private:
211  iterator(BaseT owner, ptrdiff_t curIndex)
213  owner, curIndex) {}
214 
216  friend indexed_accessor_range_base<DerivedT, BaseT, T, PointerT,
217  ReferenceT>;
218  };
219 
220  indexed_accessor_range_base(iterator begin, iterator end)
221  : base(DerivedT::offset_base(begin.getBase(), begin.getIndex())),
222  count(end.getIndex() - begin.getIndex()) {}
224  : indexed_accessor_range_base(range.begin(), range.end()) {}
225  indexed_accessor_range_base(BaseT base, ptrdiff_t count)
226  : base(base), count(count) {}
227 
228  iterator begin() const { return iterator(base, 0); }
229  iterator end() const { return iterator(base, count); }
230  ReferenceT operator[](unsigned index) const {
231  assert(index < size() && "invalid index for value range");
232  return DerivedT::dereference_iterator(base, index);
233  }
234 
236  size_t size() const { return count; }
237 
239  bool empty() const { return size() == 0; }
240 
242  DerivedT slice(size_t n, size_t m) const {
243  assert(n + m <= size() && "invalid size specifiers");
244  return DerivedT(DerivedT::offset_base(base, n), m);
245  }
246 
248  DerivedT drop_front(size_t n = 1) const {
249  assert(size() >= n && "Dropping more elements than exist");
250  return slice(n, size() - n);
251  }
253  DerivedT drop_back(size_t n = 1) const {
254  assert(size() >= n && "Dropping more elements than exist");
255  return DerivedT(base, size() - n);
256  }
257 
259  DerivedT take_front(size_t n = 1) const {
260  return n < size() ? drop_back(size() - n)
261  : static_cast<const DerivedT &>(*this);
262  }
263 
267  template <typename SVT, unsigned N> operator SmallVector<SVT, N>() const {
268  return {begin(), end()};
269  }
270 
271 protected:
275  operator=(const indexed_accessor_range_base &) = default;
276 
278  BaseT base;
280  ptrdiff_t count;
281 };
282 } // end namespace detail
283 
291 template <typename DerivedT, typename BaseT, typename T,
292  typename PointerT = T *, typename ReferenceT = T &>
295  DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
296 public:
297  indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
298  : detail::indexed_accessor_range_base<
299  DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
300  std::make_pair(base, startIndex), count) {}
302  DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT,
303  ReferenceT>::indexed_accessor_range_base;
304 
306  const BaseT &getBase() const { return this->base.first; }
307 
309  ptrdiff_t getStartIndex() const { return this->base.second; }
310 
312  static std::pair<BaseT, ptrdiff_t>
313  offset_base(const std::pair<BaseT, ptrdiff_t> &base, ptrdiff_t index) {
314  // We encode the internal base as a pair of the derived base and a start
315  // index into the derived base.
316  return std::make_pair(base.first, base.second + index);
317  }
319  static ReferenceT
320  dereference_iterator(const std::pair<BaseT, ptrdiff_t> &base,
321  ptrdiff_t index) {
322  return DerivedT::dereference(base.first, base.second + index);
323  }
324 };
325 
327 template <typename ContainerTy> auto make_second_range(ContainerTy &&c) {
328  return llvm::map_range(
329  std::forward<ContainerTy>(c),
330  [](decltype((*std::begin(c))) elt) -> decltype((elt.second)) {
331  return elt.second;
332  });
333 }
334 
336 template <typename ContainerTy> bool has_single_element(ContainerTy &&c) {
337  auto it = std::begin(c), e = std::end(c);
338  return it != e && std::next(it) == e;
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // Extra additions to <type_traits>
343 //===----------------------------------------------------------------------===//
344 
349 template <typename T, bool isClass = std::is_class<T>::value>
350 struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {};
351 
353 template <typename ClassType, typename ReturnType, typename... Args>
354 struct FunctionTraits<ReturnType (ClassType::*)(Args...) const, false> {
356  enum { num_args = sizeof...(Args) };
357 
359  using result_t = ReturnType;
360 
362  template <size_t i>
363  using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
364 };
366 template <typename ReturnType, typename... Args>
367 struct FunctionTraits<ReturnType (*)(Args...), false> {
369  enum { num_args = sizeof...(Args) };
370 
372  using result_t = ReturnType;
373 
375  template <size_t i>
376  using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
377 };
378 } // end namespace mlir
379 
380 // Allow tuples to be usable as DenseMap keys.
381 // TODO: Move this to upstream LLVM.
382 
385 static inline unsigned llvm_combineHashValue(unsigned a, unsigned b) {
386  uint64_t key = (uint64_t)a << 32 | (uint64_t)b;
387  key += ~(key << 32);
388  key ^= (key >> 22);
389  key += ~(key << 13);
390  key ^= (key >> 8);
391  key += (key << 3);
392  key ^= (key >> 15);
393  key += ~(key << 27);
394  key ^= (key >> 31);
395  return (unsigned)key;
396 }
397 
398 namespace llvm {
399 template <typename... Ts> struct DenseMapInfo<std::tuple<Ts...>> {
400  using Tuple = std::tuple<Ts...>;
401 
402  static inline Tuple getEmptyKey() {
404  }
405 
406  static inline Tuple getTombstoneKey() {
408  }
409 
410  template <unsigned I>
411  static unsigned getHashValueImpl(const Tuple &values, std::false_type) {
412  using EltType = typename std::tuple_element<I, Tuple>::type;
413  std::integral_constant<bool, I + 1 == sizeof...(Ts)> atEnd;
414  return llvm_combineHashValue(
415  DenseMapInfo<EltType>::getHashValue(std::get<I>(values)),
416  getHashValueImpl<I + 1>(values, atEnd));
417  }
418 
419  template <unsigned I>
420  static unsigned getHashValueImpl(const Tuple &values, std::true_type) {
421  return 0;
422  }
423 
424  static unsigned getHashValue(const std::tuple<Ts...> &values) {
425  std::integral_constant<bool, 0 == sizeof...(Ts)> atEnd;
426  return getHashValueImpl<0>(values, atEnd);
427  }
428 
429  template <unsigned I>
430  static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::false_type) {
431  using EltType = typename std::tuple_element<I, Tuple>::type;
432  std::integral_constant<bool, I + 1 == sizeof...(Ts)> atEnd;
433  return DenseMapInfo<EltType>::isEqual(std::get<I>(lhs), std::get<I>(rhs)) &&
434  isEqualImpl<I + 1>(lhs, rhs, atEnd);
435  }
436 
437  template <unsigned I>
438  static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::true_type) {
439  return true;
440  }
441 
442  static bool isEqual(const Tuple &lhs, const Tuple &rhs) {
443  std::integral_constant<bool, 0 == sizeof...(Ts)> atEnd;
444  return isEqualImpl<0>(lhs, rhs, atEnd);
445  }
446 };
447 
448 } // end namespace llvm
449 
450 #endif // MLIR_SUPPORT_STLEXTRAS_H
Definition: InferTypeOpInterface.cpp:20
static Tuple getEmptyKey()
Definition: STLExtras.h:402
typename std::tuple_element< i, std::tuple< Args... > >::type arg_t
The type of an argument to this function.
Definition: STLExtras.h:376
bool has_single_element(ContainerTy &&c)
Returns true of the given range only contains a single element.
Definition: STLExtras.h:336
Definition: STLExtras.h:293
ReturnType result_t
The result type of this function.
Definition: STLExtras.h:372
Definition: STLExtras.h:95
Definition: PassRegistry.cpp:413
DerivedT & operator+=(ptrdiff_t offset)
Definition: STLExtras.h:162
ptrdiff_t count
The size from the owning range.
Definition: STLExtras.h:280
Definition: LLVM.h:45
indexed_accessor_iterator(BaseT base, ptrdiff_t index)
Definition: STLExtras.h:178
const BaseT & getBase() const
Returns the current base of the range.
Definition: STLExtras.h:306
Definition: STLExtras.h:350
static bool isEqual(const Tuple &lhs, const Tuple &rhs)
Definition: STLExtras.h:442
bool empty() const
Return if the range is empty.
Definition: STLExtras.h:239
typename std::tuple_element< i, std::tuple< Args... > >::type arg_t
The type of an argument to this function.
Definition: STLExtras.h:363
Definition: STLExtras.h:115
std::false_type value_t
Definition: STLExtras.h:116
BaseT base
Definition: STLExtras.h:180
BaseT base
The base that owns the provided range of values.
Definition: STLExtras.h:278
DerivedT take_front(size_t n=1) const
Take the first n elements.
Definition: STLExtras.h:259
ptrdiff_t getIndex() const
Returns the current index of the iterator.
Definition: STLExtras.h:172
static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::false_type)
Definition: STLExtras.h:430
DerivedT drop_front(size_t n=1) const
Drop the first n elements.
Definition: STLExtras.h:248
std::true_type value_t
Definition: STLExtras.h:120
static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::true_type)
Definition: STLExtras.h:438
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
Definition: STLExtras.h:297
bool operator<(const indexed_accessor_iterator &rhs) const
Definition: STLExtras.h:157
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
Definition: STLExtras.h:225
typename detail::detector< void, Op, Args... >::value_t is_detected
Definition: STLExtras.h:125
Definition: StandardTypes.h:62
An iterator element of this range.
Definition: STLExtras.h:202
ReturnType result_t
The result type of this function.
Definition: STLExtras.h:359
void interleaveComma(const Container &c, raw_ostream &os, UnaryFunctor each_fn)
Definition: STLExtras.h:81
static unsigned getHashValueImpl(const Tuple &values, std::true_type)
Definition: STLExtras.h:420
iterator begin() const
Definition: STLExtras.h:228
static std::pair< BaseT, ptrdiff_t > offset_base(const std::pair< BaseT, ptrdiff_t > &base, ptrdiff_t index)
See detail::indexed_accessor_range_base for details.
Definition: STLExtras.h:313
ptrdiff_t getStartIndex() const
Returns the current start index of the range.
Definition: STLExtras.h:309
DerivedT & operator-=(ptrdiff_t offset)
Definition: STLExtras.h:166
indexed_accessor_range_base(iterator begin, iterator end)
Definition: STLExtras.h:220
Definition: STLExtras.h:145
mlir::edsc::intrinsics::ValueBuilder< SliceOp > slice
Definition: Intrinsics.h:24
ptrdiff_t index
Definition: STLExtras.h:181
auto make_second_range(ContainerTy &&c)
Given a container of pairs, return a range over the second elements.
Definition: STLExtras.h:327
Definition: LLVM.h:35
void interleave(ForwardIterator begin, ForwardIterator end, UnaryFunctor each_fn, NullaryFunctor between_fn)
Definition: STLExtras.h:43
NestedPattern Op(FilterFunctionType filter)
Definition: NestedMatcher.cpp:111
ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const
Definition: STLExtras.h:150
iterator end() const
Definition: STLExtras.h:229
Definition: LLVM.h:50
decltype(std::declval< Callable & >()(std::declval< Args >()...)) is_invocable
Definition: STLExtras.h:131
void void_t
Definition: STLExtras.h:114
static ReferenceT dereference_iterator(const std::pair< BaseT, ptrdiff_t > &base, ptrdiff_t index)
See detail::indexed_accessor_range_base for details.
Definition: STLExtras.h:320
typename std::remove_reference< decltype(*std::begin(std::declval< RangeT & >()))>::type ValueOfRange
Definition: STLExtras.h:25
static ClassID * getID()
Definition: STLExtras.h:96
std::tuple< Ts... > Tuple
Definition: STLExtras.h:400
bool operator==(const indexed_accessor_iterator &rhs) const
Definition: STLExtras.h:154
static unsigned getHashValueImpl(const Tuple &values, std::false_type)
Definition: STLExtras.h:411
Definition: OpDefinition.h:949
size_t size() const
Return the size of this range.
Definition: STLExtras.h:236
ReferenceT operator*() const
Definition: STLExtras.h:206
indexed_accessor_range_base(const iterator_range< iterator > &range)
Definition: STLExtras.h:223
static unsigned getHashValue(const std::tuple< Ts... > &values)
Definition: STLExtras.h:424
is_detected< detail::is_invocable, Callable, Args... > is_invocable
Definition: STLExtras.h:135
mlir::edsc::intrinsics::ValueBuilder< RangeOp > range
Definition: Intrinsics.h:23
const BaseT & getBase() const
Returns the current base of the iterator.
Definition: STLExtras.h:175
DerivedT slice(size_t n, size_t m) const
Drop the first N elements, and keep M elements.
Definition: STLExtras.h:242
ReferenceT operator[](unsigned index) const
Definition: STLExtras.h:230
static Tuple getTombstoneKey()
Definition: STLExtras.h:406
DerivedT drop_back(size_t n=1) const
Drop the last n elements.
Definition: STLExtras.h:253