From 550b348f8de3995b813c9a92d10832c55dbbf47e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 11:10:20 +0200 Subject: [PATCH 1/7] Add Mooncake fwd rules for polar --- .../MatrixAlgebraKitMooncakeExt.jl | 40 ++++++++++++++++++- src/MatrixAlgebraKit.jl | 2 + src/pushforwards/polar.jl | 21 ++++++++++ test/testsuite/mooncake/polar.jl | 16 ++++---- 4 files changed, 69 insertions(+), 10 deletions(-) create mode 100644 src/pushforwards/polar.jl diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index e2ae96c11..28d56ca65 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -9,6 +9,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! +using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra @@ -16,7 +17,7 @@ using LinearAlgebra Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent # needed for GPU tests because Mooncake can't differentiate through CUDA kernels -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(zero!), AbstractArray} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(zero!), AbstractArray} function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual) A, dA = arrayify(A_dA) Ac = copy(A) @@ -28,6 +29,12 @@ function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual) end return A_dA, zero_adjoint end +function Mooncake.frule!!(::Dual{typeof(zero!)}, A_dA::Dual) + A, dA = arrayify(A_dA) + zero!(A) + zero!(dA) + return A_dA +end # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( @@ -40,7 +47,6 @@ for (f!, f, pb, adj) in ( (:left_polar!, :left_polar, :left_polar_pullback!, :left_polar_adjoint), (:right_polar!, :right_polar, :right_polar_pullback!, :right_polar_adjoint), ) - @eval begin @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) @@ -104,6 +110,36 @@ for (f!, f, pb, adj) in ( end end +for (f!, f, pf) in ( + (:left_polar!, :left_polar, :left_polar_pushforward!), + (:right_polar!, :right_polar, :right_polar_pushforward!), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + args = Mooncake.primal(args_dargs) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + $f!(A, args, Mooncake.primal(alg_dalg)) + $pf(dA, A, (arg1, arg2), (darg1, darg2)) + return args_dargs + end + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + doutput = Mooncake.zero_tangent(output) + output_dual = Dual(output, doutput) + arg1, darg1 = arrayify(output[1], doutput[1]) + arg2, darg2 = arrayify(output[2], doutput[2]) + $pf(dA, A, (arg1, arg2), (darg1, darg2)) + return output_dual + end + end +end + for (f!, f, pb, adj) in ( (:qr_null!, :qr_null, :qr_null_pullback!, :qr_null_adjoint), (:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint), diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 25dc60c1b..37b964b44 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -129,6 +129,8 @@ include("pullbacks/eigh.jl") include("pullbacks/svd.jl") include("pullbacks/polar.jl") +include("pushforwards/polar.jl") + include("precompile.jl") end diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl new file mode 100644 index 000000000..1e0da1b22 --- /dev/null +++ b/src/pushforwards/polar.jl @@ -0,0 +1,21 @@ +function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) + W, P = WP + ΔW, ΔP = ΔWP + aWdA = adjoint(W) * ΔA + K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA))) + L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P) + ΔW .= W * K̇ + L̇ + ΔP .= aWdA - K̇ * P + return (ΔW, ΔP) +end + +function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) + P, Wᴴ = PWᴴ + ΔP, ΔWᴴ = ΔPWᴴ + dAW = ΔA * adjoint(Wᴴ) + K̇ = sylvester(P, P, -(dAW - adjoint(dAW))) + L̇ = inv(P) * ΔA * (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) + ΔWᴴ .= K̇ * Wᴴ + L̇ + ΔP .= dAW - P * K̇ + return (ΔWᴴ, ΔP) +end diff --git a/test/testsuite/mooncake/polar.jl b/test/testsuite/mooncake/polar.jl index 161360d62..fee32e71e 100644 --- a/test/testsuite/mooncake/polar.jl +++ b/test/testsuite/mooncake/polar.jl @@ -14,8 +14,8 @@ end """ test_mooncake_left_polar(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `left_polar` and its in-place variant. Only runs -for tall or square matrices (`m >= n`). +Test the Mooncake forward- and reverse-mode AD rule for `left_polar` and its in-place variant. +Only runs for tall or square matrices (`m >= n`). """ function test_mooncake_left_polar( T, sz; @@ -31,11 +31,11 @@ function test_mooncake_left_polar( Mooncake.TestUtils.test_rule( rng, left_polar, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, left_polar!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end @@ -43,8 +43,8 @@ end """ test_mooncake_right_polar(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `right_polar` and its in-place variant. Only runs -for wide or square matrices (`m <= n`). +Test the Mooncake forward- and reverse-mode AD rule for `right_polar` and its in-place variant. +Only runs for wide or square matrices (`m <= n`). """ function test_mooncake_right_polar( T, sz; @@ -60,11 +60,11 @@ function test_mooncake_right_polar( Mooncake.TestUtils.test_rule( rng, right_polar, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, right_polar!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end From be8bfdb429cf8d198232fe5b6ed945773d27c522 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 11:23:33 +0200 Subject: [PATCH 2/7] Forward rules and tests for Enzyme + polar --- .../MatrixAlgebraKitEnzymeExt.jl | 31 +++++++++++++++++++ test/testsuite/enzyme/polar.jl | 18 +++++++---- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index a7014a463..f834881fd 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -8,6 +8,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! +using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using Enzyme using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules @@ -117,6 +118,36 @@ for (f, pb) in ( end end +for (f, pf) in ( + (left_polar!, left_polar_pushforward!), + (right_polar!, right_polar_pushforward!), + ) + @eval begin + function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + arg::Annotation{TA}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + $f(A.val, arg.val, alg.val) + if !isa(A, Const) && !isa(arg, Const) + $pf(A.dval, A.val, arg.val, arg.dval) + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return arg + elseif EnzymeRules.needs_primal(config) + return arg.val + elseif EnzymeRules.needs_shadow(config) + return arg.dval + else + return nothing + end + end + end +end + for (f, pb) in ( (qr_null!, qr_null_pullback!), (lq_null!, lq_null_pullback!), diff --git a/test/testsuite/enzyme/polar.jl b/test/testsuite/enzyme/polar.jl index bfc889c24..e342f416d 100644 --- a/test/testsuite/enzyme/polar.jl +++ b/test/testsuite/enzyme/polar.jl @@ -14,14 +14,14 @@ end """ test_enzyme_left_polar(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `left_polar` and its in-place variant. Only runs -for tall or square matrices (`m >= n`). +Test the Enzyme forward- and reverse-mode AD rule for `left_polar` and its in-place variant. +Only runs for tall or square matrices (`m >= n`). """ function test_enzyme_left_polar( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) ) - return @testset "left_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "left_polar: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) m, n = size(A) if m >= n @@ -29,6 +29,9 @@ function test_enzyme_left_polar( WP, ΔWP = ad_left_polar_setup(A) test_reverse(left_polar, RT, (A, TA), (alg, Const); atol, rtol) test_reverse(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol) + A = instantiate_matrix(T, sz) + test_forward(left_polar, RT, (A, TA), (alg, Const); atol, rtol) + test_forward(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol) end end end @@ -36,14 +39,14 @@ end """ test_enzyme_right_polar(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `right_polar` and its in-place variant. Only runs -for wide or square matrices (`m <= n`). +Test the Enzyme forward- and reverse-mode AD rule for `right_polar` and its in-place variant. +Only runs for wide or square matrices (`m <= n`). """ function test_enzyme_right_polar( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) ) - return @testset "right_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "right_polar: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) m, n = size(A) if m <= n @@ -51,6 +54,9 @@ function test_enzyme_right_polar( PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) test_reverse(right_polar, RT, (A, TA), (alg, Const); atol, rtol) test_reverse(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol) + A = instantiate_matrix(T, sz) + test_forward(right_polar, RT, (A, TA), (alg, Const); atol, rtol) + test_forward(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol) end end end From 0a9222d1dc024b45aefb8acf346d62c3b6e3df07 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 12:01:03 +0200 Subject: [PATCH 3/7] Use _sylvester fallback --- src/pushforwards/polar.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl index 1e0da1b22..8dd22770a 100644 --- a/src/pushforwards/polar.jl +++ b/src/pushforwards/polar.jl @@ -2,7 +2,7 @@ function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) W, P = WP ΔW, ΔP = ΔWP aWdA = adjoint(W) * ΔA - K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA))) + K̇ = _sylvester(P, P, -(aWdA - adjoint(aWdA))) L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P) ΔW .= W * K̇ + L̇ ΔP .= aWdA - K̇ * P @@ -13,7 +13,7 @@ function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) P, Wᴴ = PWᴴ ΔP, ΔWᴴ = ΔPWᴴ dAW = ΔA * adjoint(Wᴴ) - K̇ = sylvester(P, P, -(dAW - adjoint(dAW))) + K̇ = _sylvester(P, P, -(dAW - adjoint(dAW))) L̇ = inv(P) * ΔA * (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) ΔWᴴ .= K̇ * Wᴴ + L̇ ΔP .= dAW - P * K̇ From a19ea21ebd3570e9d4f6271a54ca8b3063e4537e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 06:44:42 -0400 Subject: [PATCH 4/7] Comments and fix --- src/pushforwards/polar.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl index 8dd22770a..489d9a7e2 100644 --- a/src/pushforwards/polar.jl +++ b/src/pushforwards/polar.jl @@ -1,11 +1,11 @@ function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) W, P = WP ΔW, ΔP = ΔWP - aWdA = adjoint(W) * ΔA - K̇ = _sylvester(P, P, -(aWdA - adjoint(aWdA))) - L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P) + WᴴdA = adjoint(W) * ΔA + K̇ = _sylvester(P, P, -(WᴴdA - adjoint(WᴴdA))) + L̇ = (LinearAlgebra.UniformScaling(1) - W * adjoint(W)) * ΔA * inv(P) ΔW .= W * K̇ + L̇ - ΔP .= aWdA - K̇ * P + ΔP .= WᴴdA - K̇ * P return (ΔW, ΔP) end @@ -14,7 +14,7 @@ function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) ΔP, ΔWᴴ = ΔPWᴴ dAW = ΔA * adjoint(Wᴴ) K̇ = _sylvester(P, P, -(dAW - adjoint(dAW))) - L̇ = inv(P) * ΔA * (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) + L̇ = inv(P) * ΔA * (LinearAlgebra.UniformScaling(1) - adjoint(Wᴴ) * Wᴴ) ΔWᴴ .= K̇ * Wᴴ + L̇ ΔP .= dAW - P * K̇ return (ΔWᴴ, ΔP) From b26d5b35ca438915149dcdf6015af3d6a7269e40 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 09:12:37 -0400 Subject: [PATCH 5/7] Apply Jutho suggestions --- src/pushforwards/polar.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl index 489d9a7e2..b5887ed9e 100644 --- a/src/pushforwards/polar.jl +++ b/src/pushforwards/polar.jl @@ -3,7 +3,9 @@ function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) ΔW, ΔP = ΔWP WᴴdA = adjoint(W) * ΔA K̇ = _sylvester(P, P, -(WᴴdA - adjoint(WᴴdA))) - L̇ = (LinearAlgebra.UniformScaling(1) - W * adjoint(W)) * ΔA * inv(P) + dAiP = ΔA * inv(P) + WᴴdAiP = W' * dAiP + L̇ = mul!(dAiP, W, WᴴdAiP, -1, +1) ΔW .= W * K̇ + L̇ ΔP .= WᴴdA - K̇ * P return (ΔW, ΔP) @@ -14,7 +16,9 @@ function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) ΔP, ΔWᴴ = ΔPWᴴ dAW = ΔA * adjoint(Wᴴ) K̇ = _sylvester(P, P, -(dAW - adjoint(dAW))) - L̇ = inv(P) * ΔA * (LinearAlgebra.UniformScaling(1) - adjoint(Wᴴ) * Wᴴ) + iPdA = inv(P) * ΔA + iPdAW = iPdA * Wᴴ' + L̇ = mul!(iPdA, iPdAW, Wᴴ, -1, +1) ΔWᴴ .= K̇ * Wᴴ + L̇ ΔP .= dAW - P * K̇ return (ΔWᴴ, ΔP) From 887b1cb8fd48bc4c48df069d0d338db9daa84f0a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 4 Jun 2026 02:24:06 -0400 Subject: [PATCH 6/7] Pushforward improvements --- src/pushforwards/polar.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl index b5887ed9e..509913973 100644 --- a/src/pushforwards/polar.jl +++ b/src/pushforwards/polar.jl @@ -1,25 +1,25 @@ function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) W, P = WP ΔW, ΔP = ΔWP - WᴴdA = adjoint(W) * ΔA - K̇ = _sylvester(P, P, -(WᴴdA - adjoint(WᴴdA))) + mul!(ΔP, adjoint(W), ΔA, +1, 0) + K̇ = _sylvester(P, P, adjoint(ΔP) - ΔP) dAiP = ΔA * inv(P) WᴴdAiP = W' * dAiP L̇ = mul!(dAiP, W, WᴴdAiP, -1, +1) ΔW .= W * K̇ + L̇ - ΔP .= WᴴdA - K̇ * P + ΔP = mul!(ΔP, K̇, P, -1, +1) return (ΔW, ΔP) end function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) P, Wᴴ = PWᴴ ΔP, ΔWᴴ = ΔPWᴴ - dAW = ΔA * adjoint(Wᴴ) - K̇ = _sylvester(P, P, -(dAW - adjoint(dAW))) + mul!(ΔP, ΔA, adjoint(Wᴴ), +1, 0) + K̇ = _sylvester(P, P, adjoint(ΔP) - ΔP) iPdA = inv(P) * ΔA iPdAW = iPdA * Wᴴ' L̇ = mul!(iPdA, iPdAW, Wᴴ, -1, +1) ΔWᴴ .= K̇ * Wᴴ + L̇ - ΔP .= dAW - P * K̇ + ΔP = mul!(ΔP, P, K̇, -1, +1) return (ΔWᴴ, ΔP) end From 741278da5037701b9a0af128abbfce312ea6d055 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 4 Jun 2026 05:11:50 -0400 Subject: [PATCH 7/7] even more optimization --- src/pushforwards/polar.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl index 509913973..6238ffa4c 100644 --- a/src/pushforwards/polar.jl +++ b/src/pushforwards/polar.jl @@ -3,10 +3,10 @@ function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) ΔW, ΔP = ΔWP mul!(ΔP, adjoint(W), ΔA, +1, 0) K̇ = _sylvester(P, P, adjoint(ΔP) - ΔP) - dAiP = ΔA * inv(P) - WᴴdAiP = W' * dAiP - L̇ = mul!(dAiP, W, WᴴdAiP, -1, +1) - ΔW .= W * K̇ + L̇ + mul!(ΔW, ΔA, inv(P), +1, 0) + WᴴdAiP = W' * ΔW + mul!(ΔW, W, WᴴdAiP, -1, +1) + ΔW = mul!(ΔW, W, K̇, +1, +1) ΔP = mul!(ΔP, K̇, P, -1, +1) return (ΔW, ΔP) end @@ -16,10 +16,10 @@ function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) ΔP, ΔWᴴ = ΔPWᴴ mul!(ΔP, ΔA, adjoint(Wᴴ), +1, 0) K̇ = _sylvester(P, P, adjoint(ΔP) - ΔP) - iPdA = inv(P) * ΔA - iPdAW = iPdA * Wᴴ' - L̇ = mul!(iPdA, iPdAW, Wᴴ, -1, +1) - ΔWᴴ .= K̇ * Wᴴ + L̇ + mul!(ΔWᴴ, inv(P), ΔA, +1, 0) + iPdAW = ΔWᴴ * Wᴴ' + mul!(ΔWᴴ, iPdAW, Wᴴ, -1, +1) + ΔWᴴ = mul!(ΔWᴴ, K̇, Wᴴ, +1, +1) ΔP = mul!(ΔP, P, K̇, -1, +1) return (ΔWᴴ, ΔP) end