diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index f834881fd..343bd9681 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -6,6 +6,7 @@ using MatrixAlgebraKit: diagview, inv_safe, truncate using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback! +using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_vals_pushforward!, eigh_vals_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! @@ -119,8 +120,10 @@ for (f, pb) in ( end for (f, pf) in ( - (left_polar!, left_polar_pushforward!), - (right_polar!, right_polar_pushforward!), + (:right_polar!, :right_polar_pushforward!), + (:left_polar!, :left_polar_pushforward!), + (:eigh_full!, :eigh_pushforward!), + (:eig_full!, :eig_pushforward!), ) @eval begin function EnzymeRules.forward( @@ -128,13 +131,17 @@ for (f, pf) in ( func::Const{typeof($f)}, ::Type{RT}, A::Annotation, - arg::Annotation{TA}, + arg::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT, TA} + ) where {RT} + A_is_arg1 = !isa(A, Const) && A.val === arg.val[1] + A_is_arg2 = !isa(A, Const) && A.val === arg.val[2] + A_is_arg = A_is_arg1 || A_is_arg2 $f(A.val, arg.val, alg.val) if !isa(A, Const) && !isa(arg, Const) $pf(A.dval, A.val, arg.val, arg.dval) end + !A_is_arg && make_zero!(A.dval) if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return arg elseif EnzymeRules.needs_primal(config) @@ -367,9 +374,9 @@ for (f, trunc_f, full_f, pb) in ( end end -for (f!, f_full!, pb!) in ( - (eig_vals!, eig_full!, eig_vals_pullback!), - (eigh_vals!, eigh_full!, eigh_vals_pullback!), +for (f!, f_full!, pb!, pf!) in ( + (:eig_vals!, :eig_full!, :eig_vals_pullback!, :eig_vals_pushforward!), + (:eigh_vals!, :eigh_full!, :eigh_vals_pullback!, :eigh_vals_pushforward!), ) @eval begin function EnzymeRules.augmented_primal( @@ -418,6 +425,34 @@ for (f!, f_full!, pb!) in ( !isa(D, Const) && !A_is_arg && make_zero!(D.dval) return (nothing, nothing, nothing) end + function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation{TA}, + D::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval + DV = $f_full!(A.val, alg.val) + Dval, V = DV + if !isa(A, Const) && !isa(D, Const) + ΔD = A_is_arg ? make_zero(D.dval) : D.dval + $pf!(A.dval, A.val, (Diagonal(diagview(Dval)), V), ΔD) + A_is_arg && (D.dval .= ΔD) + end + copyto!(D.val, diagview(Dval)) + !isa(A, Const) && !A_is_arg && make_zero!(A.dval) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return D + elseif EnzymeRules.needs_primal(config) + return D.val + elseif EnzymeRules.needs_shadow(config) + return D.dval + else + return nothing + end + end end end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 46b4f9538..16241385b 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -7,6 +7,8 @@ using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero! using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! +using MatrixAlgebraKit: eig_pushforward!, eig_vals_pushforward! +using MatrixAlgebraKit: eigh_pushforward!, eigh_vals_pushforward! using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! @@ -113,6 +115,8 @@ end for (f!, f, pf) in ( (:left_polar!, :left_polar, :left_polar_pushforward!), (:right_polar!, :right_polar, :right_polar_pushforward!), + (:eig_full!, :eig_full, :eig_pushforward!), + (:eigh_full!, :eigh_full, :eigh_pushforward!), ) @eval begin @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} @@ -177,12 +181,12 @@ for (f!, f, pb, adj) in ( end end -for (f!, f, f_full, pb, adj) in ( - (:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_adjoint), - (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint), +for (f!, f, f_full, pb, pf, adj) in ( + (:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_pushforward!, :eig_vals_adjoint), + (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_pushforward!, :eigh_vals_adjoint), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -210,7 +214,18 @@ for (f!, f, f_full, pb, adj) in ( end return D_dD, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + D, dD = arrayify(D_dD) + # update primal + DV = $f_full(A, Mooncake.primal(alg_dalg)) + V = DV[2] + copyto!(D, diagview(DV[1])) + $pf(dA, A, (D, V), dD) + return D_dD + end + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -227,6 +242,18 @@ for (f!, f, f_full, pb, adj) in ( end return output_codual, $adj end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + # update primal + DV = $f_full(A, Mooncake.primal(alg_dalg)) + V = DV[2] + output = diagview(DV[1]) + output_dual = Dual(output, Mooncake.zero_tangent(output)) + D, dD = arrayify(output_dual) + $pf(dA, A, DV, dD) + return output_dual + end end end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 37b964b44..65de152c4 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -130,6 +130,8 @@ include("pullbacks/svd.jl") include("pullbacks/polar.jl") include("pushforwards/polar.jl") +include("pushforwards/eig.jl") +include("pushforwards/eigh.jl") include("precompile.jl") diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl new file mode 100644 index 000000000..9e39f6395 --- /dev/null +++ b/src/pushforwards/eig.jl @@ -0,0 +1,22 @@ +function eig_pushforward!( + ΔA, A, DV, ΔDV; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + D, V = DV + ΔD, ΔV = ΔDV + ΔAV = isnothing(ΔV) ? ΔA * V : mul!(ΔV, ΔA, V) # reusing ΔV memory if possible + ∂K = V \ ΔAV + if !iszerotangent(ΔD) + diagview(ΔD) .= diagview(∂K) + end + if !iszerotangent(ΔV) + ∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol) + mul!(ΔV, V, ∂K, 1, 0) + end + return ΔDV +end + +function eig_vals_pushforward!(ΔA, A, DV, ΔD; kwargs...) + return eig_pushforward!(ΔA, A, DV, (Diagonal(ΔD), nothing); kwargs...) +end diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl new file mode 100644 index 000000000..e610867fe --- /dev/null +++ b/src/pushforwards/eigh.jl @@ -0,0 +1,22 @@ +function eigh_pushforward!( + ΔA, A, DV, ΔDV; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + D, V = DV + ΔD, ΔV = ΔDV + ΔAV = isnothing(ΔV) ? ΔA * V : mul!(ΔV, ΔA, V) # reusing ΔV memory if possible + ∂K = V' * ΔAV + if !iszerotangent(ΔD) + diagview(ΔD) .= real.(diagview(∂K)) + end + if !iszerotangent(ΔV) + ∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol) + ΔV = mul!(ΔV, V, ∂K) + end + return (ΔD, ΔV) +end + +function eigh_vals_pushforward!(ΔA, A, DV, ΔD; kwargs...) + return eigh_pushforward!(ΔA, A, DV, (Diagonal(ΔD), nothing); kwargs...) +end diff --git a/test/enzyme/eig.jl b/test/enzyme/eig.jl index 949129eac..1404d751c 100644 --- a/test/enzyme/eig.jl +++ b/test/enzyme/eig.jl @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}} - TestSuite.test_enzyme_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + TestSuite.test_enzyme_eig(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/enzyme/eigh.jl b/test/enzyme/eigh.jl index 64c796fc6..05ae7f384 100644 --- a/test/enzyme/eigh.jl +++ b/test/enzyme/eigh.jl @@ -14,8 +14,8 @@ m = 19 for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(1234) if !is_buildkite - #TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}} - TestSuite.test_enzyme_eigh(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + TestSuite.test_enzyme_eigh(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/testsuite/enzyme/eig.jl b/test/testsuite/enzyme/eig.jl index 2c678df3a..b260e557b 100644 --- a/test/testsuite/enzyme/eig.jl +++ b/test/testsuite/enzyme/eig.jl @@ -15,38 +15,48 @@ end """ test_enzyme_eig_full(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `eig_full` and its in-place variant. +Test the Enzyme foward- and reverse-mode AD rule for `eig_full` and its in-place variant. """ function test_enzyme_eig_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "eig_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "eig_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = make_eig_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(eig_full, A) DV, ΔDV = ad_eig_full_setup(A) test_reverse(eig_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) test_reverse(call_and_zero!, RT, (eig_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) + if eltype(T) <: Real && T <: Diagonal + A = make_eig_matrix(T, sz) + test_forward(eig_full, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (eig_full!, Const), (A, TA), (alg, Const); atol, rtol, fdm) + end end end """ test_enzyme_eig_vals(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `eig_vals` and its in-place variant. +Test the Enzyme forward- and reverse-mode AD rule for `eig_vals` and its in-place variant. """ function test_enzyme_eig_vals( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "eig_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "eig_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = make_eig_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(eig_vals, A) D, ΔD = ad_eig_vals_setup(A) test_reverse(eig_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm) test_reverse(call_and_zero!, RT, (eig_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm) + if eltype(T) <: Real + A = make_eig_matrix(T, sz) + test_forward(eig_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (eig_vals!, Const), (A, TA), (alg, Const); atol, rtol, fdm) + end end end diff --git a/test/testsuite/enzyme/eigh.jl b/test/testsuite/enzyme/eigh.jl index 73dc84ed2..83c81296d 100644 --- a/test/testsuite/enzyme/eigh.jl +++ b/test/testsuite/enzyme/eigh.jl @@ -15,38 +15,46 @@ end """ test_enzyme_eigh_full(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `eigh_full` and its in-place variant. +Test the Enzyme forward- and reverse-mode AD rule for `eigh_full` and its in-place variant. """ function test_enzyme_eigh_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "eigh_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "eigh_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = make_eigh_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(eigh_full, A) DV, ΔDV = ad_eigh_full_setup(A) test_reverse(eigh_wrapper, RT, (eigh_full, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) test_reverse(eigh!_wrapper, RT, (eigh_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) + if eltype(T) <: Real + A = make_eigh_matrix(T, sz) + test_forward(eigh_wrapper, RT, (eigh_full, Const), (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(eigh!_wrapper, RT, (eigh_full!, Const), (A, TA), (alg, Const); atol, rtol, fdm) + end end end """ test_enzyme_eigh_vals(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `eigh_vals` and its in-place variant. +Test the Enzyme forward- and reverse-mode AD rule for `eigh_vals` and its in-place variant. """ function test_enzyme_eigh_vals( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "eigh_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "eigh_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = make_eigh_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(eigh_vals, A) D, ΔD = ad_eigh_vals_setup(A) test_reverse(eigh_wrapper, RT, (eigh_vals, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm) test_reverse(eigh!_wrapper, RT, (eigh_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm) + A = make_eigh_matrix(T, sz) + test_forward(eigh_wrapper, RT, (eigh_vals, Const), (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(eigh!_wrapper, RT, (eigh_vals!, Const), (A, TA), (alg, Const); atol, rtol, fdm) end end diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl index aad4b4881..3cc0063c7 100644 --- a/test/testsuite/mooncake/eig.jl +++ b/test/testsuite/mooncake/eig.jl @@ -15,7 +15,7 @@ end """ test_mooncake_eig_full(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `eig_full` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `eig_full` and its in-place variant. """ function test_mooncake_eig_full( T, sz; @@ -29,25 +29,39 @@ function test_mooncake_eig_full( Mooncake.TestUtils.test_rule( rng, eig_full, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + mode = Mooncake.ReverseMode, + output_tangent, atol, rtol ) + Mooncake.TestUtils.test_rule( + rng, call_and_zero!, eig_full!, A, alg; + mode = Mooncake.ReverseMode, + output_tangent, atol, rtol, is_primitive = false + ) + if !(eltype(T) <: ComplexF64) + Mooncake.TestUtils.test_rule( + rng, eig_full, A, alg; + mode = Mooncake.ForwardMode, + atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, call_and_zero!, eig_full!, A, alg; + mode = Mooncake.ForwardMode, + atol, rtol, is_primitive = false + ) + end if T <: Diagonal{<:Complex} Mooncake.TestUtils.test_rule( rng, eig_full!, A, (A, DV[2]), alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) end - Mooncake.TestUtils.test_rule( - rng, call_and_zero!, eig_full!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false - ) end end """ test_mooncake_eig_vals(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `eig_vals` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `eig_vals` and its in-place variant. """ function test_mooncake_eig_vals( T, sz; @@ -61,17 +75,17 @@ function test_mooncake_eig_vals( Mooncake.TestUtils.test_rule( rng, eig_vals, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) if T <: Diagonal{<:Complex} Mooncake.TestUtils.test_rule( rng, eig_vals!, A, A.diag, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) end Mooncake.TestUtils.test_rule( rng, call_and_zero!, eig_vals!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end diff --git a/test/testsuite/mooncake/eigh.jl b/test/testsuite/mooncake/eigh.jl index 62952aa8e..00f29a6f7 100644 --- a/test/testsuite/mooncake/eigh.jl +++ b/test/testsuite/mooncake/eigh.jl @@ -35,6 +35,16 @@ function test_mooncake_eigh_full( rng, eigh!_wrapper, eigh_full!, A, alg; mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol ) + if !(eltype(T) <: Complex) + Mooncake.TestUtils.test_rule( + rng, eigh_wrapper, eigh_full, A, alg; + mode = Mooncake.ForwardMode, output_tangent, is_primitive = false, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, eigh!_wrapper, eigh_full!, A, alg; + mode = Mooncake.ForwardMode, output_tangent, is_primitive = false, atol, rtol + ) + end end end @@ -55,11 +65,11 @@ function test_mooncake_eigh_vals( Mooncake.TestUtils.test_rule( rng, eigh_wrapper, eigh_vals, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( rng, eigh!_wrapper, eigh_vals!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) end end