From 0fc13e3e70ed6f1170c22a5bfa4cce6af2b264e8 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 26 Jun 2026 14:01:28 -0400 Subject: [PATCH] Skip the redundant permute in matricize, returning a view for dense arrays `matricize`/`matricizeop` now skip the permuted copy when the requested row and column grouping is already in storage order, instead of always permuting first. A new `matricizekind` classifier, dispatched on `FusionStyle`, decides this per bipermutation: the generic classifier recognizes the always-safe aligned case (skipping a no-op permute is valid for any style), and `ReshapeFusion` (dense) additionally recognizes a pure codomain/domain swap, which it realizes as a lazy `transpose`. For a dense array the aligned and swapped cases return a `reshape`/`transpose` view of the input. For a graded array `matricize` still gathers blocks into a new matrix, but the redundant permute copy beforehand is skipped. The fast paths require `op === identity`, since a plain view cannot carry a fused `op` like `conj`. The result may alias the input and is read-only, which matches the `matricizeop` docstring's existing contract. This removes input-copy allocations from aligned contractions. At bond dimension 64 the dense matmul runs about 35% faster and a memory-bound rank-3 contraction about 50% faster, closing most of the gap to an optimized reference, and an aligned `AbelianGradedArray` contraction drops one operand's permute copy. Co-Authored-By: Claude Opus 4.8 (1M context) --- Project.toml | 2 +- src/matricize.jl | 52 ++++++++++++++++++++++-- test/test_matricize.jl | 89 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 test/test_matricize.jl diff --git a/Project.toml b/Project.toml index d1519b1..4b9a2cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.11.1" +version = "0.11.2" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/matricize.jl b/src/matricize.jl index 4e28422..77e6e66 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -259,15 +259,51 @@ function matricizeop( ) return matricizeop(style, op, a, to_permblocks(a, (perm_codomain, perm_domain))...) end -# This is the primary function that should be overloaded for new fusion styles to fold -# ops into matricization (e.g., fuse `conj` into the permutation copy, or use lazy -# wrappers like StridedView with op metadata for zero-copy). +# Classifies how `matricize` realizes the bipermutation `(perm_codomain, perm_domain)` +# against storage, so `matricizeop` can skip the redundant permuted copy: +# ReshapeMatricizeKind — the groups are already in storage order, so the permute is a +# no-op and `matricize(style, a, ...)` can be called directly. +# For a dense array that is a `reshape` view; for a graded array +# it still gathers blocks, but skips the extra permute copy. +# TransposeMatricizeKind — the only reordering is a codomain/domain swap, which a dense +# array realizes as a `transpose` of a `reshape` (a view gemm +# reads via BLAS' transpose flag). +# PermuteMatricizeKind — the groups interleave storage, so a permuted copy is required. +# Pure: depends only on the index pattern, not on `a`'s data. Dispatched on `FusionStyle`. +# The generic classifier only recognizes the always-safe `ReshapeMatricizeKind` (skipping a +# no-op permute is valid for any style); `TransposeMatricizeKind` is opt-in for styles whose +# `matricize` composes with a lazy `transpose`, currently only `ReshapeFusion`. +@enum MatricizeKind ReshapeMatricizeKind TransposeMatricizeKind PermuteMatricizeKind + +# Whether `perm` is the identity permutation `(1, …, n)`. +isidentityperm(perm::Tuple{Vararg{Int}}) = perm == ntuple(identity, length(perm)) + +function matricizekind( + ::FusionStyle, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}} + ) + # Already in storage order: the permute is a no-op, so `matricize` can run directly. + isidentityperm((perm_codomain..., perm_domain...)) && return ReshapeMatricizeKind + return PermuteMatricizeKind +end + +# Skip the permuted copy when the classifier says it is unnecessary. `ReshapeMatricizeKind` +# calls `matricize` directly on `a` (a view for dense, a gather without the extra permute +# for graded); `TransposeMatricizeKind` returns a lazy `transpose` of the reshape. Both +# fast paths require `op === identity`, since a plain view cannot carry a fused `op` like +# `conj`. The result may alias `a` and must be treated as read-only, matching the docstring. function matricizeop( style::FusionStyle, op, a::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}} ) ndims(a) == length(perm_codomain) + length(perm_domain) || throw(ArgumentError("Invalid bipermutation")) + if op === identity + kind = matricizekind(style, perm_codomain, perm_domain) + kind == ReshapeMatricizeKind && + return matricize(style, a, Val(length(perm_codomain))) + kind == TransposeMatricizeKind && + return transpose(matricize(style, a, Val(length(perm_domain)))) + end a_perm_op = permutedimsop(op, a, perm_codomain, perm_domain) return matricize(style, a_perm_op, Val(length(perm_codomain))) end @@ -373,6 +409,16 @@ end function matricize(style::ReshapeFusion, a::AbstractArray, ndims_codomain::Val) return reshape(a, matricize_axes(style, a, ndims_codomain)) end +# A dense array additionally realizes a codomain/domain swap as a lazy `transpose` of a +# reshape (a view), so it opts into `TransposeMatricizeKind` on top of the generic +# reshape/permute classification. +function matricizekind( + ::ReshapeFusion, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}} + ) + isidentityperm((perm_codomain..., perm_domain...)) && return ReshapeMatricizeKind + isidentityperm((perm_domain..., perm_codomain...)) && return TransposeMatricizeKind + return PermuteMatricizeKind +end function unmatricize( style::ReshapeFusion, m::AbstractMatrix, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, diff --git a/test/test_matricize.jl b/test/test_matricize.jl new file mode 100644 index 0000000..fdd8f0f --- /dev/null +++ b/test/test_matricize.jl @@ -0,0 +1,89 @@ +using LinearAlgebra: Transpose +using StableRNGs: StableRNG +using TensorAlgebra: TensorAlgebra, PermuteMatricizeKind, ReshapeFusion, + ReshapeMatricizeKind, TransposeMatricizeKind, matricize, matricizeop +using Test: @test, @testset + +# A non-`ReshapeFusion` style, to check the always-safe generic fallback. +struct DummyFusion <: TensorAlgebra.FusionStyle end + +# Ground-truth matricization: permute into `(codomain..., domain...)` order, then reshape. +function matricize_ref(a, perm_codomain, perm_domain) + a_perm = permutedims(a, (perm_codomain..., perm_domain...)) + nrow = prod(i -> size(a, i), perm_codomain; init = 1) + ncol = prod(i -> size(a, i), perm_domain; init = 1) + return reshape(a_perm, (nrow, ncol)) +end + +@testset "matricizekind classifier" begin + style = ReshapeFusion() + # Already in storage order → plain reshape view. + @test TensorAlgebra.matricizekind(style, (1,), (2, 3)) == ReshapeMatricizeKind + @test TensorAlgebra.matricizekind(style, (1, 2), (3,)) == ReshapeMatricizeKind + @test TensorAlgebra.matricizekind(style, (1, 2, 3), ()) == ReshapeMatricizeKind + @test TensorAlgebra.matricizekind(style, (), (1, 2, 3)) == ReshapeMatricizeKind + # Pure codomain/domain swap → transpose of a reshape view. + @test TensorAlgebra.matricizekind(style, (2, 3), (1,)) == TransposeMatricizeKind + @test TensorAlgebra.matricizekind(style, (3,), (1, 2)) == TransposeMatricizeKind + # Interleaved → permuted copy. + @test TensorAlgebra.matricizekind(style, (3, 1), (2,)) == PermuteMatricizeKind + @test TensorAlgebra.matricizekind(style, (2,), (1, 3)) == PermuteMatricizeKind + @test TensorAlgebra.matricizekind(style, (1, 3), (2,)) == PermuteMatricizeKind + # Generic fusion styles recognize the always-safe reshape (no-op permute) but never + # claim a transpose (which only styles with a lazy `transpose` can realize). + @test TensorAlgebra.matricizekind(DummyFusion(), (1,), (2, 3)) == ReshapeMatricizeKind + @test TensorAlgebra.matricizekind(DummyFusion(), (1, 2, 3), ()) == ReshapeMatricizeKind + @test TensorAlgebra.matricizekind(DummyFusion(), (2, 3), (1,)) == PermuteMatricizeKind + @test TensorAlgebra.matricizekind(DummyFusion(), (3, 1), (2,)) == PermuteMatricizeKind +end + +@testset "maybe-view matricizeop (eltype=$elt)" for elt in (Float64, ComplexF64) + a = randn(StableRNG(123), elt, 2, 3, 4) + + # Reshape branch: correct values and a view aliasing `a`. + m = matricize(a, (1,), (2, 3)) + @test m ≈ matricize_ref(a, (1,), (2, 3)) + @test Base.mightalias(m, a) + + # Transpose branch: correct values and a transpose view aliasing `a`. + m = matricize(a, (2, 3), (1,)) + @test m ≈ matricize_ref(a, (2, 3), (1,)) + @test m isa Transpose + @test Base.mightalias(m, a) + + # Permute branch: correct values, but a fresh copy (no aliasing). + m = matricize(a, (3, 1), (2,)) + @test m ≈ matricize_ref(a, (3, 1), (2,)) + @test !Base.mightalias(m, a) + + # `conj` cannot ride a view, so it copies even on the reshape/transpose patterns. + m = matricizeop(conj, a, (1,), (2, 3)) + @test m ≈ conj.(matricize_ref(a, (1,), (2, 3))) + @test !Base.mightalias(m, a) + m = matricizeop(conj, a, (2, 3), (1,)) + @test m ≈ conj.(matricize_ref(a, (2, 3), (1,))) + @test !Base.mightalias(m, a) +end + +@testset "view branches track source mutations, copy branch does not" begin + rng = StableRNG(7) + + # Reshape view tracks an in-place update of `a`. + a = randn(rng, 2, 3, 4) + m = matricize(a, (1,), (2, 3)) + a .= randn(rng, 2, 3, 4) + @test m ≈ matricize_ref(a, (1,), (2, 3)) + + # Transpose view tracks an in-place update of `a`. + a = randn(rng, 2, 3, 4) + m = matricize(a, (2, 3), (1,)) + a .= randn(rng, 2, 3, 4) + @test m ≈ matricize_ref(a, (2, 3), (1,)) + + # Permute copy is independent of later updates to `a`. + a = randn(rng, 2, 3, 4) + m = matricize(a, (3, 1), (2,)) + snapshot = copy(m) + a .= a .+ 1 + @test m == snapshot +end