Skip to content

Commit 01a205c

Browse files
Updated ssesolve, smesolve and mcsolve to be minmial working changes.
1 parent db3422e commit 01a205c

File tree

5 files changed

+54
-71
lines changed

5 files changed

+54
-71
lines changed

src/time_evolution/mcsolve.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ function _mcsolve_output_func(sol, i)
2020
return (sol, false)
2121
end
2222

23-
function _normalize_state!(u, dims, normalize_states, type)
23+
function _normalize_state!(u, dims, normalize_states)
2424
getVal(normalize_states) && normalize!(u)
25-
return QuantumObject(u, type, dims)
25+
return QuantumObject(u, Ket(), dims)
2626
end
2727

2828
function _mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops)
@@ -110,15 +110,15 @@ If the environmental measurements register a quantum jump, the wave function und
110110
"""
111111
function mcsolveProblem(
112112
H::Union{AbstractQuantumObject{Operator},Tuple},
113-
ψ0::QuantumObject{ST},
113+
ψ0::QuantumObject{Ket},
114114
tlist::AbstractVector,
115115
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
116116
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
117117
params = NullParameters(),
118118
rng::AbstractRNG = default_rng(),
119119
jump_callback::TJC = ContinuousLindbladJumpCallback(),
120120
kwargs...,
121-
) where {ST<:Union{Ket,Operator}, TJC<:LindbladJumpCallbackType}
121+
) where {TJC<:LindbladJumpCallbackType}
122122
haskey(kwargs, :save_idxs) &&
123123
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
124124

@@ -221,7 +221,7 @@ If the environmental measurements register a quantum jump, the wave function und
221221
"""
222222
function mcsolveEnsembleProblem(
223223
H::Union{AbstractQuantumObject{Operator},Tuple},
224-
ψ0::QuantumObject{ST},
224+
ψ0::QuantumObject{Ket},
225225
tlist::AbstractVector,
226226
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
227227
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
@@ -234,7 +234,7 @@ function mcsolveEnsembleProblem(
234234
prob_func::Union{Function,Nothing} = nothing,
235235
output_func::Union{Tuple,Nothing} = nothing,
236236
kwargs...,
237-
) where {ST<:Union{Ket,Operator}, TJC<:LindbladJumpCallbackType}
237+
) where {TJC<:LindbladJumpCallbackType}
238238
_prob_func = isnothing(prob_func) ? _ensemble_dispatch_prob_func(rng, ntraj, tlist, _mcsolve_prob_func) : prob_func
239239
_output_func =
240240
output_func isa Nothing ?
@@ -359,7 +359,7 @@ If the environmental measurements register a quantum jump, the wave function und
359359
"""
360360
function mcsolve(
361361
H::Union{AbstractQuantumObject{Operator},Tuple},
362-
ψ0::QuantumObject{ST},
362+
ψ0::QuantumObject{Ket},
363363
tlist::AbstractVector,
364364
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
365365
alg::AbstractODEAlgorithm = DP5(),
@@ -375,7 +375,7 @@ function mcsolve(
375375
keep_runs_results::Union{Val,Bool} = Val(false),
376376
normalize_states::Union{Val,Bool} = Val(true),
377377
kwargs...,
378-
) where {ST<:Union{Ket,Operator}, TJC<:LindbladJumpCallbackType}
378+
) where {TJC<:LindbladJumpCallbackType}
379379
ens_prob_mc = mcsolveEnsembleProblem(
380380
H,
381381
ψ0,
@@ -415,11 +415,8 @@ function mcsolve(
415415
_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_expvals(sol[:, i], SaveFuncMCSolve), eachindex(sol))
416416
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
417417

418-
419-
states_all = stack(
420-
map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states, [ens_prob_mc.states_type]), eachindex(sol)), # Unsure why ens_prob_mc.states_type needs to be in an array but the other two arguments don't!
421-
dims = 1,
422-
)
418+
# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
419+
states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
423420

424421
col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
425422
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))
@@ -439,4 +436,4 @@ function mcsolve(
439436
kwargs.abstol,
440437
kwargs.reltol,
441438
)
442-
end
439+
end

src/time_evolution/mesolve.jl

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ _mesolve_make_L_QobjEvo(H::Nothing, c_ops::Nothing) = throw(ArgumentError("Both
77
c_ops are Nothing. You are probably running the wrong function."))
88

99
function _gen_mesolve_solution(sol, prob::TimeEvolutionProblem{ST}) where {ST<:Union{Operator,OperatorKet,SuperOperator}}
10-
if prob.states_type == Operator
10+
if prob.states_type isa Operator
1111
ρt = map-> QuantumObject(vec2mat(ϕ), type = prob.states_type, dims = prob.dimensions), sol.u)
1212
else
1313
ρt = map-> QuantumObject(ϕ, type = prob.states_type, dims = prob.dimensions), sol.u)
@@ -66,6 +66,7 @@ where
6666
6767
# Notes
6868
69+
- The initial state can also be [`SuperOperator`](@ref) (such as a super-identity). This is useful for simulating many density matrices simultaneously or calculating process matrices. Currently must be Square.
6970
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
7071
- If `e_ops` is empty, the default value of `saveat=tlist` (saving the states corresponding to `tlist`), otherwise, `saveat=[tlist[end]]` (only save the final state). You can also specify `e_ops` and `saveat` separately.
7172
- If `H` is an [`Operator`](@ref), `ψ0` is a [`Ket`](@ref) and `c_ops` is `Nothing`, the function will call [`sesolveProblem`](@ref) instead.
@@ -106,24 +107,15 @@ function mesolveProblem(
106107
L_evo = _mesolve_make_L_QobjEvo(H, c_ops)
107108
check_dimensions(L_evo, ψ0)
108109

109-
T = Base.promote_eltype(L_evo, ψ0)
110-
# ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
111-
# to_dense(_complex_float_type(T), copy(ψ0.data))
112-
# else
113-
# to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
114-
# end
115-
if isoper(ψ0)
116-
ρ0 = to_dense(_complex_float_type(T), mat2vec(ψ0.data))
117-
states_type = Operator()
118-
elseif isoperket(ψ0)
119-
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
120-
states_type = OperatorKet()
121-
elseif isket(ψ0)
122-
ρ0 = to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
110+
# Convert to dense vector with complex element type
111+
112+
T = _complex_float_type(Base.promote_eltype(L_evo, ψ0))
113+
if isoperket(ψ0) || issuper(ψ0)
114+
ρ0 = to_dense(T, copy(ψ0.data))
115+
states_type = ψ0.type
116+
else
117+
ρ0 = to_dense(T, mat2vec(ket2dm(ψ0).data))
123118
states_type = Operator()
124-
elseif issuper(ψ0)
125-
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
126-
states_type = SuperOperator()
127119
end
128120

129121
L = cache_operator(L_evo.data, ρ0)
@@ -180,6 +172,7 @@ where
180172
181173
# Notes
182174
175+
- The initial state can also be [`SuperOperator`](@ref) (such as a super-identity). This is useful for simulating many density matrices simultaneously or calculating process matrices. Currently must be Square.
183176
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
184177
- If `e_ops` is empty, the default value of `saveat=tlist` (saving the states corresponding to `tlist`), otherwise, `saveat=[tlist[end]]` (only save the final state). You can also specify `e_ops` and `saveat` separately.
185178
- If `H` is an [`Operator`](@ref), `ψ0` is a [`Ket`](@ref) and `c_ops` is `Nothing`, the function will call [`sesolve`](@ref) instead.
@@ -292,6 +285,7 @@ for each combination in the ensemble.
292285
293286
# Notes
294287
288+
- The initial state can also be [`SuperOperator`](@ref) (such as a super-identity). This is useful for simulating many density matrices simultaneously or calculating process matrices. Currently must be Square.
295289
- The function returns an array of solutions with dimensions matching the Cartesian product of initial states and parameter sets.
296290
- If `ψ0` is a vector of `m` states and `params = (p1, p2, ...)` where `p1` has length `n1`, `p2` has length `n2`, etc., the output will be of size `(m, n1, n2, ...)`.
297291
- If `H` is an [`Operator`](@ref), `ψ0` is a [`Ket`](@ref) and `c_ops` is `Nothing`, the function will call [`sesolve_map`](@ref) instead.
@@ -329,14 +323,10 @@ function mesolve_map(
329323
# Convert to appropriate format based on state type
330324
ψ0_iter = map(ψ0) do state
331325
T = _complex_float_type(eltype(state))
332-
if isoper(state)
333-
to_dense(_complex_float_type(T), mat2vec(state.data))
334-
elseif isoperket(state)
335-
to_dense(_complex_float_type(T), copy(state.data))
336-
elseif isket(state)
337-
to_dense(_complex_float_type(T), mat2vec(ket2dm(state).data))
338-
elseif issuper(state)
339-
to_dense(_complex_float_type(T), copy(state.data))
326+
if isoperket(state) || issuper(state)
327+
to_dense(T, copy(state.data))
328+
else
329+
to_dense(T, mat2vec(ket2dm(state).data))
340330
end
341331
end
342332
if params isa NullParameters

src/time_evolution/sesolve.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Generate the ODEProblem for the Schrödinger time evolution of a quantum system:
5050
5151
# Notes
5252
53+
- Initial state can also be [`Operator`](@ref)s where each column represents a state vector, such as the Identity operator. This can be used, for example, to calculate the propagator.
5354
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
5455
- If `e_ops` is empty, the default value of `saveat=tlist` (saving the states corresponding to `tlist`), otherwise, `saveat=[tlist[end]]` (only save the final state). You can also specify `e_ops` and `saveat` separately.
5556
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
@@ -127,6 +128,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
127128
128129
# Notes
129130
131+
- Initial state can also be [`Operator`](@ref)s where each column represents a state vector, such as the Identity operator. This can be used, for example, to calculate the propagator.
130132
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
131133
- If `e_ops` is empty, the default value of `saveat=tlist` (saving the states corresponding to `tlist`), otherwise, `saveat=[tlist[end]]` (only save the final state). You can also specify `e_ops` and `saveat` separately.
132134
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
@@ -215,9 +217,10 @@ for each combination in the ensemble.
215217
- `kwargs`: The keyword arguments for the ODEProblem.
216218
217219
# Notes
218-
220+
- Initial state can also be [`Operator`](@ref)s where each column represents a state vector, such as the Identity operator. This can be used, for example, to calculate the propagator.
219221
- The function returns an array of solutions with dimensions matching the Cartesian product of initial states and parameter sets.
220222
- If `ψ0` is a vector of `m` states and `params = (p1, p2, ...)` where `p1` has length `n1`, `p2` has length `n2`, etc., the output will be of size `(m, n1, n2, ...)`.
223+
- Similarly, the initial state(s) can also be `Operator`s where each column represents a state vector, such as the Identity operator. This can be used, for example, to calculate many propagators.
221224
- See [`sesolve`](@ref) for more details.
222225
223226
# Returns
@@ -239,7 +242,7 @@ function sesolve_map(
239242

240243
ψ0 = map(to_dense, ψ0) # Convert all initial states to dense vectors
241244

242-
ψ0_iter = map(get_data, ψ0)
245+
ψ0_iter = map(state -> to_dense(_complex_float_type(eltype(state)), copy(state.data)), ψ0)
243246
if params isa NullParameters
244247
iter = collect(Iterators.product(ψ0_iter, [params])) |> vec # convert nx1 Matrix into Vector
245248
else

src/time_evolution/smesolve.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
export smesolveProblem, smesolveEnsembleProblem, smesolve
22

3-
#_smesolve_generate_state(u, dims, isoperket::Val{false}) = QuantumObject(vec2mat(u), type = Operator(), dims = dims)
4-
#_smesolve_generate_state(u, dims, isoperket::Val{true}) = QuantumObject(u, type = OperatorKet(), dims = dims)
5-
function _smesolve_generate_state(u, dims, type)
6-
if type == OperatorKet
7-
return QuantumObject(u, type = type, dims = dims)
8-
else
9-
return QuantumObject(vec2mat(u), type = Operator(), dims = dims)
10-
end
11-
end
3+
_smesolve_generate_state(u, dims, isoperket::Val{false}) = QuantumObject(vec2mat(u), type = Operator(), dims = dims)
4+
_smesolve_generate_state(u, dims, isoperket::Val{true}) = QuantumObject(u, type = OperatorKet(), dims = dims)
125

136
function _smesolve_update_coeff(u, p, t, op_vec)
147
return 2 * real(dot(op_vec, u)) #this is Tr[Sn * ρ + ρ * Sn']
@@ -110,10 +103,10 @@ function smesolveProblem(
110103
T = Base.promote_eltype(L_evo, ψ0)
111104
if isoperket(ψ0) # Convert it to dense vector with complex element type
112105
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
113-
states_type = OperatorKet()
106+
state_type = OperatorKet()
114107
else
115108
ρ0 = to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
116-
states_type = Operator()
109+
state_type = Operator()
117110
end
118111

119112
sc_ops_evo_data = Tuple(map(get_data QobjEvo, sc_ops_list))
@@ -155,7 +148,7 @@ function smesolveProblem(
155148
kwargs4...,
156149
)
157150

158-
return TimeEvolutionProblem(prob, tlist, states_type, dims, ())
151+
return TimeEvolutionProblem(prob, tlist, state_type, dims, (isoperket = Val(isoperket(ψ0)),))
159152
end
160153

161154
@doc raw"""
@@ -285,7 +278,7 @@ function smesolveEnsembleProblem(
285278
prob_sme.times,
286279
prob_sme.states_type,
287280
prob_sme.dimensions,
288-
(progr = _output_func[2], channel = _output_func[3]),
281+
merge(prob_sme.kwargs, (progr = _output_func[2], channel = _output_func[3])),
289282
)
290283

