Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitGPUArraysExt = "GPUArrays"
TensorKitMooncakeExt = "Mooncake"

[sources]
MatrixAlgebraKit = {url = "https://github.com/quantumkithub/matrixalgebrakit.jl", rev = "main"}

[compat]
AMDGPU = "2"
Adapt = "4"
Expand Down
1 change: 1 addition & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ using Random: AbstractRNG

include("utility.jl")
include("linalg.jl")
include("factorizations.jl")

end
122 changes: 122 additions & 0 deletions ext/TensorKitEnzymeExt/factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# need these due to Enzyme choking on blocks

for f in (:project_hermitian, :project_antihermitian)
f! = Symbol(f, :!)
@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
arg::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
$f!(A.val, arg.val, alg.val)
primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
cache = nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
arg::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
if !isa(A, Const) && !isa(arg, Const)
$f!(arg.dval, arg.dval, alg.val)
if A.dval !== arg.dval
A.dval .+= arg.dval
make_zero!(arg.dval)
end
end
return (nothing, nothing, nothing)
end
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
ret = $f(A.val, alg.val)
dret = make_zero(ret)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
cache = dret
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
dret = cache
if !isa(A, Const)
$f!(dret, dret, alg.val)
add!(A.dval, dret)
end
make_zero!(dret)
return (nothing, nothing)
end
end
end

# Enzyme seems to have trouble with this one
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(svd_compact)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
USVᴴ = svd_compact(A.val, alg.val)
primal = EnzymeRules.needs_primal(config) ? USVᴴ : nothing
shadow = EnzymeRules.needs_shadow(config) ? make_zero(USVᴴ) : nothing
cache = (USVᴴ, shadow)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(svd_compact)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
!isa(A, Const) && MatrixAlgebraKit.svd_pullback!(A.dval, A.val, cache...)
return (nothing, nothing)
end

# Enzyme seems to have trouble with this one
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(svd_trunc_no_error)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
USVᴴ = svd_compact(A.val, alg.val.alg)
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc)
dUSVᴴtrunc = make_zero(USVᴴtrunc)
cache = (USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind)
return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(svd_trunc_no_error)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind = cache
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴtrunc, ind)
return (nothing, nothing)
end
1 change: 1 addition & 0 deletions ext/TensorKitEnzymeExt/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β))
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true

@inline EnzymeRules.inactive(::typeof(TensorKit.infimum), ::Any, ::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing
Expand Down
36 changes: 36 additions & 0 deletions src/factorizations/pullbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ for pullback! in (
end
return Δt
end
@eval function MAK.$pullback!(
Δt::AbstractTensorMap, ::Nothing, F, ΔF; kwargs...
)
foreachblock(Δt) do c, (Δb,)
Fc = block.(F, Ref(c))
ΔFc = block.(ΔF, Ref(c))
return MAK.$pullback!(Δb, nothing, Fc, ΔFc; kwargs...)
end
return Δt
end
end
for pullback! in (:qr_null_pullback!, :lq_null_pullback!)
@eval function MAK.$pullback!(
Expand Down Expand Up @@ -41,6 +51,28 @@ for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!)
end
return Δt
end
@eval function MAK.$pullback!(
Δt::AbstractTensorMap, ::Nothing, F, ΔF, inds; kwargs...
)
foreachblock(Δt) do c, (Δb,)
haskey(inds, c) || return nothing
ind = inds[c]
Fc = block.(F, Ref(c))
ΔFc = block.(ΔF, Ref(c))
return MAK.$pullback!(Δb, nothing, Fc, ΔFc, ind; kwargs...)
end
return Δt
end
@eval function MAK.$pullback!(
Δt::AbstractTensorMap, t::AbstractTensorMap, F, ΔF, ::Colon; kwargs...
)
return MAK.$pullback!(Δt, t, F, ΔF, _notrunc_ind(t); kwargs...)
end
@eval function MAK.$pullback!(
Δt, ::Nothing, F, ΔF; kwargs...
)
return MAK.$pullback!(Δt, nothing, F, ΔF, _notrunc_ind(Δt); kwargs...)
end
end

for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_pullback!)
Expand Down Expand Up @@ -97,3 +129,7 @@ function MAK.remove_svd_gauge_dependence!(
end
return ΔU, ΔVᴴ
end

MAK.has_equal_storage(A::AbstractTensorMap, B::AbstractTensorMap) = A === B
MAK.has_equal_storage(A::AbstractTensorMap, B::SectorVector) = false
MAK.has_equal_storage(A::SectorVector, B::AbstractTensorMap) = false
31 changes: 31 additions & 0 deletions test/enzyme-factorizations-eig/eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using VectorInterface: Zero, One
using MatrixAlgebraKit
using MatrixAlgebraKit: remove_eig_gauge_dependence!
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Factorizations (EIG): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]))
atol = default_tol(T)
rtol = default_tol(T)
DV = eig_full(t)
ΔDV = EnzymeTestUtils.rand_tangent(DV)
remove_eig_gauge_dependence!(ΔDV[2], DV...)
EnzymeTestUtils.test_reverse(eig_full, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol)

#D = eig_vals(t)
#EnzymeTestUtils.test_reverse(eig_vals, Duplicated, (t, Duplicated); atol, rtol)

