From 4f9a8050f9a0bf16b02f98cd4ed7fb47c170d790 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 25 Jun 2026 21:40:48 +0200 Subject: [PATCH 1/5] Add Enzyme rules for factorizations --- ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl | 1 + ext/TensorKitEnzymeExt/factorizations.jl | 122 ++++++++++++++++++ src/factorizations/pullbacks.jl | 36 ++++++ test/enzyme-factorizations-eig/eig.jl | 31 +++++ test/enzyme-factorizations-eig/eigh.jl | 33 +++++ test/enzyme-factorizations-eig/projections.jl | 18 +++ test/enzyme-factorizations-lqqr/lq.jl | 28 ++++ test/enzyme-factorizations-lqqr/qr.jl | 30 +++++ test/enzyme-factorizations-svd/svd.jl | 36 ++++++ 9 files changed, 335 insertions(+) create mode 100644 ext/TensorKitEnzymeExt/factorizations.jl create mode 100644 test/enzyme-factorizations-eig/eig.jl create mode 100644 test/enzyme-factorizations-eig/eigh.jl create mode 100644 test/enzyme-factorizations-eig/projections.jl create mode 100644 test/enzyme-factorizations-lqqr/lq.jl create mode 100644 test/enzyme-factorizations-lqqr/qr.jl create mode 100644 test/enzyme-factorizations-svd/svd.jl diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl index 7f448f9e3..8d9c8d5e0 100644 --- a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -12,5 +12,6 @@ using Random: AbstractRNG include("utility.jl") include("linalg.jl") +include("factorizations.jl") end diff --git a/ext/TensorKitEnzymeExt/factorizations.jl b/ext/TensorKitEnzymeExt/factorizations.jl new file mode 100644 index 000000000..afe69c119 --- /dev/null +++ b/ext/TensorKitEnzymeExt/factorizations.jl @@ -0,0 +1,122 @@ +# need these due to Enzyme choking on blocks + +for f in (:project_hermitian, :project_antihermitian) + f! = Symbol(f, :!) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + arg::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + $f!(A.val, arg.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? arg.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing + cache = nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + arg::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + if !isa(A, Const) && !isa(arg, Const) + $f!(arg.dval, arg.dval, alg.val) + if A.dval !== arg.dval + A.dval .+= arg.dval + make_zero!(arg.dval) + end + end + return (nothing, nothing, nothing) + end + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + ret = $f(A.val, alg.val) + dret = make_zero(ret) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + cache = dret + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + dret = cache + if !isa(A, Const) + $f!(dret, dret, alg.val) + add!(A.dval, dret) + end + make_zero!(dret) + return (nothing, nothing) + end + end +end + +# Enzyme seems to have trouble with this one +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_compact)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ = svd_compact(A.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? USVᴴ : nothing + shadow = EnzymeRules.needs_shadow(config) ? make_zero(USVᴴ) : nothing + cache = (USVᴴ, shadow) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_compact)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + !isa(A, Const) && MatrixAlgebraKit.svd_pullback!(A.dval, A.val, cache...) + return (nothing, nothing) +end + +# Enzyme seems to have trouble with this one +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc_no_error)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ = svd_compact(A.val, alg.val.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc) + dUSVᴴtrunc = make_zero(USVᴴtrunc) + cache = (USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind) + return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc_no_error)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind = cache + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴtrunc, ind) + return (nothing, nothing) +end diff --git a/src/factorizations/pullbacks.jl b/src/factorizations/pullbacks.jl index b910a184c..e0a044ebb 100644 --- a/src/factorizations/pullbacks.jl +++ b/src/factorizations/pullbacks.jl @@ -11,6 +11,16 @@ for pullback! in ( end return Δt end + @eval function MAK.$pullback!( + Δt::AbstractTensorMap, ::Nothing, F, ΔF; kwargs... + ) + foreachblock(Δt) do c, (Δb,) + Fc = block.(F, Ref(c)) + ΔFc = block.(ΔF, Ref(c)) + return MAK.$pullback!(Δb, nothing, Fc, ΔFc; kwargs...) + end + return Δt + end end for pullback! in (:qr_null_pullback!, :lq_null_pullback!) @eval function MAK.$pullback!( @@ -41,6 +51,28 @@ for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!) end return Δt end + @eval function MAK.$pullback!( + Δt::AbstractTensorMap, ::Nothing, F, ΔF, inds; kwargs... + ) + foreachblock(Δt) do c, (Δb,) + haskey(inds, c) || return nothing + ind = inds[c] + Fc = block.(F, Ref(c)) + ΔFc = block.(ΔF, Ref(c)) + return MAK.$pullback!(Δb, nothing, Fc, ΔFc, ind; kwargs...) + end + return Δt + end + @eval function MAK.$pullback!( + Δt::AbstractTensorMap, t::AbstractTensorMap, F, ΔF, ::Colon; kwargs... + ) + return MAK.$pullback!(Δt, t, F, ΔF, _notrunc_ind(t); kwargs...) + end + @eval function MAK.$pullback!( + Δt, ::Nothing, F, ΔF; kwargs... + ) + return MAK.$pullback!(Δt, nothing, F, ΔF, _notrunc_ind(Δt); kwargs...) + end end for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_pullback!) @@ -97,3 +129,7 @@ function MAK.remove_svd_gauge_dependence!( end return ΔU, ΔVᴴ end + +MAK.has_equal_storage(A::AbstractTensorMap, B::AbstractTensorMap) = A === B +MAK.has_equal_storage(A::AbstractTensorMap, B::SectorVector) = false +MAK.has_equal_storage(A::SectorVector, B::AbstractTensorMap) = false diff --git a/test/enzyme-factorizations-eig/eig.jl b/test/enzyme-factorizations-eig/eig.jl new file mode 100644 index 000000000..812017a64 --- /dev/null +++ b/test/enzyme-factorizations-eig/eig.jl @@ -0,0 +1,31 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using MatrixAlgebraKit: remove_eig_gauge_dependence! +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Factorizations (EIG): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + atol = default_tol(T) + rtol = default_tol(T) + DV = eig_full(t) + ΔDV = EnzymeTestUtils.rand_tangent(DV) + remove_eig_gauge_dependence!(ΔDV[2], DV...) + EnzymeTestUtils.test_reverse(eig_full, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol) + + #D = eig_vals(t) + #EnzymeTestUtils.test_reverse(eig_vals, Duplicated, (t, Duplicated); atol, rtol) + + V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(eig_trunc_no_error, t, nothing; trunc) + DVtrunc = eig_trunc_no_error(t, alg) + ΔDVtrunc = EnzymeTestUtils.rand_tangent(DVtrunc) + remove_eig_gauge_dependence!(ΔDVtrunc[2], DVtrunc...) + EnzymeTestUtils.test_reverse(eig_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔDVtrunc, atol, rtol) +end diff --git a/test/enzyme-factorizations-eig/eigh.jl b/test/enzyme-factorizations-eig/eigh.jl new file mode 100644 index 000000000..970e2bfef --- /dev/null +++ b/test/enzyme-factorizations-eig/eigh.jl @@ -0,0 +1,33 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using MatrixAlgebraKit +using MatrixAlgebraKit: remove_eigh_gauge_dependence! +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float32, ComplexF64) + +@timedtestset "Enzyme - Factorizations (EIGH): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + atol = default_tol(T) + rtol = default_tol(T) + th = project_hermitian(t) + DV = eigh_full(th) + ΔDV = EnzymeTestUtils.rand_tangent(DV) + remove_eigh_gauge_dependence!(ΔDV[2], DV...) + proj_eigh_full(t) = eigh_full(project_hermitian(t)) + EnzymeTestUtils.test_reverse(proj_eigh_full, Duplicated, (th, Duplicated); output_tangent = ΔDV, atol, rtol) + + #D = eigh_vals(th) + #EnzymeTestUtils.test_reverse(eigh_vals ∘ project_hermitian, Duplicated, (th, Duplicated); atol, rtol) + + V_trunc = spacetype(th)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(eigh_trunc_no_error, th, nothing; trunc) + DVtrunc = eigh_trunc_no_error(th, alg) + ΔDVtrunc = EnzymeTestUtils.rand_tangent(DVtrunc) + remove_eigh_gauge_dependence!(ΔDVtrunc[2], DVtrunc...) + proj_eigh(t, alg) = eigh_trunc_no_error(project_hermitian(t), alg) + EnzymeTestUtils.test_reverse(proj_eigh, Duplicated, (th, Duplicated), (alg, Const); output_tangent = ΔDVtrunc, atol, rtol) +end diff --git a/test/enzyme-factorizations-eig/projections.jl b/test/enzyme-factorizations-eig/projections.jl new file mode 100644 index 000000000..1b5b3a28f --- /dev/null +++ b/test/enzyme-factorizations-eig/projections.jl @@ -0,0 +1,18 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Factorizations (PROJECTIONS): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + atol = default_tol(T) + rtol = default_tol(T) + EnzymeTestUtils.test_reverse(project_hermitian, Duplicated, (t, Duplicated); atol, rtol) + EnzymeTestUtils.test_reverse(project_antihermitian, Duplicated, (t, Duplicated); atol, rtol) + EnzymeTestUtils.test_reverse(project_hermitian!, Duplicated, (t, Duplicated); atol, rtol) + EnzymeTestUtils.test_reverse(project_antihermitian!, Duplicated, (t, Duplicated); atol, rtol) +end diff --git a/test/enzyme-factorizations-lqqr/lq.jl b/test/enzyme-factorizations-lqqr/lq.jl new file mode 100644 index 000000000..d081a0ec9 --- /dev/null +++ b/test/enzyme-factorizations-lqqr/lq.jl @@ -0,0 +1,28 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using MatrixAlgebraKit +using MatrixAlgebraKit: remove_lq_gauge_dependence!, remove_lq_null_gauge_dependence! +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Factorizations (LQ): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, A in (randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')) + atol = default_tol(T) + rtol = default_tol(T) + EnzymeTestUtils.test_reverse(lq_compact, Duplicated, (A, Duplicated); atol, rtol) + + # lq_full/lq_null requires being careful with gauges + LQ = lq_full(A) + ΔLQ = EnzymeTestUtils.rand_tangent(LQ) + remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) + EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol) + + Nᴴ = lq_null(A) + Q = lq_compact(A)[2] + ΔNᴴ = EnzymeTestUtils.rand_tangent(Nᴴ) + remove_lq_null_gauge_dependence!(ΔNᴴ, Q, Nᴴ) + EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); output_tangent = ΔNᴴ, atol, rtol) +end diff --git a/test/enzyme-factorizations-lqqr/qr.jl b/test/enzyme-factorizations-lqqr/qr.jl new file mode 100644 index 000000000..d57a2baaa --- /dev/null +++ b/test/enzyme-factorizations-lqqr/qr.jl @@ -0,0 +1,30 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_qr_null_gauge_dependence! +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Factorizations (QR): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, A in (randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]), randn(T, V[1] ⊗ V[2] ⊗ V[3] ← (V[4] ⊗ V[5])')) + atol = default_tol(T) + rtol = default_tol(T) + + EnzymeTestUtils.test_reverse(qr_compact, Duplicated, (A, Duplicated); atol, rtol) + + # qr_full/qr_null requires being careful with gauges + QR = qr_full(A) + ΔQR = EnzymeTestUtils.rand_tangent(QR) + remove_qr_gauge_dependence!(ΔQR..., A, QR...) + EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol) + + N = qr_null(A) + Q = qr_compact(A)[1] + ΔN = EnzymeTestUtils.rand_tangent(N) + remove_qr_null_gauge_dependence!(ΔN, A, N) + EnzymeTestUtils.test_reverse(qr_null, Duplicated, (A, Duplicated); atol, rtol, output_tangent = ΔN) +end diff --git a/test/enzyme-factorizations-svd/svd.jl b/test/enzyme-factorizations-svd/svd.jl new file mode 100644 index 000000000..c51d75847 --- /dev/null +++ b/test/enzyme-factorizations-svd/svd.jl @@ -0,0 +1,36 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using MatrixAlgebraKit +using MatrixAlgebraKit: remove_svd_gauge_dependence! +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64,) # ComplexF64) + +@timedtestset "Enzyme - Factorizations (SVD): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')) + atol = default_tol(T) + rtol = default_tol(T) + + #S = svd_vals(t) + #EnzymeTestUtils.test_reverse(svd_vals, Duplicated, (t, Duplicated); atol, rtol) + + USVᴴ = svd_compact(t) + ΔUSVᴴ = EnzymeTestUtils.rand_tangent(USVᴴ) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + EnzymeTestUtils.test_reverse(svd_compact, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol) + + USVᴴ = svd_full(t) + ΔUSVᴴ = EnzymeTestUtils.rand_tangent(USVᴴ) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + EnzymeTestUtils.test_reverse(svd_full, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol) + + V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(svd_trunc_no_error, t, nothing; trunc) + USVᴴtrunc = svd_trunc_no_error(t, alg) + ΔUSVᴴtrunc = EnzymeTestUtils.rand_tangent(USVᴴtrunc) + remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], USVᴴtrunc...) + EnzymeTestUtils.test_reverse(svd_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔUSVᴴtrunc, atol, rtol) +end From c91bb806ac96825ed770e63f8fad3dd536a6ffd5 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 26 Jun 2026 02:54:11 +0200 Subject: [PATCH 2/5] Don't use unreleased MAK method have_equal_storage --- src/factorizations/pullbacks.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/factorizations/pullbacks.jl b/src/factorizations/pullbacks.jl index e0a044ebb..bd70420cd 100644 --- a/src/factorizations/pullbacks.jl +++ b/src/factorizations/pullbacks.jl @@ -130,6 +130,8 @@ function MAK.remove_svd_gauge_dependence!( return ΔU, ΔVᴴ end +#= # uncomment at next MAK release! MAK.has_equal_storage(A::AbstractTensorMap, B::AbstractTensorMap) = A === B MAK.has_equal_storage(A::AbstractTensorMap, B::SectorVector) = false MAK.has_equal_storage(A::SectorVector, B::AbstractTensorMap) = false +=# From b05080cc524aa8ca7afb45056216e14efa730847 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 26 Jun 2026 16:39:10 +0200 Subject: [PATCH 3/5] Try using latest MAK --- Project.toml | 3 +++ src/factorizations/pullbacks.jl | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 10d69842f..c6b2eb9bd 100644 --- a/Project.toml +++ b/Project.toml @@ -42,6 +42,9 @@ TensorKitFiniteDifferencesExt = "FiniteDifferences" TensorKitGPUArraysExt = "GPUArrays" TensorKitMooncakeExt = "Mooncake" +[sources] +MatrixAlgebraKit = {url = "https://github.com/quantumkithub/matrixalgebrakit.jl", rev = "main"} + [compat] AMDGPU = "2" Adapt = "4" diff --git a/src/factorizations/pullbacks.jl b/src/factorizations/pullbacks.jl index bd70420cd..e0a044ebb 100644 --- a/src/factorizations/pullbacks.jl +++ b/src/factorizations/pullbacks.jl @@ -130,8 +130,6 @@ function MAK.remove_svd_gauge_dependence!( return ΔU, ΔVᴴ end -#= # uncomment at next MAK release! MAK.has_equal_storage(A::AbstractTensorMap, B::AbstractTensorMap) = A === B MAK.has_equal_storage(A::AbstractTensorMap, B::SectorVector) = false MAK.has_equal_storage(A::SectorVector, B::AbstractTensorMap) = false -=# From ba1dea1f92dfa3982b57ceb324de046cbc03467d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 26 Jun 2026 20:15:45 +0200 Subject: [PATCH 4/5] Consistent eltypes --- test/enzyme-factorizations-eig/eigh.jl | 2 +- test/enzyme-factorizations-svd/svd.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/enzyme-factorizations-eig/eigh.jl b/test/enzyme-factorizations-eig/eigh.jl index 970e2bfef..d465123cf 100644 --- a/test/enzyme-factorizations-eig/eigh.jl +++ b/test/enzyme-factorizations-eig/eigh.jl @@ -7,7 +7,7 @@ using Enzyme, EnzymeTestUtils using Random spacelist = ad_spacelist(fast_tests) -eltypes = (Float32, ComplexF64) +eltypes = (Float64, ComplexF64) @timedtestset "Enzyme - Factorizations (EIGH): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) atol = default_tol(T) diff --git a/test/enzyme-factorizations-svd/svd.jl b/test/enzyme-factorizations-svd/svd.jl index c51d75847..78320a532 100644 --- a/test/enzyme-factorizations-svd/svd.jl +++ b/test/enzyme-factorizations-svd/svd.jl @@ -7,7 +7,7 @@ using Enzyme, EnzymeTestUtils using Random spacelist = ad_spacelist(fast_tests) -eltypes = (Float64,) # ComplexF64) +eltypes = (Float64, ComplexF64) @timedtestset "Enzyme - Factorizations (SVD): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')) atol = default_tol(T) From dd8652f42c847f395035ad4262ec7256e708e1dd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 26 Jun 2026 22:16:09 +0200 Subject: [PATCH 5/5] Mark infimum inactive --- ext/TensorKitEnzymeExt/utility.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl index 2d6366c90..1461aaf62 100644 --- a/ext/TensorKitEnzymeExt/utility.jl +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -19,6 +19,7 @@ pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β)) @inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true @inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true +@inline EnzymeRules.inactive(::typeof(TensorKit.infimum), ::Any, ::Any) = nothing @inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing @inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing @inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing