You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
- **New Features**
- Introduced support for the JAX backend, expanding user options for
model training and execution.
- Added installation instructions for JAX within the source installation
documentation.
- Included new environment variables related to JAX to enhance
configuration options.
- **Documentation Updates**
- Updated various documentation files to reflect the addition of JAX,
including sections on model commands, supported backends, and
environment variables.
- Enhanced documentation with a visual representation for JAX through an
icon.
- Improved clarity and organization of installation instructions for
DeePMD-kit.
- Updated the README to highlight JAX as a supported backend and reflect
changes in version history.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy file name to clipboardExpand all lines: README.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).
19
19
20
20
### Highlighted features
21
21
22
-
-**interfaced with multiple backends**, including TensorFlowand PyTorch, the most popular deep learning frameworks, making the training process highly automatic and efficient.
22
+
-**interfaced with multiple backends**, including TensorFlow, PyTorch, and JAX, the most popular deep learning frameworks, making the training process highly automatic and efficient.
23
23
-**interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
24
24
-**implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
25
25
-**implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.
@@ -72,7 +72,7 @@ See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all fea
72
72
73
73
#### v3
74
74
75
-
- Multiple backends supported. Add a PyTorch backend.
75
+
- Multiple backends supported. Add PyTorch and JAX backends.
Copy file name to clipboardExpand all lines: doc/backend.md
+9Lines changed: 9 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -23,6 +23,15 @@ DeePMD-kit does not use the TensorFlow v2 API but uses the TensorFlow v1 API (`t
23
23
[PyTorch](https://pytorch.org/) 2.0 or above is required.
24
24
While `.pth` and `.pt` are the same in the PyTorch package, they have different meanings in the DeePMD-kit to distinguish the model and the checkpoint.
25
25
26
+
### JAX {{ jax_icon }}
27
+
28
+
- Model filename extension: `.xlo`
29
+
- Checkpoint filename extension: `.jax`
30
+
31
+
[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
32
+
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
33
+
Currently, this backend is developed actively, and has no support for training and the C++ interface.
Copy file name to clipboardExpand all lines: doc/env.md
+1Lines changed: 1 addition & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -31,6 +31,7 @@ See [How to control the parallelism of a job](./troubleshooting/howtoset_num_nod
31
31
- If ROCm is used, [ROCm environment variables](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#environment-variables) can be used to control ROCm devices.
32
32
- {{ tensorflow_icon }} If TensorFlow is used, TensorFlow environment variables can be used.
33
33
- {{ pytorch_icon }} If PyTorch is used, [PyTorch environment variables](https://pytorch.org/docs/stable/torch_environment_variables.html) can be used.
34
+
- {{ jax_icon }} [`JAX_PLATFORMS`](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) and [`XLA_FLAGS`](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) are commonly used.
Copy file name to clipboardExpand all lines: doc/install/easy-install-dev.md
+1-3Lines changed: 1 addition & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,14 +16,12 @@ For CUDA 11.8 support, use the `devel_cu11` tag.
16
16
17
17
## Install with pip
18
18
19
-
Below is an one-line shell command to download the [artifact](https://nightly.link/deepmodeling/deepmd-kit/workflows/build_wheel/devel/artifact.zip) containing wheels and install it with `pip`:
19
+
Follow [the documentation for the stable version](easy-install.md#install-python-interface-with-pip), but add `--pre`and `--extra-index-url` options like below:
If you have no existing TensorFlow installed, you can use `pip` to install the pre-built package of the Python interface with CUDA 12 supported:
107
+
[Create a new environment](https://docs.deepmodeling.com/faq/conda.html#how-to-create-a-new-conda-pip-environment), and then execute the following command:
108
+
109
+
:::::::{tab-set}
110
+
111
+
::::::{tab-item} TensorFlow {{ tensorflow_icon }}
112
+
113
+
:::::{tab-set}
114
+
115
+
::::{tab-item} CUDA 12
108
116
109
117
```bash
110
-
pip install deepmd-kit[gpu,cu12,torch]
118
+
pip install deepmd-kit[gpu,cu12]
111
119
```
112
120
113
121
`cu12` is required only when CUDA Toolkit and cuDNN were not installed.
114
122
115
-
To install the package built against CUDA 11.8, use
[The LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md) are only provided on Linux and macOS for the TensorFlow backend. To install LAMMPS and/or i-PI, add `lmp` and/or `ipi` to extras:
130
144
131
145
```bash
132
-
pip install deepmd-kit[gpu,cu12,torch,lmp,ipi]
146
+
pip install deepmd-kit[gpu,cu12,lmp,ipi]
133
147
```
134
148
135
149
MPICH is required for parallel running.
136
150
137
-
:::{Warning}
138
-
When installing from pip, only the TensorFlow {{ tensorflow_icon }} backend is supported with LAMMPS and i-PI.
It is suggested to install the package into an isolated environment.
142
213
The supported platform includes Linux x86-64 and aarch64 with GNU C Library 2.28 or above, macOS x86-64 and arm64, and Windows x86-64.
143
-
A specific version of TensorFlow and PyTorch which is compatible with DeePMD-kit will be also installed.
144
214
145
215
:::{Warning}
146
-
If your platform is not supported, or you want to build against the installed TensorFlow, or you want to enable ROCM support, please [build from source](install-from-source.md).
216
+
If your platform is not supported, or you want to build against the installed backends, or you want to enable ROCM support, please [build from source](install-from-source.md).
Copy file name to clipboardExpand all lines: doc/install/install-from-source.md
+15Lines changed: 15 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -78,6 +78,21 @@ One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to instal
78
78
79
79
:::
80
80
81
+
:::{tab-item} JAX {{ jax_icon }}
82
+
83
+
To install [JAX AI Stack](https://github.com/jax-ml/jax-ai-stack), run
84
+
85
+
```sh
86
+
pip install jax-ai-stack
87
+
```
88
+
89
+
One can also install packages in JAX AI Stack manually.
90
+
Follow [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html) to install JAX built against different CUDA versions or without CUDA.
91
+
92
+
One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to install JAX from [conda-forge](https://conda-forge.org).
93
+
94
+
:::
95
+
81
96
::::
82
97
83
98
It is important that every time a new shell is started and one wants to use `DeePMD-kit`, the virtual environment should be activated by
Copy file name to clipboardExpand all lines: doc/model/sel.md
+8Lines changed: 8 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -24,6 +24,14 @@ dp --pt neighbor-stat -s data -r 6.0 -t O H
24
24
25
25
:::
26
26
27
+
:::{tab-item} JAX {{ jax_icon }}
28
+
29
+
```sh
30
+
dp --jax neighbor-stat -s data -r 6.0 -t O H
31
+
```
32
+
33
+
:::
34
+
27
35
::::
28
36
29
37
where `data` is the directory of data, `6.0` is the cutoff radius, and `O` and `H` is the type map. The program will give the `max_nbor_size`. For example, `max_nbor_size` of the water example is `[38, 72]`, meaning an atom may have 38 O neighbors and 72 H neighbors in the training data.
0 commit comments