Skip to content

Commit a30e664

Browse files
authored
feat: add full Nx backend (#11)
* feat: add full backend * feat: run nx test suite * feat: set default device * fix: recover from partical iree clones * Update cmake/src/runtime.cc
1 parent a520f9e commit a30e664

File tree

24 files changed

+675
-264
lines changed

24 files changed

+675
-264
lines changed

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ endif
2323
.PHONY: clone_iree
2424
clone_iree: $(NX_IREE_SOURCE_DIR)
2525

26-
$(NX_IREE_SOURCE_DIR):
26+
.PHONY: iree_source_dir
27+
iree_source_dir:
2728
./scripts/clone_iree.sh $(IREE_GIT_REV) $(NX_IREE_SOURCE_DIR)
2829

2930
IREE_CMAKE_BUILD_DIR ?= $(abspath iree-runtime/iree-build)
@@ -106,7 +107,7 @@ install_runtime: iree_host $(IREE_INSTALL_DIR)
106107

107108
CMAKE_SOURCES = $(abspath cmake/src/runtime.cc) $(abspath cmake/src/runtime.h)
108109

109-
$(IREE_INSTALL_DIR): $(NX_IREE_SOURCE_DIR) $(CMAKE_SOURCES)
110+
$(IREE_INSTALL_DIR): iree_source_dir $(CMAKE_SOURCES)
110111
cmake -G Ninja -B $(IREE_CMAKE_BUILD_DIR) \
111112
-DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG)\
112113
-DIREE_BUILD_COMPILER=OFF\

c_src/nx_iree.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ ERL_NIF_TERM error(ErlNifEnv* env, const char* error) {
1313
return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_string(env, error, ERL_NIF_LATIN1));
1414
}
1515

16+
ERL_NIF_TERM ok(ErlNifEnv* env) {
17+
return enif_make_atom(env, "ok");
18+
}
19+
1620
ERL_NIF_TERM ok(ErlNifEnv* env, ERL_NIF_TERM term) {
1721
return enif_make_tuple2(env, enif_make_atom(env, "ok"), term);
1822
}
@@ -233,7 +237,8 @@ DECLARE_NIF(list_devices) {
233237
auto ref_term = make<iree_hal_device_t*>(env, device->ref);
234238
auto driver_name_term = enif_make_string(env, device->driver_name.c_str(), ERL_NIF_LATIN1);
235239
auto uri_term = enif_make_string(env, device->uri.c_str(), ERL_NIF_LATIN1);
236-
auto tuple = enif_make_tuple3(env, ref_term, driver_name_term, uri_term);
240+
auto id_term = enif_make_uint64(env, device->id);
241+
auto tuple = enif_make_tuple4(env, ref_term, driver_name_term, uri_term, id_term);
237242
device_terms.push_back(tuple);
238243
}
239244

