diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 28d56ca6..46b4f953 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -858,7 +858,7 @@ for (f!, f, adj) in ( (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), ) @eval begin - @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) @@ -879,8 +879,14 @@ for (f!, f, adj) in ( return arg_darg, $adj end - - @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(f_df::Dual{typeof($f!)}, A_dA::Dual, arg_darg::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) + $f!(A, arg, Mooncake.primal(alg_dalg)) + $f!(dA, darg, Mooncake.primal(alg_dalg)) + return arg_darg + end + @is_primitive DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) @@ -896,6 +902,14 @@ for (f!, f, adj) in ( return output_doutput, $adj end + function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + output_doutput = Mooncake.zero_dual(output) + doutput = last(arrayify(output_doutput)) + $f!(dA, doutput, Mooncake.primal(alg_dalg)) + return output_doutput + end end end diff --git a/test/testsuite/enzyme/projections.jl b/test/testsuite/enzyme/projections.jl index bc3e7014..2c5240a2 100644 --- a/test/testsuite/enzyme/projections.jl +++ b/test/testsuite/enzyme/projections.jl @@ -15,39 +15,45 @@ end """ test_enzyme_project_hermitian(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `project_hermitian` and its in-place variant. +Test the Enzyme forward- and reverse-mode AD rule for `project_hermitian` and its in-place variant. """ function test_enzyme_project_hermitian( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "project_hermitian reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "project_hermitian: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) B = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A) test_reverse(project_hermitian, RT, (A, TA), (alg, Const); atol, rtol, fdm) test_reverse(project_hermitian!, RT, (A, TA), (B, TA), (alg, Const); atol, rtol, fdm) test_reverse(project_hermitian_inplace!, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(project_hermitian, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(project_hermitian!, RT, (A, TA), (B, TA), (alg, Const); atol, rtol, fdm) + test_forward(project_hermitian_inplace!, RT, (A, TA), (alg, Const); atol, rtol, fdm) end end """ test_enzyme_project_antihermitian(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rule for `project_antihermitian` and its in-place variant. +Test the Enzyme forward- and reverse-mode AD rule for `project_antihermitian` and its in-place variant. """ function test_enzyme_project_antihermitian( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "project_antihermitian reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "project_antihermitian: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) B = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A) test_reverse(project_antihermitian, RT, (A, TA), (alg, Const); atol, rtol, fdm) test_reverse(project_antihermitian!, RT, (A, TA), (B, TA), (alg, Const); atol, rtol, fdm) test_reverse(project_antihermitian_inplace!, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(project_antihermitian, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(project_antihermitian!, RT, (A, TA), (B, TA), (alg, Const); atol, rtol, fdm) + test_forward(project_antihermitian_inplace!, RT, (A, TA), (alg, Const); atol, rtol, fdm) end end diff --git a/test/testsuite/mooncake/projections.jl b/test/testsuite/mooncake/projections.jl index 74e22a44..359c9cbe 100644 --- a/test/testsuite/mooncake/projections.jl +++ b/test/testsuite/mooncake/projections.jl @@ -15,7 +15,7 @@ end """ test_mooncake_project_hermitian(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `project_hermitian` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `project_hermitian` and its in-place variant. """ function test_mooncake_project_hermitian( T, sz; @@ -27,15 +27,15 @@ function test_mooncake_project_hermitian( alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A) Mooncake.TestUtils.test_rule( rng, project_hermitian, A, alg; - mode = Mooncake.ReverseMode, atol, rtol + atol, rtol ) Mooncake.TestUtils.test_rule( rng, project_hermitian!, A, A, alg; - mode = Mooncake.ReverseMode, atol, rtol + atol, rtol ) Mooncake.TestUtils.test_rule( rng, project_hermitian!, A, B, alg; - mode = Mooncake.ReverseMode, atol, rtol + atol, rtol ) end end @@ -43,7 +43,7 @@ end """ test_mooncake_project_antihermitian(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `project_antihermitian` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `project_antihermitian` and its in-place variant. """ function test_mooncake_project_antihermitian( T, sz; @@ -55,15 +55,15 @@ function test_mooncake_project_antihermitian( alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A) Mooncake.TestUtils.test_rule( rng, project_antihermitian, A, alg; - mode = Mooncake.ReverseMode, atol, rtol + atol, rtol ) Mooncake.TestUtils.test_rule( rng, project_antihermitian!, A, A, alg; - mode = Mooncake.ReverseMode, atol, rtol + atol, rtol ) Mooncake.TestUtils.test_rule( rng, project_antihermitian!, A, B, alg; - mode = Mooncake.ReverseMode, atol, rtol + atol, rtol ) end end