Skip to content

Commit 058e066

Browse files
feat(jax): add options to use TensorFlow C library to build the JAX backend (#4357)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Updated installation documentation to clarify requirements for TensorFlow and JAX backends. - Expanded supported platforms to include Windows x86-64. - Added instructions for enabling JAX backend during installation from source. - **Documentation** - Enhanced clarity of installation prerequisites and supported platforms. - Included a note directing users to the TensorFlow tab for additional information. - **Bug Fixes** - Improved error handling for unsupported backend configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent d5295d5 commit 058e066

File tree

8 files changed

+111
-5
lines changed

8 files changed

+111
-5
lines changed

doc/install/easy-install.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ pip install deepmd-kit[jax]
206206

207207
:::::
208208

209+
To generate a SavedModel and use [the LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md),
210+
you need to install the TensorFlow.
211+
Switch to the TensorFlow {{ tensorflow_icon }} tab for more information.
212+
209213
::::::
210214

211215
:::::::

doc/install/install-from-c-library.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Install from pre-compiled C library {{ tensorflow_icon }}
1+
# Install from pre-compiled C library {{ tensorflow_icon }}, JAX {{ jax_icon }}
22

33
:::{note}
4-
**Supported backends**: TensorFlow {{ tensorflow_icon }}
4+
**Supported backends**: TensorFlow {{ tensorflow_icon }}, JAX {{ jax_icon }}
55
:::
66

77
DeePMD-kit provides pre-compiled C library package (`libdeepmd_c.tar.gz`) in each [release](https://github.com/deepmodeling/deepmd-kit/releases). It can be used to build the [LAMMPS plugin](./install-lammps.md) and [GROMACS patch](./install-gromacs.md), as well as many [third-party software packages](../third-party/out-of-deepmd-kit.md), without building TensorFlow and DeePMD-kit on one's own.

doc/install/install-from-source.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,15 @@ You can also download libtorch prebuilt library from the [PyTorch website](https
316316

317317
:::
318318

319+
:::{tab-item} JAX {{ jax_icon }}
320+
321+
The JAX backend only depends on the TensorFlow C API, which is included in both TensorFlow C++ library and [TensorFlow C library](https://www.tensorflow.org/install/lang_c).
322+
If you want to use the TensorFlow C++ library, just enable the TensorFlow backend (which depends on the TensorFlow C++ library) and nothing else needs to do.
323+
If you want to use the TensorFlow C library and disable the TensorFlow backend,
324+
download the TensorFlow C library from [this page](https://www.tensorflow.org/install/lang_c#download_and_extract).
325+
326+
:::
327+
319328
::::
320329

321330
### Install DeePMD-kit's C++ interface
@@ -369,6 +378,17 @@ cmake -DENABLE_PYTORCH=TRUE -DUSE_PT_PYTHON_LIBS=TRUE -DCMAKE_INSTALL_PREFIX=$de
369378

370379
:::
371380

381+
:::{tab-item} JAX {{ jax_icon }}
382+
383+
If you want to use the TensorFlow C++ library, just enable the TensorFlow backend and nothing else needs to do.
384+
If you want to use the TensorFlow C library and disable the TensorFlow backend, set {cmake:variable}`ENABLE_JAX` to `ON` and `CMAKE_PREFIX_PATH` to the root directory of the [TensorFlow C library](https://www.tensorflow.org/install/lang_c).
385+
386+
```bash
387+
cmake -DENABLE_JAX=ON -D CMAKE_PREFIX_PATH=${tensorflow_c_root} ..
388+
```
389+
390+
:::
391+
372392
::::
373393

374394
One may add the following CMake variables to `cmake` using the [`-D <var>=<value>` option](https://cmake.org/cmake/help/latest/manual/cmake.1.html#cmdoption-cmake-D):
@@ -378,6 +398,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value
378398
**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`
379399

380400
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
401+
Setting this option to `ON` will also set {cmake:variable}`ENABLE_JAX` to `ON`.
381402

382403
:::
383404

@@ -389,6 +410,16 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value
389410

390411
:::
391412

413+
:::{cmake:variable} ENABLE_JAX
414+
415+
**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`
416+
417+
{{ jax_icon }} Build the JAX backend.
418+
If {cmake:variable}`ENABLE_TENSORFLOW` is `ON`, the TensorFlow C++ library is used to build the JAX backend;
419+
If {cmake:variable}`ENABLE_TENSORFLOW` is `OFF`, the TensorFlow C library is used to build the JAX backend.
420+
421+
:::
422+
392423
:::{cmake:variable} TENSORFLOW_ROOT
393424

394425
**Type**: `PATH`

source/CMakeLists.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ project(DeePMD)
44

55
option(ENABLE_TENSORFLOW "Enable TensorFlow interface" OFF)
66
option(ENABLE_PYTORCH "Enable PyTorch interface" OFF)
7+
option(ENABLE_JAX "Enable JAX interface" OFF)
8+
if(ENABLE_TENSORFLOW)
9+
# JAX requires TF C interface, contained in TF C++ library
10+
set(ENABLE_JAX ON)
11+
endif()
712
option(BUILD_TESTING "Build test and enable coverage" OFF)
813
set(DEEPMD_C_ROOT
914
""
@@ -246,6 +251,22 @@ if(ENABLE_PYTORCH AND NOT DEEPMD_C_ROOT)
246251
list(APPEND BACKEND_LIBRARY_PATH ${PyTorch_LIBRARY_PATH})
247252
list(APPEND BACKEND_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS})
248253
endif()
254+
if(ENABLE_JAX
255+
AND BUILD_CPP_IF
256+
AND NOT DEEPMD_C_ROOT)
257+
# no way to find it using Python
258+
find_package(TensorFlowC REQUIRED MODULE)
259+
if(DEFINED TENSORFLOWC_LIBRARY)
260+
list(APPEND BACKEND_LIBRARY_PATH ${TENSORFLOWC_LIBRARY})
261+
endif()
262+
if(DEFINED TENSORFLOWC_INCLUDE_DIR)
263+
list(APPEND BACKEND_INCLUDE_DIRS ${TENSORFLOWC_INCLUDE_DIR})
264+
endif()
265+
endif()
266+
if(NOT DEFINED OP_CXX_ABI)
267+
# prevent setting an empty value; this is default on GCC>=5
268+
set(OP_CXX_ABI 1)
269+
endif()
249270
# log enabled backends
250271
if(NOT DEEPMD_C_ROOT)
251272
message(STATUS "Enabled backends:")
@@ -255,8 +276,12 @@ if(NOT DEEPMD_C_ROOT)
255276
if(ENABLE_PYTORCH)
256277
message(STATUS "- PyTorch")
257278
endif()
279+
if(ENABLE_JAX)
280+
message(STATUS "- JAX")
281+
endif()
258282
if(NOT ENABLE_TENSORFLOW
259283
AND NOT ENABLE_PYTORCH
284+
AND NOT ENABLE_JAX
260285
AND NOT BUILD_PY_IF)
261286
message(FATAL_ERROR "No backend is enabled.")
262287
endif()

source/api_cc/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ if(ENABLE_PYTORCH
2323
target_link_libraries(${libname} PRIVATE "${TORCH_LIBRARIES}")
2424
target_compile_definitions(${libname} PRIVATE BUILD_PYTORCH)
2525
endif()
26+
if(ENABLE_JAX)
27+
target_link_libraries(${libname} PRIVATE TensorFlow::tensorflow_c)
28+
target_compile_definitions(${libname} PRIVATE BUILD_JAX)
29+
endif()
2630

2731
target_include_directories(
2832
${libname}

source/api_cc/src/DeepPot.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
#include "AtomMap.h"
88
#include "common.h"
99
#ifdef BUILD_TENSORFLOW
10-
#include "DeepPotJAX.h"
1110
#include "DeepPotTF.h"
1211
#endif
1312
#ifdef BUILD_PYTORCH
1413
#include "DeepPotPT.h"
1514
#endif
15+
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)
16+
#include "DeepPotJAX.h"
17+
#endif
1618
#include "device.h"
1719

1820
using namespace deepmd;
@@ -63,7 +65,7 @@ void DeepPot::init(const std::string& model,
6365
} else if (deepmd::DPBackend::Paddle == backend) {
6466
throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet");
6567
} else if (deepmd::DPBackend::JAX == backend) {
66-
#ifdef BUILD_TENSORFLOW
68+
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)
6769
dp = std::make_shared<deepmd::DeepPotJAX>(model, gpu_rank, file_content);
6870
#else
6971
throw deepmd::deepmd_exception(

source/api_cc/src/DeepPotJAX.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: LGPL-3.0-or-later
2-
#ifdef BUILD_TENSORFLOW
2+
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)
33

44
#include "DeepPotJAX.h"
55

source/cmake/FindTensorFlowC.cmake

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Find TensorFlow C library (libtensorflow) Define target
2+
# TensorFlow::tensorflow_c If TensorFlow::tensorflow_cc is not found, also
3+
# define: - TENSORFLOWC_INCLUDE_DIR - TENSORFLOWC_LIBRARY
4+
5+
if(TARGET TensorFlow::tensorflow_cc)
6+
# since tensorflow_cc contain tensorflow_c, just use it
7+
add_library(TensorFlow::tensorflow_c ALIAS TensorFlow::tensorflow_cc)
8+
set(TensorFlowC_FOUND TRUE)
9+
endif()
10+
11+
if(NOT TensorFlowC_FOUND)
12+
find_path(
13+
TENSORFLOWC_INCLUDE_DIR
14+
NAMES tensorflow/c/c_api.h
15+
PATH_SUFFIXES include
16+
DOC "Path to TensorFlow C include directory")
17+
18+
find_library(
19+
TENSORFLOWC_LIBRARY
20+
NAMES tensorflow
21+
PATH_SUFFIXES lib
22+
DOC "Path to TensorFlow C library")
23+
24+
include(FindPackageHandleStandardArgs)
25+
find_package_handle_standard_args(
26+
TensorFlowC REQUIRED_VARS TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR)
27+
28+
if(TensorFlowC_FOUND)
29+
set(TensorFlowC_INCLUDE_DIRS ${TENSORFLOWC_INCLUDE_DIR})
30+
set(TensorFlowC_LIBRARIES ${TENSORFLOWC_LIBRARY})
31+
endif()
32+
33+
add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL)
34+
set_property(TARGET TensorFlow::tensorflow_c PROPERTY IMPORTED_LOCATION
35+
${TENSORFLOWC_LIBRARY})
36+
target_include_directories(TensorFlow::tensorflow_c
37+
INTERFACE ${TENSORFLOWC_INCLUDE_DIR})
38+
39+
mark_as_advanced(TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR)
40+
endif()

0 commit comments

Comments
 (0)