|
| 1 | +export Rotor, generator, AbstractDiff, BPDiff, QDiff, backward!, gradient, CPhaseGate, DiffBlock |
| 2 | +import Yao: expect, content, chcontent |
| 3 | +using StatsBase |
| 4 | + |
| 5 | +############# General Rotor ############ |
| 6 | +const Rotor{N, T} = Union{RotationGate{N, T}, PutBlock{N, <:Any, <:RotationGate, <:Complex{T}}} |
| 7 | +const CphaseGate{N, T} = ControlBlock{N,<:ShiftGate{T},<:Any} |
| 8 | +const DiffBlock{N, T} = Union{Rotor{N, T}, CphaseGate{N, T}} |
| 9 | +""" |
| 10 | + generator(rot::Rotor) -> MatrixBlock |
| 11 | +
|
| 12 | +Return the generator of rotation block. |
| 13 | +""" |
| 14 | +generator(rot::RotationGate) = rot.block |
| 15 | +generator(rot::PutBlock{N, C, GT}) where {N, C, GT<:RotationGate} = PutBlock{N}(generator(rot|>content), rot |> occupied_locs) |
| 16 | +generator(c::CphaseGate{N}) where N = ControlBlock(N, c.ctrol_locs, ctrl_config, control(2,1,2=>Z), c.locs) |
| 17 | + |
| 18 | +abstract type AbstractDiff{GT, N, T} <: TagBlock{GT, N, T} end |
| 19 | +Base.adjoint(df::AbstractDiff) = Daggered(df) |
| 20 | + |
| 21 | +istraitkeeper(::AbstractDiff) = Val(true) |
| 22 | + |
| 23 | +#################### The Basic Diff ################# |
| 24 | +""" |
| 25 | + QDiff{GT, N, T} <: AbstractDiff{GT, N, Complex{T}} |
| 26 | + QDiff(block) -> QDiff |
| 27 | +
|
| 28 | +Mark a block as quantum differentiable. |
| 29 | +""" |
| 30 | +mutable struct QDiff{GT, N, T} <: AbstractDiff{GT, N, Complex{T}} |
| 31 | + block::GT |
| 32 | + grad::T |
| 33 | + QDiff(block::DiffBlock{N, T}) where {N, T} = new{typeof(block), N, T}(block, T(0)) |
| 34 | +end |
| 35 | +content(cb::QDiff) = cb.block |
| 36 | +chcontent(cb::QDiff, blk::DiffBlock) = QDiff(blk) |
| 37 | + |
| 38 | +@forward QDiff.block mat, apply! |
| 39 | +Base.adjoint(df::QDiff) = QDiff(content(df)') |
| 40 | + |
| 41 | +function YaoBlocks.print_annotation(io::IO, df::QDiff) |
| 42 | + printstyled(io, "[̂∂] "; bold=true, color=:yellow) |
| 43 | +end |
| 44 | + |
| 45 | +#################### The Back Propagation Diff ################# |
| 46 | +""" |
| 47 | + BPDiff{GT, N, T, PT} <: AbstractDiff{GT, N, Complex{T}} |
| 48 | + BPDiff(block, [grad]) -> BPDiff |
| 49 | +
|
| 50 | +Mark a block as differentiable, here `GT`, `PT` is gate type, parameter type. |
| 51 | +
|
| 52 | +Warning: |
| 53 | + please don't use the `adjoint` after `BPDiff`! `adjoint` is reserved for special purpose! (back propagation) |
| 54 | +""" |
| 55 | +mutable struct BPDiff{GT, N, T, PT} <: AbstractDiff{GT, N, T} |
| 56 | + block::GT |
| 57 | + grad::PT |
| 58 | + input::AbstractRegister |
| 59 | + BPDiff(block::MatrixBlock{N, T}, grad::PT) where {N, T, PT} = new{typeof(block), N, T, typeof(grad)}(block, grad) |
| 60 | +end |
| 61 | +BPDiff(block::MatrixBlock) = BPDiff(block, zeros(parameters_eltype(block), nparameters(block))) |
| 62 | +BPDiff(block::DiffBlock{N, T}) where {N, T} = BPDiff(block, T(0)) |
| 63 | + |
| 64 | +content(cb::BPDiff) = cb.block |
| 65 | +chcontent(cb::BPDiff, blk::MatrixBlock) = BPDiff(blk) |
| 66 | + |
| 67 | +@forward BPDiff.block mat |
| 68 | +function apply!(reg::AbstractRegister, df::BPDiff) |
| 69 | + if isdefined(df, :input) |
| 70 | + copyto!(df.input, reg) |
| 71 | + else |
| 72 | + df.input = copy(reg) |
| 73 | + end |
| 74 | + apply!(reg, content(df)) |
| 75 | + reg |
| 76 | +end |
| 77 | + |
| 78 | +function apply!(δ::AbstractRegister, adf::Daggered{<:BPDiff{<:Rotor}}) |
| 79 | + df = adf |> content |
| 80 | + apply!(δ, content(df)') |
| 81 | + df.grad = -statevec(df.input |> generator(content(df)))' * statevec(δ) |> imag |
| 82 | + δ |
| 83 | +end |
| 84 | + |
| 85 | +function YaoBlocks.print_annotation(io::IO, df::BPDiff) |
| 86 | + printstyled(io, "[∂] "; bold=true, color=:yellow) |
| 87 | +end |
| 88 | + |
| 89 | + |
| 90 | +#### interface ##### |
| 91 | +export autodiff, numdiff, opdiff, StatFunctional, statdiff, as_weights |
| 92 | + |
| 93 | +as_weights(probs::AbstractVector{T}) where T = Weights(probs, T(1)) |
| 94 | +""" |
| 95 | + autodiff(mode::Symbol, block::AbstractBlock) -> AbstractBlock |
| 96 | + autodiff(mode::Symbol) -> Function |
| 97 | +
|
| 98 | +automatically mark differentiable items in a block tree as differentiable. |
| 99 | +""" |
| 100 | +function autodiff end |
| 101 | +autodiff(mode::Symbol) = block->autodiff(mode, block) |
| 102 | +autodiff(mode::Symbol, block::AbstractBlock) = autodiff(Val(mode), block) |
| 103 | + |
| 104 | +# for BP |
| 105 | +autodiff(::Val{:BP}, block::DiffBlock) = BPDiff(block) |
| 106 | +autodiff(::Val{:BP}, block::AbstractBlock) = block |
| 107 | +# Sequential, Roller and ChainBlock can propagate. |
| 108 | +function autodiff(mode::Val{:BP}, blk::Union{ChainBlock, Roller, Sequential}) |
| 109 | + chsubblocks(blk, autodiff.(mode, subblocks(blk))) |
| 110 | +end |
| 111 | + |
| 112 | +# for QC |
| 113 | +autodiff(::Val{:QC}, block::Union{RotationGate, CphaseGate}) = QDiff(block) |
| 114 | +# escape control blocks. |
| 115 | +autodiff(::Val{:QC}, block::ControlBlock) = block |
| 116 | + |
| 117 | +function autodiff(mode::Val{:QC}, blk::AbstractBlock) |
| 118 | + blks = subblocks(blk) |
| 119 | + isempty(blks) ? blk : chsubblocks(blk, autodiff.(mode, blks)) |
| 120 | + end |
| 121 | + |
| 122 | +@inline function _perturb(func, gate::AbstractDiff{<:DiffBlock}, δ::Real) |
| 123 | + dispatch!(-, gate, (δ,)) |
| 124 | + r1 = func() |
| 125 | + dispatch!(+, gate, (2δ,)) |
| 126 | + r2 = func() |
| 127 | + dispatch!(-, gate, (δ,)) |
| 128 | + r1, r2 |
| 129 | +end |
| 130 | + |
| 131 | +@inline function _perturb(func, gate::AbstractDiff{<:Rotor}, δ::Real) # for put |
| 132 | + dispatch!(-, gate, (δ,)) |
| 133 | + r1 = func() |
| 134 | + dispatch!(+, gate, (2δ,)) |
| 135 | + r2 = func() |
| 136 | + dispatch!(-, gate, (δ,)) |
| 137 | + r1, r2 |
| 138 | +end |
| 139 | + |
| 140 | +""" |
| 141 | + numdiff(loss, diffblock::AbstractDiff; δ::Real=1e-2) |
| 142 | +
|
| 143 | +Numeric differentiation. |
| 144 | +""" |
| 145 | +@inline function numdiff(loss, diffblock::AbstractDiff; δ::Real=1e-2) |
| 146 | + r1, r2 = _perturb(loss, diffblock, δ) |
| 147 | + diffblock.grad = (r2 - r1)/2δ |
| 148 | +end |
| 149 | + |
| 150 | +""" |
| 151 | + opdiff(psifunc, diffblock::AbstractDiff, op::MatrixBlock) |
| 152 | +
|
| 153 | +Operator differentiation. |
| 154 | +""" |
| 155 | +@inline function opdiff(psifunc, diffblock::AbstractDiff, op::MatrixBlock) |
| 156 | + r1, r2 = _perturb(()->expect(op, psifunc()) |> real, diffblock, π/2) |
| 157 | + diffblock.grad = (r2 - r1)/2 |
| 158 | +end |
| 159 | + |
| 160 | +""" |
| 161 | + StatFunctional{N, AT} |
| 162 | + StatFunctional(array::AT<:Array) -> StatFunctional{N, <:Array} |
| 163 | + StatFunctional{N}(func::AT<:Function) -> StatFunctional{N, <:Function} |
| 164 | +
|
| 165 | +statistic functional, i.e. |
| 166 | + * if `AT` is an array, A[i,j,k...], it is defined on finite Hilbert space, which is `∫A[i,j,k...]p[i]p[j]p[k]...` |
| 167 | + * if `AT` is a function, F(xᵢ,xⱼ,xₖ...), this functional is `1/C(r,n)... ∑ᵢⱼₖ...F(xᵢ,xⱼ,xₖ...)`, see U-statistics for detail. |
| 168 | +
|
| 169 | +References: |
| 170 | + U-statistics, http://personal.psu.edu/drh20/asymp/fall2006/lectures/ANGELchpt10.pdf |
| 171 | +""" |
| 172 | +struct StatFunctional{N, AT} |
| 173 | + data::AT |
| 174 | + StatFunctional{N}(data::AT) where {N, AT<:Function} = new{N, AT}(data) |
| 175 | + StatFunctional(data::AT) where {N, AT<:AbstractArray{<:Real, N}} = new{N, AT}(data) |
| 176 | +end |
| 177 | + |
| 178 | +@forward StatFunctional.data Base.ndims |
| 179 | +Base.parent(stat::StatFunctional) = stat.data |
| 180 | + |
| 181 | +expect(stat::StatFunctional{2, <:AbstractArray}, px::Weights, py::Weights=px) = px.values' * stat.data * py.values |
| 182 | +expect(stat::StatFunctional{1, <:AbstractArray}, px::Weights) = stat.data' * px.values |
| 183 | +function expect(stat::StatFunctional{2, <:Function}, xs::AbstractVector{T}) where T |
| 184 | + N = length(xs) |
| 185 | + res = zero(stat.data(xs[1], xs[1])) |
| 186 | + for i = 2:N |
| 187 | + for j = 1:i-1 |
| 188 | + @inbounds res += stat.data(xs[i], xs[j]) |
| 189 | + end |
| 190 | + end |
| 191 | + res/binomial(N,2) |
| 192 | +end |
| 193 | +function expect(stat::StatFunctional{2, <:Function}, xs::AbstractVector, ys::AbstractVector) |
| 194 | + M = length(xs) |
| 195 | + N = length(ys) |
| 196 | + ci = CartesianIndices((M, N)) |
| 197 | + @inbounds mapreduce(ind->stat.data(xs[ind[1]], ys[ind[2]]), +, ci)/M/N |
| 198 | +end |
| 199 | +expect(stat::StatFunctional{1, <:Function}, xs::AbstractVector) = mean(stat.data.(xs)) |
| 200 | +Base.ndims(stat::StatFunctional{N}) where N = N |
| 201 | + |
| 202 | +""" |
| 203 | + statdiff(probfunc, diffblock::AbstractDiff, stat::StatFunctional{<:Any, <:AbstractArray}; initial::AbstractVector=probfunc()) |
| 204 | + statdiff(samplefunc, diffblock::AbstractDiff, stat::StatFunctional{<:Any, <:Function}; initial::AbstractVector=samplefunc()) |
| 205 | +
|
| 206 | +Differentiation for statistic functionals. |
| 207 | +""" |
| 208 | +@inline function statdiff(probfunc, diffblock::AbstractDiff, stat::StatFunctional{2}; initial::AbstractVector=probfunc()) |
| 209 | + r1, r2 = _perturb(()->expect(stat, probfunc(), initial), diffblock, π/2) |
| 210 | + diffblock.grad = (r2 - r1)*ndims(stat)/2 |
| 211 | +end |
| 212 | +@inline function statdiff(probfunc, diffblock::AbstractDiff, stat::StatFunctional{1}) |
| 213 | + r1, r2 = _perturb(()->expect(stat, probfunc()), diffblock, π/2) |
| 214 | + diffblock.grad = (r2 - r1)*ndims(stat)/2 |
| 215 | +end |
| 216 | + |
| 217 | +""" |
| 218 | + backward!(δ::AbstractRegister, circuit::MatrixBlock) -> AbstractRegister |
| 219 | +
|
| 220 | +back propagate and calculate the gradient ∂f/∂θ = 2*Re(∂f/∂ψ*⋅∂ψ*/∂θ), given ∂f/∂ψ*. |
| 221 | +
|
| 222 | +Note: |
| 223 | +Here, the input circuit should be a matrix block, otherwise the back propagate may not apply (like Measure operations). |
| 224 | +""" |
| 225 | +backward!(δ::AbstractRegister, circuit::MatrixBlock) = apply!(δ, circuit') |
| 226 | + |
| 227 | +""" |
| 228 | + gradient(circuit::AbstractBlock, mode::Symbol=:ANY) -> Vector |
| 229 | +
|
| 230 | +collect all gradients in a circuit, mode can be :BP/:QC/:ANY, they will collect `grad` from BPDiff/QDiff/AbstractDiff respectively. |
| 231 | +""" |
| 232 | +gradient(circuit::AbstractBlock, mode::Symbol=:ANY) = gradient!(circuit, parameters_eltype(circuit)[], mode) |
| 233 | + |
| 234 | +gradient!(circuit::AbstractBlock, grad, mode::Symbol) = gradient!(circuit, grad, Val(mode)) |
| 235 | +function gradient!(circuit::AbstractBlock, grad, mode::Val) |
| 236 | + for block in subblocks(circuit) |
| 237 | + gradient!(block, grad, mode) |
| 238 | + end |
| 239 | + grad |
| 240 | +end |
| 241 | + |
| 242 | +gradient!(circuit::BPDiff, grad, mode::Val{:BP}) = append!(grad, circuit.grad) |
| 243 | +gradient!(circuit::QDiff, grad, mode::Val{:QC}) = push!(grad, circuit.grad) |
| 244 | +gradient!(circuit::AbstractDiff, grad, mode::Val{:ANY}) = append!(grad, circuit.grad) |
0 commit comments