My Project
Helpers.h
Go to the documentation of this file.
1 //===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides helper classes and syntactic sugar for declarative builders.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_EDSC_HELPERS_H_
14 #define MLIR_EDSC_HELPERS_H_
15 
16 #include "mlir/EDSC/Builders.h"
17 #include "mlir/EDSC/Intrinsics.h"
18 
19 namespace mlir {
20 namespace edsc {
21 
22 // A TemplatedIndexedValue brings an index notation over the template Load and
23 // Store parameters.
24 template <typename Load, typename Store> class TemplatedIndexedValue;
25 
26 // By default, edsc::IndexedValue provides an index notation around the affine
27 // load and stores. edsc::StdIndexedValue provides the standard load/store
28 // counterpart.
29 using IndexedValue =
31 using StdIndexedValue =
33 
34 // Base class for MemRefView and VectorView.
35 class View {
36 public:
37  unsigned rank() const { return lbs.size(); }
38  ValueHandle lb(unsigned idx) { return lbs[idx]; }
39  ValueHandle ub(unsigned idx) { return ubs[idx]; }
40  int64_t step(unsigned idx) { return steps[idx]; }
41  std::tuple<ValueHandle, ValueHandle, int64_t> range(unsigned idx) {
42  return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
43  }
44  void swapRanges(unsigned i, unsigned j) {
45  if (i == j)
46  return;
47  lbs[i].swap(lbs[j]);
48  ubs[i].swap(ubs[j]);
49  std::swap(steps[i], steps[j]);
50  }
51 
55 
56 protected:
60 };
61 
66 // TODO(ntv): Support MemRefs with layoutMaps.
67 class MemRefView : public View {
68 public:
69  explicit MemRefView(Value v);
70  MemRefView(const MemRefView &) = default;
71  MemRefView &operator=(const MemRefView &) = default;
72 
73  unsigned fastestVarying() const { return rank() - 1; }
74 
75 private:
76  friend IndexedValue;
77  ValueHandle base;
78 };
79 
83 class VectorView : public View {
84 public:
85  explicit VectorView(Value v);
86  VectorView(const VectorView &) = default;
87  VectorView &operator=(const VectorView &) = default;
88 
89 private:
90  friend IndexedValue;
91  ValueHandle base;
92 };
93 
111 template <typename Load, typename Store> class TemplatedIndexedValue {
112 public:
113  explicit TemplatedIndexedValue(Type t) : base(t) {}
116  explicit TemplatedIndexedValue(ValueHandle v) : base(v) {}
117 
118  TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default;
119 
120  TemplatedIndexedValue operator()() { return *this; }
123  TemplatedIndexedValue res(base);
124  res.indices.push_back(index);
125  return res;
126  }
127  template <typename... Args>
128  TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
129  return TemplatedIndexedValue(base, index).append(indices...);
130  }
132  return TemplatedIndexedValue(base, indices);
133  }
135  return TemplatedIndexedValue(
136  base, ArrayRef<ValueHandle>(indices.begin(), indices.end()));
137  }
138 
140  // NOLINTNEXTLINE: unconventional-assign-operator
142  ValueHandle rrhs(rhs);
143  return Store(rrhs, getBase(), {indices.begin(), indices.end()});
144  }
145  // NOLINTNEXTLINE: unconventional-assign-operator
147  return Store(rhs, getBase(), {indices.begin(), indices.end()});
148  }
149 
151  operator ValueHandle() const {
152  return Load(getBase(), {indices.begin(), indices.end()});
153  }
154 
156  Value operator*(void) const {
157  return Load(getBase(), {indices.begin(), indices.end()}).getValue();
158  }
159 
160  ValueHandle getBase() const { return base; }
161 
167  OperationHandle operator+=(ValueHandle e);
168  OperationHandle operator-=(ValueHandle e);
169  OperationHandle operator*=(ValueHandle e);
170  OperationHandle operator/=(ValueHandle e);
172  return *this + static_cast<ValueHandle>(e);
173  }
175  return *this - static_cast<ValueHandle>(e);
176  }
178  return *this * static_cast<ValueHandle>(e);
179  }
181  return *this / static_cast<ValueHandle>(e);
182  }
184  return this->operator+=(static_cast<ValueHandle>(e));
185  }
187  return this->operator-=(static_cast<ValueHandle>(e));
188  }
190  return this->operator*=(static_cast<ValueHandle>(e));
191  }
193  return this->operator/=(static_cast<ValueHandle>(e));
194  }
195 
196 private:
198  : base(base), indices(indices.begin(), indices.end()) {}
199 
200  TemplatedIndexedValue &append() { return *this; }
201 
202  template <typename T, typename... Args>
203  TemplatedIndexedValue &append(T index, Args... indices) {
204  this->indices.push_back(static_cast<ValueHandle>(index));
205  append(indices...);
206  return *this;
207  }
208  ValueHandle base;
210 };
211 
213 template <typename Load, typename Store>
215  using op::operator+;
216  return static_cast<ValueHandle>(*this) + e;
217 }
218 template <typename Load, typename Store>
220  using op::operator-;
221  return static_cast<ValueHandle>(*this) - e;
222 }
223 template <typename Load, typename Store>
225  using op::operator*;
226  return static_cast<ValueHandle>(*this) * e;
227 }
228 template <typename Load, typename Store>
230  using op::operator/;
231  return static_cast<ValueHandle>(*this) / e;
232 }
233 
234 template <typename Load, typename Store>
236  using op::operator+;
237  return Store(*this + e, getBase(), {indices.begin(), indices.end()});
238 }
239 template <typename Load, typename Store>
241  using op::operator-;
242  return Store(*this - e, getBase(), {indices.begin(), indices.end()});
243 }
244 template <typename Load, typename Store>
246  using op::operator*;
247  return Store(*this * e, getBase(), {indices.begin(), indices.end()});
248 }
249 template <typename Load, typename Store>
251  using op::operator/;
252  return Store(*this / e, getBase(), {indices.begin(), indices.end()});
253 }
254 
255 } // namespace edsc
256 } // namespace mlir
257 
258 #endif // MLIR_EDSC_HELPERS_H_
unsigned fastestVarying() const
Definition: Helpers.h:73
Definition: InferTypeOpInterface.cpp:20
TemplatedIndexedValue(Value v)
Definition: Helpers.h:114
TemplatedIndexedValue(ValueHandle v)
Definition: Helpers.h:116
TemplatedIndexedValue operator()(ValueHandle index, Args... indices)
Definition: Helpers.h:128
std::tuple< ValueHandle, ValueHandle, int64_t > range(unsigned idx)
Definition: Helpers.h:41
Definition: Builders.h:290
OperationHandle operator+=(ValueHandle e)
Definition: Helpers.h:235
IntInfty operator+(IntInfty lhs, IntInfty rhs)
Definition: SDBM.h:57
OperationHandle operator*=(TemplatedIndexedValue e)
Definition: Helpers.h:189
Definition: Helpers.h:67
TemplatedIndexedValue operator()(ValueHandle index)
Returns a new TemplatedIndexedValue.
Definition: Helpers.h:122
OperationHandle operator/=(ValueHandle e)
Definition: Helpers.h:250
Definition: Helpers.h:83
Definition: LLVM.h:37
OperationHandle operator=(const TemplatedIndexedValue &rhs)
Emits a store.
Definition: Helpers.h:141
TemplatedIndexedValue operator()(ArrayRef< ValueHandle > indices)
Definition: Helpers.h:131
TemplatedIndexedValue(Type t)
Definition: Helpers.h:113
ValueHandle operator*(TemplatedIndexedValue e)
Definition: Helpers.h:177
ValueHandle operator+(TemplatedIndexedValue e)
Definition: Helpers.h:171
Definition: Helpers.h:35
ValueHandle operator/(ValueHandle lhs, ValueHandle rhs)
Definition: Builders.cpp:383
unsigned rank() const
Definition: Helpers.h:37
ValueHandle operator/(ValueHandle e)
Definition: Helpers.h:229
int64_t step(unsigned idx)
Definition: Helpers.h:40
ValueHandle operator-(TemplatedIndexedValue e)
Definition: Helpers.h:174
ValueHandle operator+(ValueHandle e)
Operator overloadings.
Definition: Helpers.h:214
OperationHandle operator-=(ValueHandle e)
Definition: Helpers.h:240
OperationHandle operator=(ValueHandle rhs)
Definition: Helpers.h:146
Definition: Types.h:84
ValueHandle getBase() const
Definition: Helpers.h:160
void swapRanges(unsigned i, unsigned j)
Definition: Helpers.h:44
ValueHandle ub(unsigned idx)
Definition: Helpers.h:39
Definition: Builders.h:385
Definition: Value.h:38
Definition: LLVM.h:35
ArrayRef< ValueHandle > getUbs()
Definition: Helpers.h:53
OperationHandle operator/=(TemplatedIndexedValue e)
Definition: Helpers.h:192
OperationHandle operator+=(TemplatedIndexedValue e)
Definition: Helpers.h:183
TemplatedIndexedValue operator()()
Definition: Helpers.h:120
SmallVector< int64_t, 8 > steps
Definition: Helpers.h:59
AffineExpr operator-(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:207
Value operator*(void) const
Emits a load when converting to a Value.
Definition: Helpers.h:156
TemplatedIndexedValue operator()(ArrayRef< IndexHandle > indices)
Definition: Helpers.h:134
ValueHandle lb(unsigned idx)
Definition: Helpers.h:38
Definition: Helpers.h:24
SmallVector< ValueHandle, 8 > lbs
Definition: Helpers.h:57
SmallVector< ValueHandle, 8 > ubs
Definition: Helpers.h:58
AffineExpr operator*(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:206
TemplatedIndexedValue< intrinsics::affine_load, intrinsics::affine_store > IndexedValue
Definition: Helpers.h:30
OperationHandle operator-=(TemplatedIndexedValue e)
Definition: Helpers.h:186
ValueHandle operator/(TemplatedIndexedValue e)
Definition: Helpers.h:180
ArrayRef< int64_t > getSteps()
Definition: Helpers.h:54
ValueHandle operator-(ValueHandle e)
Definition: Helpers.h:219
OperationHandle operator*=(ValueHandle e)
Definition: Helpers.h:245
ArrayRef< ValueHandle > getLbs()
Definition: Helpers.h:52