Skip to content

Commit da1cf45

Browse files
authored
fix: defn compiler now works with axon (#10)
* fix: defn compiler for axon * fix: defn compiler * feat: use new flag * refactor: used_inputs as a mapset * chore: update deps * chore: use github deps
1 parent dd01dec commit da1cf45

File tree

7 files changed

+77
-14
lines changed

7 files changed

+77
-14
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ nx_iree-*.tar
3030

3131
/priv/iree-compile
3232
/priv/iree-runtime
33-
/priv/lbnx_iree.so
33+
/priv/libnx_iree.so

axon.exs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
Mix.install([
2+
{:axon, github: "elixir-nx/axon", branch: "main"},
3+
{:nx_iree, path: "."},
4+
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
5+
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true}
6+
], system_env: %{"NX_IREE_PREFER_PRECOMPILED" => false})
7+
8+
NxIREE.list_drivers() |> IO.inspect(label: "drivers")
9+
10+
{:ok, [dev | _]} = NxIREE.list_devices("metal")
11+
12+
flags = ["--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
13+
Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, iree_runtime_options: [device: dev])
14+
15+
model =
16+
Axon.input("x", shape: {nil, 3})
17+
|> Axon.dense(8, activation: :relu)
18+
|> Axon.dense(1, activation: :relu)
19+
20+
Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, iree_runtime_options: [device: dev])
21+
# Nx.Defn.default_options(compiler: EXLA, iree_compiler_flags: flags, iree_runtime_options: [device: dev])
22+
23+
template = %{"x" => Nx.template({10, 3}, :f32)}
24+
25+
{init_fn, predict_fn} = Axon.build(model, [])
26+
init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(Axon.ModelState.empty())])
27+
28+
IO.puts("\n\n\n======= BEGIN predict_compiled_fn =======\n\n\n")
29+
predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template])
30+
IO.puts("\n\n\n======= END predict_compiled_fn =======\n\n\n")
31+
32+
IO.puts("\n\n\n======= BEGIN predict_compiled_fn CALL =======\n\n\n")
33+
predict_compiled_fn.(init_params, Nx.iota({10, 3}, type: :f32)) |> dbg()
34+
IO.puts("\n\n\n======= END predict_compiled_fn CALL =======\n\n\n")

c_src/nx_iree.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,14 @@ DECLARE_NIF(read_buffer_nif) {
355355
return error(env, "invalid num_bytes");
356356
}
357357

358+
std::cout << "num_bytes input: " << num_bytes << std::endl;
359+
360+
if (num_bytes == -1) {
361+
num_bytes = (*input)->size;
362+
}
363+
364+
std::cout << "num_bytes actual: " << num_bytes << std::endl;
365+
358366
ErlNifBinary binary;
359367

