1- using BlockArrays: blocks
1+ using BlockArrays: blocks, eachblockaxes1
22using BlockSparseArrays: BlockSparseArray, blockreshape
3- using GradedArrays:
4- AbstractGradedUnitRange,
5- SectorRange,
6- GradedArray,
7- flip,
8- gradedrange,
9- invblockperm,
10- sectormergesortperm,
11- sectorsortperm,
12- trivial,
13- unmerged_tensor_product
14- using TensorAlgebra:
15- TensorAlgebra,
16- ⊗ ,
17- AbstractBlockPermutation,
18- BlockedTuple,
19- FusionStyle,
20- trivial_axis,
21- unmatricize
3+ using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockReshapeFusion,
4+ BlockedTuple, FusionStyle, ReshapeFusion, matricize, matricize_axes,
5+ tensor_product_axis, unmatricize
226
237struct SectorFusion <: FusionStyle end
248
259TensorAlgebra. FusionStyle (:: Type{<:GradedArray} ) = SectorFusion ()
2610
27- function TensorAlgebra. trivial_axis (t:: Tuple{Vararg{G}} ) where {G <: AbstractGradedUnitRange }
11+ function TensorAlgebra. trivial_axis (
12+ :: BlockReshapeFusion ,
13+ :: Val{:codomain} ,
14+ a:: GradedArray ,
15+ axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
16+ axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
17+ )
18+ return trivial_gradedrange (axes (a))
19+ end
20+ function TensorAlgebra. trivial_axis (
21+ :: BlockReshapeFusion ,
22+ :: Val{:domain} ,
23+ a:: GradedArray ,
24+ axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
25+ axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
26+ )
27+ return flip (trivial_gradedrange (axes (a)))
28+ end
29+ function trivial_gradedrange (t:: Tuple{Vararg{G}} ) where {G <: AbstractGradedUnitRange }
2830 return trivial (first (t))
2931end
3032# heterogeneous sectors
31- TensorAlgebra . trivial_axis (t:: Tuple{Vararg{AbstractGradedUnitRange}} ) = ⊗ (trivial .(t)... )
33+ trivial_gradedrange (t:: Tuple{Vararg{AbstractGradedUnitRange}} ) = ⊗ (trivial .(t)... )
3234# trivial_axis from sector_type
33- function TensorAlgebra . trivial_axis (:: Type{S} ) where {S <: SectorRange }
35+ function trivial_gradedrange (:: Type{S} ) where {S <: SectorRange }
3436 return gradedrange ([trivial (S) => 1 ])
3537end
3638
37- function matricize_axes (
38- blocked_axes :: BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}}
39+ function TensorAlgebra . tensor_product_axis (
40+ :: ReshapeFusion , :: Val{:codomain} , r1 :: SectorUnitRange , r2 :: SectorUnitRange
3941 )
40- @assert ! isempty (blocked_axes)
41- default_axis = trivial_axis (Tuple (blocked_axes))
42- codomain_axes, domain_axes = blocks (blocked_axes)
43- codomain_axis = unmerged_tensor_product (default_axis, codomain_axes... )
44- unflipped_domain_axis = unmerged_tensor_product (default_axis, domain_axes... )
45- return codomain_axis, flip (unflipped_domain_axis)
42+ return r1 ⊗ r2
43+ end
44+ function TensorAlgebra. tensor_product_axis (
45+ :: ReshapeFusion , :: Val{:domain} , r1:: SectorUnitRange , r2:: SectorUnitRange
46+ )
47+ return flip (r1 ⊗ r2)
48+ end
49+ function TensorAlgebra. tensor_product_axis (
50+ style:: BlockReshapeFusion ,
51+ side:: Val{:codomain} ,
52+ r1:: AbstractGradedUnitRange ,
53+ r2:: AbstractGradedUnitRange ,
54+ )
55+ return tensor_product_gradedrange (style, side, r1, r2)
56+ end
57+ function TensorAlgebra. tensor_product_axis (
58+ style:: BlockReshapeFusion ,
59+ side:: Val{:domain} ,
60+ r1:: AbstractGradedUnitRange ,
61+ r2:: AbstractGradedUnitRange ,
62+ )
63+ return tensor_product_gradedrange (style, side, r1, r2)
64+ end
65+ # TODO : Could this call out to a generic tensor_product_axis for AbstractBlockedUnitRange?
66+ function tensor_product_gradedrange (
67+ :: BlockReshapeFusion ,
68+ side:: Val ,
69+ r1:: AbstractUnitRange ,
70+ r2:: AbstractUnitRange ,
71+ )
72+ (isone (first (r1)) && isone (first (r2))) ||
73+ throw (ArgumentError (" Only one-based axes are supported" ))
74+ blockaxpairs = Iterators. product (eachblockaxes1 (r1), eachblockaxes1 (r2))
75+ blockaxs = map (blockaxpairs) do (b1, b2)
76+ # TODO : Store a FusionStyle for the blocks in `BlockReshapeFusion`
77+ # and use that here.
78+ return tensor_product_axis (side, b1, b2)
79+ end
80+ return mortar_axis (vec (blockaxs))
4681end
4782
48- using TensorAlgebra: blockedtrivialperm
4983function TensorAlgebra. matricize (
50- :: SectorFusion , a:: AbstractArray , codomain_length :: Val , domain_length :: Val
84+ :: SectorFusion , a:: AbstractArray , length_codomain :: Val
5185 )
52- biperm = blockedtrivialperm ((codomain_length, domain_length))
53- codomain_axis, domain_axis = matricize_axes (axes (a)[biperm])
54- a_reshaped = blockreshape (a, (codomain_axis, domain_axis))
55- # Sort the blocks by sector and merge the equivalent sectors.
86+ a_reshaped = matricize (BlockReshapeFusion (), a, length_codomain)
5687 return sectormergesort (a_reshaped)
5788end
5889
@@ -74,7 +105,7 @@ function TensorAlgebra.unmatricize(
74105
75106 # First, fuse axes to get `sectormergesortperm`.
76107 # Then unpermute the blocks.
77- fused_axes = matricize_axes (blocked_axes )
108+ fused_axes = matricize_axes (BlockReshapeFusion (), m, codomain_axes, domain_axes )
78109
79110 blockperms = sectorsortperm .(fused_axes)
80111 sorted_axes = map ((r, I) -> only (axes (r[I])), fused_axes, blockperms)
0 commit comments