My Project
TypeSwitch.h
Go to the documentation of this file.
1 //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the TypeSwitch template, which mimics a switch()
10 // statement whose cases are type names.
11 //
12 //===-----------------------------------------------------------------------===/
13 
14 #ifndef MLIR_SUPPORT_TYPESWITCH_H
15 #define MLIR_SUPPORT_TYPESWITCH_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/Support/STLExtras.h"
19 #include "llvm/ADT/Optional.h"
20 
21 namespace mlir {
22 namespace detail {
23 
24 template <typename DerivedT, typename T> class TypeSwitchBase {
25 public:
26  TypeSwitchBase(const T &value) : value(value) {}
28  ~TypeSwitchBase() = default;
29 
31  TypeSwitchBase(const TypeSwitchBase &) = delete;
32  void operator=(const TypeSwitchBase &) = delete;
33  void operator=(TypeSwitchBase &&other) = delete;
34 
36  template <typename CaseT, typename CaseT2, typename... CaseTs,
37  typename CallableT>
38  DerivedT &Case(CallableT &&caseFn) {
39  DerivedT &derived = static_cast<DerivedT &>(*this);
40  return derived.template Case<CaseT>(caseFn)
41  .template Case<CaseT2, CaseTs...>(caseFn);
42  }
43 
48  template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
50  using CaseT = std::remove_cv_t<std::remove_pointer_t<
51  std::remove_reference_t<typename Traits::template arg_t<0>>>>;
52 
53  DerivedT &derived = static_cast<DerivedT &>(*this);
54  return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
55  }
56 
57 protected:
60  template <typename ValueT, typename CastT>
61  using has_dyn_cast_t =
62  decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
63 
66  template <typename CastT, typename ValueT>
67  static auto castValue(
68  ValueT value,
69  typename std::enable_if_t<
71  return value.template dyn_cast<CastT>();
72  }
73 
76  template <typename CastT, typename ValueT>
77  static auto castValue(
78  ValueT value,
79  typename std::enable_if_t<
81  return dyn_cast<CastT>(value);
82  }
83 
85  const T value;
86 };
87 } // end namespace detail
88 
100 template <typename T, typename ResultT = void>
101 class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
102 public:
104  using BaseT::BaseT;
105  using BaseT::Case;
106  TypeSwitch(TypeSwitch &&other) = default;
107 
109  template <typename CaseT, typename CallableT>
110  TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
111  if (result)
112  return *this;
113 
114  // Check to see if CaseT applies to 'value'.
115  if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
116  result = caseFn(caseValue);
117  return *this;
118  }
119 
121  template <typename CallableT>
122  LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
123  if (result)
124  return std::move(*result);
125  return defaultFn(this->value);
126  }
127 
128  LLVM_NODISCARD
129  operator ResultT() {
130  assert(result && "Fell off the end of a type-switch");
131  return std::move(*result);
132  }
133 
134 private:
137  Optional<ResultT> result;
138 };
139 
141 template <typename T>
142 class TypeSwitch<T, void>
143  : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
144 public:
146  using BaseT::BaseT;
147  using BaseT::Case;
148  TypeSwitch(TypeSwitch &&other) = default;
149 
151  template <typename CaseT, typename CallableT>
152  TypeSwitch<T, void> &Case(CallableT &&caseFn) {
153  if (foundMatch)
154  return *this;
155 
156  // Check to see if any of the types apply to 'value'.
157  if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
158  caseFn(caseValue);
159  foundMatch = true;
160  }
161  return *this;
162  }
163 
165  template <typename CallableT> void Default(CallableT &&defaultFn) {
166  if (!foundMatch)
167  defaultFn(this->value);
168  }
169 
170 private:
172  bool foundMatch = false;
173 };
174 } // end namespace mlir
175 
176 #endif // MLIR_SUPPORT_TYPESWITCH_H
Definition: InferTypeOpInterface.cpp:20
Definition: TypeSwitch.h:24
TypeSwitchBase(const T &value)
Definition: TypeSwitch.h:26
static auto castValue(ValueT value, typename std::enable_if_t< !is_detected< has_dyn_cast_t, ValueT, CastT >::value > *=nullptr)
Definition: TypeSwitch.h:77
LLVM_NODISCARD ResultT Default(CallableT &&defaultFn)
As a default, invoke the given callable within the root value.
Definition: TypeSwitch.h:122
Definition: STLExtras.h:350
Definition: LLVM.h:40
TypeSwitchBase(TypeSwitchBase &&other)
Definition: TypeSwitch.h:27
decltype(std::declval< ValueT &>().template dyn_cast< CastT >()) has_dyn_cast_t
Definition: TypeSwitch.h:62
static auto castValue(ValueT value, typename std::enable_if_t< is_detected< has_dyn_cast_t, ValueT, CastT >::value > *=nullptr)
Definition: TypeSwitch.h:67
TypeSwitch< T, void > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: TypeSwitch.h:152
const T value
The root value we are switching on.
Definition: TypeSwitch.h:85
typename detail::detector< void, Op, Args... >::value_t is_detected
Definition: STLExtras.h:125
void operator=(const TypeSwitchBase &)=delete
Definition: TypeSwitch.h:101
DerivedT & Case(CallableT &&caseFn)
Invoke a case on the derived class with multiple case types.
Definition: TypeSwitch.h:38
void Default(CallableT &&defaultFn)
As a default, invoke the given callable within the root value.
Definition: TypeSwitch.h:165
Specialization of TypeSwitch for void returning callables.
Definition: TypeSwitch.h:142
TypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: TypeSwitch.h:110
DerivedT & Case(CallableT &&caseFn)
Definition: TypeSwitch.h:48