360368
if (!enif_alloc_binary(num_bytes, &binary)) {

lib/nx_iree.ex

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ defmodule NxIREE do
2323
{:ok, tmpfile} = create_temp_file(mlir_module)
2424

2525
compiler_path = Path.join(:code.priv_dir(:nx_iree), "iree-compile")
26-
IO.puts(mlir_module)
2726

2827
try do
2928
{output, 0} =
@@ -71,8 +70,14 @@ defmodule NxIREE do
7170

7271
input_refs =
7372
Enum.map(inputs, fn
74-
%Nx.Tensor{data: %NxIREE.Tensor{ref: ref}} -> ref
75-
t -> NxIREE.VM.allocate_buffer(t, device_ref)
73+
%Nx.Tensor{data: %NxIREE.Tensor{ref: ref}} ->
74+
ref
75+
76+
fun when is_function(fun, 0) ->
77+
NxIREE.VM.allocate_buffer(fun.(), device_ref)
78+
79+
t ->
80+
NxIREE.VM.allocate_buffer(t, device_ref)
7681
end)
7782

7883
instance_ref = NxIREE.VM.get_instance()

lib/nx_iree/compiler.ex

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ defmodule NxIREE.Compiler do
1515
@behaviour Nx.Defn.Compiler
1616

1717
@impl true
18-
def __compile__(key, vars, fun, opts) do
19-
output_container = fun.(vars)
20-
18+
def __compile__(_key, vars, fun, opts) do
2119
{iree_compiler_flags, opts} = Keyword.pop(opts, :iree_compiler_flags, nil)
2220
{iree_runtime_options, opts} = Keyword.pop(opts, :iree_runtime_options, [])
2321
{output_mode, opts} = Keyword.pop(opts, :output_mode, nil)
@@ -26,20 +24,22 @@ defmodule NxIREE.Compiler do
2624
raise "missing :iree_compiler_flags option"
2725
end
2826

29-
mlir_module = EXLA.to_mlir_module(key, vars, opts)
27+
%{mlir_module: mlir_module, output_container: output_container, used_inputs: used_inputs} =
28+
EXLA.to_mlir_module(fun, vars, Keyword.put(opts, :within_defn_compiler, true))
3029

3130
bytecode = NxIREE.compile(mlir_module, iree_compiler_flags)
3231

3332
if output_mode == :bytecode do
3433
throw({:bytecode, %{bytecode: bytecode, output_container: output_container}})
3534
else
3635
fn [inputs] ->
36+
filtered_inputs =
37+
filter_inputs_by_indices(inputs, used_inputs)
38+
3739
{:ok, results} =
3840
NxIREE.call(
3941
bytecode,
40-
Enum.map(inputs, fn f ->
41-
f.()
42-
end),
42+
filtered_inputs,
4343
iree_runtime_options
4444
)
4545

@@ -68,4 +68,17 @@ defmodule NxIREE.Compiler do
6868

6969
@impl true
7070
defdelegate __to_backend__(opts), to: EXLA.Defn
71+
72+
defp filter_inputs_by_indices(args, inputs) do
73+
filter_by_indices_list(args, 0, Enum.sort(inputs), fn x, _ -> x end)
74+
end
75+
76+
defp filter_by_indices_list([var | vars], i, [i | inputs], callback),
77+
do: [callback.(var, i) | filter_by_indices_list(vars, i + 1, inputs, callback)]
78+
79+
defp filter_by_indices_list([_var | vars], i, inputs, callback),
80+
do: filter_by_indices_list(vars, i + 1, inputs, callback)
81+
82+
defp filter_by_indices_list([], _i, [], _callback),
83+
do: []
7184
end

lib/nx_iree/vm.ex

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ defmodule NxIREE.VM do
5757
def allocate_buffer(binary, device_ref, shape, type) when is_binary(binary) do
5858
element_type = to_iree_type(type)
5959

60-
NxIREE.Native.allocate_buffer(binary, device_ref, Tuple.to_list(shape), element_type)
60+
{:ok, buffer_ref} =
61+
NxIREE.Native.allocate_buffer(binary, device_ref, Tuple.to_list(shape), element_type)
62+
63+
buffer_ref
6164
end
6265

6366
def read_buffer(%NxIREE.Tensor{} = t) do

mix.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
%{
22
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
33
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
4-
"exla": {:git, "https://github.com/elixir-nx/nx.git", "7a3d7cd87efc9811fb8c86ec0b0b245e99bf7c6d", [sparse: "exla"]},
4+
"exla": {:git, "https://github.com/elixir-nx/nx.git", "ad28ea754dc2780b0b0726a062c46a58c588dc31", [sparse: "exla"]},
55
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
6-
"nx": {:git, "https://github.com/elixir-nx/nx.git", "7a3d7cd87efc9811fb8c86ec0b0b245e99bf7c6d", [sparse: "nx"]},
6+
"nx": {:git, "https://github.com/elixir-nx/nx.git", "ad28ea754dc2780b0b0726a062c46a58c588dc31", [sparse: "nx"]},
77
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
88
"xla": {:hex, :xla, "0.8.0", "fef314d085dd3ee16a0816c095239938f80769150e15db16dfaa435553d7cb16", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "739c61c8d93b97e12ba0369d10e76130224c208f1a76ad293e3581f056833e57"},
99
}

0 commit comments

Comments
 (0)