Skip to content

Commit ee6535c

Browse files
dyzhengdyzheng
andauthored
Feature: support init_chg dm to restart with DMR (#6753)
* Feature: support `init_chg dm` * Fix: makefile and refactor * add test case for restart_dm * Fix: compiling error without MPI --------- Co-authored-by: dyzheng <zhengdy@bjaisi.com>
1 parent 3b58709 commit ee6535c

File tree

13 files changed

+341
-15
lines changed

13 files changed

+341
-15
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ OBJS_HCONTAINER=base_matrix.o\
356356
atom_pair.o\
357357
hcontainer.o\
358358
output_hcontainer.o\
359+
read_hcontainer.o\
359360
func_folding.o\
360361
func_transfer.o\
361362
transfer.o\

source/source_esolver/esolver_ks_lcao.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
7575

7676
LCAO_domain::set_psi_occ_dm_chg<TK>(this->kv, this->psi, this->pv, this->pelec,
7777
this->dmat, this->chr, inp);
78+
79+
if(inp.init_chg == "dm")
80+
{
81+
//! 4.1) init density matrix from file
82+
std::string dmfile = PARAM.globalv.global_readin_dir + "/dmrs1_nao.csr";
83+
LCAO_domain::init_dm_from_file<TK>(dmfile, this->dmat, ucell, &(this->pv));
84+
}
7885

7986
LCAO_domain::set_pot<TK>(ucell, this->kv, this->sf, *this->pw_rho, *this->pw_rhod,
8087
this->pelec, this->orb_, this->pv, this->locpp, this->dftu,
@@ -160,7 +167,12 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
160167
// 11) set xc type before the first cal of xc in pelec->init_scf, Peize Lin add 2016-12-03
161168
this->exx_nao.before_scf(ucell, this->kv, orb_, this->p_chgmix, istep, PARAM.inp);
162169

163-
// 12) init_scf, should be before_scf? mohan add 2025-03-10
170+
// 12.1) if init_chg = "dm", then calculate rho from readin DMR before init_scf
171+
if(PARAM.inp.init_chg == "dm")
172+
{
173+
LCAO_domain::dm2rho(this->dmat.dm->get_DMR_vector(), PARAM.inp.nspin, this->pelec->charge, true);
174+
}
175+
// 12.2) init_scf, should be before_scf? mohan add 2025-03-10
164176
this->pelec->init_scf(istep, ucell, this->Pgrid, this->sf.strucFac, this->locpp.numeric, ucell.symm);
165177

166178
// 13) initalize DM(R), which has the same size with Hamiltonian(R)
@@ -169,7 +181,7 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
169181
{
170182
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::before_scf","p_hamilt does not exist");
171183
}
172-
this->dmat.dm->init_DMR(*hamilt_lcao->getHR());
184+
if(PARAM.inp.init_chg != "dm") this->dmat.dm->init_DMR(*hamilt_lcao->getHR());
173185

174186
#ifdef __MLALGO
175187
// 14) initialize DM2(R) of DeePKS, the DM2(R) is different from DM(R)

source/source_io/read_input_item_system.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ void ReadInput::item_system()
552552
}
553553
};
554554
item.check_value = [](const Input_Item& item, const Parameter& para) {
555-
const std::vector<std::string> init_chgs = {"atomic", "file", "wfc", "auto"};
555+
const std::vector<std::string> init_chgs = {"atomic", "file", "wfc", "auto", "dm"};
556556
if (std::find(init_chgs.begin(), init_chgs.end(), para.input.init_chg) == init_chgs.end())
557557
{
558558
const std::string warningstr = nofound_str(init_chgs, "init_chg");

source/source_lcao/LCAO_set.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "source_psi/setup_psi.h" // use Setup_Psi
44
#include "source_io/read_wfc_nao.h" // use read_wfc_nao
55
#include "source_estate/elecstate_tools.h" // use fixed_weights
6+
#include "source_lcao/module_hcontainer/read_hcontainer.h"
67

78
template <typename TK>
89
void LCAO_domain::set_psi_occ_dm_chg(
@@ -91,6 +92,27 @@ void LCAO_domain::set_pot(
9192
return;
9293
}
9394

95+
template <typename TK>
96+
void LCAO_domain::init_dm_from_file(
97+
const std::string dmfile,
98+
LCAO_domain::Setup_DM<TK>& dmat,
99+
const UnitCell& ucell,
100+
const Parallel_Orbitals* pv)
101+
{
102+
ModuleBase::TITLE("LCAO_domain::init_dm_from_file", "init_dm_from_file");
103+
hamilt::HContainer<double>* dm_container = new hamilt::HContainer<double>(pv);
104+
dmat.dm->init_DMR(dm_container[0]);
105+
hamilt::Read_HContainer<double> reader_dm(
106+
dmat.dm->get_DMR_vector()[0],
107+
dmfile,
108+
PARAM.globalv.nlocal,
109+
&ucell
110+
);
111+
reader_dm.read();
112+
delete dm_container;
113+
return;
114+
}
115+
94116

95117

96118
template void LCAO_domain::set_psi_occ_dm_chg<double>(
@@ -142,3 +164,14 @@ template void LCAO_domain::set_pot<std::complex<double>>(
142164
Exx_NAO<std::complex<double>> &exx_nao,
143165
Setup_DeePKS<std::complex<double>> &deepks,
144166
const Input_para &inp);
167+
168+
template void LCAO_domain::init_dm_from_file<double>(
169+
const std::string dmfile,
170+
LCAO_domain::Setup_DM<double>& dmat,
171+
const UnitCell& ucell,
172+
const Parallel_Orbitals* pv);
173+
template void LCAO_domain::init_dm_from_file<std::complex<double>>(
174+
const std::string dmfile,
175+
LCAO_domain::Setup_DM<std::complex<double>>& dmat,
176+
const UnitCell& ucell,
177+
const Parallel_Orbitals* pv);

source/source_lcao/LCAO_set.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ void set_pot(
5353
Setup_DeePKS<TK> &deepks,
5454
const Input_para &inp);
5555

56+
template <typename TK>
57+
void init_dm_from_file(
58+
const std::string dmfile,
59+
LCAO_domain::Setup_DM<TK>& dmat,
60+
const UnitCell& ucell,
61+
const Parallel_Orbitals* pv);
5662
} // end namespace
5763

5864
#endif

source/source_lcao/module_hcontainer/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ list(APPEND objects
55
atom_pair.cpp
66
hcontainer.cpp
77
output_hcontainer.cpp
8+
read_hcontainer.cpp
89
func_folding.cpp
910
transfer.cpp
1011
func_transfer.cpp
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#include "read_hcontainer.h"
2+
3+
#include "source_io/sparse_matrix.h"
4+
#include "source_io/csr_reader.h"
5+
#include "hcontainer_funcs.h"
6+
7+
#include <fstream>
8+
9+
namespace hamilt
10+
{
11+
12+
/**
13+
* @brief Constructor of Read_HContainer
14+
* @attention ifs should be open outside of this interface
15+
*/
16+
template <typename T>
17+
Read_HContainer<T>::Read_HContainer(hamilt::HContainer<T>* hcontainer,
18+
const std::string& filename,
19+
const int nlocal,
20+
const UnitCell* ucell)
21+
: _hcontainer(hcontainer), _filename(filename), _nlocal(nlocal), _ucell(ucell)
22+
{
23+
}
24+
25+
template <typename T>
26+
void Read_HContainer<T>::read()
27+
{
28+
// build atom index of col and row
29+
std::vector<int> atom_index_row;
30+
std::vector<int> atom_index_col;
31+
int natom = this->_ucell->nat;
32+
Parallel_Orbitals pv_serial;
33+
pv_serial.set_serial(this->_nlocal, this->_nlocal);
34+
pv_serial.set_atomic_trace(this->_ucell->get_iat2iwt(), this->_ucell->nat, this->_nlocal);
35+
for (int iat = 0; iat < natom; ++iat)
36+
{
37+
int row_size = pv_serial.get_row_size(iat);
38+
int col_size = pv_serial.get_col_size(iat);
39+
for (int i = 0; i < row_size; ++i)
40+
{
41+
atom_index_row.push_back(iat);
42+
}
43+
for (int j = 0; j < col_size; ++j)
44+
{
45+
atom_index_col.push_back(iat);
46+
}
47+
}
48+
//
49+
hamilt::HContainer<T> hcontainer_serial(&pv_serial);
50+
51+
#ifdef __MPI
52+
if(GlobalV::MY_RANK == 0)
53+
{
54+
#endif
55+
ModuleIO::csrFileReader<T> csr(this->_filename);
56+
int step = csr.getStep();
57+
int matrix_dimension = csr.getMatrixDimension();
58+
int r_number = csr.getNumberOfR();
59+
60+
//construct serial hcontainer firstly
61+
// prepare atom index mapping from csr row/col to atom index
62+
for (int i = 0; i < r_number; i++)
63+
{
64+
std::vector<int> RCoord = csr.getRCoordinate(i);
65+
ModuleIO::SparseMatrix<T> sparse_matrix = csr.getMatrix(i);
66+
for (const auto& element: sparse_matrix.getElements())
67+
{
68+
int row = element.first.first;
69+
int col = element.first.second;
70+
T value = element.second;
71+
72+
73+
//insert into hcontainer
74+
int atom_i = atom_index_row[row];
75+
int atom_j = atom_index_col[col];
76+
auto* ij_pair = hcontainer_serial.find_pair(atom_i, atom_j);
77+
if(ij_pair == nullptr)
78+
{
79+
//insert new pair
80+
hamilt::AtomPair<T> new_pair(atom_i, atom_j, RCoord[0], RCoord[1], RCoord[2], &pv_serial);
81+
hcontainer_serial.insert_pair(new_pair);
82+
}
83+
else
84+
{
85+
if(ij_pair->find_R(RCoord[0], RCoord[1], RCoord[2]) == -1)
86+
{
87+
//insert new R
88+
hamilt::AtomPair<T> new_pair(atom_i, atom_j, RCoord[0], RCoord[1], RCoord[2], &pv_serial);
89+
hcontainer_serial.insert_pair(new_pair);
90+
}
91+
}
92+
}
93+
}
94+
hcontainer_serial.allocate(nullptr, true);
95+
// second loop, add values into hcontainer
96+
for (int i = 0; i < r_number; i++)
97+
{
98+
std::vector<int> RCoord = csr.getRCoordinate(i);
99+
ModuleIO::SparseMatrix<T> sparse_matrix = csr.getMatrix(i);
100+
for (const auto& element: sparse_matrix.getElements())
101+
{
102+
int row = element.first.first;
103+
int col = element.first.second;
104+
T value = element.second;
105+
106+
//insert into hcontainer
107+
int atom_i = atom_index_row[row];
108+
int atom_j = atom_index_col[col];
109+
auto* matrix = hcontainer_serial.find_matrix(atom_i, atom_j, RCoord[0], RCoord[1], RCoord[2]);
110+
matrix->add_element(row - pv_serial.atom_begin_row[atom_i],
111+
col - pv_serial.atom_begin_col[atom_j],
112+
value);
113+
}
114+
}
115+
#ifdef __MPI
116+
}
117+
// thirdly, distribute hcontainer_serial to parallel hcontainer
118+
// send <IJR>s from serial_rank to all ranks
119+
int my_rank, size;
120+
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
121+
MPI_Comm_size(MPI_COMM_WORLD, &size);
122+
std::vector<int> para_ijrs;
123+
if (my_rank == 0)
124+
{
125+
para_ijrs = hcontainer_serial.get_ijr_info();
126+
this->_hcontainer->insert_ijrs(&para_ijrs);
127+
this->_hcontainer->allocate();
128+
}
129+
if (my_rank != 0)
130+
{
131+
std::vector<int> tmp_ijrs;
132+
MPI_Status status;
133+
long tmp_size = 0;
134+
MPI_Recv(&tmp_size, 1, MPI_LONG, 0, 0, MPI_COMM_WORLD, &status);
135+
tmp_ijrs.resize(tmp_size);
136+
MPI_Recv(tmp_ijrs.data(),
137+
tmp_ijrs.size(),
138+
MPI_INT,
139+
0,
140+
1,
141+
MPI_COMM_WORLD,
142+
&status);
143+
this->_hcontainer->insert_ijrs(&tmp_ijrs);
144+
this->_hcontainer->allocate();
145+
}
146+
else
147+
{
148+
for (int i = 1; i < size; ++i)
149+
{
150+
long tmp_size = para_ijrs.size();
151+
MPI_Send(&tmp_size, 1, MPI_LONG, i, 0, MPI_COMM_WORLD);
152+
MPI_Send(para_ijrs.data(), para_ijrs.size(), MPI_INT, i, 1, MPI_COMM_WORLD);
153+
}
154+
}
155+
// gather values from serial_rank to Parallels
156+
transferSerial2Parallels(hcontainer_serial, this->_hcontainer, 0);
157+
#else
158+
std::vector<int> para_ijrs = hcontainer_serial.get_ijr_info();
159+
this->_hcontainer->insert_ijrs(&para_ijrs);
160+
this->_hcontainer->allocate();
161+
this->_hcontainer->add(hcontainer_serial);
162+
#endif
163+
164+
}
165+
166+
template class Read_HContainer<double>;
167+
template class Read_HContainer<std::complex<double>>;
168+
169+
} // namespace hamilt
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifndef READ_HCONTAINER_H
2+
#define READ_HCONTAINER_H
3+
4+
#include "source_lcao/module_hcontainer/hcontainer.h"
5+
#include "source_cell/unitcell.h"
6+
7+
namespace hamilt
8+
{
9+
10+
/**
11+
* @brief A class to read the HContainer
12+
*/
13+
template <typename T>
14+
class Read_HContainer
15+
{
16+
public:
17+
Read_HContainer(
18+
hamilt::HContainer<T>* hcontainer,
19+
const std::string& filename,
20+
const int nlocal,
21+
const UnitCell* ucell
22+
);
23+
// read the matrices of all R vectors to the read stream
24+
void read();
25+
26+
/**
27+
* read the matrix of a single R vector to the output stream
28+
* rx_in, ry_in, rz_in: the R vector from the input
29+
*/
30+
void read(int rx_in, int ry_in, int rz_in);
31+
32+
/**
33+
* read the matrix of a single R vector to the output stream
34+
* rx, ry, rz: the R vector from the HContainer
35+
*/
36+
void read_single_R(int rx, int ry, int rz);
37+
38+
private:
39+
hamilt::HContainer<T>* _hcontainer;
40+
std::string _filename;
41+
int _nlocal;
42+
const UnitCell* _ucell;
43+
};
44+
45+
} // namespace hamilt
46+
47+
#endif // OUTPUT_HCONTAINER_H

source/source_lcao/rho_tau_lcao.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
void LCAO_domain::dm2rho(std::vector<hamilt::HContainer<double>*> &dmr,
66
const int nspin,
7-
Charge* chr)
7+
Charge* chr,
8+
bool skip_normalize)
89
{
910
ModuleBase::TITLE("LCAO_domain", "dm2rho");
1011
ModuleBase::timer::tick("LCAO_domain", "dm2rho");
@@ -16,7 +17,7 @@ void LCAO_domain::dm2rho(std::vector<hamilt::HContainer<double>*> &dmr,
1617

1718
ModuleGint::cal_gint_rho(dmr, nspin, chr->rho);
1819

19-
chr->renormalize_rho();
20+
if(!skip_normalize)chr->renormalize_rho();
2021

2122
// should be moved somewhere else, mohan 20251024
2223
if (XC_Functional::get_ked_flag())

source/source_lcao/rho_tau_lcao.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ namespace LCAO_domain
99
{
1010
void dm2rho(std::vector<hamilt::HContainer<double>*> &dmr,
1111
const int nspin,
12-
Charge* chr);
12+
Charge* chr,
13+
bool skip_normalize = false);
1314

1415
void dm2tau(std::vector<hamilt::HContainer<double>*> &dmr,
1516
const int nspin,

0 commit comments

Comments
 (0)