Skip to content

Commit 0392d0e

Browse files
authored
feat: add serde for embedding nx iree tensors (#8)
1 parent fe2f093 commit 0392d0e

File tree

7 files changed

+139
-5
lines changed

7 files changed

+139
-5
lines changed

c_src/nx_iree.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,46 @@ DECLARE_NIF(allocate_buffer) {
410410
return ok(env, make<iree::runtime::IREETensor*>(env, input));
411411
}
412412

413+
DECLARE_NIF(serialize_tensor) {
414+
if (argc != 1) {
415+
return error(env, "invalid number of arguments");
416+
}
417+
418+
iree::runtime::IREETensor** input;
419+
420+
if (!get<iree::runtime::IREETensor*>(env, argv[0], input)) {
421+
return error(env, "invalid input");
422+
}
423+
424+
std::vector<char>* serialized = (*input)->serialize();
425+
426+
ErlNifBinary binary;
427+
428+
if (!enif_alloc_binary(serialized->size(), &binary)) {
429+
return error(env, "unable to allocate binary");
430+
}
431+
432+
std::memcpy(binary.data, serialized->data(), serialized->size());
433+
434+
return ok(env, enif_make_binary(env, &binary));
435+
}
436+
437+
DECLARE_NIF(deserialize_tensor) {
438+
if (argc != 1) {
439+
return error(env, "invalid number of arguments");
440+
}
441+
442+
ErlNifBinary input;
443+
444+
if (!enif_inspect_binary(env, argv[0], &input)) {
445+
return error(env, "invalid input");
446+
}
447+
448+
auto tensor = new iree::runtime::IREETensor(reinterpret_cast<char*>(input.data));
449+
450+
return ok(env, make<iree::runtime::IREETensor*>(env, tensor));
451+
}
452+
413453
DECLARE_NIF(call_nif) {
414454
iree_vm_instance_t** instance;
415455
iree_hal_device_t** device;
@@ -463,6 +503,8 @@ static ErlNifFunc funcs[] = {
463503
{"list_devices", 2, list_devices},
464504
{"list_drivers", 1, list_drivers},
465505
{"allocate_buffer", 4, allocate_buffer},
506+
{"serialize_tensor", 1, serialize_tensor},
507+
{"deserialize_tensor", 1, deserialize_tensor},
466508
{"read_buffer", 3, read_buffer_nif},
467509
{"call_io", 5, call_nif, ERL_NIF_DIRTY_JOB_IO_BOUND},
468510
{"call_cpu", 5, call_nif, ERL_NIF_DIRTY_JOB_CPU_BOUND}};

cmake/src/runtime.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ iree::runtime::Device::~Device() {
2828

2929
iree::runtime::IREETensor::IREETensor(iree_hal_buffer_view_t *buffer_view, iree_hal_element_type_t type) : buffer_view(buffer_view), type(type) {
3030
size = iree_hal_buffer_view_byte_length(buffer_view);
31+
// TODO: fill in dim metadata
3132
}
3233

3334
iree::runtime::IREETensor::IREETensor(void *data, size_t size, std::vector<int64_t> in_dims, iree_hal_element_type_t type) : size(size), type(type) {
@@ -43,6 +44,55 @@ iree::runtime::IREETensor::IREETensor(void *data, size_t size, std::vector<int64
4344
this->buffer_view = nullptr;
4445
}
4546

47+
iree::runtime::IREETensor::IREETensor(char *buffer) {
48+
size_t offset = 0;
49+
50+
// Deserialize 'type'
51+
std::memcpy(&type, buffer + offset, sizeof(type));
52+
offset += sizeof(type);
53+
54+
// Deserialize 'size'
55+
std::memcpy(&size, buffer + offset, sizeof(size));
56+
offset += sizeof(size);
57+
58+
// Allocate memory and deserialize 'data'
59+
data = operator new(size); // Allocate raw memory
60+
std::memcpy(data, buffer + offset, size);
61+
offset += size;
62+
63+
// Deserialize 'dims'
64+
size_t num_dims;
65+
std::memcpy(&num_dims, buffer + offset, sizeof(num_dims));
66+
offset += sizeof(num_dims);
67+
dims.resize(num_dims);
68+
std::memcpy(dims.data(), buffer + offset, num_dims * sizeof(iree_hal_dim_t));
69+
70+
this->buffer_view = nullptr;
71+
}
72+
73+
std::vector<char> *iree::runtime::IREETensor::serialize() {
74+
auto buffer = new std::vector<char>();
75+
76+
// Serialize 'type'
77+
size_t type_size = sizeof(type);
78+
buffer->insert(buffer->end(), reinterpret_cast<const char *>(&type), reinterpret_cast<const char *>(&type) + type_size);
79+
80+
// Serialize 'size'
81+
size_t size_size = sizeof(size);
82+
buffer->insert(buffer->end(), reinterpret_cast<const char *>(&size), reinterpret_cast<const char *>(&size) + size_size);
83+
84+
// Serialize 'data'
85+
buffer->insert(buffer->end(), reinterpret_cast<const char *>(data), reinterpret_cast<const char *>(data) + size);
86+
87+
// Serialize 'dims'
88+
size_t dims_size = sizeof(iree_hal_dim_t) * dims.size();
89+
size_t num_dims = dims.size();
90+
buffer->insert(buffer->end(), reinterpret_cast<const char *>(&num_dims), reinterpret_cast<const char *>(&num_dims) + sizeof(num_dims));
91+
buffer->insert(buffer->end(), reinterpret_cast<const char *>(dims.data()), reinterpret_cast<const char *>(dims.data()) + dims_size);
92+
93+
return buffer;
94+
}
95+
4696
iree_vm_instance_t *create_instance() {
4797
iree_vm_instance_t *instance = nullptr;
4898
iree_status_t status = iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance);

cmake/src/runtime.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class IREETensor {
4242
iree_hal_element_type_t type;
4343
iree_hal_buffer_view_t* buffer_view;
4444

45+
IREETensor(char* serialized_data);
4546
IREETensor(iree_hal_buffer_view_t* buffer_view, iree_hal_element_type_t type);
4647
IREETensor(void* data, size_t size, std::vector<int64_t> in_dims, iree_hal_element_type_t type);
4748

@@ -62,6 +63,10 @@ class IREETensor {
6263
iree_const_byte_span_t data_byte_span() const {
6364
return iree_make_const_byte_span(static_cast<uint8_t*>(data), size);
6465
}
66+
67+
// Serializes the tensor to a buffer that can be transmitted over the wire.
68+
// Fields in order: type, rank, dims, data
69+
std::vector<char>* serialize();
6570
};
6671

6772
} // namespace runtime

lib/nx_iree/native.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,7 @@ defmodule NxIREE.Native do
2525

2626
def call_cpu(_instance_ref, _device_ref, _driver_name, _bytecode, _inputs),
2727
do: :erlang.nif_error(:undef)
28+
29+
def serialize_tensor(_reference), do: :erlang.nif_error(:undef)
30+
def deserialize_tensor(_binary), do: :erlang.nif_error(:undef)
2831
end

lib/nx_iree/tensor.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ defmodule NxIREE.Tensor do
5959

6060
@impl true
6161
def from_binary(out, binary, opts) do
62-
device_uri = opts[:device]
62+
device_uri = opts[:device] || "local-sync://default"
6363

6464
device_ref =
6565
case NxIREE.Device.get(device_uri) do

lib/nx_iree/vm.ex

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,7 @@ defmodule NxIREE.VM do
5757
def allocate_buffer(binary, device_ref, shape, type) when is_binary(binary) do
5858
element_type = to_iree_type(type)
5959

60-
{:ok, buffer_ref} =
61-
NxIREE.Native.allocate_buffer(binary, device_ref, Tuple.to_list(shape), element_type)
62-
63-
buffer_ref
60+
NxIREE.Native.allocate_buffer(binary, device_ref, Tuple.to_list(shape), element_type)
6461
end
6562

6663
def read_buffer(%NxIREE.Tensor{} = t) do

test/nx_iree/native_test.exs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
defmodule NxIREE.NativeTest do
2+
use ExUnit.Case, async: true
3+
4+
test "serializes and deserializes a tensor" do
5+
tensor = Nx.tensor([[[1, 2], [3, 4], [5, 6]]], type: :s32, backend: NxIREE.Tensor)
6+
7+
{:ok, serialized} = NxIREE.Native.serialize_tensor(tensor.data.ref)
8+
9+
assert <<
10+
type::unsigned-integer-native-size(32),
11+
num_bytes::unsigned-integer-native-size(64),
12+
data::binary-size(num_bytes),
13+
num_dims::unsigned-integer-native-size(64),
14+
dims_bin::bitstring
15+
>> = serialized
16+
17+
dims =
18+
for <<x::signed-integer-native-size(64) <- dims_bin>> do
19+
x
20+
end
21+
22+
# the type assertion is really an internal type to iree,
23+
# but we assert on it as a sanity check.
24+
# This can be skipped if needed in the future.
25+
assert Bitwise.band(type, 0xFF) == 32
26+
assert Bitwise.band(Bitwise.bsr(type, 24), 0xFF) == 0x10
27+
28+
assert num_bytes == Nx.byte_size(tensor)
29+
assert data == Nx.to_binary(tensor)
30+
assert num_dims == 3
31+
assert dims == [1, 3, 2]
32+
33+
{:ok, deserialized_ref} = NxIREE.Native.deserialize_tensor(serialized)
34+
35+
assert Nx.to_binary(tensor) == Nx.to_binary(put_in(tensor.data.ref, deserialized_ref))
36+
end
37+
end

0 commit comments

Comments
 (0)