Commit cf70d94
perf: use
The benchmark result shows an overhead introduced by slicing operators
in the current implementation.
This PR replaces slicing for each tensor with a unified `torch.split`
op.
It brings a speed-up of 6.7% while improves the code readability.
Tested on OMat with 9 DPA-3 layers and batch size=auto:512.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
- **Refactor**
- Streamlined the internal logic for processing numerical components to
reduce complexity.
- Enhanced internal validation checks related to the `bias` variable to
boost overall system robustness and maintainability.
- Updated method signatures for `optim_angle_update` and
`optim_edge_update` to improve clarity and usability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Chun Cai <amoycaic@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>torch.split in replace of slicing ops in repflow (#4687)1 parent a1b5089 commit cf70d94
File tree
2 files changed
+60
-96
lines changed- deepmd
- dpmodel/descriptor
- pt/model/descriptor
2 files changed
+60
-96
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
796 | 796 | | |
797 | 797 | | |
798 | 798 | | |
799 | | - | |
800 | | - | |
801 | | - | |
802 | | - | |
803 | | - | |
804 | | - | |
805 | | - | |
806 | | - | |
807 | | - | |
808 | | - | |
809 | 799 | | |
810 | 800 | | |
| 801 | + | |
811 | 802 | | |
812 | 803 | | |
| 804 | + | |
813 | 805 | | |
814 | 806 | | |
815 | 807 | | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
816 | 813 | | |
| 814 | + | |
| 815 | + | |
| 816 | + | |
| 817 | + | |
| 818 | + | |
| 819 | + | |
| 820 | + | |
817 | 821 | | |
818 | 822 | | |
819 | | - | |
820 | | - | |
821 | | - | |
822 | | - | |
| 823 | + | |
823 | 824 | | |
824 | | - | |
825 | | - | |
826 | | - | |
827 | | - | |
| 825 | + | |
828 | 826 | | |
829 | | - | |
830 | | - | |
831 | | - | |
832 | | - | |
833 | | - | |
834 | | - | |
| 827 | + | |
| 828 | + | |
835 | 829 | | |
836 | 830 | | |
837 | 831 | | |
| |||
851 | 845 | | |
852 | 846 | | |
853 | 847 | | |
854 | | - | |
855 | | - | |
856 | | - | |
857 | | - | |
858 | | - | |
859 | 848 | | |
860 | 849 | | |
861 | 850 | | |
862 | 851 | | |
863 | 852 | | |
864 | 853 | | |
865 | 854 | | |
866 | | - | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
| 859 | + | |
| 860 | + | |
| 861 | + | |
867 | 862 | | |
868 | 863 | | |
869 | | - | |
870 | | - | |
871 | | - | |
| 864 | + | |
872 | 865 | | |
873 | 866 | | |
874 | | - | |
875 | | - | |
876 | | - | |
| 867 | + | |
877 | 868 | | |
878 | 869 | | |
879 | 870 | | |
880 | 871 | | |
881 | | - | |
882 | | - | |
883 | | - | |
| 872 | + | |
884 | 873 | | |
885 | 874 | | |
886 | 875 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
397 | 397 | | |
398 | 398 | | |
399 | 399 | | |
400 | | - | |
401 | | - | |
402 | | - | |
403 | | - | |
404 | | - | |
405 | | - | |
406 | | - | |
407 | | - | |
408 | | - | |
409 | | - | |
410 | | - | |
411 | 400 | | |
| 401 | + | |
412 | 402 | | |
413 | 403 | | |
| 404 | + | |
414 | 405 | | |
415 | 406 | | |
416 | 407 | | |
417 | | - | |
| 408 | + | |
418 | 409 | | |
419 | | - | |
420 | | - | |
421 | | - | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
422 | 416 | | |
423 | 417 | | |
| 418 | + | |
| 419 | + | |
424 | 420 | | |
425 | | - | |
426 | | - | |
427 | | - | |
428 | | - | |
| 421 | + | |
429 | 422 | | |
430 | | - | |
431 | | - | |
432 | | - | |
433 | | - | |
434 | | - | |
435 | | - | |
| 423 | + | |
| 424 | + | |
436 | 425 | | |
437 | 426 | | |
438 | 427 | | |
439 | | - | |
440 | | - | |
441 | | - | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
442 | 431 | | |
443 | 432 | | |
444 | 433 | | |
| |||
451 | 440 | | |
452 | 441 | | |
453 | 442 | | |
454 | | - | |
455 | | - | |
456 | | - | |
457 | | - | |
458 | | - | |
459 | | - | |
460 | 443 | | |
461 | 444 | | |
462 | 445 | | |
463 | 446 | | |
464 | 447 | | |
465 | 448 | | |
466 | | - | |
| 449 | + | |
467 | 450 | | |
468 | | - | |
469 | | - | |
470 | | - | |
471 | | - | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
472 | 455 | | |
| 456 | + | |
| 457 | + | |
473 | 458 | | |
474 | | - | |
475 | | - | |
476 | | - | |
| 459 | + | |
477 | 460 | | |
478 | 461 | | |
479 | | - | |
480 | 462 | | |
481 | | - | |
482 | | - | |
483 | | - | |
| 463 | + | |
484 | 464 | | |
485 | 465 | | |
486 | | - | |
487 | | - | |
488 | | - | |
489 | | - | |
| 466 | + | |
490 | 467 | | |
491 | 468 | | |
492 | 469 | | |
| |||
614 | 591 | | |
615 | 592 | | |
616 | 593 | | |
617 | | - | |
| 594 | + | |
618 | 595 | | |
619 | 596 | | |
620 | 597 | | |
| |||
649 | 626 | | |
650 | 627 | | |
651 | 628 | | |
652 | | - | |
653 | | - | |
| 629 | + | |
| 630 | + | |
654 | 631 | | |
655 | 632 | | |
656 | 633 | | |
657 | 634 | | |
658 | 635 | | |
659 | | - | |
| 636 | + | |
660 | 637 | | |
661 | 638 | | |
662 | 639 | | |
| |||
704 | 681 | | |
705 | 682 | | |
706 | 683 | | |
707 | | - | |
708 | | - | |
709 | | - | |
| 684 | + | |
710 | 685 | | |
711 | 686 | | |
712 | 687 | | |
| |||
0 commit comments