diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 3cfbd898e49e..e516118f7520 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -532,6 +532,9 @@ void GpuToLLVMConversionPass::runOnOperation() { // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. vector::populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); + // Transform N-D vector.from_elements to 1-D vector.from_elements before + // conversion. + vector::populateVectorFromElementsLoweringPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 317bfc2970cf..50ac1c60184e 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -369,6 +370,9 @@ struct LowerGpuOpsToNVVMOpsPass final { RewritePatternSet patterns(m.getContext()); populateGpuRewritePatterns(patterns); + // Transform N-D vector.from_elements to 1-D vector.from_elements before + // conversion. + vector::populateVectorFromElementsLoweringPatterns(patterns); if (failed(applyPatternsGreedily(m, std::move(patterns)))) return signalPassFailure(); }