Skip to content

Commit 3f4dd49

Browse files
authored
Refactor: Support RT-TDDFT EDM calculation in Tensor (#6726)
1 parent 19a9e54 commit 3f4dd49

File tree

3 files changed

+309
-2
lines changed

3 files changed

+309
-2
lines changed

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::iter_finish(UnitCell& ucell,
318318
if (conv_esolver && estep == estep_max - 1 && istep >= (PARAM.inp.init_wfc == "file" ? 0 : 1)
319319
&& PARAM.inp.td_edm == 0)
320320
{
321-
elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, this->p_hamilt);
321+
elecstate::cal_edm_tddft_tensor(this->pv, this->dmat, this->kv, this->p_hamilt);
322322
}
323323
}
324324

source/source_estate/module_dm/cal_edm_tddft.cpp

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,56 @@
11
#include "cal_edm_tddft.h"
22

3+
#include "source_base/module_container/ATen/core/tensor.h" // For ct::Tensor
4+
#include "source_base/module_container/ATen/kernels/blas.h"
5+
#include "source_base/module_container/ATen/kernels/lapack.h"
6+
#include "source_base/module_container/ATen/kernels/memory.h" // memory operations (Tensor)
7+
#include "source_base/module_device/memory_op.h" // memory operations
38
#include "source_base/module_external/lapack_connector.h"
49
#include "source_base/module_external/scalapack_connector.h"
510
#include "source_io/module_parameter/parameter.h" // use PARAM.globalv
11+
612
namespace elecstate
713
{
14+
void print_local_matrix(std::ostream& os,
15+
const std::complex<double>* matrix_data,
16+
int local_rows, // pv.nrow
17+
int local_cols, // pv.ncol
18+
const std::string& matrix_name = "",
19+
int rank = -1)
20+
{
21+
if (!matrix_name.empty() || rank >= 0)
22+
{
23+
os << "=== ";
24+
if (!matrix_name.empty())
25+
{
26+
os << "Matrix: " << matrix_name;
27+
if (rank >= 0)
28+
os << " ";
29+
}
30+
if (rank >= 0)
31+
{
32+
os << "(Process: " << rank + 1 << ")";
33+
}
34+
os << " (Local dims: " << local_rows << " x " << local_cols << ") ===" << std::endl;
35+
}
36+
37+
os << std::fixed << std::setprecision(10) << std::showpos;
38+
39+
for (int i = 0; i < local_rows; ++i) // Iterate over rows (i)
40+
{
41+
for (int j = 0; j < local_cols; ++j) // Iterate over columns (j)
42+
{
43+
// For column-major storage, element (i, j) is at index i + j * LDA
44+
// where LDA (leading dimension) is typically the number of *rows* in the local block.
45+
int idx = i + j * local_rows;
46+
os << "(" << std::real(matrix_data[idx]) << "," << std::imag(matrix_data[idx]) << ") ";
47+
}
48+
os << std::endl; // New line after each row
49+
}
50+
os.unsetf(std::ios_base::fixed | std::ios_base::showpos);
51+
os << std::endl;
52+
}
53+
854
// use the original formula (Hamiltonian matrix) to calculate energy density matrix
955
void cal_edm_tddft(Parallel_Orbitals& pv,
1056
LCAO_domain::Setup_DM<std::complex<double>>& dmat,
@@ -252,4 +298,260 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
252298
ModuleBase::timer::tick("elecstate", "cal_edm_tddft");
253299
return;
254300
} // cal_edm_tddft
301+
302+
void cal_edm_tddft_tensor(Parallel_Orbitals& pv,
303+
LCAO_domain::Setup_DM<std::complex<double>>& dmat,
304+
K_Vectors& kv,
305+
hamilt::Hamilt<std::complex<double>>* p_hamilt)
306+
{
307+
ModuleBase::timer::tick("elecstate", "cal_edm_tddft_tensor");
308+
309+
const int nlocal = PARAM.globalv.nlocal;
310+
assert(nlocal >= 0);
311+
dmat.dm->EDMK.resize(kv.get_nks());
312+
313+
for (int ik = 0; ik < kv.get_nks(); ++ik)
314+
{
315+
p_hamilt->updateHk(ik);
316+
std::complex<double>* tmp_dmk = dmat.dm->get_DMK_pointer(ik);
317+
ModuleBase::ComplexMatrix& tmp_edmk = dmat.dm->EDMK[ik];
318+
319+
#ifdef __MPI
320+
const int nloc = pv.nloc;
321+
const int ncol = pv.ncol;
322+
const int nrow = pv.nrow;
323+
324+
// Initialize EDMK matrix
325+
tmp_edmk.create(ncol, nrow);
326+
327+
// Allocate Tensor objects on CPU
328+
ct::Tensor Htmp_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
329+
Htmp_tensor.zero();
330+
331+
ct::Tensor Sinv_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
332+
Sinv_tensor.zero();
333+
334+
ct::Tensor tmp1_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
335+
tmp1_tensor.zero();
336+
337+
ct::Tensor tmp2_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
338+
tmp2_tensor.zero();
339+
340+
ct::Tensor tmp3_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
341+
tmp3_tensor.zero();
342+
343+
ct::Tensor tmp4_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc}));
344+
tmp4_tensor.zero();
345+
346+
// Get raw pointers from tensors for ScaLAPACK calls
347+
std::complex<double>* Htmp_ptr = Htmp_tensor.data<std::complex<double>>();
348+
std::complex<double>* Sinv_ptr = Sinv_tensor.data<std::complex<double>>();
349+
std::complex<double>* tmp1_ptr = tmp1_tensor.data<std::complex<double>>();
350+
std::complex<double>* tmp2_ptr = tmp2_tensor.data<std::complex<double>>();
351+
std::complex<double>* tmp3_ptr = tmp3_tensor.data<std::complex<double>>();
352+
std::complex<double>* tmp4_ptr = tmp4_tensor.data<std::complex<double>>();
353+
354+
const int inc = 1;
355+
hamilt::MatrixBlock<std::complex<double>> h_mat;
356+
hamilt::MatrixBlock<std::complex<double>> s_mat;
357+
p_hamilt->matrix(h_mat, s_mat);
358+
359+
// Copy Hamiltonian and Overlap matrices into Tensor buffers using BlasConnector
360+
BlasConnector::copy(nloc, h_mat.p, inc, Htmp_ptr, inc);
361+
BlasConnector::copy(nloc, s_mat.p, inc, Sinv_ptr, inc);
362+
363+
// --- ScaLAPACK Inversion of S ---
364+
ct::Tensor ipiv_tensor(ct::DataType::DT_INT,
365+
ct::DeviceType::CpuDevice,
366+
ct::TensorShape({pv.nrow + pv.nb})); // Size for ScaLAPACK pivot array
367+
ipiv_tensor.zero();
368+
int* ipiv_ptr = ipiv_tensor.data<int>();
369+
370+
int info = 0;
371+
const int one_int = 1;
372+
ScalapackConnector::getrf(nlocal, nlocal, Sinv_ptr, one_int, one_int, pv.desc, ipiv_ptr, &info);
373+
374+
int lwork = -1;
375+
int liwork = -1;
376+
ct::Tensor work_query_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({1}));
377+
ct::Tensor iwork_query_tensor(ct::DataType::DT_INT, ct::DeviceType::CpuDevice, ct::TensorShape({1}));
378+
379+
ScalapackConnector::getri(nlocal,
380+
Sinv_ptr,
381+
one_int,
382+
one_int,
383+
pv.desc,
384+
ipiv_ptr,
385+
work_query_tensor.data<std::complex<double>>(),
386+
&lwork,
387+
iwork_query_tensor.data<int>(),
388+
&liwork,
389+
&info);
390+
391+
// Resize work arrays based on query results
392+
lwork = work_query_tensor.data<std::complex<double>>()[0].real();
393+
work_query_tensor.resize(ct::TensorShape({lwork}));
394+
liwork = iwork_query_tensor.data<int>()[0];
395+
iwork_query_tensor.resize(ct::TensorShape({liwork}));
396+
397+
ScalapackConnector::getri(nlocal,
398+
Sinv_ptr,
399+
one_int,
400+
one_int,
401+
pv.desc,
402+
ipiv_ptr,
403+
work_query_tensor.data<std::complex<double>>(),
404+
&lwork,
405+
iwork_query_tensor.data<int>(),
406+
&liwork,
407+
&info);
408+
409+
// --- EDM Calculation using ScaLAPACK ---
410+
const char N_char = 'N';
411+
const char T_char = 'T';
412+
const std::complex<double> one_complex = {1.0, 0.0};
413+
const std::complex<double> zero_complex = {0.0, 0.0};
414+
const std::complex<double> half_complex = {0.5, 0.0};
415+
416+
// tmp1 = Sinv * Htmp (result stored in tmp1)
417+
ScalapackConnector::gemm(N_char,
418+
N_char,
419+
nlocal,
420+
nlocal,
421+
nlocal,
422+
one_complex,
423+
Sinv_ptr,
424+
one_int,
425+
one_int,
426+
pv.desc,
427+
Htmp_ptr,
428+
one_int,
429+
one_int,
430+
pv.desc,
431+
zero_complex,
432+
tmp1_ptr,
433+
one_int,
434+
one_int,
435+
pv.desc);
436+
437+
// tmp2 = tmp1 * tmp_dmk (result stored in tmp2)
438+
ScalapackConnector::gemm(N_char,
439+
N_char,
440+
nlocal,
441+
nlocal,
442+
nlocal,
443+
one_complex,
444+
tmp1_ptr,
445+
one_int,
446+
one_int,
447+
pv.desc,
448+
tmp_dmk,
449+
one_int,
450+
one_int,
451+
pv.desc,
452+
zero_complex,
453+
tmp2_ptr,
454+
one_int,
455+
one_int,
456+
pv.desc);
457+
458+
// tmp3 = Htmp * Sinv (result stored in tmp3)
459+
ScalapackConnector::gemm(N_char,
460+
N_char,
461+
nlocal,
462+
nlocal,
463+
nlocal,
464+
one_complex,
465+
Htmp_ptr,
466+
one_int,
467+
one_int,
468+
pv.desc,
469+
Sinv_ptr,
470+
one_int,
471+
one_int,
472+
pv.desc,
473+
zero_complex,
474+
tmp3_ptr,
475+
one_int,
476+
one_int,
477+
pv.desc);
478+
479+
// tmp4 = tmp_dmk * tmp3 (result stored in tmp4)
480+
ScalapackConnector::gemm(N_char,
481+
N_char,
482+
nlocal,
483+
nlocal,
484+
nlocal,
485+
one_complex,
486+
tmp_dmk,
487+
one_int,
488+
one_int,
489+
pv.desc,
490+
tmp3_ptr,
491+
one_int,
492+
one_int,
493+
pv.desc,
494+
zero_complex,
495+
tmp4_ptr,
496+
one_int,
497+
one_int,
498+
pv.desc);
499+
500+
// tmp4 = 0.5 * tmp2 + 0.5 * tmp4 (final EDM contribution)
501+
ScalapackConnector::geadd(N_char,
502+
nlocal,
503+
nlocal,
504+
half_complex,
505+
tmp2_ptr,
506+
one_int,
507+
one_int,
508+
pv.desc,
509+
half_complex,
510+
tmp4_ptr,
511+
one_int,
512+
one_int,
513+
pv.desc);
514+
515+
// Copy final result from Tensor buffer back to EDMK matrix
516+
BlasConnector::copy(nloc, tmp4_ptr, inc, tmp_edmk.c, inc);
517+
518+
#else
519+
// Serial version remains unchanged, using ModuleBase::ComplexMatrix directly
520+
tmp_edmk.create(pv.ncol, pv.nrow);
521+
ModuleBase::ComplexMatrix Sinv(nlocal, nlocal);
522+
ModuleBase::ComplexMatrix Htmp(nlocal, nlocal);
523+
hamilt::MatrixBlock<std::complex<double>> h_mat;
524+
hamilt::MatrixBlock<std::complex<double>> s_mat;
525+
p_hamilt->matrix(h_mat, s_mat);
526+
for (int i = 0; i < nlocal; i++)
527+
{
528+
for (int j = 0; j < nlocal; j++)
529+
{
530+
Htmp(i, j) = h_mat.p[i * nlocal + j];
531+
Sinv(i, j) = s_mat.p[i * nlocal + j];
532+
}
533+
}
534+
int INFO = 0;
535+
int lwork = 3 * nlocal - 1; // tmp
536+
std::complex<double>* work = new std::complex<double>[lwork];
537+
ModuleBase::GlobalFunc::ZEROS(work, lwork);
538+
int IPIV[nlocal];
539+
LapackConnector::zgetrf(nlocal, nlocal, Sinv, nlocal, IPIV, &INFO);
540+
LapackConnector::zgetri(nlocal, Sinv, nlocal, IPIV, work, lwork, &INFO);
541+
ModuleBase::ComplexMatrix tmp_dmk_base(nlocal, nlocal);
542+
for (int i = 0; i < nlocal; i++)
543+
{
544+
for (int j = 0; j < nlocal; j++)
545+
{
546+
tmp_dmk_base(i, j) = tmp_dmk[i * nlocal + j];
547+
}
548+
}
549+
tmp_edmk = 0.5 * (Sinv * Htmp * tmp_dmk_base + tmp_dmk_base * Htmp * Sinv);
550+
delete[] work;
551+
#endif
552+
} // end ik
553+
ModuleBase::timer::tick("elecstate", "cal_edm_tddft_tensor");
554+
return;
555+
} // cal_edm_tddft_tensor
556+
255557
} // namespace elecstate

source/source_estate/module_dm/cal_edm_tddft.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@
99
namespace elecstate
1010
{
1111
void cal_edm_tddft(Parallel_Orbitals& pv,
12-
LCAO_domain::Setup_DM<std::complex<double>> &dmat,
12+
LCAO_domain::Setup_DM<std::complex<double>>& dmat,
1313
K_Vectors& kv,
1414
hamilt::Hamilt<std::complex<double>>* p_hamilt);
15+
16+
void cal_edm_tddft_tensor(Parallel_Orbitals& pv,
17+
LCAO_domain::Setup_DM<std::complex<double>>& dmat,
18+
K_Vectors& kv,
19+
hamilt::Hamilt<std::complex<double>>* p_hamilt);
1520
} // namespace elecstate
1621
#endif // CAL_EDM_TDDFT_H

0 commit comments

Comments
 (0)