diff --git a/test/mooncake/projections.jl b/test/mooncake/projections.jl index b33967e9d..5e2a1e2c5 100644 --- a/test/mooncake/projections.jl +++ b/test/mooncake/projections.jl @@ -18,4 +18,8 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.test_mooncake_projections(T, (m, m); atol, rtol) TestSuite.test_mooncake_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol) end + if T ∈ BLASFloats && CUDA.functional() + TestSuite.test_mooncake_projections(CuMatrix{T}, (m, m); atol, rtol) + TestSuite.test_mooncake_projections(Diagonal{T, CuVector{T}}, (m, m); atol, rtol) + end end