Skip to content

Commit

Permalink
Fix type conversion of tlist in time evolution (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang authored Sep 20, 2024
2 parents fb48447 + 725fa0b commit 5af51a2
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 10 deletions.
6 changes: 6 additions & 0 deletions ext/QuantumToolboxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,10 @@ _change_eltype(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
_change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
_change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32

sparse_to_dense(::Type{T}, A::CuArray{T}) where {T<:Number} = A
sparse_to_dense(::Type{T1}, A::CuArray{T2}) where {T1<:Number,T2<:Number} = CuArray{T1}(A)
sparse_to_dense(::Type{T}, A::CuSparseVector) where {T<:Number} = CuArray{T}(A)
sparse_to_dense(::Type{T}, A::CuSparseMatrixCSC) where {T<:Number} = CuArray{T}(A)
sparse_to_dense(::Type{T}, A::CuSparseMatrixCSR) where {T<:Number} = CuArray{T}(A)

end
6 changes: 5 additions & 1 deletion src/qobj/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,16 @@ variance(O::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject}, ψ::Vector
Converts a sparse QuantumObject to a dense QuantumObject.
"""
sparse_to_dense(A::QuantumObject{<:AbstractVecOrMat}) = QuantumObject(sparse_to_dense(A.data), A.type, A.dims)
sparse_to_dense(A::MT) where {MT<:AbstractSparseMatrix} = Array(A)
sparse_to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A)
for op in (:Transpose, :Adjoint)
@eval sparse_to_dense(A::$op{T,<:AbstractSparseMatrix}) where {T<:BlasFloat} = Array(A)
end
sparse_to_dense(A::MT) where {MT<:AbstractArray} = A

sparse_to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A)
sparse_to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A)
sparse_to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A

function sparse_to_dense(::Type{M}) where {M<:SparseMatrixCSC}
T = M
par = T.parameters
Expand Down
4 changes: 4 additions & 0 deletions src/qobj/quantum_object.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,7 @@ SparseArrays.SparseMatrixCSC(A::QuantumObject{<:AbstractMatrix}) =
QuantumObject(SparseMatrixCSC(A.data), A.type, A.dims)
SparseArrays.SparseMatrixCSC{T}(A::QuantumObject{<:SparseMatrixCSC}) where {T<:Number} =
QuantumObject(SparseMatrixCSC{T}(A.data), A.type, A.dims)

# functions for getting Float or Complex element type
_FType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _FType(T)
_CType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _CType(T)
7 changes: 4 additions & 3 deletions src/steadystate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ function steadystate(
(H.dims != ψ0.dims) && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))

N = prod(H.dims)
u0 = mat2vec(ket2dm(ψ0).data)
u0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data))

L = MatrixOperator(liouvillian(H, c_ops).data)

prob = ODEProblem{true}(L, u0, (0.0, tspan))
ftype = _FType(ψ0)
prob = ODEProblem{true}(L, u0, (ftype(0), ftype(tspan))) # Convert tspan to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
sol = solve(
prob,
solver.alg;
Expand All @@ -109,7 +111,6 @@ function steadystate(
)

ρss = reshape(sol.u[end], N, N)
ρss = (ρss + ρss') / 2 # Hermitianize
return QuantumObject(ρss, Operator, H.dims)
end

Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function mcsolveProblem(
c_ops isa Nothing &&
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2

Expand Down
4 changes: 2 additions & 2 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ function mesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type

ρ0 = mat2vec(ket2dm(ψ0).data)
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

L = liouvillian(H, c_ops).data
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
Expand Down
4 changes: 2 additions & 2 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ function sesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type

ϕ0 = get_data(ψ0)
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

U = -1im * get_data(H)
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
Expand Down
16 changes: 16 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,19 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
join(arg, ", ") *
")` instead of `$argname = $arg`." maxlog = 1

# functions for getting Float or Complex element type
_FType(::AbstractArray{T}) where {T<:Number} = _FType(T)
_FType(::Type{Int32}) = Float32
_FType(::Type{Int64}) = Float64
_FType(::Type{Float32}) = Float32
_FType(::Type{Float64}) = Float64
_FType(::Type{ComplexF32}) = Float32
_FType(::Type{ComplexF64}) = Float64
_CType(::AbstractArray{T}) where {T<:Number} = _CType(T)
_CType(::Type{Int32}) = ComplexF32
_CType(::Type{Int64}) = ComplexF64
_CType(::Type{Float32}) = ComplexF32
_CType(::Type{Float64}) = ComplexF64
_CType(::Type{ComplexF32}) = ComplexF32
_CType(::Type{ComplexF64}) = ComplexF64
8 changes: 7 additions & 1 deletion test/core-test/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
"reltol = $(sol.reltol)\n"

@testset "Type Inference sesolve" begin
@inferred sesolveProblem(H, psi0, t_l)
@inferred sesolveProblem(H, psi0, t_l, progress_bar = Val(false))
@inferred sesolveProblem(H, psi0, [0, 10], progress_bar = Val(false))
@inferred sesolveProblem(H, Qobj(zeros(Int64, N * 2); dims = (N, 2)), t_l, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
Expand Down Expand Up @@ -91,6 +93,8 @@

@testset "Type Inference mesolve" begin
@inferred mesolveProblem(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
@inferred mesolveProblem(H, psi0, [0, 10], c_ops, e_ops = e_ops, progress_bar = Val(false))
@inferred mesolveProblem(H, Qobj(zeros(Int64, N)), t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
@inferred mesolve(H, psi0, t_l, c_ops, progress_bar = Val(false))
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
Expand All @@ -108,6 +112,8 @@
)
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, progress_bar = Val(true))
@inferred mcsolve(H, psi0, [0, 10], c_ops, n_traj = 500, progress_bar = Val(false))
@inferred mcsolve(H, Qobj(zeros(Int64, N)), t_l, c_ops, n_traj = 500, progress_bar = Val(false))
end
end

Expand Down

0 comments on commit 5af51a2

Please sign in to comment.