[flang] Improved performance of runtime Matmul/MatmulTranspose.

This patch mostly affects performance of the code produced by
HLIFR lowering. If MATMUL argument is an array slice, then
HLFIR lowering passes the slice to the runtime, whereas
FIR lowering would create a contiguous temporary for the slice.
Performance might be better than the generic implementation
for cases where the leading dimension is contiguous.
This patch improves CPU2000/178.galgel making HLFIR version
faster than FIR version (due to avoiding the temporary copies
for MATMUL arguments).

Reviewed By: klausler

Differential Revision: https://reviews.llvm.org/D159134
This commit is contained in:
Slava Zakharin 2023-08-29 15:08:23 -07:00
parent 8f48392bc0
commit 4d9771741d
4 changed files with 421 additions and 30 deletions

View File

@ -52,25 +52,64 @@ using namespace Fortran::runtime;
// DO 2 I = 1, NROWS
// DO 2 K = 1, N
// 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
inline static void MatrixTransposedTimesMatrix(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
SubscriptValue n) {
SubscriptValue n, std::size_t xColumnByteStride = 0,
std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * cols * sizeof *product);
for (SubscriptValue j{0}; j < cols; ++j) {
for (SubscriptValue i{0}; i < rows; ++i) {
for (SubscriptValue k{0}; k < n; ++k) {
ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
ResultType y_kj = static_cast<ResultType>(y[j * n + k]);
ResultType x_ki;
if constexpr (!X_HAS_STRIDED_COLUMNS) {
x_ki = static_cast<ResultType>(x[i * n + k]);
} else {
x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>(
reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]);
}
ResultType y_kj;
if constexpr (!Y_HAS_STRIDED_COLUMNS) {
y_kj = static_cast<ResultType>(y[j * n + k]);
} else {
y_kj = static_cast<ResultType>(reinterpret_cast<const YT *>(
reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
}
product[j * rows + i] += x_ki * y_kj;
}
}
}
}
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
inline static void MatrixTransposedTimesMatrixHelper(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
SubscriptValue n, std::optional<std::size_t> xColumnByteStride,
std::optional<std::size_t> yColumnByteStride) {
if (!xColumnByteStride) {
if (!yColumnByteStride) {
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
product, rows, cols, x, y, n);
} else {
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
product, rows, cols, x, y, n, 0, *yColumnByteStride);
}
} else {
if (!yColumnByteStride) {
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
product, rows, cols, x, y, n, *xColumnByteStride);
} else {
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
}
}
}
// Contiguous numeric matrix*vector multiplication
// matrix(rows,n) * column vector(n) -> column vector(rows)
// Straightforward algorithm:
@ -85,21 +124,43 @@ inline static void MatrixTransposedTimesMatrix(
// DO 2 I = 1, NROWS
// DO 2 K = 1, N
// 2 RES(I) = RES(I) + X(K,I)*Y(K)
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool X_HAS_STRIDED_COLUMNS>
inline static void MatrixTransposedTimesVector(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y) {
SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
std::size_t xColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * sizeof *product);
for (SubscriptValue i{0}; i < rows; ++i) {
for (SubscriptValue k{0}; k < n; ++k) {
ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
ResultType x_ki;
if constexpr (!X_HAS_STRIDED_COLUMNS) {
x_ki = static_cast<ResultType>(x[i * n + k]);
} else {
x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>(
reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]);
}
ResultType y_k = static_cast<ResultType>(y[k]);
product[i] += x_ki * y_k;
}
}
}
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
inline static void MatrixTransposedTimesVectorHelper(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
std::optional<std::size_t> xColumnByteStride) {
if (!xColumnByteStride) {
MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, false>(
product, rows, n, x, y);
} else {
MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, true>(
product, rows, n, x, y, *xColumnByteStride);
}
}
// Implements an instance of MATMUL for given argument types.
template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
typename YT>
@ -149,19 +210,39 @@ inline static void DoMatmulTranspose(
const SubscriptValue rows{extent[0]};
const SubscriptValue cols{extent[1]};
if constexpr (RCAT != TypeCategory::Logical) {
if (x.IsContiguous() && y.IsContiguous() &&
if (x.IsContiguous(1) && y.IsContiguous(1) &&
(IS_ALLOCATING || result.IsContiguous())) {
// Contiguous numeric matrices
// Contiguous numeric matrices (maybe with columns
// separated by a stride).
std::optional<std::size_t> xColumnByteStride;
if (!x.IsContiguous()) {
// X's columns are strided.
SubscriptValue xAt[2]{};
x.GetLowerBounds(xAt);
xAt[1]++;
xColumnByteStride = x.SubscriptsToByteOffset(xAt);
}
std::optional<std::size_t> yColumnByteStride;
if (!y.IsContiguous()) {
// Y's columns are strided.
SubscriptValue yAt[2]{};
y.GetLowerBounds(yAt);
yAt[1]++;
yColumnByteStride = y.SubscriptsToByteOffset(yAt);
}
if (resRank == 2) { // M*M -> M
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT>(
// TODO: use BLAS-3 GEMM for supported types.
MatrixTransposedTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), rows, cols,
x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
yColumnByteStride);
return;
}
if (xRank == 2) { // M*V -> V
MatrixTransposedTimesVector<RCAT, RKIND, XT, YT>(
// TODO: use BLAS-2 GEMM for supported types.
MatrixTransposedTimesVectorHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), rows, n,
x.OffsetElement<XT>(), y.OffsetElement<YT>());
x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
return;
}
// else V*M -> V (not allowed because TRANSPOSE() is only defined for rank

View File

@ -69,10 +69,12 @@ private:
// DO 2 J = 1, NCOLS
// DO 2 I = 1, NROWS
// 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
const YT *RESTRICT y, SubscriptValue n) {
const YT *RESTRICT y, SubscriptValue n, std::size_t xColumnByteStride = 0,
std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * cols * sizeof *product);
const XT *RESTRICT xp0{x};
@ -80,12 +82,48 @@ inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
ResultType *RESTRICT p{product};
for (SubscriptValue j{0}; j < cols; ++j) {
const XT *RESTRICT xp{xp0};
auto yv{static_cast<ResultType>(y[k + j * n])};
ResultType yv;
if constexpr (!Y_HAS_STRIDED_COLUMNS) {
yv = static_cast<ResultType>(y[k + j * n]);
} else {
yv = static_cast<ResultType>(reinterpret_cast<const YT *>(
reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
}
for (SubscriptValue i{0}; i < rows; ++i) {
*p++ += static_cast<ResultType>(*xp++) * yv;
}
}
xp0 += rows;
if constexpr (!X_HAS_STRIDED_COLUMNS) {
xp0 += rows;
} else {
xp0 = reinterpret_cast<const XT *>(
reinterpret_cast<const char *>(xp0) + xColumnByteStride);
}
}
}
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
inline void MatrixTimesMatrixHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
const YT *RESTRICT y, SubscriptValue n,
std::optional<std::size_t> xColumnByteStride,
std::optional<std::size_t> yColumnByteStride) {
if (!xColumnByteStride) {
if (!yColumnByteStride) {
MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
product, rows, cols, x, y, n);
} else {
MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
product, rows, cols, x, y, n, 0, *yColumnByteStride);
}
} else {
if (!yColumnByteStride) {
MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
product, rows, cols, x, y, n, *xColumnByteStride);
} else {
MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
}
}
}
@ -103,18 +141,37 @@ inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
// DO 2 K = 1, N
// DO 2 J = 1, NROWS
// 2 RES(J) = RES(J) + X(J,K)*Y(K)
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool X_HAS_STRIDED_COLUMNS>
inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
const YT *RESTRICT y) {
const YT *RESTRICT y, std::size_t xColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * sizeof *product);
[[maybe_unused]] const XT *RESTRICT xp0{x};
for (SubscriptValue k{0}; k < n; ++k) {
ResultType *RESTRICT p{product};
auto yv{static_cast<ResultType>(*y++)};
for (SubscriptValue j{0}; j < rows; ++j) {
*p++ += static_cast<ResultType>(*x++) * yv;
}
if constexpr (X_HAS_STRIDED_COLUMNS) {
xp0 = reinterpret_cast<const XT *>(
reinterpret_cast<const char *>(xp0) + xColumnByteStride);
x = xp0;
}
}
}
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
inline void MatrixTimesVectorHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
const YT *RESTRICT y, std::optional<std::size_t> xColumnByteStride) {
if (!xColumnByteStride) {
MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y);
} else {
MatrixTimesVector<RCAT, RKIND, XT, YT, true>(
product, rows, n, x, y, *xColumnByteStride);
}
}
@ -132,10 +189,11 @@ inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
// DO 2 K = 1, N
// DO 2 J = 1, NCOLS
// 2 RES(J) = RES(J) + X(K)*Y(K,J)
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool Y_HAS_STRIDED_COLUMNS>
inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
const YT *RESTRICT y) {
const YT *RESTRICT y, std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, cols * sizeof *product);
for (SubscriptValue k{0}; k < n; ++k) {
@ -144,11 +202,29 @@ inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
const YT *RESTRICT yp{&y[k]};
for (SubscriptValue j{0}; j < cols; ++j) {
*p++ += xv * static_cast<ResultType>(*yp);
yp += n;
if constexpr (!Y_HAS_STRIDED_COLUMNS) {
yp += n;
} else {
yp = reinterpret_cast<const YT *>(
reinterpret_cast<const char *>(yp) + yColumnByteStride);
}
}
}
}
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool SPARSE_COLUMNS = false>
inline void VectorTimesMatrixHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
const YT *RESTRICT y, std::optional<std::size_t> yColumnByteStride) {
if (!yColumnByteStride) {
VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y);
} else {
VectorTimesMatrix<RCAT, RKIND, XT, YT, true>(
product, n, cols, x, y, *yColumnByteStride);
}
}
// Implements an instance of MATMUL for given argument types.
template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
typename YT>
@ -194,13 +270,35 @@ static inline void DoMatmul(
CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
RKIND>;
if constexpr (RCAT != TypeCategory::Logical) {
if (x.IsContiguous() && y.IsContiguous() &&
if (x.IsContiguous(1) && y.IsContiguous(1) &&
(IS_ALLOCATING || result.IsContiguous())) {
// Contiguous numeric matrices
// Contiguous numeric matrices (maybe with columns
// separated by a stride).
std::optional<std::size_t> xColumnByteStride;
if (!x.IsContiguous()) {
// X's columns are strided.
SubscriptValue xAt[2]{};
x.GetLowerBounds(xAt);
xAt[1]++;
xColumnByteStride = x.SubscriptsToByteOffset(xAt);
}
std::optional<std::size_t> yColumnByteStride;
if (!y.IsContiguous()) {
// Y's columns are strided.
SubscriptValue yAt[2]{};
y.GetLowerBounds(yAt);
yAt[1]++;
yColumnByteStride = y.SubscriptsToByteOffset(yAt);
}
// Note that BLAS GEMM can be used for the strided
// columns by setting proper leading dimension size.
// This implies that the column stride is divisible
// by the element size, which is usually true.
if (resRank == 2) { // M*M -> M
if (std::is_same_v<XT, YT>) {
if constexpr (std::is_same_v<XT, float>) {
// TODO: call BLAS-3 SGEMM
// TODO: try using CUTLASS for device.
} else if constexpr (std::is_same_v<XT, double>) {
// TODO: call BLAS-3 DGEMM
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
@ -209,9 +307,10 @@ static inline void DoMatmul(
// TODO: call BLAS-3 ZGEMM
}
}
MatrixTimesMatrix<RCAT, RKIND, XT, YT>(
MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), extent[0], extent[1],
x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
yColumnByteStride);
return;
} else if (xRank == 2) { // M*V -> V
if (std::is_same_v<XT, YT>) {
@ -225,9 +324,9 @@ static inline void DoMatmul(
// TODO: call BLAS-2 ZGEMV(x,y)
}
}
MatrixTimesVector<RCAT, RKIND, XT, YT>(
MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), extent[0], n,
x.OffsetElement<XT>(), y.OffsetElement<YT>());
x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
return;
} else { // V*M -> V
if (std::is_same_v<XT, YT>) {
@ -241,9 +340,9 @@ static inline void DoMatmul(
// TODO: call BLAS-2 ZGEMV(y,x)
}
}
VectorTimesMatrix<RCAT, RKIND, XT, YT>(
VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), n, extent[0],
x.OffsetElement<XT>(), y.OffsetElement<YT>());
x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride);
return;
}
}

