[MLIR][NVVM] Improve inline_ptx, add readwrite support (#154358)
Key Features 1. Multiple SSA returns – no struct packing/unpacking required. 2. Automatic struct unpacking – values are directly usable. 3. Readable register mapping * {$rwN} → read-write * {$roN} → read-only * {$woN} → write-only 4. Full read-write support (+ modifier). 5. Simplified operand specification – avoids cryptic "=r,=r,=f,=f,f,f,0,1" constraints. 6. Predicate support: PTX `@p` predication support IR Example: ``` %wo0, %wo1 = nvvm.inline_ptx """ .reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0}, {$r1}, p; selp.s32 {$rw1}, {$r0}, {$r1}, p; selp.s32 {$w0}, {$r0}, {$r1}, p; selp.s32 {$w1}, {$r0}, {$r1}, p; """ ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32 ``` After lowering ``` %0 = llvm.inline_asm has_side_effects asm_dialect = att "{ .reg .pred p;\ setp.ge.s32 p, $4, $5; \ selp.s32 $0, $4, $5, p;\ selp.s32 $1, $4, $5, p;\ selp.s32 $2, $4, $5, p;\ selp.s32 $3, $4, $5, p;\ }" "=r,=r,=f,=f,f,f,0,1" %c500_i32, %c400_i32, %cst, %cst_0 : (i32, i32, f32, f32) -> !llvm.struct<(i32, i32, f32, f32)> %1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> // Unpacked result from nvvm.inline_ptx %5 = arith.addi %1, %2 : i32 // read only %6 = arith.addf %cst, %cst_0 : f32 // write only %7 = arith.addf %3, %4 : f32 ```
This commit is contained in:
parent
1b0b59ae43
commit
5c36fb3303
@ -26,11 +26,11 @@ namespace NVVM {
|
|||||||
enum class PTXRegisterMod {
|
enum class PTXRegisterMod {
|
||||||
/// Read register with no modifier
|
/// Read register with no modifier
|
||||||
Read = 0,
|
Read = 0,
|
||||||
/// Read register with '+' modifier
|
/// Write register with '=' modifier
|
||||||
Write = 2,
|
Write = 2,
|
||||||
/// Read register with '=' modifier.
|
/// ReadWrite register with '+' modifier.
|
||||||
/// Note that, this is not natively supported by LLVM, but it is possible to
|
/// Note that, this is not natively supported by LLVM, the Interface does
|
||||||
/// set read and write for the same operand.
|
/// mapping
|
||||||
ReadWrite = 1,
|
ReadWrite = 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -67,13 +67,19 @@ class PtxBuilder {
|
|||||||
SmallVector<Value> ptxOperands;
|
SmallVector<Value> ptxOperands;
|
||||||
// Register constraints (read, write, readwrite) and register data types
|
// Register constraints (read, write, readwrite) and register data types
|
||||||
std::string registerConstraints;
|
std::string registerConstraints;
|
||||||
|
// Modifiers
|
||||||
|
SmallVector<PTXRegisterMod> registerModifiers;
|
||||||
|
// Has return value as write-only or read-write
|
||||||
bool hasResult = false;
|
bool hasResult = false;
|
||||||
|
// Indicates if the Op will handle the register mapping manually.
|
||||||
|
bool needsManualRegisterMapping = false;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// Single constructor that only initializes members.
|
/// Single constructor that only initializes members.
|
||||||
PtxBuilder(Operation *op, PatternRewriter &rewriter)
|
PtxBuilder(Operation *op, PatternRewriter &rewriter,
|
||||||
: interfaceOp(op), rewriter(rewriter) {}
|
bool needsManualRegisterMapping = false)
|
||||||
|
: interfaceOp(op), rewriter(rewriter),
|
||||||
|
needsManualRegisterMapping(needsManualRegisterMapping) {}
|
||||||
|
|
||||||
/// Add an operand with the read/write input type.
|
/// Add an operand with the read/write input type.
|
||||||
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
|
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
|
||||||
@ -87,6 +93,16 @@ public:
|
|||||||
void buildAndReplaceOp();
|
void buildAndReplaceOp();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the
|
||||||
|
/// PTX code.
|
||||||
|
void countPlaceholderNumbers(StringRef ptxCode,
|
||||||
|
llvm::SmallDenseSet<unsigned> &seenRW,
|
||||||
|
llvm::SmallDenseSet<unsigned> &seenW,
|
||||||
|
llvm::SmallDenseSet<unsigned> &seenR,
|
||||||
|
llvm::SmallVectorImpl<unsigned> &rwNums,
|
||||||
|
llvm::SmallVectorImpl<unsigned> &wNums,
|
||||||
|
llvm::SmallVectorImpl<unsigned> &rNums);
|
||||||
|
|
||||||
} // namespace NVVM
|
} // namespace NVVM
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
|
|||||||
following this order:
|
following this order:
|
||||||
1) Adds results
|
1) Adds results
|
||||||
2) Adds operands
|
2) Adds operands
|
||||||
3) Adds attributes
|
3) Adds attributes
|
||||||
|
Returns true if the OP is going to do register mapping itself
|
||||||
}],
|
}],
|
||||||
/*retType=*/"void",
|
/*retType=*/"bool",
|
||||||
/*methodName=*/"getAsmValues",
|
/*methodName=*/"getAsmValues",
|
||||||
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
|
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
|
||||||
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
|
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues
|
||||||
|
),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImpl=*/ [{
|
/*defaultImpl=*/ [{
|
||||||
mlir::Operation* op = $_op;
|
mlir::Operation* op = $_op;
|
||||||
|
|
||||||
// Step 1. Add results
|
// Step 1. Add results
|
||||||
for (auto val : op->getResults())
|
for (auto val : op->getResults())
|
||||||
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
|
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
|
||||||
|
|
||||||
// Step 2. Add operands
|
// Step 2. Add operands
|
||||||
for (auto val : op->getOperands())
|
for (auto val : op->getOperands())
|
||||||
@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
|
|||||||
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
|
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false; // No manual mapping needed
|
||||||
}]
|
}]
|
||||||
>
|
>
|
||||||
];
|
];
|
||||||
|
@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
|
let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
|
||||||
|
Variadic<AnyType>:$readWriteArgs,
|
||||||
StrAttr:$ptxCode,
|
StrAttr:$ptxCode,
|
||||||
PtxPredicate:$predicate);
|
PtxPredicate:$predicate);
|
||||||
|
|
||||||
let results = (outs Variadic<AnyType>:$writeOnlyArgs);
|
let results = (outs Variadic<AnyType>:$writeOnlyArgs);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$ptxCode `(` $readOnlyArgs `)`
|
$ptxCode
|
||||||
(`,` `predicate` `=` $predicate^)? attr-dict
|
( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
|
||||||
`:` type(operands)
|
( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
|
||||||
(`->` type($writeOnlyArgs)^)?
|
(`,` `predicate` `=` $predicate^)?
|
||||||
|
attr-dict
|
||||||
|
( `->` type($writeOnlyArgs)^ )?
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let extraClassDefinition = [{
|
let extraClassDefinition = [{
|
||||||
@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
|
|||||||
return std::string(ptxInstStr.data());
|
return std::string(ptxInstStr.data());
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -3057,8 +3064,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
|
|||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
void getAsmValues(RewriterBase &rewriter,
|
bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
|
||||||
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
|
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,9 +57,9 @@ struct PtxLowering
|
|||||||
|
|
||||||
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
|
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
|
||||||
LDBG() << op.getPtx();
|
LDBG() << op.getPtx();
|
||||||
PtxBuilder generator(op, rewriter);
|
|
||||||
|
|
||||||
op.getAsmValues(rewriter, asmValues);
|
bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
|
||||||
|
PtxBuilder generator(op, rewriter, needsManualMapping);
|
||||||
for (auto &[asmValue, modifier] : asmValues) {
|
for (auto &[asmValue, modifier] : asmValues) {
|
||||||
LDBG() << asmValue << "\t Modifier : " << modifier;
|
LDBG() << asmValue << "\t Modifier : " << modifier;
|
||||||
generator.insertValue(asmValue, modifier);
|
generator.insertValue(asmValue, modifier);
|
||||||
|
@ -13,7 +13,10 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
|
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/Support/DebugLog.h"
|
#include "llvm/Support/DebugLog.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/Regex.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "ptx-builder"
|
#define DEBUG_TYPE "ptx-builder"
|
||||||
|
|
||||||
@ -59,19 +62,37 @@ static char getRegisterType(Value v) {
|
|||||||
return getRegisterType(v.getType());
|
return getRegisterType(v.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extract every element of a struct value.
|
||||||
|
static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
|
||||||
|
Location loc, Value structVal) {
|
||||||
|
auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType());
|
||||||
|
assert(structTy && "expected LLVM struct");
|
||||||
|
|
||||||
|
SmallVector<Value> elems;
|
||||||
|
for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
|
||||||
|
elems.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, structVal, i));
|
||||||
|
|
||||||
|
return elems;
|
||||||
|
}
|
||||||
|
|
||||||
void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
|
void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
|
||||||
LDBG() << v << "\t Modifier : " << &itype;
|
LDBG() << v << "\t Modifier : " << itype << "\n";
|
||||||
|
registerModifiers.push_back(itype);
|
||||||
|
|
||||||
auto getModifier = [&]() -> const char * {
|
auto getModifier = [&]() -> const char * {
|
||||||
if (itype == PTXRegisterMod::ReadWrite) {
|
switch (itype) {
|
||||||
assert(false && "Read-Write modifier is not supported. Try setting the "
|
case PTXRegisterMod::Read:
|
||||||
"same value as Write and Read separately.");
|
return "";
|
||||||
|
case PTXRegisterMod::Write:
|
||||||
|
return "=";
|
||||||
|
case PTXRegisterMod::ReadWrite:
|
||||||
|
// "Read-Write modifier is not actually supported
|
||||||
|
// Interface will change it to "=" later and add integer mapping
|
||||||
return "+";
|
return "+";
|
||||||
}
|
}
|
||||||
if (itype == PTXRegisterMod::Write) {
|
llvm_unreachable("Unknown PTX register modifier");
|
||||||
return "=";
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
};
|
};
|
||||||
|
|
||||||
auto addValue = [&](Value v) {
|
auto addValue = [&](Value v) {
|
||||||
if (itype == PTXRegisterMod::Read) {
|
if (itype == PTXRegisterMod::Read) {
|
||||||
ptxOperands.push_back(v);
|
ptxOperands.push_back(v);
|
||||||
@ -108,38 +129,247 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Check if the operation needs to pack and unpack results.
|
/// Check if the operation needs to pack and unpack results.
|
||||||
static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
|
static bool
|
||||||
return interfaceOp->getNumResults() > 1;
|
needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
|
||||||
|
bool needsManualRegisterMapping,
|
||||||
|
SmallVectorImpl<PTXRegisterMod> ®isterModifiers) {
|
||||||
|
if (needsManualRegisterMapping)
|
||||||
|
return false;
|
||||||
|
const unsigned writeOnlyVals = interfaceOp->getNumResults();
|
||||||
|
const unsigned readWriteVals =
|
||||||
|
llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
|
||||||
|
return m == PTXRegisterMod::ReadWrite;
|
||||||
|
});
|
||||||
|
return (writeOnlyVals + readWriteVals) > 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pack the result types of the interface operation.
|
/// Pack the result types of the interface operation.
|
||||||
/// If the operation has multiple results, it packs them into a struct
|
/// If the operation has multiple results, it packs them into a struct
|
||||||
/// type. Otherwise, it returns the original result types.
|
/// type. Otherwise, it returns the original result types.
|
||||||
static SmallVector<Type> packResultTypes(MLIRContext *ctx,
|
static SmallVector<Type>
|
||||||
BasicPtxBuilderInterface interfaceOp) {
|
packResultTypes(BasicPtxBuilderInterface interfaceOp,
|
||||||
TypeRange results = interfaceOp->getResultTypes();
|
bool needsManualRegisterMapping,
|
||||||
|
SmallVectorImpl<PTXRegisterMod> ®isterModifiers,
|
||||||
|
SmallVectorImpl<Value> &ptxOperands) {
|
||||||
|
MLIRContext *ctx = interfaceOp->getContext();
|
||||||
|
TypeRange resultRange = interfaceOp->getResultTypes();
|
||||||
|
|
||||||
if (!needsPackUnpack(interfaceOp))
|
if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
|
||||||
return llvm::to_vector<1>(results);
|
registerModifiers)) {
|
||||||
|
// Single value path:
|
||||||
|
if (interfaceOp->getResults().size() == 1)
|
||||||
|
return SmallVector<Type>{resultRange.front()};
|
||||||
|
|
||||||
SmallVector<mlir::Type> elems(results.begin(), results.end());
|
// No declared results: if there is an RW, forward its type.
|
||||||
auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
|
for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
|
||||||
return {sTy};
|
if (m == PTXRegisterMod::ReadWrite)
|
||||||
|
return SmallVector<Type>{v.getType()};
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Type> packed;
|
||||||
|
for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
|
||||||
|
if (m == PTXRegisterMod::ReadWrite)
|
||||||
|
packed.push_back(v.getType());
|
||||||
|
for (Type t : resultRange)
|
||||||
|
packed.push_back(t);
|
||||||
|
|
||||||
|
if (packed.empty())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
|
||||||
|
return SmallVector<Type>{sTy};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Canonicalize the register constraints:
|
||||||
|
/// - Turn every "+X" into "=X"
|
||||||
|
/// - Append (at the very end) the 0-based indices of tokens that were "+X"
|
||||||
|
/// Examples:
|
||||||
|
/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
|
||||||
|
/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2"
|
||||||
|
static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
|
||||||
|
SmallVector<llvm::StringRef> toks;
|
||||||
|
SmallVector<std::string> out;
|
||||||
|
SmallVector<unsigned> plusIdx;
|
||||||
|
|
||||||
|
csv.split(toks, ',');
|
||||||
|
out.reserve(toks.size() + 8);
|
||||||
|
|
||||||
|
for (unsigned i = 0, e = toks.size(); i < e; ++i) {
|
||||||
|
StringRef t = toks[i].trim();
|
||||||
|
if (t.consume_front("+")) {
|
||||||
|
plusIdx.push_back(i);
|
||||||
|
out.push_back(("=" + t).str());
|
||||||
|
} else {
|
||||||
|
out.push_back(t.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append indices of original "+X" tokens.
|
||||||
|
for (unsigned idx : plusIdx)
|
||||||
|
out.push_back(std::to_string(idx));
|
||||||
|
|
||||||
|
// Join back to CSV.
|
||||||
|
std::string result;
|
||||||
|
result.reserve(csv.size() + plusIdx.size() * 2);
|
||||||
|
llvm::raw_string_ostream os(result);
|
||||||
|
for (size_t i = 0; i < out.size(); ++i) {
|
||||||
|
if (i)
|
||||||
|
os << ',';
|
||||||
|
os << out[i];
|
||||||
|
}
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
|
||||||
|
constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
|
||||||
|
constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
|
||||||
|
|
||||||
|
/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
|
||||||
|
static llvm::Regex getPredicateMappingRegex() {
|
||||||
|
llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
|
||||||
|
kReadWritePrefix, kWriteOnlyPrefix,
|
||||||
|
kReadOnlyPrefix)
|
||||||
|
.str());
|
||||||
|
return rx;
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::NVVM::countPlaceholderNumbers(
|
||||||
|
StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
|
||||||
|
llvm::SmallDenseSet<unsigned int> &seenW,
|
||||||
|
llvm::SmallDenseSet<unsigned int> &seenR,
|
||||||
|
llvm::SmallVectorImpl<unsigned int> &rwNums,
|
||||||
|
llvm::SmallVectorImpl<unsigned int> &wNums,
|
||||||
|
llvm::SmallVectorImpl<unsigned int> &rNums) {
|
||||||
|
|
||||||
|
llvm::Regex rx = getPredicateMappingRegex();
|
||||||
|
StringRef rest = ptxCode;
|
||||||
|
|
||||||
|
SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
|
||||||
|
while (!rest.empty() && rx.match(rest, &m)) {
|
||||||
|
unsigned num = 0;
|
||||||
|
(void)m[2].getAsInteger(10, num);
|
||||||
|
// Insert it into the vector only the first time we see this number
|
||||||
|
if (m[1].equals_insensitive(kReadWritePrefix)) {
|
||||||
|
if (seenRW.insert(num).second)
|
||||||
|
rwNums.push_back(num);
|
||||||
|
} else if (m[1].equals_insensitive(kWriteOnlyPrefix)) {
|
||||||
|
if (seenW.insert(num).second)
|
||||||
|
wNums.push_back(num);
|
||||||
|
} else {
|
||||||
|
if (seenR.insert(num).second)
|
||||||
|
rNums.push_back(num);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
|
||||||
|
rest = rest.drop_front(advance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into
|
||||||
|
/// compact `$K` indices:
|
||||||
|
/// - All `rw*` first (sorted by N),
|
||||||
|
/// - Then `w*`,
|
||||||
|
/// - Then `r*`.
|
||||||
|
/// If there a predicate, it comes always in the end.
|
||||||
|
/// Each number is assigned once; duplicates are ignored.
|
||||||
|
///
|
||||||
|
/// Example Input:
|
||||||
|
/// "{
|
||||||
|
/// reg .pred p;
|
||||||
|
/// setp.ge.s32 p, {$r0}, {$r1};"
|
||||||
|
/// selp.s32 {$rw0}, {$r0}, {$r1}, p;
|
||||||
|
/// selp.s32 {$rw1}, {$r0}, {$r1}, p;
|
||||||
|
/// selp.s32 {$w0}, {$r0}, {$r1}, p;
|
||||||
|
/// selp.s32 {$w1}, {$r0}, {$r1}, p;
|
||||||
|
/// }\n"
|
||||||
|
/// Example Output:
|
||||||
|
/// "{
|
||||||
|
/// reg .pred p;
|
||||||
|
/// setp.ge.s32 p, $4, $5;"
|
||||||
|
/// selp.s32 $0, $4, $5, p;
|
||||||
|
/// selp.s32 $1, $4, $5, p;
|
||||||
|
/// selp.s32 $2, $4, $5, p;
|
||||||
|
/// selp.s32 $3, $4, $5, p;
|
||||||
|
/// }\n"
|
||||||
|
static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
|
||||||
|
llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
|
||||||
|
llvm::SmallVector<unsigned> rwNums, wNums, rNums;
|
||||||
|
|
||||||
|
// Step 1. Count Register Placeholder numbers
|
||||||
|
countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums);
|
||||||
|
|
||||||
|
// Step 2. Sort the Register Placeholder numbers
|
||||||
|
llvm::sort(rwNums);
|
||||||
|
llvm::sort(wNums);
|
||||||
|
llvm::sort(rNums);
|
||||||
|
|
||||||
|
// Step 3. Create mapping from original to new IDs
|
||||||
|
llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
|
||||||
|
unsigned nextId = 0;
|
||||||
|
for (unsigned n : rwNums)
|
||||||
|
rwMap[n] = nextId++;
|
||||||
|
for (unsigned n : wNums)
|
||||||
|
wMap[n] = nextId++;
|
||||||
|
for (unsigned n : rNums)
|
||||||
|
rMap[n] = nextId++;
|
||||||
|
|
||||||
|
// Step 4. Rewrite the PTX code with new IDs
|
||||||
|
std::string out;
|
||||||
|
out.reserve(ptxCode.size());
|
||||||
|
size_t prev = 0;
|
||||||
|
StringRef rest = ptxCode;
|
||||||
|
SmallVector<StringRef, 3> matches;
|
||||||
|
llvm::Regex rx = getPredicateMappingRegex();
|
||||||
|
while (!rest.empty() && rx.match(rest, &matches)) {
|
||||||
|
// Compute absolute match bounds in the original buffer.
|
||||||
|
size_t absStart = (size_t)(matches[0].data() - ptxCode.data());
|
||||||
|
size_t absEnd = absStart + matches[0].size();
|
||||||
|
|
||||||
|
// Emit text before the match.
|
||||||
|
out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
|
||||||
|
|
||||||
|
// Emit compact $K
|
||||||
|
unsigned num = 0;
|
||||||
|
(void)matches[2].getAsInteger(10, num);
|
||||||
|
unsigned id = 0;
|
||||||
|
if (matches[1].equals_insensitive(kReadWritePrefix))
|
||||||
|
id = rwMap.lookup(num);
|
||||||
|
else if (matches[1].equals_insensitive(kWriteOnlyPrefix))
|
||||||
|
id = wMap.lookup(num);
|
||||||
|
else
|
||||||
|
id = rMap.lookup(num);
|
||||||
|
|
||||||
|
out.push_back('$');
|
||||||
|
out += std::to_string(id);
|
||||||
|
|
||||||
|
prev = absEnd;
|
||||||
|
|
||||||
|
const size_t advance =
|
||||||
|
(size_t)(matches[0].data() - rest.data()) + matches[0].size();
|
||||||
|
rest = rest.drop_front(advance);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5. Tail.
|
||||||
|
out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVM::InlineAsmOp PtxBuilder::build() {
|
LLVM::InlineAsmOp PtxBuilder::build() {
|
||||||
MLIRContext *ctx = interfaceOp->getContext();
|
|
||||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
|
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
|
||||||
LLVM::AsmDialect::AD_ATT);
|
LLVM::AsmDialect::AD_ATT);
|
||||||
|
|
||||||
SmallVector<Type> resultTypes = packResultTypes(ctx, interfaceOp);
|
SmallVector<Type> resultTypes = packResultTypes(
|
||||||
|
interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands);
|
||||||
|
|
||||||
// Remove the last comma from the constraints string.
|
// Remove the last comma from the constraints string.
|
||||||
if (!registerConstraints.empty() &&
|
if (!registerConstraints.empty() &&
|
||||||
registerConstraints[registerConstraints.size() - 1] == ',')
|
registerConstraints[registerConstraints.size() - 1] == ',')
|
||||||
registerConstraints.pop_back();
|
registerConstraints.pop_back();
|
||||||
|
registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
|
||||||
|
|
||||||
std::string ptxInstruction = interfaceOp.getPtx();
|
std::string ptxInstruction = interfaceOp.getPtx();
|
||||||
|
if (!needsManualRegisterMapping)
|
||||||
|
ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
|
||||||
|
|
||||||
// Add the predicate to the asm string.
|
// Add the predicate to the asm string.
|
||||||
if (interfaceOp.getPredicate().has_value() &&
|
if (interfaceOp.getPredicate().has_value() &&
|
||||||
@ -169,33 +399,87 @@ void PtxBuilder::buildAndReplaceOp() {
|
|||||||
LLVM::InlineAsmOp inlineAsmOp = build();
|
LLVM::InlineAsmOp inlineAsmOp = build();
|
||||||
LDBG() << "\n Generated PTX \n\t" << inlineAsmOp;
|
LDBG() << "\n Generated PTX \n\t" << inlineAsmOp;
|
||||||
|
|
||||||
// Case 1: no result
|
// Case 0: no result at all → just erase wrapper op.
|
||||||
if (inlineAsmOp->getNumResults() == 0) {
|
if (!hasResult) {
|
||||||
rewriter.eraseOp(interfaceOp);
|
rewriter.eraseOp(interfaceOp);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 2: single result, forward it directly
|
if (needsManualRegisterMapping) {
|
||||||
if (!needsPackUnpack(interfaceOp)) {
|
|
||||||
rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
|
rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 3: multiple results were packed; unpack the struct.
|
// Case 1: Simple path, return single scalar
|
||||||
assert(mlir::LLVM::LLVMStructType::classof(
|
if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
|
||||||
inlineAsmOp.getResultTypes().front()) &&
|
registerModifiers)) {
|
||||||
"Expected result type to be LLVMStructType when unpacking multiple "
|
if (inlineAsmOp->getNumResults() > 0) {
|
||||||
"results");
|
rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
|
||||||
auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
|
} else {
|
||||||
inlineAsmOp.getResultTypes().front());
|
// RW-only case with no declared results: forward the RW value.
|
||||||
|
SmallVector<Value> results;
|
||||||
SmallVector<mlir::Value> unpacked;
|
for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
|
||||||
Value structVal = inlineAsmOp.getResult(0);
|
if (m == PTXRegisterMod::ReadWrite) {
|
||||||
for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
|
results.push_back(v);
|
||||||
Value unpackedValue = LLVM::ExtractValueOp::create(
|
break;
|
||||||
rewriter, interfaceOp->getLoc(), structVal, idx);
|
}
|
||||||
unpacked.push_back(unpackedValue);
|
rewriter.replaceOp(interfaceOp, results);
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(interfaceOp, unpacked);
|
const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
|
||||||
|
return m == PTXRegisterMod::ReadWrite;
|
||||||
|
});
|
||||||
|
|
||||||
|
// All multi-value paths produce a single struct result we need to unpack.
|
||||||
|
assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
|
||||||
|
"expected struct return for multi-result inline asm");
|
||||||
|
Value structVal = inlineAsmOp.getResult(0);
|
||||||
|
SmallVector<Value> unpacked =
|
||||||
|
extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
|
||||||
|
|
||||||
|
// Case 2: only declared results (no RW): replace the op with all unpacked.
|
||||||
|
if (!hasRW && interfaceOp->getResults().size() > 0) {
|
||||||
|
rewriter.replaceOp(interfaceOp, unpacked);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 3: RW-only (no declared results): update RW uses and erase wrapper.
|
||||||
|
if (hasRW && interfaceOp->getResults().size() == 0) {
|
||||||
|
unsigned idx = 0;
|
||||||
|
for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
|
||||||
|
if (m != PTXRegisterMod::ReadWrite)
|
||||||
|
continue;
|
||||||
|
Value repl = unpacked[idx++];
|
||||||
|
v.replaceUsesWithIf(repl, [&](OpOperand &use) {
|
||||||
|
Operation *owner = use.getOwner();
|
||||||
|
return owner != interfaceOp && owner != inlineAsmOp;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
rewriter.eraseOp(interfaceOp);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 4: mixed (RW + declared results).
|
||||||
|
{
|
||||||
|
// First rewrite RW operands in place.
|
||||||
|
unsigned idx = 0;
|
||||||
|
for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
|
||||||
|
if (m != PTXRegisterMod::ReadWrite)
|
||||||
|
continue;
|
||||||
|
Value repl = unpacked[idx++];
|
||||||
|
v.replaceUsesWithIf(repl, [&](OpOperand &use) {
|
||||||
|
Operation *owner = use.getOwner();
|
||||||
|
return owner != interfaceOp && owner != inlineAsmOp;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// The remaining unpacked values correspond to the declared results.
|
||||||
|
SmallVector<Value> tail;
|
||||||
|
tail.reserve(unpacked.size() - idx);
|
||||||
|
for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
|
||||||
|
tail.push_back(unpacked[i]);
|
||||||
|
|
||||||
|
rewriter.replaceOp(interfaceOp, tail);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1123,7 +1123,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
|
|||||||
return ptx;
|
return ptx;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NVVM::WgmmaMmaAsyncOp::getAsmValues(
|
bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
|
||||||
RewriterBase &rewriter,
|
RewriterBase &rewriter,
|
||||||
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
|
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
|
||||||
&asmValues) {
|
&asmValues) {
|
||||||
@ -1154,7 +1154,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
|
|||||||
{makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
|
{makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
|
||||||
mlir::NVVM::PTXRegisterMod::Read});
|
mlir::NVVM::PTXRegisterMod::Read});
|
||||||
}
|
}
|
||||||
|
return true; // Has manual mapping
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult NVVM::FenceProxyOp::verify() {
|
LogicalResult NVVM::FenceProxyOp::verify() {
|
||||||
if (getKind() == NVVM::ProxyKind::TENSORMAP)
|
if (getKind() == NVVM::ProxyKind::TENSORMAP)
|
||||||
return emitOpError() << "tensormap proxy is not a supported proxy kind";
|
return emitOpError() << "tensormap proxy is not a supported proxy kind";
|
||||||
@ -1870,6 +1872,21 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool NVVM::InlinePtxOp::getAsmValues(
|
||||||
|
RewriterBase &rewriter,
|
||||||
|
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
|
||||||
|
&asmValues) {
|
||||||
|
for (auto arg : getReadWriteArgs())
|
||||||
|
asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
|
||||||
|
for (auto arg : getResults())
|
||||||
|
asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
|
||||||
|
for (auto arg : getReadOnlyArgs())
|
||||||
|
asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
|
||||||
|
if (getPredicate())
|
||||||
|
asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
|
||||||
|
return false; // No manual mapping needed
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// NVVMDialect initialization, type parsing, and registration.
|
// NVVMDialect initialization, type parsing, and registration.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -667,34 +667,82 @@ llvm.func @init_mbarrier(
|
|||||||
%count : i32,
|
%count : i32,
|
||||||
%pred : i1) {
|
%pred : i1) {
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r"
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r"
|
||||||
nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
|
nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32)
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
|
||||||
nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1
|
nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32), predicate = %pred
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
llvm.func @ex2(%input : f32, %pred : i1) {
|
llvm.func @ex2(%input : f32, %pred : i1) {
|
||||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
|
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
|
||||||
%0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
|
%0 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32) -> f32
|
||||||
|
|
||||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
|
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
|
||||||
%1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred : f32, i1 -> f32
|
%1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred -> f32
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @multi_return(
|
// CHECK-LABEL: @multi_return(
|
||||||
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
|
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
|
||||||
llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
|
llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
|
||||||
// CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
|
// CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
|
||||||
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)>
|
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)>
|
||||||
// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)>
|
// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)>
|
||||||
// CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
|
// CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
|
||||||
// CHECK: llvm.return %[[S4]] : i32
|
// CHECK: llvm.return %[[S4]] : i32
|
||||||
%r1, %r2 = nvvm.inline_ptx "{\n\t .reg .pred p;\n\t setp.ge.s32 p, $2, $3;\n\t selp.s32 $0, $2, $3, p;\n\t selp.s32 $1, $2, $3, !p;\n\t}\n" (%a, %b) : i32,i32 -> i32,i32
|
%r1, %r2 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
|
||||||
|
ro (%a, %b : i32,i32) -> i32,i32
|
||||||
%r3 = llvm.add %r1, %r2 : i32
|
%r3 = llvm.add %r1, %r2 : i32
|
||||||
llvm.return %r3 : i32
|
llvm.return %r3 : i32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @inline_ptx_multi_rw(
|
||||||
|
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
|
||||||
|
llvm.func @inline_ptx_multi_rw(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 {
|
||||||
|
// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}",
|
||||||
|
// CHECK-SAME: "=f,=f,r,r,0,1"
|
||||||
|
// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]]
|
||||||
|
// CHECK-SAME: : (f32, f32, i32, i32) -> !llvm.struct<(f32, f32)>
|
||||||
|
// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32)>
|
||||||
|
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32)>
|
||||||
|
// CHECK: %[[S3:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
|
||||||
|
// CHECK: llvm.return %[[S3]] : f32
|
||||||
|
nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p;}"
|
||||||
|
ro (%a, %b : i32,i32)
|
||||||
|
rw (%rw_c, %rw_d: f32,f32)
|
||||||
|
%r4 = llvm.fadd %rw_c, %rw_d : f32
|
||||||
|
llvm.return %r4 : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @inline_ptx_multi_rw_r(
|
||||||
|
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
|
||||||
|
llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 {
|
||||||
|
// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $4, $5; selp.s32 $0, $4,$5, p; selp.s32 $1, $4,$5, p; selp.s32 $2, $4,$5, p; selp.s32 $3, $4,$5, p;}",
|
||||||
|
// CHECK-SAME: "=f,=f,=r,=r,r,r,0,1"
|
||||||
|
// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]] :
|
||||||
|
// CHECK-SAME: (f32, f32, i32, i32) -> !llvm.struct<(f32, f32, i32, i32)>
|
||||||
|
// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32, i32, i32)>
|
||||||
|
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32, i32, i32)>
|
||||||
|
// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S0]][2] : !llvm.struct<(f32, f32, i32, i32)>
|
||||||
|
// CHECK: %[[S4:.+]] = llvm.extractvalue %[[S0]][3] : !llvm.struct<(f32, f32, i32, i32)>
|
||||||
|
// CHECK: %[[S5:.+]] = llvm.add %[[S3]], %[[S4]] : i32
|
||||||
|
// CHECK: %[[S6:.+]] = llvm.sitofp %[[S5]] : i32 to f32
|
||||||
|
// CHECK: %[[S7:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
|
||||||
|
// CHECK: %[[S8:.+]] = llvm.fadd %[[S6]], %[[S2]] : f32
|
||||||
|
// CHECK: llvm.return %[[S8]] : f32
|
||||||
|
|
||||||
|
%wo0, %wo1 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
|
||||||
|
ro (%a, %b : i32,i32)
|
||||||
|
rw (%rw_c, %rw_d: f32,f32) -> i32,i32
|
||||||
|
%r3 = llvm.add %wo0, %wo1 : i32
|
||||||
|
%r3f = llvm.sitofp %r3 : i32 to f32
|
||||||
|
%r4 = llvm.fadd %rw_c, %rw_d : f32
|
||||||
|
%r5 = llvm.fadd %r3f, %rw_d : f32
|
||||||
|
llvm.return %r5 : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @nvvm_pmevent
|
// CHECK-LABEL: @nvvm_pmevent
|
||||||
|
@ -5,6 +5,8 @@ from mlir.ir import *
|
|||||||
from mlir.dialects import nvvm
|
from mlir.dialects import nvvm
|
||||||
from mlir.dialects import llvm
|
from mlir.dialects import llvm
|
||||||
from mlir.dialects import func
|
from mlir.dialects import func
|
||||||
|
import mlir.extras.types as T
|
||||||
|
from mlir.dialects import arith
|
||||||
|
|
||||||
|
|
||||||
def constructAndPrintInModule(f):
|
def constructAndPrintInModule(f):
|
||||||
@ -25,6 +27,7 @@ def testSmoke():
|
|||||||
"!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
|
"!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
|
||||||
)
|
)
|
||||||
shape_attr = Attribute.parse("#nvvm.shape<m = 64, n = 32, k = 16>")
|
shape_attr = Attribute.parse("#nvvm.shape<m = 64, n = 32, k = 16>")
|
||||||
|
|
||||||
# CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64)
|
# CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64)
|
||||||
@func.FuncOp.from_py_func(i64, i64)
|
@func.FuncOp.from_py_func(i64, i64)
|
||||||
def wgmma_f32_f16_f16(desc_a, desc_b):
|
def wgmma_f32_f16_f16(desc_a, desc_b):
|
||||||
@ -48,3 +51,41 @@ def testSmoke():
|
|||||||
layoutA=nvvm.MMALayout.col,
|
layoutA=nvvm.MMALayout.col,
|
||||||
layoutB=nvvm.MMALayout.col,
|
layoutB=nvvm.MMALayout.col,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: TEST: test_inline_ptx
|
||||||
|
# CHECK-LABEL: func.func @my_inline_ptx(
|
||||||
|
# CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: f32, %[[arg1:[a-zA-Z0-9_]+]]: f32, %[[arg2:[a-zA-Z0-9_]+]]: i32, %[[arg3:[a-zA-Z0-9_]+]]: i32)
|
||||||
|
# CHECK: %[[S0:.+]]:2 = nvvm.inline_ptx
|
||||||
|
# CHECK-SAME: ro(%[[arg0]], %[[arg1]] : f32, f32) rw(%[[arg2]], %[[arg3]] : i32, i32) -> f32, f32
|
||||||
|
# CHECK: %[[S1:.+]] = arith.addf %[[arg0]], %[[arg1]] : f32
|
||||||
|
# CHECK: %[[S2:.+]] = arith.addi %[[arg2]], %[[arg3]] : i32
|
||||||
|
# CHECK: %[[S3:.+]] = arith.addf %[[S0]]#0, %[[S0]]#1 : f32
|
||||||
|
|
||||||
|
|
||||||
|
@constructAndPrintInModule
|
||||||
|
def test_inline_ptx():
|
||||||
|
i32 = T.i32()
|
||||||
|
f32 = T.f32()
|
||||||
|
|
||||||
|
@func.FuncOp.from_py_func(f32, f32, i32, i32)
|
||||||
|
def my_inline_ptx(a, b, c, d):
|
||||||
|
ptx = r"""
|
||||||
|
{
|
||||||
|
.reg .pred p;
|
||||||
|
setp.ge.s32 p, {$r0}, {$r1};
|
||||||
|
selp.s32 {$r0}, {$r0}, {$r1}, p;
|
||||||
|
selp.s32 {$r1}, {$r0}, {$r1}, p;
|
||||||
|
selp.s32 {$rw0}, {$r0}, {$r1}, p;
|
||||||
|
selp.s32 {$rw1}, {$r0}, {$r1}, p;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
wo0, wo1 = nvvm.inline_ptx(
|
||||||
|
read_only_args=[a, b],
|
||||||
|
read_write_args=[c, d],
|
||||||
|
write_only_args=[f32, f32],
|
||||||
|
ptx_code=ptx,
|
||||||
|
)
|
||||||
|
arith.addf(a, b)
|
||||||
|
arith.addi(c, d)
|
||||||
|
arith.addf(wo0, wo1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user