Skip to content

Commit 04b2922

Browse files
authored
Fix: Fix the errors in building abacus with libtorch-gpu (#6554)
* Fix: Fix the errors in building abacus with libtorch-gpu * Fix: Mark submodule directory as safe for CI workflow * Fix: Update CI workflow to manually handle submodules and set safe directories * Fix: Update CI workflow to handle submodule ownership and initialization * Fix: Add conditional compilation for CUDA and ROCM in MLKEDF descriptor functions * Fix: Refactor Psi constructor to separate declaration and implementation for better readability and maintainability * Fix: Improve code readability by adding colons to constructor comments and formatting memory operations
1 parent 5411878 commit 04b2922

File tree

4 files changed

+80
-4
lines changed

4 files changed

+80
-4
lines changed

.github/workflows/test.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@ jobs:
1717
volumes:
1818
- /tmp/ccache:/github/home/.ccache
1919
steps:
20-
- name: Checkout
20+
- name: Checkout repository
2121
uses: actions/checkout@v5
2222
with:
23-
submodules: recursive
2423
fetch-depth: 0
24+
# We will handle submodules manually after fixing ownership
25+
submodules: 'false'
26+
27+
- name: Take ownership of the workspace and update submodules
28+
run: |
29+
sudo chown -R $(whoami) .
30+
git submodule update --init --recursive
2531
2632
- name: Install CI tools
2733
run: |

source/source_io/write_mlkedf_descriptors.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,53 @@ void Write_MLKEDF_Descriptors::generateTrainData_KS(
6767
delete ptempRho;
6868
}
6969

70+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
71+
const std::string& out_dir,
72+
psi::Psi<std::complex<float>> *psi,
73+
elecstate::ElecState *pelec,
74+
ModulePW::PW_Basis_K *pw_psi,
75+
ModulePW::PW_Basis *pw_rho,
76+
UnitCell& ucell,
77+
const double* veff
78+
)
79+
{
80+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_double(*psi);
81+
82+
this->generateTrainData_KS(out_dir, &psi_double, pelec, pw_psi, pw_rho, ucell, veff);
83+
}
84+
85+
#if ((defined __CUDA) || (defined __ROCM))
86+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
87+
const std::string& out_dir,
88+
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi,
89+
elecstate::ElecState *pelec,
90+
ModulePW::PW_Basis_K *pw_psi,
91+
ModulePW::PW_Basis *pw_rho,
92+
UnitCell& ucell,
93+
const double* veff
94+
)
95+
{
96+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_cpu(*psi);
97+
98+
this->generateTrainData_KS(out_dir, &psi_cpu, pelec, pw_psi, pw_rho, ucell, veff);
99+
}
100+
101+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
102+
const std::string& dir,
103+
psi::Psi<std::complex<float>, base_device::DEVICE_GPU>* psi,
104+
elecstate::ElecState *pelec,
105+
ModulePW::PW_Basis_K *pw_psi,
106+
ModulePW::PW_Basis *pw_rho,
107+
UnitCell& ucell,
108+
const double *veff
109+
)
110+
{
111+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_cpu_double(*psi);
112+
113+
this->generateTrainData_KS(dir, &psi_cpu_double, pelec, pw_psi, pw_rho, ucell, veff);
114+
}
115+
#endif
116+
70117
void Write_MLKEDF_Descriptors::generate_descriptor(
71118
const std::string& out_dir,
72119
const double * const *prho,

source/source_io/write_mlkedf_descriptors.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,29 @@ class Write_MLKEDF_Descriptors
4040
ModulePW::PW_Basis *pw_rho,
4141
UnitCell& ucell,
4242
const double *veff
43-
){} // a mock function
43+
);
44+
45+
#if ((defined __CUDA) || (defined __ROCM))
46+
void generateTrainData_KS(
47+
const std::string& dir,
48+
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi,
49+
elecstate::ElecState *pelec,
50+
ModulePW::PW_Basis_K *pw_psi,
51+
ModulePW::PW_Basis *pw_rho,
52+
UnitCell& ucell,
53+
const double *veff
54+
);
55+
void generateTrainData_KS(
56+
const std::string& dir,
57+
psi::Psi<std::complex<float>, base_device::DEVICE_GPU>* psi,
58+
elecstate::ElecState *pelec,
59+
ModulePW::PW_Basis_K *pw_psi,
60+
ModulePW::PW_Basis *pw_rho,
61+
UnitCell& ucell,
62+
const double *veff
63+
);
64+
#endif
65+
4466
void generate_descriptor(
4567
const std::string& out_dir,
4668
const double * const *prho,

source/source_psi/psi.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ Psi<T, Device>::Psi(const Psi& psi_in)
171171
this->psi_current = this->psi + psi_in.get_psi_bias();
172172
}
173173

174-
175174
// Constructor 2-2:
176175
template <typename T, typename Device>
177176
template <typename T_in, typename Device_in>
@@ -545,6 +544,8 @@ template Psi<double, base_device::DEVICE_CPU>::Psi(const Psi<double, base_device
545544
template Psi<double, base_device::DEVICE_GPU>::Psi(const Psi<double, base_device::DEVICE_CPU>&);
546545
template Psi<std::complex<double>, base_device::DEVICE_CPU>::Psi(
547546
const Psi<std::complex<double>, base_device::DEVICE_GPU>&);
547+
template Psi<std::complex<double>, base_device::DEVICE_CPU>::Psi(
548+
const Psi<std::complex<float>, base_device::DEVICE_GPU>&);
548549
template Psi<std::complex<double>, base_device::DEVICE_GPU>::Psi(
549550
const Psi<std::complex<double>, base_device::DEVICE_CPU>&);
550551
template Psi<std::complex<float>, base_device::DEVICE_GPU>::Psi(

0 commit comments

Comments
 (0)