@@ -315,13 +320,13 @@ iree_hal_element_type_t nx_type_to_iree_type(std::string type) {
315320
return type_enum::IREE_HAL_ELEMENT_TYPE_INT_32;
316321
} else if (type == "i64") {
317322
return type_enum::IREE_HAL_ELEMENT_TYPE_INT_64;
318-
} else if (type == "u8") {
323+
} else if (type == "ui8") {
319324
return type_enum::IREE_HAL_ELEMENT_TYPE_UINT_8;
320-
} else if (type == "u16") {
325+
} else if (type == "ui16") {
321326
return type_enum::IREE_HAL_ELEMENT_TYPE_UINT_16;
322-
} else if (type == "u32") {
327+
} else if (type == "ui32") {
323328
return type_enum::IREE_HAL_ELEMENT_TYPE_UINT_32;
324-
} else if (type == "u64") {
329+
} else if (type == "ui64") {
325330
return type_enum::IREE_HAL_ELEMENT_TYPE_UINT_64;
326331
} else if (type == "bf16") {
327332
return type_enum::IREE_HAL_ELEMENT_TYPE_BFLOAT_16;
@@ -355,14 +360,10 @@ DECLARE_NIF(read_buffer_nif) {
355360
return error(env, "invalid num_bytes");
356361
}
357362

358-
std::cout << "num_bytes input: " << num_bytes << std::endl;
359-
360363
if (num_bytes == -1) {
361364
num_bytes = (*input)->size;
362365
}
363366

364-
std::cout << "num_bytes actual: " << num_bytes << std::endl;
365-
366367
ErlNifBinary binary;
367368

368369
if (!enif_alloc_binary(num_bytes, &binary)) {
@@ -418,6 +419,22 @@ DECLARE_NIF(allocate_buffer) {
418419
return ok(env, make<iree::runtime::IREETensor*>(env, input));
419420
}
420421

422+
DECLARE_NIF(deallocate_buffer) {
423+
if (argc != 1) {
424+
return error(env, "invalid number of arguments");
425+
}
426+
427+
iree::runtime::IREETensor** input;
428+
429+
if (!get<iree::runtime::IREETensor*>(env, argv[0], input)) {
430+
return error(env, "invalid input");
431+
}
432+
433+
(*input)->deallocate();
434+
435+
return ok(env);
436+
}
437+
421438
DECLARE_NIF(serialize_tensor) {
422439
if (argc != 1) {
423440
return error(env, "invalid number of arguments");
@@ -510,6 +527,7 @@ static ErlNifFunc funcs[] = {
510527
{"list_devices", 1, list_devices},
511528
{"list_devices", 2, list_devices},
512529
{"list_drivers", 1, list_drivers},
530+
{"deallocate_buffer", 1, deallocate_buffer},
513531
{"allocate_buffer", 4, allocate_buffer},
514532
{"serialize_tensor", 1, serialize_tensor},
515533
{"deserialize_tensor", 1, deserialize_tensor},

cmake/src/runtime.cc

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,24 @@ iree::runtime::IREETensor::IREETensor(char *buffer) {
7070
this->buffer_view = nullptr;
7171
}
7272

73-
std::vector<char> *iree::runtime::IREETensor::serialize() {
73+
iree::runtime::IREETensor::~IREETensor() {
74+
this->deallocate();
75+
}
76+
77+
void iree::runtime::IREETensor::deallocate() {
78+
if (data) {
79+
std::free(data);
80+
data = nullptr;
81+
}
82+
83+
if (buffer_view) {
84+
iree_hal_buffer_view_release(buffer_view);
85+
buffer_view = nullptr;
86+
}
87+
}
88+
89+
std::vector<char> *
90+
iree::runtime::IREETensor::serialize() {
7491
auto buffer = new std::vector<char>();
7592

7693
// Serialize 'type'
@@ -196,16 +213,6 @@ iree_status_t list_devices(iree_hal_driver_registry_t *registry, std::string dri
196213
return status;
197214
}
198215

199-
auto out_device = new iree::runtime::Device(driver_name);
200-
status = iree_hal_driver_create_default_device(driver, iree_allocator_system(),
201-
&out_device->ref);
202-
if (!iree_status_is_ok(status)) {
203-
return status;
204-
}
205-
206-
out_device->uri = driver_name + "://default";
207-
devices.push_back(out_device);
208-
209216
status = iree_hal_driver_query_available_devices(
210217
driver, iree_allocator_system(), &device_info_count, &device_infos);
211218

@@ -217,7 +224,9 @@ iree_status_t list_devices(iree_hal_driver_registry_t *registry, std::string dri
217224
for (size_t i = 0; i < device_info_count; i++) {
218225
auto device = new iree::runtime::Device(driver_name);
219226
auto info = device_infos[i];
220-
device->uri = driver_name + "://" + std::string(info.path.data, info.path.size);
227+
std::string device_urn(info.path.data, info.path.size);
228+
device->uri = driver_name + "://" + device_urn;
229+
device->id = info.device_id;
221230

222231
status = iree_hal_create_device(
223232
registry,

cmake/src/runtime.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Device {
2929
std::string uri;
3030
iree_hal_device_t* ref;
3131
std::string driver_name;
32+
uint64_t id;
3233

3334
Device(std::string driver_name) : driver_name(driver_name) {}
3435
~Device();
@@ -48,12 +49,9 @@ class IREETensor {
4849
IREETensor(void* data, size_t size, std::vector<int64_t> in_dims, iree_hal_element_type_t type);
4950

5051
// Destructor
51-
~IREETensor() {
52-
if (data) {
53-
std::free(data);
54-
data = nullptr;
55-
}
56-
}
52+
~IREETensor();
53+
54+
void deallocate();
5755

5856
// Disable copy and move semantics for simplicity
5957
IREETensor(const IREETensor&) = delete;

compiler.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, i
1818

1919
f = Nx.Defn.compile(fun, args)
2020

21-
Nx.default_backend(NxIREE.Tensor)
21+
Nx.default_backend(NxIREE.Backend)
2222
arg0 = Nx.tensor([1.0, 2.0, 3.0, 4.0])
2323
arg1 = Nx.tensor([1, -1, 1, -1])
2424
f.(arg0, arg1) |> dbg()

config/config.exs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import Config
2+
3+
import_config "#{config_env()}.exs"

config/dev.exs

Whitespace-only changes.

config/prod.exs

Whitespace-only changes.

config/test.exs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import Config
2+
3+
config :nx_iree, :add_backend_on_inspect, false

embedded_devices/live_nx_iree/config/runtime.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if System.get_env("PHX_SERVER") do
2020
config :live_nx_iree, LiveNxIREEWeb.Endpoint, server: true
2121
end
2222

23-
config :nx, :default_backend, NxIREE.Tensor
23+
config :nx, :default_backend, NxIREE.Backend
2424

2525
if config_env() == :prod do
2626
database_url =

0 commit comments

Comments
 (0)