Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,43 +1,48 @@
name = "ITensorBase"
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
version = "0.6.2"
version = "0.6.3"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
projects = ["benchmark", "dev", "docs", "examples", "test"]

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"

[weakdeps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

[extensions]
ITensorBaseAbstractTreesExt = "AbstractTrees"
ITensorBaseAdaptExt = "Adapt"
ITensorBaseBlockArraysExt = "BlockArrays"
ITensorBaseMooncakeExt = "Mooncake"
ITensorBaseTensorOperationsExt = "TensorOperations"

[compat]
AbstractTrees = "0.4.5"
Accessors = "0.1.39"
Adapt = "4.1.1"
ArrayLayouts = "1.11"
BlockArrays = "1.3"
Combinatorics = "1"
Compat = "4.16"
ConstructionBase = "1.6"
LinearAlgebra = "1.10"
Expand All @@ -46,7 +51,10 @@ OrderedCollections = "1.6"
Random = "1.10"
SimpleTraits = "0.9.4"
TensorAlgebra = "0.9.6"
TensorOperations = "5.3.1"
TermInterface = "2"
TupleTools = "1.6"
UUIDs = "1.10"
VectorInterface = "0.5, 0.6"
WrappedUnions = "0.3"
julia = "1.10"
13 changes: 0 additions & 13 deletions ext/ITensorBaseAbstractTreesExt/ITensorBaseAbstractTreesExt.jl

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module ITensorBaseTensorOperationsExt

using ITensorBase.TermInterface: arguments
using ITensorBase: ITensorBase, Optimal, denamed, inds, ismul, optimize_contraction_order,
substitute, symnameddims
using TensorOperations: TensorOperations, optimaltree

function contraction_tree_to_expr(f, tree)
return if !(tree isa AbstractVector)
f(tree)
else
prod(Base.Fix1(contraction_tree_to_expr, f), tree)
end
end

function ITensorBase.optimize_contraction_order(alg::Optimal, a)
@assert ismul(a)
ts = arguments(a)
inds_network = collect.(inds.(ts))
# Converting dims to Float64 to minimize overflow issues
inds_to_dims = Dict(i => Float64(length(denamed(i))) for i in reduce(∪, inds_network))
tree, _ = optimaltree(inds_network, inds_to_dims)
return contraction_tree_to_expr(i -> ts[i], tree)
end

end
10 changes: 10 additions & 0 deletions src/ITensorBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,14 @@ include("itensoroperator.jl")
include("index.jl")
include("quirks.jl")

# Lazy and symbolic ITensor expressions.
include("lazyitensors/baseextensions.jl")
include("lazyitensors/itensorbaseextensions.jl")
include("lazyitensors/applied.jl")
include("lazyitensors/lazyinterface.jl")
include("lazyitensors/lazybroadcast.jl")
include("lazyitensors/lazyitensor.jl")
include("lazyitensors/symbolicitensor.jl")
include("lazyitensors/evaluation_order.jl")

end
49 changes: 49 additions & 0 deletions src/lazyitensors/applied.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using AbstractTrees: AbstractTrees
using TermInterface: TermInterface, arguments, iscall, operation

# Generic functionality for Applied types, like `Mul`, `Add`, etc.
ismul(a) = iscall(a) && operation(a) ≡ *
head_applied(a) = operation(a)
iscall_applied(a) = true
isexpr_applied(a) = iscall(a)
function show_applied(io::IO, a)
args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(a))
print(io, "(", join(args, " $(operation(a)) "), ")")
return nothing
end
sorted_arguments_applied(a) = arguments(a)
children_applied(a) = arguments(a)
sorted_children_applied(a) = sorted_arguments(a)
function maketerm_applied(type, head, args, metadata)
term = type(args)
@assert head ≡ operation(term)
return term
end
map_arguments_applied(f, a) = Base.typename(typeof(a)).wrapper(map(f, arguments(a)))
function hash_applied(a, h::UInt64)
h = hash(Symbol(Base.typename(typeof(a)).wrapper), h)
for arg in arguments(a)
h = hash(arg, h)
end
return h
end

abstract type Applied end
TermInterface.head(a::Applied) = head_applied(a)
TermInterface.iscall(a::Applied) = iscall_applied(a)
TermInterface.isexpr(a::Applied) = isexpr_applied(a)
Base.show(io::IO, a::Applied) = show_applied(io, a)
TermInterface.sorted_arguments(a::Applied) = sorted_arguments_applied(a)
TermInterface.children(a::Applied) = children_applied(a)
TermInterface.sorted_children(a::Applied) = sorted_children_applied(a)
function TermInterface.maketerm(type::Type{<:Applied}, head, args, metadata)
return maketerm_applied(type, head, args, metadata)
end
map_arguments(f, a::Applied) = map_arguments_applied(f, a)
Base.hash(a::Applied, h::UInt64) = hash_applied(a, h)

struct Mul{A} <: Applied
arguments::Vector{A}
end
TermInterface.arguments(m::Mul) = getfield(m, :arguments)
TermInterface.operation(m::Mul) = *
3 changes: 3 additions & 0 deletions src/lazyitensors/baseextensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
generic_map(f, v) = map(f, v)
generic_map(f, v::AbstractDict) = Dict(eachindex(v) .=> map(f, values(v)))
generic_map(f, v::AbstractSet) = Set([f(x) for x in v])
110 changes: 110 additions & 0 deletions src/lazyitensors/evaluation_order.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
using TermInterface: arguments, arity, operation

