diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 861187e0..ef91bf3f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,7 +16,7 @@ jobs: test-script: | cp -r ${{ github.workspace }}/test ./test cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml - python -m pip install -r ./test/requirements.txt + python -m pip install '.[tests]' python -m test pypi-token: ${{ secrets.pypi_token }} github-user: patrick-kidger diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index b209bb3d..f6f10c8c 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -23,8 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r ./test/requirements.txt - + python -m pip install '.[tests]' - name: Checks with pre-commit uses: pre-commit/action@v3.0.1 @@ -33,3 +32,26 @@ jobs: run: | python -m pip install . python -m test + + # Run a test with JAX tracer leak detection enabled + run-test-tracer: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python 3.13 + uses: actions/setup-python@v2 + with: + python-version: "3.13" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install '.[tests]' + + - name: Test tracer functionality + run: | + python -m pip install . + export JAX_CHECK_TRACER_LEAKS=1 + python -m test --tracer diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1c9b3ced..be5355ab 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -34,7 +34,7 @@ Now make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass: ```bash -pip install -r test/requirements.txt +pip install -e '.[tests]' pytest ``` diff --git a/pyproject.toml b/pyproject.toml index e7326e29..1da93de9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,14 @@ docs = [ "mkdocstrings-python==1.16.8", "pymdown-extensions==10.14.3" ] +tests = [ + "beartype>=0.22.5", + "jaxlib>=0.6.2", + "optax>=0.2.6", + "pytest>=9.0.1", + "scipy>=1.15.3", + "tqdm>=4.67.1", +] [tool.hatch.build] include = ["diffrax/*"] diff --git a/test/requirements.txt b/test/requirements.txt deleted file mode 100644 index 9de88eb6..00000000 --- a/test/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -beartype -jaxlib -optax -pytest -scipy -tqdm