[mlir] Fix parsing of empty complex tensors (#134322)

After https://github.com/llvm/llvm-project/pull/133220 we had some empty
complex literals (`tensor<0xcomplex<f32>>`) failing to parse.

This was largely due to the ambiguity between `shape.empty()` meaning
splat (`dense<1>`) or empty literal (`dense<>`). Used type's numel to
disambiguate during verification.
This commit is contained in:
Kevin Gleason 2025-04-04 11:29:51 -05:00 committed by GitHub
parent 0d3f5ec0da
commit e8d5009784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 2 deletions

View File

@ -566,8 +566,10 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) { if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
eltType = complexTy.getElementType(); eltType = complexTy.getElementType();
isComplex = true; isComplex = true;
// Complex types have 2 elements. // Complex types have N*2 elements or complex splat.
if (shape.empty() && storage.size() != 2) { // Empty shape may mean a splat or empty literal, only validate splats.
bool isSplat = shape.empty() && type.getNumElements() != 0;
if (isSplat && storage.size() != 2) {
p.emitError(loc) << "parsed " << storage.size() << " elements, but type (" p.emitError(loc) << "parsed " << storage.size() << " elements, but type ("
<< complexTy << ") expected 2 elements"; << complexTy << ") expected 2 elements";
return nullptr; return nullptr;

View File

@ -730,6 +730,10 @@ func.func @densetensorattr() -> () {
"complex_attr"(){bar = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>} : () -> () "complex_attr"(){bar = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>} : () -> ()
// CHECK: dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>> // CHECK: dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>
"complex_attr"(){bar = dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>} : () -> () "complex_attr"(){bar = dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>} : () -> ()
// CHECK: dense<> : tensor<0xcomplex<i64>>
"complex_attr"(){bar = dense<> : tensor<0xcomplex<i64>>} : () -> ()
// CHECK: dense<> : tensor<2x0xcomplex<i64>>
"complex_attr"(){bar = dense<> : tensor<2x0xcomplex<i64>>} : () -> ()
return return
} }