Skip to content

Commit cefce47

Browse files
Fix: support "max:N" and "filter:N" batch_size rules in DeepmdDataSystem (#4876)
# Fix: support "max:N" and "filter:N" batch_size rules in DeepmdDataSystem ## Problem - Using `batch_size: "max:..."` or `"filter:..."` in configs caused: - `RuntimeError: unknown batch_size rule max` during the PyTorch path (neighbor statistics). - Docs mention these rules and PyTorch `DpLoaderSet` already supports them, so behavior was inconsistent across layers. ## Cause - The common data layer `DeepmdDataSystem` only implemented `"auto"` and `"mixed"` for string `batch_size`, missing `"max"` and `"filter"`. - PT training performs neighbor statistics via `DeepmdDataSystem` before real training, so it failed early when those rules were used. ## Fix - Implement `"max:N"` and `"filter:N"` in `DeepmdDataSystem.__init__` to mirror `DpLoaderSet` semantics: - `max:N`: per-system `batch_size = max(1, N // natoms)` so `batch_size * natoms <= N`. - `filter:N`: drop systems with `natoms > N` (warn if any removed; error if none left), then set per-system `batch_size` as in `max:N`. - After filtering, update `self.data_systems`, `self.system_dirs`, and `self.nsystems` before computing other metadata. ## Impact - Aligns the common layer behavior with PyTorch `DpLoaderSet` and with the documentation. - Prevents PT neighbor-stat crashes with configs using `"max"`/`"filter"`. ## Compatibility - No change to numeric `batch_size` or existing `"auto"/"auto:N"/"mixed:N"` rules. - TF/PT/PD paths now accept the same `batch_size` rules consistently in the common layer. ## Files Changed - `deepmd/utils/data_system.py`: add parsing branches for `"max:N"` and `"filter:N"` in `DeepmdDataSystem.__init__`. ```python elif "max" == words[0]: # Determine batch size so that batch_size * natoms <= rule, at least 1 if len(words) != 2: raise RuntimeError("batch size must be specified for max systems") rule = int(words[1]) bs = [] for ii in self.data_systems: ni = ii.get_natoms() bsi = rule // ni if bsi == 0: bsi = 1 bs.append(bsi) self.batch_size = bs elif "filter" == words[0]: # Remove systems with natoms > rule, then set batch size like "max:rule" if len(words) != 2: raise RuntimeError("batch size must be specified for filter systems") rule = int(words[1]) filtered_data_systems = [] filtered_system_dirs = [] for sys_dir, data_sys in zip(self.system_dirs, self.data_systems): if data_sys.get_natoms() <= rule: filtered_data_systems.append(data_sys) filtered_system_dirs.append(sys_dir) if len(filtered_data_systems) == 0: raise RuntimeError(f"No system left after removing systems with more than {rule} atoms") if len(filtered_data_systems) != len(self.data_systems): warnings.warn(f"Remove {len(self.data_systems) - len(filtered_data_systems)} systems with more than {rule} atoms") self.data_systems = filtered_data_systems self.system_dirs = filtered_system_dirs self.nsystems = len(self.data_systems) bs = [] for ii in self.data_systems: ni = ii.get_natoms() bsi = rule // ni if bsi == 0: bsi = 1 bs.append(bsi) self.batch_size = bs ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for "max" and "filter" batch size rules, allowing more flexible control over batch sizing and filtering of data systems based on atom counts. * **Bug Fixes** * Improved error handling for incorrect batch size string formats and cases where no systems remain after filtering. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 88b71e8 commit cefce47

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

deepmd/utils/data_system.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,51 @@ def __init__(
152152
else:
153153
raise RuntimeError("batch size must be specified for mixed systems")
154154
self.batch_size = rule * np.ones(self.nsystems, dtype=int)
155+
elif "max" == words[0]:
156+
# Determine batch size so that batch_size * natoms <= rule, at least 1
157+
if len(words) != 2:
158+
raise RuntimeError("batch size must be specified for max systems")
159+
rule = int(words[1])
160+
bs = []
161+
for ii in self.data_systems:
162+
ni = ii.get_natoms()
163+
bsi = rule // ni
164+
if bsi == 0:
165+
bsi = 1
166+
bs.append(bsi)
167+
self.batch_size = bs
168+
elif "filter" == words[0]:
169+
# Remove systems with natoms > rule, then set batch size like "max:rule"
170+
if len(words) != 2:
171+
raise RuntimeError(
172+
"batch size must be specified for filter systems"
173+
)
174+
rule = int(words[1])
175+
filtered_data_systems = []
176+
filtered_system_dirs = []
177+
for sys_dir, data_sys in zip(self.system_dirs, self.data_systems):
178+
if data_sys.get_natoms() <= rule:
179+
filtered_data_systems.append(data_sys)
180+
filtered_system_dirs.append(sys_dir)
181+
if len(filtered_data_systems) == 0:
182+
raise RuntimeError(
183+
f"No system left after removing systems with more than {rule} atoms"
184+
)
185+
if len(filtered_data_systems) != len(self.data_systems):
186+
warnings.warn(
187+
f"Remove {len(self.data_systems) - len(filtered_data_systems)} systems with more than {rule} atoms"
188+
)
189+
self.data_systems = filtered_data_systems
190+
self.system_dirs = filtered_system_dirs
191+
self.nsystems = len(self.data_systems)
192+
bs = []
193+
for ii in self.data_systems:
194+
ni = ii.get_natoms()
195+
bsi = rule // ni
196+
if bsi == 0:
197+
bsi = 1
198+
bs.append(bsi)
199+
self.batch_size = bs
155200
else:
156201
raise RuntimeError("unknown batch_size rule " + words[0])
157202
elif isinstance(self.batch_size, list):

0 commit comments

Comments
 (0)