Skip to content

Commit

Permalink
FF: don't cache atom positions
Browse files Browse the repository at this point in the history
Fixes: #134
Signed-off-by: Thomas Kemmer <[email protected]>
  • Loading branch information
tkemmer committed Nov 28, 2024
1 parent d37f558 commit 8ef2d6a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 47 deletions.
17 changes: 7 additions & 10 deletions src/forcefields/common/bend_component.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ export
θ₀::T
k::T
a1::Atom{T}
a1r::Vector3{T}
a2::Atom{T}
a2r::Vector3{T}
a3::Atom{T}
a3r::Vector3{T}
end

@auto_hash_equals mutable struct QuadraticBendComponent{T<:Real} <: AbstractForceFieldComponent{T}
Expand Down Expand Up @@ -100,9 +97,9 @@ function setup!(qbc::QuadraticBendComponent{T}) where {T<:Real}
QuadraticAngleBend(
T(θ₀_factor*only(qab.theta0)),
T(k_factor*only(qab.k)),
a1, a1.r,
a2, a2.r,
a3, a3.r
a1,
a2,
a3,
))
end
end
Expand All @@ -117,8 +114,8 @@ function update!(qbc::QuadraticBendComponent{T}) where {T<:Real}
end

@inline function compute_energy(qab::QuadraticAngleBend{T})::T where {T<:Real}
v1 = qab.a1r .- qab.a2r
v2 = qab.a3r .- qab.a2r
v1 = qab.a1.r .- qab.a2.r
v2 = qab.a3.r .- qab.a2.r

sq_length = squared_norm(v1) * squared_norm(v2)

Expand Down Expand Up @@ -149,8 +146,8 @@ end

function compute_forces!(qab::QuadraticAngleBend{T}) where {T<:Real}
# calculate the vectors between the atoms and normalize if possible
v1 = qab.a1r .- qab.a2r
v2 = qab.a3r .- qab.a2r
v1 = qab.a1.r .- qab.a2.r
v2 = qab.a3.r .- qab.a2.r

v1_length = norm(v1)
v2_length = norm(v2)
Expand Down
22 changes: 9 additions & 13 deletions src/forcefields/common/nonbonded_component.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ end
distance::T
scaling_factor::T
a1::Atom{T}
a1r::Vector3{T}
a2::Atom{T}
a2r::Vector3{T}
switching_function::CubicSwitchingFunction{T}
end

Expand All @@ -98,9 +96,7 @@ end
scaling_factor::T
distance_dependent_dielectric::Bool
a1::Atom{T}
a1r::Vector3{T}
a2::Atom{T}
a2r::Vector3{T}
switching_function::CubicSwitchingFunction{T}
end

Expand Down Expand Up @@ -162,8 +158,8 @@ end
params.B_ij,
T(distance),
scaling_factor,
atom_1, atom_1.r,
atom_2, atom_2.r,
atom_1,
atom_2,
switching_function
)
)
Expand Down Expand Up @@ -416,8 +412,8 @@ function update!(nbc::NonBondedComponent{T}) where {T<:Real}
lj_candidate[3],
vicinal_pair ? scaling_es_1_4 : T(1.0),
distance_dependent_dielectric,
atom_1, atom_1.r,
atom_2, atom_2.r,
atom_1,
atom_2,
es_switching_function
)
push!(electrostatic_interactions, es)
Expand All @@ -443,8 +439,8 @@ function update!(nbc::NonBondedComponent{T}) where {T<:Real}
only(h_params.B),
T(lj_candidate[3]),
T(1.0),
atom_1, atom_1.r,
atom_2, atom_2.r,
atom_1,
atom_2,
vdw_switching_function
)
)
Expand Down Expand Up @@ -516,7 +512,7 @@ function compute_energy!(nbc::NonBondedComponent{T})::T where {T<:Real}
end

function compute_forces!(lji::LennardJonesInteraction{T, 12, 6}) where {T<:Real}
direction = lji.a1r .- lji.a2r
direction = lji.a1.r .- lji.a2.r

sq_distance = squared_norm(direction)

Expand Down Expand Up @@ -552,7 +548,7 @@ function compute_forces!(lji::LennardJonesInteraction{T, 12, 6}) where {T<:Real}
end

function compute_forces!(hb::LennardJonesInteraction{T, 12, 10}) where {T<:Real}
direction = hb.a1r .- hb.a2r
direction = hb.a1.r .- hb.a2.r

sq_distance = squared_norm(direction)

Expand Down Expand Up @@ -592,7 +588,7 @@ function compute_forces!(hb::LennardJonesInteraction{T, 12, 10}) where {T<:Real}
end

function compute_forces!(esi::ElectrostaticInteraction{T}) where {T<:Real}
direction = esi.a1r .- esi.a2r
direction = esi.a1.r .- esi.a2.r

sq_distance = squared_norm(direction)

