Skip to content

Commit aadbeae

Browse files
Fixed tests OOM'ing/crashing/etc.
1 parent c5d301a commit aadbeae

File tree

4 files changed

+21
-34
lines changed

4 files changed

+21
-34
lines changed

.github/workflows/release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ jobs:
1414
with:
1515
python-version: "3.11"
1616
test-script: |
17-
python -m pip install pytest psutil jax jaxlib equinox scipy optax
17+
python -m pip install pytest jax jaxlib equinox scipy optax
1818
cp -r ${{ github.workspace }}/test ./test
19-
pytest
19+
python -m test
2020
pypi-token: ${{ secrets.pypi_token }}
2121
github-user: patrick-kidger
2222
github-token: ${{ github.token }}

.github/workflows/run_tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
python -m pip install pytest psutil wheel scipy numpy optax jaxlib
26+
python -m pip install pytest wheel scipy numpy optax jaxlib
2727
2828
- name: Checks with pre-commit
2929
uses: pre-commit/action@v2.0.3
3030

3131
- name: Test with pytest
3232
run: |
3333
python -m pip install .
34-
python -m pytest --durations=0
34+
python -m test

test/__main__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pathlib
2+
import subprocess
3+
import sys
4+
5+
6+
here = pathlib.Path(__file__).resolve().parent
7+
8+
9+
# Each file is ran separately to avoid out-of-memorying.
10+
running_out = 0
11+
for file in here.iterdir():
12+
if file.is_file() and file.name.startswith("test"):
13+
out = subprocess.run(f"pytest {file}", shell=True).returncode
14+
running_out = max(running_out, out)
15+
sys.exit(running_out)

test/conftest.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
import gc
21
import random
3-
import sys
42

53
import jax
64
import jax.config
7-
import jax.random as jrandom
8-
import psutil
5+
import jax.random as jr
96
import pytest
107

118

@@ -16,31 +13,6 @@
1613
def getkey():
1714
def _getkey():
1815
# Not sure what the maximum actually is but this will do
19-
return jrandom.PRNGKey(random.randint(0, 2**31 - 1))
16+
return jr.PRNGKey(random.randint(0, 2**31 - 1))
2017

2118
return _getkey
22-
23-
24-
# Hugely hacky way of reducing memory usage in tests.
25-
# JAX can be a little over-happy with its caching; this is especially noticable when
26-
# performing tests and therefore doing an unusual amount of compilation etc.
27-
# This can be enough to exceed the 8GB RAM available to Ubuntu instances on GitHub
28-
# Actions.
29-
@pytest.fixture(autouse=True)
30-
def clear_caches():
31-
process = psutil.Process()
32-
if process.memory_info().vms > 4 * 2**30: # >4GB memory usage
33-
jax.clear_backends()
34-
for module_name, module in sys.modules.copy().items():
35-
if module_name.startswith("jax"):
36-
if module_name not in ["jax.interpreters.partial_eval"]:
37-
for obj_name in dir(module):
38-
obj = getattr(module, obj_name)
39-
if hasattr(obj, "cache_clear"):
40-
try:
41-
print(f"Clearing {obj}")
42-
if "Weakref" not in type(obj).__name__:
43-
obj.cache_clear()
44-
except Exception:
45-
pass
46-
gc.collect()

0 commit comments

Comments
 (0)