15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H 16 #define MLIR_DIALECT_COMMONFOLDERS_H 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/STLExtras.h" 26 template <
class AttrElementT,
27 class ElementValueT =
typename AttrElementT::ValueType,
29 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
31 const CalculationT &calculate) {
32 assert(operands.size() == 2 &&
"binary op takes two operands");
33 if (!operands[0] || !operands[1])
35 if (operands[0].getType() != operands[1].getType())
38 if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
39 auto lhs = operands[0].cast<AttrElementT>();
40 auto rhs = operands[1].cast<AttrElementT>();
42 return AttrElementT::get(lhs.getType(),
43 calculate(lhs.getValue(), rhs.getValue()));
44 }
else if (operands[0].isa<SplatElementsAttr>() &&
45 operands[1].isa<SplatElementsAttr>()) {
51 auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
54 }
else if (operands[0].isa<ElementsAttr>() &&
61 auto lhsIt = lhs.
getValues<ElementValueT>().begin();
62 auto rhsIt = rhs.
getValues<ElementValueT>().begin();
64 elementResults.reserve(lhs.getNumElements());
65 for (
size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
66 elementResults.push_back(calculate(*lhsIt, *rhsIt));
73 #endif // MLIR_DIALECT_COMMONFOLDERS_H Definition: InferTypeOpInterface.cpp:20
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Definition: Attributes.cpp:572
iterator_range< T > getValues() const
Definition: Attributes.h:53
Definition: Attributes.h:1193
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, const CalculationT &calculate)
Definition: CommonFolders.h:30
Attribute getSplatValue() const
Definition: Attributes.h:820
Definition: Attributes.h:559