diff --git a/Project.toml b/Project.toml index 9212700..ffdfee9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,16 @@ name = "ITensorBase" uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" -version = "0.6.2" +version = "0.6.3" authors = ["ITensor developers and contributors"] [workspace] projects = ["benchmark", "dev", "docs", "examples", "test"] [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,21 +18,23 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [weakdeps] -AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" [extensions] -ITensorBaseAbstractTreesExt = "AbstractTrees" ITensorBaseAdaptExt = "Adapt" ITensorBaseBlockArraysExt = "BlockArrays" ITensorBaseMooncakeExt = "Mooncake" +ITensorBaseTensorOperationsExt = "TensorOperations" [compat] AbstractTrees = "0.4.5" @@ -38,6 +42,7 @@ Accessors = "0.1.39" Adapt = "4.1.1" ArrayLayouts = "1.11" BlockArrays = "1.3" +Combinatorics = "1" Compat = "4.16" ConstructionBase = "1.6" LinearAlgebra = "1.10" @@ -46,7 +51,10 @@ OrderedCollections = "1.6" Random = "1.10" SimpleTraits = "0.9.4" TensorAlgebra = "0.9.6" +TensorOperations = "5.3.1" +TermInterface = "2" TupleTools = "1.6" UUIDs = "1.10" VectorInterface = "0.5, 0.6" +WrappedUnions = "0.3" julia = "1.10" diff --git a/ext/ITensorBaseAbstractTreesExt/ITensorBaseAbstractTreesExt.jl b/ext/ITensorBaseAbstractTreesExt/ITensorBaseAbstractTreesExt.jl deleted file mode 100644 index 75886ef..0000000 --- a/ext/ITensorBaseAbstractTreesExt/ITensorBaseAbstractTreesExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module ITensorBaseAbstractTreesExt - -using AbstractTrees: AbstractTrees -using ITensorBase: AbstractITensor, dimnames - -# Only print the dimension names when printing with `AbstractTrees.print_tree`. -function AbstractTrees.printnode(io::IO, a::AbstractITensor) - dimnames_a = "{" * join(map(s -> "\"$s\"", dimnames(a)), ", ") * "}" - print(io, dimnames_a) - return nothing -end - -end diff --git a/ext/ITensorBaseTensorOperationsExt/ITensorBaseTensorOperationsExt.jl b/ext/ITensorBaseTensorOperationsExt/ITensorBaseTensorOperationsExt.jl new file mode 100644 index 0000000..f1797d3 --- /dev/null +++ b/ext/ITensorBaseTensorOperationsExt/ITensorBaseTensorOperationsExt.jl @@ -0,0 +1,26 @@ +module ITensorBaseTensorOperationsExt + +using ITensorBase.TermInterface: arguments +using ITensorBase: ITensorBase, Optimal, denamed, inds, ismul, optimize_contraction_order, + substitute, symnameddims +using TensorOperations: TensorOperations, optimaltree + +function contraction_tree_to_expr(f, tree) + return if !(tree isa AbstractVector) + f(tree) + else + prod(Base.Fix1(contraction_tree_to_expr, f), tree) + end +end + +function ITensorBase.optimize_contraction_order(alg::Optimal, a) + @assert ismul(a) + ts = arguments(a) + inds_network = collect.(inds.(ts)) + # Converting dims to Float64 to minimize overflow issues + inds_to_dims = Dict(i => Float64(length(denamed(i))) for i in reduce(∪, inds_network)) + tree, _ = optimaltree(inds_network, inds_to_dims) + return contraction_tree_to_expr(i -> ts[i], tree) +end + +end diff --git a/src/ITensorBase.jl b/src/ITensorBase.jl index 2c8100c..37433e8 100644 --- a/src/ITensorBase.jl +++ b/src/ITensorBase.jl @@ -25,4 +25,14 @@ include("itensoroperator.jl") include("index.jl") include("quirks.jl") +# Lazy and symbolic ITensor expressions. +include("lazyitensors/baseextensions.jl") +include("lazyitensors/itensorbaseextensions.jl") +include("lazyitensors/applied.jl") +include("lazyitensors/lazyinterface.jl") +include("lazyitensors/lazybroadcast.jl") +include("lazyitensors/lazyitensor.jl") +include("lazyitensors/symbolicitensor.jl") +include("lazyitensors/evaluation_order.jl") + end diff --git a/src/lazyitensors/applied.jl b/src/lazyitensors/applied.jl new file mode 100644 index 0000000..0eaca17 --- /dev/null +++ b/src/lazyitensors/applied.jl @@ -0,0 +1,49 @@ +using AbstractTrees: AbstractTrees +using TermInterface: TermInterface, arguments, iscall, operation + +# Generic functionality for Applied types, like `Mul`, `Add`, etc. +ismul(a) = iscall(a) && operation(a) ≡ * +head_applied(a) = operation(a) +iscall_applied(a) = true +isexpr_applied(a) = iscall(a) +function show_applied(io::IO, a) + args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(a)) + print(io, "(", join(args, " $(operation(a)) "), ")") + return nothing +end +sorted_arguments_applied(a) = arguments(a) +children_applied(a) = arguments(a) +sorted_children_applied(a) = sorted_arguments(a) +function maketerm_applied(type, head, args, metadata) + term = type(args) + @assert head ≡ operation(term) + return term +end +map_arguments_applied(f, a) = Base.typename(typeof(a)).wrapper(map(f, arguments(a))) +function hash_applied(a, h::UInt64) + h = hash(Symbol(Base.typename(typeof(a)).wrapper), h) + for arg in arguments(a) + h = hash(arg, h) + end + return h +end + +abstract type Applied end +TermInterface.head(a::Applied) = head_applied(a) +TermInterface.iscall(a::Applied) = iscall_applied(a) +TermInterface.isexpr(a::Applied) = isexpr_applied(a) +Base.show(io::IO, a::Applied) = show_applied(io, a) +TermInterface.sorted_arguments(a::Applied) = sorted_arguments_applied(a) +TermInterface.children(a::Applied) = children_applied(a) +TermInterface.sorted_children(a::Applied) = sorted_children_applied(a) +function TermInterface.maketerm(type::Type{<:Applied}, head, args, metadata) + return maketerm_applied(type, head, args, metadata) +end +map_arguments(f, a::Applied) = map_arguments_applied(f, a) +Base.hash(a::Applied, h::UInt64) = hash_applied(a, h) + +struct Mul{A} <: Applied + arguments::Vector{A} +end +TermInterface.arguments(m::Mul) = getfield(m, :arguments) +TermInterface.operation(m::Mul) = * diff --git a/src/lazyitensors/baseextensions.jl b/src/lazyitensors/baseextensions.jl new file mode 100644 index 0000000..984b197 --- /dev/null +++ b/src/lazyitensors/baseextensions.jl @@ -0,0 +1,3 @@ +generic_map(f, v) = map(f, v) +generic_map(f, v::AbstractDict) = Dict(eachindex(v) .=> map(f, values(v))) +generic_map(f, v::AbstractSet) = Set([f(x) for x in v]) diff --git a/src/lazyitensors/evaluation_order.jl b/src/lazyitensors/evaluation_order.jl new file mode 100644 index 0000000..e34ec22 --- /dev/null +++ b/src/lazyitensors/evaluation_order.jl @@ -0,0 +1,110 @@ +using TermInterface: arguments, arity, operation + +# The time complexity of evaluating `f(args...)`. +function time_complexity(f, args...) + return error("Not implemented.") +end +# The space complexity of evaluating `f(args...)`. +function space_complexity(f, args...) + return error("Not implemented.") +end +# The space complexity of `args`. +function input_space_complexity(f, args...) + return error("Not implemented.") +end + +function time_complexity( + ::typeof(*), t1::AbstractITensor, t2::AbstractITensor + ) + return prod(length ∘ denamed, (inds(t1) ∪ inds(t2))) +end +function time_complexity( + ::typeof(+), t1::AbstractITensor, t2::AbstractITensor + ) + @assert issetequal(dimnames(t1), dimnames(t2)) + return prod(denamed, size(t1)) +end +function time_complexity(::typeof(*), c::Number, t::AbstractITensor) + return prod(denamed, size(t)) +end +function time_complexity(::typeof(*), t::AbstractITensor, c::Number) + return time_complexity(*, c, t) +end + +function evaluation_time_complexity(a) + t = Ref(0) + opwalk(a) do f + return function (args...) + t[] += time_complexity(f, args...) + return f(args...) + end + end + return t[] +end + +# The workspace complexity of evaluating expression. +function evaluation_space_complexity(a) + # TODO: Walk the expression and call `space_complexity` on each node. + return error("Not implemented.") +end +# The complexity of storing the arguments of the expression. +function argument_space_complexity(a) + # TODO: Walk the expression and call `input_space_complexity` on each node. + return error("Not implemented.") +end + +# Flatten a nested expression down to a flat expression, +# removing information about the order of operations. +function flatten_expression(a) + if !iscall(a) + return a + elseif ismul(a) + flattened_arguments = mapreduce(to_mul_arguments, vcat, arguments(a)) + return lazy(Mul(flattened_arguments)) + else + return error("Variant not supported.") + end +end + +function optimize_evaluation_order(alg, a) + if !iscall(a) + return a + elseif ismul(a) + return optimize_contraction_order(alg, a) + else + # TODO: Recurse into other operations, calling `optimize_evaluation_order`. + return error("Variant not supported.") + end +end + +function optimize_evaluation_order( + a; alg = default_optimize_evaluation_order_alg(a) + ) + return optimize_evaluation_order(alg, a) +end + +abstract type EvaluationOrderAlgorithm end +struct Greedy <: EvaluationOrderAlgorithm end +# `Optimal` finds the cost-optimal contraction order. The method is provided by +# the TensorOperations extension. +struct Optimal <: EvaluationOrderAlgorithm end +default_optimize_evaluation_order_alg(a) = Greedy() + +function optimize_contraction_order(alg, a) + return error("`alg = $alg` not supported.") +end + +using Combinatorics: combinations +function optimize_contraction_order(alg::Greedy, a) + @assert ismul(a) + arity(a) in (1, 2) && return a + a1, a2 = argmin(combinations(arguments(a), 2)) do (a1, a2) + # Penalize outer product contractions. + # TODO: Still order the outer products by time complexity, + # say by checking if there are only outer products left. + isdisjoint(dimnames(a1), dimnames(a2)) && return typemax(Int) + return time_complexity(*, a1, a2) + end + contracted_arguments = [filter(∉((a1, a2)), arguments(a)); [a1 * a2]] + return optimize_contraction_order(alg, lazy(Mul(contracted_arguments))) +end diff --git a/src/lazyitensors/itensorbaseextensions.jl b/src/lazyitensors/itensorbaseextensions.jl new file mode 100644 index 0000000..34886d0 --- /dev/null +++ b/src/lazyitensors/itensorbaseextensions.jl @@ -0,0 +1,29 @@ +# Defined to avoid type piracy. +# TODO: Define a proper hash function +# in ITensorBase.jl, maybe one that is +# independent of the order of dimensions. +function _hash(a::ITensor, h::UInt64) + h = hash(:ITensor, h) + h = hash(denamed(a), h) + for i in inds(a) + h = hash(i, h) + end + return h +end +function _hash(x, h::UInt64) + return hash(x, h) +end + +using AbstractTrees: AbstractTrees +# Only print the dimension names when printing with `AbstractTrees.print_tree`. +function AbstractTrees.printnode(io::IO, a::AbstractITensor) + dimnames_a = "{" * join(map(s -> "\"$s\"", dimnames(a)), ", ") * "}" + print(io, dimnames_a) + return nothing +end + +# Custom version of `AbstractTrees.printnode` to +# avoid type piracy when overloading on `AbstractITensor`. +# Method specializations (`LazyITensor`, `SymbolicITensor`) live in +# `lazyitensor.jl` and `symbolicitensor.jl`. +printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x) diff --git a/src/lazyitensors/lazybroadcast.jl b/src/lazyitensors/lazybroadcast.jl new file mode 100644 index 0000000..157747f --- /dev/null +++ b/src/lazyitensors/lazybroadcast.jl @@ -0,0 +1,13 @@ +# Lazy broadcasting. +struct LazyITensorStyle <: Base.Broadcast.AbstractArrayStyle{Any} end +function Broadcast.broadcasted(::LazyITensorStyle, f, as...) + return error("Arbitrary broadcasting not supported for LazyITensor.") +end +# Linear operations. +Broadcast.broadcasted(::LazyITensorStyle, ::typeof(+), a1, a2) = a1 + a2 +Broadcast.broadcasted(::LazyITensorStyle, ::typeof(-), a1, a2) = a1 - a2 +Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), c::Number, a) = c * a +Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), a, c::Number) = a * c +Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), a::Number, b::Number) = a * b +Broadcast.broadcasted(::LazyITensorStyle, ::typeof(/), a, c::Number) = a / c +Broadcast.broadcasted(::LazyITensorStyle, ::typeof(-), a) = -a diff --git a/src/lazyitensors/lazyinterface.jl b/src/lazyitensors/lazyinterface.jl new file mode 100644 index 0000000..111ca18 --- /dev/null +++ b/src/lazyitensors/lazyinterface.jl @@ -0,0 +1,213 @@ +using TermInterface: iscall, maketerm, operation, sorted_arguments +using WrappedUnions: unwrap + +lazy(x) = error("Not defined.") + +# Walk the expression `ex`, modifying the +# operations by `opmap` and the arguments by `argmap`. +function walk(opmap, argmap, ex) + if !iscall(ex) + return argmap(ex) + else + return mapfoldl(opmap(operation(ex)), arguments(ex)) do (args...) + return walk(opmap, argmap, args...) + end + end +end +# Walk the expression `ex`, modifying the +# operations by `opmap`. +opwalk(opmap, a) = walk(opmap, identity, a) +# Walk the expression `ex`, modifying the +# arguments by `argmap`. +argwalk(argmap, a) = walk(identity, argmap, a) + +# Generic lazy functionality. +function maketerm_lazy(type::Type, head, args, metadata) + if head ≡ * + return type(maketerm(Mul, head, args, metadata)) + else + return error("Only mul supported right now.") + end +end +function getindex_lazy(a::AbstractArray, I...) + u = unwrap(a) + if !iscall(u) + return u[I...] + else + return error("Indexing into expression not supported.") + end +end +function arguments_lazy(a) + u = unwrap(a) + if !iscall(u) + return error("No arguments.") + elseif ismul(u) + return arguments(u) + else + return error("Variant not supported.") + end +end +using TermInterface: children +children_lazy(a) = arguments(a) +using TermInterface: head +head_lazy(a) = operation(a) +iscall_lazy(a) = iscall(unwrap(a)) +using TermInterface: isexpr +isexpr_lazy(a) = iscall(a) +function operation_lazy(a) + u = unwrap(a) + if !iscall(u) + return error("No operation.") + elseif ismul(u) + return operation(u) + else + return error("Variant not supported.") + end +end +function sorted_arguments_lazy(a) + u = unwrap(a) + if !iscall(u) + return error("No arguments.") + elseif ismul(u) + return sorted_arguments(u) + else + return error("Variant not supported.") + end +end +using TermInterface: sorted_children +sorted_children_lazy(a) = sorted_arguments(a) +ismul_lazy(a) = ismul(unwrap(a)) +using AbstractTrees: AbstractTrees +function abstracttrees_children_lazy(a) + if !iscall(a) + return () + else + return arguments(a) + end +end +using AbstractTrees: nodevalue +function nodevalue_lazy(a) + if !iscall(a) + return unwrap(a) + else + return operation(a) + end +end +using Base.Broadcast: materialize +materialize_lazy(a) = argwalk(unwrap, a) +copy_lazy(a) = materialize(a) +function equals_lazy(a1, a2) + u1, u2 = unwrap.((a1, a2)) + if !iscall(u1) && !iscall(u2) + return u1 == u2 + elseif ismul(u1) && ismul(u2) + return arguments(u1) == arguments(u2) + else + return false + end +end +function isequal_lazy(a1, a2) + u1, u2 = unwrap.((a1, a2)) + if !iscall(u1) && !iscall(u2) + return isequal(u1, u2) + elseif ismul(u1) && ismul(u2) + return isequal(arguments(u1), arguments(u2)) + else + return false + end +end +function hash_lazy(a, h::UInt64) + h = hash(Symbol(Base.typename(typeof(a)).wrapper), h) + # Use `_hash`, which defines a custom hash for ITensor. + return _hash(unwrap(a), h) +end +function map_arguments_lazy(f, a) + u = unwrap(a) + if !iscall(u) + return error("No arguments to map.") + elseif ismul(u) + return lazy(map_arguments(f, u)) + else + return error("Variant not supported.") + end +end +function substitute end +function substitute_lazy(a, substitutions::AbstractDict) + haskey(substitutions, a) && return substitutions[a] + !iscall(a) && return a + return map_arguments(arg -> substitute(arg, substitutions), a) +end +substitute_lazy(a, substitutions) = substitute(a, Dict(substitutions)) +using AbstractTrees: printnode +function printnode_lazy(io, a) + # Use `printnode_nameddims` to avoid type piracy, + # since it overloads on `AbstractITensor`. + return printnode_nameddims(io, unwrap(a)) +end +function show_lazy(io::IO, a) + if !iscall(a) + return show(io, unwrap(a)) + else + return AbstractTrees.printnode(io, a) + end +end +function show_lazy(io::IO, mime::MIME"text/plain", a) + summary(io, a) + println(io, ":") + !iscall(a) ? show(io, mime, unwrap(a)) : show(io, a) + return nothing +end +add_lazy(a1, a2) = error("Not implemented.") +sub_lazy(a) = error("Not implemented.") +sub_lazy(a1, a2) = error("Not implemented.") +function mul_lazy(a) + u = unwrap(a) + if !iscall(u) + return lazy(Mul([a])) + elseif ismul(u) + return a + else + return error("Variant not supported.") + end +end +# Note that this is nested by default. +function mul_lazy(a1, a2; flatten::Bool = false) + return flatten ? mul_lazy_flattened(a1, a2) : mul_lazy_nested(a1, a2) +end +mul_lazy_nested(a1, a2) = lazy(Mul([a1, a2])) +to_mul_arguments(a) = ismul(a) ? arguments(a) : [a] +mul_lazy_flattened(a1, a2) = lazy(Mul([to_mul_arguments(a1); to_mul_arguments(a2)])) +mul_lazy(a1::Number, a2) = error("Not implemented.") +mul_lazy(a1, a2::Number) = error("Not implemented.") +mul_lazy(a1::Number, a2::Number) = a1 * a2 +div_lazy(a1, a2::Number) = error("Not implemented.") + +# ITensorBase.jl named-tensor interface. +function dimnames_lazy(a) + u = unwrap(a) + if !iscall(u) + return dimnames(u) + elseif ismul(u) + return mapreduce(dimnames, symdiff, arguments(u)) + else + return error("Variant not supported.") + end +end +function inds_lazy(a) + u = unwrap(a) + if !iscall(u) + return inds(u) + elseif ismul(u) + return mapreduce(inds, symdiff, arguments(u)) + else + return error("Variant not supported.") + end +end +function denamed_lazy(a) + u = unwrap(a) + if !iscall(u) + return denamed(u) + else + return error("Variant not supported.") + end +end diff --git a/src/lazyitensors/lazyitensor.jl b/src/lazyitensors/lazyitensor.jl new file mode 100644 index 0000000..c440bef --- /dev/null +++ b/src/lazyitensors/lazyitensor.jl @@ -0,0 +1,66 @@ +using WrappedUnions: @wrapped + +@wrapped struct LazyITensor{ + DimName, A <: AbstractITensor{DimName}, + } <: AbstractITensor{DimName} + union::Union{A, Mul{LazyITensor{DimName, A}}} +end + +parenttype(::Type{LazyITensor{DimName, A}}) where {DimName, A} = A +parenttype(::Type{LazyITensor{DimName}}) where {DimName} = AbstractITensor{DimName} +parenttype(::Type{LazyITensor}) = AbstractITensor + +function LazyITensor(a::AbstractITensor) + return LazyITensor{dimnametype(typeof(a)), typeof(a)}(a) +end +function LazyITensor(a::Mul{L}) where {L <: LazyITensor} + return LazyITensor{dimnametype(L), parenttype(L)}(a) +end +lazy(a::LazyITensor) = a +lazy(a::AbstractITensor) = LazyITensor(a) +lazy(a::Mul{<:LazyITensor}) = LazyITensor(a) + +dimnames(a::LazyITensor) = dimnames_lazy(a) +inds(a::LazyITensor) = inds_lazy(a) +denamed(a::LazyITensor) = denamed_lazy(a) + +# Broadcasting +function Base.BroadcastStyle(::Type{<:LazyITensor}) + return LazyITensorStyle() +end + +# Derived functionality. +function TermInterface.maketerm(type::Type{LazyITensor}, head, args, metadata) + return maketerm_lazy(type, head, args, metadata) +end +Base.getindex(a::LazyITensor, I::Int...) = getindex_lazy(a, I...) +TermInterface.arguments(a::LazyITensor) = arguments_lazy(a) +TermInterface.children(a::LazyITensor) = children_lazy(a) +TermInterface.head(a::LazyITensor) = head_lazy(a) +TermInterface.iscall(a::LazyITensor) = iscall_lazy(a) +TermInterface.isexpr(a::LazyITensor) = isexpr_lazy(a) +TermInterface.operation(a::LazyITensor) = operation_lazy(a) +TermInterface.sorted_arguments(a::LazyITensor) = sorted_arguments_lazy(a) +AbstractTrees.children(a::LazyITensor) = abstracttrees_children_lazy(a) +TermInterface.sorted_children(a::LazyITensor) = sorted_children_lazy(a) +ismul(a::LazyITensor) = ismul_lazy(a) +AbstractTrees.nodevalue(a::LazyITensor) = nodevalue_lazy(a) +Base.Broadcast.materialize(a::LazyITensor) = materialize_lazy(a) +Base.copy(a::LazyITensor) = copy_lazy(a) +Base.:(==)(a1::LazyITensor, a2::LazyITensor) = equals_lazy(a1, a2) +Base.isequal(a1::LazyITensor, a2::LazyITensor) = isequal_lazy(a1, a2) +Base.hash(a::LazyITensor, h::UInt64) = hash_lazy(a, h) +map_arguments(f, a::LazyITensor) = map_arguments_lazy(f, a) +substitute(a::LazyITensor, substitutions) = substitute_lazy(a, substitutions) +AbstractTrees.printnode(io::IO, a::LazyITensor) = printnode_lazy(io, a) +printnode_nameddims(io::IO, a::LazyITensor) = printnode_lazy(io, a) +Base.show(io::IO, a::LazyITensor) = show_lazy(io, a) +Base.show(io::IO, mime::MIME"text/plain", a::LazyITensor) = show_lazy(io, mime, a) +Base.:*(a::LazyITensor) = mul_lazy(a) +Base.:*(a1::LazyITensor, a2::LazyITensor) = mul_lazy(a1, a2) +Base.:+(a1::LazyITensor, a2::LazyITensor) = add_lazy(a1, a2) +Base.:-(a1::LazyITensor, a2::LazyITensor) = sub_lazy(a1, a2) +Base.:*(a1::Number, a2::LazyITensor) = mul_lazy(a1, a2) +Base.:*(a1::LazyITensor, a2::Number) = mul_lazy(a1, a2) +Base.:/(a1::LazyITensor, a2::Number) = div_lazy(a1, a2) +Base.:-(a::LazyITensor) = sub_lazy(a) diff --git a/src/lazyitensors/symbolicitensor.jl b/src/lazyitensors/symbolicitensor.jl new file mode 100644 index 0000000..1dff729 --- /dev/null +++ b/src/lazyitensors/symbolicitensor.jl @@ -0,0 +1,75 @@ +# Expression leaf with no array payload, so it defines no `denamed`/`getindex`. +# A symbolic tensor is a placeholder substituted with a real tensor before +# contraction, so it only needs what drives contraction-order selection: the +# `dimnames` and the index `size`s (the cost model uses lengths). `inds` is +# reconstructed as plain ranges of those sizes. Storing sizes and dimnames as +# fields rather than type parameters lets symbolic tensors of different rank +# share one concrete type so a flat `Mul` over them stays concretely typed. +struct SymbolicITensor{DimName, Name} <: AbstractITensor{DimName} + name::Name + size::Vector{Int} + dimnames::Vector{DimName} +end +function SymbolicITensor(symname, inds) + dnames = collect(name.(inds)) + DimName = isempty(inds) ? typeof(symname) : eltype(dnames) + sizes = Int[length(denamed(i)) for i in inds] + return SymbolicITensor{DimName, typeof(symname)}(symname, sizes, dnames) +end + +symname(a::SymbolicITensor) = getfield(a, :name) + +dimnames(a::SymbolicITensor) = getfield(a, :dimnames) +function inds(a::SymbolicITensor) + return named.(Tuple(Base.OneTo.(getfield(a, :size))), Tuple(getfield(a, :dimnames))) +end +dimnametype(::Type{<:SymbolicITensor{DimName}}) where {DimName} = DimName +Base.ndims(a::SymbolicITensor) = length(getfield(a, :dimnames)) + +function Base.:(==)(a::SymbolicITensor, b::SymbolicITensor) + return symname(a) == symname(b) && dimnames(a) == dimnames(b) +end +Base.isequal(a::SymbolicITensor, b::SymbolicITensor) = a == b +function Base.hash(a::SymbolicITensor, h::UInt64) + h = hash(:SymbolicITensor, h) + h = hash(symname(a), h) + return hash(dimnames(a), h) +end + +# Products build lazy expressions rather than contracting numerically. +Base.:*(a::SymbolicITensor, b::SymbolicITensor) = lazy(a) * lazy(b) +Base.:*(a::SymbolicITensor, b::LazyITensor) = lazy(a) * b +Base.:*(a::LazyITensor, b::SymbolicITensor) = a * lazy(b) + +issymbolic(a) = a isa SymbolicITensor +issymbolic(a::LazyITensor) = !iscall(a) && issymbolic(unwrap(a)) + +function Base.show(io::IO, a::SymbolicITensor) + print(io, symname(a)) + if ndims(a) > 0 + print(io, "[", join(dimnames(a), ","), "]") + end + return nothing +end +function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicITensor) + summary(io, a) + println(io, ":") + show(io, a) + return nothing +end + +using AbstractTrees: AbstractTrees +function AbstractTrees.printnode(io::IO, a::SymbolicITensor) + show(io, a) + return nothing +end + +function symnameddims(symname, dims) + return lazy(SymbolicITensor(symname, dims)) +end +symnameddims(name) = symnameddims(name, ()) + +function printnode_nameddims(io::IO, a::SymbolicITensor) + AbstractTrees.printnode(io, a) + return nothing +end diff --git a/test/Project.toml b/test/Project.toml index e85dbbe..6088dc5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,9 +14,12 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [sources.ITensorBase] path = ".." @@ -37,6 +40,9 @@ SafeTestsets = "0.1" StableRNGs = "1" Suppressor = "0.2" TensorAlgebra = "0.9.5" +TensorOperations = "5.3.1" +TermInterface = "2" Test = "1.10" UUIDs = "1.10" VectorInterface = "0.5, 0.6" +WrappedUnions = "0.3" diff --git a/test/test_abstracttreesext.jl b/test/test_abstracttrees.jl similarity index 85% rename from test/test_abstracttreesext.jl rename to test/test_abstracttrees.jl index 4660cb8..bce94e9 100644 --- a/test/test_abstracttreesext.jl +++ b/test/test_abstracttrees.jl @@ -2,7 +2,7 @@ using AbstractTrees: printnode using ITensorBase: nameddims using Test: @test, @testset -@testset "AbstractTreesExt" begin +@testset "AbstractTrees" begin a = randn(3, 4) na = nameddims(a, ("i", "j")) @test sprint(printnode, na) == "{\"i\", \"j\"}" diff --git a/test/test_lazyitensors.jl b/test/test_lazyitensors.jl new file mode 100644 index 0000000..f35e008 --- /dev/null +++ b/test/test_lazyitensors.jl @@ -0,0 +1,122 @@ +using AbstractTrees: AbstractTrees, print_tree, printnode +using Base.Broadcast: materialize +using ITensorBase: @names, Greedy, ITensor, LazyITensor, Mul, Optimal, SymbolicITensor, + dimnames, inds, ismul, lazy, nameddims, namedoneto, optimize_evaluation_order, + substitute, symnameddims +using TensorOperations: TensorOperations +using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation, + sorted_arguments, sorted_children +using Test: @test, @test_broken, @test_throws, @testset +using WrappedUnions: unwrap + +@testset "LazyITensors" begin + @testset "Basics" begin + i, j, k, l = namedoneto.(2, (:i, :j, :k, :l)) + a1 = randn(i, j) + a2 = randn(j, k) + a3 = randn(k, l) + l1, l2, l3 = lazy.((a1, a2, a3)) + for li in (l1, l2, l3) + @test li isa LazyITensor + @test unwrap(li) isa ITensor + @test inds(li) == inds(unwrap(li)) + @test copy(li) == unwrap(li) + @test materialize(li) == unwrap(li) + end + l = l1 * l2 * l3 + @test copy(l) ≈ a1 * a2 * a3 + @test materialize(l) ≈ a1 * a2 * a3 + @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) + @test unwrap(l) isa Mul + @test ismul(unwrap(l)) + @test unwrap(l).arguments == [l1 * l2, l3] + # TermInterface.jl + @test operation(unwrap(l)) ≡ * + @test arguments(unwrap(l)) == [l1 * l2, l3] + end + + @testset "TermInterface" begin + a1 = nameddims(randn(2, 2), (:i, :j)) + a2 = nameddims(randn(2, 2), (:j, :k)) + a3 = nameddims(randn(2, 2), (:k, :l)) + l1, l2, l3 = lazy.((a1, a2, a3)) + + @test_throws ErrorException arguments(l1) + @test_throws ErrorException arity(l1) + @test_throws ErrorException children(l1) + @test_throws ErrorException head(l1) + @test !iscall(l1) + @test !isexpr(l1) + @test_throws ErrorException operation(l1) + @test_throws ErrorException sorted_arguments(l1) + @test_throws ErrorException sorted_children(l1) + @test AbstractTrees.children(l1) ≡ () + @test AbstractTrees.nodevalue(l1) ≡ a1 + @test sprint(show, l1) == sprint(show, a1) + # The leaf format mirrors ITensorBase's display of a tensor's index names. + @test sprint(printnode, l1) == "{\"i\", \"j\"}" + @test sprint(print_tree, l1) == "{\"i\", \"j\"}\n" + + l = l1 * l2 * l3 + @test arguments(l) == [l1 * l2, l3] + @test arity(l) == 2 + @test children(l) == [l1 * l2, l3] + @test head(l) ≡ * + @test iscall(l) + @test isexpr(l) + @test l == maketerm(LazyITensor, *, [l1 * l2, l3], nothing) + @test operation(l) ≡ * + @test sorted_arguments(l) == [l1 * l2, l3] + @test sorted_children(l) == [l1 * l2, l3] + @test AbstractTrees.children(l) == [l1 * l2, l3] + @test AbstractTrees.nodevalue(l) ≡ * + @test sprint(show, l) == "(({\"i\", \"j\"} * {\"j\", \"k\"}) * {\"k\", \"l\"})" + @test sprint(printnode, l) == "(({\"i\", \"j\"} * {\"j\", \"k\"}) * {\"k\", \"l\"})" + @test sprint(print_tree, l) == + "(({\"i\", \"j\"} * {\"j\", \"k\"}) * {\"k\", \"l\"})\n" * + "├─ ({\"i\", \"j\"} * {\"j\", \"k\"})\n" * + "│ ├─ {\"i\", \"j\"}\n│ └─ {\"j\", \"k\"}\n" * + "└─ {\"k\", \"l\"}\n" + end + + @testset "symnameddims" begin + a1, a2, a3 = symnameddims.((:a1, :a2, :a3)) + @test a1 isa LazyITensor + @test unwrap(a1) isa SymbolicITensor + @test unwrap(a1) == SymbolicITensor(:a1, ()) + @test isequal(unwrap(a1), SymbolicITensor(:a1, ())) + @test inds(a1) == () + @test isempty(dimnames(a1)) + + ex = a1 * a2 * a3 + @test copy(ex) == ex + @test arguments(ex) == [a1 * a2, a3] + @test operation(ex) ≡ * + @test sprint(show, ex) == "((a1 * a2) * a3)" + end + + @testset "substitute" begin + s = symnameddims.((:a1, :a2, :a3)) + i = @names i[1:4] + a = (randn(2, 2)[i[1], i[2]], randn(2, 2)[i[2], i[3]], randn(2, 2)[i[3], i[4]]) + l = lazy.(a) + + seq = s[1] * (s[2] * s[3]) + net = substitute(seq, s .=> l) + @test net == l[1] * (l[2] * l[3]) + @test arguments(net) == [l[1], l[2] * l[3]] + end + + @testset "optimize_evaluation_order ($alg)" for alg in (Greedy(), Optimal()) + i, j, k, l = namedoneto.((2, 3, 4, 5), (:i, :j, :k, :l)) + s = [symnameddims(:a, (i, j)), symnameddims(:b, (j, k)), symnameddims(:c, (k, l))] + flat = lazy(Mul(s)) + ordered = optimize_evaluation_order(flat; alg) + @test ordered isa LazyITensor + @test ismul(ordered) + # Reordering nests the flat product into binary contractions and preserves + # the open indices. + @test arity(ordered) == 2 + @test issetequal(dimnames(ordered), dimnames(flat)) + end +end