@@ -3,39 +3,40 @@ import Yao: expect, content, chcontent, mat, apply!
33using StatsBase
44
55# ############ General Rotor ############
6- const Rotor{N, T} = Union{RotationGate{N, T}, PutBlock{N, <: Any , <: RotationGate , <: Complex{ T} }}
6+ const Rotor{N, T} = Union{RotationGate{N, T}, PutBlock{N, <: Any , <: RotationGate{<:Any, T} }}
77const CphaseGate{N, T} = ControlBlock{N,<: ShiftGate{T} ,<: Any }
88const DiffBlock{N, T} = Union{Rotor{N, T}, CphaseGate{N, T}}
99"""
10- generator(rot::Rotor) -> MatrixBlock
10+ generator(rot::Rotor) -> AbstractBlock
1111
1212Return the generator of rotation block.
1313"""
1414generator (rot:: RotationGate ) = rot. block
1515generator (rot:: PutBlock{N, C, GT} ) where {N, C, GT<: RotationGate } = PutBlock {N} (generator (rot|> content), rot |> occupied_locs)
1616generator (c:: CphaseGate{N} ) where N = ControlBlock (N, c. ctrol_locs, ctrl_config, control (2 ,1 ,2 => Z), c. locs)
1717
18- abstract type AbstractDiff{GT, N, T} <: TagBlock{GT, N, T } end
18+ abstract type AbstractDiff{GT, N, T} <: TagBlock{GT, N} end
1919Base. adjoint (df:: AbstractDiff ) = Daggered (df)
2020
2121istraitkeeper (:: AbstractDiff ) = Val (true )
2222
2323# ################### The Basic Diff #################
2424"""
25- QDiff{GT, N, T } <: AbstractDiff{GT, N, Complex{T} }
25+ QDiff{GT, N} <: AbstractDiff{GT, N, T }
2626 QDiff(block) -> QDiff
2727
2828Mark a block as quantum differentiable.
2929"""
30- mutable struct QDiff{GT, N, T} <: AbstractDiff{GT, N, Complex{T} }
30+ mutable struct QDiff{GT, N, T} <: AbstractDiff{GT, N, T }
3131 block:: GT
3232 grad:: T
3333 QDiff (block:: DiffBlock{N, T} ) where {N, T} = new {typeof(block), N, T} (block, T (0 ))
3434end
3535content (cb:: QDiff ) = cb. block
3636chcontent (cb:: QDiff , blk:: DiffBlock ) = QDiff (blk)
3737
38- @forward QDiff. block mat, apply!
38+ @forward QDiff. block apply!
39+ mat (:: Type{T} , df:: QDiff ) where T = mat (T, df. block)
3940Base. adjoint (df:: QDiff ) = QDiff (content (df)' )
4041
4142function YaoBlocks. print_annotation (io:: IO , df:: QDiff )
@@ -52,19 +53,19 @@ Mark a block as differentiable, here `GT`, `PT` is gate type, parameter type.
5253Warning:
5354 please don't use the `adjoint` after `BPDiff`! `adjoint` is reserved for special purpose! (back propagation)
5455"""
55- mutable struct BPDiff{GT, N, T, PT } <: AbstractDiff{GT, N, T}
56+ mutable struct BPDiff{GT, N, T} <: AbstractDiff{GT, N, T}
5657 block:: GT
57- grad:: PT
58+ grad:: T
5859 input:: AbstractRegister
59- BPDiff (block:: MatrixBlock{N, T } , grad:: PT ) where {N, T, PT } = new {typeof(block), N, T, typeof(grad) } (block, grad)
60+ BPDiff (block:: AbstractBlock{N } , grad:: T ) where {N, T} = new {typeof(block), N, T} (block, grad)
6061end
61- BPDiff (block:: MatrixBlock ) = BPDiff (block, zeros (parameters_eltype (block), nparameters (block)))
62+ BPDiff (block:: AbstractBlock ) = BPDiff (block, zeros (parameters_eltype (block), nparameters (block)))
6263BPDiff (block:: DiffBlock{N, T} ) where {N, T} = BPDiff (block, T (0 ))
6364
6465content (cb:: BPDiff ) = cb. block
65- chcontent (cb:: BPDiff , blk:: MatrixBlock ) = BPDiff (blk)
66+ chcontent (cb:: BPDiff , blk:: AbstractBlock ) = BPDiff (blk)
6667
67- @forward BPDiff. block mat
68+ mat ( :: Type{T} , df :: BPDiff ) where T = mat (T, df . block)
6869function apply! (reg:: AbstractRegister , df:: BPDiff )
6970 if isdefined (df, :input )
7071 copyto! (df. input, reg)
@@ -105,7 +106,7 @@ autodiff(mode::Symbol, block::AbstractBlock) = autodiff(Val(mode), block)
105106autodiff (:: Val{:BP} , block:: DiffBlock ) = BPDiff (block)
106107autodiff (:: Val{:BP} , block:: AbstractBlock ) = block
107108# Sequential, Roller and ChainBlock can propagate.
108- function autodiff (mode:: Val{:BP} , blk:: Union{ChainBlock, Roller, Sequential} )
109+ function autodiff (mode:: Val{:BP} , blk:: Union{ChainBlock, Sequential} )
109110 chsubblocks (blk, autodiff .(mode, subblocks (blk)))
110111end
111112
@@ -148,11 +149,11 @@ Numeric differentiation.
148149end
149150
150151"""
151- opdiff(psifunc, diffblock::AbstractDiff, op::MatrixBlock )
152+ opdiff(psifunc, diffblock::AbstractDiff, op::AbstractBlock )
152153
153154Operator differentiation.
154155"""
155- @inline function opdiff (psifunc, diffblock:: AbstractDiff , op:: MatrixBlock )
156+ @inline function opdiff (psifunc, diffblock:: AbstractDiff , op:: AbstractBlock )
156157 r1, r2 = _perturb (()-> expect (op, psifunc ()) |> real, diffblock, π/ 2 )
157158 diffblock. grad = (r2 - r1)/ 2
158159end
@@ -215,14 +216,14 @@ end
215216end
216217
217218"""
218- backward!(δ::AbstractRegister, circuit::MatrixBlock ) -> AbstractRegister
219+ backward!(δ::AbstractRegister, circuit::AbstractBlock ) -> AbstractRegister
219220
220221back propagate and calculate the gradient ∂f/∂θ = 2*Re(∂f/∂ψ*⋅∂ψ*/∂θ), given ∂f/∂ψ*.
221222
222223Note:
223224Here, the input circuit should be a matrix block, otherwise the back propagate may not apply (like Measure operations).
224225"""
225- backward! (δ:: AbstractRegister , circuit:: MatrixBlock ) = apply! (δ, circuit' )
226+ backward! (δ:: AbstractRegister , circuit:: AbstractBlock ) = apply! (δ, circuit' )
226227
227228"""
228229 gradient(circuit::AbstractBlock, mode::Symbol=:ANY) -> Vector
0 commit comments