Skip to content

Commit

Permalink
Refactor: Remove global dependence of some functions in DeePKS. (#5778)
Browse files Browse the repository at this point in the history
* Move cal_o_delta from GlobalC::ld to DeePKS_domain and remove variable o_delta in ld.

* Remove F_delta in ld and lessen the tedious dimension in orbital related variables in DeePKS.
  • Loading branch information
ErjieWu authored Dec 28, 2024
1 parent 5e8bc21 commit 9bf2533
Show file tree
Hide file tree
Showing 18 changed files with 350 additions and 357 deletions.
2 changes: 1 addition & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ OBJS_CELL=atom_pseudo.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_force.o\
LCAO_deepks_odelta.o\
deepks_orbital.o\
LCAO_deepks_io.o\
LCAO_deepks_mpi.o\
LCAO_deepks_pdm.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Force_LCAO
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
typename TGint<T>::type& gint,
Expand Down
18 changes: 14 additions & 4 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
ModuleBase::matrix fewalds;
ModuleBase::matrix fcc;
ModuleBase::matrix fscc;
#ifdef __DEEPKS
ModuleBase::matrix fvnl_dalpha; // deepks
#endif

fvl_dphi.create(nat, 3); // must do it now, update it later, noted by zhengdy

Expand All @@ -93,6 +96,9 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
fewalds.create(nat, 3);
fcc.create(nat, 3);
fscc.create(nat, 3);
#ifdef __DEEPKS
fvnl_dalpha.create(nat, 3); // deepks
#endif

// calculate basic terms in Force, same method with PW base
this->calForcePwPart(ucell,
Expand Down Expand Up @@ -172,6 +178,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
svnl_dbeta,
svl_dphi,
#ifdef __DEEPKS
fvnl_dalpha,
svnl_dalpha,
#endif
gint_gamma,
Expand Down Expand Up @@ -454,7 +461,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
// mohan add 2021-08-04
if (PARAM.inp.deepks_scf)
{
fcs(iat, i) += GlobalC::ld.F_delta(iat, i);
fcs(iat, i) += fvnl_dalpha(iat, i);
}
#endif
// sum total force for correction
Expand Down Expand Up @@ -499,7 +506,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
if (PARAM.inp.deepks_scf)
{
const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
LCAO_deepks_io::save_npy_f(fcs - GlobalC::ld.F_delta,
LCAO_deepks_io::save_npy_f(fcs - fvnl_dalpha,
file_fbase,
ucell.nat,
GlobalV::MY_RANK); // Ry/Bohr, F_base
Expand Down Expand Up @@ -636,8 +643,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
// caoyu add 2021-06-03
if (PARAM.inp.deepks_scf)
{
ModuleIO::print_force(GlobalV::ofs_running, ucell, "DeePKS FORCE", GlobalC::ld.F_delta, true);
// this->print_force("DeePKS FORCE", GlobalC::ld.F_delta, 1, ry);
ModuleIO::print_force(GlobalV::ofs_running, ucell, "DeePKS FORCE", fvnl_dalpha, true);
}
#endif
}
Expand Down Expand Up @@ -891,6 +897,7 @@ void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
Gint_Gamma& gint_gamma, // mohan add 2024-04-01
Expand All @@ -917,6 +924,7 @@ void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
svnl_dbeta,
svl_dphi,
#if __DEEPKS
fvnl_dalpha,
svnl_dalpha,
#endif
gint_gamma,
Expand Down Expand Up @@ -944,6 +952,7 @@ void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOn
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
Gint_Gamma& gint_gamma,
Expand All @@ -969,6 +978,7 @@ void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOn
svnl_dbeta,
svl_dphi,
#if __DEEPKS
fvnl_dalpha,
svnl_dalpha,
#endif
gint_k,
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Force_Stress_LCAO
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
Gint_Gamma& gint_gamma,
Expand Down
70 changes: 35 additions & 35 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ void Force_LCAO<double>::ftable(const bool isforce,
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
TGint<double>::type& gint,
Expand Down Expand Up @@ -246,15 +247,13 @@ void Force_LCAO<double>::ftable(const bool isforce,
false /*reset dm to gint*/);

#ifdef __DEEPKS
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
if (PARAM.inp.deepks_scf)
{
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();

// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);

GlobalC::ld.cal_descriptor(ucell.nat);

GlobalC::ld.cal_gedm(ucell.nat);

const int nks = 1;
Expand All @@ -269,40 +268,9 @@ void Force_LCAO<double>::ftable(const bool isforce,
GlobalC::ld.phialpha,
GlobalC::ld.gedm,
GlobalC::ld.inl_index,
GlobalC::ld.F_delta,
fvnl_dalpha,
isstress,
svnl_dalpha);

#ifdef __MPI
Parallel_Reduce::reduce_all(GlobalC::ld.F_delta.c, GlobalC::ld.F_delta.nr * GlobalC::ld.F_delta.nc);

if (isstress)
{
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
}
#endif

if (PARAM.inp.deepks_out_unittest)
{
const int nks = 1; // 1 for gamma-only
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);

GlobalC::ld.check_projected_dm();

GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);

GlobalC::ld.check_gedm();

GlobalC::ld.cal_e_delta_band(dm_gamma, nks);

std::ofstream ofs("E_delta_bands.dat");
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;

std::ofstream ofs1("E_delta.dat");
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;

DeePKS_domain::check_f_delta(ucell.nat, GlobalC::ld.F_delta, svnl_dalpha);
}
}
#endif

Expand All @@ -312,14 +280,46 @@ void Force_LCAO<double>::ftable(const bool isforce,
Parallel_Reduce::reduce_pool(ftvnl_dphi.c, ftvnl_dphi.nr * ftvnl_dphi.nc);
Parallel_Reduce::reduce_pool(fvnl_dbeta.c, fvnl_dbeta.nr * fvnl_dbeta.nc);
Parallel_Reduce::reduce_pool(fvl_dphi.c, fvl_dphi.nr * fvl_dphi.nc);
#ifdef __DEEPKS
Parallel_Reduce::reduce_pool(fvnl_dalpha.c, fvnl_dalpha.nr * fvnl_dalpha.nc);
#endif
}
if (isstress)
{
Parallel_Reduce::reduce_pool(soverlap.c, soverlap.nr * soverlap.nc);
Parallel_Reduce::reduce_pool(stvnl_dphi.c, stvnl_dphi.nr * stvnl_dphi.nc);
Parallel_Reduce::reduce_pool(svnl_dbeta.c, svnl_dbeta.nr * svnl_dbeta.nc);
Parallel_Reduce::reduce_pool(svl_dphi.c, svl_dphi.nr * svl_dphi.nc);
#ifdef __DEEPKS
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
#endif
}

#ifdef __DEEPKS
// It seems these test should not all be here, should be moved in the future
// Also, these test are not in multi-k case now
if (PARAM.inp.deepks_scf && PARAM.inp.deepks_out_unittest)
{
const int nks = 1; // 1 for gamma-only
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);

GlobalC::ld.check_projected_dm();

GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);

GlobalC::ld.check_gedm();

GlobalC::ld.cal_e_delta_band(dm_gamma, nks);

std::ofstream ofs("E_delta_bands.dat");
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;

std::ofstream ofs1("E_delta.dat");
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;

DeePKS_domain::check_f_delta(ucell.nat, fvnl_dalpha, svnl_dalpha);
}
#endif

