Skip to content

Commit

Permalink
optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielBrosch committed Nov 9, 2023
1 parent 001e254 commit 1e64989
Showing 1 changed file with 62 additions and 16 deletions.
78 changes: 62 additions & 16 deletions src/FlagModels/QuadraticModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function modelBlockSizes(m::QuadraticModule)
return modelBlockSizes(m.baseModel)
end

function buildJuMPModel(m::QuadraticModule, replaceBlocks=Dict(), jumpModel=Model())
function buildJuMPModel(m::QuadraticModule{T,U,B,N,D}, replaceBlocks=Dict(), jumpModel=Model()) where {T, U, B, N, D}
b = modelBlockSizes(m)
Y = Dict()
constraints = Dict()
Expand All @@ -85,12 +85,31 @@ function buildJuMPModel(m::QuadraticModule, replaceBlocks=Dict(), jumpModel=Mode
end
end

graphCoefficients = Dict(
G => sum(
dot(Y[mu], Symmetric(m.sdpData[G][mu])) for
mu in keys(b) if haskey(m.sdpData[G], mu)
) for G in keys(m.sdpData)
)
AT = typeof(sum(collect(values(Y))[1]))
graphCoefficients = Dict()

for G in keys(m.sdpData)
# eG = AffExpr()
# eG = GenericAffExpr{D, GenericVariableRef{D}}()
eG = AT()
for mu in keys(b)
if haskey(m.sdpData[G], mu)
for (i, j, c) in Iterators.zip(findnz(m.sdpData[G][mu])...)
i > j && continue
fact = (i == j ? D(1) : D(2))
add_to_expression!(eG, fact * D(c), Y[mu][i, j])
# add_to_expression!(eG, m.sdpData[G][mu][c],Y[mu][c])
end
end
end
graphCoefficients[G] = eG
end
# graphCoefficients = Dict(
# G => sum(
# dot(Y[mu], Symmetric(m.sdpData[G][mu])) for
# mu in keys(b) if haskey(m.sdpData[G], mu)
# ) for G in keys(m.sdpData)
# )

return (model=jumpModel, variables=graphCoefficients, blocks=Y, constraints=constraints)
end
Expand All @@ -110,21 +129,27 @@ mutable struct EqualityModule{T<:Flag,U<:Flag,N,D} <: AbstractFlagModel{T,N,D}
equality::QuantumFlag{U}
reservedVerts::Int

function EqualityModule{T,U}(equality::QuantumFlag{U}, reservedVerts::Int=0) where {T<:Flag,U<:Flag}
function EqualityModule{T,U}(
equality::QuantumFlag{U}, reservedVerts::Int=0
) where {T<:Flag,U<:Flag}
return new{T,U,:limit,Int}(Dict(), U[], equality, reservedVerts)
end
function EqualityModule{T,U,N,D}(equality::QuantumFlag{U}, reservedVerts::Int=0) where {T<:Flag,U<:Flag,N,D}
function EqualityModule{T,U,N,D}(
equality::QuantumFlag{U}, reservedVerts::Int=0
) where {T<:Flag,U<:Flag,N,D}
return new{T,U,N,D}(Dict(), U[], equality, reservedVerts)
end
end

function computeSDP!(m::EqualityModule{T,U,N,D}, reservedVerts::Int) where {T<:Flag,U<:Flag,N,D}
function computeSDP!(
m::EqualityModule{T,U,N,D}, reservedVerts::Int
) where {T<:Flag,U<:Flag,N,D}
m.sdpData = Dict()
# @assert N == :limit "TODO"
for (i, G) in enumerate(m.basis)
for (G2, c) in m.equality.coeff
# GG2 = G * G2
GG2s = D(1)*glueFinite(N == :limit ? N : N - reservedVerts, G, G2)
GG2s = D(1) * glueFinite(N == :limit ? N : N - reservedVerts, G, G2)
for (GG2, c2) in GG2s.coeff
GG2 === nothing && continue
if GG2 isa PartiallyLabeledFlag{T}
Expand Down Expand Up @@ -152,7 +177,7 @@ function modelBlockSizes(m::EqualityModule)
return Dict(i => -1 for i in 1:length(m.basis))
end

function buildJuMPModel(m::EqualityModule, replaceBlocks=Dict(), jumpModel=Model())
function buildJuMPModel(m::EqualityModule{T,U,N,D}, replaceBlocks=Dict(), jumpModel=Model()) where {T,U,N,D}
@assert length(replaceBlocks) == 0

b = modelBlockSizes(m)
Expand All @@ -161,10 +186,31 @@ function buildJuMPModel(m::EqualityModule, replaceBlocks=Dict(), jumpModel=Model
Y[mu] = @variable(jumpModel)
end

graphCoefficients = Dict(
G => sum(Y[mu] * m.sdpData[G][mu] for mu in keys(b) if haskey(m.sdpData[G], mu)) for
G in keys(m.sdpData)
)
# graphCoefficients = Dict(
# G => sum(Y[mu] * m.sdpData[G][mu] for mu in keys(b) if haskey(m.sdpData[G], mu)) for
# G in keys(m.sdpData)
# )

AT = typeof(sum(D(1)*collect(values(Y))[1]))
graphCoefficients = Dict()

for G in keys(m.sdpData)
# eG = AffExpr()
# eG = GenericAffExpr{D, GenericVariableRef{D}}()
eG = AT()
for mu in keys(b)
if haskey(m.sdpData[G], mu)
# for (i, j, c) in Iterators.zip(findnz(m.sdpData[G][mu])...)
# i > j && continue
# fact = (i == j ? D(1) : D(2))
# add_to_expression!(eG, fact * D(c), Y[mu][i, j])
# # add_to_expression!(eG, m.sdpData[G][mu][c],Y[mu][c])
# end
add_to_expression!(eG, m.sdpData[G][mu], Y[mu])
end
end
graphCoefficients[G] = eG
end

return (model=jumpModel, variables=graphCoefficients, blocks=Y, constraints=Dict())
end
Expand Down

0 comments on commit 1e64989

Please sign in to comment.