# The time complexity of evaluating `f(args...)`.
function time_complexity(f, args...)
return error("Not implemented.")
end
# The space complexity of evaluating `f(args...)`.
function space_complexity(f, args...)
return error("Not implemented.")
end
# The space complexity of `args`.
function input_space_complexity(f, args...)
return error("Not implemented.")
end

function time_complexity(
::typeof(*), t1::AbstractITensor, t2::AbstractITensor
)
return prod(length ∘ denamed, (inds(t1) ∪ inds(t2)))
end
function time_complexity(
::typeof(+), t1::AbstractITensor, t2::AbstractITensor
)
@assert issetequal(dimnames(t1), dimnames(t2))
return prod(denamed, size(t1))
end
function time_complexity(::typeof(*), c::Number, t::AbstractITensor)
return prod(denamed, size(t))
end
function time_complexity(::typeof(*), t::AbstractITensor, c::Number)
return time_complexity(*, c, t)
end

function evaluation_time_complexity(a)
t = Ref(0)
opwalk(a) do f
return function (args...)
t[] += time_complexity(f, args...)
return f(args...)
end
end
return t[]
end

# The workspace complexity of evaluating expression.
function evaluation_space_complexity(a)
# TODO: Walk the expression and call `space_complexity` on each node.
return error("Not implemented.")
end
# The complexity of storing the arguments of the expression.
function argument_space_complexity(a)
# TODO: Walk the expression and call `input_space_complexity` on each node.
return error("Not implemented.")
end

# Flatten a nested expression down to a flat expression,
# removing information about the order of operations.
function flatten_expression(a)
if !iscall(a)
return a
elseif ismul(a)
flattened_arguments = mapreduce(to_mul_arguments, vcat, arguments(a))
return lazy(Mul(flattened_arguments))
else
return error("Variant not supported.")
end
end

function optimize_evaluation_order(alg, a)
if !iscall(a)
return a
elseif ismul(a)
return optimize_contraction_order(alg, a)
else
# TODO: Recurse into other operations, calling `optimize_evaluation_order`.
return error("Variant not supported.")
end
end

function optimize_evaluation_order(
a; alg = default_optimize_evaluation_order_alg(a)
)
return optimize_evaluation_order(alg, a)
end

abstract type EvaluationOrderAlgorithm end
struct Greedy <: EvaluationOrderAlgorithm end
# `Optimal` finds the cost-optimal contraction order. The method is provided by
# the TensorOperations extension.
struct Optimal <: EvaluationOrderAlgorithm end
default_optimize_evaluation_order_alg(a) = Greedy()

function optimize_contraction_order(alg, a)
return error("`alg = $alg` not supported.")
end

using Combinatorics: combinations
function optimize_contraction_order(alg::Greedy, a)
@assert ismul(a)
arity(a) in (1, 2) && return a
a1, a2 = argmin(combinations(arguments(a), 2)) do (a1, a2)
# Penalize outer product contractions.
# TODO: Still order the outer products by time complexity,
# say by checking if there are only outer products left.
isdisjoint(dimnames(a1), dimnames(a2)) && return typemax(Int)
return time_complexity(*, a1, a2)
end
contracted_arguments = [filter(∉((a1, a2)), arguments(a)); [a1 * a2]]
return optimize_contraction_order(alg, lazy(Mul(contracted_arguments)))
end
29 changes: 29 additions & 0 deletions src/lazyitensors/itensorbaseextensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Defined to avoid type piracy.
# TODO: Define a proper hash function
# in ITensorBase.jl, maybe one that is
# independent of the order of dimensions.
function _hash(a::ITensor, h::UInt64)
h = hash(:ITensor, h)
h = hash(denamed(a), h)
for i in inds(a)
h = hash(i, h)
end
return h
end
function _hash(x, h::UInt64)
return hash(x, h)
end

using AbstractTrees: AbstractTrees
# Only print the dimension names when printing with `AbstractTrees.print_tree`.
function AbstractTrees.printnode(io::IO, a::AbstractITensor)
dimnames_a = "{" * join(map(s -> "\"$s\"", dimnames(a)), ", ") * "}"
print(io, dimnames_a)
return nothing
end

# Custom version of `AbstractTrees.printnode` to
# avoid type piracy when overloading on `AbstractITensor`.
# Method specializations (`LazyITensor`, `SymbolicITensor`) live in
# `lazyitensor.jl` and `symbolicitensor.jl`.
printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x)
13 changes: 13 additions & 0 deletions src/lazyitensors/lazybroadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Lazy broadcasting.
struct LazyITensorStyle <: Base.Broadcast.AbstractArrayStyle{Any} end
function Broadcast.broadcasted(::LazyITensorStyle, f, as...)
return error("Arbitrary broadcasting not supported for LazyITensor.")
end
# Linear operations.
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(+), a1, a2) = a1 + a2
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(-), a1, a2) = a1 - a2
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), c::Number, a) = c * a
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), a, c::Number) = a * c
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), a::Number, b::Number) = a * b
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(/), a, c::Number) = a / c
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(-), a) = -a
Loading