Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ using Random: AbstractRNG

include("utility.jl")
include("linalg.jl")
include("indexmanipulations.jl")

end
266 changes: 266 additions & 0 deletions ext/TensorKitEnzymeExt/indexmanipulations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
for transform in (:permute, :transpose)
transform! = Symbol(transform, :!)
transform_pb = Symbol(transform, :_pullback_dA)
@eval function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TK.$transform!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
ba::Const...
) where {RT}
C_cache = !isa(β, Const) ? copy(C.val) : nothing
A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing
# if we need to compute Δa, it is faster to allocate an intermediate permuted A
# and store that instead of repeating the permutation in the pullback each time.
# effectively, we replace `add_permute` by `add ∘ permute`.
Ap = if !isa(α, Const)
Ap = $transform(A.val, p.val)
add!(C.val, Ap, α.val, β.val)
Ap
else
bavs = map(a -> a.val, ba)
TK.$transform!(C.val, A.val, p.val, α.val, β.val, bavs...)
nothing
end
cache = (C_cache, A_cache, Ap)
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
@eval function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TK.$transform!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
ba::Const...
) where {RT}
C_cache, A_cache, Ap = cache
Cval = something(C_cache, C.val)
bavs = map(a -> a.val, ba)
# ΔA
if !isa(A, Const) && !isa(C, Const)
Aval = something(A_cache, A.val)
TK.$transform_pb(A.dval, Aval, C.dval, C.val, p.val, α.val, bavs...)
end
Δα = pullback_dα(α, C, Ap)
Δβ = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return nothing, nothing, nothing, Δα, Δβ, map(Returns(nothing), ba)...
end
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TK.braid!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
levels::Const{<:IndexTuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
ba::Const...
) where {RT}
C_cache = !isa(β, Const) ? deepcopy(C.val) : nothing
A_cache = EnzymeRules.overwritten(config)[3] ? deepcopy(A.val) : nothing
# if we need to compute Δa, it is faster to allocate an intermediate braided A
# and store that instead of repeating the permutation in the pullback each time.
# effectively, we replace `add_permute` by `add ∘ permute`.
Ap = if !isa(α, Const)
Ap = braid(A.val, p.val, levels.val)
add!(C.val, Ap, α.val, β.val)
Ap
else
bavs = map(a -> a.val, ba)
TK.braid!(C.val, A.val, p.val, levels.val, α.val, β.val, bavs...)
nothing
end
cache = (C_cache, A_cache, Ap)
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TK.braid!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
levels::Const{<:IndexTuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
ba::Const...
) where {RT}
C_cache, A_cache, Ap = cache
Cval = something(C_cache, C.val)
Aval = something(A_cache, A.val)
bavs = map(a -> a.val, ba)
# ΔA
if !isa(A, Const) && !isa(C, Const)
TK.braid_pb(A.dval, Aval, C.dval, C.val, p.val, levels.val, α.val, bavs...)
end
Δαr = pullback_dα(α, C, Ap)
Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return nothing, nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)...
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(twist!)},
::Type{RT},
t::Annotation{<:AbstractTensorMap},
inds::Const;
inv::Bool = false
) where {RT}
twist!(t.val, inds.val; inv)
primal = EnzymeRules.needs_primal(config) ? t.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? t.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(twist!)},
::Type{RT},
cache,
t::Annotation{<:AbstractTensorMap},
inds::Const;
inv::Bool = false
) where {RT}
!isa(t, Const) && twist!(t.dval, inds.val; inv = !inv)
return (nothing, nothing)
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(flip)},
::Type{RT},
t::Annotation{<:AbstractTensorMap},
inds::Const;
inv::Bool = false
) where {RT}
t′ = flip(t.val, inds.val; inv)
dt′ = make_zero(t′)
cache = dt′
primal = EnzymeRules.needs_primal(config) ? t′ : nothing
shadow = EnzymeRules.needs_shadow(config) ? dt′ : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(flip)},
::Type{RT},
cache,
t::Annotation{<:AbstractTensorMap},
inds::Const;
inv::Bool = false,
) where {RT}
dt′ = cache
if !isa(t, Const)
dt′′ = flip(dt′, inds.val; inv = !inv)
add!(t.dval, scalartype(t.dval) <: Real ? real(dt′′) : dt′′)
end
return (nothing, nothing)
end