View File

@ -27,6 +27,16 @@ TEST(Matmul, Basic) {
std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
auto v{MakeArray<TypeCategory::Integer, 8>(
std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
// X2 0 2 4 Y2 -1 -1
// 1 3 5 6 9
// -1 -1 -1 7 10
// 8 11
auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{3, 3},
std::vector<std::int32_t>{0, 1, -1, 2, 3, -1, 4, 5})};
auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2},
std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})};
StaticDescriptor<2, true> statDesc;
Descriptor &result{statDesc.descriptor()};
@ -73,6 +83,98 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
result.Destroy();
// Test non-contiguous sections.
static constexpr int sectionRank{2};
StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
Descriptor &sectionX2{sectionStaticDescriptorX2.descriptor()};
sectionX2.Establish(x2->type(), x2->ElementBytes(),
/*p=*/nullptr, /*rank=*/sectionRank);
static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{2, 3};
// Section of X2:
// +--------+
// | 0 2 4|
// | 1 3 5|
// +--------+
// -1 -1 -1
const auto errorX2{CFI_section(
&sectionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)};
ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2;
StaticDescriptor<sectionRank> sectionStaticDescriptorY2;
Descriptor &sectionY2{sectionStaticDescriptorY2.descriptor()};
sectionY2.Establish(y2->type(), y2->ElementBytes(),
/*p=*/nullptr, /*rank=*/sectionRank);
static const SubscriptValue lowersY2[]{2, 1};
// Section of Y2:
// -1 -1
// +-----+
// | 6 9|
// | 7 10|
// | 8 11|
// +-----+
const auto errorY2{CFI_section(&sectionY2.raw(), &y2->raw(), lowersY2,
/*uppers=*/nullptr, /*strides=*/nullptr)};
ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2;
RTNAME(Matmul)(result, sectionX2, *y, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 2);
EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(1).Extent(), 2);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
RTNAME(Matmul)(result, *x, sectionY2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 2);
EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(1).Extent(), 2);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
RTNAME(Matmul)(result, sectionX2, sectionY2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 2);
EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(1).Extent(), 2);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
RTNAME(Matmul)(result, *v, sectionX2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 3);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
result.Destroy();
RTNAME(Matmul)(result, sectionY2, *v, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 3);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
result.Destroy();
// X F F T Y F T
// F T T F T
// F F

