Skip to content

Commit 722a24a

Browse files
authored
Upgrade to TensorAlgebra v0.6 (#83)
1 parent 8173240 commit 722a24a

File tree

11 files changed

+169
-127
lines changed

11 files changed

+169
-127
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
3-
version = "0.5.4"
43
authors = ["ITensor developers <support@itensor.org> and contributors"]
4+
version = "0.5.5"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -16,7 +16,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1616
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
1717
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1818
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
19-
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2019
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2120

2221
[weakdeps]
@@ -37,8 +36,7 @@ MatrixAlgebraKit = "0.6"
3736
Random = "1.10"
3837
SUNRepresentations = "0.3"
3938
SplitApplyCombine = "1.2.3"
40-
TensorAlgebra = "0.5"
39+
TensorAlgebra = "0.6.2"
4140
TensorKitSectors = "0.1, 0.2"
42-
TensorProducts = "0.1.3"
4341
TypeParameterAccessors = "0.4"
4442
julia = "1.10"

src/fusion.jl

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
using BlockArrays: Block, blocks
22
using SplitApplyCombine: groupcount
3-
using TensorProducts: TensorProducts, , OneToOne, tensor_product
43

54
flip_dual(r::AbstractUnitRange) = isdual(r) ? flip(r) : r
65

7-
# TensorProducts interface
8-
function TensorProducts.tensor_product(sr1::SectorUnitRange, sr2::SectorUnitRange)
6+
function tensor_product(sr1::SectorUnitRange, sr2::SectorUnitRange)
97
return tensor_product(combine_styles(SymmetryStyle(sr1), SymmetryStyle(sr2)), sr1, sr2)
108
end
119

12-
function TensorProducts.tensor_product(
10+
function tensor_product(
1311
::AbelianStyle, sr1::SectorUnitRange, sr2::SectorUnitRange
1412
)
1513
s = sector(flip_dual(sr1)) sector(flip_dual(sr2))
1614
return sectorrange(s, sector_multiplicity(sr1) * sector_multiplicity(sr2))
1715
end
1816

19-
function TensorProducts.tensor_product(
17+
function tensor_product(
2018
::NotAbelianStyle, sr1::SectorUnitRange, sr2::SectorUnitRange
2119
)
2220
g0 = sector(flip_dual(sr1)) sector(flip_dual(sr2))
@@ -27,41 +25,52 @@ function TensorProducts.tensor_product(
2725
end
2826

2927
# allow to fuse a Sector with a GradedUnitRange
30-
function TensorProducts.tensor_product(
28+
function tensor_product(
3129
s::Union{SectorRange, SectorUnitRange}, g::AbstractGradedUnitRange
3230
)
3331
return to_gradedrange(s) g
3432
end
3533

36-
function TensorProducts.tensor_product(
34+
function tensor_product(
3735
g::AbstractGradedUnitRange, s::Union{SectorRange, SectorUnitRange}
3836
)
3937
return g to_gradedrange(s)
4038
end
4139

42-
function TensorProducts.tensor_product(sr::SectorUnitRange, s::SectorRange)
40+
function tensor_product(sr::SectorUnitRange, s::SectorRange)
4341
return sr sectorrange(s, 1)
4442
end
4543

46-
function TensorProducts.tensor_product(s::SectorRange, sr::SectorUnitRange)
44+
function tensor_product(s::SectorRange, sr::SectorUnitRange)
4745
return sectorrange(s, 1) sr
4846
end
4947

48+
function tensor_product(r1::AbstractUnitRange, r2::AbstractUnitRange)
49+
(isone(first(r1)) && isone(first(r2))) ||
50+
throw(ArgumentError("Only one-based axes are supported"))
51+
return Base.OneTo(length(r1) * length(r2))
52+
end
53+
54+
function tensor_product(
55+
r1::AbstractUnitRange, r2::AbstractUnitRange, r3::AbstractUnitRange,
56+
rs::AbstractUnitRange...,
57+
)
58+
return tensor_product(tensor_product(r1, r2), r3, rs...)
59+
end
60+
5061
# unmerged_tensor_product is a private function needed in GradedArraysTensorAlgebraExt
5162
# to get block permutation
5263
# it is not aimed for generic use and does not support all tensor_product methods (no dispatch on SymmetryStyle)
53-
unmerged_tensor_product() = OneToOne()
64+
unmerged_tensor_product() = Base.OneTo(1)
5465
unmerged_tensor_product(a) = a
55-
unmerged_tensor_product(a, ::OneToOne) = a
56-
unmerged_tensor_product(::OneToOne, a) = a
57-
unmerged_tensor_product(::OneToOne, ::OneToOne) = OneToOne()
58-
function unmerged_tensor_product(a1, a2, as...)
59-
return unmerged_tensor_product(unmerged_tensor_product(a1, a2), as...)
66+
function unmerged_tensor_product(a1, a2, a3, as...)
67+
return unmerged_tensor_product(unmerged_tensor_product(a1, a2), a3, as...)
6068
end
6169

6270
# default to tensor_product
6371
unmerged_tensor_product(a1, a2) = a1 a2
6472

73+
using BlockSparseArrays: mortar_axis
6574
function unmerged_tensor_product(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
6675
new_axes = map(splat(), Iterators.flatten((Iterators.product(blocks(a1), blocks(a2)),)))
6776
return mortar_axis(new_axes)
@@ -103,9 +112,9 @@ end
103112
sectormergesort(g::AbstractUnitRange) = g
104113

105114
# tensor_product produces a sorted, non-dual GradedUnitRange
106-
TensorProducts.tensor_product(g::AbstractGradedUnitRange) = sectormergesort(flip_dual(g))
115+
tensor_product(g::AbstractGradedUnitRange) = sectormergesort(flip_dual(g))
107116

108-
function TensorProducts.tensor_product(
117+
function tensor_product(
109118
g1::AbstractGradedUnitRange, g2::AbstractGradedUnitRange
110119
)
111120
return sectormergesort(unmerged_tensor_product(g1, g2))

src/sectorrange.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# This file defines the interface for type Sector
22
# all fusion categories (Z{2}, SU2, Ising...) are subtypes of Sector
3-
using TensorProducts: TensorProducts,
43
import TensorKitSectors as TKS
54

65
"""
@@ -122,17 +121,19 @@ function fusion_rule(r1::SectorRange, r2::SectorRange)
122121
)
123122
end
124123

125-
# ============================= TensorProducts interface =====--==========================
124+
# ============================= Tensor products ==========================================
126125

127-
TensorProducts.tensor_product(s::SectorRange) = s
128-
TensorProducts.tensor_product(c1::SectorRange, c2::SectorRange) = fusion_rule(c1, c2)
129-
function TensorProducts.tensor_product(c1::TKS.Sector, c2::TKS.Sector)
126+
function tensor_product end
127+
const = tensor_product
128+
tensor_product(s::SectorRange) = s
129+
tensor_product(c1::SectorRange, c2::SectorRange) = fusion_rule(c1, c2)
130+
function tensor_product(c1::TKS.Sector, c2::TKS.Sector)
130131
return tensor_product(to_sector(c1), to_sector(c2))
131132
end
132-
function TensorProducts.tensor_product(c1::SectorRange, c2::TKS.Sector)
133+
function tensor_product(c1::SectorRange, c2::TKS.Sector)
133134
return tensor_product(c1, to_sector(c2))
134135
end
135-
function TensorProducts.tensor_product(c1::TKS.Sector, c2::SectorRange)
136+
function tensor_product(c1::TKS.Sector, c2::SectorRange)
136137
return tensor_product(to_sector(c1), c2)
137138
end
138139

src/tensoralgebra.jl

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,89 @@
1-
using BlockArrays: blocks
1+
using BlockArrays: blocks, eachblockaxes1
22
using BlockSparseArrays: BlockSparseArray, blockreshape
3-
using GradedArrays:
4-
AbstractGradedUnitRange,
5-
SectorRange,
6-
GradedArray,
7-
flip,
8-
gradedrange,
9-
invblockperm,
10-
sectormergesortperm,
11-
sectorsortperm,
12-
trivial,
13-
unmerged_tensor_product
14-
using TensorAlgebra:
15-
TensorAlgebra,
16-
,
17-
AbstractBlockPermutation,
18-
BlockedTuple,
19-
FusionStyle,
20-
trivial_axis,
21-
unmatricize
3+
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockReshapeFusion,
4+
BlockedTuple, FusionStyle, ReshapeFusion, matricize, matricize_axes,
5+
tensor_product_axis, unmatricize
226

237
struct SectorFusion <: FusionStyle end
248

259
TensorAlgebra.FusionStyle(::Type{<:GradedArray}) = SectorFusion()
2610

27-
function TensorAlgebra.trivial_axis(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange}
11+
function TensorAlgebra.trivial_axis(
12+
::BlockReshapeFusion,
13+
::Val{:codomain},
14+
a::GradedArray,
15+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
16+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
17+
)
18+
return trivial_gradedrange(axes(a))
19+
end
20+
function TensorAlgebra.trivial_axis(
21+
::BlockReshapeFusion,
22+
::Val{:domain},
23+
a::GradedArray,
24+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
25+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
26+
)
27+
return flip(trivial_gradedrange(axes(a)))
28+
end
29+
function trivial_gradedrange(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange}
2830
return trivial(first(t))
2931
end
3032
# heterogeneous sectors
31-
TensorAlgebra.trivial_axis(t::Tuple{Vararg{AbstractGradedUnitRange}}) = (trivial.(t)...)
33+
trivial_gradedrange(t::Tuple{Vararg{AbstractGradedUnitRange}}) = (trivial.(t)...)
3234
# trivial_axis from sector_type
33-
function TensorAlgebra.trivial_axis(::Type{S}) where {S <: SectorRange}
35+
function trivial_gradedrange(::Type{S}) where {S <: SectorRange}
3436
return gradedrange([trivial(S) => 1])
3537
end
3638

37-
function matricize_axes(
38-
blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}}
39+
function TensorAlgebra.tensor_product_axis(
40+
::ReshapeFusion, ::Val{:codomain}, r1::SectorUnitRange, r2::SectorUnitRange
3941
)
40-
@assert !isempty(blocked_axes)
41-
default_axis = trivial_axis(Tuple(blocked_axes))
42-
codomain_axes, domain_axes = blocks(blocked_axes)
43-
codomain_axis = unmerged_tensor_product(default_axis, codomain_axes...)
44-
unflipped_domain_axis = unmerged_tensor_product(default_axis, domain_axes...)
45-
return codomain_axis, flip(unflipped_domain_axis)
42+
return r1 r2
43+
end
44+
function TensorAlgebra.tensor_product_axis(
45+
::ReshapeFusion, ::Val{:domain}, r1::SectorUnitRange, r2::SectorUnitRange
46+
)
47+
return flip(r1 r2)
48+
end
49+
function TensorAlgebra.tensor_product_axis(
50+
style::BlockReshapeFusion,
51+
side::Val{:codomain},
52+
r1::AbstractGradedUnitRange,
53+
r2::AbstractGradedUnitRange,
54+
)
55+
return tensor_product_gradedrange(style, side, r1, r2)
56+
end
57+
function TensorAlgebra.tensor_product_axis(
58+
style::BlockReshapeFusion,
59+
side::Val{:domain},
60+
r1::AbstractGradedUnitRange,
61+
r2::AbstractGradedUnitRange,
62+
)
63+
return tensor_product_gradedrange(style, side, r1, r2)
64+
end
65+
# TODO: Could this call out to a generic tensor_product_axis for AbstractBlockedUnitRange?
66+
function tensor_product_gradedrange(
67+
::BlockReshapeFusion,
68+
side::Val,
69+
r1::AbstractUnitRange,
70+
r2::AbstractUnitRange,
71+
)
72+
(isone(first(r1)) && isone(first(r2))) ||
73+
throw(ArgumentError("Only one-based axes are supported"))
74+
blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2))
75+
blockaxs = map(blockaxpairs) do (b1, b2)
76+
# TODO: Store a FusionStyle for the blocks in `BlockReshapeFusion`
77+
# and use that here.
78+
return tensor_product_axis(side, b1, b2)
79+
end
80+
return mortar_axis(vec(blockaxs))
4681
end
4782

48-
using TensorAlgebra: blockedtrivialperm
4983
function TensorAlgebra.matricize(
50-
::SectorFusion, a::AbstractArray, codomain_length::Val, domain_length::Val
84+
::SectorFusion, a::AbstractArray, length_codomain::Val
5185
)
52-
biperm = blockedtrivialperm((codomain_length, domain_length))
53-
codomain_axis, domain_axis = matricize_axes(axes(a)[biperm])
54-
a_reshaped = blockreshape(a, (codomain_axis, domain_axis))
55-
# Sort the blocks by sector and merge the equivalent sectors.
86+
a_reshaped = matricize(BlockReshapeFusion(), a, length_codomain)
5687
return sectormergesort(a_reshaped)
5788
end
5889

@@ -74,7 +105,7 @@ function TensorAlgebra.unmatricize(
74105

75106
# First, fuse axes to get `sectormergesortperm`.
76107
# Then unpermute the blocks.
77-
fused_axes = matricize_axes(blocked_axes)
108+
fused_axes = matricize_axes(BlockReshapeFusion(), m, codomain_axes, domain_axes)
78109

79110
blockperms = sectorsortperm.(fused_axes)
80111
sorted_axes = map((r, I) -> only(axes(r[I])), fused_axes, blockperms)

test/Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1212
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1313
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1414
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
15-
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
1615
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1716
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1817

@@ -31,8 +30,7 @@ SUNRepresentations = "0.3"
3130
SafeTestsets = "0.1"
3231
SparseArraysBase = "0.7"
3332
Suppressor = "0.2.8"
34-
TensorAlgebra = "0.5"
33+
TensorAlgebra = "0.6"
3534
TensorKitSectors = "0.1, 0.2"
36-
TensorProducts = "0.1.3"
3735
Test = "1.10"
3836
TestExtras = "0.3.1"

test/test_fusion_rule.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ using GradedArrays:
66
TrivialSector,
77
U1,
88
Z,
9+
,
910
dual,
1011
flip,
1112
gradedrange,
1213
nsymbol,
1314
quantum_dimension,
1415
space_isequal,
16+
tensor_product,
1517
trivial,
1618
unmerged_tensor_product
17-
using TensorProducts: , tensor_product
1819
using SUNRepresentations: SUNIrrep
1920
using Test: @test, @test_throws, @testset
2021
using TestExtras: @constinferred

test/test_interface.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,8 @@ using BlockArrays: BlockedOneTo, blockedrange, blockisequal
33
using GradedArrays:
44
NoSector, dag, dual, flip, isdual, map_sectors, sectors, space_isequal, ungrade
55
using Test: @test, @testset
6-
using TensorProducts: OneToOne
76

87
@testset "GradedUnitRange interface for AbstractUnitRange" begin
9-
a0 = OneToOne()
10-
@test !isdual(a0)
11-
@test dual(a0) isa OneToOne
12-
@test space_isequal(a0, a0)
13-
@test space_isequal(a0, dual(a0))
14-
@test only(sectors(a0)) == NoSector()
15-
@test ungrade(a0) === a0
16-
@test map_sectors(identity, a0) === a0
17-
@test dag(a0) === a0
18-
198
a = 1:3
209
ad = dual(a)
2110
af = flip(a)

test/test_sectorproduct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using GradedArrays:
55
TrivialSector,
66
U1,
77
Z,
8+
,
89
×,
910
arguments,
1011
dual,
@@ -17,7 +18,6 @@ using GradedArrays:
1718
sectorrange,
1819
space_isequal,
1920
trivial
20-
using TensorProducts:
2121
using Test: @test, @test_broken, @test_throws, @testset
2222
using TestExtras: @constinferred
2323
using BlockArrays: blocklengths

0 commit comments

Comments
 (0)