Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback!
using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
using Enzyme
using Enzyme.EnzymeCore
using Enzyme.EnzymeCore: EnzymeRules
Expand Down Expand Up @@ -117,6 +118,36 @@ for (f, pb) in (
end
end

for (f, pf) in (
(left_polar!, left_polar_pushforward!),
(right_polar!, right_polar_pushforward!),
)
@eval begin
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation,
arg::Annotation{TA},
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
$f(A.val, arg.val, alg.val)
if !isa(A, Const) && !isa(arg, Const)
$pf(A.dval, A.val, arg.val, arg.dval)
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return arg
elseif EnzymeRules.needs_primal(config)
return arg.val
elseif EnzymeRules.needs_shadow(config)
return arg.dval
else
return nothing
end
end
end
end

for (f, pb) in (
(qr_null!, qr_null_pullback!),
(lq_null!, lq_null_pullback!),
Expand Down
40 changes: 38 additions & 2 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: TruncatedAlgorithm
using LinearAlgebra

Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent

# needed for GPU tests because Mooncake can't differentiate through CUDA kernels
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(zero!), AbstractArray}
@is_primitive Mooncake.DefaultCtx Tuple{typeof(zero!), AbstractArray}
function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual)
A, dA = arrayify(A_dA)
Ac = copy(A)
Expand All @@ -28,6 +29,12 @@ function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual)
end
return A_dA, zero_adjoint
end
function Mooncake.frule!!(::Dual{typeof(zero!)}, A_dA::Dual)
A, dA = arrayify(A_dA)
zero!(A)
zero!(dA)
return A_dA
end

