Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ using Random: AbstractRNG

include("utility.jl")
include("linalg.jl")
include("tensoroperations.jl")

end
208 changes: 208 additions & 0 deletions ext/TensorKitEnzymeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# tensorcontract!
# ---------------
# TODO: it might be beneficial to compare here if it would make sense to simply compute the
# rrule of permute-permute-gemm-permute, rather than using the contractions directly.
# This could possibly out save some permutations being carried out twice, at the cost of having
# to store some more intermediate objects.
# For example, the combination `ΔC, pΔC, false` appears in the pullback for ΔA and ΔB, so effectively
# this permutation is done multiple times.

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Const{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Const{<:Index2Tuple},
pAB::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
Ccache = isa(β, Const) ? nothing : copy(C.val)
A_needs_cache = EnzymeRules.overwritten(config)[3] && !(typeof(B) <: Const) && !(typeof(C) <: Const)
Acache = A_needs_cache ? copy(A.val) : nothing
B_needs_cache = EnzymeRules.overwritten(config)[5] && !(typeof(A) <: Const) && !(typeof(C) <: Const)
Bcache = B_needs_cache ? copy(B.val) : nothing
AB = if !isa(α, Const)
AB = TO.tensorcontract(A.val, pA.val, false, B.val, pB.val, false, pAB.val, One(), backend.val, allocator.val)
add!(C.val, AB, α.val, β.val)
AB
else
TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val)
nothing
end
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
cache = (Ccache, Acache, Bcache, AB)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Const{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Const{<:Index2Tuple},
pAB::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
cacheC, cacheA, cacheB, AB = cache
Cval = cacheC
Aval = something(cacheA, A.val)
Bval = something(cacheB, B.val)

Δα = pullback_dα(α, C, AB)
Δβ = pullback_dβ(β, C, Cval)

if !isa(A, Const)
TensorKit.blas_contract_pullback_ΔA!(
A.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val
) # this typically returns nothing
end
if !isa(B, Const)
TensorKit.blas_contract_pullback_ΔB!(
B.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val
) # this typically returns nothing
end
!isa(C, Const) && pullback_dC!(C.dval, β.val) # this typically returns nothing
return nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, nothing, nothing
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Annotation{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Annotation{<:Index2Tuple},
pAB::Annotation{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
if !isa(C, Const)
if isa(β, Const)
scale!(C.dval, β.val)
else
add!(C.dval, C.val, β.dval, β.val)
end
!isa(α, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.val, pB.val, pAB.val, α.dval, One(), backend.val, allocator.val)
!isa(A, Const) && TensorKit.blas_contract!(C.dval, A.dval, pA.val, B.val, pB.val, pAB.val, α.val, One(), backend.val, allocator.val)
!isa(B, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.dval, pB.val, pAB.val, α.val, One(), backend.val, allocator.val)
end
TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return C
elseif EnzymeRules.needs_primal(config)
return C.val
elseif EnzymeRules.needs_shadow(config)
return C.dval
else
return nothing
end
end

# tensortrace!
# ------------

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
q::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
C_cache = !isa(β, Const) ? copy(C.val) : nothing
A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing
At = if !isa(α, Const)
At = TO.tensortrace(A.val, p.val, q.val, false, One(), backend.val)
add!(C.val, At, α.val, β.val)
At
else
TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val)
nothing
end
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
cache = (C_cache, A_cache, At)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end


function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
q::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
C_cache, A_cache, At = cache
Aval = something(A_cache, A.val)
Cval = something(C_cache, C.val)
!isa(A, Const) && !isa(C, Const) && TensorKit.trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val)
Δαr = pullback_dα(α, C, At)
Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Annotation{<:Index2Tuple},
q::Annotation{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
# dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC
# dC1 = dβ * C + β * dC
if !isa(C, Const)
if isa(β, Const)
scale!(C.dval, β.val)
else
add!(C.dval, C.val, β.dval, β.val)
end
!isa(α, Const) && TensorKit.trace_permute!(C.dval, A.val, p.val, q.val, α.dval, One(), backend.val)
!isa(A, Const) && TensorKit.trace_permute!(C.dval, A.dval, p.val, q.val, α.val, One(), backend.val)
end
TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return C
elseif EnzymeRules.needs_primal(config)
return C.val
elseif EnzymeRules.needs_shadow(config)
return C.dval
else
return nothing
end
end
79 changes: 7 additions & 72 deletions ext/TensorKitMooncakeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ function Mooncake.rrule!!(
function blas_contract_pullback(::NoRData)
copy!(C, C_cache)

ΔAr = blas_contract_pullback_ΔA!(
ΔAr = TensorKit.blas_contract_pullback_ΔA!(
ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
) # this typically returns NoRData()
ΔBr = blas_contract_pullback_ΔB!(
ΔBr = TensorKit.blas_contract_pullback_ΔB!(
ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator
) # this typically returns NoRData()
Δαr = pullback_dα(α, ΔC, AB)
Δβr = pullback_dβ(β, ΔC, C)
ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData()

return NoRData(), ΔCr,
ΔAr, NoRData(),
ΔBr, NoRData(),
return NoRData(), NoRData(),
NoRData(), NoRData(),
NoRData(), NoRData(),
NoRData(),
Δαr, Δβr,
NoRData(), NoRData()
Expand Down Expand Up @@ -99,56 +99,6 @@ function Mooncake.frule!!(
return C_ΔC
end

function blas_contract_pullback_ΔA!(
ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
)
ipAB = invperm(linearize(pAB))
pΔC = TO.repartition(ipAB, pA)
ipA = TO.repartition(invperm(linearize(pA)), numout(A))

tB = twist(
B,
TupleTools.vcat(
filter(x -> !isdual(space(B, x)), pB[1]),
filter(x -> isdual(space(B, x)), pB[2])
); copy = false
)

TK.project_contract!(
ΔA,
ΔC, pΔC, false,
tB, reverse(pB), true,
ipA, conj(α), backend, allocator
)

return NoRData()
end

function blas_contract_pullback_ΔB!(
ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator
)
ipAB = invperm(linearize(pAB))
pΔC = TO.repartition(ipAB, pA)
ipB = TO.repartition(invperm(linearize(pB)), numout(B))

tA = twist(
A,
TupleTools.vcat(
filter(x -> isdual(space(A, x)), pA[1]),
filter(x -> !isdual(space(A, x)), pA[2])
); copy = false
)

TK.project_contract!(
ΔB,
tA, reverse(pA), true,
ΔC, pΔC, false,
ipB, conj(α), backend, allocator
)

return NoRData()
end

# tensortrace!
# ------------
@is_primitive(
Expand Down Expand Up @@ -191,14 +141,14 @@ function Mooncake.rrule!!(
function trace_permute_pullback(::NoRData)
copy!(C, C_cache)

ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) # this typically returns NoRData()
ΔAr = TensorKit.trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) # this typically returns NoRData()

Δαr = pullback_dα(α, ΔC, At)
Δβr = pullback_dβ(β, ΔC, C)
ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData()

return NoRData(),
ΔCr, ΔAr, NoRData(), NoRData(),
NoRData(), NoRData(), NoRData(), NoRData(),
Δαr, Δβr, NoRData()
end

Expand Down Expand Up @@ -236,21 +186,6 @@ function Mooncake.frule!!(
return C_ΔC
end

function trace_permute_pullback_ΔA!(
ΔA, ΔC, A, p, q, α, backend
)
ip = invperm((linearize(p)..., q[1]..., q[2]...))
pdA = TO.repartition(ip, numout(A))
E = one!(TO.tensoralloc_add(scalartype(A), A, q, false))
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
pE = ((), TO.trivialpermutation(TO.numind(q)))
pΔC = (TO.trivialpermutation(TO.numind(p)), ())
TO.tensorproduct!(
ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend
)
return NoRData()
end

@is_primitive(
DefaultCtx,
Tuple{
Expand Down
1 change: 1 addition & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,5 +284,6 @@ include("planar/planaroperations.jl")
# once all types have been declared
# ------------------------
include("auxiliary/ad.jl")
include("pullbacks/tensoroperations.jl")

end
4 changes: 2 additions & 2 deletions src/auxiliary/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ function project_mul!(C, A, B, α, β = One())
end
end
function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator)
TA = TensorKit.promote_permute(A)
TB = TensorKit.promote_permute(B)
TA = promote_permute(A)
TB = promote_permute(B)
TC = TO.promote_contract(TA, TB, scalartype(α))

return if scalartype(C) <: Real && !(TC <: Real)
Expand Down
Loading
Loading