My Project
SDBMExpr.h
Go to the documentation of this file.
1 //===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A striped difference-bound matrix (SDBM) expression is a constant expression,
10 // an identifier, a binary expression with constant RHS and +, stripe operators
11 // or a difference expression between two identifiers.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_SDBM_SDBMEXPR_H
16 #define MLIR_DIALECT_SDBM_SDBMEXPR_H
17 
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/DenseMapInfo.h"
20 
21 namespace mlir {
22 
23 class AffineExpr;
24 class MLIRContext;
25 
27 
28 namespace detail {
29 struct SDBMExprStorage;
30 struct SDBMBinaryExprStorage;
31 struct SDBMDiffExprStorage;
32 struct SDBMTermExprStorage;
33 struct SDBMConstantExprStorage;
34 struct SDBMNegExprStorage;
35 } // namespace detail
36 
37 class SDBMConstantExpr;
38 class SDBMDialect;
39 class SDBMDimExpr;
40 class SDBMSymbolExpr;
41 class SDBMTermExpr;
42 
88 class SDBMExpr {
89 public:
91  SDBMExpr() : impl(nullptr) {}
92  /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
93 
96  SDBMExpr(const SDBMExpr &) = default;
97  SDBMExpr &operator=(const SDBMExpr &) = default;
98 
100  bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
101  bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
102 
105  explicit operator bool() const { return impl != nullptr; }
106  bool operator!() const { return !static_cast<bool>(*this); }
107 
110 
112  void print(raw_ostream &os) const;
113  void dump() const;
114 
116  template <typename U> bool isa() const { return U::isClassFor(*this); }
117  template <typename U> U dyn_cast() const {
118  if (!isa<U>())
119  return {};
120  return U(const_cast<SDBMExpr *>(this)->impl);
121  }
122  template <typename U> U cast() const {
123  assert(isa<U>() && "cast to incorrect subtype");
124  return U(const_cast<SDBMExpr *>(this)->impl);
125  }
126 
128  ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
129 
131  SDBMExprKind getKind() const;
132 
134  MLIRContext *getContext() const;
135 
137  SDBMDialect *getDialect() const;
138 
141  AffineExpr getAsAffineExpr() const;
142 
148  static Optional<SDBMExpr> tryConvertAffineExpr(AffineExpr affine);
149 
150 protected:
152 };
153 
155 class SDBMConstantExpr : public SDBMExpr {
156 public:
158 
159  using SDBMExpr::SDBMExpr;
160 
163  static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value);
164 
165  static bool isClassFor(const SDBMExpr &expr) {
166  return expr.getKind() == SDBMExprKind::Constant;
167  }
168 
169  int64_t getValue() const;
170 };
171 
178 class SDBMVaryingExpr : public SDBMExpr {
179 public:
181  using SDBMExpr::SDBMExpr;
182 
183  static bool isClassFor(const SDBMExpr &expr) {
184  return expr.getKind() == SDBMExprKind::DimId ||
185  expr.getKind() == SDBMExprKind::SymbolId ||
186  expr.getKind() == SDBMExprKind::Neg ||
187  expr.getKind() == SDBMExprKind::Stripe ||
188  expr.getKind() == SDBMExprKind::Add ||
189  expr.getKind() == SDBMExprKind::Diff;
190  }
191 };
192 
198 public:
199  using SDBMVaryingExpr::SDBMVaryingExpr;
200 
203  SDBMTermExpr getTerm();
204 
206  int64_t getConstant();
207 
208  static bool isClassFor(const SDBMExpr &expr) {
209  return expr.getKind() == SDBMExprKind::DimId ||
210  expr.getKind() == SDBMExprKind::SymbolId ||
211  expr.getKind() == SDBMExprKind::Stripe ||
212  expr.getKind() == SDBMExprKind::Add;
213  }
214 };
215 
221 class SDBMTermExpr : public SDBMDirectExpr {
222 public:
223  using SDBMDirectExpr::SDBMDirectExpr;
224 
225  static bool isClassFor(const SDBMExpr &expr) {
226  return expr.getKind() == SDBMExprKind::DimId ||
227  expr.getKind() == SDBMExprKind::SymbolId ||
228  expr.getKind() == SDBMExprKind::Stripe;
229  }
230 };
231 
233 class SDBMSumExpr : public SDBMDirectExpr {
234 public:
236  using SDBMDirectExpr::SDBMDirectExpr;
237 
239  static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs);
240 
241  static bool isClassFor(const SDBMExpr &expr) {
242  SDBMExprKind kind = expr.getKind();
243  return kind == SDBMExprKind::Add;
244  }
245 
246  SDBMTermExpr getLHS() const;
247  SDBMConstantExpr getRHS() const;
248 };
249 
256 public:
258  using SDBMVaryingExpr::SDBMVaryingExpr;
259 
261  static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs);
262 
263  static bool isClassFor(const SDBMExpr &expr) {
264  return expr.getKind() == SDBMExprKind::Diff;
265  }
266 
267  SDBMDirectExpr getLHS() const;
268  SDBMTermExpr getRHS() const;
269 };
270 
274 class SDBMStripeExpr : public SDBMTermExpr {
275 public:
277  using SDBMTermExpr::SDBMTermExpr;
278 
279  static bool isClassFor(const SDBMExpr &expr) {
280  return expr.getKind() == SDBMExprKind::Stripe;
281  }
282 
283  static SDBMStripeExpr get(SDBMDirectExpr var, SDBMConstantExpr stripeFactor);
284 
285  SDBMDirectExpr getLHS() const;
286  SDBMConstantExpr getStripeFactor() const;
287 };
288 
293 class SDBMInputExpr : public SDBMTermExpr {
294 public:
296  using SDBMTermExpr::SDBMTermExpr;
297 
298  static bool isClassFor(const SDBMExpr &expr) {
299  return expr.getKind() == SDBMExprKind::DimId ||
301  }
302 
303  unsigned getPosition() const;
304 };
305 
308 class SDBMDimExpr : public SDBMInputExpr {
309 public:
311  using SDBMInputExpr::SDBMInputExpr;
312 
315  static SDBMDimExpr get(SDBMDialect *dialect, unsigned position);
316 
317  static bool isClassFor(const SDBMExpr &expr) {
318  return expr.getKind() == SDBMExprKind::DimId;
319  }
320 };
321 
325 public:
327  using SDBMInputExpr::SDBMInputExpr;
328 
331  static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position);
332 
333  static bool isClassFor(const SDBMExpr &expr) {
334  return expr.getKind() == SDBMExprKind::SymbolId;
335  }
336 };
337 
340 class SDBMNegExpr : public SDBMVaryingExpr {
341 public:
343  using SDBMVaryingExpr::SDBMVaryingExpr;
344 
346  static SDBMNegExpr get(SDBMDirectExpr var);
347 
348  static bool isClassFor(const SDBMExpr &expr) {
349  return expr.getKind() == SDBMExprKind::Neg;
350  }
351 
352  SDBMDirectExpr getVar() const;
353 };
354 
357 template <typename Derived, typename Result = void> class SDBMVisitor {
358 public:
360  Result visit(SDBMExpr expr) {
361  auto *derived = static_cast<Derived *>(this);
362  switch (expr.getKind()) {
363  case SDBMExprKind::Add:
364  case SDBMExprKind::Diff:
365  case SDBMExprKind::DimId:
367  case SDBMExprKind::Neg:
369  return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
371  return derived->visitConstant(expr.cast<SDBMConstantExpr>());
372  }
373 
374  llvm_unreachable("unsupported SDBM expression kind");
375  }
376 
379  void walkPreorder(SDBMExpr expr) { return walk</*isPreorder=*/true>(expr); }
380 
383  void walkPostorder(SDBMExpr expr) { return walk</*isPreorder=*/false>(expr); }
384 
385 protected:
394 
399  auto *derived = static_cast<Derived *>(this);
400  if (auto sum = expr.dyn_cast<SDBMSumExpr>())
401  return derived->visitSum(sum);
402  else
403  return derived->visitTerm(expr.cast<SDBMTermExpr>());
404  }
405 
408  Result visitTerm(SDBMTermExpr expr) {
409  auto *derived = static_cast<Derived *>(this);
410  if (expr.getKind() == SDBMExprKind::Stripe)
411  return derived->visitStripe(expr.cast<SDBMStripeExpr>());
412  else
413  return derived->visitInput(expr.cast<SDBMInputExpr>());
414  }
415 
419  Result visitInput(SDBMInputExpr expr) {
420  auto *derived = static_cast<Derived *>(this);
421  if (expr.getKind() == SDBMExprKind::DimId)
422  return derived->visitDim(expr.cast<SDBMDimExpr>());
423  else
424  return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
425  }
426 
431  auto *derived = static_cast<Derived *>(this);
432  if (auto var = expr.dyn_cast<SDBMDirectExpr>())
433  return derived->visitDirect(var);
434  else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
435  return derived->visitNeg(neg);
436  else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
437  return derived->visitDiff(diff);
438 
439  llvm_unreachable("unhandled subtype of varying SDBM expression");
440  }
441 
442  template <bool isPreorder> void walk(SDBMExpr expr) {
443  if (isPreorder)
444  visit(expr);
445  if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
446  walk<isPreorder>(sumExpr.getLHS());
447  walk<isPreorder>(sumExpr.getRHS());
448  } else if (auto diffExpr = expr.dyn_cast<SDBMDiffExpr>()) {
449  walk<isPreorder>(diffExpr.getLHS());
450  walk<isPreorder>(diffExpr.getRHS());
451  } else if (auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>()) {
452  walk<isPreorder>(stripeExpr.getLHS());
453  walk<isPreorder>(stripeExpr.getStripeFactor());
454  } else if (auto negExpr = expr.dyn_cast<SDBMNegExpr>()) {
455  walk<isPreorder>(negExpr.getVar());
456  }
457  if (!isPreorder)
458  visit(expr);
459  }
460 };
461 
465 namespace ops_assertions {
466 
471 inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) {
472  return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs);
473 }
474 inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) {
475  return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs;
476 }
477 
481 inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) {
482  return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs);
483 }
484 inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) {
485  return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs;
486 }
487 
490 SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor);
491 inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) {
492  return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor));
493 }
494 } // namespace ops_assertions
495 
496 } // end namespace mlir
497 
498 namespace llvm {
499 // SDBMExpr hash just like pointers.
500 template <> struct DenseMapInfo<mlir::SDBMExpr> {
502  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
503  return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
504  }
507  return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
508  }
509  static unsigned getHashValue(mlir::SDBMExpr expr) {
510  return expr.hash_value();
511  }
512  static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) {
513  return lhs == rhs;
514  }
515 };
516 
517 // SDBMDirectExpr hash just like pointers.
518 template <> struct DenseMapInfo<mlir::SDBMDirectExpr> {
520  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
521  return mlir::SDBMDirectExpr(
522  static_cast<mlir::SDBMExpr::ImplType *>(pointer));
523  }
526  return mlir::SDBMDirectExpr(
527  static_cast<mlir::SDBMExpr::ImplType *>(pointer));
528  }
529  static unsigned getHashValue(mlir::SDBMDirectExpr expr) {
530  return expr.hash_value();
531  }
533  return lhs == rhs;
534  }
535 };
536 
537 // SDBMTermExpr hash just like pointers.
538 template <> struct DenseMapInfo<mlir::SDBMTermExpr> {
540  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
541  return mlir::SDBMTermExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
542  }
545  return mlir::SDBMTermExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
546  }
547  static unsigned getHashValue(mlir::SDBMTermExpr expr) {
548  return expr.hash_value();
549  }
551  return lhs == rhs;
552  }
553 };
554 
555 // SDBMConstantExpr hash just like pointers.
556 template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
558  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
559  return mlir::SDBMConstantExpr(
560  static_cast<mlir::SDBMExpr::ImplType *>(pointer));
561  }
564  return mlir::SDBMConstantExpr(
565  static_cast<mlir::SDBMExpr::ImplType *>(pointer));
566  }
567  static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
568  return expr.hash_value();
569  }
571  return lhs == rhs;
572  }
573 };
574 } // namespace llvm
575 
576 #endif // MLIR_DIALECT_SDBM_SDBMEXPR_H
Definition: InferTypeOpInterface.cpp:20
SDBMDialect * getDialect() const
Returns the SDBM dialect instance.
Definition: SDBMExpr.cpp:148
void visitDim(SDBMDimExpr)
Definition: SDBMExpr.h:390
Result visit(SDBMExpr expr)
Visit the given SDBM expression, dispatching to kind-specific functions.
Definition: SDBMExpr.h:360
void visitSum(SDBMSumExpr)
Default visitors do nothing.
Definition: SDBMExpr.h:387
Definition: SDBMExprDetail.h:27
Definition: PassRegistry.cpp:413
static mlir::SDBMDirectExpr getTombstoneKey()
Definition: SDBMExpr.h:524
static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value)
Definition: SDBMExpr.cpp:633
SDBMExprKind getKind() const
Returns the kind of the SDBM expression.
Definition: SDBMExpr.cpp:142
bool operator==(const SDBMExpr &other) const
SDBM expressions can be compared straight-forwardly.
Definition: SDBMExpr.h:100
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:241
Definition: LLVM.h:45
Definition: SDBMExpr.h:178
Definition: SDBMExpr.h:293
IntInfty operator+(IntInfty lhs, IntInfty rhs)
Definition: SDBM.h:57
void walkPreorder(SDBMExpr expr)
Definition: SDBMExpr.h:379
static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs)
Definition: SDBMExpr.h:512
static unsigned getHashValue(mlir::SDBMConstantExpr expr)
Definition: SDBMExpr.h:567
void walkPostorder(SDBMExpr expr)
Definition: SDBMExpr.h:383
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:317
Result visitInput(SDBMInputExpr expr)
Definition: SDBMExpr.h:419
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:183
Definition: LLVM.h:40
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:208
static mlir::SDBMConstantExpr getTombstoneKey()
Definition: SDBMExpr.h:562
SDBMExprKind
Definition: SDBMExpr.h:26
void visitSymbol(SDBMSymbolExpr)
Definition: SDBMExpr.h:391
SDBMExpr stripe(SDBMExpr expr, int64_t factor)
Definition: SDBMExpr.h:491
void visitNeg(SDBMNegExpr)
Definition: SDBMExpr.h:392
U cast() const
Definition: SDBMExpr.h:122
SDBM constant expression, wraps a 64-bit integer.
Definition: SDBMExpr.h:155
Definition: SDBMExprDetail.h:57
static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs)
Definition: SDBMExpr.h:532
static mlir::SDBMTermExpr getTombstoneKey()
Definition: SDBMExpr.h:543
Definition: SDBMExprDetail.h:36
bool operator!() const
Definition: SDBMExpr.h:106
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:333
SDBMExpr()
Definition: SDBMExpr.h:91
static mlir::SDBMExpr getTombstoneKey()
Definition: SDBMExpr.h:505
Definition: SDBMExpr.h:324
static mlir::SDBMDirectExpr getEmptyKey()
Definition: SDBMExpr.h:519
Definition: SDBMExprDetail.h:110
Definition: AffineExpr.h:66
Result visitVarying(SDBMVaryingExpr expr)
Definition: SDBMExpr.h:430
SDBMExpr(ImplType *expr)
Definition: SDBMExpr.h:92
Definition: SDBMExpr.h:274
static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs)
Definition: SDBMExpr.h:570
Definition: SDBMExpr.h:308
ImplType * impl
Definition: SDBMExpr.h:151
static bool isEqual(mlir::SDBMTermExpr lhs, mlir::SDBMTermExpr rhs)
Definition: SDBMExpr.h:550
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:348
static unsigned getHashValue(mlir::SDBMExpr expr)
Definition: SDBMExpr.h:509
Result visitTerm(SDBMTermExpr expr)
Definition: SDBMExpr.h:408
void visitStripe(SDBMStripeExpr)
Definition: SDBMExpr.h:389
inline ::llvm::hash_code hash_value(AffineExpr arg)
Make AffineExpr hashable.
Definition: AffineExpr.h:201
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:298
bool operator!=(const SDBMExpr &other) const
Definition: SDBMExpr.h:101
Definition: SDBMExpr.h:357
Definition: SDBMExpr.h:340
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:279
SDBM sum expression. LHS is a term expression and RHS is a constant.
Definition: SDBMExpr.h:233
Definition: SDBMExpr.h:197
Definition: SDBMExprDetail.h:94
static mlir::SDBMConstantExpr getEmptyKey()
Definition: SDBMExpr.h:557
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:225
Definition: MLIRContext.h:34
void print(OpAsmPrinter &p, AffineIfOp op)
Definition: AffineOps.cpp:1671
AffineExpr operator-(int64_t val, AffineExpr expr)
Definition: AffineExpr.h:207
static mlir::SDBMTermExpr getEmptyKey()
Definition: SDBMExpr.h:539
static unsigned getHashValue(mlir::SDBMDirectExpr expr)
Definition: SDBMExpr.h:529
Definition: SDBMDialect.h:18
void visitConstant(SDBMConstantExpr)
Definition: SDBMExpr.h:393
Definition: SDBMExprDetail.h:78
bool isa() const
LLVM-style casts.
Definition: SDBMExpr.h:116
Result visitDirect(SDBMDirectExpr expr)
Definition: SDBMExpr.h:398
void visitDiff(SDBMDiffExpr)
Definition: SDBMExpr.h:388
static mlir::SDBMExpr getEmptyKey()
Definition: SDBMExpr.h:501
U dyn_cast() const
Definition: SDBMExpr.h:117
Definition: SDBMExpr.h:255
static unsigned getHashValue(mlir::SDBMTermExpr expr)
Definition: SDBMExpr.h:547
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:263
void walk(SDBMExpr expr)
Definition: SDBMExpr.h:442
static bool isClassFor(const SDBMExpr &expr)
Definition: SDBMExpr.h:165
Definition: SDBMExpr.h:221
::llvm::hash_code hash_value() const
Support for LLVM hashing.
Definition: SDBMExpr.h:128
Definition: SDBMExpr.h:88