# two-argument in-place factorizations like LQ, QR, EIG
for (f!, f, pb, adj) in (
Expand All @@ -40,7 +47,6 @@ for (f!, f, pb, adj) in (
(:left_polar!, :left_polar, :left_polar_pullback!, :left_polar_adjoint),
(:right_polar!, :right_polar, :right_polar_pullback!, :right_polar_adjoint),
)

@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
Expand Down Expand Up @@ -104,6 +110,36 @@ for (f!, f, pb, adj) in (
end
end

for (f!, f, pf) in (
(:left_polar!, :left_polar, :left_polar_pushforward!),
(:right_polar!, :right_polar, :right_polar_pushforward!),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
args = Mooncake.primal(args_dargs)
dargs = Mooncake.tangent(args_dargs)
arg1, darg1 = arrayify(args[1], dargs[1])
arg2, darg2 = arrayify(args[2], dargs[2])
$f!(A, args, Mooncake.primal(alg_dalg))
$pf(dA, A, (arg1, arg2), (darg1, darg2))
return args_dargs
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
output = $f(A, Mooncake.primal(alg_dalg))
doutput = Mooncake.zero_tangent(output)
output_dual = Dual(output, doutput)
arg1, darg1 = arrayify(output[1], doutput[1])
arg2, darg2 = arrayify(output[2], doutput[2])
$pf(dA, A, (arg1, arg2), (darg1, darg2))
return output_dual
end
end
end

for (f!, f, pb, adj) in (
(:qr_null!, :qr_null, :qr_null_pullback!, :qr_null_adjoint),
(:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint),
Expand Down
2 changes: 2 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ include("pullbacks/eigh.jl")
include("pullbacks/svd.jl")
include("pullbacks/polar.jl")

include("pushforwards/polar.jl")

include("precompile.jl")

end
25 changes: 25 additions & 0 deletions src/pushforwards/polar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
W, P = WP
ΔW, ΔP = ΔWP
mul!(ΔP, adjoint(W), ΔA, +1, 0)
K̇ = _sylvester(P, P, adjoint(ΔP) - ΔP)
mul!(ΔW, ΔA, inv(P), +1, 0)
WᴴdAiP = W' * ΔW
mul!(ΔW, W, WᴴdAiP, -1, +1)
ΔW = mul!(ΔW, W, K̇, +1, +1)
ΔP = mul!(ΔP, K̇, P, -1, +1)
return (ΔW, ΔP)
end

function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
P, Wᴴ = PWᴴ
ΔP, ΔWᴴ = ΔPWᴴ
mul!(ΔP, ΔA, adjoint(Wᴴ), +1, 0)
K̇ = _sylvester(P, P, adjoint(ΔP) - ΔP)
mul!(ΔWᴴ, inv(P), ΔA, +1, 0)
iPdAW = ΔWᴴ * Wᴴ'
mul!(ΔWᴴ, iPdAW, Wᴴ, -1, +1)
ΔWᴴ = mul!(ΔWᴴ, K̇, Wᴴ, +1, +1)
ΔP = mul!(ΔP, P, K̇, -1, +1)
return (ΔWᴴ, ΔP)
end
18 changes: 12 additions & 6 deletions test/testsuite/enzyme/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,49 @@ end
"""
test_enzyme_left_polar(T, sz; rng, atol, rtol)

Test the Enzyme reverse-mode AD rule for `left_polar` and its in-place variant. Only runs
for tall or square matrices (`m >= n`).
Test the Enzyme forward- and reverse-mode AD rule for `left_polar` and its in-place variant.
Only runs for tall or square matrices (`m >= n`).
"""
function test_enzyme_left_polar(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T)
)
return @testset "left_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "left_polar: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = instantiate_matrix(T, sz)
m, n = size(A)
if m >= n
alg = MatrixAlgebraKit.select_algorithm(left_polar, A)
WP, ΔWP = ad_left_polar_setup(A)
test_reverse(left_polar, RT, (A, TA), (alg, Const); atol, rtol)
test_reverse(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol)
A = instantiate_matrix(T, sz)
test_forward(left_polar, RT, (A, TA), (alg, Const); atol, rtol)
test_forward(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol)
end
end
end

"""
test_enzyme_right_polar(T, sz; rng, atol, rtol)

Test the Enzyme reverse-mode AD rule for `right_polar` and its in-place variant. Only runs
for wide or square matrices (`m <= n`).
Test the Enzyme forward- and reverse-mode AD rule for `right_polar` and its in-place variant.
Only runs for wide or square matrices (`m <= n`).
"""
function test_enzyme_right_polar(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T)
)
return @testset "right_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "right_polar: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = instantiate_matrix(T, sz)
m, n = size(A)
if m <= n
alg = MatrixAlgebraKit.select_algorithm(right_polar, A)
PWᴴ, ΔPWᴴ = ad_right_polar_setup(A)
test_reverse(right_polar, RT, (A, TA), (alg, Const); atol, rtol)
test_reverse(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol)
A = instantiate_matrix(T, sz)
test_forward(right_polar, RT, (A, TA), (alg, Const); atol, rtol)
test_forward(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol)
end
end
end
16 changes: 8 additions & 8 deletions test/testsuite/mooncake/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ end
"""
test_mooncake_left_polar(T, sz; rng, atol, rtol)

Test the Mooncake reverse-mode AD rule for `left_polar` and its in-place variant. Only runs
for tall or square matrices (`m >= n`).
Test the Mooncake forward- and reverse-mode AD rule for `left_polar` and its in-place variant.
Only runs for tall or square matrices (`m >= n`).
"""
function test_mooncake_left_polar(
T, sz;
Expand All @@ -31,20 +31,20 @@ function test_mooncake_left_polar(

Mooncake.TestUtils.test_rule(
rng, left_polar, A, alg;
mode = Mooncake.ReverseMode, output_tangent, atol, rtol
output_tangent, atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, call_and_zero!, left_polar!, A, alg;
mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false
output_tangent, atol, rtol, is_primitive = false
)
end
end

"""
test_mooncake_right_polar(T, sz; rng, atol, rtol)

Test the Mooncake reverse-mode AD rule for `right_polar` and its in-place variant. Only runs
for wide or square matrices (`m <= n`).
Test the Mooncake forward- and reverse-mode AD rule for `right_polar` and its in-place variant.
Only runs for wide or square matrices (`m <= n`).
"""
function test_mooncake_right_polar(
T, sz;
Expand All @@ -60,11 +60,11 @@ function test_mooncake_right_polar(

Mooncake.TestUtils.test_rule(
rng, right_polar, A, alg;
mode = Mooncake.ReverseMode, output_tangent, atol, rtol
output_tangent, atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, call_and_zero!, right_polar!, A, alg;
mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false
output_tangent, atol, rtol, is_primitive = false
)
end
end
Loading