291284
return ensemble_prob
@@ -432,7 +425,7 @@ function smesolve(
432425

433426
# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
434427
states_all = stack(
435-
map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims), [ens_prob.states_type]), eachindex(sol)),
428+
map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims), ens_prob.kwargs.isoperket), eachindex(sol)),
436429
dims = 1,
437430
)
438431

@@ -454,4 +447,4 @@ function smesolve(
454447
kwargs.abstol,
455448
kwargs.reltol,
456449
)
457-
end
450+
end

src/time_evolution/ssesolve.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is th
7676
"""
7777
function ssesolveProblem(
7878
H::Union{AbstractQuantumObject{Operator},Tuple},
79-
ψ0::QuantumObject{ST},
79+
ψ0::QuantumObject{Ket},
8080
tlist::AbstractVector,
8181
sc_ops::Union{Nothing,AbstractVector,Tuple,AbstractQuantumObject} = nothing;
8282
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
@@ -85,7 +85,7 @@ function ssesolveProblem(
8585
progress_bar::Union{Val,Bool} = Val(true),
8686
store_measurement::Union{Val,Bool} = Val(false),
8787
kwargs...,
88-
) where {ST<:Union{Ket,Operator}}
88+
)
8989
haskey(kwargs, :save_idxs) &&
9090
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
9191

@@ -95,13 +95,13 @@ function ssesolveProblem(
9595
sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject
9696

9797
tlist = _check_tlist(tlist, _float_type(ψ0))
98-
states_type = ψ0.type
9998

10099
H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, sc_ops_list)
101100
isoper(H_eff_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
102101
check_dimensions(H_eff_evo, ψ0)
103102
dims = H_eff_evo.dimensions
104103

104+
states_type = ψ0.type
105105
ψ0 = to_dense(_complex_float_type(ψ0), get_data(ψ0))
106106

107107
sc_ops_evo_data = Tuple(map(get_data QobjEvo, sc_ops_list))
@@ -219,7 +219,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is t
219219
"""
220220
function ssesolveEnsembleProblem(
221221
H::Union{AbstractQuantumObject{Operator},Tuple},
222-
ψ0::QuantumObject{ST},
222+
ψ0::QuantumObject{Ket},
223223
tlist::AbstractVector,
224224
sc_ops::Union{Nothing,AbstractVector,Tuple,AbstractQuantumObject} = nothing;
225225
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
@@ -232,7 +232,7 @@ function ssesolveEnsembleProblem(
232232
progress_bar::Union{Val,Bool} = Val(true),
233233
store_measurement::Union{Val,Bool} = Val(false),
234234
kwargs...,
235-
) where {ST<:Union{Ket,Operator}}
235+
)
236236
_prob_func =
237237
isnothing(prob_func) ?
238238
_ensemble_dispatch_prob_func(
@@ -253,7 +253,7 @@ function ssesolveEnsembleProblem(
253253
progr_desc = "[ssesolve] ",
254254
) : output_func
255255

256-
prob_sse = ssesolveProblem(
256+
prob_sme = ssesolveProblem(
257257
H,
258258
ψ0,
259259
tlist,
@@ -267,10 +267,10 @@ function ssesolveEnsembleProblem(
267267
)
268268

269269
ensemble_prob = TimeEvolutionProblem(
270-
EnsembleProblem(prob_sse, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true),
271-
prob_sse.times,
272-
prob_sse.states_type,
273-
prob_sse.dimensions,
270+
EnsembleProblem(prob_sme, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true),
271+
prob_sme.times,
272+
prob_sme.states_type,
273+
prob_sme.dimensions,
274274
(progr = _output_func[2], channel = _output_func[3]),
275275
)
276276

@@ -357,7 +357,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is th
357357
"""
358358
function ssesolve(
359359
H::Union{AbstractQuantumObject{Operator},Tuple},
360-
ψ0::QuantumObject{ST},
360+
ψ0::QuantumObject{Ket},
361361
tlist::AbstractVector,
362362
sc_ops::Union{Nothing,AbstractVector,Tuple,AbstractQuantumObject} = nothing;
363363
alg::Union{Nothing,AbstractSDEAlgorithm} = nothing,
@@ -372,7 +372,7 @@ function ssesolve(
372372
keep_runs_results::Union{Val,Bool} = Val(false),
373373
store_measurement::Union{Val,Bool} = Val(false),
374374
kwargs...,
375-
) where {ST<:Union{Ket,Operator}}
375+
)
376376
ens_prob = ssesolveEnsembleProblem(
377377
H,
378378
ψ0,
@@ -419,7 +419,7 @@ function ssesolve(
419419
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
420420

421421
# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
422-
states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states, [ens_prob.states_type]), eachindex(sol)), dims = 1)
422+
states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
423423

424424
_m_expvals =
425425
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol))
@@ -439,4 +439,4 @@ function ssesolve(
439439
kwargs.abstol,
440440
kwargs.reltol,
441441
)
442-
end
442+
end

0 commit comments

Comments
 (0)