Skip to content

Our of memory : How to optimize the memory use in Diffrax framework? #685

@timnotavailable

Description

@timnotavailable

Hello,
I wrote a simulator to simulate a ODE system ( with at least 256x256 ODEs in this system), max_steps=1000, solver is Tsit5, adjoints=diffrax.RecursiveCheckpointAdjoint(), however I found the error: Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate xxxx bytes.

I'm a new bird in terms of neural ODE and I'm kindly asking some advice on how to reduce memory usage? (From previous issues maybe I should reduce the max_steps), how about solver? Should I use some easy solver such as 2nd/3nd Runge Kutta solver or ? Any other suggestion on optimizing GPU usage?

Will the adjoint method influencing memory usage? I saw in https://docs.kidger.site/diffrax/api/adjoints/ one can use max_steps or checkpoints in RecursiveCheckpointAdjoint class to control the memory usage, will other adjoint method saving the memory usage? From the past post it is suggested to use RecursiveCheckpointAdjoint as it has been optimized to O(logn), will other method more memory efficient?

If the ODE's rastering timing is known, will the ConstantStepSize better than the adaptive one?

Thanks for your brilliant library!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions