diff --git a/ext/StaticArraysExt/StaticArraysExt.jl b/ext/StaticArraysExt/StaticArraysExt.jl index 07a082b..aa4c3b1 100644 --- a/ext/StaticArraysExt/StaticArraysExt.jl +++ b/ext/StaticArraysExt/StaticArraysExt.jl @@ -3,8 +3,11 @@ module StaticArraysExt using StaticArrays: SVector, SMatrix, SArray using Gabs -import Gabs: ptrace, tensor, ⊗, _promote_output_matrix, _promote_output_vector +using LinearAlgebra +import Gabs: ptrace, tensor, ⊗, _promote_output_matrix, _promote_output_vector, +SymplecticBasis, vacuumstate, thermalstate, coherentstate, squeezedstate, eprstate +include("cleaner_dispatch.jl") include("utils.jl") -end \ No newline at end of file +end diff --git a/ext/StaticArraysExt/cleaner_dispatch.jl b/ext/StaticArraysExt/cleaner_dispatch.jl new file mode 100644 index 0000000..c36fade --- /dev/null +++ b/ext/StaticArraysExt/cleaner_dispatch.jl @@ -0,0 +1,175 @@ +# Gaussian states with cleaner dispatch + +# Vacuum state +function vacuumstate(::Type{SArray}, basis::SymplecticBasis{N}; ħ=2) where {N<:Int} + n = basis.nmodes + T = typeof(ħ/2) + vacuumstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis; ħ=ħ) +end +function vacuumstate(::Type{SVector}, ::Type{SMatrix}, basis::SymplecticBasis{N}; ħ=2) where {N<:Int} + n = basis.nmodes + T = typeof(ħ/2) + vacuumstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis; ħ=ħ) +end +function vacuumstate(::Type{SVector{M,T1}}, ::Type{SMatrix{M,M,T2}}, basis::SymplecticBasis{N}; ħ=2) where {M,N<:Int,T1,T2} + n = basis.nmodes + M == 2n || error("Size mismatch: SVector{$M}/SMatrix{$M,$M} != 2n (n=$n)") + T = promote_type(T1, T2, typeof(ħ/2)) + GaussianState(basis, SVector{2n,T}(zeros(T, 2n)), SMatrix{2n,2n,T}((ħ/2) * I); ħ=ħ) +end + +# Thermal state +function thermalstate(::Type{SArray}, basis::SymplecticBasis{N}, photons; ħ=2) where {N<:Int} + n = basis.nmodes + T = promote_type(typeof(ħ/2), eltype(photons)) + thermalstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis, photons; ħ=ħ) +end +function thermalstate(::Type{SVector}, ::Type{SMatrix}, basis::SymplecticBasis{N}, photons; ħ=2) where {N<:Int} + n = basis.nmodes + T = promote_type(typeof(ħ/2), eltype(photons)) + thermalstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis, photons; ħ=ħ) +end +function thermalstate(::Type{SVector{M,T1}}, ::Type{SMatrix{M,M,T2}}, basis::SymplecticBasis{N}, photons; ħ=2) where {M,N<:Int,T1,T2} + n = basis.nmodes + M == 2n || error("Size mismatch: SVector{$M}/SMatrix{$M,$M} != 2n (n=$n)") + T = promote_type(T1, T2, typeof(ħ/2), eltype(photons)) + mean = zeros(SVector{2n,T}) + covar = (2 * photons + 1) * (ħ/2) * one(SMatrix{2n,2n,T}) + GaussianState(basis, mean, covar; ħ=ħ) +end + +# Coherent state + +function coherentstate(::Type{SArray}, basis::SymplecticBasis{N}, alpha; ħ=2) where {N<:Int} + n = basis.nmodes + T = promote_type(typeof(ħ/2), real(eltype(alpha))) + coherentstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis, alpha; ħ=ħ) +end +function coherentstate(::Type{SVector}, ::Type{SMatrix}, basis::SymplecticBasis{N}, alpha; ħ=2) where {N<:Int} + n = basis.nmodes + T = promote_type(typeof(ħ/2), real(eltype(alpha))) + coherentstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis, alpha; ħ=ħ) +end +function _complex_to_real_vec(alpha::Number, ħ, T, n) + real_part = sqrt(2ħ) * real(alpha) + imag_part = sqrt(2ħ) * imag(alpha) + SVector{2n,T}([repeat([real_part, imag_part], n)...]) +end +function _complex_to_real_vec(alpha::AbstractVector, ħ, T, n) + length(alpha) == n || error("Number of complex amplitudes ($(length(alpha))) must match number of modes ($n)") + SVector{2n,T}(sqrt(2ħ) * vcat(real.(alpha), imag.(alpha))) +end +function coherentstate(::Type{SVector{M,T1}}, ::Type{SMatrix{M,M,T2}}, basis::SymplecticBasis{N}, alpha; ħ=2) where {M,N<:Int,T1,T2} + n = basis.nmodes + M == 2n || error("Size mismatch: SVector{$M}/SMatrix{$M,$M} != 2n (n=$n)") + T = promote_type(T1, T2, typeof(ħ/2), real(eltype(alpha))) + mean = _complex_to_real_vec(alpha, ħ, T, n) + covar = (ħ/2) * one(SMatrix{2n,2n,T}) + GaussianState(basis, mean, covar; ħ=ħ) +end + +# Squeezed state +function squeezedstate(::Type{SArray}, basis::SymplecticBasis{N}, r, theta; ħ=2) where {N<:Int} + n = basis.nmodes + T = promote_type(typeof(ħ/2), eltype(r), eltype(theta)) + squeezedstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis, r, theta; ħ=ħ) +end +function squeezedstate(::Type{SVector}, ::Type{SMatrix}, basis::SymplecticBasis{N}, r, theta; ħ=2) where {N<:Int} + n = basis.nmodes + T = promote_type(typeof(ħ/2), eltype(r), eltype(theta)) + squeezedstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis, r, theta; ħ=ħ) +end +function squeezedstate(::Type{SVector{M,T1}}, ::Type{SMatrix{M,M,T2}}, basis::SymplecticBasis{N}, r, theta; ħ=2) where {M,N<:Int,T1,T2} + n = basis.nmodes + M == 2n || error("Size mismatch: SVector{$M}/SMatrix{$M,$M} != 2n (n=$n)") + T = promote_type(T1, T2, typeof(ħ/2), eltype(r), eltype(theta)) + mean = zeros(SVector{2n,T}) + covar = _squeezed_covar(SMatrix{2n,2n,T}, r, theta, ħ) + GaussianState(basis, mean, covar; ħ=ħ) +end + +function eprstate(::Type{SVector{M,T1}}, ::Type{SMatrix{M,M,T2}}, + basis::SymplecticBasis{N}, r::R, theta::R; ħ=2) where { + M, N<:Int, T1, T2, R<:Number} + n = basis.nmodes + M == 2n || error("Size mismatch: SVector{$M}/SMatrix{$M,$M} != 2n (n=$n)") + T = promote_type(T1, T2, typeof(ħ/2), R) + mean = zeros(SVector{2n,T}) + covar = _epr_covar(SMatrix{2n,2n,T}, r, theta, ħ) + GaussianState(basis, mean, covar; ħ=ħ) +end + +# Unsized StaticArrays convenience methods +function eprstate(::Type{SVector}, ::Type{SMatrix}, basis::SymplecticBasis{N}, r::R, theta::R; ħ=2) where { + N<:Int, R<:Number} + n = basis.nmodes + T = promote_type(typeof(ħ/2), R) + eprstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis, r, theta; ħ=ħ) +end + +function eprstate(::Type{SArray}, basis::SymplecticBasis{N}, r::R, theta::R; ħ=2) where { + N<:Int, R<:Number} + n = basis.nmodes + T = promote_type(typeof(ħ/2), R) + eprstate(SVector{2n,T}, SMatrix{2n,2n,T}, basis, r, theta; ħ=ħ) +end + +function _epr_covar(::Type{SMatrix{M,M,T}}, r::R, theta::R, ħ) where {M, T, R<:Number} + n = M ÷ 2 + cr, sr = (ħ/2)*cosh(2*r), (ħ/2)*sinh(2*r) + ct, st = cos(theta), sin(theta) + elements = zeros(T, M, M) + for i in 1:2:n + j = n + i + elements[i,i] = cr + elements[i+1,i+1] = cr + elements[j,j] = cr + elements[j+1,j+1] = cr + elements[i,j] = -sr*ct + elements[i,j+1] = -sr*st + elements[i+1,j] = -sr*st + elements[i+1,j+1] = sr*ct + elements[j,i] = -sr*ct + elements[j,i+1] = -sr*st + elements[j+1,i] = -sr*st + elements[j+1,i+1] = sr*ct + end + return SMatrix{M,M,T}(elements) +end + +function tensor(::Type{SVector}, ::Type{SMatrix}, state1::GaussianState, state2::GaussianState) + M1, V1 = typeof(state1.mean), typeof(state1.covar) + M2, V2 = typeof(state2.mean), typeof(state2.covar) + out_size = 2*(state1.basis.nmodes + state2.basis.nmodes) + tensor( + SVector{out_size, promote_type(eltype(M1), eltype(M2))}, + SMatrix{out_size, out_size, promote_type(eltype(V1), eltype(V2))}, + state1, + state2 + ) +end +function tensor(::Type{SArray}, state1::GaussianState, state2::GaussianState) + tensor(SVector, SMatrix, state1, state2) +end +function _output_size(basis::SymplecticBasis{N}, indices) where {N} + nmodes = basis.nmodes + notindices = setdiff(1:nmodes, indices) + notidxlength = length(notindices) + 2 * notidxlength +end + +function _output_types(state::GaussianState{B,M,V}) where {B,M,V} + (M, V) +end + +# ptrace +function ptrace(::Type{SVector}, ::Type{SMatrix}, state::GaussianState, indices) + M, V = _output_types(state) + out_size = _output_size(state.basis, indices) + ptrace(SVector{out_size, eltype(M)}, SMatrix{out_size, out_size, eltype(V)}, state, indices) +end +function ptrace(::Type{SArray}, state::GaussianState, indices) + M, V = _output_types(state) + out_size = _output_size(state.basis, indices) + ptrace(SVector{out_size, eltype(M)}, SMatrix{out_size, out_size, eltype(V)}, state, indices) +end diff --git a/test/test_states.jl b/test/test_states.jl index afc58b1..5060c00 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -21,6 +21,8 @@ state_block = thermalstate(qblockbasis, n) @test state_pair isa GaussianState && state_block isa GaussianState @test thermalstate(SVector{2*nmodes}, SMatrix{2*nmodes,2*nmodes}, qpairbasis, n) isa GaussianState + @test thermalstate(SVector, SMatrix, qpairbasis, n) isa GaussianState + @test thermalstate(SArray, qpairbasis, n) isa GaussianState @test thermalstate(qblockbasis, n) == changebasis(QuadBlockBasis, state_pair) @test state_pair == changebasis(QuadPairBasis, state_block) && state_block == changebasis(QuadBlockBasis, state_pair) @test state_pair == changebasis(QuadPairBasis, state_pair) && state_block == changebasis(QuadBlockBasis, state_block) @@ -36,6 +38,8 @@ state_block = coherentstate(qblockbasis, alpha) @test state_pair isa GaussianState && state_block isa GaussianState @test coherentstate(SVector{2*nmodes}, SMatrix{2*nmodes,2*nmodes}, qpairbasis, alpha) isa GaussianState + @test coherentstate(SVector, SMatrix, qpairbasis, alpha) isa GaussianState + @test coherentstate(SArray, qpairbasis, alpha) isa GaussianState @test coherentstate(qblockbasis, alpha) == changebasis(QuadBlockBasis, state_pair) @test state_pair == changebasis(QuadPairBasis, state_block) && state_block == changebasis(QuadBlockBasis, state_pair) @test state_pair == changebasis(QuadPairBasis, state_pair) && state_block == changebasis(QuadBlockBasis, state_block) @@ -57,9 +61,11 @@ @testset "epr states" begin r, theta = rand(Float64), rand(Float64) rs, thetas = rand(Float64, nmodes), rand(Float64, nmodes) - state, array_state, static_state = eprstate(2*qpairbasis, r, theta), eprstate(Array, 2*qpairbasis, r, theta), eprstate(SVector{4*nmodes}, SMatrix{4*nmodes,4*nmodes}, 2*qpairbasis, r, theta) - @test state isa GaussianState && array_state isa GaussianState && static_state isa GaussianState + state, array_state, static_state, static_array_state, static_state1 = eprstate(2*qpairbasis, r, theta), eprstate(Array, 2*qpairbasis, r, theta), eprstate(SVector{4*nmodes}, SMatrix{4*nmodes,4*nmodes}, 2*qpairbasis, r, theta), eprstate(SArray, 2*qpairbasis, r, theta), eprstate(SVector, SMatrix, 2*qpairbasis, r, theta) + @test state isa GaussianState && array_state isa GaussianState && static_state isa GaussianState && static_state1 isa GaussianState && static_array_state isa GaussianState @test eprstate(SVector{4*nmodes}, SMatrix{4*nmodes,4*nmodes}, 2*qpairbasis, r, theta) isa GaussianState + @test eprstate(SVector, SMatrix, 2*qpairbasis, r, theta) isa GaussianState + @test eprstate(SArray, 2*qpairbasis, r, theta) isa GaussianState @test eprstate(2*qblockbasis, r, theta) == changebasis(QuadBlockBasis, state) @test eprstate(2*qblockbasis, rs, thetas) == changebasis(QuadBlockBasis, eprstate(2*qpairbasis, rs, thetas)) @test state.ħ == 2 && array_state.ħ == 2 && static_state.ħ == 2 @@ -70,6 +76,8 @@ vs = tensor(v, v) @test vs isa GaussianState @test tensor(SVector{4*nmodes}, SMatrix{4*nmodes,4*nmodes}, v, v) isa GaussianState + @test tensor(SVector, SMatrix, v, v) isa GaussianState + @test tensor(SArray, v, v) isa GaussianState @test vs == v ⊗ v @test isapprox(vs, v ⊗ v, atol = 1e-10) @@ -83,6 +91,8 @@ @test sq ⊗ sq == sqs vstatic = vacuumstate(SVector{2*nmodes}, SMatrix{2*nmodes,2*nmodes}, qpairbasis) + vstatic = vacuumstate(SVector, SMatrix, qpairbasis) + vstatic = vacuumstate(SArray, qpairbasis) tpstatic = vstatic ⊗ vstatic ⊗ vstatic @test tpstatic.mean isa SVector{6*nmodes} @test tpstatic.covar isa SMatrix{6*nmodes,6*nmodes} @@ -113,8 +123,30 @@ @test ptrace(tpstatic, 1) == sstatic ⊗ sstatic @test ptrace(tpstatic, [1,3]) == sstatic - @test ptrace(SVector{2}, SMatrix{2,2}, state, [1, 3]) isa GaussianState - @test ptrace(SVector{4}, SMatrix{4,4}, state, 1) isa GaussianState + sstatic = coherentstate(SVector, SMatrix, basis, alpha) + tpstatic = sstatic ⊗ sstatic ⊗ sstatic + @test ptrace(tpstatic, 1) == sstatic ⊗ sstatic + @test ptrace(tpstatic, [1,3]) == sstatic + + sstatic = coherentstate(SArray, basis, alpha) + tpstatic = sstatic ⊗ sstatic ⊗ sstatic + @test ptrace(tpstatic, 1) == sstatic ⊗ sstatic + @test ptrace(tpstatic, [1,3]) == sstatic + + for (T1, T2, subsys) in [ + (SVector{2}, SMatrix{2,2}, [1, 3]), + (SVector{4}, SMatrix{4,4}, 1), + (SVector, SMatrix, [1, 3]), + (SVector, SMatrix, 1), + (SArray, nothing, [1, 3]), + (SArray, nothing, 1) + ] + if T2 === nothing + @test ptrace(T1, state, subsys) isa GaussianState + else + @test ptrace(T1, T2, state, subsys) isa GaussianState + end + end eprstates = eprstate(basis ⊕ basis ⊕ basis ⊕ basis, r, theta)