diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index a7014a463..f834881fd 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -8,6 +8,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! +using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using Enzyme using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules @@ -117,6 +118,36 @@ for (f, pb) in ( end end +for (f, pf) in ( + (left_polar!, left_polar_pushforward!), + (right_polar!, right_polar_pushforward!), + ) + @eval begin + function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + arg::Annotation{TA}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + $f(A.val, arg.val, alg.val) + if !isa(A, Const) && !isa(arg, Const) + $pf(A.dval, A.val, arg.val, arg.dval) + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return arg + elseif EnzymeRules.needs_primal(config) + return arg.val + elseif EnzymeRules.needs_shadow(config) + return arg.dval + else + return nothing + end + end + end +end + for (f, pb) in ( (qr_null!, qr_null_pullback!), (lq_null!, lq_null_pullback!), diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index e2ae96c11..28d56ca65 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -9,6 +9,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! +using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra @@ -16,7 +17,7 @@ using LinearAlgebra Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent # needed for GPU tests because Mooncake can't differentiate through CUDA kernels -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(zero!), AbstractArray} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(zero!), AbstractArray} function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual) A, dA = arrayify(A_dA) Ac = copy(A) @@ -28,6 +29,12 @@ function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual) end return A_dA, zero_adjoint end +function Mooncake.frule!!(::Dual{typeof(zero!)}, A_dA::Dual) + A, dA = arrayify(A_dA) + zero!(A) + zero!(dA) + return A_dA +end # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( @@ -40,7 +47,6 @@ for (f!, f, pb, adj) in ( (:left_polar!, :left_polar, :left_polar_pullback!, :left_polar_adjoint), (:right_polar!, :right_polar, :right_polar_pullback!, :right_polar_adjoint), ) - @eval begin @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) @@ -104,6 +110,36 @@ for (f!, f, pb, adj) in ( end end +for (f!, f, pf) in ( + (:left_polar!, :left_polar, :left_polar_pushforward!), + (:right_polar!, :right_polar, :right_polar_pushforward!), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + args = Mooncake.primal(args_dargs) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + $f!(A, args, Mooncake.primal(alg_dalg)) + $pf(dA, A, (arg1, arg2), (darg1, darg2)) + return args_dargs + end + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + doutput = Mooncake.zero_tangent(output) + output_dual = Dual(output, doutput) + arg1, darg1 = arrayify(output[1], doutput[1]) + arg2, darg2 = arrayify(output[2], doutput[2]) + $pf(dA, A, (arg1, arg2), (darg1, darg2)) + return output_dual + end + end +end + for (f!, f, pb, adj) in ( (:qr_null!, :qr_null, :qr_null_pullback!, :qr_null_adjoint), (:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint), diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 25dc60c1b..37b964b44 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -129,6 +129,8 @@ include("pullbacks/eigh.jl") include("pullbacks/svd.jl") include("pullbacks/polar.jl") +include("pushforwards/polar.jl") + include("precompile.jl") end diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl new file mode 100644 index 000000000..6238ffa4c --- /dev/null +++ b/src/pushforwards/polar.jl @@ -0,0 +1,25 @@ +function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) + W, P = WP + ΔW, ΔP = ΔWP + mul!(ΔP, adjoint(W), ΔA, +1, 0) + K̇ = _sylvester(P, P, adjoint(ΔP) - ΔP) + mul!(ΔW, ΔA, inv(P), +1, 0) + WᴴdAiP = W' * ΔW + mul!(ΔW, W, WᴴdAiP, -1, +1) + ΔW = mul!(ΔW, W, K̇, +1, +1) + ΔP = mul!(ΔP, K̇, P, -1, +1) + return (ΔW, ΔP) +end + +function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) + P, Wᴴ = PWᴴ + ΔP, ΔWᴴ = ΔPWᴴ + mul!(ΔP, ΔA, adjoint(Wᴴ), +1, 0) + K̇ = _sylvester(P, P, adjoint(ΔP) - ΔP) + mul!(ΔWᴴ, inv(P), ΔA, +1, 0) + iPdAW = ΔWᴴ * Wᴴ' + mul!(ΔWᴴ, iPdAW, Wᴴ, -1, +1) + ΔWᴴ = mul!(ΔWᴴ, K̇, Wᴴ, +1, +1) + ΔP = mul!(ΔP, P, K̇, -1, +1) + return (ΔWᴴ, ΔP) +end diff --git a/test/testsuite/enzyme/polar.jl b/test/testsuite/enzyme/polar.jl index bfc889c24..e342f416d 100644 --- a/test/testsuite/enzyme/polar.jl +++ b/test/testsuite/enzyme/polar.jl @@ -14,14 +14,14 @@ end """ test_enzyme_left_polar(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `left_polar` and its in-place variant. Only runs -for tall or square matrices (`m >= n`). +Test the Enzyme forward- and reverse-mode AD rule for `left_polar` and its in-place variant. +Only runs for tall or square matrices (`m >= n`). """ function test_enzyme_left_polar( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) ) - return @testset "left_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "left_polar: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) m, n = size(A) if m >= n @@ -29,6 +29,9 @@ function test_enzyme_left_polar( WP, ΔWP = ad_left_polar_setup(A) test_reverse(left_polar, RT, (A, TA), (alg, Const); atol, rtol) test_reverse(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol) + A = instantiate_matrix(T, sz) + test_forward(left_polar, RT, (A, TA), (alg, Const); atol, rtol) + test_forward(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol) end end end @@ -36,14 +39,14 @@ end """ test_enzyme_right_polar(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `right_polar` and its in-place variant. Only runs -for wide or square matrices (`m <= n`). +Test the Enzyme forward- and reverse-mode AD rule for `right_polar` and its in-place variant. +Only runs for wide or square matrices (`m <= n`). """ function test_enzyme_right_polar( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) ) - return @testset "right_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "right_polar: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) m, n = size(A) if m <= n @@ -51,6 +54,9 @@ function test_enzyme_right_polar( PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) test_reverse(right_polar, RT, (A, TA), (alg, Const); atol, rtol) test_reverse(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol) + A = instantiate_matrix(T, sz) + test_forward(right_polar, RT, (A, TA), (alg, Const); atol, rtol) + test_forward(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol) end end end diff --git a/test/testsuite/mooncake/polar.jl b/test/testsuite/mooncake/polar.jl index 161360d62..fee32e71e 100644 --- a/test/testsuite/mooncake/polar.jl +++ b/test/testsuite/mooncake/polar.jl @@ -14,8 +14,8 @@ end """ test_mooncake_left_polar(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `left_polar` and its in-place variant. Only runs -for tall or square matrices (`m >= n`). +Test the Mooncake forward- and reverse-mode AD rule for `left_polar` and its in-place variant. +Only runs for tall or square matrices (`m >= n`). """ function test_mooncake_left_polar( T, sz; @@ -31,11 +31,11 @@ function test_mooncake_left_polar( Mooncake.TestUtils.test_rule( rng, left_polar, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, left_polar!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end @@ -43,8 +43,8 @@ end """ test_mooncake_right_polar(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `right_polar` and its in-place variant. Only runs -for wide or square matrices (`m <= n`). +Test the Mooncake forward- and reverse-mode AD rule for `right_polar` and its in-place variant. +Only runs for wide or square matrices (`m <= n`). """ function test_mooncake_right_polar( T, sz; @@ -60,11 +60,11 @@ function test_mooncake_right_polar( Mooncake.TestUtils.test_rule( rng, right_polar, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, right_polar!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end