Skip to content

Commit f879b48

Browse files
authored
fix(jax): handle DPA-2 pbc/nopbc without mapping (#4363)
In the C++ API, generate the mapping for the no PBC and throw the error for PBC. Considering I forgot setting `atom_modify map yes` when testing it, others may also forget. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a function to determine if the model supports message passing, enhancing the model's interface. - Added a private member variable to facilitate message passing functionality. - Implemented unit tests for the `DeepPot` class, validating its functionality under various conditions. - **Bug Fixes** - Improved error handling for TensorFlow function retrieval, ensuring more specific exceptions are thrown. - Enhanced compatibility with earlier model versions by managing exceptions related to the new message passing variable. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 031c3ce commit f879b48

File tree

4 files changed

+424
-1
lines changed

4 files changed

+424
-1
lines changed

deepmd/jax/jax2tf/serialization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ def get_model_def_script():
294294
)
295295

296296
tf_model.get_model_def_script = get_model_def_script
297+
298+
@tf.function
299+
def has_message_passing() -> tf.Tensor:
300+
return tf.constant(model.has_message_passing(), dtype=tf.bool)
301+
302+
tf_model.has_message_passing = has_message_passing
297303
tf.saved_model.save(
298304
tf_model,
299305
model_file,

source/api_cc/include/DeepPotJAX.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ class DeepPotJAX : public DeepPotBackend {
189189
std::vector<int64_t> sel;
190190
// number of neighbors
191191
int nnei;
192+
// do message passing
193+
bool do_message_passing;
192194
// padding to nall
193195
int padding_to_nall = 0;
194196
// padding for nloc

source/api_cc/src/DeepPotJAX.cc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ inline TF_DataType get_data_tensor_type(const std::vector<int64_t>& data) {
6464
return TF_INT64;
6565
}
6666

67+
struct tf_function_not_found : public deepmd::deepmd_exception {
68+
public:
69+
tf_function_not_found() : deepmd_exception() {};
70+
tf_function_not_found(const std::string& msg) : deepmd_exception(msg) {};
71+
};
72+
6773
inline TFE_Op* get_func_op(TFE_Context* ctx,
6874
const std::string func_name,
6975
const std::vector<TF_Function*>& funcs,
@@ -72,7 +78,7 @@ inline TFE_Op* get_func_op(TFE_Context* ctx,
7278
TF_Function* func = NULL;
7379
find_function(func, funcs, func_name);
7480
if (func == NULL) {
75-
throw std::runtime_error("Function " + func_name + " not found");
81+
throw tf_function_not_found("Function " + func_name + " not found");
7682
}
7783
const char* real_func_name = TF_FunctionName(func);
7884
// execute the function
@@ -314,6 +320,13 @@ void deepmd::DeepPotJAX::init(const std::string& model,
314320
ntypes = type_map_.size();
315321
sel = get_vector<int64_t>(ctx, "get_sel", func_vector, device, status);
316322
nnei = std::accumulate(sel.begin(), sel.end(), decltype(sel)::value_type(0));
323+
try {
324+
do_message_passing = get_scalar<bool>(ctx, "do_message_passing",
325+
func_vector, device, status);
326+
} catch (tf_function_not_found& e) {
327+
// compatibile with models generated by v3.0.0rc0
328+
do_message_passing = false;
329+
}
317330
inited = true;
318331
}
319332

@@ -584,6 +597,15 @@ void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
584597
for (size_t ii = 0; ii < nall_real; ii++) {
585598
mapping[ii] = lmp_list.mapping[fwd_map[ii]];
586599
}
600+
} else if (nloc_real == nall_real) {
601+
// no ghost atoms
602+
for (size_t ii = 0; ii < nall_real; ii++) {
603+
mapping[ii] = ii;
604+
}
605+
} else if (do_message_passing) {
606+
throw deepmd::deepmd_exception(
607+
"Mapping is required for a message passing model. If you are using "
608+
"LAMMPS, set `atom_modify map yes`");
587609
}
588610
input_list[3] = add_input(op, mapping, mapping_shape, data_tensor[3], status);
589611
// fparam

0 commit comments

Comments
 (0)