My Project
CommonFolders.h
Go to the documentation of this file.
1 //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various common operation folders. These folders
10 // are intended to be used by dialects to support common folding behavior
11 // without requiring each dialect to provide its own implementation.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
17 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/StandardTypes.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 
23 namespace mlir {
26 template <class AttrElementT,
27  class ElementValueT = typename AttrElementT::ValueType,
28  class CalculationT =
29  function_ref<ElementValueT(ElementValueT, ElementValueT)>>
31  const CalculationT &calculate) {
32  assert(operands.size() == 2 && "binary op takes two operands");
33  if (!operands[0] || !operands[1])
34  return {};
35  if (operands[0].getType() != operands[1].getType())
36  return {};
37 
38  if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
39  auto lhs = operands[0].cast<AttrElementT>();
40  auto rhs = operands[1].cast<AttrElementT>();
41 
42  return AttrElementT::get(lhs.getType(),
43  calculate(lhs.getValue(), rhs.getValue()));
44  } else if (operands[0].isa<SplatElementsAttr>() &&
45  operands[1].isa<SplatElementsAttr>()) {
46  // Both operands are splats so we can avoid expanding the values out and
47  // just fold based on the splat value.
48  auto lhs = operands[0].cast<SplatElementsAttr>();
49  auto rhs = operands[1].cast<SplatElementsAttr>();
50 
51  auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
52  rhs.getSplatValue<ElementValueT>());
53  return DenseElementsAttr::get(lhs.getType(), elementResult);
54  } else if (operands[0].isa<ElementsAttr>() &&
55  operands[1].isa<ElementsAttr>()) {
56  // Operands are ElementsAttr-derived; perform an element-wise fold by
57  // expanding the values.
58  auto lhs = operands[0].cast<ElementsAttr>();
59  auto rhs = operands[1].cast<ElementsAttr>();
60 
61  auto lhsIt = lhs.getValues<ElementValueT>().begin();
62  auto rhsIt = rhs.getValues<ElementValueT>().begin();
63  SmallVector<ElementValueT, 4> elementResults;
64  elementResults.reserve(lhs.getNumElements());
65  for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
66  elementResults.push_back(calculate(*lhsIt, *rhsIt));
67  return DenseElementsAttr::get(lhs.getType(), elementResults);
68  }
69  return {};
70 }
71 } // namespace mlir
72 
73 #endif // MLIR_DIALECT_COMMONFOLDERS_H
Definition: InferTypeOpInterface.cpp:20
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Definition: Attributes.cpp:572
iterator_range< T > getValues() const
Definition: LLVM.h:37
Definition: Attributes.h:53
Definition: Attributes.h:1193
Definition: LLVM.h:35
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, const CalculationT &calculate)
Definition: CommonFolders.h:30
Attribute getSplatValue() const
Definition: Attributes.h:820
Definition: Attributes.h:559