Skip to content

Commit dd36e6c

Browse files
authored
docs: document JAX backend (#4259)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced support for the JAX backend, expanding user options for model training and execution. - Added installation instructions for JAX within the source installation documentation. - Included new environment variables related to JAX to enhance configuration options. - **Documentation Updates** - Updated various documentation files to reflect the addition of JAX, including sections on model commands, supported backends, and environment variables. - Enhanced documentation with a visual representation for JAX through an icon. - Improved clarity and organization of installation instructions for DeePMD-kit. - Updated the README to highlight JAX as a supported backend and reflect changes in version history. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 82aaa0d commit dd36e6c

File tree

14 files changed

+131
-28
lines changed

14 files changed

+131
-28
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).
1919

2020
### Highlighted features
2121

22-
- **interfaced with multiple backends**, including TensorFlow and PyTorch, the most popular deep learning frameworks, making the training process highly automatic and efficient.
22+
- **interfaced with multiple backends**, including TensorFlow, PyTorch, and JAX, the most popular deep learning frameworks, making the training process highly automatic and efficient.
2323
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
2424
- **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
2525
- **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.
@@ -72,7 +72,7 @@ See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all fea
7272

7373
#### v3
7474

75-
- Multiple backends supported. Add a PyTorch backend.
75+
- Multiple backends supported. Add PyTorch and JAX backends.
7676
- The DPA-2 model.
7777

7878
## Install and use DeePMD-kit

doc/_static/jax.svg

Lines changed: 1 addition & 0 deletions
Loading

