Aart Bik fbe611309e [mlir][sparse] refactored codegen environment into its own file
Also, as a proof of concept, all functionality related to reductions
has been refactored into private fields and a clean public API. As a
result, some dead code was found as well. This approach also simplifies
asserting on a proper environment state for each call.

NOTE: making all other fields private and migrating more methods into
      this new class is still TBD in yes another next revision!

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D140443
2022-12-20 16:58:59 -08:00

70 lines
2.1 KiB
C++

//===- CodegenEnv.cpp - Code generation environment class ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "CodegenEnv.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
// Code generation environment constructor and setup
//===----------------------------------------------------------------------===//
CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
unsigned numTensors, unsigned numLoops,
unsigned numFilterLoops)
: linalgOp(linop), options(opts), topSort(),
merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
sparseOut(nullptr), redVal(nullptr), redExp(-1u), redCustom(-1u) {}
void CodegenEnv::startEmit(SparseTensorLoopEmitter *le) {
assert(!loopEmitter && "must only start emitting once");
loopEmitter = le;
if (sparseOut) {
insChain = sparseOut->get();
merger.setHasSparseOut(true);
}
}
//===----------------------------------------------------------------------===//
// Code generation environment methods
//===----------------------------------------------------------------------===//
void CodegenEnv::startReduc(unsigned exp, Value val) {
assert(redExp == -1u && exp != -1u);
redExp = exp;
updateReduc(val);
}
void CodegenEnv::updateReduc(Value val) {
assert(redExp != -1u);
redVal = exp(redExp).val = val;
}
Value CodegenEnv::endReduc() {
Value val = redVal;
updateReduc(Value());
redExp = -1u;
return val;
}
void CodegenEnv::startCustomReduc(unsigned exp) {
assert(redCustom == -1u && exp != -1u);
redCustom = exp;
}
Value CodegenEnv::getCustomRedId() {
assert(redCustom != -1u);
return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
}
void CodegenEnv::endCustomReduc() {
assert(redCustom != -1u);
redCustom = -1u;
}