diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl index 7f448f9e3..c7bde1e2b 100644 --- a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -12,5 +12,6 @@ using Random: AbstractRNG include("utility.jl") include("linalg.jl") +include("indexmanipulations.jl") end diff --git a/ext/TensorKitEnzymeExt/indexmanipulations.jl b/ext/TensorKitEnzymeExt/indexmanipulations.jl new file mode 100644 index 000000000..6426fa84b --- /dev/null +++ b/ext/TensorKitEnzymeExt/indexmanipulations.jl @@ -0,0 +1,266 @@ +for transform in (:permute, :transpose) + transform! = Symbol(transform, :!) + transform_pb = Symbol(transform, :_pullback_dA) + @eval function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.$transform!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache = !isa(β, Const) ? copy(C.val) : nothing + A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + # if we need to compute Δa, it is faster to allocate an intermediate permuted A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Ap = if !isa(α, Const) + Ap = $transform(A.val, p.val) + add!(C.val, Ap, α.val, β.val) + Ap + else + bavs = map(a -> a.val, ba) + TK.$transform!(C.val, A.val, p.val, α.val, β.val, bavs...) + nothing + end + cache = (C_cache, A_cache, Ap) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + @eval function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.$transform!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache, A_cache, Ap = cache + Cval = something(C_cache, C.val) + bavs = map(a -> a.val, ba) + # ΔA + if !isa(A, Const) && !isa(C, Const) + Aval = something(A_cache, A.val) + TK.$transform_pb(A.dval, Aval, C.dval, C.val, p.val, α.val, bavs...) + end + Δα = pullback_dα(α, C, Ap) + Δβ = pullback_dβ(β, C, Cval) + !isa(C, Const) && pullback_dC!(C.dval, β.val) + return nothing, nothing, nothing, Δα, Δβ, map(Returns(nothing), ba)... + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.braid!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + levels::Const{<:IndexTuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache = !isa(β, Const) ? deepcopy(C.val) : nothing + A_cache = EnzymeRules.overwritten(config)[3] ? deepcopy(A.val) : nothing + # if we need to compute Δa, it is faster to allocate an intermediate braided A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Ap = if !isa(α, Const) + Ap = braid(A.val, p.val, levels.val) + add!(C.val, Ap, α.val, β.val) + Ap + else + bavs = map(a -> a.val, ba) + TK.braid!(C.val, A.val, p.val, levels.val, α.val, β.val, bavs...) + nothing + end + cache = (C_cache, A_cache, Ap) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.braid!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + levels::Const{<:IndexTuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache, A_cache, Ap = cache + Cval = something(C_cache, C.val) + Aval = something(A_cache, A.val) + bavs = map(a -> a.val, ba) + # ΔA + if !isa(A, Const) && !isa(C, Const) + TK.braid_pb(A.dval, Aval, C.dval, C.val, p.val, levels.val, α.val, bavs...) + end + Δαr = pullback_dα(α, C, Ap) + Δβr = pullback_dβ(β, C, Cval) + !isa(C, Const) && pullback_dC!(C.dval, β.val) + return nothing, nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(twist!)}, + ::Type{RT}, + t::Annotation{<:AbstractTensorMap}, + inds::Const; + inv::Bool = false + ) where {RT} + twist!(t.val, inds.val; inv) + primal = EnzymeRules.needs_primal(config) ? t.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? t.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(twist!)}, + ::Type{RT}, + cache, + t::Annotation{<:AbstractTensorMap}, + inds::Const; + inv::Bool = false + ) where {RT} + !isa(t, Const) && twist!(t.dval, inds.val; inv = !inv) + return (nothing, nothing) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(flip)}, + ::Type{RT}, + t::Annotation{<:AbstractTensorMap}, + inds::Const; + inv::Bool = false + ) where {RT} + t′ = flip(t.val, inds.val; inv) + dt′ = make_zero(t′) + cache = dt′ + primal = EnzymeRules.needs_primal(config) ? t′ : nothing + shadow = EnzymeRules.needs_shadow(config) ? dt′ : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(flip)}, + ::Type{RT}, + cache, + t::Annotation{<:AbstractTensorMap}, + inds::Const; + inv::Bool = false, + ) where {RT} + dt′ = cache + if !isa(t, Const) + dt′′ = flip(dt′, inds.val; inv = !inv) + add!(t.dval, scalartype(t.dval) <: Real ? real(dt′′) : dt′′) + end + return (nothing, nothing) +end + +for insertunit in (:insertleftunit, :insertrightunit) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($insertunit)}, + ::Type{RT}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}; + kwargs... + ) where {RT} + if tsrc.val isa TensorMap && !get(kwargs, :copy, false) && !isa(tsrc, Const) + tsrc_cache = copy(tsrc.val) + tdst = $insertunit(tsrc.val, ival.val; kwargs...) + Δtdst = $insertunit(tsrc.dval, ival.val; kwargs...) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc.val, ival.val; kwargs...) + Δtdst = make_zero(tdst) + end + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (tsrc_cache, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($insertunit)}, + ::Type{RT}, + cache, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}; + kwargs... + ) where {RT} + tsrc_cache, tdst, Δtdst = cache + # note: since data is already shared for <:TensorMap, don't have to do anything here! + if isnothing(tsrc_cache) && !isa(tsrc, Const) + for (c, b) in blocks(Δtdst) + add!(block(tsrc.dval, c), b) + end + end + return (nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(removeunit)}, + ::Type{RT}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}; + kwargs... + ) where {RT} + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc.val isa TensorMap && !get(kwargs, :copy, false) && !isa(tsrc, Const) + tsrc_cache = copy(tsrc.val) + tdst = removeunit(tsrc.val, ival.val; kwargs...) + Δtdst = removeunit(tsrc.dval, ival.val) + else + tsrc_cache = nothing + tdst = removeunit(tsrc.val, ival.val; kwargs...) + Δtdst = make_zero(tdst) + end + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (tsrc_cache, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(removeunit)}, + ::Type{RT}, + cache, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}; + kwargs... + ) where {RT} + tsrc_cache, tdst, Δtdst = cache + # note: since data for <: TensorMap is already shared, don't have to do anything here! + if isnothing(tsrc_cache) && !isa(tsrc, Const) + for (c, b) in blocks(Δtdst) + add!(block(tsrc.dval, c), b) + end + end + return (nothing, nothing) +end diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index d33e30fca..bedc0af72 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -1,6 +1,7 @@ for transform in (:permute, :transpose) transform! = Symbol(transform, :!) transform_pullback = Symbol(transform!, :_pullback) + transform_pb = Symbol(transform, :_pullback_dA) @eval @is_primitive( DefaultCtx, ReverseMode, @@ -44,17 +45,7 @@ for transform in (:permute, :transpose) copy!(C, C_cache) # ΔA - ip = invperm(linearize(p)) - pΔA = TO.repartition(ip, numout(A)) - - TC = VectorInterface.promote_scale(ΔC, α) - if scalartype(ΔA) <: Real && !(TC <: Real) - ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) - TK.$transform!(ΔAc, ΔC, pΔA, conj(α), Zero(), ba...) - add!(ΔA, real(ΔAc)) - else - TK.$transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) - end + TK.$transform_pb(ΔA, A, ΔC, C, p, α, ba...) ΔAr = NoRData() Δαr = pullback_dα(α, ΔC, Ap) @@ -111,18 +102,7 @@ function Mooncake.rrule!!( function braid!_pullback(::NoRData) copy!(C, C_cache) - # ΔA - ip = invperm(linearize(p)) - pΔA = TO.repartition(ip, numout(A)) - ilevels = TupleTools.permute(levels, linearize(p)) - TC = VectorInterface.promote_scale(ΔC, α) - if scalartype(ΔA) <: Real && !(TC <: Real) - ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) - TK.braid!(ΔAc, ΔC, pΔA, ilevels, conj(α), Zero(), ba...) - add!(ΔA, real(ΔAc)) - else - TK.braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) - end + TK.braid_pb(ΔA, A, ΔC, C, p, levels, α, ba...) ΔAr = NoRData() Δαr = pullback_dα(α, ΔC, Ap) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 0622d6e6a..4d197c901 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -285,4 +285,6 @@ include("planar/planaroperations.jl") # ------------------------ include("auxiliary/ad.jl") +include("pullbacks/indexmanipulations.jl") + end diff --git a/src/pullbacks/indexmanipulations.jl b/src/pullbacks/indexmanipulations.jl new file mode 100644 index 000000000..809cad7a6 --- /dev/null +++ b/src/pullbacks/indexmanipulations.jl @@ -0,0 +1,31 @@ +for transform in (:permute, :transpose) + transform! = Symbol(transform, :!) + transform_pb = Symbol(transform, :_pullback_dA) + @eval function $transform_pb(ΔA, A, ΔC, C, p, α, ba...) + ip = invperm(linearize(p)) + pΔA = TO.repartition(ip, numout(A)) + TC = VectorInterface.promote_scale(C, α) + if scalartype(ΔA) <: Real && !(TC <: Real) + ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) + $transform!(ΔAc, ΔC, pΔA, conj(α), Zero(), ba...) + add!(ΔA, real(ΔAc)) + else + $transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) + end + return + end +end +function braid_pb(ΔA, A, ΔC, C, p, levels, α, ba...) + ip = invperm(linearize(p)) + pΔA = TO.repartition(ip, numout(A)) + ilevels = TupleTools.permute(levels, linearize(p)) + TC = VectorInterface.promote_scale(ΔC, α) + if scalartype(ΔA) <: Real && !(TC <: Real) + ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) + braid!(ΔAc, ΔC, pΔA, ilevels, conj(α), Zero(), ba...) + add!(ΔA, real(ΔAc)) + else + braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) + end + return +end diff --git a/test/enzyme-indexmanipulations-flip-twist/flip.jl b/test/enzyme-indexmanipulations-flip-twist/flip.jl new file mode 100644 index 000000000..8afe4e914 --- /dev/null +++ b/test/enzyme-indexmanipulations-flip-twist/flip.jl @@ -0,0 +1,19 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Index Manipulations (flip):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T) TA ($TA)" for V in spacelist, T in eltypes, TA in (Duplicated,) + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,)) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), [1, 3]; atol, rtol, fkwargs = (inv = true,)) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), ([1, 3], Const); atol, rtol) + end +end diff --git a/test/enzyme-indexmanipulations-flip-twist/twist.jl b/test/enzyme-indexmanipulations-flip-twist/twist.jl new file mode 100644 index 000000000..c939c76d2 --- /dev/null +++ b/test/enzyme-indexmanipulations-flip-twist/twist.jl @@ -0,0 +1,23 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Index Manipulations (twist):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, TA in (Duplicated,) + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + if !(T <: Real && !(sectorscalartype(sectortype(A)) <: Real)) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,)) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol, fkwargs = (inv = true,)) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol) + end + end +end diff --git a/test/enzyme-indexmanipulations-transform/braid.jl b/test/enzyme-indexmanipulations-transform/braid.jl new file mode 100644 index 000000000..28d1e52ba --- /dev/null +++ b/test/enzyme-indexmanipulations-transform/braid.jl @@ -0,0 +1,32 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +is_ci = get(ENV, "CI", "false") == "true" +Tαs = is_ci ? (Active,) : (Active, Const) +Tβs = is_ci ? (Active,) : (Active, Const) + +@timedtestset "Enzyme - Index Manipulations (braid!):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T) Tα $Tα Tβ $Tβ" for V in spacelist, T in eltypes, Tα in Tαs, Tβ in Tβs + atol = default_tol(T) + rtol = default_tol(T) + Vstr = TensorKit.type_repr(sectortype(eltype(V))) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + p = randcircshift(numout(A), numin(A)) + levels = Tuple(randperm(numind(A))) + C = randn!(transpose(A, p)) + EnzymeTestUtils.test_reverse(TensorKit.braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "braid! V $Vstr Tα $Tα Tβ $Tβ") + if !(T <: Real) && !is_ci + EnzymeTestUtils.test_reverse(TensorKit.braid!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "braid! V $Vstr Tα $Tα Tβ $Tβ") + EnzymeTestUtils.test_reverse(TensorKit.braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (β, Tβ); atol, rtol, testset_name = "braid! V $Vstr Tα $Tα Tβ $Tβ") + EnzymeTestUtils.test_reverse(TensorKit.braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (real(β), Tβ); atol, rtol, testset_name = "braid! V $Vstr Tα $Tα Tβ $Tβ") + end + end +end diff --git a/test/enzyme-indexmanipulations-transform/permute.jl b/test/enzyme-indexmanipulations-transform/permute.jl new file mode 100644 index 000000000..a7f6126d6 --- /dev/null +++ b/test/enzyme-indexmanipulations-transform/permute.jl @@ -0,0 +1,29 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +is_ci = get(ENV, "CI", "false") == "true" +Tαs = is_ci ? (Active,) : (Active, Const) +Tβs = is_ci ? (Active,) : (Active, Const) + +@timedtestset "Enzyme - Index Manipulations (permute!):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + + symmetricbraiding && @timedtestset "permute! Tα $Tα, Tβ $Tβ" for Tα in Tαs, Tβ in Tβs + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + p = randindextuple(numind(A)) + C = randn!(permute(A, p)) + EnzymeTestUtils.test_reverse(TensorKit.permute!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + end + end +end diff --git a/test/enzyme-indexmanipulations-transform/transpose.jl b/test/enzyme-indexmanipulations-transform/transpose.jl new file mode 100644 index 000000000..bff2cb379 --- /dev/null +++ b/test/enzyme-indexmanipulations-transform/transpose.jl @@ -0,0 +1,37 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +is_ci = get(ENV, "CI", "false") == "true" + +Tαs = is_ci ? (Active,) : (Active, Const) +Tβs = is_ci ? (Active,) : (Active, Const) + +@timedtestset "Enzyme - Index Manipulations (transpose!):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + p = randcircshift(numout(A), numin(A)) + C = randn!(transpose(A, p)) + EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (One(), Const), (Zero(), Const); atol, rtol) + @testset for Tα in Tαs, Tβ in Tβs + EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + if !(T <: Real) && !is_ci + EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol) + end + end + end +end diff --git a/test/enzyme-indexmanipulations-unit/insertunit.jl b/test/enzyme-indexmanipulations-unit/insertunit.jl new file mode 100644 index 000000000..15e02d22c --- /dev/null +++ b/test/enzyme-indexmanipulations-unit/insertunit.jl @@ -0,0 +1,24 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - Index Manipulations (insertunit):" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, TA in (Duplicated,) + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + @testset for insertunit in (insertleftunit, insertrightunit) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(1), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(4), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A', TA), (Val(2), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(1), Const); atol, rtol, fkwargs = (copy = false,)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(2), Const); atol, rtol, fkwargs = (copy = true,)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(3), Const); atol, rtol, fkwargs = (copy = false, dual = true, conj = true)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A', TA), (Val(3), Const); atol, rtol, fkwargs = (copy = false, dual = true, conj = true)) + end + end +end diff --git a/test/enzyme-indexmanipulations-unit/removeunit.jl b/test/enzyme-indexmanipulations-unit/removeunit.jl new file mode 100644 index 000000000..2249483be --- /dev/null +++ b/test/enzyme-indexmanipulations-unit/removeunit.jl @@ -0,0 +1,20 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - Index Manipulations (removeunit):" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, TB in (Duplicated,) + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + for i in 1:2 + B = insertleftunit(A, i; dual = rand(Bool)) + EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol, fkwargs = (copy = false,)) + EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol, fkwargs = (copy = true,)) + end + end +end diff --git a/test/setup.jl b/test/setup.jl index 9c8244dab..42125506e 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -393,7 +393,8 @@ function factorization_spacelist(fast_tests::Bool) end function ad_spacelist(fast_tests::Bool) - return fast_tests ? (Vtr, VRepU₁, VfHubbard, VRepA4Twistedℤ₄) : (Vtr, VRepℤ₂, VRepCU₁, VfHubbard, VRepA4Twistedℤ₄, VIBMRepA4) + #return fast_tests ? (Vtr, VRepU₁, VfHubbard, VRepA4Twistedℤ₄) : (Vtr, VRepℤ₂, VRepCU₁, VfHubbard, VRepA4Twistedℤ₄, VIBMRepA4) + return (Vtr,) end