Commit cefce47
Fix: support "max:N" and "filter:N" batch_size rules in DeepmdDataSystem (#4876)
# Fix: support "max:N" and "filter:N" batch_size rules in
DeepmdDataSystem
## Problem
- Using `batch_size: "max:..."` or `"filter:..."` in configs caused:
- `RuntimeError: unknown batch_size rule max` during the PyTorch path
(neighbor statistics).
- Docs mention these rules and PyTorch `DpLoaderSet` already supports
them, so behavior was inconsistent across layers.
## Cause
- The common data layer `DeepmdDataSystem` only implemented `"auto"` and
`"mixed"` for string `batch_size`, missing `"max"` and `"filter"`.
- PT training performs neighbor statistics via `DeepmdDataSystem` before
real training, so it failed early when those rules were used.
## Fix
- Implement `"max:N"` and `"filter:N"` in `DeepmdDataSystem.__init__` to
mirror `DpLoaderSet` semantics:
- `max:N`: per-system `batch_size = max(1, N // natoms)` so `batch_size
* natoms <= N`.
- `filter:N`: drop systems with `natoms > N` (warn if any removed; error
if none left), then set per-system `batch_size` as in `max:N`.
- After filtering, update `self.data_systems`, `self.system_dirs`, and
`self.nsystems` before computing other metadata.
## Impact
- Aligns the common layer behavior with PyTorch `DpLoaderSet` and with
the documentation.
- Prevents PT neighbor-stat crashes with configs using
`"max"`/`"filter"`.
## Compatibility
- No change to numeric `batch_size` or existing
`"auto"/"auto:N"/"mixed:N"` rules.
- TF/PT/PD paths now accept the same `batch_size` rules consistently in
the common layer.
## Files Changed
- `deepmd/utils/data_system.py`: add parsing branches for `"max:N"` and
`"filter:N"` in `DeepmdDataSystem.__init__`.
```python
elif "max" == words[0]:
# Determine batch size so that batch_size * natoms <= rule, at least 1
if len(words) != 2:
raise RuntimeError("batch size must be specified for max systems")
rule = int(words[1])
bs = []
for ii in self.data_systems:
ni = ii.get_natoms()
bsi = rule // ni
if bsi == 0:
bsi = 1
bs.append(bsi)
self.batch_size = bs
elif "filter" == words[0]:
# Remove systems with natoms > rule, then set batch size like "max:rule"
if len(words) != 2:
raise RuntimeError("batch size must be specified for filter systems")
rule = int(words[1])
filtered_data_systems = []
filtered_system_dirs = []
for sys_dir, data_sys in zip(self.system_dirs, self.data_systems):
if data_sys.get_natoms() <= rule:
filtered_data_systems.append(data_sys)
filtered_system_dirs.append(sys_dir)
if len(filtered_data_systems) == 0:
raise RuntimeError(f"No system left after removing systems with more than {rule} atoms")
if len(filtered_data_systems) != len(self.data_systems):
warnings.warn(f"Remove {len(self.data_systems) - len(filtered_data_systems)} systems with more than {rule} atoms")
self.data_systems = filtered_data_systems
self.system_dirs = filtered_system_dirs
self.nsystems = len(self.data_systems)
bs = []
for ii in self.data_systems:
ni = ii.get_natoms()
bsi = rule // ni
if bsi == 0:
bsi = 1
bs.append(bsi)
self.batch_size = bs
```
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added support for "max" and "filter" batch size rules, allowing more
flexible control over batch sizing and filtering of data systems based
on atom counts.
* **Bug Fixes**
* Improved error handling for incorrect batch size string formats and
cases where no systems remain after filtering.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent 88b71e8 commit cefce47
1 file changed
+45
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
152 | 152 | | |
153 | 153 | | |
154 | 154 | | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
155 | 200 | | |
156 | 201 | | |
157 | 202 | | |
| |||
0 commit comments