View File

@ -32,6 +32,17 @@ TEST(MatmulTranspose, Basic) {
std::vector<std::int16_t>{0, 0, 0, 1, 1, 0, 1, 1})};
auto v{MakeArray<TypeCategory::Integer, 8>(
std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
// X2 0 1 Y2 -1 -1 Z2 6 7 8
// 2 3 6 9 9 10 11
// 4 5 7 10 -1 -1 -1
// -1 -1 8 11
auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{4, 2},
std::vector<std::int32_t>{0, 2, 4, -1, 1, 3, 5, -1})};
auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2},
std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})};
auto z2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{3, 3},
std::vector<std::int16_t>{6, 9, -1, 7, 10, -1, 8, 11, -1})};
StaticDescriptor<2, true> statDesc;
Descriptor &result{statDesc.descriptor()};
@ -89,6 +100,104 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
result.Destroy();
// Test non-contiguous sections.
static constexpr int sectionRank{2};
StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
Descriptor &sectionX2{sectionStaticDescriptorX2.descriptor()};
sectionX2.Establish(x2->type(), x2->ElementBytes(),
/*p=*/nullptr, /*rank=*/sectionRank);
static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{3, 2};
// Section of X2:
// +-----+
// | 0 1|
// | 2 3|
// | 4 5|
// +-----+
// -1 -1
const auto errorX2{CFI_section(
&sectionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)};
ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2;
StaticDescriptor<sectionRank> sectionStaticDescriptorY2;
Descriptor &sectionY2{sectionStaticDescriptorY2.descriptor()};
sectionY2.Establish(y2->type(), y2->ElementBytes(),
/*p=*/nullptr, /*rank=*/sectionRank);
static const SubscriptValue lowersY2[]{2, 1};
// Section of Y2:
// -1 -1
// +-----+
// | 6 0|
// | 7 10|
// | 8 11|
// +-----+
const auto errorY2{CFI_section(&sectionY2.raw(), &y2->raw(), lowersY2,
/*uppers=*/nullptr, /*strides=*/nullptr)};
ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2;
StaticDescriptor<sectionRank> sectionStaticDescriptorZ2;
Descriptor &sectionZ2{sectionStaticDescriptorZ2.descriptor()};
sectionZ2.Establish(z2->type(), z2->ElementBytes(),
/*p=*/nullptr, /*rank=*/sectionRank);
static const SubscriptValue lowersZ2[]{1, 1}, uppersZ2[]{2, 3};
// Section of Z2:
// +--------+
// | 6 7 8|
// | 9 10 11|
// +--------+
// -1 -1 -1
const auto errorZ2{CFI_section(
&sectionZ2.raw(), &z2->raw(), lowersZ2, uppersZ2, /*strides=*/nullptr)};
ASSERT_EQ(errorZ2, 0) << "CFI_section failed for Z2: " << errorZ2;
RTNAME(MatmulTranspose)(result, sectionX2, *y, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 2);
EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(1).Extent(), 2);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
RTNAME(MatmulTranspose)(result, *x, sectionY2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 2);
EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(1).Extent(), 2);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
RTNAME(MatmulTranspose)(result, sectionX2, sectionY2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 2);
EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(1).Extent(), 2);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
RTNAME(MatmulTranspose)(result, sectionZ2, *v, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(result.GetDimension(0).Extent(), 3);
ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
result.Destroy();
// X F F Y F T V T F T
// T F F T
// T T F F