Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
version = "0.11.1"
version = "0.11.2"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
52 changes: 49 additions & 3 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,51 @@ function matricizeop(
)
return matricizeop(style, op, a, to_permblocks(a, (perm_codomain, perm_domain))...)
end
# This is the primary function that should be overloaded for new fusion styles to fold
# ops into matricization (e.g., fuse `conj` into the permutation copy, or use lazy
# wrappers like StridedView with op metadata for zero-copy).
# Classifies how `matricize` realizes the bipermutation `(perm_codomain, perm_domain)`
# against storage, so `matricizeop` can skip the redundant permuted copy:
# ReshapeMatricizeKind — the groups are already in storage order, so the permute is a
# no-op and `matricize(style, a, ...)` can be called directly.
# For a dense array that is a `reshape` view; for a graded array
# it still gathers blocks, but skips the extra permute copy.
# TransposeMatricizeKind — the only reordering is a codomain/domain swap, which a dense
# array realizes as a `transpose` of a `reshape` (a view gemm
# reads via BLAS' transpose flag).
# PermuteMatricizeKind — the groups interleave storage, so a permuted copy is required.
# Pure: depends only on the index pattern, not on `a`'s data. Dispatched on `FusionStyle`.
# The generic classifier only recognizes the always-safe `ReshapeMatricizeKind` (skipping a
# no-op permute is valid for any style); `TransposeMatricizeKind` is opt-in for styles whose
# `matricize` composes with a lazy `transpose`, currently only `ReshapeFusion`.
@enum MatricizeKind ReshapeMatricizeKind TransposeMatricizeKind PermuteMatricizeKind

# Whether `perm` is the identity permutation `(1, …, n)`.
isidentityperm(perm::Tuple{Vararg{Int}}) = perm == ntuple(identity, length(perm))

function matricizekind(
::FusionStyle, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}
)
# Already in storage order: the permute is a no-op, so `matricize` can run directly.
isidentityperm((perm_codomain..., perm_domain...)) && return ReshapeMatricizeKind
return PermuteMatricizeKind
end

