Skip to content

Commit a520f9e

Browse files
committed
fix: cuda and linux support
1 parent da1cf45 commit a520f9e

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

Makefile

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ IREE_BUILD_TARGET ?= host
3737

3838
BUILD_TARGET_FLAGS = -S $(abspath cmake)
3939

40+
CUDA_PRESENT := $(shell command -v nvcc >/dev/null 2>&1 && echo true || echo false)
41+
42+
ifeq ($(CUDA_PRESENT), true)
43+
CFLAGS += -DCUDA_ENABLED
44+
CMAKE_CXX_FLAGS += -DCUDA_ENABLED
45+
endif
46+
4047
# flags for xcode 15.4
4148
ifeq ($(IREE_BUILD_TARGET), host)
4249
else ifeq ($(IREE_BUILD_TARGET), ios)
@@ -106,6 +113,7 @@ $(IREE_INSTALL_DIR): $(NX_IREE_SOURCE_DIR) $(CMAKE_SOURCES)
106113
-DIREE_RUNTIME_BUILD_DIR=$(IREE_RUNTIME_BUILD_DIR)\
107114
-DIREE_RUNTIME_INCLUDE_PATH=$(IREE_RUNTIME_INCLUDE_PATH)\
108115
-DNX_IREE_SOURCE_DIR=$(NX_IREE_SOURCE_DIR) \
116+
-DCMAKE_CXX_FLAGS=$(CMAKE_CXX_FLAGS) \
109117
$(BUILD_TARGET_FLAGS)
110118

111119
cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG)
@@ -119,6 +127,7 @@ iree_host:
119127
-DCMAKE_INSTALL_PREFIX=$(IREE_HOST_INSTALL_DIR) \
120128
-DIREE_BUILD_COMPILER=OFF\
121129
-DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG) \
130+
-DCMAKE_CXX_FLAGS=$(CMAKE_CXX_FLAGS) \
122131
-S $(NX_IREE_SOURCE_DIR)
123132
cmake --build $(IREE_HOST_BUILD_DIR) --target install
124133
else
@@ -132,8 +141,9 @@ NX_IREE_SO ?= $(MIX_APP_PATH)/priv/libnx_iree.so
132141
NX_IREE_CACHE_SO ?= cache/libnx_iree.so
133142
NX_IREE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(NX_IREE_CACHE_SO)
134143

135-
NX_IREE_RUNTIME_LIB = cache/iree-runtime/
144+
NX_IREE_RUNTIME_LIB = cache/iree-runtime
136145
NX_IREE__IREE_RUNTIME_INCLUDE_PATH = $(NX_IREE_RUNTIME_LIB)/include
146+
NX_IREE_RUNTIME_SO ?= $(MIX_APP_PATH)/priv/libnx_iree_runtime.so
137147

138148
CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(NX_IREE__IREE_RUNTIME_INCLUDE_PATH) -Wall -Wno-sign-compare \
139149
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
@@ -157,14 +167,8 @@ else
157167
LDFLAGS += -Wl,-rpath,'$$ORIGIN/iree-runtime'
158168
endif
159169

160-
CUDA_PRESENT := $(shell command -v nvcc >/dev/null 2>&1 && echo true || echo false)
161-
162-
ifeq ($(CUDA_PRESENT), true)
163-
CFLAGS += -DCUDA_ENABLED
164-
endif
165-
166170
NX_IREE_LIB_DIR = $(MIX_APP_PATH)/priv/iree-runtime
167-
NX_IREE_LIB_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(NX_IREE_RUNTIME_LIB)
171+
NX_IREE_LIB_LINK_PATH = $(abspath $(NX_IREE_RUNTIME_LIB))
168172
NX_IREE_CACHE_SO_LINK_PATH = $(NX_IREE_CACHE_SO)
169173

170174
SOURCES = $(wildcard c_src/*.cc)

axon.exs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ Mix.install([
55
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true}
66
], system_env: %{"NX_IREE_PREFER_PRECOMPILED" => false})
77

8-
NxIREE.list_drivers() |> IO.inspect(label: "drivers")
8+
{:ok, drivers} = NxIREE.list_drivers() |> IO.inspect(label: "drivers")
99

10-
{:ok, [dev | _]} = NxIREE.list_devices("metal")
10+
{:ok, [dev | _]} = NxIREE.list_devices("cuda")
1111

12-
flags = ["--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
12+
flags = ["--iree-hal-target-backends=cuda", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
1313
Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, iree_runtime_options: [device: dev])
1414

1515
model =

0 commit comments

Comments
 (0)