Skip to content

Commit f81adcc

Browse files
committed
Optimize snap_psibeta_half_tddft for better performance
Key optimizations include: 1. Cache Gauss-Legendre Grid: Implemented thread-safe caching for radial integration grid points and weights. 2. Precompute Spherical Harmonics: Precomputed Ylm and A dot r on the Lebedev angular grid to reduce inner loop overhead. 3. Memory Optimization: Lifted vector allocations out of loops and reused buffers to minimize allocation costs. 4. Inline Interpolation: Inlined polynomial interpolation and used precomputed inverse step size to avoid divisions. 5. Common Subexpression Elimination: Extracted invariant factors from the innermost loop to reduce arithmetic operations.
1 parent f1c5f1b commit f81adcc

File tree

1 file changed

+105
-23
lines changed

1 file changed

+105
-23
lines changed

source/source_lcao/module_rt/snap_psibeta_half_tddft.cpp

Lines changed: 105 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,49 @@ void snap_psibeta_half_tddft(const LCAO_Orbitals& orb,
106106
const int mesh_r1 = orb.Phi[T1].PhiLN(L1, N1).getNr();
107107
const double* psi_1 = orb.Phi[T1].PhiLN(L1, N1).getPsi();
108108
const double dk_1 = orb.Phi[T1].PhiLN(L1, N1).getDk();
109+
const double inv_dk_1 = 1.0 / dk_1;
109110

110111
const int ridial_grid_num = 140;
111112
const int angular_grid_num = 110;
112113
std::vector<double> r_ridial(ridial_grid_num);
113114
std::vector<double> weights_ridial(ridial_grid_num);
114115

116+
// OPTIMIZATION START: Cache standard Gauss-Legendre grid
117+
static std::vector<double> gl_x(ridial_grid_num);
118+
static std::vector<double> gl_w(ridial_grid_num);
119+
static bool gl_init = false;
120+
121+
// Thread-safe initialization
122+
if (!gl_init)
123+
{
124+
#pragma omp critical(init_gauss_legendre)
125+
{
126+
if (!gl_init)
127+
{
128+
ModuleBase::Integral::Gauss_Legendre_grid_and_weight(ridial_grid_num, gl_x.data(), gl_w.data());
129+
gl_init = true;
130+
}
131+
}
132+
}
133+
// OPTIMIZATION END
134+
135+
// Precompute A dot r_angular for Lebedev grid
136+
std::vector<double> A_dot_lebedev(angular_grid_num);
137+
for (int ian = 0; ian < angular_grid_num; ++ian)
138+
{
139+
A_dot_lebedev[ian] = A.x * ModuleBase::Integral::Lebedev_Laikov_grid110_x[ian] +
140+
A.y * ModuleBase::Integral::Lebedev_Laikov_grid110_y[ian] +
141+
A.z * ModuleBase::Integral::Lebedev_Laikov_grid110_z[ian];
142+
}
143+
144+
// Buffers to reuse
145+
std::vector<std::complex<double>> result_angular;
146+
std::vector<std::complex<double>> result_angular_r_commu_x;
147+
std::vector<std::complex<double>> result_angular_r_commu_y;
148+
std::vector<std::complex<double>> result_angular_r_commu_z;
149+
std::vector<double> rly1;
150+
std::vector<std::vector<double>> rly0_all(angular_grid_num);
151+
115152
int index = 0;
116153
for (int nb = 0; nb < nproj; nb++)
117154
{
@@ -128,28 +165,52 @@ void snap_psibeta_half_tddft(const LCAO_Orbitals& orb,
128165
const double dk_0 = infoNL_.Beta[T0].Proj[nb].getDk();
129166

130167
const double Rcut0 = infoNL_.Beta[T0].Proj[nb].getRcut();
131-
ModuleBase::Integral::Gauss_Legendre_grid_and_weight(radial0[0],
132-
radial0[mesh_r0 - 1],
133-
ridial_grid_num,
134-
r_ridial.data(),
135-
weights_ridial.data());
168+
169+
// OPTIMIZATION: Use precomputed standard Gauss-Legendre grid
170+
double r_min = radial0[0];
171+
double r_max = radial0[mesh_r0 - 1];
172+
double xl = (r_max - r_min) * 0.5;
173+
double xmean = (r_max + r_min) * 0.5;
174+
175+
for(int i=0; i<ridial_grid_num; ++i)
176+
{
177+
r_ridial[i] = xmean + xl * gl_x[i];
178+
weights_ridial[i] = xl * gl_w[i];
179+
}
136180

137181
const double A_phase = A * R0;
138182
const std::complex<double> exp_iAR0 = std::exp(ModuleBase::IMAG_UNIT * A_phase);
139183

140-
std::vector<double> rly0(L0);
141-
std::vector<double> rly1(L1);
184+
// Precompute rly0 for all angular points
185+
for(int ian = 0; ian < angular_grid_num; ++ian) {
186+
ModuleBase::Ylm::rl_sph_harm(L0,
187+
ModuleBase::Integral::Lebedev_Laikov_grid110_x[ian],
188+
ModuleBase::Integral::Lebedev_Laikov_grid110_y[ian],
189+
ModuleBase::Integral::Lebedev_Laikov_grid110_z[ian],
190+
rly0_all[ian]);
191+
}
192+
193+
// Resize buffers
194+
if (result_angular.size() < 2 * L0 + 1)
195+
{
196+
result_angular.resize(2 * L0 + 1);
197+
if (calc_r)
198+
{
199+
result_angular_r_commu_x.resize(2 * L0 + 1);
200+
result_angular_r_commu_y.resize(2 * L0 + 1);
201+
result_angular_r_commu_z.resize(2 * L0 + 1);
202+
}
203+
}
204+
142205
for (int ir = 0; ir < ridial_grid_num; ir++)
143206
{
144-
std::vector<std::complex<double>> result_angular(2 * L0 + 1, 0.0);
145-
std::vector<std::complex<double>> result_angular_r_commu_x;
146-
std::vector<std::complex<double>> result_angular_r_commu_y;
147-
std::vector<std::complex<double>> result_angular_r_commu_z;
207+
// Reset result accumulators
208+
std::fill(result_angular.begin(), result_angular.begin() + (2 * L0 + 1), 0.0);
148209
if (calc_r)
149210
{
150-
result_angular_r_commu_x.resize(2 * L0 + 1, 0.0);
151-
result_angular_r_commu_y.resize(2 * L0 + 1, 0.0);
152-
result_angular_r_commu_z.resize(2 * L0 + 1, 0.0);
211+
std::fill(result_angular_r_commu_x.begin(), result_angular_r_commu_x.begin() + (2 * L0 + 1), 0.0);
212+
std::fill(result_angular_r_commu_y.begin(), result_angular_r_commu_y.begin() + (2 * L0 + 1), 0.0);
213+
std::fill(result_angular_r_commu_z.begin(), result_angular_r_commu_z.begin() + (2 * L0 + 1), 0.0);
153214
}
154215

155216
for (int ian = 0; ian < angular_grid_num; ian++)
@@ -158,11 +219,13 @@ void snap_psibeta_half_tddft(const LCAO_Orbitals& orb,
158219
const double y = ModuleBase::Integral::Lebedev_Laikov_grid110_y[ian];
159220
const double z = ModuleBase::Integral::Lebedev_Laikov_grid110_z[ian];
160221
const double weights_angular = ModuleBase::Integral::Lebedev_Laikov_grid110_w[ian];
161-
const ModuleBase::Vector3<double> r_angular_tmp(x, y, z);
162-
163-
const ModuleBase::Vector3<double> r_coor = r_ridial[ir] * r_angular_tmp;
222+
223+
const double r_val = r_ridial[ir];
224+
const ModuleBase::Vector3<double> r_coor(r_val * x, r_val * y, r_val * z);
225+
164226
const ModuleBase::Vector3<double> tmp_r_coor = r_coor + dRa;
165227
const double tmp_r_coor_norm = tmp_r_coor.norm();
228+
166229
if (tmp_r_coor_norm > Rcut1)
167230
{
168231
continue;
@@ -174,21 +237,40 @@ void snap_psibeta_half_tddft(const LCAO_Orbitals& orb,
174237
tmp_r_unit = tmp_r_coor / tmp_r_coor_norm;
175238
}
176239

177-
ModuleBase::Ylm::rl_sph_harm(L0, x, y, z, rly0);
240+
const std::vector<double>& rly0_vec = rly0_all[ian];
178241

179242
ModuleBase::Ylm::rl_sph_harm(L1, tmp_r_unit.x, tmp_r_unit.y, tmp_r_unit.z, rly1);
180243

181-
const double phase = A * r_coor;
244+
const double phase = r_val * A_dot_lebedev[ian];
182245
const std::complex<double> exp_iAr = std::exp(ModuleBase::IMAG_UNIT * phase);
183246

184247
const ModuleBase::Vector3<double> tmp_r_coor_r_commu = r_coor + R0;
185-
const double interp_v = ModuleBase::PolyInt::Polynomial_Interpolation(psi_1,
186-
mesh_r1, dk_1, tmp_r_coor_norm);
248+
249+
// OPTIMIZATION: Inline Polynomial Interpolation
250+
double position = tmp_r_coor_norm * inv_dk_1;
251+
int iq = static_cast<int>(position);
252+
double interp_v = 0.0;
253+
254+
if (iq <= mesh_r1 - 4)
255+
{
256+
const double x0 = position - static_cast<double>(iq);
257+
const double x1 = 1.0 - x0;
258+
const double x2 = 2.0 - x0;
259+
const double x3 = 3.0 - x0;
260+
interp_v = x1*x2*(psi_1[iq]*x3+psi_1[iq+3]*x0)/6.0
261+
+ x0*x3*(psi_1[iq+1]*x2-psi_1[iq+2]*x1)/2.0;
262+
}
263+
264+
const double weight_interp = interp_v * weights_angular;
265+
const int offset_L0 = L0 * L0;
266+
const int offset_L1 = L1 * L1 + m1;
267+
const double rly1_val = rly1[offset_L1];
268+
269+
const std::complex<double> common_factor = exp_iAr * rly1_val * weight_interp;
187270

188271
for (int m0 = 0; m0 < 2 * L0 + 1; m0++)
189272
{
190-
std::complex<double> temp = exp_iAr * rly0[L0 * L0 + m0] * rly1[L1 * L1 + m1]
191-
* interp_v * weights_angular;
273+
std::complex<double> temp = common_factor * rly0_vec[offset_L0 + m0];
192274
result_angular[m0] += temp;
193275

194276
if (calc_r)

0 commit comments

Comments
 (0)