for insertunit in (:insertleftunit, :insertrightunit)
@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($insertunit)},
::Type{RT},
tsrc::Annotation{<:AbstractTensorMap},
ival::Const{<:Val};
kwargs...
) where {RT}
if tsrc.val isa TensorMap && !get(kwargs, :copy, false) && !isa(tsrc, Const)
tsrc_cache = copy(tsrc.val)
tdst = $insertunit(tsrc.val, ival.val; kwargs...)
Δtdst = $insertunit(tsrc.dval, ival.val; kwargs...)
else
tsrc_cache = nothing
tdst = $insertunit(tsrc.val, ival.val; kwargs...)
Δtdst = make_zero(tdst)
end
primal = EnzymeRules.needs_primal(config) ? tdst : nothing
shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing
cache = (tsrc_cache, tdst, Δtdst)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($insertunit)},
::Type{RT},
cache,
tsrc::Annotation{<:AbstractTensorMap},
ival::Const{<:Val};
kwargs...
) where {RT}
tsrc_cache, tdst, Δtdst = cache
# note: since data is already shared for <:TensorMap, don't have to do anything here!
if isnothing(tsrc_cache) && !isa(tsrc, Const)
for (c, b) in blocks(Δtdst)
add!(block(tsrc.dval, c), b)
end
end
return (nothing, nothing)
end
end
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(removeunit)},
::Type{RT},
tsrc::Annotation{<:AbstractTensorMap},
ival::Const{<:Val};
kwargs...
) where {RT}
# tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly
# sharing address spaces
if tsrc.val isa TensorMap && !get(kwargs, :copy, false) && !isa(tsrc, Const)
tsrc_cache = copy(tsrc.val)
tdst = removeunit(tsrc.val, ival.val; kwargs...)
Δtdst = removeunit(tsrc.dval, ival.val)
else
tsrc_cache = nothing
tdst = removeunit(tsrc.val, ival.val; kwargs...)
Δtdst = make_zero(tdst)
end
primal = EnzymeRules.needs_primal(config) ? tdst : nothing
shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing
cache = (tsrc_cache, tdst, Δtdst)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(removeunit)},
::Type{RT},
cache,
tsrc::Annotation{<:AbstractTensorMap},
ival::Const{<:Val};
kwargs...
) where {RT}
tsrc_cache, tdst, Δtdst = cache
# note: since data for <: TensorMap is already shared, don't have to do anything here!
if isnothing(tsrc_cache) && !isa(tsrc, Const)
for (c, b) in blocks(Δtdst)
add!(block(tsrc.dval, c), b)
end
end
return (nothing, nothing)
end
26 changes: 3 additions & 23 deletions ext/TensorKitMooncakeExt/indexmanipulations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
for transform in (:permute, :transpose)
transform! = Symbol(transform, :!)
transform_pullback = Symbol(transform!, :_pullback)
transform_pb = Symbol(transform, :_pullback_dA)
@eval @is_primitive(
DefaultCtx,
ReverseMode,
Expand Down Expand Up @@ -44,17 +45,7 @@ for transform in (:permute, :transpose)
copy!(C, C_cache)

# ΔA
ip = invperm(linearize(p))
pΔA = TO.repartition(ip, numout(A))

TC = VectorInterface.promote_scale(ΔC, α)
if scalartype(ΔA) <: Real && !(TC <: Real)
ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false))
TK.$transform!(ΔAc, ΔC, pΔA, conj(α), Zero(), ba...)
add!(ΔA, real(ΔAc))
else
TK.$transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...)
end
TK.$transform_pb(ΔA, A, ΔC, C, p, α, ba...)
ΔAr = NoRData()

Δαr = pullback_dα(α, ΔC, Ap)
Expand Down Expand Up @@ -111,18 +102,7 @@ function Mooncake.rrule!!(
function braid!_pullback(::NoRData)
copy!(C, C_cache)

# ΔA
ip = invperm(linearize(p))
pΔA = TO.repartition(ip, numout(A))
ilevels = TupleTools.permute(levels, linearize(p))
TC = VectorInterface.promote_scale(ΔC, α)
if scalartype(ΔA) <: Real && !(TC <: Real)
ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false))
TK.braid!(ΔAc, ΔC, pΔA, ilevels, conj(α), Zero(), ba...)
add!(ΔA, real(ΔAc))
else
TK.braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...)
end
TK.braid_pb(ΔA, A, ΔC, C, p, levels, α, ba...)
ΔAr = NoRData()

Δαr = pullback_dα(α, ΔC, Ap)
Expand Down
2 changes: 2 additions & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,4 +285,6 @@ include("planar/planaroperations.jl")
# ------------------------
include("auxiliary/ad.jl")

include("pullbacks/indexmanipulations.jl")

end
31 changes: 31 additions & 0 deletions src/pullbacks/indexmanipulations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
for transform in (:permute, :transpose)
transform! = Symbol(transform, :!)
transform_pb = Symbol(transform, :_pullback_dA)
@eval function $transform_pb(ΔA, A, ΔC, C, p, α, ba...)
ip = invperm(linearize(p))
pΔA = TO.repartition(ip, numout(A))
TC = VectorInterface.promote_scale(C, α)
if scalartype(ΔA) <: Real && !(TC <: Real)
ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false))
$transform!(ΔAc, ΔC, pΔA, conj(α), Zero(), ba...)
add!(ΔA, real(ΔAc))
else
$transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...)
end
return
end
end
function braid_pb(ΔA, A, ΔC, C, p, levels, α, ba...)
ip = invperm(linearize(p))
pΔA = TO.repartition(ip, numout(A))
ilevels = TupleTools.permute(levels, linearize(p))
TC = VectorInterface.promote_scale(ΔC, α)
if scalartype(ΔA) <: Real && !(TC <: Real)
ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false))
braid!(ΔAc, ΔC, pΔA, ilevels, conj(α), Zero(), ba...)
add!(ΔA, real(ΔAc))
else
braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...)
end
return
end
19 changes: 19 additions & 0 deletions test/enzyme-indexmanipulations-flip-twist/flip.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Test, TestExtras
using TensorKit
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Index Manipulations (flip):" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T) TA ($TA)" for V in spacelist, T in eltypes, TA in (Duplicated,)
atol = default_tol(T)
rtol = default_tol(T)
A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5])
EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,))
EnzymeTestUtils.test_reverse(flip, TA, (A, TA), [1, 3]; atol, rtol, fkwargs = (inv = true,))
EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol)
EnzymeTestUtils.test_reverse(flip, TA, (A, TA), ([1, 3], Const); atol, rtol)
end
end
Loading
Loading