// delete DSloc_x, DSloc_y, DSloc_z
// delete DHloc_fixed_x, DHloc_fixed_y, DHloc_fixed_z
Expand Down
17 changes: 8 additions & 9 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
ModuleBase::matrix& svnl_dbeta,
ModuleBase::matrix& svl_dphi,
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
#endif
TGint<std::complex<double>>::type& gint,
Expand Down Expand Up @@ -363,17 +364,9 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
GlobalC::ld.phialpha,
GlobalC::ld.gedm,
GlobalC::ld.inl_index,
GlobalC::ld.F_delta,
fvnl_dalpha,
isstress,
svnl_dalpha);

#ifdef __MPI
Parallel_Reduce::reduce_all(GlobalC::ld.F_delta.c, GlobalC::ld.F_delta.nr * GlobalC::ld.F_delta.nc);
if (isstress)
{
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
}
#endif
}
#endif

Expand All @@ -386,13 +379,19 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
Parallel_Reduce::reduce_pool(ftvnl_dphi.c, ftvnl_dphi.nr * ftvnl_dphi.nc);
Parallel_Reduce::reduce_pool(fvnl_dbeta.c, fvnl_dbeta.nr * fvnl_dbeta.nc);
Parallel_Reduce::reduce_pool(fvl_dphi.c, fvl_dphi.nr * fvl_dphi.nc);
#ifdef __DEEPKS
Parallel_Reduce::reduce_pool(fvnl_dalpha.c, fvnl_dalpha.nr * fvnl_dalpha.nc);
#endif
}
if (isstress)
{
Parallel_Reduce::reduce_pool(soverlap.c, soverlap.nr * soverlap.nc);
Parallel_Reduce::reduce_pool(stvnl_dphi.c, stvnl_dphi.nr * stvnl_dphi.nc);
Parallel_Reduce::reduce_pool(svnl_dbeta.c, svnl_dbeta.nr * svnl_dbeta.nc);
Parallel_Reduce::reduce_pool(svl_dphi.c, svl_dphi.nr * svl_dphi.nc);
#ifdef __DEEPKS
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
#endif
}

