Skip to content

Commit 7b2476f

Browse files
support LKF optimizer
1 parent afd4746 commit 7b2476f

File tree

5 files changed

+15
-10
lines changed

5 files changed

+15
-10
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).
1919

2020
### Highlighted features
2121

22-
- **interfaced with multiple backends**, including TensorFlow, PyTorch, JAX and Paddle the most popular deep learning frameworks, making the training process highly automatic and efficient.
22+
- **interfaced with multiple backends**, including TensorFlow, PyTorch, JAX and Paddle, the most popular deep learning frameworks, making the training process highly automatic and efficient.
2323
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
2424
- **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
2525
- **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.

deepmd/pd/optimizer/KFWrapper.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,17 @@ def update_energy(
5858
mask = error < 0
5959

6060
error = error * update_prefactor
61-
error[mask] = -1 * error[mask]
61+
# error[mask] = -1 * error[mask]
62+
error = _mask_update(error, mask, -error[mask])
6263
error = error.mean()
6364

6465
if self.is_distributed:
6566
dist.all_reduce(error)
6667
error /= dist.get_world_size()
6768

6869
Etot_predict = update_prefactor * Etot_predict
69-
Etot_predict[mask] = -Etot_predict[mask]
70+
# Etot_predict[mask] = -Etot_predict[mask]
71+
Etot_predict = _mask_update(Etot_predict, mask, -Etot_predict[mask])
7072

7173
Etot_predict.sum().backward()
7274
error = error * math.sqrt(bs)
@@ -91,7 +93,7 @@ def update_force(
9193
error_tmp = Force_label[:, index[i]] - force_predict[:, index[i]]
9294
error_tmp = update_prefactor * error_tmp
9395
mask = error_tmp < 0
94-
error_tmp = _mask_update(error_tmp, mask, -1 * error_tmp[mask])
96+
error_tmp = _mask_update(error_tmp, mask, -error_tmp[mask])
9597
# error_tmp[mask] = -1 * error_tmp[mask]
9698
error = error_tmp.mean() / natoms_sum
9799

deepmd/pd/optimizer/LKF.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def __update(self, H, error, weights):
265265
def set_grad_prefactor(self, grad_prefactor):
266266
self.grad_prefactor = grad_prefactor
267267

268+
@paddle.no_grad()
268269
def step(self, error):
269270
params_packed_index = self._state.get("params_packed_index")
270271

deepmd/pd/train/training.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@
4242
get_model,
4343
get_zbl_model,
4444
)
45-
from deepmd.pd.optimizer import ( # LKFOptimizer,
45+
from deepmd.pd.optimizer import (
4646
KFOptimizerWrapper,
47+
LKFOptimizer,
4748
)
4849
from deepmd.pd.train.wrapper import (
4950
ModelWrapper,
@@ -601,10 +602,12 @@ def warm_up_linear(step, warmup_steps):
601602
if optimizer_state_dict is not None and self.restart_training:
602603
self.optimizer.set_state_dict(optimizer_state_dict)
603604
elif self.opt_type == "LKF":
604-
raise NotImplementedError("LKF is not supported yet in Paddle backend.")
605-
# self.optimizer = LKFOptimizer(
606-
# [{'params': self.wrapper.parameters()}], 0.98, 0.99870, self.opt_param["kf_blocksize"]
607-
# )
605+
self.optimizer = LKFOptimizer(
606+
[{"params": self.wrapper.parameters()}],
607+
0.98,
608+
0.99870,
609+
self.opt_param["kf_blocksize"],
610+
)
608611
else:
609612
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
610613

source/tests/pd/test_LKF.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
)
1212

1313

14-
@unittest.skip("Paddle do not support LKF now")
1514
class TestLKF(unittest.TestCase):
1615
def test_lkf(self):
1716
with open(str(Path(__file__).parent / "water/lkf.json")) as fin:

0 commit comments

Comments
 (0)