Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ dependencies = [
"jax>=0.6.0",
]

[project.entry-points.jax_plugins]
mpibackend4jax = "mpibackend4jax"

[tool.hatch.build.targets.wheel]
packages = ["src/mpibackend4jax"]

Expand Down
67 changes: 34 additions & 33 deletions src/mpibackend4jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,40 @@
from pathlib import Path

# Import the cluster to register it automatically
from .mpitrampoline_cluster import MPITrampolineLocalCluster
# from .mpitrampoline_cluster import MPITrampolineLocalCluster

__version__ = "0.1.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__version__ = "0.1.0"
__version__ = "0.1.1"


# Get the package installation directory
_package_dir = Path(__file__).parent
_mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so"

# Set environment variables for MPITrampoline
if _mpiwrapper_lib.exists():
os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute())
os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi"

print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}")
print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi")
else:
print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}")
print("Please ensure the package was installed correctly.")


# Convenience function to check if MPITrampoline is properly configured
def is_configured():
"""Check if MPITrampoline is properly configured for JAX"""
return (
"MPITRAMPOLINE_LIB" in os.environ
and os.environ.get("JAX_CPU_COLLECTIVES_IMPLEMENTATION") == "mpi"
and Path(os.environ["MPITRAMPOLINE_LIB"]).exists()
)


def get_library_path():
"""Get the path to the MPIWrapper library"""
return os.environ.get("MPITRAMPOLINE_LIB")


__all__ = ["is_configured", "get_library_path", "MPITrampolineLocalCluster"]
def initialize():
# Get the package installation directory
_package_dir = Path(__file__).parent
_mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so"

# Set environment variables for MPITrampoline
if _mpiwrapper_lib.exists():
os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute())
os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi"

print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}")
print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi")
else:
print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}")
print("Please ensure the package was installed correctly.")


# # Convenience function to check if MPITrampoline is properly configured
# def is_configured():
# """Check if MPITrampoline is properly configured for JAX"""
# return (
# "MPITRAMPOLINE_LIB" in os.environ
# and os.environ.get("JAX_CPU_COLLECTIVES_IMPLEMENTATION") == "mpi"
# and Path(os.environ["MPITRAMPOLINE_LIB"]).exists()
# )
#
#
# def get_library_path():
# """Get the path to the MPIWrapper library"""
# return os.environ.get("MPITRAMPOLINE_LIB")


__all__ = ["is_configured", "get_library_path"]#, "MPITrampolineLocalCluster"]