doc/backend.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ DeePMD-kit does not use the TensorFlow v2 API but uses the TensorFlow v1 API (`t
2323
[PyTorch](https://pytorch.org/) 2.0 or above is required.
2424
While `.pth` and `.pt` are the same in the PyTorch package, they have different meanings in the DeePMD-kit to distinguish the model and the checkpoint.
2525

26+
### JAX {{ jax_icon }}
27+
28+
- Model filename extension: `.xlo`
29+
- Checkpoint filename extension: `.jax`
30+
31+
[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
32+
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
33+
Currently, this backend is developed actively, and has no support for training and the C++ interface.
34+
2635
### DP {{ dpmodel_icon }}
2736

2837
:::{note}

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
myst_substitutions = {
169169
"tensorflow_icon": """![TensorFlow](/_static/tensorflow.svg){class=platform-icon}""",
170170
"pytorch_icon": """![PyTorch](/_static/pytorch.svg){class=platform-icon}""",
171+
"jax_icon": """![JAX](/_static/jax.svg){class=platform-icon}""",
171172
"dpmodel_icon": """![DP](/_static/logo_icon.svg){class=platform-icon}""",
172173
}
173174

doc/env.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ See [How to control the parallelism of a job](./troubleshooting/howtoset_num_nod
3131
- If ROCm is used, [ROCm environment variables](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#environment-variables) can be used to control ROCm devices.
3232
- {{ tensorflow_icon }} If TensorFlow is used, TensorFlow environment variables can be used.
3333
- {{ pytorch_icon }} If PyTorch is used, [PyTorch environment variables](https://pytorch.org/docs/stable/torch_environment_variables.html) can be used.
34+
- {{ jax_icon }} [`JAX_PLATFORMS`](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) and [`XLA_FLAGS`](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) are commonly used.
3435

3536
## Python interface only
3637

doc/install/easy-install-dev.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@ For CUDA 11.8 support, use the `devel_cu11` tag.
1616

1717
## Install with pip
1818

19-
Below is an one-line shell command to download the [artifact](https://nightly.link/deepmodeling/deepmd-kit/workflows/build_wheel/devel/artifact.zip) containing wheels and install it with `pip`:
19+
Follow [the documentation for the stable version](easy-install.md#install-python-interface-with-pip), but add `--pre` and `--extra-index-url` options like below:
2020

2121
```sh
2222
pip install -U --pre deepmd-kit[gpu,cu12,lmp,torch] --extra-index-url https://deepmodeling.github.io/deepmd-kit/simple
2323
```
2424

25-
`cu12` and `lmp` are optional, which is the same as the stable version.
26-
2725
## Download pre-compiled C Library {{ tensorflow_icon }}
2826

2927
:::{note}

doc/install/easy-install.md

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,44 +104,114 @@ docker pull ghcr.io/deepmodeling/deepmd-kit:2.2.8_cuda12.0_gpu
104104

105105
## Install Python interface with pip
106106

107-
If you have no existing TensorFlow installed, you can use `pip` to install the pre-built package of the Python interface with CUDA 12 supported:
107+
[Create a new environment](https://docs.deepmodeling.com/faq/conda.html#how-to-create-a-new-conda-pip-environment), and then execute the following command:
108+
109+
:::::::{tab-set}
110+
111+
::::::{tab-item} TensorFlow {{ tensorflow_icon }}
112+
113+
:::::{tab-set}
114+
115+
::::{tab-item} CUDA 12
108116

109117
```bash
110-
pip install deepmd-kit[gpu,cu12,torch]
118+
pip install deepmd-kit[gpu,cu12]
111119
```
112120

113121
`cu12` is required only when CUDA Toolkit and cuDNN were not installed.
114122

115-
To install the package built against CUDA 11.8, use
123+
::::
124+
125+
::::{tab-item} CUDA 11
116126

117127
```bash
118-
pip install torch --index-url https://download.pytorch.org/whl/cu118
119128
pip install deepmd-kit-cu11[gpu,cu11]
120129
```
121130

122-
Or install the CPU version without CUDA supported:
131+
::::
132+
133+
::::{tab-item} CPU
123134

124135
```bash
125-
pip install torch --index-url https://download.pytorch.org/whl/cpu
126136
pip install deepmd-kit[cpu]
127137
```
128138

139+
::::
140+
141+
:::::
142+
129143
[The LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md) are only provided on Linux and macOS for the TensorFlow backend. To install LAMMPS and/or i-PI, add `lmp` and/or `ipi` to extras:
130144

131145
```bash
132-
pip install deepmd-kit[gpu,cu12,torch,lmp,ipi]
146+
pip install deepmd-kit[gpu,cu12,lmp,ipi]
133147
```
134148

135149
MPICH is required for parallel running.
136150

137-
:::{Warning}
138-
When installing from pip, only the TensorFlow {{ tensorflow_icon }} backend is supported with LAMMPS and i-PI.
139-
:::
151+
::::::
152+
153+
::::::{tab-item} PyTorch {{ pytorch_icon }}
154+
155+
:::::{tab-set}
156+
157+
::::{tab-item} CUDA 12
158+
159+
```bash
160+
pip install deepmd-kit[torch]
161+
```
162+
163+
::::
164+
165+
::::{tab-item} CUDA 11.8
166+
167+
```bash
168+
pip install torch --index-url https://download.pytorch.org/whl/cu118
169+
pip install deepmd-kit-cu11
170+
```
171+
172+
::::
173+
174+
::::{tab-item} CPU
175+
176+
```bash
177+
pip install torch --index-url https://download.pytorch.org/whl/cpu
178+
pip install deepmd-kit
179+
```
180+
181+
::::
182+
183+
:::::
184+
185+
::::::
186+
187+
::::::{tab-item} JAX {{ jax_icon }}
188+
189+
:::::{tab-set}
190+
191+
::::{tab-item} CUDA 12
192+
193+
```bash
194+
pip install deepmd-kit[jax] jax[cuda12]
195+
```
196+
197+
::::
198+
199+
::::{tab-item} CPU
200+
201+
```bash
202+
pip install deepmd-kit[jax]
203+
```
204+
205+
::::
206+
207+
:::::
208+
209+
::::::
210+
211+
:::::::
140212

141-
It is suggested to install the package into an isolated environment.
142213
The supported platform includes Linux x86-64 and aarch64 with GNU C Library 2.28 or above, macOS x86-64 and arm64, and Windows x86-64.
143-
A specific version of TensorFlow and PyTorch which is compatible with DeePMD-kit will be also installed.
144214

145215
:::{Warning}
146-
If your platform is not supported, or you want to build against the installed TensorFlow, or you want to enable ROCM support, please [build from source](install-from-source.md).
216+
If your platform is not supported, or you want to build against the installed backends, or you want to enable ROCM support, please [build from source](install-from-source.md).
147217
:::

doc/install/install-from-source.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,21 @@ One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to instal
7878

7979
:::
8080

81+
:::{tab-item} JAX {{ jax_icon }}
82+
83+
To install [JAX AI Stack](https://github.com/jax-ml/jax-ai-stack), run
84+
85+
```sh
86+
pip install jax-ai-stack
87+
```
88+
89+
One can also install packages in JAX AI Stack manually.
90+
Follow [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html) to install JAX built against different CUDA versions or without CUDA.
91+
92+
One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to install JAX from [conda-forge](https://conda-forge.org).
93+
94+
:::
95+
8196
::::
8297

8398
It is important that every time a new shell is started and one wants to use `DeePMD-kit`, the virtual environment should be activated by

doc/model/sel.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ dp --pt neighbor-stat -s data -r 6.0 -t O H
2424

2525
:::
2626

27+
:::{tab-item} JAX {{ jax_icon }}
28+
29+
```sh
30+
dp --jax neighbor-stat -s data -r 6.0 -t O H
31+
```
32+
33+
:::
34+
2735
::::
2836

2937
where `data` is the directory of data, `6.0` is the cutoff radius, and `O` and `H` is the type map. The program will give the `max_nbor_size`. For example, `max_nbor_size` of the water example is `[38, 72]`, meaning an atom may have 38 O neighbors and 72 H neighbors in the training data.

doc/model/train-energy.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
1+
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}
22

33
:::{note}
4-
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
4+
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
55
:::
66

77
In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.json` as an example of the input file.

0 commit comments

Comments
 (0)