//===- Merger.cpp - Implementation of iteration lattices ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Operation.h" #include "llvm/Support/Debug.h" namespace mlir { namespace sparse_tensor { //===----------------------------------------------------------------------===// // Constructors. //===----------------------------------------------------------------------===// TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) : kind(k), val(v), op(o) { switch (kind) { // Leaf. case kTensor: assert(x != -1u && y == -1u && !v && !o); tensor = x; break; case kInvariant: assert(x == -1u && y == -1u && v && !o); break; case kIndex: assert(x != -1u && y == -1u && !v && !o); index = x; break; // Unary operations. case kAbsF: case kAbsC: case kCeilF: case kFloorF: case kSqrtF: case kSqrtC: case kExpm1F: case kExpm1C: case kLog1pF: case kLog1pC: case kSinF: case kSinC: case kTanhF: case kTanhC: case kNegF: case kNegC: case kNegI: case kCIm: case kCRe: assert(x != -1u && y == -1u && !v && !o); children.e0 = x; children.e1 = y; break; case kTruncF: case kExtF: case kCastFS: case kCastFU: case kCastSF: case kCastUF: case kCastS: case kCastU: case kCastIdx: case kTruncI: case kBitCast: assert(x != -1u && y == -1u && v && !o); children.e0 = x; children.e1 = y; break; case kBinaryBranch: assert(x != -1u && y == -1u && !v && o); children.e0 = x; children.e1 = y; break; case kUnary: // No assertion on y can be made, as the branching paths involve both // a unary (mapSet) and binary (takeDisj) pathway. assert(x != -1u && !v && o); children.e0 = x; children.e1 = y; break; // Binary operations. case kMulF: case kMulC: case kMulI: case kDivF: case kDivC: case kDivS: case kDivU: case kAddF: case kAddC: case kAddI: case kSubF: case kSubC: case kSubI: case kAndI: case kOrI: case kXorI: case kShrS: case kShrU: case kShlI: assert(x != -1u && y != -1u && !v && !o); children.e0 = x; children.e1 = y; break; case kBinary: assert(x != -1u && y != -1u && !v && o); children.e0 = x; children.e1 = y; break; } } LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), simple(), exp(e) { bits.set(b); } LatPoint::LatPoint(const BitVector &b, unsigned e) : bits(b), simple(), exp(e) {} //===----------------------------------------------------------------------===// // Lattice methods. //===----------------------------------------------------------------------===// unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, Operation *op) { unsigned e = tensorExps.size(); tensorExps.push_back(TensorExp(k, e0, e1, v, op)); return e; } unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { assert(t < numTensors && i < numLoops); unsigned p = latPoints.size(); latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); return p; } unsigned Merger::addSet() { unsigned s = latSets.size(); latSets.emplace_back(SmallVector()); return s; } unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, Operation *op) { unsigned p = latPoints.size(); BitVector nb = BitVector(latPoints[p0].bits); nb |= latPoints[p1].bits; unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); latPoints.push_back(LatPoint(nb, e)); return p; } unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { unsigned s = addSet(); for (unsigned p0 : latSets[s0]) for (unsigned p1 : latSets[s1]) latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); return s; } unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { unsigned s = takeConj(kind, s0, s1, op); // Followed by all in s0. for (unsigned p : latSets[s0]) latSets[s].push_back(p); // Map binary 0-y to unary -y. // TODO: move this if-else logic into buildLattices if (kind == kSubF) s1 = mapSet(kNegF, s1); else if (kind == kSubC) s1 = mapSet(kNegC, s1); else if (kind == kSubI) s1 = mapSet(kNegI, s1); // Followed by all in s1. for (unsigned p : latSets[s1]) latSets[s].push_back(p); return s; } unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, bool includeLeft, Kind ltrans, Operation *opleft, bool includeRight, Kind rtrans, Operation *opright) { unsigned s = takeConj(kind, s0, s1, orig); // Left Region. if (includeLeft) { if (opleft) s0 = mapSet(ltrans, s0, Value(), opleft); for (unsigned p : latSets[s0]) latSets[s].push_back(p); } // Right Region. if (includeRight) { if (opright) s1 = mapSet(rtrans, s1, Value(), opright); for (unsigned p : latSets[s1]) latSets[s].push_back(p); } return s; } unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { assert(kAbsF <= kind && kind <= kUnary); unsigned s = addSet(); for (unsigned p : latSets[s0]) { unsigned e = addExp(kind, latPoints[p].exp, v, op); latPoints.push_back(LatPoint(latPoints[p].bits, e)); latSets[s].push_back(latPoints.size() - 1); } return s; } unsigned Merger::optimizeSet(unsigned s0) { unsigned s = addSet(); assert(!latSets[s0].empty()); unsigned p0 = latSets[s0][0]; for (unsigned p1 : latSets[s0]) { bool add = true; if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { assert(!latGT(p1, p2)); // Lj => Li would be bad if (onlyDenseDiff(p2, p1)) { add = false; break; } } assert(!add || latGT(p0, p1)); } if (add) latSets[s].push_back(p1); } for (unsigned p : latSets[s]) latPoints[p].simple = simplifyCond(s, p); return s; } BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { // First determine if this lattice point is a *singleton*, i.e., // the last point in a lattice, no other is less than this one. bool isSingleton = true; for (unsigned p1 : latSets[s0]) { if (p0 != p1 && latGT(p0, p1)) { isSingleton = false; break; } } // Now apply the two basic rules. BitVector simple = latPoints[p0].bits; bool reset = isSingleton && hasAnyDimOf(simple, kSparse); for (unsigned b = 0, be = simple.size(); b < be; b++) { if (simple[b] && !isDim(b, kSparse)) { if (reset) simple.reset(b); reset = true; } } return simple; } bool Merger::latGT(unsigned i, unsigned j) const { const BitVector &bitsi = latPoints[i].bits; const BitVector &bitsj = latPoints[j].bits; assert(bitsi.size() == bitsj.size()); if (bitsi.count() > bitsj.count()) { for (unsigned b = 0, be = bitsj.size(); b < be; b++) if (bitsj[b] && !bitsi[b]) return false; return true; } return false; } bool Merger::onlyDenseDiff(unsigned i, unsigned j) { BitVector tmp = latPoints[j].bits; tmp ^= latPoints[i].bits; return !hasAnyDimOf(tmp, kSparse); } bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { for (unsigned b = 0, be = bits.size(); b < be; b++) if (bits[b] && isDim(b, d)) return true; return false; } bool Merger::isSingleCondition(unsigned t, unsigned e) const { switch (tensorExps[e].kind) { // Leaf. case kTensor: return tensorExps[e].tensor == t; case kInvariant: case kIndex: return false; // Unary operations. case kAbsF: case kAbsC: case kCeilF: case kFloorF: case kSqrtF: case kSqrtC: case kExpm1F: case kExpm1C: case kLog1pF: case kLog1pC: case kSinF: case kSinC: case kTanhF: case kTanhC: case kNegF: case kNegC: case kNegI: case kTruncF: case kExtF: case kCastFS: case kCastFU: case kCastSF: case kCastUF: case kCastS: case kCastU: case kCastIdx: case kTruncI: case kCIm: case kCRe: case kBitCast: return isSingleCondition(t, tensorExps[e].children.e0); case kBinaryBranch: case kUnary: return false; // Binary operations. case kDivF: // note: x / c only case kDivC: case kDivS: case kDivU: assert(!maybeZero(tensorExps[e].children.e1)); return isSingleCondition(t, tensorExps[e].children.e0); case kShrS: // note: x >> inv only case kShrU: case kShlI: assert(isInvariant(tensorExps[e].children.e1)); return isSingleCondition(t, tensorExps[e].children.e0); case kMulF: case kMulC: case kMulI: case kAndI: if (isSingleCondition(t, tensorExps[e].children.e0)) return isSingleCondition(t, tensorExps[e].children.e1) || isInvariant(tensorExps[e].children.e1); if (isSingleCondition(t, tensorExps[e].children.e1)) return isInvariant(tensorExps[e].children.e0); return false; case kAddF: case kAddC: case kAddI: return isSingleCondition(t, tensorExps[e].children.e0) && isSingleCondition(t, tensorExps[e].children.e1); case kSubF: case kSubC: case kSubI: case kOrI: case kXorI: case kBinary: return false; } } #ifndef NDEBUG //===----------------------------------------------------------------------===// // Print methods (for debugging). //===----------------------------------------------------------------------===// static const char *kindToOpSymbol(Kind kind) { switch (kind) { // Leaf. case kTensor: return "tensor"; case kInvariant: return "invariant"; case kIndex: return "index"; // Unary operations. case kAbsF: case kAbsC: return "abs"; case kCeilF: return "ceil"; case kFloorF: return "floor"; case kSqrtF: case kSqrtC: return "sqrt"; case kExpm1F: case kExpm1C: return "expm1"; case kLog1pF: case kLog1pC: return "log1p"; case kSinF: case kSinC: return "sin"; case kTanhF: case kTanhC: return "tanh"; case kNegF: case kNegC: case kNegI: return "-"; case kTruncF: case kExtF: case kCastFS: case kCastFU: case kCastSF: case kCastUF: case kCastS: case kCastU: case kCastIdx: case kTruncI: case kCIm: return "complex.im"; case kCRe: return "complex.re"; case kBitCast: return "cast"; case kBinaryBranch: return "binary_branch"; case kUnary: return "unary"; // Binary operations. case kMulF: case kMulC: case kMulI: return "*"; case kDivF: case kDivC: case kDivS: case kDivU: return "/"; case kAddF: case kAddC: case kAddI: return "+"; case kSubF: case kSubC: case kSubI: return "-"; case kAndI: return "&"; case kOrI: return "|"; case kXorI: return "^"; case kShrS: return "a>>"; case kShrU: return ">>"; case kShlI: return "<<"; case kBinary: return "binary"; } llvm_unreachable("unexpected kind for symbol"); } void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { // Leaf. case kTensor: if (tensorExps[e].tensor == syntheticTensor) llvm::dbgs() << "synthetic_"; else if (tensorExps[e].tensor == outTensor) llvm::dbgs() << "output_"; llvm::dbgs() << "tensor_" << tensorExps[e].tensor; break; case kInvariant: llvm::dbgs() << "invariant"; break; case kIndex: llvm::dbgs() << "index_" << tensorExps[e].index; break; // Unary operations. case kAbsF: case kAbsC: case kCeilF: case kFloorF: case kSqrtF: case kSqrtC: case kExpm1F: case kExpm1C: case kLog1pF: case kLog1pC: case kSinF: case kSinC: case kTanhF: case kTanhC: case kNegF: case kNegC: case kNegI: case kTruncF: case kExtF: case kCastFS: case kCastFU: case kCastSF: case kCastUF: case kCastS: case kCastU: case kCastIdx: case kTruncI: case kCIm: case kCRe: case kBitCast: case kBinaryBranch: case kUnary: llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; dumpExp(tensorExps[e].children.e0); break; // Binary operations. case kMulF: case kMulC: case kMulI: case kDivF: case kDivC: case kDivS: case kDivU: case kAddF: case kAddC: case kAddI: case kSubF: case kSubC: case kSubI: case kAndI: case kOrI: case kXorI: case kShrS: case kShrU: case kShlI: case kBinary: llvm::dbgs() << "("; dumpExp(tensorExps[e].children.e0); llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; dumpExp(tensorExps[e].children.e1); llvm::dbgs() << ")"; } } void Merger::dumpLat(unsigned p) const { llvm::dbgs() << "lat("; dumpBits(latPoints[p].bits); llvm::dbgs() << " :"; dumpBits(latPoints[p].simple); llvm::dbgs() << " : "; dumpExp(latPoints[p].exp); llvm::dbgs() << " )\n"; } void Merger::dumpSet(unsigned s) const { llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; for (unsigned p : latSets[s]) { llvm::dbgs() << " "; dumpLat(p); } llvm::dbgs() << "}\n"; } void Merger::dumpBits(const BitVector &bits) const { for (unsigned b = 0, be = bits.size(); b < be; b++) { if (bits[b]) { unsigned t = tensor(b); unsigned i = index(b); llvm::dbgs() << " i_" << t << "_" << i << "_"; switch (dims[t][i]) { case kSparse: llvm::dbgs() << "S"; break; case kDense: llvm::dbgs() << "D"; break; case kSingle: llvm::dbgs() << "T"; break; case kUndef: llvm::dbgs() << "U"; break; } } } } #endif // NDEBUG //===----------------------------------------------------------------------===// // Builder methods. //===----------------------------------------------------------------------===// unsigned Merger::buildLattices(unsigned e, unsigned i) { Kind kind = tensorExps[e].kind; switch (kind) { // Leaf. case kTensor: case kInvariant: case kIndex: { // Either the index is really used in the tensor expression, or it is // set to the undefined index in that dimension. An invariant expression, // a proper index value, and a truly dynamic sparse output tensor are set // to a synthetic tensor with undefined indices only to ensure the // iteration space is not skipped as a result of their contents. unsigned s = addSet(); unsigned t = syntheticTensor; if (kind == kTensor) { t = tensorExps[e].tensor; if (hasSparseOut && t == outTensor) t = syntheticTensor; } latSets[s].push_back(addLat(t, i, e)); return s; } // Unary operations. case kAbsF: case kAbsC: case kCeilF: case kFloorF: case kSqrtF: case kSqrtC: case kExpm1F: case kExpm1C: case kLog1pF: case kLog1pC: case kSinF: case kSinC: case kTanhF: case kTanhC: case kNegF: case kNegC: case kNegI: case kTruncF: case kExtF: case kCastFS: case kCastFU: case kCastSF: case kCastUF: case kCastS: case kCastU: case kCastIdx: case kTruncI: case kCIm: case kCRe: case kBitCast: // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the // lattice set of the operand through the operator into a new set. // // -y|!y | y | // --+---+---+ // | 0 |-y | return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), tensorExps[e].val); case kBinaryBranch: // The left or right half of a binary operation which has already // been split into separate operations for each region. return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), tensorExps[e].op); case kUnary: // A custom unary operation. // // op y| !y | y | // ----+----------+------------+ // | absent() | present(y) | { unsigned child0 = buildLattices(tensorExps[e].children.e0, i); UnaryOp unop = cast(tensorExps[e].op); Region &absentRegion = unop.absentRegion(); if (absentRegion.empty()) { // Simple mapping over existing values. return mapSet(kind, child0, Value(), unop); } // Use a disjunction with `unop` on the left and the absent value as an // invariant on the right. Block &absentBlock = absentRegion.front(); YieldOp absentYield = cast(absentBlock.getTerminator()); Value absentVal = absentYield.result(); unsigned rhs = addExp(kInvariant, absentVal); return takeDisj(kind, child0, buildLattices(rhs, i), unop); } // Binary operations. case kMulF: case kMulC: case kMulI: case kAndI: // A multiplicative operation only needs to be performed // for the conjunction of sparse iteration spaces. // // x*y|!y | y | // ---+---+---+ // !x | 0 | 0 | // x | 0 |x*y| // // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); case kDivF: case kDivC: case kDivS: case kDivU: // A division is tricky, since 0/0, 0/c, c/0 all have // specific outcomes for floating-point and integers. // Thus, we need to traverse the full iteration space. // // x/y|!y | y | // ---+---+---+ // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero // x |x/0|x/y| INT: x/0=exception for any x // // TODO: for now we "fixed" this by only accepting x/c cases // during expression building, so that the conjunction // rules applies (viz. x/c = x*(1/c) as far as lattice // construction is concerned). assert(!maybeZero(tensorExps[e].children.e1)); return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); case kAddF: case kAddC: case kAddI: case kSubF: case kSubC: case kSubI: case kOrI: case kXorI: // An additive operation needs to be performed // for the disjunction of sparse iteration spaces. // // x+y|!y | y | x-y|!y | y | // ---+---+---+ ---+---+---+ // !x | 0 | y | !x | 0 |-y | // x | x |x+y| x | x |x-y| return takeDisj(kind, // take binary disjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); case kShrS: case kShrU: case kShlI: // A shift operation by an invariant amount (viz. tensor expressions // can only occur at the left-hand-side of the operator) can be handled // with the conjuction rule. assert(isInvariant(tensorExps[e].children.e1)); return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); case kBinary: // A custom binary operation. // // x op y| !y | y | // ------+---------+--------------+ // !x | empty | right(y) | // x | left(x) | overlap(x,y) | { unsigned child0 = buildLattices(tensorExps[e].children.e0, i); unsigned child1 = buildLattices(tensorExps[e].children.e1, i); BinaryOp binop = cast(tensorExps[e].op); Region &leftRegion = binop.leftRegion(); Region &rightRegion = binop.rightRegion(); // Left Region. Operation *leftYield = nullptr; if (!leftRegion.empty()) { Block &leftBlock = leftRegion.front(); leftYield = leftBlock.getTerminator(); } // Right Region. Operation *rightYield = nullptr; if (!rightRegion.empty()) { Block &rightBlock = rightRegion.front(); rightYield = rightBlock.getTerminator(); } bool includeLeft = binop.left_identity() || !leftRegion.empty(); bool includeRight = binop.right_identity() || !rightRegion.empty(); return takeCombi(kBinary, child0, child1, binop, includeLeft, kBinaryBranch, leftYield, includeRight, kBinaryBranch, rightYield); } } llvm_unreachable("unexpected expression kind"); } Optional Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { Operation *yield = op.region().front().getTerminator(); return buildTensorExp(op, yield->getOperand(0)); } /// Only returns false if we are certain this is a nonzero. bool Merger::maybeZero(unsigned e) const { if (tensorExps[e].kind == kInvariant) { if (auto c = tensorExps[e].val.getDefiningOp()) { ArrayAttr arrayAttr = c.getValue(); return arrayAttr[0].cast().getValue().isZero() && arrayAttr[0].cast().getValue().isZero(); } if (auto c = tensorExps[e].val.getDefiningOp()) return c.value() == 0; if (auto c = tensorExps[e].val.getDefiningOp()) return c.value().isZero(); } return true; } bool Merger::isInvariant(unsigned e) const { return tensorExps[e].kind == kInvariant; } Type Merger::inferType(unsigned e, Value src) { // Obtain the destination type from the cast node. Type dtp = tensorExps[e].val.getType(); // Inspect source type. For vector types, apply the same // vectorization to the destination type. if (auto vtp = src.getType().dyn_cast()) return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); return dtp; } Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { if (auto arg = v.dyn_cast()) { unsigned argN = arg.getArgNumber(); // Any argument of the generic op that is not marked as a scalar // argument is considered a tensor, indexed by the implicit loop // bounds. This includes rank-0 tensor arguments. if (arg.getOwner()->getParentOp() == op) { OpOperand *t = op.getInputAndOutputOperands()[argN]; if (!op.isScalar(t)) return addExp(kTensor, argN); v = t->get(); // get scalar value } // Any other argument (marked as scalar argument for the generic op // or belonging to an enveloping op) is considered invariant. return addExp(kInvariant, v); } // Something defined outside is invariant. Operation *def = v.getDefiningOp(); if (def->getBlock() != &op.region().front()) return addExp(kInvariant, v); // Construct index operations. if (def->getNumOperands() == 0) { if (auto indexOp = dyn_cast(def)) return addExp(kIndex, indexOp.dim()); } // Construct unary operations if subexpression can be built. if (def->getNumOperands() == 1) { auto x = buildTensorExp(op, def->getOperand(0)); if (x.hasValue()) { unsigned e = x.getValue(); if (isa(def)) return addExp(kAbsF, e); if (isa(def)) return addExp(kAbsC, e); if (isa(def)) return addExp(kCeilF, e); if (isa(def)) return addExp(kFloorF, e); if (isa(def)) return addExp(kSqrtF, e); if (isa(def)) return addExp(kSqrtC, e); if (isa(def)) return addExp(kExpm1F, e); if (isa(def)) return addExp(kExpm1C, e); if (isa(def)) return addExp(kLog1pF, e); if (isa(def)) return addExp(kLog1pC, e); if (isa(def)) return addExp(kSinF, e); if (isa(def)) return addExp(kSinC, e); if (isa(def)) return addExp(kTanhF, e); if (isa(def)) return addExp(kTanhC, e); if (isa(def)) return addExp(kNegF, e); // no negi in std if (isa(def)) return addExp(kNegC, e); if (isa(def)) return addExp(kTruncF, e, v); if (isa(def)) return addExp(kExtF, e, v); if (isa(def)) return addExp(kCastFS, e, v); if (isa(def)) return addExp(kCastFU, e, v); if (isa(def)) return addExp(kCastSF, e, v); if (isa(def)) return addExp(kCastUF, e, v); if (isa(def)) return addExp(kCastS, e, v); if (isa(def)) return addExp(kCastU, e, v); if (isa(def)) return addExp(kCastIdx, e, v); if (isa(def)) return addExp(kTruncI, e, v); if (isa(def)) return addExp(kCIm, e); if (isa(def)) return addExp(kCRe, e); if (isa(def)) return addExp(kBitCast, e, v); if (isa(def)) return addExp(kUnary, e, Value(), def); } } // Construct binary operations if subexpressions can be built. // See buildLattices() for an explanation of rejecting certain // division and shift operations if (def->getNumOperands() == 2) { auto x = buildTensorExp(op, def->getOperand(0)); auto y = buildTensorExp(op, def->getOperand(1)); if (x.hasValue() && y.hasValue()) { unsigned e0 = x.getValue(); unsigned e1 = y.getValue(); if (isa(def)) return addExp(kMulF, e0, e1); if (isa(def)) return addExp(kMulC, e0, e1); if (isa(def)) return addExp(kMulI, e0, e1); if (isa(def) && !maybeZero(e1)) return addExp(kDivF, e0, e1); if (isa(def) && !maybeZero(e1)) return addExp(kDivC, e0, e1); if (isa(def) && !maybeZero(e1)) return addExp(kDivS, e0, e1); if (isa(def) && !maybeZero(e1)) return addExp(kDivU, e0, e1); if (isa(def)) return addExp(kAddF, e0, e1); if (isa(def)) return addExp(kAddC, e0, e1); if (isa(def)) return addExp(kAddI, e0, e1); if (isa(def)) return addExp(kSubF, e0, e1); if (isa(def)) return addExp(kSubC, e0, e1); if (isa(def)) return addExp(kSubI, e0, e1); if (isa(def)) return addExp(kAndI, e0, e1); if (isa(def)) return addExp(kOrI, e0, e1); if (isa(def)) return addExp(kXorI, e0, e1); if (isa(def) && isInvariant(e1)) return addExp(kShrS, e0, e1); if (isa(def) && isInvariant(e1)) return addExp(kShrU, e0, e1); if (isa(def) && isInvariant(e1)) return addExp(kShlI, e0, e1); if (isa(def)) return addExp(kBinary, e0, e1, Value(), def); } } // Cannot build. return None; } static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, ValueRange vals) { // Make a clone of overlap region. Region tmpRegion; BlockAndValueMapping mapper; region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); Block &clonedBlock = tmpRegion.front(); YieldOp clonedYield = cast(clonedBlock.getTerminator()); // Merge cloned block and return yield value. Operation *placeholder = rewriter.create(loc, 0); rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); Value val = clonedYield.result(); rewriter.eraseOp(clonedYield); rewriter.eraseOp(placeholder); return val; } static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0) { if (!v0) // Empty input value must be propagated. return Value(); UnaryOp unop = cast(op); Region &presentRegion = unop.presentRegion(); if (presentRegion.empty()) // Uninitialized Value() will be interpreted as missing data in the // output. return Value(); return insertYieldOp(rewriter, loc, presentRegion, {v0}); } static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1) { if (!v0 || !v1) // Empty input values must be propagated. return Value(); BinaryOp binop = cast(op); Region &overlapRegion = binop.overlapRegion(); if (overlapRegion.empty()) // Uninitialized Value() will be interpreted as missing data in the // output. return Value(); return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); } Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1) { switch (tensorExps[e].kind) { // Leaf. case kTensor: case kInvariant: case kIndex: llvm_unreachable("unexpected non-op"); // Unary operations. case kAbsF: return rewriter.create(loc, v0); case kAbsC: { auto type = v0.getType().cast(); auto eltType = type.getElementType().cast(); return rewriter.create(loc, eltType, v0); } case kCeilF: return rewriter.create(loc, v0); case kFloorF: return rewriter.create(loc, v0); case kSqrtF: return rewriter.create(loc, v0); case kSqrtC: return rewriter.create(loc, v0); case kExpm1F: return rewriter.create(loc, v0); case kExpm1C: return rewriter.create(loc, v0); case kLog1pF: return rewriter.create(loc, v0); case kLog1pC: return rewriter.create(loc, v0); case kSinF: return rewriter.create(loc, v0); case kSinC: return rewriter.create(loc, v0); case kTanhF: return rewriter.create(loc, v0); case kTanhC: return rewriter.create(loc, v0); case kNegF: return rewriter.create(loc, v0); case kNegC: return rewriter.create(loc, v0); case kNegI: // no negi in std return rewriter.create( loc, rewriter.create(loc, v0.getType(), rewriter.getZeroAttr(v0.getType())), v0); case kTruncF: return rewriter.create(loc, inferType(e, v0), v0); case kExtF: return rewriter.create(loc, inferType(e, v0), v0); case kCastFS: return rewriter.create(loc, inferType(e, v0), v0); case kCastFU: return rewriter.create(loc, inferType(e, v0), v0); case kCastSF: return rewriter.create(loc, inferType(e, v0), v0); case kCastUF: return rewriter.create(loc, inferType(e, v0), v0); case kCastS: return rewriter.create(loc, inferType(e, v0), v0); case kCastU: return rewriter.create(loc, inferType(e, v0), v0); case kCastIdx: return rewriter.create(loc, inferType(e, v0), v0); case kTruncI: return rewriter.create(loc, inferType(e, v0), v0); case kCIm: { auto type = v0.getType().cast(); auto eltType = type.getElementType().cast(); return rewriter.create(loc, eltType, v0); } case kCRe: { auto type = v0.getType().cast(); auto eltType = type.getElementType().cast(); return rewriter.create(loc, eltType, v0); } case kBitCast: return rewriter.create(loc, inferType(e, v0), v0); // Binary operations. case kMulF: return rewriter.create(loc, v0, v1); case kMulC: return rewriter.create(loc, v0, v1); case kMulI: return rewriter.create(loc, v0, v1); case kDivF: return rewriter.create(loc, v0, v1); case kDivC: return rewriter.create(loc, v0, v1); case kDivS: return rewriter.create(loc, v0, v1); case kDivU: return rewriter.create(loc, v0, v1); case kAddF: return rewriter.create(loc, v0, v1); case kAddC: return rewriter.create(loc, v0, v1); case kAddI: return rewriter.create(loc, v0, v1); case kSubF: return rewriter.create(loc, v0, v1); case kSubC: return rewriter.create(loc, v0, v1); case kSubI: return rewriter.create(loc, v0, v1); case kAndI: return rewriter.create(loc, v0, v1); case kOrI: return rewriter.create(loc, v0, v1); case kXorI: return rewriter.create(loc, v0, v1); case kShrS: return rewriter.create(loc, v0, v1); case kShrU: return rewriter.create(loc, v0, v1); case kShlI: return rewriter.create(loc, v0, v1); case kBinaryBranch: // semi-ring ops with custom logic. return insertYieldOp(rewriter, loc, *tensorExps[e].op->getBlock()->getParent(), {v0}); case kUnary: return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); case kBinary: return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); } llvm_unreachable("unexpected expression kind in build"); } } // namespace sparse_tensor } // namespace mlir