ModuleBase::timer::tick("Force_LCAO", "ftable");
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_lcao/module_deepks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ if(ENABLE_DEEPKS)
list(APPEND objects
LCAO_deepks.cpp
deepks_force.cpp
LCAO_deepks_odelta.cpp
deepks_orbital.cpp
LCAO_deepks_io.cpp
LCAO_deepks_mpi.cpp
LCAO_deepks_pdm.cpp
Expand Down
34 changes: 9 additions & 25 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
}
if (PARAM.inp.cal_force)
{
// init F_delta
F_delta.create(nat, 3);
if (PARAM.inp.deepks_out_labels)
{
this->init_gdmx(nat);
Expand All @@ -342,34 +340,24 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
// gdmx is used only in calculating gvx
}

if (PARAM.inp.deepks_bandgap)
{
// init o_delta
o_delta.create(nks, 1);
}

return;
}

void LCAO_Deepks::init_orbital_pdm_shell(const int nks)
{

this->orbital_pdm_shell = new double***[nks];
this->orbital_pdm_shell = new double**[nks];

for (int iks = 0; iks < nks; iks++)
{
this->orbital_pdm_shell[iks] = new double**[1];
for (int hl = 0; hl < 1; hl++)
this->orbital_pdm_shell[iks] = new double*[this->inlmax];
for (int inl = 0; inl < this->inlmax; inl++)
{
this->orbital_pdm_shell[iks][hl] = new double*[this->inlmax];

for (int inl = 0; inl < this->inlmax; inl++)
{
this->orbital_pdm_shell[iks][hl][inl] = new double[(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1)];
ModuleBase::GlobalFunc::ZEROS(orbital_pdm_shell[iks][hl][inl],
(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1));
}
this->orbital_pdm_shell[iks][inl] = new double[(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1)];
ModuleBase::GlobalFunc::ZEROS(orbital_pdm_shell[iks][inl],
(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1));
}

}

return;
Expand All @@ -379,13 +367,9 @@ void LCAO_Deepks::del_orbital_pdm_shell(const int nks)
{
for (int iks = 0; iks < nks; iks++)
{
for (int hl = 0; hl < 1; hl++)
for (int inl = 0; inl < this->inlmax; inl++)
{
for (int inl = 0; inl < this->inlmax; inl++)
{
delete[] this->orbital_pdm_shell[iks][hl][inl];
}
delete[] this->orbital_pdm_shell[iks][hl];
delete[] this->orbital_pdm_shell[iks][inl];
}
delete[] this->orbital_pdm_shell[iks];
}
Expand Down
Loading

0 comments on commit 9bf2533

Please sign in to comment.