diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp index 99e3dbc95519..66e1f8490629 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp @@ -1497,14 +1497,21 @@ struct SgToWiConvertLayout ConversionPatternRewriter &rewriter) const override { auto inputLayout = op.getInputLayoutAttr(); auto targetLayout = op.getTargetLayoutAttr(); - auto resShape = cast(op.getResult().getType()).getShape(); - SmallVector resShapeVec(resShape.begin(), resShape.end()); + Type valType = op.getResult().getType(); + if (valType.isIntOrFloat()) { + rewriter.replaceOp(op, op.getSource()); + return success(); + } + + auto resShape = cast(valType).getShape(); + SmallVector resShapeVec(resShape.begin(), resShape.end()); if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec, xegpu::LayoutKind::Lane)) { return rewriter.notifyMatchFailure( op, "lowering incompatible convert_layout not yet supported"); } + rewriter.replaceOp(op, adaptor.getSource()); return success(); } diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir index 4b386ac9317c..842c2375dd31 100644 --- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir +++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir @@ -858,6 +858,21 @@ gpu.func @convert_layout_removed_when_compatible() { } } +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @convert_layout_scalar +// CHECK-NOT: xegpu.convert_layout +gpu.func @convert_layout_scalar() { + %0 = "some_op"() : () -> f32 + %1 = xegpu.convert_layout %0 + <{input_layout = #xegpu.slice<#xegpu.layout, dims = [0]>, + target_layout = #xegpu.slice<#xegpu.layout, dims = [0]>}> + : f32 + "some_use"(%1) : (f32) -> () + gpu.return +} +} + // ----- // load_matrix and store_matrix with coordinate computation (offsets [0,0]) gpu.module @xevm_module {