My Project
AffineExprVisitor.h
Go to the documentation of this file.
1 //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the AffineExpr visitor class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_AFFINE_EXPR_VISITOR_H
14 #define MLIR_IR_AFFINE_EXPR_VISITOR_H
15 
16 #include "mlir/IR/AffineExpr.h"
17 
18 namespace mlir {
19 
56 // expressions: AffineConstantExpr, AffineDimExpr, and
57 // AffineSymbolExpr.
66 
67 template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
68  //===--------------------------------------------------------------------===//
69  // Interface code - This is the public interface of the AffineExprVisitor
70  // that you use to visit affine expressions...
71 public:
72  // Function to walk an AffineExpr (in post order).
73  RetTy walkPostOrder(AffineExpr expr) {
74  static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
75  "Must instantiate with a derived type of AffineExprVisitor");
76  switch (expr.getKind()) {
77  case AffineExprKind::Add: {
78  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
79  walkOperandsPostOrder(binOpExpr);
80  return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
81  }
82  case AffineExprKind::Mul: {
83  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
84  walkOperandsPostOrder(binOpExpr);
85  return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
86  }
87  case AffineExprKind::Mod: {
88  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
89  walkOperandsPostOrder(binOpExpr);
90  return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
91  }
93  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
94  walkOperandsPostOrder(binOpExpr);
95  return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
96  }
98  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
99  walkOperandsPostOrder(binOpExpr);
100  return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
101  }
103  return static_cast<SubClass *>(this)->visitConstantExpr(
104  expr.cast<AffineConstantExpr>());
106  return static_cast<SubClass *>(this)->visitDimExpr(
107  expr.cast<AffineDimExpr>());
109  return static_cast<SubClass *>(this)->visitSymbolExpr(
110  expr.cast<AffineSymbolExpr>());
111  }
112  }
113 
114  // Function to visit an AffineExpr.
115  RetTy visit(AffineExpr expr) {
116  static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
117  "Must instantiate with a derived type of AffineExprVisitor");
118  switch (expr.getKind()) {
119  case AffineExprKind::Add: {
120  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
121  return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
122  }
123  case AffineExprKind::Mul: {
124  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
125  return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
126  }
127  case AffineExprKind::Mod: {
128  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
129  return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
130  }
132  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
133  return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
134  }
136  auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
137  return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
138  }
140  return static_cast<SubClass *>(this)->visitConstantExpr(
141  expr.cast<AffineConstantExpr>());
143  return static_cast<SubClass *>(this)->visitDimExpr(
144  expr.cast<AffineDimExpr>());
146  return static_cast<SubClass *>(this)->visitSymbolExpr(
147  expr.cast<AffineSymbolExpr>());
148  }
149  llvm_unreachable("Unknown AffineExpr");
150  }
151 
152  //===--------------------------------------------------------------------===//
153  // Visitation functions... these functions provide default fallbacks in case
154  // the user does not specify what to do for a particular instruction type.
155  // The default behavior is to generalize the instruction type to its subtype
156  // and try visiting the subtype. All of this should be inlined perfectly,
157  // because there are no virtual functions to get in the way.
158  //
159 
160  // Default visit methods. Note that the default op-specific binary op visit
161  // methods call the general visitAffineBinaryOpExpr visit method.
164  static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
165  }
167  static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
168  }
170  static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
171  }
173  static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
174  }
176  static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
177  }
181 
182 private:
183  // Walk the operands - each operand is itself walked in post order.
184  void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
185  walkPostOrder(expr.getLHS());
186  walkPostOrder(expr.getRHS());
187  }
188 };
189 
190 // This class is used to flatten a pure affine expression (AffineExpr,
191 // which is in a tree form) into a sum of products (w.r.t constants) when
192 // possible, and in that process simplifying the expression. For a modulo,
193 // floordiv, or a ceildiv expression, an additional identifier, called a local
194 // identifier, is introduced to rewrite the expression as a sum of product
195 // affine expression. Each local identifier is always and by construction a
196 // floordiv of a pure add/mul affine function of dimensional, symbolic, and
197 // other local identifiers, in a non-mutually recursive way. Hence, every local
198 // identifier can ultimately always be recovered as an affine function of
199 // dimensional and symbolic identifiers (involving floordiv's); note however
200 // that by AffineExpr construction, some floordiv combinations are converted to
201 // mod's. The result of the flattening is a flattened expression and a set of
202 // constraints involving just the local variables.
203 //
204 // d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
205 // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
206 //
207 // The simplification performed includes the accumulation of contributions for
208 // each dimensional and symbolic identifier together, the simplification of
209 // floordiv/ceildiv/mod expressions and other simplifications that in turn
210 // happen as a result. A simplification that this flattening naturally performs
211 // is of simplifying the numerator and denominator of floordiv/ceildiv, and
212 // folding a modulo expression to a zero, if possible. Three examples are below:
213 //
214 // (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
215 // (d0 - d0 mod 4 + 4) mod 4 simplified to 0
216 // (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
217 //
218 // The way the flattening works for the second example is as follows: d0 % 4 is
219 // replaced by d0 - 4*q with q being introduced: the expression then simplifies
220 // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
221 // zero. Note that an affine expression may not always be expressible purely as
222 // a sum of products involving just the original dimensional and symbolic
223 // identifiers due to the presence of modulo/floordiv/ceildiv expressions that
224 // may not be eliminated after simplification; in such cases, the final
225 // expression can be reconstructed by replacing the local identifiers with their
226 // corresponding explicit form stored in 'localExprs' (note that each of the
227 // explicit forms itself would have been simplified).
228 //
229 // The expression walk method here performs a linear time post order walk that
230 // performs the above simplifications through visit methods, with partial
231 // results being stored in 'operandExprStack'. When a parent expr is visited,
232 // the flattened expressions corresponding to its two operands would already be
233 // on the stack - the parent expression looks at the two flattened expressions
234 // and combines the two. It pops off the operand expressions and pushes the
235 // combined result (although this is done in-place on its LHS operand expr).
236 // When the walk is completed, the flattened form of the top-level expression
237 // would be left on the stack.
238 //
239 // A flattener can be repeatedly used for multiple affine expressions that bind
240 // to the same operands, for example, for all result expressions of an
241 // AffineMap or AffineValueMap. In such cases, using it for multiple expressions
242 // is more efficient than creating a new flattener for each expression since
243 // common identical div and mod expressions appearing across different
244 // expressions are mapped to the same local identifier (same column position in
245 // 'localVarCst').
247  : public AffineExprVisitor<SimpleAffineExprFlattener> {
248 public:
249  // Flattend expression layout: [dims, symbols, locals, constant]
250  // Stack that holds the LHS and RHS operands while visiting a binary op expr.
251  // In future, consider adding a prepass to determine how big the SmallVector's
252  // will be, and linearize this to std::vector<int64_t> to prevent
253  // SmallVector moves on re-allocation.
254  std::vector<SmallVector<int64_t, 8>> operandExprStack;
255 
256  unsigned numDims;
257  unsigned numSymbols;
258 
259  // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
260  unsigned numLocals;
261 
262  // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
263  // which new identifiers were introduced; if the latter do not get canceled
264  // out, these expressions can be readily used to reconstruct the AffineExpr
265  // (tree) form. Note that these expressions themselves would have been
266  // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
267  // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
268  // ceildiv 2 would be the local expression stored for q.
270 
271  SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
272 
273  virtual ~SimpleAffineExprFlattener() = default;
274 
275  // Visitor method overrides.
276  void visitMulExpr(AffineBinaryOpExpr expr);
277  void visitAddExpr(AffineBinaryOpExpr expr);
278  void visitDimExpr(AffineDimExpr expr);
283 
284  //
285  // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
286  //
287  // A mod expression "expr mod c" is thus flattened by introducing a new local
288  // variable q (= expr floordiv c), such that expr mod c is replaced with
289  // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
290  void visitModExpr(AffineBinaryOpExpr expr);
291 
292 protected:
293  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
294  // The local identifier added is always a floordiv of a pure add/mul affine
295  // function of other identifiers, coefficients of which are specified in
296  // dividend and with respect to a positive constant divisor. localExpr is the
297  // simplified tree expression (AffineExpr) corresponding to the quantifier.
298  virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
299  AffineExpr localExpr);
300 
301 private:
302  // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
303  // A floordiv is thus flattened by introducing a new local variable q, and
304  // replacing that expression with 'q' while adding the constraints
305  // c * q <= expr <= c * q + c - 1 to localVarCst (done by
306  // FlatAffineConstraints::addLocalFloorDiv).
307  //
308  // A ceildiv is similarly flattened:
309  // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
310  void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
311 
312  int findLocalId(AffineExpr localExpr);
313 
314  inline unsigned getNumCols() const {
315  return numDims + numSymbols + numLocals + 1;
316  }
317  inline unsigned getConstantIndex() const { return getNumCols() - 1; }
318  inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
319  inline unsigned getSymbolStartIndex() const { return numDims; }
320  inline unsigned getDimStartIndex() const { return 0; }
321 };
322 
323 } // end namespace mlir
324 
325 #endif // MLIR_IR_AFFINE_EXPR_VISITOR_H
Definition: AffineExpr.h:168
Definition: InferTypeOpInterface.cpp:20
Definition: AffineExprVisitor.h:246
RetTy walkPostOrder(AffineExpr expr)
Definition: AffineExprVisitor.h:73
unsigned numDims
Definition: AffineExprVisitor.h:256
An integer constant appearing in affine expression.
Definition: AffineExpr.h:193
void visitMulExpr(AffineBinaryOpExpr expr)
Definition: AffineExprVisitor.h:166
AffineExpr getRHS() const
Definition: AffineExpr.cpp:232
Definition: AffineExprVisitor.h:67
Definition: LLVM.h:37
void visitSymbolExpr(AffineSymbolExpr expr)
Definition: AffineExprVisitor.h:180
void visitDimExpr(AffineDimExpr expr)
Definition: AffineExprVisitor.h:179
void visitFloorDivExpr(AffineBinaryOpExpr expr)
Definition: AffineExprVisitor.h:172
AffineExpr getLHS() const
Definition: AffineExpr.cpp:229
void visitModExpr(AffineBinaryOpExpr expr)
Definition: AffineExprVisitor.h:169
void visitConstantExpr(AffineConstantExpr expr)
Definition: AffineExprVisitor.h:178
Definition: AffineExpr.h:66
RHS of mul is always a constant or a symbolic expression.
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr)
Definition: AffineExprVisitor.h:162
U cast() const
Definition: AffineExpr.h:247
RHS of floordiv is always a constant or a symbolic expression.
RHS of ceildiv is always a constant or a symbolic expression.
unsigned numLocals
Definition: AffineExprVisitor.h:260
RetTy visit(AffineExpr expr)
Definition: AffineExprVisitor.h:115
Definition: LLVM.h:35
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:23
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:177
void visitAddExpr(AffineBinaryOpExpr expr)
Definition: AffineExprVisitor.h:163
unsigned numSymbols
Definition: AffineExprVisitor.h:257
std::vector< SmallVector< int64_t, 8 > > operandExprStack
Definition: AffineExprVisitor.h:254
void visitCeilDivExpr(AffineBinaryOpExpr expr)
Definition: AffineExprVisitor.h:175
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:185
SmallVector< AffineExpr, 4 > localExprs
Definition: AffineExprVisitor.h:269