From fbf484009c782a3ecb8bae6526d4c10d006365fc Mon Sep 17 00:00:00 2001 From: Vito Secona <77039267+secona@users.noreply.github.com> Date: Thu, 2 Apr 2026 00:42:26 +0700 Subject: [PATCH] [mlir][sparse] add GPU num threads to sparsifier options (#189078) This change adds a `gpu-num-threads` option to the sparsifier. This allows users to specify the number of threads used for GPU codegen, similar to the `num-threads` option in the `-sparse-gpu-codegen` pass. --- mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h | 7 +++++++ .../SparseTensor/Pipelines/SparseTensorPipelines.cpp | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h index 2e76985e92e1..6bb46299738d 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -151,6 +151,13 @@ struct SparsifierOptions : public PassPipelineOptions { desc("Enables GPU acceleration by means of direct library calls (like " "cuSPARSE)")}; + /// This option is used to specify the number of threads of GPU codegen. + PassOptions::Option gpuNumThreads{ + *this, "gpu-num-threads", + desc("Number of threads for GPU codegen. Setting this to 0 enables " + "direct library calls instead."), + init(1024)}; + /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { return SparsificationOptions(parallelization, emitStrategy, diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index dabbea1bdec6..d85966ec88f9 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -53,7 +53,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, // GPU code generation. const bool gpuCodegen = options.gpuTriple.hasValue(); if (gpuCodegen) { - pm.addPass(createSparseGPUCodegenPass()); + pm.addPass(createSparseGPUCodegenPass(options.gpuNumThreads, + options.enableRuntimeLibrary)); pm.addNestedPass(createStripDebugInfoPass()); pm.addNestedPass(createSCFToControlFlowPass()); pm.addNestedPass(createConvertGpuOpsToNVVMOps());