Skip to content

Commit 6b96629

Browse files
authored
Feature: Support RT-TDDFT EDM calculation on GPU (#6762)
1 parent 6a94875 commit 6b96629

File tree

5 files changed

+360
-69
lines changed

5 files changed

+360
-69
lines changed

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,13 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
264264
{
265265
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
266266
hsolver::HSolverLCAO<std::complex<double>> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver);
267-
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, *this->dmat.dm,
268-
this->chr, PARAM.inp.nspin, skip_charge);
267+
hsolver_lcao_obj.solve(this->p_hamilt,
268+
this->psi[0],
269+
this->pelec,
270+
*this->dmat.dm,
271+
this->chr,
272+
PARAM.inp.nspin,
273+
skip_charge);
269274
}
270275
}
271276

@@ -318,7 +323,14 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::iter_finish(UnitCell& ucell,
318323
if (conv_esolver && estep == estep_max - 1 && istep >= (PARAM.inp.init_wfc == "file" ? 0 : 1)
319324
&& PARAM.inp.td_edm == 0)
320325
{
321-
elecstate::cal_edm_tddft_tensor(this->pv, this->dmat, this->kv, this->p_hamilt);
326+
if (use_tensor && use_lapack)
327+
{
328+
elecstate::cal_edm_tddft_tensor_lapack<Device>(this->pv, this->dmat, this->kv, this->p_hamilt);
329+
}
330+
else
331+
{
332+
elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, this->p_hamilt);
333+
}
322334
}
323335
}
324336

@@ -434,7 +446,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::store_h_s_psi(UnitCell& ucell,
434446
1);
435447
} // end use_tensor
436448
} // end ik
437-
}// conv_esolver
449+
} // conv_esolver
438450
}
439451

440452
template <typename TR, typename Device>
@@ -483,7 +495,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::weight_dm_rho(const UnitCell& ucell)
483495
elecstate::calEBand(this->pelec->ekb, this->pelec->wg, this->pelec->f_en);
484496

485497
elecstate::cal_dm_psi(this->dmat.dm->get_paraV_pointer(), this->pelec->wg, this->psi[0], *this->dmat.dm);
486-
if(PARAM.inp.td_stype == 2)
498+
if (PARAM.inp.td_stype == 2)
487499
{
488500
this->dmat.dm->cal_DMR_td(ucell, TD_info::cart_At);
489501
}

0 commit comments

Comments
 (0)