File tree Expand file tree Collapse file tree 4 files changed +21
-34
lines changed
Expand file tree Collapse file tree 4 files changed +21
-34
lines changed Original file line number Diff line number Diff line change 1414 with :
1515 python-version : " 3.11"
1616 test-script : |
17- python -m pip install pytest psutil jax jaxlib equinox scipy optax
17+ python -m pip install pytest jax jaxlib equinox scipy optax
1818 cp -r ${{ github.workspace }}/test ./test
19- pytest
19+ python -m test
2020 pypi-token : ${{ secrets.pypi_token }}
2121 github-user : patrick-kidger
2222 github-token : ${{ github.token }}
Original file line number Diff line number Diff line change @@ -23,12 +23,12 @@ jobs:
2323 - name : Install dependencies
2424 run : |
2525 python -m pip install --upgrade pip
26- python -m pip install pytest psutil wheel scipy numpy optax jaxlib
26+ python -m pip install pytest wheel scipy numpy optax jaxlib
2727
2828 - name : Checks with pre-commit
2929 uses : pre-commit/action@v2.0.3
3030
3131 - name : Test with pytest
3232 run : |
3333 python -m pip install .
34- python -m pytest --durations=0
34+ python -m test
Original file line number Diff line number Diff line change 1+ import pathlib
2+ import subprocess
3+ import sys
4+
5+
6+ here = pathlib .Path (__file__ ).resolve ().parent
7+
8+
9+ # Each file is ran separately to avoid out-of-memorying.
10+ running_out = 0
11+ for file in here .iterdir ():
12+ if file .is_file () and file .name .startswith ("test" ):
13+ out = subprocess .run (f"pytest { file } " , shell = True ).returncode
14+ running_out = max (running_out , out )
15+ sys .exit (running_out )
Original file line number Diff line number Diff line change 1- import gc
21import random
3- import sys
42
53import jax
64import jax .config
7- import jax .random as jrandom
8- import psutil
5+ import jax .random as jr
96import pytest
107
118
1613def getkey ():
1714 def _getkey ():
1815 # Not sure what the maximum actually is but this will do
19- return jrandom .PRNGKey (random .randint (0 , 2 ** 31 - 1 ))
16+ return jr .PRNGKey (random .randint (0 , 2 ** 31 - 1 ))
2017
2118 return _getkey
22-
23-
24- # Hugely hacky way of reducing memory usage in tests.
25- # JAX can be a little over-happy with its caching; this is especially noticable when
26- # performing tests and therefore doing an unusual amount of compilation etc.
27- # This can be enough to exceed the 8GB RAM available to Ubuntu instances on GitHub
28- # Actions.
29- @pytest .fixture (autouse = True )
30- def clear_caches ():
31- process = psutil .Process ()
32- if process .memory_info ().vms > 4 * 2 ** 30 : # >4GB memory usage
33- jax .clear_backends ()
34- for module_name , module in sys .modules .copy ().items ():
35- if module_name .startswith ("jax" ):
36- if module_name not in ["jax.interpreters.partial_eval" ]:
37- for obj_name in dir (module ):
38- obj = getattr (module , obj_name )
39- if hasattr (obj , "cache_clear" ):
40- try :
41- print (f"Clearing { obj } " )
42- if "Weakref" not in type (obj ).__name__ :
43- obj .cache_clear ()
44- except Exception :
45- pass
46- gc .collect ()
You can’t perform that action at this time.
0 commit comments