Skip to content

Commit 31373bc

Browse files
Feat:support customized rglob (#4763)
Currently, when passing a single str as the root data directory, the `expand_sys_str` function will automatically perform `rglob` to grab all systems. However, this depends on the structure of the data folder. There are scenarios where train/val folders are nested, i.e. "root/dataset\_\*/trn" & "root/dataset\_\*/val". A customizable rglob function is needed to provide more flexibility when constructing datasets, and to remove unnecessarily long data lists in the input file. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for specifying custom glob patterns to filter training and validation datasets, allowing more flexible and targeted data selection (PyTorch backend only). - Introduced recursive pattern matching to improve system directory selection based on user-defined criteria. - **Tests** - Added new test cases to validate the customized glob pattern functionality for training and validation datasets. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ba974fb commit 31373bc

File tree

5 files changed

+83
-5
lines changed

5 files changed

+83
-5
lines changed

deepmd/common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,30 @@ def expand_sys_str(root_dir: Union[str, Path]) -> list[str]:
207207
return matches
208208

209209

210+
def rglob_sys_str(root_dir: str, patterns: list[str]) -> list[str]:
211+
"""Recursively iterate over directories taking those that contain `type.raw` file.
212+
213+
Parameters
214+
----------
215+
root_dir : str, Path
216+
starting directory
217+
patterns : list[str]
218+
list of glob patterns to match directories
219+
220+
Returns
221+
-------
222+
list[str]
223+
list of string pointing to system directories
224+
"""
225+
root_dir = Path(root_dir)
226+
matches = []
227+
for pattern in patterns:
228+
matches.extend(
229+
[str(d) for d in root_dir.rglob(pattern) if (d / "type.raw").is_file()]
230+
)
231+
return list(set(matches)) # remove duplicates
232+
233+
210234
def get_np_precision(precision: "_PRECISION") -> np.dtype:
211235
"""Get numpy precision constant from string.
212236

deepmd/pt/entrypoints/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ def prepare_trainer_input_single(
114114
validation_dataset_params["systems"] if validation_dataset_params else None
115115
)
116116
training_systems = training_dataset_params["systems"]
117-
training_systems = process_systems(training_systems)
117+
trn_patterns = training_dataset_params.get("rglob_patterns", None)
118+
training_systems = process_systems(training_systems, patterns=trn_patterns)
118119
if validation_systems is not None:
119-
validation_systems = process_systems(validation_systems)
120+
val_patterns = validation_dataset_params.get("rglob_patterns", None)
121+
validation_systems = process_systems(validation_systems, val_patterns)
120122

121123
# stat files
122124
stat_file_path_single = data_dict_single.get("stat_file", None)

deepmd/utils/argcheck.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,6 +2926,9 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
29262926
"This key can be provided with a list that specifies the systems, or be provided with a string "
29272927
"by which the prefix of all systems are given and the list of the systems is automatically generated."
29282928
)
2929+
doc_patterns = (
2930+
"The customized patterns used in `rglob` to collect all training systems. "
2931+
)
29292932
doc_batch_size = f'This key can be \n\n\
29302933
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
29312934
- int: all {link_sys} use the same batch size.\n\n\
@@ -2949,6 +2952,13 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
29492952
Argument(
29502953
"systems", [list[str], str], optional=False, default=".", doc=doc_systems
29512954
),
2955+
Argument(
2956+
"rglob_patterns",
2957+
[list[str]],
2958+
optional=True,
2959+
default=None,
2960+
doc=doc_patterns + doc_only_pt_supported,
2961+
),
29522962
Argument(
29532963
"batch_size",
29542964
[list[int], int, str],
@@ -2995,6 +3005,9 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
29953005
"This key can be provided with a list that specifies the systems, or be provided with a string "
29963006
"by which the prefix of all systems are given and the list of the systems is automatically generated."
29973007
)
3008+
doc_patterns = (
3009+
"The customized patterns used in `rglob` to collect all validation systems. "
3010+
)
29983011
doc_batch_size = f'This key can be \n\n\
29993012
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
30003013
- int: all {link_sys} use the same batch size.\n\n\
@@ -3015,6 +3028,13 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
30153028
Argument(
30163029
"systems", [list[str], str], optional=False, default=".", doc=doc_systems
30173030
),
3031+
Argument(
3032+
"rglob_patterns",
3033+
[list[str]],
3034+
optional=True,
3035+
default=None,
3036+
doc=doc_patterns + doc_only_pt_supported,
3037+
),
30183038
Argument(
30193039
"batch_size",
30203040
[list[int], int, str],

deepmd/utils/data_system.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from deepmd.common import (
1818
expand_sys_str,
1919
make_default_mesh,
20+
rglob_sys_str,
2021
)
2122
from deepmd.env import (
2223
GLOBAL_NP_FLOAT_PRECISION,
@@ -730,7 +731,9 @@ def prob_sys_size_ext(keywords, nsystems, nbatch):
730731
return sys_probs
731732

732733

733-
def process_systems(systems: Union[str, list[str]]) -> list[str]:
734+
def process_systems(
735+
systems: Union[str, list[str]], patterns: Optional[list[str]] = None
736+
) -> list[str]:
734737
"""Process the user-input systems.
735738
736739
If it is a single directory, search for all the systems in the directory.
@@ -740,14 +743,19 @@ def process_systems(systems: Union[str, list[str]]) -> list[str]:
740743
----------
741744
systems : str or list of str
742745
The user-input systems
746+
patterns : list of str, optional
747+
The patterns to match the systems, by default None
743748
744749
Returns
745750
-------
746751
list of str
747752
The valid systems
748753
"""
749754
if isinstance(systems, str):
750-
systems = expand_sys_str(systems)
755+
if patterns is None:
756+
systems = expand_sys_str(systems)
757+
else:
758+
systems = rglob_sys_str(systems, patterns)
751759
elif isinstance(systems, list):
752760
systems = systems.copy()
753761
return systems
@@ -777,7 +785,8 @@ def get_data(
777785
The data system
778786
"""
779787
systems = jdata["systems"]
780-
systems = process_systems(systems)
788+
rglob_patterns = jdata.get("rglob_patterns", None)
789+
systems = process_systems(systems, patterns=rglob_patterns)
781790

782791
batch_size = jdata["batch_size"]
783792
sys_probs = jdata.get("sys_probs", None)

source/tests/pt/test_training.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,5 +516,28 @@ def tearDown(self) -> None:
516516
shutil.rmtree(f)
517517

518518

519+
class TestCustomizedRGLOB(unittest.TestCase, DPTrainTest):
520+
def setUp(self) -> None:
521+
input_json = str(Path(__file__).parent / "water/se_atten.json")
522+
with open(input_json) as f:
523+
self.config = json.load(f)
524+
self.config["training"]["training_data"]["rglob_patterns"] = [
525+
"water/data/data_*"
526+
]
527+
self.config["training"]["training_data"]["systems"] = str(Path(__file__).parent)
528+
self.config["training"]["validation_data"]["rglob_patterns"] = [
529+
"water/*/data_0"
530+
]
531+
self.config["training"]["validation_data"]["systems"] = str(
532+
Path(__file__).parent
533+
)
534+
self.config["model"] = deepcopy(model_dpa1)
535+
self.config["training"]["numb_steps"] = 1
536+
self.config["training"]["save_freq"] = 1
537+
538+
def tearDown(self) -> None:
539+
DPTrainTest.tearDown(self)
540+
541+
519542
if __name__ == "__main__":
520543
unittest.main()

0 commit comments

Comments
 (0)