V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
trunc = truncspace(V_trunc)
alg = MatrixAlgebraKit.select_algorithm(eig_trunc_no_error, t, nothing; trunc)
DVtrunc = eig_trunc_no_error(t, alg)
ΔDVtrunc = EnzymeTestUtils.rand_tangent(DVtrunc)
remove_eig_gauge_dependence!(ΔDVtrunc[2], DVtrunc...)
EnzymeTestUtils.test_reverse(eig_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔDVtrunc, atol, rtol)
end
33 changes: 33 additions & 0 deletions test/enzyme-factorizations-eig/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using MatrixAlgebraKit
using MatrixAlgebraKit: remove_eigh_gauge_dependence!
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Factorizations (EIGH): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]))
atol = default_tol(T)
rtol = default_tol(T)
th = project_hermitian(t)
DV = eigh_full(th)
ΔDV = EnzymeTestUtils.rand_tangent(DV)
remove_eigh_gauge_dependence!(ΔDV[2], DV...)
proj_eigh_full(t) = eigh_full(project_hermitian(t))
EnzymeTestUtils.test_reverse(proj_eigh_full, Duplicated, (th, Duplicated); output_tangent = ΔDV, atol, rtol)

#D = eigh_vals(th)
#EnzymeTestUtils.test_reverse(eigh_vals ∘ project_hermitian, Duplicated, (th, Duplicated); atol, rtol)

V_trunc = spacetype(th)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
trunc = truncspace(V_trunc)
alg = MatrixAlgebraKit.select_algorithm(eigh_trunc_no_error, th, nothing; trunc)
DVtrunc = eigh_trunc_no_error(th, alg)
ΔDVtrunc = EnzymeTestUtils.rand_tangent(DVtrunc)
remove_eigh_gauge_dependence!(ΔDVtrunc[2], DVtrunc...)
proj_eigh(t, alg) = eigh_trunc_no_error(project_hermitian(t), alg)
EnzymeTestUtils.test_reverse(proj_eigh, Duplicated, (th, Duplicated), (alg, Const); output_tangent = ΔDVtrunc, atol, rtol)
end
18 changes: 18 additions & 0 deletions test/enzyme-factorizations-eig/projections.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using MatrixAlgebraKit
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Factorizations (PROJECTIONS): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]))
atol = default_tol(T)
rtol = default_tol(T)
EnzymeTestUtils.test_reverse(project_hermitian, Duplicated, (t, Duplicated); atol, rtol)
EnzymeTestUtils.test_reverse(project_antihermitian, Duplicated, (t, Duplicated); atol, rtol)
EnzymeTestUtils.test_reverse(project_hermitian!, Duplicated, (t, Duplicated); atol, rtol)
EnzymeTestUtils.test_reverse(project_antihermitian!, Duplicated, (t, Duplicated); atol, rtol)
end
28 changes: 28 additions & 0 deletions test/enzyme-factorizations-lqqr/lq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using MatrixAlgebraKit
using MatrixAlgebraKit: remove_lq_gauge_dependence!, remove_lq_null_gauge_dependence!
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Factorizations (LQ): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, A in (randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])'))
atol = default_tol(T)
rtol = default_tol(T)
EnzymeTestUtils.test_reverse(lq_compact, Duplicated, (A, Duplicated); atol, rtol)

# lq_full/lq_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = EnzymeTestUtils.rand_tangent(LQ)
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol)

Nᴴ = lq_null(A)
Q = lq_compact(A)[2]
ΔNᴴ = EnzymeTestUtils.rand_tangent(Nᴴ)
remove_lq_null_gauge_dependence!(ΔNᴴ, Q, Nᴴ)
EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); output_tangent = ΔNᴴ, atol, rtol)
end
30 changes: 30 additions & 0 deletions test/enzyme-factorizations-lqqr/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using VectorInterface: Zero, One
using MatrixAlgebraKit
using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_qr_null_gauge_dependence!
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Factorizations (QR): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, A in (randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]), randn(T, V[1] ⊗ V[2] ⊗ V[3] ← (V[4] ⊗ V[5])'))
atol = default_tol(T)
rtol = default_tol(T)

EnzymeTestUtils.test_reverse(qr_compact, Duplicated, (A, Duplicated); atol, rtol)

# qr_full/qr_null requires being careful with gauges
QR = qr_full(A)
ΔQR = EnzymeTestUtils.rand_tangent(QR)
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol)

N = qr_null(A)
Q = qr_compact(A)[1]
ΔN = EnzymeTestUtils.rand_tangent(N)
remove_qr_null_gauge_dependence!(ΔN, A, N)
EnzymeTestUtils.test_reverse(qr_null, Duplicated, (A, Duplicated); atol, rtol, output_tangent = ΔN)
end
36 changes: 36 additions & 0 deletions test/enzyme-factorizations-svd/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using MatrixAlgebraKit
using MatrixAlgebraKit: remove_svd_gauge_dependence!
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Factorizations (SVD): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])'))
atol = default_tol(T)
rtol = default_tol(T)

#S = svd_vals(t)
#EnzymeTestUtils.test_reverse(svd_vals, Duplicated, (t, Duplicated); atol, rtol)

USVᴴ = svd_compact(t)
ΔUSVᴴ = EnzymeTestUtils.rand_tangent(USVᴴ)
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
EnzymeTestUtils.test_reverse(svd_compact, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol)

USVᴴ = svd_full(t)
ΔUSVᴴ = EnzymeTestUtils.rand_tangent(USVᴴ)
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
EnzymeTestUtils.test_reverse(svd_full, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol)

V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
trunc = truncspace(V_trunc)
alg = MatrixAlgebraKit.select_algorithm(svd_trunc_no_error, t, nothing; trunc)
USVᴴtrunc = svd_trunc_no_error(t, alg)
ΔUSVᴴtrunc = EnzymeTestUtils.rand_tangent(USVᴴtrunc)
remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], USVᴴtrunc...)
EnzymeTestUtils.test_reverse(svd_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔUSVᴴtrunc, atol, rtol)
end
Loading