|
1 | 1 | #include "cal_edm_tddft.h" |
2 | 2 |
|
| 3 | +#include "source_base/module_container/ATen/core/tensor.h" // For ct::Tensor |
| 4 | +#include "source_base/module_container/ATen/kernels/blas.h" |
| 5 | +#include "source_base/module_container/ATen/kernels/lapack.h" |
| 6 | +#include "source_base/module_container/ATen/kernels/memory.h" // memory operations (Tensor) |
| 7 | +#include "source_base/module_device/memory_op.h" // memory operations |
3 | 8 | #include "source_base/module_external/lapack_connector.h" |
4 | 9 | #include "source_base/module_external/scalapack_connector.h" |
5 | 10 | #include "source_io/module_parameter/parameter.h" // use PARAM.globalv |
| 11 | + |
6 | 12 | namespace elecstate |
7 | 13 | { |
| 14 | +void print_local_matrix(std::ostream& os, |
| 15 | + const std::complex<double>* matrix_data, |
| 16 | + int local_rows, // pv.nrow |
| 17 | + int local_cols, // pv.ncol |
| 18 | + const std::string& matrix_name = "", |
| 19 | + int rank = -1) |
| 20 | +{ |
| 21 | + if (!matrix_name.empty() || rank >= 0) |
| 22 | + { |
| 23 | + os << "=== "; |
| 24 | + if (!matrix_name.empty()) |
| 25 | + { |
| 26 | + os << "Matrix: " << matrix_name; |
| 27 | + if (rank >= 0) |
| 28 | + os << " "; |
| 29 | + } |
| 30 | + if (rank >= 0) |
| 31 | + { |
| 32 | + os << "(Process: " << rank + 1 << ")"; |
| 33 | + } |
| 34 | + os << " (Local dims: " << local_rows << " x " << local_cols << ") ===" << std::endl; |
| 35 | + } |
| 36 | + |
| 37 | + os << std::fixed << std::setprecision(10) << std::showpos; |
| 38 | + |
| 39 | + for (int i = 0; i < local_rows; ++i) // Iterate over rows (i) |
| 40 | + { |
| 41 | + for (int j = 0; j < local_cols; ++j) // Iterate over columns (j) |
| 42 | + { |
| 43 | + // For column-major storage, element (i, j) is at index i + j * LDA |
| 44 | + // where LDA (leading dimension) is typically the number of *rows* in the local block. |
| 45 | + int idx = i + j * local_rows; |
| 46 | + os << "(" << std::real(matrix_data[idx]) << "," << std::imag(matrix_data[idx]) << ") "; |
| 47 | + } |
| 48 | + os << std::endl; // New line after each row |
| 49 | + } |
| 50 | + os.unsetf(std::ios_base::fixed | std::ios_base::showpos); |
| 51 | + os << std::endl; |
| 52 | +} |
| 53 | + |
8 | 54 | // use the original formula (Hamiltonian matrix) to calculate energy density matrix |
9 | 55 | void cal_edm_tddft(Parallel_Orbitals& pv, |
10 | 56 | LCAO_domain::Setup_DM<std::complex<double>>& dmat, |
@@ -252,4 +298,260 @@ void cal_edm_tddft(Parallel_Orbitals& pv, |
252 | 298 | ModuleBase::timer::tick("elecstate", "cal_edm_tddft"); |
253 | 299 | return; |
254 | 300 | } // cal_edm_tddft |
| 301 | + |
| 302 | +void cal_edm_tddft_tensor(Parallel_Orbitals& pv, |
| 303 | + LCAO_domain::Setup_DM<std::complex<double>>& dmat, |
| 304 | + K_Vectors& kv, |
| 305 | + hamilt::Hamilt<std::complex<double>>* p_hamilt) |
| 306 | +{ |
| 307 | + ModuleBase::timer::tick("elecstate", "cal_edm_tddft_tensor"); |
| 308 | + |
| 309 | + const int nlocal = PARAM.globalv.nlocal; |
| 310 | + assert(nlocal >= 0); |
| 311 | + dmat.dm->EDMK.resize(kv.get_nks()); |
| 312 | + |
| 313 | + for (int ik = 0; ik < kv.get_nks(); ++ik) |
| 314 | + { |
| 315 | + p_hamilt->updateHk(ik); |
| 316 | + std::complex<double>* tmp_dmk = dmat.dm->get_DMK_pointer(ik); |
| 317 | + ModuleBase::ComplexMatrix& tmp_edmk = dmat.dm->EDMK[ik]; |
| 318 | + |
| 319 | +#ifdef __MPI |
| 320 | + const int nloc = pv.nloc; |
| 321 | + const int ncol = pv.ncol; |
| 322 | + const int nrow = pv.nrow; |
| 323 | + |
| 324 | + // Initialize EDMK matrix |
| 325 | + tmp_edmk.create(ncol, nrow); |
| 326 | + |
| 327 | + // Allocate Tensor objects on CPU |
| 328 | + ct::Tensor Htmp_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc})); |
| 329 | + Htmp_tensor.zero(); |
| 330 | + |
| 331 | + ct::Tensor Sinv_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc})); |
| 332 | + Sinv_tensor.zero(); |
| 333 | + |
| 334 | + ct::Tensor tmp1_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc})); |
| 335 | + tmp1_tensor.zero(); |
| 336 | + |
| 337 | + ct::Tensor tmp2_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc})); |
| 338 | + tmp2_tensor.zero(); |
| 339 | + |
| 340 | + ct::Tensor tmp3_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc})); |
| 341 | + tmp3_tensor.zero(); |
| 342 | + |
| 343 | + ct::Tensor tmp4_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nloc})); |
| 344 | + tmp4_tensor.zero(); |
| 345 | + |
| 346 | + // Get raw pointers from tensors for ScaLAPACK calls |
| 347 | + std::complex<double>* Htmp_ptr = Htmp_tensor.data<std::complex<double>>(); |
| 348 | + std::complex<double>* Sinv_ptr = Sinv_tensor.data<std::complex<double>>(); |
| 349 | + std::complex<double>* tmp1_ptr = tmp1_tensor.data<std::complex<double>>(); |
| 350 | + std::complex<double>* tmp2_ptr = tmp2_tensor.data<std::complex<double>>(); |
| 351 | + std::complex<double>* tmp3_ptr = tmp3_tensor.data<std::complex<double>>(); |
| 352 | + std::complex<double>* tmp4_ptr = tmp4_tensor.data<std::complex<double>>(); |
| 353 | + |
| 354 | + const int inc = 1; |
| 355 | + hamilt::MatrixBlock<std::complex<double>> h_mat; |
| 356 | + hamilt::MatrixBlock<std::complex<double>> s_mat; |
| 357 | + p_hamilt->matrix(h_mat, s_mat); |
| 358 | + |
| 359 | + // Copy Hamiltonian and Overlap matrices into Tensor buffers using BlasConnector |
| 360 | + BlasConnector::copy(nloc, h_mat.p, inc, Htmp_ptr, inc); |
| 361 | + BlasConnector::copy(nloc, s_mat.p, inc, Sinv_ptr, inc); |
| 362 | + |
| 363 | + // --- ScaLAPACK Inversion of S --- |
| 364 | + ct::Tensor ipiv_tensor(ct::DataType::DT_INT, |
| 365 | + ct::DeviceType::CpuDevice, |
| 366 | + ct::TensorShape({pv.nrow + pv.nb})); // Size for ScaLAPACK pivot array |
| 367 | + ipiv_tensor.zero(); |
| 368 | + int* ipiv_ptr = ipiv_tensor.data<int>(); |
| 369 | + |
| 370 | + int info = 0; |
| 371 | + const int one_int = 1; |
| 372 | + ScalapackConnector::getrf(nlocal, nlocal, Sinv_ptr, one_int, one_int, pv.desc, ipiv_ptr, &info); |
| 373 | + |
| 374 | + int lwork = -1; |
| 375 | + int liwork = -1; |
| 376 | + ct::Tensor work_query_tensor(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({1})); |
| 377 | + ct::Tensor iwork_query_tensor(ct::DataType::DT_INT, ct::DeviceType::CpuDevice, ct::TensorShape({1})); |
| 378 | + |
| 379 | + ScalapackConnector::getri(nlocal, |
| 380 | + Sinv_ptr, |
| 381 | + one_int, |
| 382 | + one_int, |
| 383 | + pv.desc, |
| 384 | + ipiv_ptr, |
| 385 | + work_query_tensor.data<std::complex<double>>(), |
| 386 | + &lwork, |
| 387 | + iwork_query_tensor.data<int>(), |
| 388 | + &liwork, |
| 389 | + &info); |
| 390 | + |
| 391 | + // Resize work arrays based on query results |
| 392 | + lwork = work_query_tensor.data<std::complex<double>>()[0].real(); |
| 393 | + work_query_tensor.resize(ct::TensorShape({lwork})); |
| 394 | + liwork = iwork_query_tensor.data<int>()[0]; |
| 395 | + iwork_query_tensor.resize(ct::TensorShape({liwork})); |
| 396 | + |
| 397 | + ScalapackConnector::getri(nlocal, |
| 398 | + Sinv_ptr, |
| 399 | + one_int, |
| 400 | + one_int, |
| 401 | + pv.desc, |
| 402 | + ipiv_ptr, |
| 403 | + work_query_tensor.data<std::complex<double>>(), |
| 404 | + &lwork, |
| 405 | + iwork_query_tensor.data<int>(), |
| 406 | + &liwork, |
| 407 | + &info); |
| 408 | + |
| 409 | + // --- EDM Calculation using ScaLAPACK --- |
| 410 | + const char N_char = 'N'; |
| 411 | + const char T_char = 'T'; |
| 412 | + const std::complex<double> one_complex = {1.0, 0.0}; |
| 413 | + const std::complex<double> zero_complex = {0.0, 0.0}; |
| 414 | + const std::complex<double> half_complex = {0.5, 0.0}; |
| 415 | + |
| 416 | + // tmp1 = Sinv * Htmp (result stored in tmp1) |
| 417 | + ScalapackConnector::gemm(N_char, |
| 418 | + N_char, |
| 419 | + nlocal, |
| 420 | + nlocal, |
| 421 | + nlocal, |
| 422 | + one_complex, |
| 423 | + Sinv_ptr, |
| 424 | + one_int, |
| 425 | + one_int, |
| 426 | + pv.desc, |
| 427 | + Htmp_ptr, |
| 428 | + one_int, |
| 429 | + one_int, |
| 430 | + pv.desc, |
| 431 | + zero_complex, |
| 432 | + tmp1_ptr, |
| 433 | + one_int, |
| 434 | + one_int, |
| 435 | + pv.desc); |
| 436 | + |
| 437 | + // tmp2 = tmp1 * tmp_dmk (result stored in tmp2) |
| 438 | + ScalapackConnector::gemm(N_char, |
| 439 | + N_char, |
| 440 | + nlocal, |
| 441 | + nlocal, |
| 442 | + nlocal, |
| 443 | + one_complex, |
| 444 | + tmp1_ptr, |
| 445 | + one_int, |
| 446 | + one_int, |
| 447 | + pv.desc, |
| 448 | + tmp_dmk, |
| 449 | + one_int, |
| 450 | + one_int, |
| 451 | + pv.desc, |
| 452 | + zero_complex, |
| 453 | + tmp2_ptr, |
| 454 | + one_int, |
| 455 | + one_int, |
| 456 | + pv.desc); |
| 457 | + |
| 458 | + // tmp3 = Htmp * Sinv (result stored in tmp3) |
| 459 | + ScalapackConnector::gemm(N_char, |
| 460 | + N_char, |
| 461 | + nlocal, |
| 462 | + nlocal, |
| 463 | + nlocal, |
| 464 | + one_complex, |
| 465 | + Htmp_ptr, |
| 466 | + one_int, |
| 467 | + one_int, |
| 468 | + pv.desc, |
| 469 | + Sinv_ptr, |
| 470 | + one_int, |
| 471 | + one_int, |
| 472 | + pv.desc, |
| 473 | + zero_complex, |
| 474 | + tmp3_ptr, |
| 475 | + one_int, |
| 476 | + one_int, |
| 477 | + pv.desc); |
| 478 | + |
| 479 | + // tmp4 = tmp_dmk * tmp3 (result stored in tmp4) |
| 480 | + ScalapackConnector::gemm(N_char, |
| 481 | + N_char, |
| 482 | + nlocal, |
| 483 | + nlocal, |
| 484 | + nlocal, |
| 485 | + one_complex, |
| 486 | + tmp_dmk, |
| 487 | + one_int, |
| 488 | + one_int, |
| 489 | + pv.desc, |
| 490 | + tmp3_ptr, |
| 491 | + one_int, |
| 492 | + one_int, |
| 493 | + pv.desc, |
| 494 | + zero_complex, |
| 495 | + tmp4_ptr, |
| 496 | + one_int, |
| 497 | + one_int, |
| 498 | + pv.desc); |
| 499 | + |
| 500 | + // tmp4 = 0.5 * tmp2 + 0.5 * tmp4 (final EDM contribution) |
| 501 | + ScalapackConnector::geadd(N_char, |
| 502 | + nlocal, |
| 503 | + nlocal, |
| 504 | + half_complex, |
| 505 | + tmp2_ptr, |
| 506 | + one_int, |
| 507 | + one_int, |
| 508 | + pv.desc, |
| 509 | + half_complex, |
| 510 | + tmp4_ptr, |
| 511 | + one_int, |
| 512 | + one_int, |
| 513 | + pv.desc); |
| 514 | + |
| 515 | + // Copy final result from Tensor buffer back to EDMK matrix |
| 516 | + BlasConnector::copy(nloc, tmp4_ptr, inc, tmp_edmk.c, inc); |
| 517 | + |
| 518 | +#else |
| 519 | + // Serial version remains unchanged, using ModuleBase::ComplexMatrix directly |
| 520 | + tmp_edmk.create(pv.ncol, pv.nrow); |
| 521 | + ModuleBase::ComplexMatrix Sinv(nlocal, nlocal); |
| 522 | + ModuleBase::ComplexMatrix Htmp(nlocal, nlocal); |
| 523 | + hamilt::MatrixBlock<std::complex<double>> h_mat; |
| 524 | + hamilt::MatrixBlock<std::complex<double>> s_mat; |
| 525 | + p_hamilt->matrix(h_mat, s_mat); |
| 526 | + for (int i = 0; i < nlocal; i++) |
| 527 | + { |
| 528 | + for (int j = 0; j < nlocal; j++) |
| 529 | + { |
| 530 | + Htmp(i, j) = h_mat.p[i * nlocal + j]; |
| 531 | + Sinv(i, j) = s_mat.p[i * nlocal + j]; |
| 532 | + } |
| 533 | + } |
| 534 | + int INFO = 0; |
| 535 | + int lwork = 3 * nlocal - 1; // tmp |
| 536 | + std::complex<double>* work = new std::complex<double>[lwork]; |
| 537 | + ModuleBase::GlobalFunc::ZEROS(work, lwork); |
| 538 | + int IPIV[nlocal]; |
| 539 | + LapackConnector::zgetrf(nlocal, nlocal, Sinv, nlocal, IPIV, &INFO); |
| 540 | + LapackConnector::zgetri(nlocal, Sinv, nlocal, IPIV, work, lwork, &INFO); |
| 541 | + ModuleBase::ComplexMatrix tmp_dmk_base(nlocal, nlocal); |
| 542 | + for (int i = 0; i < nlocal; i++) |
| 543 | + { |
| 544 | + for (int j = 0; j < nlocal; j++) |
| 545 | + { |
| 546 | + tmp_dmk_base(i, j) = tmp_dmk[i * nlocal + j]; |
| 547 | + } |
| 548 | + } |
| 549 | + tmp_edmk = 0.5 * (Sinv * Htmp * tmp_dmk_base + tmp_dmk_base * Htmp * Sinv); |
| 550 | + delete[] work; |
| 551 | +#endif |
| 552 | + } // end ik |
| 553 | + ModuleBase::timer::tick("elecstate", "cal_edm_tddft_tensor"); |
| 554 | + return; |
| 555 | +} // cal_edm_tddft_tensor |
| 556 | + |
255 | 557 | } // namespace elecstate |
0 commit comments