# Skip the permuted copy when the classifier says it is unnecessary. `ReshapeMatricizeKind`
# calls `matricize` directly on `a` (a view for dense, a gather without the extra permute
# for graded); `TransposeMatricizeKind` returns a lazy `transpose` of the reshape. Both
# fast paths require `op === identity`, since a plain view cannot carry a fused `op` like
# `conj`. The result may alias `a` and must be treated as read-only, matching the docstring.
function matricizeop(
style::FusionStyle, op, a::AbstractArray,
perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}
)
ndims(a) == length(perm_codomain) + length(perm_domain) ||
throw(ArgumentError("Invalid bipermutation"))
if op === identity
kind = matricizekind(style, perm_codomain, perm_domain)
kind == ReshapeMatricizeKind &&
return matricize(style, a, Val(length(perm_codomain)))
kind == TransposeMatricizeKind &&
return transpose(matricize(style, a, Val(length(perm_domain))))
end
a_perm_op = permutedimsop(op, a, perm_codomain, perm_domain)
return matricize(style, a_perm_op, Val(length(perm_codomain)))
end
Expand Down Expand Up @@ -373,6 +409,16 @@ end
function matricize(style::ReshapeFusion, a::AbstractArray, ndims_codomain::Val)
return reshape(a, matricize_axes(style, a, ndims_codomain))
end
# A dense array additionally realizes a codomain/domain swap as a lazy `transpose` of a
# reshape (a view), so it opts into `TransposeMatricizeKind` on top of the generic
# reshape/permute classification.
function matricizekind(
::ReshapeFusion, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}
)
isidentityperm((perm_codomain..., perm_domain...)) && return ReshapeMatricizeKind
isidentityperm((perm_domain..., perm_codomain...)) && return TransposeMatricizeKind
return PermuteMatricizeKind
end
function unmatricize(
style::ReshapeFusion, m::AbstractMatrix,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
Expand Down
89 changes: 89 additions & 0 deletions test/test_matricize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using LinearAlgebra: Transpose
using StableRNGs: StableRNG
using TensorAlgebra: TensorAlgebra, PermuteMatricizeKind, ReshapeFusion,
ReshapeMatricizeKind, TransposeMatricizeKind, matricize, matricizeop
using Test: @test, @testset

# A non-`ReshapeFusion` style, to check the always-safe generic fallback.
struct DummyFusion <: TensorAlgebra.FusionStyle end

# Ground-truth matricization: permute into `(codomain..., domain...)` order, then reshape.
function matricize_ref(a, perm_codomain, perm_domain)
a_perm = permutedims(a, (perm_codomain..., perm_domain...))
nrow = prod(i -> size(a, i), perm_codomain; init = 1)
ncol = prod(i -> size(a, i), perm_domain; init = 1)
return reshape(a_perm, (nrow, ncol))
end

@testset "matricizekind classifier" begin
style = ReshapeFusion()
# Already in storage order → plain reshape view.
@test TensorAlgebra.matricizekind(style, (1,), (2, 3)) == ReshapeMatricizeKind
@test TensorAlgebra.matricizekind(style, (1, 2), (3,)) == ReshapeMatricizeKind
@test TensorAlgebra.matricizekind(style, (1, 2, 3), ()) == ReshapeMatricizeKind
@test TensorAlgebra.matricizekind(style, (), (1, 2, 3)) == ReshapeMatricizeKind
# Pure codomain/domain swap → transpose of a reshape view.
@test TensorAlgebra.matricizekind(style, (2, 3), (1,)) == TransposeMatricizeKind
@test TensorAlgebra.matricizekind(style, (3,), (1, 2)) == TransposeMatricizeKind
# Interleaved → permuted copy.
@test TensorAlgebra.matricizekind(style, (3, 1), (2,)) == PermuteMatricizeKind
@test TensorAlgebra.matricizekind(style, (2,), (1, 3)) == PermuteMatricizeKind
@test TensorAlgebra.matricizekind(style, (1, 3), (2,)) == PermuteMatricizeKind
# Generic fusion styles recognize the always-safe reshape (no-op permute) but never
# claim a transpose (which only styles with a lazy `transpose` can realize).
@test TensorAlgebra.matricizekind(DummyFusion(), (1,), (2, 3)) == ReshapeMatricizeKind
@test TensorAlgebra.matricizekind(DummyFusion(), (1, 2, 3), ()) == ReshapeMatricizeKind
@test TensorAlgebra.matricizekind(DummyFusion(), (2, 3), (1,)) == PermuteMatricizeKind
@test TensorAlgebra.matricizekind(DummyFusion(), (3, 1), (2,)) == PermuteMatricizeKind
end

@testset "maybe-view matricizeop (eltype=$elt)" for elt in (Float64, ComplexF64)
a = randn(StableRNG(123), elt, 2, 3, 4)

# Reshape branch: correct values and a view aliasing `a`.
m = matricize(a, (1,), (2, 3))
@test m ≈ matricize_ref(a, (1,), (2, 3))
@test Base.mightalias(m, a)

# Transpose branch: correct values and a transpose view aliasing `a`.
m = matricize(a, (2, 3), (1,))
@test m ≈ matricize_ref(a, (2, 3), (1,))
@test m isa Transpose
@test Base.mightalias(m, a)

# Permute branch: correct values, but a fresh copy (no aliasing).
m = matricize(a, (3, 1), (2,))
@test m ≈ matricize_ref(a, (3, 1), (2,))
@test !Base.mightalias(m, a)

# `conj` cannot ride a view, so it copies even on the reshape/transpose patterns.
m = matricizeop(conj, a, (1,), (2, 3))
@test m ≈ conj.(matricize_ref(a, (1,), (2, 3)))
@test !Base.mightalias(m, a)
m = matricizeop(conj, a, (2, 3), (1,))
@test m ≈ conj.(matricize_ref(a, (2, 3), (1,)))
@test !Base.mightalias(m, a)
end

@testset "view branches track source mutations, copy branch does not" begin
rng = StableRNG(7)

# Reshape view tracks an in-place update of `a`.
a = randn(rng, 2, 3, 4)
m = matricize(a, (1,), (2, 3))
a .= randn(rng, 2, 3, 4)
@test m ≈ matricize_ref(a, (1,), (2, 3))

# Transpose view tracks an in-place update of `a`.
a = randn(rng, 2, 3, 4)
m = matricize(a, (2, 3), (1,))
a .= randn(rng, 2, 3, 4)
@test m ≈ matricize_ref(a, (2, 3), (1,))

# Permute copy is independent of later updates to `a`.
a = randn(rng, 2, 3, 4)
m = matricize(a, (3, 1), (2,))
snapshot = copy(m)
a .= a .+ 1
@test m == snapshot
end
Loading