From 7dfedba6bfcd3add9e207be100244c7db9460700 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 24 Jun 2026 02:19:48 +0200 Subject: [PATCH 1/3] Consolidate duplicated GPU logic into new GPUArrays extension --- Project.toml | 3 + ext/TensorKitAMDGPUExt/TensorKitAMDGPUExt.jl | 3 +- ext/TensorKitAMDGPUExt/roctensormap.jl | 17 --- ext/TensorKitCUDAExt/TensorKitCUDAExt.jl | 17 --- ext/TensorKitCUDAExt/cutensormap.jl | 17 --- ext/TensorKitCUDAExt/truncation.jl | 69 ----------- ext/TensorKitGPUArraysExt.jl | 116 +++++++++++++++++++ test/Project.toml | 1 - 8 files changed, 120 insertions(+), 123 deletions(-) delete mode 100644 ext/TensorKitCUDAExt/truncation.jl create mode 100644 ext/TensorKitGPUArraysExt.jl diff --git a/Project.toml b/Project.toml index 4b19a7dcb..417e25d9c 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] @@ -31,6 +32,7 @@ TensorKitAMDGPUExt = "AMDGPU" TensorKitCUDAExt = "CUDA" TensorKitChainRulesCoreExt = "ChainRulesCore" TensorKitFiniteDifferencesExt = "FiniteDifferences" +TensorKitGPUArraysExt = "GPUArrays" TensorKitMooncakeExt = "Mooncake" [workspace] @@ -43,6 +45,7 @@ CUDA = "6" ChainRulesCore = "1" Dictionaries = "0.4" FiniteDifferences = "0.12" +GPUArrays = "11.4.1" LRUCache = "1.0.2" LinearAlgebra = "1" MatrixAlgebraKit = "0.6.7" diff --git a/ext/TensorKitAMDGPUExt/TensorKitAMDGPUExt.jl b/ext/TensorKitAMDGPUExt/TensorKitAMDGPUExt.jl index e163ca005..33afca113 100644 --- a/ext/TensorKitAMDGPUExt/TensorKitAMDGPUExt.jl +++ b/ext/TensorKitAMDGPUExt/TensorKitAMDGPUExt.jl @@ -1,7 +1,6 @@ module TensorKitAMDGPUExt using AMDGPU, AMDGPU.rocBLAS, AMDGPU.rocSOLVER, LinearAlgebra -using AMDGPU: @allowscalar import AMDGPU: rand as rocrand, rand! as rocrand!, randn as rocrandn, randn! as rocrandn! using TensorKit @@ -9,7 +8,7 @@ using TensorKit.Factorizations using Strided using MatrixAlgebraKit using MatrixAlgebraKit: AbstractAlgorithm -using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check +using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype import TensorKit: randisometry using Base: rand, randn diff --git a/ext/TensorKitAMDGPUExt/roctensormap.jl b/ext/TensorKitAMDGPUExt/roctensormap.jl index f2f094c60..a2ccd6cec 100644 --- a/ext/TensorKitAMDGPUExt/roctensormap.jl +++ b/ext/TensorKitAMDGPUExt/roctensormap.jl @@ -7,16 +7,6 @@ function ROCTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂ return ROCTensorMap{T, S, N₁, N₂}(ROCArray{T}(t.data), space(t)) end -# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy -function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: ROCVector{T}} - h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V) - h_t = TensorKit.project_symmetric!(h_t, Array(data)) - # verify result - isapprox(Array(reshape(data, dims(h_t))), convert(Array, h_t); atol = tol) || - throw(ArgumentError("Data has non-zero elements at incompatible positions")) - return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V) -end - for (fname, felt) in ((:zeros, :zero), (:ones, :one)) @eval begin function AMDGPU.$fname( @@ -92,13 +82,6 @@ for randfun in (:rocrand, :rocrandn) end end -# Scalar implementation -#----------------------- -function TensorKit.scalar(t::ROCTensorMap{T, S, 0, 0}) where {T, S} - inds = findall(!iszero, t.data) - return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)] -end - function Base.convert( TT::Type{ROCTensorMap{T, S, N₁, N₂}}, t::AbstractTensorMap{<:Any, S, N₁, N₂} diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index dc48042a4..a54d2490d 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -1,10 +1,8 @@ module TensorKitCUDAExt using CUDA, CUDA.cuBLAS, CUDA.cuSOLVER, CUDA.cuRAND, LinearAlgebra -using CUDA: @allowscalar import CUDA.cuRAND: rand as curand, rand! as curand!, randn as curandn, randn! as curandn! using Strided: StridedViews -using CUDA.CUDACore.KernelAbstractions: @kernel, @index, get_backend using Adapt: Adapt @@ -20,20 +18,5 @@ using TensorKit: MatrixAlgebraKit using Random include("cutensormap.jl") -include("truncation.jl") - -function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}} - # COV_EXCL_START - # kernels are not reachable by coverage - @kernel function fill_subblock_kernel!(subblock, val) - idx = @index(Global, Cartesian) - idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val) - @inbounds subblock[idx] = idx_val - end - # COV_EXCL_STOP - kernel = fill_subblock_kernel!(get_backend(data)) - kernel(data, val; ndrange = size(data)) - return data -end end diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index 016749fce..e6febf5b5 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -7,16 +7,6 @@ function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) end -# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy -function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}} - h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V) - h_t = TensorKit.project_symmetric!(h_t, Array(data)) - # verify result - isapprox(Array(reshape(data, dims(h_t))), convert(Array, h_t); atol = tol) || - throw(ArgumentError("Data has non-zero elements at incompatible positions")) - return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V) -end - for (fname, felt) in ((:zeros, :zero), (:ones, :one)) @eval begin function CUDA.$fname( @@ -94,13 +84,6 @@ for randfun in (:curand, :curandn) end end -# Scalar implementation -#----------------------- -function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S} - inds = findall(!iszero, t.data) - return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)] -end - function LinearAlgebra.isposdef(t::CuTensorMap) domain(t) == codomain(t) || throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl deleted file mode 100644 index a87e6c97a..000000000 --- a/ext/TensorKitCUDAExt/truncation.jl +++ /dev/null @@ -1,69 +0,0 @@ -const CuSectorVector{T, I} = TensorKit.SectorVector{T, I, <:CuVector{T}} - -function MatrixAlgebraKit.findtruncated( - values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder - ) - I = sectortype(values) - - dims = similar(values, Base.promote_op(dim, I)) - for (c, v) in pairs(dims) - fill!(v, dim(c)) - end - - isempty(parent(values)) && return similar(values, Bool) - - perm = sortperm(parent(values); strategy.by, strategy.rev) - cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) - - result = similar(values, Bool) - parent(result)[perm] .= cumulative_dim .<= strategy.howmany - return result -end - -function MatrixAlgebraKit.findtruncated( - values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByError - ) - (isfinite(strategy.p) && strategy.p > 0) || - throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported.")) - ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) - ϵᵖ = similar(values, typeof(ϵᵖmax)) - - # dimensions are all 1 so no need to account for weight - if FusionStyle(sectortype(values)) isa UniqueFusion - parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p - else - for (c, v) in pairs(values) - v′ = ϵᵖ[c] - v′ .= abs.(v) .^ strategy.p .* dim(c) - end - end - - isempty(parent(values)) && return similar(values, Bool) - - perm = sortperm(parent(values); by = abs, rev = false) - cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) - - result = similar(values, Bool) - parent(result)[perm] .= cumulative_err .> ϵᵖmax - return result -end - -function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy} - # returning a CuSectorVector wrecks things in truncate_{co}domain - # because of scalar indexing - return Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) -end - -for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace)) - @eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat) - # returning a CuSectorVector wrecks things in truncate_{co}domain - # because of scalar indexing - return Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) - end -end - -function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) - atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) - strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) - return SectorDict(c => Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) -end diff --git a/ext/TensorKitGPUArraysExt.jl b/ext/TensorKitGPUArraysExt.jl new file mode 100644 index 000000000..1295b4178 --- /dev/null +++ b/ext/TensorKitGPUArraysExt.jl @@ -0,0 +1,116 @@ +module TensorKitGPUArraysExt + +using GPUArrays +using GPUArrays: @allowscalar +using GPUArarys.KernelAbstractions: @kernel, @index, get_backend + +using TensorKit +using TensorKit.Factorizations +using TensorKit.Strided +using TensorKit.Factorizations: AbstractAlgorithm +using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check +import TensorKit: randisometry, rand, randn, fill_braidingsubblock! + +function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{<:AnyGPUMatrix{T}, <:StridedViews.StridedView{T, 4, <:AnyGPUArray{T}}}} + # COV_EXCL_START + # kernels are not reachable by coverage + @kernel function fill_subblock_kernel!(subblock, val) + idx = @index(Global, Cartesian) + idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val) + @inbounds subblock[idx] = idx_val + end + # COV_EXCL_STOP + kernel = fill_subblock_kernel!(get_backend(data)) + kernel(data, val; ndrange = size(data)) + return data +end + +const GPUSectorVector{T, I} = TensorKit.SectorVector{T, I, <:AnyGPUVector{T}} + +function MatrixAlgebraKit.findtruncated( + values::GPUSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder + ) + I = sectortype(values) + + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) + end + + isempty(parent(values)) && return similar(values, Bool) + + perm = sortperm(parent(values); strategy.by, strategy.rev) + cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) + + result = similar(values, Bool) + parent(result)[perm] .= cumulative_dim .<= strategy.howmany + return result +end + +function MatrixAlgebraKit.findtruncated( + values::GPUSectorVector, strategy::MatrixAlgebraKit.TruncationByError + ) + (isfinite(strategy.p) && strategy.p > 0) || + throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported.")) + ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) + ϵᵖ = similar(values, typeof(ϵᵖmax)) + + # dimensions are all 1 so no need to account for weight + if FusionStyle(sectortype(values)) isa UniqueFusion + parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p + else + for (c, v) in pairs(values) + v′ = ϵᵖ[c] + v′ .= abs.(v) .^ strategy.p .* dim(c) + end + end + + isempty(parent(values)) && return similar(values, Bool) + + perm = sortperm(parent(values); by = abs, rev = false) + cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) + + result = similar(values, Bool) + parent(result)[perm] .= cumulative_err .> ϵᵖmax + return result +end + +function MatrixAlgebraKit.findtruncated_svd(values::GPUSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy} + # returning a GPUSectorVector wrecks things in truncate_{co}domain + # because of scalar indexing + return Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) +end + +for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace)) + @eval function MatrixAlgebraKit.findtruncated_svd(values::GPUSectorVector, strategy::$strat) + # returning a GPUSectorVector wrecks things in truncate_{co}domain + # because of scalar indexing + return Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) + end +end + +function MatrixAlgebraKit.findtruncated_svd(values::GPUSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) + atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) + strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) + return SectorDict(c => Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) +end + +# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy +function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: AnyGPUVector{T}} + h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V) + h_t = TensorKit.project_symmetric!(h_t, Array(data)) + # verify result + isapprox(Array(reshape(data, dims(h_t))), convert(Array, h_t); atol = tol) || + throw(ArgumentError("Data has non-zero elements at incompatible positions")) + return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V) +end + +# Scalar implementation +#----------------------- +function TensorKit.scalar(t::TensorMap{T, S, 0, 0, <:AnyGPUArra}) where {T, S} + inds = findall(!iszero, t.data) + return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)] +end + + +end diff --git a/test/Project.toml b/test/Project.toml index f8602fad0..3c9e9b743 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,6 @@ AllocCheck = "0.2" ChainRulesTestUtils = "1" Combinatorics = "1" cuTENSOR = "6" -GPUArrays = "11.3.1" JET = "0.9, 0.10, 0.11" ParallelTestRunner = "2" Test = "1" From 181fe9c46d2f4a21f3f69745eb1d1bd22f3d6bb8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 24 Jun 2026 11:15:58 -0400 Subject: [PATCH 2/3] Fixes --- ext/TensorKitGPUArraysExt.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitGPUArraysExt.jl b/ext/TensorKitGPUArraysExt.jl index 1295b4178..b38089c93 100644 --- a/ext/TensorKitGPUArraysExt.jl +++ b/ext/TensorKitGPUArraysExt.jl @@ -2,11 +2,12 @@ module TensorKitGPUArraysExt using GPUArrays using GPUArrays: @allowscalar -using GPUArarys.KernelAbstractions: @kernel, @index, get_backend +using GPUArrays.KernelAbstractions: @kernel, @index, get_backend +using Strided: StridedViews +using MatrixAlgebraKit using TensorKit using TensorKit.Factorizations -using TensorKit.Strided using TensorKit.Factorizations: AbstractAlgorithm using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check import TensorKit: randisometry, rand, randn, fill_braidingsubblock! @@ -107,7 +108,7 @@ end # Scalar implementation #----------------------- -function TensorKit.scalar(t::TensorMap{T, S, 0, 0, <:AnyGPUArra}) where {T, S} +function TensorKit.scalar(t::TensorMap{T, S, 0, 0, <:AnyGPUArray}) where {T, S} inds = findall(!iszero, t.data) return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)] end From 3677da54982ddccbf3f2e537340af6aaa5bb0765 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 24 Jun 2026 13:59:31 -0400 Subject: [PATCH 3/3] missing import --- ext/TensorKitGPUArraysExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorKitGPUArraysExt.jl b/ext/TensorKitGPUArraysExt.jl index b38089c93..179a112ea 100644 --- a/ext/TensorKitGPUArraysExt.jl +++ b/ext/TensorKitGPUArraysExt.jl @@ -5,7 +5,7 @@ using GPUArrays: @allowscalar using GPUArrays.KernelAbstractions: @kernel, @index, get_backend using Strided: StridedViews -using MatrixAlgebraKit +using MatrixAlgebraKit, Adapt using TensorKit using TensorKit.Factorizations using TensorKit.Factorizations: AbstractAlgorithm