Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions torchx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ to `2.4.0` or later.
If you want torch with CUDA support, please use `LIBTORCH_TARGET` to choose
CUDA versions. The current supported targets are:

- `cpu` default CPU only version
- `cu118` CUDA 11.8 and CPU version (no macOS support, libtorch `< 2.8.0` only)
- `cu126` CUDA 12.6 and CPU version (no macOS support)
- `cu128` CUDA 12.8 and CPU version (no macOS support)
- `cu129` CUDA 12.9 and CPU version (no macOS support, libtorch `>= 2.8.0` only)
| Torch Version | Supported Targets |
| --------------- | -------------------------------------------------------------- |
| 2.7.0 | cpu, cu118 (CUDA 11.8), cu126 (CUDA 12.6), cu128 (CUDA 12.8) |
| 2.8.0 | cpu, cu126 (CUDA 12.6), cu128 (CUDA 12.8), cu129 (CUDA 12.9) |
| 2.9.0 | cpu, cu126 (CUDA 12.6), cu128 (CUDA 12.8), cu130 (CUDA 13.0) |

Once downloaded, we will compile `Torchx` bindings. You will need `make`/`nmake`,
`cmake` (3.12+) and a `C++` compiler. If building on Windows, you will need:
Expand Down
23 changes: 10 additions & 13 deletions torchx/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,19 @@ defmodule Torchx.MixProject do
end

defp libtorch_config() do
target = System.get_env("LIBTORCH_TARGET", "cpu")
version = System.get_env("LIBTORCH_VERSION", "2.8.0")
env_dir = System.get_env("LIBTORCH_DIR")

# 2.8.0 is the first version that supports cu129 and drops cu118
# cu118 might still be needed for older hardware, so we're keeping it
# for now.
valid_targets = ["cpu", "cu118", "cu126", "cu128"]

# Supported targets for each LibTorch version:
valid_targets =
case Version.parse(version) do
{:ok, parsed} ->
if Version.match?(parsed, "<= 2.7.0") do
valid_targets
else
(valid_targets -- ["cu118"]) ++ ["cu129"]
cond do
Version.match?(parsed, "< 2.8.0") ->
["cpu", "cu118", "cu126", "cu128"]

Version.match?(parsed, "< 2.9.0") ->
["cpu", "cu126", "cu129"]

true ->
["cpu", "cu126", "cu128", "cu130"]
end

_ ->
Expand Down