Expand Down
14 changes: 6 additions & 8 deletions src/forcefields/common/stretch_component.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ export
r0::T
k::T
a1::Atom{T}
a1r::Vector3{T}
a2::Atom{T}
a2r::Vector3{T}
end

@auto_hash_equals mutable struct QuadraticStretchComponent{T<:Real} <: AbstractForceFieldComponent{T}
Expand Down Expand Up @@ -60,7 +58,7 @@ function setup!(qsc::QuadraticStretchComponent{T}) where {T}

if has_flag(bond, :TYPE__HYDROGEN)
# skip hydrogen bonds
stretches[i] = QuadraticBondStretch(one(T), zero(T), a1, a1.r, a2, a2.r)
stretches[i] = QuadraticBondStretch(one(T), zero(T), a1, a2)
end

qbs = coalesce(
Expand All @@ -82,13 +80,13 @@ function setup!(qsc::QuadraticStretchComponent{T}) where {T}
end

# we don't want to get any force or energy component from this stretch
QuadraticBondStretch(one(T), zero(T), a1, a1.r, a2, a2.r)
QuadraticBondStretch(one(T), zero(T), a1, a2)
else
QuadraticBondStretch(
T(r0_factor*only(qbs.r0)),
T(k_factor*only(qbs.k)),
a1, a1.r,
a2, a2.r
a1,
a2
)
end
end
Expand All @@ -101,7 +99,7 @@ function update!(qsc::QuadraticStretchComponent{T}) where {T<:Real}
end

@inline function compute_energy(qbs::QuadraticBondStretch{T})::T where {T<:Real}
d = distance(qbs.a1r, qbs.a2r)
d = distance(qbs.a1.r, qbs.a2.r)

qbs.k * (d - qbs.r0)^2
end
Expand All @@ -117,7 +115,7 @@ function compute_energy!(qsc::QuadraticStretchComponent{T})::T where {T<:Real}
end

function compute_forces!(qbs::QuadraticBondStretch{T}) where {T<:Real}
direction = qbs.a1r .- qbs.a2r
direction = qbs.a1.r .- qbs.a2.r
distance = norm(direction)

if distance == zero(T)
Expand Down
28 changes: 12 additions & 16 deletions src/forcefields/common/torsion_component.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ export
f::Vector{Int}
div::Vector{Int}
a1::Atom{T}
a1r::Vector3{T}
a2::Atom{T}
a2r::Vector3{T}
a3::Atom{T}
a3r::Vector3{T}
a4::Atom{T}
a4r::Vector3{T}
end

@auto_hash_equals mutable struct TorsionComponent{T<:Real} <: AbstractForceFieldComponent{T}
Expand Down Expand Up @@ -113,10 +109,10 @@ function _try_assign_torsion!(
ϕ₀_factor .* getproperty.(pts, :phi0),
getproperty.(pts, :f),
getproperty.(pts, :div),
a1, a1.r,
a2, a2.r,
a3, a3.r,
a4, a4.r
a1,
a2,
a3,
a4
)
)
end
Expand Down Expand Up @@ -262,10 +258,10 @@ end
@inline function compute_energy(pt::CosineTorsion{T})::T where {T<:Real}
energy = zero(T)

a23 = pt.a3r .- pt.a2r
a23 = pt.a3.r .- pt.a2.r

cross2321 = normalize(cross(a23, pt.a1r .- pt.a2r))
cross2334 = normalize(cross(a23, pt.a4r .- pt.a3r))
cross2321 = normalize(cross(a23, pt.a1.r .- pt.a2.r))
cross2334 = normalize(cross(a23, pt.a4.r .- pt.a3.r))

if !isnan(cross2321[1]) && !isnan(cross2334[1])
cos_ϕ = clamp(dot(cross2321, cross2334), T(-1.0), T(1.0))
Expand Down Expand Up @@ -296,9 +292,9 @@ function compute_energy!(tc::TorsionComponent{T})::T where {T<:Real}
end

function compute_forces!(ct::CosineTorsion{T}) where {T<:Real}
a21 = ct.a1r .- ct.a2r
a23 = ct.a3r .- ct.a2r
a34 = ct.a4r .- ct.a3r
a21 = ct.a1.r .- ct.a2.r
a23 = ct.a3.r .- ct.a2.r
a34 = ct.a4.r .- ct.a3.r

cross2321 = cross(a23, a21)
cross2334 = cross(a23, a34)
Expand All @@ -325,8 +321,8 @@ function compute_forces!(ct::CosineTorsion{T}) where {T<:Real}
∂E∂ϕ *= -1
end

a13 = ct.a3r .- ct.a1r
a24 = ct.a4r .- ct.a2r
a13 = ct.a3.r .- ct.a1.r
a24 = ct.a4.r .- ct.a2.r

dEdt = (∂E∂ϕ / (length_cross2321^2 * norm(a23)) * cross(cross2321, a23))
dEdu = -(∂E∂ϕ / (length_cross2334^2 * norm(a23)) * cross(cross2334, a23))
Expand Down

0 comments on commit 8ef2d6a

Please sign in to comment.