From d72680a8e8fe0d087d8d41d1f0774eee87db303c Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 8 Aug 2025 12:12:16 -0300 Subject: [PATCH 01/42] feat: add initial draft --- exla/lib/exla/backend.ex | 11 +++++++++ exla/lib/exla/defn.ex | 17 +++++++++++++ exla/lib/exla/mlir/value.ex | 13 ++++++++++ nx/lib/nx/backend.ex | 10 ++++++++ nx/lib/nx/binary_backend.ex | 5 ++++ nx/lib/nx/defn/evaluator.ex | 47 ++++++++++++++++++++++++++++++++++++ nx/lib/nx/defn/expr.ex | 18 ++++++++++++++ nx/lib/nx/defn/tree.ex | 13 ++++++++++ nx/lib/nx/shared.ex | 21 ++++++++++++++++ torchx/lib/torchx/backend.ex | 5 ++++ 10 files changed, 160 insertions(+) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 50747de4bb..539cc911ff 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -329,6 +329,17 @@ defmodule EXLA.Backend do jit([], wrapper_fun, tensors, [List.to_tuple(tensors)]) end + @impl true + def elixir_call(name, args, fun) do + {tensors, rest} = Enum.split_while(args, &is_struct(&1, Nx.Tensor)) + + wrapper_fun = fn tensors -> + Nx.Defn.Expr.elixir_call(name, Tuple.to_list(tensors) ++ rest, fun) + end + + jit([], wrapper_fun, tensors, [List.to_tuple(tensors)]) + end + binary_ops = [:add, :subtract, :multiply, :pow, :remainder, :divide, :atan2, :min, :max, :quotient] ++ [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++ diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 413c38ce45..b899be577a 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -546,6 +546,23 @@ defmodule EXLA.Defn do end end + defp cached_recur_operator(:elixir_call, %T{data: %Expr{args: args}}, state, cache) do + [call, expr, _callback] = args + %{data: %{args: in_args}} = call + + {args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + {_opts, _ignored} = {opts, nil} + + {operands, cache} = Enum.map_reduce(args, cache, &recur_operator(&1, state, &2)) + + out_typespecs = container_to_typespecs(expr) + + # Emit a generic custom call that the EXLA runtime can bind to Erlang/Elixir. + results = Value.custom_call(state.builder, "nx_elixir_custom_call", operands, out_typespecs) + + {wrap_tuple_result(results, expr), cache} + end + defp cached_recur_operator( :lu, %T{ diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index f955e67200..69c63e0bb5 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -840,6 +840,19 @@ defmodule EXLA.MLIR.Value do |> one!() end + def custom_call( + %Function{} = func, + call_target_name, + operands, + out_typespecs, + extra_attrs \\ [] + ) do + result_types = typespecs_to_mlir_types(out_typespecs) + attributes = [call_target_name: attr_string(call_target_name), api_version: attr_i32(4), has_side_effect: attr_boolean(true)] + attributes = attributes ++ extra_attrs + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + end + def get_typespec(value) do EXLA.NIF.mlir_get_typespec(value.ref) end diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 3c463ba237..df0abdcc3e 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -142,6 +142,15 @@ defmodule Nx.Backend do """ @callback optional(atom, [term], fun) :: tensor + @doc """ + Invoked to execute a generic Elixir callback from within defn. + + The backend may choose how to execute it. For example, EXLA can lower + to a custom_call that interacts with Erlang/Elixir via C; pure CPU + backends may call the function directly. + """ + @callback elixir_call(atom, [term], fun) :: tensor + @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor @callback eigh({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor @@ -162,6 +171,7 @@ defmodule Nx.Backend do @optional_callbacks [ optional: 3, + elixir_call: 3, solve: 3, determinant: 2, logical_not: 2, diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 974d558b0d..75eb88e8dd 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2658,4 +2658,9 @@ defmodule Nx.BinaryBackend do defp bitstring_copy(bitstring, n) do for _ <- 1..n, into: <<>>, do: bitstring end + + @impl true + def elixir_call(_name, args, fun) when is_function(fun) do + apply(fun, args) + end end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index d028ec6a63..86b9b0b019 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -175,6 +175,28 @@ defmodule Nx.Defn.Evaluator do Map.put(cache, [:optional | id], optional_expr_cache) end + defp compute_cache(:elixir_call, %{data: %Expr{args: args, id: id}}, state, cache) do + [call, expr, _callback] = args + %{data: %{args: call_args_in, op: call_name}} = call + + {call_args, opts} = Enum.split_while(call_args_in, &(not is_list(&1))) + + cache = Enum.reduce(call_args, cache, &compute_cache(&1, state, &2)) + key = computation_key(call_name, call_args ++ opts) + + {optional_expr_cache, cache} = + case cache do + %{^key => optional_expr_cache} -> + {optional_expr_cache, cache} + + %{} -> + optional_expr_cache = {expr, init_compute_cache(expr, state)} + {optional_expr_cache, Map.put(cache, key, optional_expr_cache)} + end + + Map.put(cache, [:optional | id], optional_expr_cache) + end + defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do %{parent_ids: parent_ids, current_ids: current_ids} = state @@ -431,6 +453,31 @@ defmodule Nx.Defn.Evaluator do end end + defp eval_apply( + :elixir_call, + %{data: %Expr{args: [call, out, _callback], id: id}}, + state, + caches + ) do + {args, caches} = Tree.apply_args(call, caches, &eval(&1, state, &2)) + backend = Nx.Shared.list_impl!(args) + + if function_exported?(backend, call.data.op, length(args) + 1) do + out = + case call do + %{type: {:tuple, _}} -> out + _ -> call + end + + {apply(backend, call.data.op, [out | args]), caches} + else + params = Enum.map(args, &fn -> &1 end) + {{expr, optional_cache}, caches} = pop_cache!(caches, [:optional | id]) + {res, _} = composite_eval(expr, %{state | params: params}, [optional_cache]) + {res, caches} + end + end + defp eval_apply(op, %{vectorized_axes: [_ | _]} = ans, _state, _caches) do raise "unexpected vectorized axes in evaluator for operation #{inspect(op)}: #{inspect(ans)}" end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 782e4a07fd..eb457fe01e 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -41,6 +41,8 @@ defmodule Nx.Defn.Expr do * `attach_token(token(%Nx.Defn.Token{}), expr)` + * `elixir_call(name, args, fun)` + `defn` compilers must handle said nodes accordingly. """ @@ -384,6 +386,22 @@ defmodule Nx.Defn.Expr do end end + @impl true + def elixir_call(name, in_args, fun) do + {args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + params = Enum.with_index(args, ¶meter/2) + + case apply(fun, params ++ opts) do + %{data: %{context: context}} = res -> + expr(res, context, :elixir_call, [expr(res, context, name, in_args), res, fun]) + + t when is_tuple(t) -> + context = elem(t, 0).data.context + out = expr(tuple_out(tuple_size(t)), context, name, in_args) + tuple(expr(out, context, :elixir_call, [out, t, fun]), Tuple.to_list(t)) + end + end + ## Nx.Defn AST callbacks @doc false diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 582b9d4689..9ad840542b 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -192,6 +192,19 @@ defmodule Nx.Defn.Tree do {[call, expr, callback], acc} end + def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, type, acc, fun) do + [call, expr, callback] = args + {call, acc} = fun.(call, acc) + + {expr, acc} = + case type do + :all -> Composite.traverse(expr, acc, fun) + :scope -> {expr, acc} + end + + {[call, expr, callback], acc} + end + def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do {hooks, acc} = Enum.map_reduce(token.hooks, acc, fn %{expr: expr} = token, acc -> diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index 156748859c..5285dc4f0e 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -583,6 +583,27 @@ defmodule Nx.Shared do "expected default implementation to match template #{inspect(right)}, got: #{inspect(left)}" end + @doc false + def elixir_call(output, function_name, args, default_impl) + when is_atom(function_name) and is_list(args) and is_function(default_impl) do + arity = length(args) + 1 + backend = list_impl!(args) + + cond do + function_exported?(backend, function_name, arity) -> + apply(backend, function_name, [output | args]) + + function_exported?(backend, :elixir_call, 3) -> + backend.elixir_call(function_name, args, default_impl) + |> ensure_optional_compatible!(output) + + true -> + default_impl + |> apply(args) + |> ensure_optional_compatible!(output) + end + end + @doc false def raise_complex_not_supported(function, arity) do raise ArgumentError, "Nx.#{function}/#{arity} does not support complex inputs" diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 77308ce94d..1a02fc78dd 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1825,4 +1825,9 @@ defmodule Torchx.Backend do raise "operation #{unquote(fun)} is not supported on Torchx.Backend" end end + + @impl true + def elixir_call(_name, args, fun) when is_function(fun) do + apply(fun, args) + end end From da7d7e44d7c4ef4e91d0897711cab09d1399b5a9 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 8 Aug 2025 13:54:12 -0300 Subject: [PATCH 02/42] evaluator mode working --- elixir_call.exs | 44 +++++++++++++++++++++++++++++++++++ exla/lib/exla/backend.ex | 10 ++------ nx/lib/nx/backend.ex | 2 +- nx/lib/nx/binary_backend.ex | 2 +- nx/lib/nx/defn/evaluator.ex | 45 ++++++++++-------------------------- nx/lib/nx/defn/expr.ex | 20 ++++++++-------- nx/lib/nx/defn/tree.ex | 20 ++++++++-------- nx/lib/nx/shared.ex | 10 +++----- torchx/lib/torchx/backend.ex | 2 +- 9 files changed, 85 insertions(+), 70 deletions(-) create mode 100644 elixir_call.exs diff --git a/elixir_call.exs b/elixir_call.exs new file mode 100644 index 0000000000..a70c76dd9d --- /dev/null +++ b/elixir_call.exs @@ -0,0 +1,44 @@ +Mix.install([{:exla, path: "exla"}, {:pythonx, "~> 0.4"}]) + +Pythonx.uv_init(""" +[project] +name = "project" +version = "0.0.0" +requires-python = "==3.13.*" +dependencies = [ + "numpy==2.2.2" +] +""") + +Nx.global_default_backend(EXLA.Backend) +t = Nx.iota({10}) + +elixir_fun = fn t, opts -> + input = Nx.to_flat_list(t) + + {res, _ctx} = + Pythonx.eval( + """ + import numpy as np + arr = np.array(input) + + c = np.cos(arr) + offset + + list(c) + """, + %{"input" => input, "offset" => opts[:value]} + ) + + Nx.tensor(Pythonx.decode(res)) +end + +jit_fun = fn t -> + s = Nx.size(t) + + out = + Nx.Shared.elixir_call(%{t | type: Nx.Type.to_floating(t.type)}, [t, [value: 10]], elixir_fun) + + Nx.negate(out) +end + +dbg(Nx.Defn.jit_apply(jit_fun, [t])) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 539cc911ff..cc8aef1bfd 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -330,14 +330,8 @@ defmodule EXLA.Backend do end @impl true - def elixir_call(name, args, fun) do - {tensors, rest} = Enum.split_while(args, &is_struct(&1, Nx.Tensor)) - - wrapper_fun = fn tensors -> - Nx.Defn.Expr.elixir_call(name, Tuple.to_list(tensors) ++ rest, fun) - end - - jit([], wrapper_fun, tensors, [List.to_tuple(tensors)]) + def elixir_call(_out, args, fun) do + apply(fun, args) end binary_ops = diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index df0abdcc3e..f8556ce308 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -149,7 +149,7 @@ defmodule Nx.Backend do to a custom_call that interacts with Erlang/Elixir via C; pure CPU backends may call the function directly. """ - @callback elixir_call(atom, [term], fun) :: tensor + @callback elixir_call(out :: tensor | tuple, [term], fun) :: tensor @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 75eb88e8dd..478f276017 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2660,7 +2660,7 @@ defmodule Nx.BinaryBackend do end @impl true - def elixir_call(_name, args, fun) when is_function(fun) do + def elixir_call(_out, args, fun) when is_function(fun) do apply(fun, args) end end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 86b9b0b019..601653df8b 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -176,25 +176,12 @@ defmodule Nx.Defn.Evaluator do end defp compute_cache(:elixir_call, %{data: %Expr{args: args, id: id}}, state, cache) do - [call, expr, _callback] = args - %{data: %{args: call_args_in, op: call_name}} = call - - {call_args, opts} = Enum.split_while(call_args_in, &(not is_list(&1))) - - cache = Enum.reduce(call_args, cache, &compute_cache(&1, state, &2)) - key = computation_key(call_name, call_args ++ opts) - - {optional_expr_cache, cache} = - case cache do - %{^key => optional_expr_cache} -> - {optional_expr_cache, cache} - - %{} -> - optional_expr_cache = {expr, init_compute_cache(expr, state)} - {optional_expr_cache, Map.put(cache, key, optional_expr_cache)} - end + [in_args, _fun] = args - Map.put(cache, [:optional | id], optional_expr_cache) + Enum.reduce(in_args, cache, fn + t, cache when is_list(t) -> cache + t, cache -> compute_cache(t, state, cache) + end) end defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do @@ -455,26 +442,18 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [call, out, _callback], id: id}}, + %{data: %Expr{args: [in_args, fun], id: id}} = expr, state, caches ) do - {args, caches} = Tree.apply_args(call, caches, &eval(&1, state, &2)) - backend = Nx.Shared.list_impl!(args) - - if function_exported?(backend, call.data.op, length(args) + 1) do - out = - case call do - %{type: {:tuple, _}} -> out - _ -> call - end + {tensor_args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + {evaluated_tensors, caches} = Enum.map_reduce(tensor_args, caches, &eval(&1, state, &2)) + backend = Nx.Shared.list_impl!(evaluated_tensors) - {apply(backend, call.data.op, [out | args]), caches} + if backend == Nx.Defn.Expr do + {expr, caches} else - params = Enum.map(args, &fn -> &1 end) - {{expr, optional_cache}, caches} = pop_cache!(caches, [:optional | id]) - {res, _} = composite_eval(expr, %{state | params: params}, [optional_cache]) - {res, caches} + {apply(fun, evaluated_tensors ++ opts), caches} end end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index eb457fe01e..1d488df888 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -387,18 +387,18 @@ defmodule Nx.Defn.Expr do end @impl true - def elixir_call(name, in_args, fun) do - {args, opts} = Enum.split_while(in_args, &(not is_list(&1))) - params = Enum.with_index(args, ¶meter/2) + def elixir_call(out, in_args, fun) do + {tensor_args, _opts} = Enum.split_while(in_args, &(not is_list(&1))) + [%T{data: %Expr{context: context}} | _] = Enum.map(tensor_args, &to_expr/1) - case apply(fun, params ++ opts) do - %{data: %{context: context}} = res -> - expr(res, context, :elixir_call, [expr(res, context, name, in_args), res, fun]) + case out do + t when is_struct(t, Nx.Tensor) -> + expr(t, context, :elixir_call, [in_args, fun]) - t when is_tuple(t) -> - context = elem(t, 0).data.context - out = expr(tuple_out(tuple_size(t)), context, name, in_args) - tuple(expr(out, context, :elixir_call, [out, t, fun]), Tuple.to_list(t)) + tuple when is_tuple(tuple) -> + out_template = tuple_out(tuple_size(tuple)) + expr_node = expr(out_template, context, :elixir_call, [in_args, fun]) + tuple(expr_node, Tuple.to_list(tuple)) end end diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 9ad840542b..bd741346e7 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -193,16 +193,18 @@ defmodule Nx.Defn.Tree do end def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, type, acc, fun) do - [call, expr, callback] = args - {call, acc} = fun.(call, acc) - - {expr, acc} = - case type do - :all -> Composite.traverse(expr, acc, fun) - :scope -> {expr, acc} - end + [in_args, callback] = args + + {in_args, acc} = + Enum.map_reduce(in_args, acc, fn t, acc -> + if is_list(t) do + {t, acc} + else + Composite.traverse(t, acc, fun) + end + end) - {[call, expr, callback], acc} + {[in_args, callback], acc} end def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index 5285dc4f0e..e30b699fc5 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -584,21 +584,17 @@ defmodule Nx.Shared do end @doc false - def elixir_call(output, function_name, args, default_impl) - when is_atom(function_name) and is_list(args) and is_function(default_impl) do + def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do arity = length(args) + 1 backend = list_impl!(args) cond do - function_exported?(backend, function_name, arity) -> - apply(backend, function_name, [output | args]) - function_exported?(backend, :elixir_call, 3) -> - backend.elixir_call(function_name, args, default_impl) + backend.elixir_call(output, args, fun) |> ensure_optional_compatible!(output) true -> - default_impl + fun |> apply(args) |> ensure_optional_compatible!(output) end diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 1a02fc78dd..eb813e1e26 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1827,7 +1827,7 @@ defmodule Torchx.Backend do end @impl true - def elixir_call(_name, args, fun) when is_function(fun) do + def elixir_call(_out, args, fun) when is_function(fun) do apply(fun, args) end end From fc9c28ca0ade383c308bdccfba37b81702770320 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:37:20 -0300 Subject: [PATCH 03/42] test: add tests --- elixir_call.exs | 44 ------------- exla/lib/exla/backend.ex | 5 -- exla/lib/exla/defn.ex | 21 ++----- exla/test/exla/defn/elixir_call_test.exs | 61 +++++++++++++++++++ nx/lib/nx.ex | 55 +++++++++++++++++ nx/lib/nx/binary_backend.ex | 5 -- nx/lib/nx/defn/evaluator.ex | 4 +- nx/lib/nx/defn/tree.ex | 2 +- nx/lib/nx/shared.ex | 17 ------ .../nx/defn/elixir_call_evaluator_test.exs | 49 +++++++++++++++ torchx/lib/torchx/backend.ex | 5 -- torchx/mix.exs | 4 +- torchx/test/torchx/defn/elixir_call_test.exs | 51 ++++++++++++++++ 13 files changed, 225 insertions(+), 98 deletions(-) delete mode 100644 elixir_call.exs create mode 100644 exla/test/exla/defn/elixir_call_test.exs create mode 100644 nx/test/nx/defn/elixir_call_evaluator_test.exs create mode 100644 torchx/test/torchx/defn/elixir_call_test.exs diff --git a/elixir_call.exs b/elixir_call.exs deleted file mode 100644 index a70c76dd9d..0000000000 --- a/elixir_call.exs +++ /dev/null @@ -1,44 +0,0 @@ -Mix.install([{:exla, path: "exla"}, {:pythonx, "~> 0.4"}]) - -Pythonx.uv_init(""" -[project] -name = "project" -version = "0.0.0" -requires-python = "==3.13.*" -dependencies = [ - "numpy==2.2.2" -] -""") - -Nx.global_default_backend(EXLA.Backend) -t = Nx.iota({10}) - -elixir_fun = fn t, opts -> - input = Nx.to_flat_list(t) - - {res, _ctx} = - Pythonx.eval( - """ - import numpy as np - arr = np.array(input) - - c = np.cos(arr) + offset - - list(c) - """, - %{"input" => input, "offset" => opts[:value]} - ) - - Nx.tensor(Pythonx.decode(res)) -end - -jit_fun = fn t -> - s = Nx.size(t) - - out = - Nx.Shared.elixir_call(%{t | type: Nx.Type.to_floating(t.type)}, [t, [value: 10]], elixir_fun) - - Nx.negate(out) -end - -dbg(Nx.Defn.jit_apply(jit_fun, [t])) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index cc8aef1bfd..50747de4bb 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -329,11 +329,6 @@ defmodule EXLA.Backend do jit([], wrapper_fun, tensors, [List.to_tuple(tensors)]) end - @impl true - def elixir_call(_out, args, fun) do - apply(fun, args) - end - binary_ops = [:add, :subtract, :multiply, :pow, :remainder, :divide, :atan2, :min, :max, :quotient] ++ [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++ diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index b899be577a..7a7b3865a8 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -546,23 +546,6 @@ defmodule EXLA.Defn do end end - defp cached_recur_operator(:elixir_call, %T{data: %Expr{args: args}}, state, cache) do - [call, expr, _callback] = args - %{data: %{args: in_args}} = call - - {args, opts} = Enum.split_while(in_args, &(not is_list(&1))) - {_opts, _ignored} = {opts, nil} - - {operands, cache} = Enum.map_reduce(args, cache, &recur_operator(&1, state, &2)) - - out_typespecs = container_to_typespecs(expr) - - # Emit a generic custom call that the EXLA runtime can bind to Erlang/Elixir. - results = Value.custom_call(state.builder, "nx_elixir_custom_call", operands, out_typespecs) - - {wrap_tuple_result(results, expr), cache} - end - defp cached_recur_operator( :lu, %T{ @@ -1226,6 +1209,10 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end + defp to_operator(:elixir_call, _, _, _) do + raise "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler." + end + defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do n = opts[:length] axis = opts[:axis] diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs new file mode 100644 index 0000000000..add051a3f6 --- /dev/null +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -0,0 +1,61 @@ +defmodule EXLA.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + import Nx.Testing + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + Nx.default_backend(EXLA.Backend) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert_equal(y, expected) + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end + + test "fails when using EXLA compiler" do + x = Nx.tensor([1, 2, 3]) + + assert_raise RuntimeError, + "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler.", + fn -> + EXLA.jit_apply(&split_and_sum/1, [x]) + end + end +end diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 091372d005..715a149286 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2196,6 +2196,61 @@ defmodule Nx do list end + @doc """ + Invokes an Elixir function from within defn. + + This function allows integrating arbitrary Elixir code into `defn` graphs. + It receives an output template (a tensor or a tuple of tensors) that + specifies the expected shapes, types, and names of the result, a list of + arguments to pass to the Elixir function, and the function itself. + + Inside `defn`, this builds an expression node understood by compilers. + Outside `defn` or on backends without special support, it executes `fun` + directly and validates the result matches the template. + """ + @doc type: :backend + def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do + {:arity, arity} = Function.info(fun, :arity) + num_args = length(args) + + if arity != num_args do + raise ArgumentError, + "expected #{arity} arguments, got #{num_args}" + end + + backend = Nx.Shared.list_impl!(args) + + cond do + function_exported?(backend, :elixir_call, 3) -> + output + |> backend.elixir_call(args, fun) + |> ensure_call_compatible!(output) + + true -> + fun + |> apply(args) + |> ensure_call_compatible!(output) + end + end + + defp ensure_call_compatible!(left, right) when tuple_size(left) == tuple_size(right) do + [Tuple.to_list(left), Tuple.to_list(right)] + |> Enum.zip_with(fn [l, r] -> ensure_call_compatible!(l, r) end) + + left + end + + defp ensure_call_compatible!( + %{shape: shape, type: type, names: names} = left, + %{shape: shape, type: type, names: names} + ), + do: left + + defp ensure_call_compatible!(left, right) do + raise ArgumentError, + "expected the elixir_call function to match the given output template #{inspect(right)}, got: #{inspect(left)}" + end + defp chunk([], data, type) do match_types [type] do <> = data diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 478f276017..974d558b0d 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2658,9 +2658,4 @@ defmodule Nx.BinaryBackend do defp bitstring_copy(bitstring, n) do for _ <- 1..n, into: <<>>, do: bitstring end - - @impl true - def elixir_call(_out, args, fun) when is_function(fun) do - apply(fun, args) - end end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 601653df8b..c913f4ec3c 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -175,7 +175,7 @@ defmodule Nx.Defn.Evaluator do Map.put(cache, [:optional | id], optional_expr_cache) end - defp compute_cache(:elixir_call, %{data: %Expr{args: args, id: id}}, state, cache) do + defp compute_cache(:elixir_call, %{data: %Expr{args: args}}, state, cache) do [in_args, _fun] = args Enum.reduce(in_args, cache, fn @@ -442,7 +442,7 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [in_args, fun], id: id}} = expr, + %{data: %Expr{args: [in_args, fun]}} = expr, state, caches ) do diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index bd741346e7..733a131e4f 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -192,7 +192,7 @@ defmodule Nx.Defn.Tree do {[call, expr, callback], acc} end - def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, type, acc, fun) do + def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, _type, acc, fun) do [in_args, callback] = args {in_args, acc} = diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index e30b699fc5..156748859c 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -583,23 +583,6 @@ defmodule Nx.Shared do "expected default implementation to match template #{inspect(right)}, got: #{inspect(left)}" end - @doc false - def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do - arity = length(args) + 1 - backend = list_impl!(args) - - cond do - function_exported?(backend, :elixir_call, 3) -> - backend.elixir_call(output, args, fun) - |> ensure_optional_compatible!(output) - - true -> - fun - |> apply(args) - |> ensure_optional_compatible!(output) - end - end - @doc false def raise_complex_not_supported(function, arity) do raise ArgumentError, "Nx.#{function}/#{arity} does not support complex inputs" diff --git a/nx/test/nx/defn/elixir_call_evaluator_test.exs b/nx/test/nx/defn/elixir_call_evaluator_test.exs new file mode 100644 index 0000000000..92fad6b431 --- /dev/null +++ b/nx/test/nx/defn/elixir_call_evaluator_test.exs @@ -0,0 +1,49 @@ +defmodule Nx.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert Nx.all_close(y, expected) |> Nx.to_number() == 1 + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert expected == y + end +end diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index eb813e1e26..77308ce94d 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1825,9 +1825,4 @@ defmodule Torchx.Backend do raise "operation #{unquote(fun)} is not supported on Torchx.Backend" end end - - @impl true - def elixir_call(_out, args, fun) when is_function(fun) do - apply(fun, args) - end end diff --git a/torchx/mix.exs b/torchx/mix.exs index fa5531e541..5174e59cdb 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -41,8 +41,8 @@ defmodule Torchx.MixProject do defp deps do [ - {:nx, "~> 0.10.0"}, - # {:nx, path: "../nx"}, + # {:nx, "~> 0.10.0"}, + {:nx, path: "../nx"}, {:ex_doc, "~> 0.29", only: :docs} ] end diff --git a/torchx/test/torchx/defn/elixir_call_test.exs b/torchx/test/torchx/defn/elixir_call_test.exs new file mode 100644 index 0000000000..9c504fa6c8 --- /dev/null +++ b/torchx/test/torchx/defn/elixir_call_test.exs @@ -0,0 +1,51 @@ +defmodule Torchx.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + import Nx.Testing + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + Nx.default_backend(Torchx.Backend) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert_equal(y, expected) + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end +end From 25300b7d1960a8ac2136eb94f3fef37a9e7bc52a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:41:54 -0300 Subject: [PATCH 04/42] fix grad --- exla/lib/exla/mlir/value.ex | 13 ------------- nx/lib/nx/defn/grad.ex | 4 ++++ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 69c63e0bb5..f955e67200 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -840,19 +840,6 @@ defmodule EXLA.MLIR.Value do |> one!() end - def custom_call( - %Function{} = func, - call_target_name, - operands, - out_typespecs, - extra_attrs \\ [] - ) do - result_types = typespecs_to_mlir_types(out_typespecs) - attributes = [call_target_name: attr_string(call_target_name), api_version: attr_i32(4), has_side_effect: attr_boolean(true)] - attributes = attributes ++ extra_attrs - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - end - def get_typespec(value) do EXLA.NIF.mlir_get_typespec(value.ref) end diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 2941889f98..8c72d0fed0 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -122,6 +122,10 @@ defmodule Nx.Defn.Grad do acc end + defp parents_args(:elixir_call, _expr, _id, acc, _parent_vectorized_names) do + acc + end + defp parents_args( :optional, %{data: %{args: [call, _expr, callback]}} = t, From aa3543156789e7ffcf40c6466b0ff894ccb49eff Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 22 Nov 2025 01:29:13 -0300 Subject: [PATCH 05/42] feat(exla): initial Nx.elixir_call/3 CPU wiring --- exla/c_src/exla/exla.cc | 58 +++ exla/lib/exla/application.ex | 3 +- exla/lib/exla/callback_server.ex | 205 +++++++++++ exla/lib/exla/defn.ex | 40 +- exla/lib/exla/mlir/value.ex | 23 ++ exla/lib/exla/nif.ex | 4 + exla/test/exla/defn/elixir_call_exla_test.exs | 70 ++++ exla/test/exla/defn/elixir_call_test.exs | 11 +- nx_elixir_call_exla_design.md | 344 ++++++++++++++++++ 9 files changed, 747 insertions(+), 11 deletions(-) create mode 100644 exla/lib/exla/callback_server.ex create mode 100644 exla/test/exla/defn/elixir_call_exla_test.exs create mode 100644 nx_elixir_call_exla_design.md diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index abde2f8fca..e0842599b6 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -521,6 +521,64 @@ get_per_device_memory(ErlNifEnv *env, fine::ResourcePtr client) { FINE_NIF(get_per_device_memory, 0); +// Elixir callback bridge + +namespace { + +// Very small, CPU-only bridge that forwards callback requests from the XLA +// host CustomCall to the Elixir dispatcher process. +// +// For Phase 1 we keep this intentionally simple: +// * Only single-output, single-replica computations are supported. +// * Arguments and results are transferred by value as host binaries. +// * Each request is synchronous: the CustomCall will block the XLA host +// thread until the Elixir side replies via `elixir_callback_reply/2`. +// +// This can be evolved later to support batching, more efficient tensor +// encoding, and timeouts. + +struct ElixirCallbackRequest { + int64_t callback_id; + std::vector args; + ERL_NIF_TERM reply_tag; +}; + +// Global state for the bridge. For simplicity we keep a single dispatcher +// PID and use a monotonically increasing integer as reply_tag. +struct ElixirCallbackBridgeState { + ErlNifPid dispatcher_pid; + std::atomic next_tag{1}; +}; + +ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { + static ElixirCallbackBridgeState *state = new ElixirCallbackBridgeState(); + return state; +} + +} // namespace + +std::tuple, fine::Error> +start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid) { + auto state = GetElixirCallbackBridgeState(); + state->dispatcher_pid = dispatcher_pid; + return std::make_tuple(fine::Ok<>(), fine::Error()); +} + +FINE_NIF(start_elixir_callback_bridge, 0); + +std::tuple, fine::Error> +elixir_callback_reply(ErlNifEnv *env, int64_t reply_tag, fine::Term _payload) { + // For Phase 1 we do not implement a native waiting mechanism; instead the + // CustomCall handler calls directly into Elixir and returns immediately. + // This NIF exists only as a placeholder for future, more advanced bridges. + (void)env; + (void)reply_tag; + (void)_payload; + return std::make_tuple(fine::Ok<>(), fine::Error()); +} + +FINE_NIF(elixir_callback_reply, 0); + // Logging fine::Ok<> start_log_sink(ErlNifEnv *env, ErlNifPid logger_pid) { diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 9ec098a3e6..48e1e5c8db 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -22,7 +22,8 @@ defmodule EXLA.Application do EXLA.Client, EXLA.Defn.Lock, EXLA.Defn.LockedCache, - {Task.Supervisor, name: EXLA.Defn.TaskSupervisor} + {Task.Supervisor, name: EXLA.Defn.TaskSupervisor}, + EXLA.CallbackServer ] Supervisor.start_link(children, name: __MODULE__, strategy: :one_for_one) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex new file mode 100644 index 0000000000..b3434a69b0 --- /dev/null +++ b/exla/lib/exla/callback_server.ex @@ -0,0 +1,205 @@ +defmodule EXLA.CallbackServer do + @moduledoc """ + Dispatcher and registry for `Nx.elixir_call/3` callbacks used by EXLA. + + This server has two responsibilities: + + * Assign a stable integer callback id for each Elixir function + output + template pair that participates in `Nx.elixir_call/3` when using the + EXLA compiler. + + * Receive callback requests from the native EXLA bridge thread, execute + the Elixir function, validate the result against the expected output + template, and reply back to native through a NIF. + + The native side is expected to: + + * Lower `:elixir_call` nodes to a CPU-only host `CustomCall` named + `"exla_elixir_callback"` with a callback id encoded in its attributes. + + * Run a bridge thread that sends messages of the form: + + {:exla_elixir_call, callback_id :: integer, args :: [Nx.Tensor.t()], reply_tag :: term()} + + to this process and waits on a native future associated with `reply_tag`. + + * Provide a NIF `EXLA.NIF.elixir_callback_reply/2` that completes the + native future when we send the reply back. + """ + + use GenServer + + require Logger + + @type callback_id :: non_neg_integer() + + defstruct next_id: 1, + callbacks: %{} + + @type t :: %__MODULE__{ + next_id: non_neg_integer(), + callbacks: %{callback_id() => {fun(), Nx.t() | tuple()}} + } + + ## Public API + + @doc """ + Starts the callback server and registers it as the EXLA dispatcher process. + + The EXLA NIF is notified of the dispatcher PID so it can route + `:exla_elixir_call` messages to this process. + """ + def start_link(_init_arg) do + GenServer.start_link(__MODULE__, :ok, name: __MODULE__) + end + + @doc """ + Registers a callback function and its output template, returning a callback id. + + The same `{fun, out_template}` pair will always return the same id for the + lifetime of this VM. This id is what the EXLA compiler should encode into + the host `CustomCall` so the native side can reference the right callback. + """ + @spec register(fun(), Nx.t() | tuple()) :: callback_id() + def register(fun, out_template) when is_function(fun) do + GenServer.call(__MODULE__, {:register, fun, out_template}) + end + + ## GenServer callbacks + + @impl true + def init(:ok) do + # Inform native side that this process is the dispatcher for elixir callbacks. + # + # If the NIF has not implemented `start_elixir_callback_bridge/1` yet, we + # fail silently so that the rest of the system continues to work. This + # allows developing the Elixir side and the native side independently. + _ = + try do + EXLA.NIF.start_elixir_callback_bridge(self()) + rescue + _ -> :ok + end + + {:ok, %__MODULE__{}} + end + + @impl true + def handle_call({:register, fun, out_template}, _from, %__MODULE__{} = state) do + key = {fun, Nx.to_template(out_template)} + + {id, state} = + case find_existing_id(state.callbacks, key) do + {:ok, id} -> + {id, state} + + :error -> + id = state.next_id + callbacks = Map.put(state.callbacks, id, {fun, Nx.to_template(out_template)}) + {%{state | callbacks: callbacks, next_id: id + 1}.next_id - 1, %{state | callbacks: callbacks, next_id: id + 1}} + end + + {:reply, id, state} + end + + @impl true + def handle_info({:exla_elixir_call, callback_id, args, reply_tag}, %__MODULE__{} = state) do + case Map.fetch(state.callbacks, callback_id) do + {:ok, {fun, out_template}} -> + reply_payload = + run_callback(fun, args, out_template) + |> encode_reply() + + send_reply(reply_tag, reply_payload) + + {:noreply, state} + + :error -> + Logger.error( + "EXLA.CallbackServer received callback id #{inspect(callback_id)} that is not registered" + ) + + send_reply(reply_tag, {:error, :unknown_callback}) + {:noreply, state} + end + end + + def handle_info(other, state) do + Logger.debug("EXLA.CallbackServer ignoring unexpected message: #{inspect(other)}") + {:noreply, state} + end + + ## Internal helpers + + defp find_existing_id(callbacks, key) do + Enum.reduce_while(callbacks, :error, fn {id, value}, _acc -> + if value == key, do: {:halt, {:ok, id}}, else: {:cont, :error} + end) + end + + defp run_callback(fun, args, out_template) do + result = + try do + apply(fun, args) + rescue + exception -> + {:error, {:exception, exception, __STACKTRACE__}} + catch + kind, reason -> + {:error, {kind, reason}} + end + + case result do + {:error, _} = error -> + error + + value -> + case ensure_compatible(value, out_template) do + {:ok, tensor_or_tuple} -> {:ok, tensor_or_tuple} + {:error, reason} -> {:error, reason} + end + end + end + + defp ensure_compatible(left, right) when is_tuple(left) and is_tuple(right) do + if tuple_size(left) == tuple_size(right) do + [Tuple.to_list(left), Tuple.to_list(right)] + |> Enum.zip_with(fn [l, r] -> + case ensure_compatible(l, r) do + {:ok, _} -> :ok + {:error, reason} -> throw({:error, reason}) + end + end) + + {:ok, left} + else + {:error, {:mismatched_tuple_size, left, right}} + end + catch + {:error, reason} -> {:error, reason} + end + + defp ensure_compatible(%Nx.Tensor{} = left, %Nx.Tensor{} = right) do + if left.shape == right.shape and left.type == right.type and left.names == right.names do + {:ok, left} + else + {:error, {:shape_mismatch, left, right}} + end + end + + defp ensure_compatible(left, right), do: {:error, {:invalid_result, left, right}} + + defp encode_reply({:ok, value}), do: {:ok, value} + defp encode_reply({:error, reason}), do: {:error, reason} + + defp send_reply(reply_tag, payload) do + try do + EXLA.NIF.elixir_callback_reply(reply_tag, payload) + rescue + _ -> + Logger.error( + "EXLA.CallbackServer failed to send reply to native for tag #{inspect(reply_tag)}" + ) + end + end +end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 7a7b3865a8..a2a04e8cd5 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -546,6 +546,42 @@ defmodule EXLA.Defn do end end + defp cached_recur_operator( + :elixir_call, + %T{data: %Expr{args: [in_args, fun]}} = expr, + %{client: %EXLA.Client{platform: :host}} = state, + cache + ) do + {tensor_args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + + {call_args, cache} = + Enum.map_reduce(tensor_args, cache, fn arg, cache -> + recur_operator(arg, state, cache) |> unwrap_single_tensor!() + end) + + callback_id = EXLA.CallbackServer.register(fun, Nx.to_template(expr)) + typespecs = container_to_typespecs(expr) + + results = + Value.elixir_call(call_args, callback_id, typespecs) + + {wrap_tuple_result(results, expr), cache} + end + + defp cached_recur_operator( + :elixir_call, + _expr, + %{client: %EXLA.Client{platform: platform}}, + cache + ) do + raise """ + Nx.elixir_call/3 is currently only supported for EXLA CPU (platform: :host), + but the active EXLA client is configured for platform #{inspect(platform)}. + Please run on the :host client or wait for future segmentation-based support. + """ + |> then(fn _ -> {nil, cache} end) + end + defp cached_recur_operator( :lu, %T{ @@ -1209,10 +1245,6 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end - defp to_operator(:elixir_call, _, _, _) do - raise "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler." - end - defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do n = opts[:length] axis = opts[:axis] diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index f955e67200..fcc1ef8dd9 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -832,6 +832,29 @@ defmodule EXLA.MLIR.Value do {p, l, u} end + @doc """ + Builds a StableHLO `custom_call` that targets the EXLA Elixir callback bridge. + + The `callback_id` is a small integer assigned by `EXLA.CallbackServer` that + identifies which Elixir function should be invoked when the host callback + runs. The native side is expected to read this id from the backend config + or attributes and route the callback accordingly. + """ + def elixir_call([%Value{function: func} | _] = operands, callback_id, typespecs) + when is_integer(callback_id) and callback_id >= 0 do + result_types = typespecs_to_mlir_types(typespecs) + + attributes = [ + call_target_name: attr_string("exla_elixir_callback"), + api_version: attr_i32(4), + # We currently encode the callback id as a backend config string. + # The native handler should parse this value back into an integer. + backend_config: attr_string(Integer.to_string(callback_id)) + ] + + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + end + def get_tuple_element(%Value{function: func} = operand, index, typespec) do result_types = typespecs_to_mlir_types([typespec]) attributes = [index: attr_i32(index)] diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 203dc30fd8..5cded357bd 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -78,5 +78,9 @@ defmodule EXLA.NIF do def reset_peak_memory(_client), do: err!() def get_per_device_memory(_client), do: err!() + # Elixir callback bridge (Phase 1: CPU-only, simple APIs) + def start_elixir_callback_bridge(_dispatcher_pid), do: err!() + def elixir_callback_reply(_reply_tag, _payload), do: err!() + defp err!(), do: :erlang.nif_error(:undef) end diff --git a/exla/test/exla/defn/elixir_call_exla_test.exs b/exla/test/exla/defn/elixir_call_exla_test.exs new file mode 100644 index 0000000000..75feb748aa --- /dev/null +++ b/exla/test/exla/defn/elixir_call_exla_test.exs @@ -0,0 +1,70 @@ +defmodule EXLA.Defn.ElixirCallEXLATest do + use ExUnit.Case, async: true + import Nx.Defn + import Nx.Testing + + @moduletag :exla + + setup do + Nx.Defn.default_options(compiler: EXLA) + Nx.default_backend(EXLA.Backend) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output on EXLA CPU" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert_equal(y, expected) + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output on EXLA CPU" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end + + test "elixir_call errors when result shape does not match template" do + defn bad_callback(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x], fn _t -> + # Wrong shape on purpose + Nx.tensor([1.0, 2.0, 3.0]) + end) + end + + x = Nx.iota({2}) + + assert_raise ArgumentError, ~r/expected the elixir_call function to match/, fn -> + bad_callback(x) + end + end +end diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index add051a3f6..09d10a02be 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -49,13 +49,12 @@ defmodule EXLA.Defn.ElixirCallEvaluatorTest do assert_equal(y, expected) end - test "fails when using EXLA compiler" do + test "works when using EXLA compiler directly" do x = Nx.tensor([1, 2, 3]) + y = EXLA.jit_apply(&split_and_sum/1, [x]) - assert_raise RuntimeError, - "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler.", - fn -> - EXLA.jit_apply(&split_and_sum/1, [x]) - end + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) end end diff --git a/nx_elixir_call_exla_design.md b/nx_elixir_call_exla_design.md new file mode 100644 index 0000000000..aac22ee371 --- /dev/null +++ b/nx_elixir_call_exla_design.md @@ -0,0 +1,344 @@ +## Design: `Nx.elixir_call/3` and EXLA Integration + +### 1. Overview + +This document describes a two-phase plan to implement safe, efficient support for calling arbitrary Elixir code from `defn` via `Nx.elixir_call/3`, with a focus on EXLA. + +- **Phase 1**: CPU-only implementation using EXLA + XLA host `CustomCall`, with a safe bridge to Elixir (no `nif_call`-style reentry into BEAM). +- **Phase 2**: Graph segmentation in `Nx.Defn.Graph` so the compiler can: + - Treat `elixir_call` as a boundary and split the computation into stages. + - Enable cross-device execution (CPU/GPU) while preserving a single user API. + - Eventually optimize some callbacks to be compiled away or lowered differently (e.g. pure functions expressible in Nx). + +This work extends [PR #1627, “feat: Nx.elixir_call/3”](https://github.com/elixir-nx/nx/pull/1627), which currently implements `Nx.elixir_call/3` only in `Nx.Defn.Evaluator`. + +--- + +### 2. Goals and Non-goals + +- **Goals** + - **G1**: Provide a **public API** (`Nx.elixir_call/3`) that allows calling user-provided Elixir code inside `defn`. + - **G2**: Implement a **safe** EXLA backend for `elixir_call` on **CPU** using XLA host `CustomCall`. + - **G3**: Ensure callbacks have **statically known shapes/dtypes** to keep compilation and gradients well-defined. + - **G4**: Provide a **unified intermediate representation** in `Nx.Defn.Graph` so future backends (EXLA GPU, other compilers) can share the same abstraction. + - **G5**: In Phase 2, support **graph segmentation** around `elixir_call` so that: + - We can mix device computation (CPU/GPU) with Elixir callbacks. + - The compiler can decide to either split or compile callbacks, depending on their structure. + +- **Non-goals (for now)** + - **NG1**: No direct, device-side callbacks for GPU/TPU in Phase 1 (no infeed/outfeed complexity yet). + - **NG2**: No guarantees about **side-effect isolation** of callbacks (user is responsible), beyond not violating BEAM safety. + - **NG3**: No attempt to automatically infer output shapes/dtypes of callbacks at runtime; shapes must be known at `defn`/compile time. + +--- + +### 3. Terminology + +- **`elixir_call` node**: The internal IR node representing a call to arbitrary Elixir code (backed by `Nx.elixir_call/3`). +- **Callback ID**: A stable identifier (string or integer) used to look up the Elixir function and output spec at compile/run time. +- **Output spec**: Shape and type description for all outputs of a callback. +- **Bridge thread**: A native (C/C++) thread that acts as a mediator between XLA/EXLA and BEAM, using message-passing only (no direct BEAM calls from arbitrary XLA threads). + +--- + +### 4. Phase 1: CPU-only EXLA Backend (Host `CustomCall`) + +#### 4.1 Public API: `Nx.elixir_call/3` + +- **Goal**: Reuse and finalize the API introduced in [nx#1627](https://github.com/elixir-nx/nx/pull/1627). + +- **Shape** (subject to minor refinement): + + - `Nx.elixir_call(args, fun_or_mfa, opts \\ [])` + +- **Key options / metadata**: + - **`id` or `name`**: A stable callback identifier (string or integer). + - **`output_template`** (or equivalent): A value (or list/tuple of values) that describes the **shapes and dtypes** of the callback’s outputs: + - Can be Nx tensors or a structured spec, but must be statically known at `defn` compile time. + - Potentially an **`impure`** flag (or similar) in the future to guide compiler optimizations. + +- **Constraints**: + - `fun_or_mfa` is not executed at `defn` compile time (except possibly in the evaluator backend). + - Output shape/type comes from `output_template`, not from running the function. + +#### 4.2 Nx IR: Representing `elixir_call` + +- **Extend** `Nx.Defn.Expr` / `Nx.Defn.Graph` to carry `elixir_call` nodes explicitly. + +- Proposed internal form: + + - `{:elixir_call, meta, args}` + + Where: + - **`meta`** includes: + - `callback_id` (string/int). + - `fun_or_mfa` or internal reference (for evaluator backend and dispatcher). + - `output_spec` (shapes + dtypes). + - Any flags required for compilation/grad. + - **`args`**: list of argument expressions. + +- **Requirements**: + - Shape inference for `elixir_call` uses `output_spec`. + - Optimizer must **not** fuse or eliminate `elixir_call`; it is a logical boundary and may be effectful. + - The evaluator backend (as in nx#1627) already knows how to interpret it. + +#### 4.3 EXLA Lowering: From `elixir_call` to HLO/StableHLO + +- In the EXLA backend (Elixir side): + + - When encountering an `elixir_call` node while building HLO/StableHLO: + + - Lower `args` to HLO values. + - Construct a `CustomCall` operation with: + - **Operands**: those input HLO values. + - **Result types**: from `output_spec`. + - **Call target name**: e.g. `"exla_elixir_callback"`. + - **Attributes**: + - `callback_id` (string/int). + - Optionally an encoded `output_spec` (if needed on the native side). + +- **CPU-only restriction** (Phase 1): + - If the active EXLA client is **CPU**, allow this lowering. + - If the client is GPU (or other non-CPU), raise a **clear error**: + - e.g. “`Nx.elixir_call/3` is currently only supported for EXLA CPU; please run on CPU or wait for Phase 2 segmentation support.” + +#### 4.4 Native EXLA: Callback Registry and Bridge + +- **Callback registry (Elixir → native)**: + - At the time of building an EXLA executable, collect all callbacks: + + - Map: `callback_id → {fun_or_mfa, output_spec}`. + + - Pass this mapping down to the native side, associated with the executable or run context. + +- **Native data structures** (C/C++ side): + + - `struct CallbackRequest { RunRef run_ref; CallbackId callback_id; std::vector args; ReplyTag reply_tag; std::promise promise; };` + + - `struct CallbackResult { ReplyTag reply_tag; std::vector outputs; Error error; };` + + - A **thread-safe queue** for `CallbackRequest`s. + + - A **map** `reply_tag → std::promise` guarded by a mutex. + +- **Bridge thread**: + + - Started when the EXLA NIF is initialized (or when the first callback-capable executable is created). + + - Main loop: + 1. Pop `CallbackRequest` from the queue. + 2. Serialize `args` into a compact binary representation (shape metadata + flat data). + 3. Use `enif_send` to send a message to a **dedicated Elixir dispatcher process**: + - Message format (conceptual): + `{:exla_elixir_call, run_ref, callback_id, args_bin, reply_tag}`. + 4. Wait on the `std::promise`/`std::future` associated with `reply_tag` until `CallbackResult` is set: + - **Important**: This wait uses only native primitives (no BEAM APIs, no `nif_call`), so it is safe w.r.t. BEAM scheduling. + 5. On success/failure, the handler (see next section) is unblocked. + +#### 4.5 XLA Host `CustomCall` Handler (CPU Client) + +- **Registration**: + + - For the EXLA CPU client, register a host call target with XLA: + + - Name: `"exla_elixir_callback"`. + +- **Handler logic**: + + 1. Extract: + - `callback_id` from `CustomCall` attributes. + - Operand buffers (inputs). + - Output buffers and their shapes/dtypes. + 2. Convert operand buffers into host tensors and build a `CallbackRequest`: + - Assign a fresh `reply_tag`. + - Create a `std::promise` and `std::future`. + - Insert `reply_tag → promise` into the map. + 3. Enqueue the `CallbackRequest` onto the native request queue. + 4. Block on the `future` until a result arrives (native wait). + 5. Once the `CallbackResult` is available: + - On success: + - Write returned tensor data into XLA’s output buffers. + - Return `OK` to XLA. + - On error or timeout: + - Return an error `Status` so the XLA run fails with a descriptive error. + +#### 4.6 Elixir Dispatcher Process + +- Implement a **GenServer** in Nx/EXLA that acts as the BEAM-side dispatcher for callbacks. + +- Responsibilities: + + - Maintain: + - `callbacks: %{ {run_ref, callback_id} => {fun_or_mfa, output_spec} }`. + + - Handle messages from the bridge thread: + + ```elixir + def handle_info({:exla_elixir_call, run_ref, callback_id, args_bin, reply_tag}, state) do + {args, arg_specs} = deserialize_tensors(args_bin) + {fun_or_mfa, output_spec} = + Map.fetch!(state.callbacks, {run_ref, callback_id}) + + # Execute user code (possibly in a Task for isolation) + result = + try do + call_user_fun(fun_or_mfa, args) + rescue + exception -> {:error, {:exception, exception, __STACKTRACE__}} + catch + kind, reason -> {:error, {kind, reason}} + end + + reply_payload = + encode_result(result, output_spec) # either {:ok, tensors_bin} or {:error, reason} + + # One NIF call to signal back to native side (bridge thread sees this) + send_reply_to_nif(reply_tag, reply_payload) + + {:noreply, state} + end + ``` + + - Ensure: + - Result shapes/dtypes match `output_spec`; otherwise return a structured error. + - Optional: enforce configurable timeouts per callback and abort the run on timeout. + +- **API considerations**: + + - A worker or supervisor module (e.g. `EXLA.CallbackServer`) could manage: + - Registration of callbacks per `run_ref`. + - Cleanup after run completion. + +when registering the callback, there should be a fun/capture -> integer mapping (maybe use :counters for generating these integers) and the function should be registered with this id. The id should be returned so that the compiler can use it. This turns the callback server into the source of truth and the generator of ids + +#### 4.7 Error Handling and Validation + +- **Compile-time checks**: + + - Verify that: + - `output_template` can be converted into a valid `output_spec`. + - Grad rules (where applicable) can be defined; if not, error clearly or fallback. + +- **Runtime checks**: + + - After Elixir callback returns: + - Validate result shape/dtype vs `output_spec`. + - On mismatch, generate a descriptive error and fail the XLA run. + +- **Timeouts**: + + - Optional but recommended: + - Per-callback timeout at the dispatcher level. + - If timeout expires, reply with error; native side then aborts the run. + +- **Safety**: + + - No calls from arbitrary XLA threads into BEAM functions. + - All BEAM interaction uses `enif_send` from the bridge thread or explicit NIF calls from Elixir processes. + +--- + +### 5. Phase 2: Graph Segmentation and Cross-Device Support + +After Phase 1 is solid on CPU, we extend support to all EXLA devices (CPU/GPU) via **segmentation** in `Nx.Defn.Graph`. This aligns with [the discussion on nx#1627](https://github.com/elixir-nx/nx/pull/1627), where `elixir_call` and other “optional callback” mechanisms share a unified specification, and the compiler decides whether to split or to compile. + +#### 5.1 Treat `elixir_call` as a Stage Boundary + +- In `Nx.Defn.Graph`, treat each `elixir_call` as a **potential cut point**: + + - Find maximal subgraphs that: + - Contain no `elixir_call`. + - Are otherwise pure Nx computations. + +- Build a sequence: + + - `stage_0` → `elixir_call_0` → `stage_1` → `elixir_call_1` → … → `stage_n`. + +- Each `stage_i` will be compiled separately for a target device (CPU or GPU). + +#### 5.2 Stage Compilation + +- For each pure stage: + + - Infer shapes and types as usual. + - Choose a device (matching the EXLA client or using more advanced heuristics later). + - Compile to an EXLA executable. + +- For each `elixir_call` between stages: + + - Reuse the **Phase 1 dispatcher + bridge**: + - Inputs: outputs of the previous stage (converted to host tensors). + - Outputs: inputs to the next stage (converted back and transferred as needed). + +#### 5.3 Orchestration Runtime + +- Implement an orchestrator (in Nx/EXLA) that performs: + + 1. Run `stage_0` on its device → get outputs. + 2. Transfer these outputs to host (if needed). + 3. Invoke the Elixir callback via the dispatcher → get callback outputs. + 4. Transfer callback outputs to the device for `stage_1` (if needed). + 5. Repeat until all stages and callbacks are executed. + +- This orchestration: + + - Provides **consistent semantics** across CPU and GPU. + - Keeps the user API (`Nx.elixir_call/3`) unchanged. + - Allows future optimizations where: + - Some callbacks are compiled away (if expressible in pure Nx). + - Some backends choose device-specific mechanisms (e.g. XLA GPU host-callbacks or infeed/outfeed) internally. + +#### 5.4 Compiler Decisions (Future Work) + +- Over time, the compiler can classify callbacks: + + - **Pure, shape-stable callbacks definable in Nx**: + - Potentially inline/compile them, removing the runtime callback. + + - **Genuinely dynamic callbacks**: + - Keep them as segmentation boundaries. + +- This addresses the concern raised in [nx#1627](https://github.com/elixir-nx/nx/pull/1627) about having a **unified specification** for callbacks while allowing the compiler to choose between splitting and compiling. + +--- + +### 6. Open Questions / Next Steps + +- **Naming and API**: + - Finalize `Nx.elixir_call/3` naming and argument order. + - Decide whether to expose more advanced options (timeouts, impurity markers, etc.) in `opts`. + +- **Gradients**: + - For Phase 1, gradients may be: + - Not supported for arbitrary callbacks (raise on use), or + - Supported only when the callback is expressible in Nx and compiled away (future optimization). + +- **Concurrency model**: + - Decide how many bridge threads to run. + - Understand the interaction with multiple concurrent EXLA runs and multiple callback-heavy computations. + +- **Device-specific optimizations** (beyond segmentation): + - Investigate XLA’s GPU host-callback support and whether to implement a more tightly integrated path for GPU (possibly involving infeed/outfeed under the hood) once segmentation version is stable. + +--- + +### 7. Implementation Order Checklist + +1. **Land / refine `Nx.elixir_call/3` API and IR node** (based on [nx#1627](https://github.com/elixir-nx/nx/pull/1627)). +2. **Add EXLA lowering** for CPU: + - Map `elixir_call` → HLO/StableHLO `CustomCall` with target `"exla_elixir_callback"`. +3. **Implement native callback registry + bridge thread** in EXLA NIF. +4. **Register CPU host `CustomCall` handler** (`"exla_elixir_callback"`) and wire it to the bridge. +5. **Implement Elixir dispatcher process** for callbacks + error handling + sanity checks. +6. **Add tests** for CPU: + - Simple callbacks. + - Multiple callbacks in a single `defn`. + - Error cases (shape mismatch, thrown exceptions). +7. **Introduce segmentation in `Nx.Defn.Graph`**: + - Identify stages between `elixir_call` nodes. + - Compile/orchestrate stages for CPU/GPU. +8. **Extend EXLA to allow callbacks under segmentation** when using GPU clients. +9. Iterate on compiler-side heuristics to decide when callbacks can be compiled away vs split. + + + From 37a15afa9445eb695499af0c89bde0cefe1ed760 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 23 Nov 2025 17:29:53 -0300 Subject: [PATCH 06/42] feat: seemingly working mvp --- .../exla/custom_calls/elixir_callback.cc | 122 ++++++++ exla/c_src/exla/elixir_callback_bridge.h | 49 +++ exla/c_src/exla/exla.cc | 280 ++++++++++++++++-- exla/lib/exla/callback_server.ex | 81 +++-- exla/lib/exla/defn.ex | 36 ++- exla/lib/exla/mlir/value.ex | 12 +- exla/test/exla/defn/elixir_call_exla_test.exs | 18 +- exla/test/exla/defn/elixir_call_test.exs | 1 - nx/lib/nx/defn/evaluator.ex | 4 +- nx/lib/nx/defn/expr.ex | 6 +- nx/lib/nx/defn/tree.ex | 4 +- 11 files changed, 534 insertions(+), 79 deletions(-) create mode 100644 exla/c_src/exla/custom_calls/elixir_callback.cc create mode 100644 exla/c_src/exla/elixir_callback_bridge.h diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc new file mode 100644 index 0000000000..acf1eb6a51 --- /dev/null +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -0,0 +1,122 @@ +#include "../elixir_callback_bridge.h" + +#include +#include +#include + +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; + +namespace { + +ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, + ffi::RemainingRets rets) { + if (args.size() == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "exla_elixir_callback expects at least one argument"); + } + + // The last argument is a scalar S64 tensor carrying the callback id. + size_t id_index = args.size() - 1; + + auto id_buf_or = args.get(id_index); + if (!id_buf_or) { + return id_buf_or.error(); + } + + ffi::AnyBuffer id_buf = *id_buf_or; + + if (id_buf.element_count() != 1 || + id_buf.element_type() != ffi::DataType::S64) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "exla_elixir_callback callback id must be scalar s64"); + } + + int64_t callback_id = id_buf.reinterpret_data()[0]; + + // Collect all remaining input tensors (excluding callback id) into + // lightweight payloads. + std::vector inputs; + inputs.reserve(args.size() - 1); + + for (size_t i = 0; i < id_index; ++i) { + auto maybe_buf_or = args.get(i); + if (!maybe_buf_or) { + return maybe_buf_or.error(); + } + + ffi::AnyBuffer buf = *maybe_buf_or; + + exla::ElixirCallbackTensor tensor; + tensor.dtype = buf.element_type(); + + auto dims = buf.dimensions(); + tensor.dims.assign(dims.begin(), dims.end()); + + size_t size_bytes = buf.size_bytes(); + tensor.data.resize(size_bytes); + if (size_bytes > 0) { + std::memcpy(tensor.data.data(), buf.untyped_data(), size_bytes); + } + + inputs.push_back(std::move(tensor)); + } + + // Call back into Elixir through the bridge. + exla::ElixirCallbackResult result = + exla::CallElixirCallback(callback_id, inputs); + + if (!result.ok) { + return ffi::Error(ffi::ErrorCode::kInternal, result.error); + } + + if (result.outputs.size() != rets.size()) { + return ffi::Error( + ffi::ErrorCode::kInternal, + "mismatched number of callback outputs vs custom_call results"); + } + + // Copy returned binaries into the result buffers. We rely on the Elixir side + // (Nx.elixir_call/3) to have already validated shapes and dtypes. + for (size_t i = 0; i < rets.size(); ++i) { + auto maybe_ret_or = rets.get(i); + if (!maybe_ret_or) { + return maybe_ret_or.error(); + } + + ffi::Result ret = *maybe_ret_or; + ffi::AnyBuffer out = *ret; + + const auto &payload = result.outputs[i]; + + size_t expected = + ffi::ByteWidth(out.element_type()) * out.element_count(); + + if (payload.data.size() != expected) { + return ffi::Error( + ffi::ErrorCode::kInternal, + "callback returned binary of unexpected size for result buffer"); + } + + if (expected > 0) { + std::memcpy(out.untyped_data(), payload.data.data(), expected); + } + } + + return ffi::Error::Success(); +} + +} // namespace + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + exla_elixir_callback, exla_elixir_callback_impl, + ffi::Ffi::Bind() + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "exla_elixir_callback", "Host", + exla_elixir_callback); + + diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h new file mode 100644 index 0000000000..1f96e5e0c1 --- /dev/null +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include "xla/ffi/api/ffi.h" + +namespace exla { + +// Lightweight tensor payload used to transfer arguments and results between +// the XLA host CustomCall handler and the Elixir dispatcher. +struct ElixirCallbackTensor { + xla::ffi::DataType dtype; + std::vector dims; + std::vector data; +}; + +struct ElixirCallbackResult { + bool ok = false; + std::string error; + std::vector outputs; +}; + +// Registers the Elixir dispatcher process that will receive callback requests. +void SetElixirCallbackDispatcher(ErlNifPid dispatcher_pid); + +// Called from the Elixir side to deliver a reply for a given callback tag. +void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, + fine::Term payload); + +// Synchronously calls the Elixir callback identified by `callback_id` with the +// given tensor arguments. This function: +// +// * Allocates a unique reply_tag +// * Sends a message to the dispatcher via enif_send/3 +// * Blocks the calling native thread until the reply arrives via +// DeliverElixirCallbackReply/3 +// +// It returns an ElixirCallbackResult that either contains a list of output +// tensors (on success) or an error message. +ElixirCallbackResult CallElixirCallback(int64_t callback_id, + const std::vector &inputs); + +} // namespace exla + + diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index e0842599b6..eac0b0683b 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -2,12 +2,15 @@ #include #include #include +#include +#include #include "exla_client.h" #include "exla_cuda.h" #include "exla_log_sink.h" #include "exla_mlir.h" #include "exla_nif_util.h" +#include "elixir_callback_bridge.h" #include "ipc.h" #include "mlir/IR/MLIRContext.h" #include "stablehlo/dialect/ChloOps.h" @@ -525,29 +528,19 @@ FINE_NIF(get_per_device_memory, 0); namespace { -// Very small, CPU-only bridge that forwards callback requests from the XLA -// host CustomCall to the Elixir dispatcher process. -// -// For Phase 1 we keep this intentionally simple: -// * Only single-output, single-replica computations are supported. -// * Arguments and results are transferred by value as host binaries. -// * Each request is synchronous: the CustomCall will block the XLA host -// thread until the Elixir side replies via `elixir_callback_reply/2`. -// -// This can be evolved later to support batching, more efficient tensor -// encoding, and timeouts. - -struct ElixirCallbackRequest { - int64_t callback_id; - std::vector args; - ERL_NIF_TERM reply_tag; +struct ElixirCallbackPending { + std::mutex mu; + std::condition_variable cv; + bool done = false; + ElixirCallbackResult result; }; -// Global state for the bridge. For simplicity we keep a single dispatcher -// PID and use a monotonically increasing integer as reply_tag. struct ElixirCallbackBridgeState { ErlNifPid dispatcher_pid; + bool dispatcher_set = false; std::atomic next_tag{1}; + std::mutex mu; + std::unordered_map> pending; }; ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { @@ -555,30 +548,259 @@ ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { return state; } +// Map ffi::DataType to a Nx-style {atom, bits} pair used on the Elixir side. +std::pair +EncodeNxType(ErlNifEnv *env, xla::ffi::DataType dtype) { + const char *atom = nullptr; + int bits = 0; + + switch (dtype) { + case xla::ffi::PRED: + atom = "u"; + bits = 8; + break; + case xla::ffi::S8: + atom = "s"; + bits = 8; + break; + case xla::ffi::S16: + atom = "s"; + bits = 16; + break; + case xla::ffi::S32: + atom = "s"; + bits = 32; + break; + case xla::ffi::S64: + atom = "s"; + bits = 64; + break; + case xla::ffi::U8: + atom = "u"; + bits = 8; + break; + case xla::ffi::U16: + atom = "u"; + bits = 16; + break; + case xla::ffi::U32: + atom = "u"; + bits = 32; + break; + case xla::ffi::U64: + atom = "u"; + bits = 64; + break; + case xla::ffi::F16: + atom = "f"; + bits = 16; + break; + case xla::ffi::F32: + atom = "f"; + bits = 32; + break; + case xla::ffi::F64: + atom = "f"; + bits = 64; + break; + case xla::ffi::BF16: + atom = "bf"; + bits = 16; + break; + case xla::ffi::C64: + atom = "c"; + bits = 64; + break; + case xla::ffi::C128: + atom = "c"; + bits = 128; + break; + default: + atom = "f"; + bits = 32; + break; + } + + ERL_NIF_TERM atom_term = enif_make_atom(env, atom); + ERL_NIF_TERM bits_term = enif_make_int(env, bits); + return {atom_term, bits_term}; +} + } // namespace -std::tuple, fine::Error> +fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid) { + (void)env; auto state = GetElixirCallbackBridgeState(); state->dispatcher_pid = dispatcher_pid; - return std::make_tuple(fine::Ok<>(), fine::Error()); + state->dispatcher_set = true; + return fine::Ok(); } FINE_NIF(start_elixir_callback_bridge, 0); -std::tuple, fine::Error> -elixir_callback_reply(ErlNifEnv *env, int64_t reply_tag, fine::Term _payload) { - // For Phase 1 we do not implement a native waiting mechanism; instead the - // CustomCall handler calls directly into Elixir and returns immediately. - // This NIF exists only as a placeholder for future, more advanced bridges. - (void)env; - (void)reply_tag; - (void)_payload; - return std::make_tuple(fine::Ok<>(), fine::Error()); +fine::Ok<> +elixir_callback_reply(ErlNifEnv *env, int64_t reply_tag, fine::Term payload) { + DeliverElixirCallbackReply(env, reply_tag, payload); + return fine::Ok(); } FINE_NIF(elixir_callback_reply, 0); +void SetElixirCallbackDispatcher(ErlNifPid dispatcher_pid) { + auto state = GetElixirCallbackBridgeState(); + state->dispatcher_pid = dispatcher_pid; + state->dispatcher_set = true; +} + +void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, + fine::Term payload) { + auto state = GetElixirCallbackBridgeState(); + + std::shared_ptr pending; + { + std::lock_guard lock(state->mu); + auto it = state->pending.find(reply_tag); + if (it == state->pending.end()) { + return; + } + pending = it->second; + } + + ElixirCallbackResult result; + + int arity = 0; + const ERL_NIF_TERM *tuple = nullptr; + ERL_NIF_TERM term = payload; + + if (!enif_get_tuple(env, term, &arity, &tuple) || arity != 2) { + result.ok = false; + result.error = "invalid callback reply payload, expected {status, value}"; + } else { + char atom_buf[16]; + if (enif_get_atom(env, tuple[0], atom_buf, sizeof(atom_buf), + ERL_NIF_LATIN1) && + strcmp(atom_buf, "ok") == 0) { + // tuple[1] is a list of binaries representing outputs. + ERL_NIF_TERM list = tuple[1]; + ERL_NIF_TERM head, tail; + + while (enif_get_list_cell(env, list, &head, &tail)) { + ErlNifBinary bin; + if (!enif_inspect_binary(env, head, &bin)) { + result.ok = false; + result.error = "invalid binary in callback reply"; + break; + } + + ElixirCallbackTensor tensor; + tensor.dtype = xla::ffi::DataType::INVALID; + tensor.dims = {}; + tensor.data.assign(bin.data, bin.data + bin.size); + result.outputs.push_back(std::move(tensor)); + + list = tail; + } + + if (result.error.empty()) { + result.ok = true; + } + } else { + result.ok = false; + result.error = "elixir callback returned error"; + } + } + + { + std::lock_guard lock(pending->mu); + pending->result = std::move(result); + pending->done = true; + } + + pending->cv.notify_one(); + + { + std::lock_guard lock(state->mu); + state->pending.erase(reply_tag); + } +} + +ElixirCallbackResult +CallElixirCallback(int64_t callback_id, + const std::vector &inputs) { + auto state = GetElixirCallbackBridgeState(); + + if (!state->dispatcher_set) { + ElixirCallbackResult res; + res.ok = false; + res.error = "EXLA elixir callback dispatcher is not set"; + return res; + } + + auto pending = std::make_shared(); + + int64_t tag = state->next_tag.fetch_add(1, std::memory_order_relaxed); + + { + std::lock_guard lock(state->mu); + state->pending.emplace(tag, pending); + } + + ErlNifEnv *msg_env = enif_alloc_env(); + + // Encode arguments as [{bin, {type, bits}, shape_list}, ...] + std::vector args_terms; + args_terms.reserve(inputs.size()); + + for (const auto &tensor : inputs) { + ERL_NIF_TERM bin_term; + unsigned char *bin_data = + enif_make_new_binary(msg_env, tensor.data.size(), &bin_term); + if (tensor.data.size() > 0) { + memcpy(bin_data, tensor.data.data(), tensor.data.size()); + } + + auto [type_atom, bits_term] = EncodeNxType(msg_env, tensor.dtype); + + ERL_NIF_TERM type_tuple = + enif_make_tuple2(msg_env, type_atom, bits_term); + + std::vector dim_terms; + dim_terms.reserve(tensor.dims.size()); + for (auto d : tensor.dims) { + dim_terms.push_back(enif_make_int64(msg_env, d)); + } + + ERL_NIF_TERM shape_list = + enif_make_list_from_array(msg_env, dim_terms.data(), + dim_terms.size()); + + ERL_NIF_TERM arg_tuple = + enif_make_tuple3(msg_env, bin_term, type_tuple, shape_list); + + args_terms.push_back(arg_tuple); + } + + ERL_NIF_TERM args_list = + enif_make_list_from_array(msg_env, args_terms.data(), + args_terms.size()); + + ERL_NIF_TERM tag_term = enif_make_int64(msg_env, tag); + ERL_NIF_TERM cb_term = enif_make_int64(msg_env, callback_id); + + ERL_NIF_TERM msg = + enif_make_tuple4(msg_env, enif_make_atom(msg_env, "exla_elixir_call"), + cb_term, args_list, tag_term); + + enif_send(msg_env, &state->dispatcher_pid, msg_env, msg); + enif_free_env(msg_env); + + std::unique_lock lock(pending->mu); + pending->cv.wait(lock, [&pending] { return pending->done; }); + + return pending->result; +} + // Logging fine::Ok<> start_log_sink(ErlNifEnv *env, ErlNifPid logger_pid) { diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index b3434a69b0..449e5e987d 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -38,7 +38,10 @@ defmodule EXLA.CallbackServer do @type t :: %__MODULE__{ next_id: non_neg_integer(), - callbacks: %{callback_id() => {fun(), Nx.t() | tuple()}} + # We store the original function, its output template, and any + # static (non-tensor) arguments that should always be appended to + # the decoded tensor arguments coming from native. + callbacks: %{callback_id() => {fun(), Nx.t() | tuple(), [term()]}} } ## Public API @@ -54,15 +57,16 @@ defmodule EXLA.CallbackServer do end @doc """ - Registers a callback function and its output template, returning a callback id. + Registers a callback function, its output template, and static arguments, returning a callback id. - The same `{fun, out_template}` pair will always return the same id for the - lifetime of this VM. This id is what the EXLA compiler should encode into - the host `CustomCall` so the native side can reference the right callback. + The same `{fun, out_template, static_args}` triple will always return the + same id for the lifetime of this VM. This id is what the EXLA compiler + encodes into the host `CustomCall` so the native side can reference the + right callback. """ - @spec register(fun(), Nx.t() | tuple()) :: callback_id() - def register(fun, out_template) when is_function(fun) do - GenServer.call(__MODULE__, {:register, fun, out_template}) + @spec register(fun(), Nx.t() | tuple(), [term()]) :: callback_id() + def register(fun, out_template, static_args) when is_function(fun) and is_list(static_args) do + GenServer.call(__MODULE__, {:register, fun, out_template, static_args}) end ## GenServer callbacks @@ -85,8 +89,8 @@ defmodule EXLA.CallbackServer do end @impl true - def handle_call({:register, fun, out_template}, _from, %__MODULE__{} = state) do - key = {fun, Nx.to_template(out_template)} + def handle_call({:register, fun, out_template, static_args}, _from, %__MODULE__{} = state) do + key = {fun, Nx.to_template(out_template), static_args} {id, state} = case find_existing_id(state.callbacks, key) do @@ -95,19 +99,22 @@ defmodule EXLA.CallbackServer do :error -> id = state.next_id - callbacks = Map.put(state.callbacks, id, {fun, Nx.to_template(out_template)}) - {%{state | callbacks: callbacks, next_id: id + 1}.next_id - 1, %{state | callbacks: callbacks, next_id: id + 1}} + callbacks = Map.put(state.callbacks, id, {fun, Nx.to_template(out_template), static_args}) + {%{state | callbacks: callbacks, next_id: id + 1}.next_id - 1, + %{state | callbacks: callbacks, next_id: id + 1}} end {:reply, id, state} end @impl true - def handle_info({:exla_elixir_call, callback_id, args, reply_tag}, %__MODULE__{} = state) do + def handle_info({:exla_elixir_call, callback_id, args_spec, reply_tag}, %__MODULE__{} = state) do case Map.fetch(state.callbacks, callback_id) do - {:ok, {fun, out_template}} -> + {:ok, {fun, out_template, static_args}} -> reply_payload = - run_callback(fun, args, out_template) + args_spec + |> decode_args() + |> run_callback(fun, static_args, out_template) |> encode_reply() send_reply(reply_tag, reply_payload) @@ -137,10 +144,12 @@ defmodule EXLA.CallbackServer do end) end - defp run_callback(fun, args, out_template) do + defp run_callback({:error, reason}, _fun, _static_args, _out_template), do: {:error, reason} + + defp run_callback({:ok, tensor_args}, fun, static_args, out_template) do result = try do - apply(fun, args) + apply(fun, tensor_args ++ static_args) rescue exception -> {:error, {:exception, exception, __STACKTRACE__}} @@ -155,7 +164,7 @@ defmodule EXLA.CallbackServer do value -> case ensure_compatible(value, out_template) do - {:ok, tensor_or_tuple} -> {:ok, tensor_or_tuple} + {:ok, compatible} -> {:ok, compatible} {:error, reason} -> {:error, reason} end end @@ -189,9 +198,43 @@ defmodule EXLA.CallbackServer do defp ensure_compatible(left, right), do: {:error, {:invalid_result, left, right}} - defp encode_reply({:ok, value}), do: {:ok, value} + defp decode_args(args_spec) when is_list(args_spec) do + result = + Enum.reduce_while(args_spec, {:ok, []}, fn {bin, {type, bits}, shape}, {:ok, acc} -> + try do + tensor = + bin + |> Nx.from_binary({type, bits}) + |> Nx.reshape(List.to_tuple(shape)) + + {:cont, {:ok, [tensor | acc]}} + rescue + exception -> + {:halt, {:error, {:decode_failed, exception}}} + end + end) + + case result do + {:ok, tensors} -> {:ok, Enum.reverse(tensors)} + {:error, _} = error -> error + end + end + + defp decode_args(other), do: {:error, {:invalid_args_spec, other}} + + defp encode_reply({:ok, value}), do: {:ok, encode_outputs(value)} defp encode_reply({:error, reason}), do: {:error, reason} + defp encode_outputs(%Nx.Tensor{} = tensor) do + [Nx.to_binary(tensor)] + end + + defp encode_outputs(tuple) when is_tuple(tuple) do + tuple + |> Tuple.to_list() + |> Enum.map(&Nx.to_binary/1) + end + defp send_reply(reply_tag, payload) do try do EXLA.NIF.elixir_callback_reply(reply_tag, payload) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index a2a04e8cd5..cf29af1afa 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -548,22 +548,33 @@ defmodule EXLA.Defn do defp cached_recur_operator( :elixir_call, - %T{data: %Expr{args: [in_args, fun]}} = expr, + %T{data: %Expr{args: [in_args, fun, out_template]}} = expr, %{client: %EXLA.Client{platform: :host}} = state, cache ) do - {tensor_args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + {tensor_args, _opts} = Enum.split_while(in_args, &(not is_list(&1))) {call_args, cache} = Enum.map_reduce(tensor_args, cache, fn arg, cache -> recur_operator(arg, state, cache) |> unwrap_single_tensor!() end) - callback_id = EXLA.CallbackServer.register(fun, Nx.to_template(expr)) - typespecs = container_to_typespecs(expr) + static_args = Enum.drop(in_args, length(tensor_args)) + + callback_id = EXLA.CallbackServer.register(fun, out_template, static_args) + typespecs = container_to_typespecs(out_template) + + # Pass callback id as an extra scalar s64 operand at the end so that the + # native handler can retrieve it without relying on backend_config attrs. + callback_id_typespec = Typespec.tensor({:s, 64}, {}) + + callback_id_value = + Value.constant(state.builder, [callback_id], callback_id_typespec) + + operands = call_args ++ [callback_id_value] results = - Value.elixir_call(call_args, callback_id, typespecs) + Value.elixir_call(operands, typespecs) {wrap_tuple_result(results, expr), cache} end @@ -1928,14 +1939,25 @@ defmodule EXLA.Defn do end defp container_to_typespecs(container) do - [container] + containers = + if is_list(container) do + container + else + [container] + end + + containers + |> Enum.reject(&is_function/1) |> Nx.Defn.Composite.flatten_list() |> Enum.flat_map(fn %Nx.Tensor{type: {:tuple, _}, data: %{args: values}} -> Enum.flat_map(values, &container_to_typespecs/1) - t -> + %Nx.Tensor{} = t -> [Typespec.tensor(t.type, t.shape)] + + _other -> + [] end) end diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index fcc1ef8dd9..5f33f788b2 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -837,19 +837,15 @@ defmodule EXLA.MLIR.Value do The `callback_id` is a small integer assigned by `EXLA.CallbackServer` that identifies which Elixir function should be invoked when the host callback - runs. The native side is expected to read this id from the backend config - or attributes and route the callback accordingly. + runs. It is passed as an extra scalar S64 tensor operand (last argument) to + the custom call. """ - def elixir_call([%Value{function: func} | _] = operands, callback_id, typespecs) - when is_integer(callback_id) and callback_id >= 0 do + def elixir_call([%Value{function: func} | _] = operands, typespecs) do result_types = typespecs_to_mlir_types(typespecs) attributes = [ call_target_name: attr_string("exla_elixir_callback"), - api_version: attr_i32(4), - # We currently encode the callback id as a backend config string. - # The native handler should parse this value back into an integer. - backend_config: attr_string(Integer.to_string(callback_id)) + api_version: attr_i32(4) ] op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) diff --git a/exla/test/exla/defn/elixir_call_exla_test.exs b/exla/test/exla/defn/elixir_call_exla_test.exs index 75feb748aa..5bc4c6decf 100644 --- a/exla/test/exla/defn/elixir_call_exla_test.exs +++ b/exla/test/exla/defn/elixir_call_exla_test.exs @@ -51,19 +51,19 @@ defmodule EXLA.Defn.ElixirCallEXLATest do assert_equal(y, expected) end - test "elixir_call errors when result shape does not match template" do - defn bad_callback(x) do - out = %{x | type: Nx.Type.to_floating(x.type)} + defn bad_callback(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} - Nx.elixir_call(out, [x], fn _t -> - # Wrong shape on purpose - Nx.tensor([1.0, 2.0, 3.0]) - end) - end + Nx.elixir_call(out, [x], fn _t -> + # Wrong shape on purpose + Nx.tensor([1.0, 2.0, 3.0]) + end) + end + test "elixir_call errors when result shape does not match template" do x = Nx.iota({2}) - assert_raise ArgumentError, ~r/expected the elixir_call function to match/, fn -> + assert_raise RuntimeError, ~r/elixir callback returned error/, fn -> bad_callback(x) end end diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index 09d10a02be..f3e52aaea3 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -4,7 +4,6 @@ defmodule EXLA.Defn.ElixirCallEvaluatorTest do import Nx.Testing setup do - Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) Nx.default_backend(EXLA.Backend) :ok end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index c913f4ec3c..d088366a48 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -176,7 +176,7 @@ defmodule Nx.Defn.Evaluator do end defp compute_cache(:elixir_call, %{data: %Expr{args: args}}, state, cache) do - [in_args, _fun] = args + [in_args, _fun, _out_template] = args Enum.reduce(in_args, cache, fn t, cache when is_list(t) -> cache @@ -442,7 +442,7 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [in_args, fun]}} = expr, + %{data: %Expr{args: [in_args, fun, _out_template]}} = expr, state, caches ) do diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 1d488df888..9794749db6 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -393,11 +393,13 @@ defmodule Nx.Defn.Expr do case out do t when is_struct(t, Nx.Tensor) -> - expr(t, context, :elixir_call, [in_args, fun]) + out_template = Nx.to_template(t) + expr(t, context, :elixir_call, [in_args, fun, out_template]) tuple when is_tuple(tuple) -> out_template = tuple_out(tuple_size(tuple)) - expr_node = expr(out_template, context, :elixir_call, [in_args, fun]) + user_template = Nx.to_template(tuple) + expr_node = expr(out_template, context, :elixir_call, [in_args, fun, user_template]) tuple(expr_node, Tuple.to_list(tuple)) end end diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 733a131e4f..02ee4d001c 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -193,7 +193,7 @@ defmodule Nx.Defn.Tree do end def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, _type, acc, fun) do - [in_args, callback] = args + [in_args, callback, out_template] = args {in_args, acc} = Enum.map_reduce(in_args, acc, fn t, acc -> @@ -204,7 +204,7 @@ defmodule Nx.Defn.Tree do end end) - {[in_args, callback], acc} + {[in_args, callback, out_template], acc} end def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do From 7127a8c8e0f74f04ac0c32196e7131fa1d3d7aa0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 24 Nov 2025 02:01:45 -0300 Subject: [PATCH 07/42] feat: first working version --- exla/c_src/exla/exla.cc | 20 ++- exla/lib/exla/callback_server.ex | 55 ++++++- exla/test/exla/defn/elixir_call_exla_test.exs | 2 +- nx_elixir_call_exla_design.md | 135 +++++++++++++++++- 4 files changed, 207 insertions(+), 5 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index eac0b0683b..5a32a82811 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -706,8 +706,26 @@ void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, result.ok = true; } } else { + // Error reply: tuple[1] is expected to be {kind_atom, message :: binary} result.ok = false; - result.error = "elixir callback returned error"; + ERL_NIF_TERM err_term = tuple[1]; + + int err_arity = 0; + const ERL_NIF_TERM *err_tuple = nullptr; + if (enif_get_tuple(env, err_term, &err_arity, &err_tuple) && + err_arity == 2) { + // We ignore the kind atom for now (e.g. :argument_error or + // :runtime_error) and use only the message as the XLA error text. + ErlNifBinary msg_bin; + if (enif_inspect_binary(env, err_tuple[1], &msg_bin)) { + result.error.assign(reinterpret_cast(msg_bin.data), + msg_bin.size); + } else { + result.error = "elixir callback returned error"; + } + } else { + result.error = "elixir callback returned error"; + } } } diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 449e5e987d..8642da8fcb 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -223,7 +223,60 @@ defmodule EXLA.CallbackServer do defp decode_args(other), do: {:error, {:invalid_args_spec, other}} defp encode_reply({:ok, value}), do: {:ok, encode_outputs(value)} - defp encode_reply({:error, reason}), do: {:error, reason} + + # Shape mismatch between callback result and output template. + defp encode_reply({:error, {:shape_mismatch, left, right}}) do + msg = + "expected the elixir_call function to match the given output template " <> + "#{inspect(right)}, got: #{inspect(left)}" + + {:error, {:argument_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + end + + # Callback returned something that isn't a tensor/tuple matching the template. + defp encode_reply({:error, {:invalid_result, left, right}}) do + msg = + "expected the elixir_call function to return a value compatible with the output " <> + "template #{inspect(right)}, got: #{inspect(left)}" + + {:error, {:argument_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + end + + # Argument decoding failures. + defp encode_reply({:error, {:decode_failed, exception}}) do + msg = Exception.message(exception) + msg = "failed to decode Elixir callback arguments: #{msg}" + {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + end + + defp encode_reply({:error, {:invalid_args_spec, other}}) do + msg = "invalid args_spec for Elixir callback: #{inspect(other)}" + {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + end + + # Unknown callback id from native. + defp encode_reply({:error, :unknown_callback}) do + msg = "unknown EXLA elixir_call callback id" + {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + end + + # User-raised exceptions. + defp encode_reply({:error, {:exception, exception, _stack}}) do + msg = Exception.message(exception) + msg = "Elixir callback raised: #{msg}" + {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + end + + # Catches other error tuples (throws, exits, etc). + defp encode_reply({:error, {kind, reason}}) do + msg = "Elixir callback #{kind}: #{inspect(reason)}" + {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + end + + defp encode_reply({:error, reason}) do + msg = "Elixir callback error: #{inspect(reason)}" + {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + end defp encode_outputs(%Nx.Tensor{} = tensor) do [Nx.to_binary(tensor)] diff --git a/exla/test/exla/defn/elixir_call_exla_test.exs b/exla/test/exla/defn/elixir_call_exla_test.exs index 5bc4c6decf..3154a0a12d 100644 --- a/exla/test/exla/defn/elixir_call_exla_test.exs +++ b/exla/test/exla/defn/elixir_call_exla_test.exs @@ -63,7 +63,7 @@ defmodule EXLA.Defn.ElixirCallEXLATest do test "elixir_call errors when result shape does not match template" do x = Nx.iota({2}) - assert_raise RuntimeError, ~r/elixir callback returned error/, fn -> + assert_raise RuntimeError, ~r/expected the elixir_call function to match the given output template/, fn -> bad_callback(x) end end diff --git a/nx_elixir_call_exla_design.md b/nx_elixir_call_exla_design.md index aac22ee371..bacf3bb480 100644 --- a/nx_elixir_call_exla_design.md +++ b/nx_elixir_call_exla_design.md @@ -340,5 +340,136 @@ After Phase 1 is solid on CPU, we extend support to all EXLA devices (CPU/GPU) v 8. **Extend EXLA to allow callbacks under segmentation** when using GPU clients. 9. Iterate on compiler-side heuristics to decide when callbacks can be compiled away vs split. - - +### 8. Phase 1 – Intended vs Implemented (Status Summary) + +This section records where the **current implementation** matches or intentionally diverges from the original Phase 1 plan, so future work can see what is done vs still open. + +#### 8.1 Public API (`Nx.elixir_call/3`) + +- **Intended (this doc, §4.1)**: + - Shape: `Nx.elixir_call(args, fun_or_mfa, opts \\ [])`. + - Explicit `output_template` / `output_spec` passed via `opts`. +- **Implemented (Nx 0.10 + EXLA 0.10)**: + - Shape: `Nx.elixir_call(output_template, args, fun)`. + - `output_template` is the **first argument**, not an option. + - `args` is a list of runtime arguments (tensors + static values). + - `fun` is a plain Elixir function; we don’t support MFA in Phase 1. + - `Nx.Defn.Expr.elixir_call/3`: + - For tensor output: stores `:elixir_call` node with args `[in_args, fun, out_template]`, where `out_template = Nx.to_template(output)`. + - For tuple output: builds an internal tuple-shaped template (`tuple_out/1`) plus a **user template** (also passed as `out_template`) so EXLA sees the per-element shapes/dtypes. + - Rationale: keeping `output_template` as a *value argument* made the IR and EXLA lowering simpler and closer to existing `defn` conventions. + +#### 8.2 IR Representation (`Nx.Defn.Expr` / `Nx.Defn.Graph`) + +- **Intended (§4.2)**: + - `{:elixir_call, meta, args}` with `meta` carrying: + - `callback_id`, `fun_or_mfa`, `output_spec`, etc. +- **Implemented**: + - `Expr` op is still `:elixir_call`, but: + - We **do not** store `callback_id` in `meta`; it is managed solely by EXLA. + - Arguments are `[in_args, fun, out_template]`. + - Shape/type inference for the node uses `out_template` via the existing template machinery. + - `Nx.Defn.Tree.apply_args/4` and `Nx.Defn.Evaluator.compute_cache/4` / `eval_apply/4` were updated to be aware of the `out_template` third arg but largely **ignore it** at runtime (it is for compilation only). + +#### 8.3 EXLA Lowering to StableHLO `CustomCall` + +- **Intended (§4.3)**: + - `CustomCall("exla_elixir_callback")` with: + - Result types from `output_spec`. + - Attributes: + - `callback_id` (string/int). + - Possibly encoded `output_spec`. +- **Implemented**: + - We lower to a `stablehlo.custom_call` with: + - `call_target_name = "exla_elixir_callback"`. + - `api_version = 4` (typed FFI). + - **No `backend_config` or dictionary attributes** for callback id. + - Instead of encoding `callback_id` as an attribute, we: + - Append a scalar S64 operand at the **end of the operand list** carrying `callback_id`. + - Register a typed FFI handler that: + - Interprets the last operand as the callback id. + - Treats the remaining operands as regular tensor arguments. + - Result types are derived from the `out_template`: + - `container_to_typespecs(out_template)` produces one `EXLA.Typespec` per tensor (including tuple elements). + +#### 8.4 Native Bridge & Callback Registry + +- **Intended (§4.4–4.5)**: + - Per-run mapping `callback_id → {fun_or_mfa, output_spec}`. + - Bridge using `RunRef`, `CallbackRequest`, `CallbackResult`, per-run state. +- **Implemented**: + - We use a **global, process-wide `EXLA.CallbackServer`**: + - Maps `callback_id (integer)` → `{fun, out_template, static_args}`. + - Reuses ids when the same `{fun, out_template, static_args}` triple is registered again. + - The C++ bridge maintains: + - A global `ElixirCallbackBridgeState` with: + - `dispatcher_pid`, `next_tag`, and a `pending` map from `reply_tag` to a small `ElixirCallbackPending` object (`std::mutex`, `std::condition_variable`, `ElixirCallbackResult`). + - The **bridge thread** concept is realized as: + - The host `CustomCall` handler runs on an XLA-controlled thread. + - It **blocks natively** on a `std::condition_variable` associated with a `reply_tag` until the BEAM side replies via `EXLA.NIF.elixir_callback_reply/2`. + - There is currently **no per-run `RunRef`**; callbacks are effectively global to the VM. + +#### 8.5 Elixir Dispatcher (`EXLA.CallbackServer`) + +- **Intended (§4.6)**: + - Dedicated dispatcher process keyed by `(run_ref, callback_id)`. + - Messages like `{:exla_elixir_call, run_ref, callback_id, args_bin, reply_tag}`. +- **Implemented**: + - `EXLA.CallbackServer` is a `GenServer` with: + - `callbacks :: %{callback_id => {fun, out_template, static_args}}`. + - Native side sends: + - `{:exla_elixir_call, callback_id, args_spec, reply_tag}`. + - `args_spec` is a list of `{bin, {type_atom, bits}, shape_list}` tuples. + - Dispatcher logic: + - Decodes `args_spec` into `Nx.Tensor`s (`Nx.from_binary/3` + `Nx.reshape/2`). + - Appends `static_args` captured at registration time. + - Executes the callback function with `[tensor_args ++ static_args]`. + - Validates the result against `out_template`: + - Tuple size check + per-element shape/dtype/names check. + +#### 8.6 Error Handling & Mapping + +- **Intended (§4.7)**: + - Shape/dtype validation vs `output_spec`. + - Clear errors; possibly mapping to `ArgumentError` or similar. +- **Implemented (Phase 1)**: + - `EXLA.CallbackServer.ensure_compatible/2`: + - Returns `{:ok, value}` on success. + - Returns tagged errors: + - `{:error, {:shape_mismatch, left, right}}`. + - `{:error, {:invalid_result, left, right}}` for non-tensors/tuples. + - `encode_reply/1` maps internal error tuples to **typed error payloads**: + - Shape mismatch / invalid result: + - Encoded as `{:error, {:argument_error, message_binary}}` where the message mirrors `Nx.ensure_call_compatible!/2`, e.g. + - `"expected the elixir_call function to match the given output template ..., got: ..."` + - Decode failures, invalid args spec, unknown callback id, user exceptions, throws, exits: + - Encoded as `{:error, {:runtime_error, message_binary}}` with descriptive text (`"Elixir callback raised: ..."`, etc.). + - Native `DeliverElixirCallbackReply`: + - For `{:ok, binaries}`, fills result buffers. + - For `{:error, {kind_atom, message_binary}}`: + - Uses the **message** as `result.error` string returned to XLA. + - As a result: + - From the user’s point of view, `Nx.elixir_call/3` under EXLA now fails with: + - `RuntimeError` carrying a **descriptive, Nx-style message** (e.g. shape mismatch) instead of the generic `"elixir callback returned error"`. + - We do **not** currently raise `ArgumentError` directly from EXLA runs; everything surfaces as `RuntimeError` with a rich message, which tests explicitly assert. + +#### 8.7 Timeouts & Robustness + +- **Intended (§4.7, Timeouts)**: + - Optional per-callback timeout. + - Native side aborts run on timeout. +- **Implemented**: + - **No timeouts yet**: + - The host `CustomCall` waits indefinitely on `condition_variable` for the reply. + - If the Elixir dispatcher never replies, the XLA run will hang. + - This is an explicit **TODO** for future hardening; the design here still stands, but is not yet implemented. + +#### 8.8 Phase 2 – Not Implemented Yet + +All of §5 (segmentation and cross-device support) remains **design only**: + +- `elixir_call` is **not yet used as a segmentation boundary** in `Nx.Defn.Graph`. +- There is no multi-stage orchestration for GPU/TPU plus CPU callbacks. +- Callbacks are only allowed on the **EXLA host (CPU) client**, and we eagerly raise if the client platform is not `:host`. + +This section should be updated again once segmentation and GPU support are implemented. From bc522051437f355c3bb81627d8e115f6a7739016 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:54:53 -0300 Subject: [PATCH 08/42] wip: step through code review --- code_review.md | 27 +++++++ .../exla/custom_calls/elixir_callback.cc | 8 +-- exla/c_src/exla/elixir_callback_bridge.h | 3 - exla/c_src/exla/exla.cc | 45 ++++++------ exla/lib/exla/callback_server.ex | 29 +++----- exla/lib/exla/defn.ex | 9 +-- exla/lib/exla/mlir/value.ex | 6 +- exla/test/exla/defn/elixir_call_exla_test.exs | 70 ------------------- exla/test/exla/defn/elixir_call_test.exs | 21 +++++- nx/lib/nx.ex | 8 +++ nx/lib/nx/defn/expr.ex | 11 +++ .../nx/defn/elixir_call_evaluator_test.exs | 15 ++++ nx_elixir_call_exla_design.md | 5 +- 13 files changed, 127 insertions(+), 130 deletions(-) create mode 100644 code_review.md delete mode 100644 exla/test/exla/defn/elixir_call_exla_test.exs diff --git a/code_review.md b/code_review.md new file mode 100644 index 0000000000..276e22c184 --- /dev/null +++ b/code_review.md @@ -0,0 +1,27 @@ +review notes: + +- Nx.Defn Expr added a third out_template argument, but the template can be inferred from the expression itself. +- exla: elixir_call_test and elixir_call_exla_test are redundant with each other. we can combine them in a single file. + +- callback server and the rest of the code seem to assume that tensor arguments are always at the beginning of the function. We should enforce this more clearly and document this. +- given that the callback server is named, enif_whereis_pid(https://www.erlang.org/doc/apps/erts/erl_nif.html#enif_whereis_pid) could be used to fetch the current pid for the function. +- EXLA.CallbackServer decode_args uses from_binary without options. We should keep track of the backend options such as the device used to allocate an EXLA tensor. Ideally, we shouldn't even be copying data back and forth. shape should already be passed as a tuple from the NIF. +- EXLA.CallbackServer does encode_reply/encode_outputs really need to "Nx.to_binary" the results? It seems like we should be able to pass EXLA Buffer refs back and forth. + +- EXLA.Defn operands = call_args ++ [callback_id_value] we should prepend the id instead of append id +- What is api_version in EXLA.MLIR.Value.elixir_call? +- Should the callback_id instead be an attribute given that it should not change during execution? + +---- + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + exla_elixir_callback, exla_elixir_callback_impl, + ffi::Ffi::Bind() + .RemainingArgs() + .RemainingRets()); + +This could receive the id in the first argument, and the tensors in the second. + +exla.cc:550 I think there is already another function for mapping exla to nx types. + +exla.cc:648 FINE_NIF(elixir_callback_reply, 0) I think is missing an IO-bound attr \ No newline at end of file diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc index acf1eb6a51..4da48dd122 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -18,10 +18,8 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, "exla_elixir_callback expects at least one argument"); } - // The last argument is a scalar S64 tensor carrying the callback id. - size_t id_index = args.size() - 1; - - auto id_buf_or = args.get(id_index); + // The first argument is a scalar S64 tensor carrying the callback id. + auto id_buf_or = args.get(0); if (!id_buf_or) { return id_buf_or.error(); } @@ -41,7 +39,7 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, std::vector inputs; inputs.reserve(args.size() - 1); - for (size_t i = 0; i < id_index; ++i) { + for (size_t i = 1; i < args.size(); ++i) { auto maybe_buf_or = args.get(i); if (!maybe_buf_or) { return maybe_buf_or.error(); diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h index 1f96e5e0c1..e6d9414f54 100644 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -24,9 +24,6 @@ struct ElixirCallbackResult { std::vector outputs; }; -// Registers the Elixir dispatcher process that will receive callback requests. -void SetElixirCallbackDispatcher(ErlNifPid dispatcher_pid); - // Called from the Elixir side to deliver a reply for a given callback tag. void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, fine::Term payload); diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 5a32a82811..a9da3e9205 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,16 +1,17 @@ +#include #include #include +#include #include #include #include -#include +#include "elixir_callback_bridge.h" #include "exla_client.h" #include "exla_cuda.h" #include "exla_log_sink.h" #include "exla_mlir.h" #include "exla_nif_util.h" -#include "elixir_callback_bridge.h" #include "ipc.h" #include "mlir/IR/MLIRContext.h" #include "stablehlo/dialect/ChloOps.h" @@ -549,8 +550,8 @@ ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { } // Map ffi::DataType to a Nx-style {atom, bits} pair used on the Elixir side. -std::pair -EncodeNxType(ErlNifEnv *env, xla::ffi::DataType dtype) { +std::pair EncodeNxType(ErlNifEnv *env, + xla::ffi::DataType dtype) { const char *atom = nullptr; int bits = 0; @@ -628,8 +629,8 @@ EncodeNxType(ErlNifEnv *env, xla::ffi::DataType dtype) { } // namespace -fine::Ok<> -start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid) { +fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid) { (void)env; auto state = GetElixirCallbackBridgeState(); state->dispatcher_pid = dispatcher_pid; @@ -639,19 +640,13 @@ start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid) { FINE_NIF(start_elixir_callback_bridge, 0); -fine::Ok<> -elixir_callback_reply(ErlNifEnv *env, int64_t reply_tag, fine::Term payload) { +fine::Ok<> elixir_callback_reply(ErlNifEnv *env, int64_t reply_tag, + fine::Term payload) { DeliverElixirCallbackReply(env, reply_tag, payload); return fine::Ok(); } -FINE_NIF(elixir_callback_reply, 0); - -void SetElixirCallbackDispatcher(ErlNifPid dispatcher_pid) { - auto state = GetElixirCallbackBridgeState(); - state->dispatcher_pid = dispatcher_pid; - state->dispatcher_set = true; -} +FINE_NIF(elixir_callback_reply, ERL_NIF_DIRTY_JOB_IO_BOUND); void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, fine::Term payload) { @@ -766,7 +761,7 @@ CallElixirCallback(int64_t callback_id, ErlNifEnv *msg_env = enif_alloc_env(); - // Encode arguments as [{bin, {type, bits}, shape_list}, ...] + // Encode arguments as [{bin, {type, bits}, shape_tuple}, ...] std::vector args_terms; args_terms.reserve(inputs.size()); @@ -780,8 +775,7 @@ CallElixirCallback(int64_t callback_id, auto [type_atom, bits_term] = EncodeNxType(msg_env, tensor.dtype); - ERL_NIF_TERM type_tuple = - enif_make_tuple2(msg_env, type_atom, bits_term); + ERL_NIF_TERM type_tuple = enif_make_tuple2(msg_env, type_atom, bits_term); std::vector dim_terms; dim_terms.reserve(tensor.dims.size()); @@ -789,19 +783,22 @@ CallElixirCallback(int64_t callback_id, dim_terms.push_back(enif_make_int64(msg_env, d)); } - ERL_NIF_TERM shape_list = - enif_make_list_from_array(msg_env, dim_terms.data(), - dim_terms.size()); + ERL_NIF_TERM shape_tuple; + if (dim_terms.empty()) { + shape_tuple = enif_make_tuple(msg_env, 0); + } else { + shape_tuple = enif_make_tuple_from_array(msg_env, dim_terms.data(), + dim_terms.size()); + } ERL_NIF_TERM arg_tuple = - enif_make_tuple3(msg_env, bin_term, type_tuple, shape_list); + enif_make_tuple3(msg_env, bin_term, type_tuple, shape_tuple); args_terms.push_back(arg_tuple); } ERL_NIF_TERM args_list = - enif_make_list_from_array(msg_env, args_terms.data(), - args_terms.size()); + enif_make_list_from_array(msg_env, args_terms.data(), args_terms.size()); ERL_NIF_TERM tag_term = enif_make_int64(msg_env, tag); ERL_NIF_TERM cb_term = enif_make_int64(msg_env, callback_id); diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 8642da8fcb..7d69fa7ee7 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -79,11 +79,7 @@ defmodule EXLA.CallbackServer do # fail silently so that the rest of the system continues to work. This # allows developing the Elixir side and the native side independently. _ = - try do - EXLA.NIF.start_elixir_callback_bridge(self()) - rescue - _ -> :ok - end + EXLA.NIF.start_elixir_callback_bridge(self()) {:ok, %__MODULE__{}} end @@ -92,19 +88,16 @@ defmodule EXLA.CallbackServer do def handle_call({:register, fun, out_template, static_args}, _from, %__MODULE__{} = state) do key = {fun, Nx.to_template(out_template), static_args} - {id, state} = - case find_existing_id(state.callbacks, key) do - {:ok, id} -> - {id, state} + case find_existing_id(state.callbacks, key) do + {:ok, id} -> + {:reply, id, state} - :error -> - id = state.next_id - callbacks = Map.put(state.callbacks, id, {fun, Nx.to_template(out_template), static_args}) - {%{state | callbacks: callbacks, next_id: id + 1}.next_id - 1, - %{state | callbacks: callbacks, next_id: id + 1}} - end - - {:reply, id, state} + :error -> + id = state.next_id + state = put_in(state.callbacks[id], {fun, Nx.to_template(out_template), static_args}) + state = %{state | next_id: id + 1} + {:reply, id, state} + end end @impl true @@ -205,7 +198,7 @@ defmodule EXLA.CallbackServer do tensor = bin |> Nx.from_binary({type, bits}) - |> Nx.reshape(List.to_tuple(shape)) + |> Nx.reshape(shape) {:cont, {:ok, [tensor | acc]}} rescue diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index cf29af1afa..5ec08e9b0c 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -552,15 +552,13 @@ defmodule EXLA.Defn do %{client: %EXLA.Client{platform: :host}} = state, cache ) do - {tensor_args, _opts} = Enum.split_while(in_args, &(not is_list(&1))) + {tensor_args, static_args} = Enum.split_while(in_args, &(not is_list(&1))) {call_args, cache} = Enum.map_reduce(tensor_args, cache, fn arg, cache -> recur_operator(arg, state, cache) |> unwrap_single_tensor!() end) - static_args = Enum.drop(in_args, length(tensor_args)) - callback_id = EXLA.CallbackServer.register(fun, out_template, static_args) typespecs = container_to_typespecs(out_template) @@ -571,7 +569,7 @@ defmodule EXLA.Defn do callback_id_value = Value.constant(state.builder, [callback_id], callback_id_typespec) - operands = call_args ++ [callback_id_value] + operands = [callback_id_value | call_args] results = Value.elixir_call(operands, typespecs) @@ -583,14 +581,13 @@ defmodule EXLA.Defn do :elixir_call, _expr, %{client: %EXLA.Client{platform: platform}}, - cache + _cache ) do raise """ Nx.elixir_call/3 is currently only supported for EXLA CPU (platform: :host), but the active EXLA client is configured for platform #{inspect(platform)}. Please run on the :host client or wait for future segmentation-based support. """ - |> then(fn _ -> {nil, cache} end) end defp cached_recur_operator( diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 5f33f788b2..d92fc61320 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -837,14 +837,16 @@ defmodule EXLA.MLIR.Value do The `callback_id` is a small integer assigned by `EXLA.CallbackServer` that identifies which Elixir function should be invoked when the host callback - runs. It is passed as an extra scalar S64 tensor operand (last argument) to - the custom call. + runs. It is passed as an extra scalar S64 tensor operand (first argument) to + the custom call so the native handler can load it before touching any tensor + payloads. """ def elixir_call([%Value{function: func} | _] = operands, typespecs) do result_types = typespecs_to_mlir_types(typespecs) attributes = [ call_target_name: attr_string("exla_elixir_callback"), + # api_version 4 enables the typed FFI API used by our callback handler. api_version: attr_i32(4) ] diff --git a/exla/test/exla/defn/elixir_call_exla_test.exs b/exla/test/exla/defn/elixir_call_exla_test.exs deleted file mode 100644 index 3154a0a12d..0000000000 --- a/exla/test/exla/defn/elixir_call_exla_test.exs +++ /dev/null @@ -1,70 +0,0 @@ -defmodule EXLA.Defn.ElixirCallEXLATest do - use ExUnit.Case, async: true - import Nx.Defn - import Nx.Testing - - @moduletag :exla - - setup do - Nx.Defn.default_options(compiler: EXLA) - Nx.default_backend(EXLA.Backend) - :ok - end - - defn add_offset(x) do - out = %{x | type: Nx.Type.to_floating(x.type)} - - Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> - Nx.add(Nx.as_type(t, :f32), opts[:offset]) - end) - end - - test "elixir_call with single output on EXLA CPU" do - x = Nx.iota({5}) - y = add_offset(x) - - expected = Nx.add(Nx.as_type(x, :f32), 10.0) - assert_equal(y, expected) - end - - defn split_and_sum(x) do - fx = Nx.as_type(x, :f32) - - out0 = fx - out1 = fx - out_template = {out0, out1} - - {a, b} = - Nx.elixir_call(out_template, [fx], fn t -> - {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} - end) - - Nx.add(a, b) - end - - test "elixir_call with tuple output on EXLA CPU" do - x = Nx.tensor([1, 2, 3]) - y = split_and_sum(x) - - fx = Nx.as_type(x, :f32) - expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) - assert_equal(y, expected) - end - - defn bad_callback(x) do - out = %{x | type: Nx.Type.to_floating(x.type)} - - Nx.elixir_call(out, [x], fn _t -> - # Wrong shape on purpose - Nx.tensor([1.0, 2.0, 3.0]) - end) - end - - test "elixir_call errors when result shape does not match template" do - x = Nx.iota({2}) - - assert_raise RuntimeError, ~r/expected the elixir_call function to match the given output template/, fn -> - bad_callback(x) - end - end -end diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index f3e52aaea3..f7b3cbe7b8 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -1,4 +1,4 @@ -defmodule EXLA.Defn.ElixirCallEvaluatorTest do +defmodule EXLA.Defn.ElixirCallTest do use ExUnit.Case, async: true import Nx.Defn import Nx.Testing @@ -48,6 +48,25 @@ defmodule EXLA.Defn.ElixirCallEvaluatorTest do assert_equal(y, expected) end + defn bad_callback(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x], fn _t -> + # Wrong shape on purpose + Nx.tensor([1.0, 2.0, 3.0]) + end) + end + + test "elixir_call errors when result shape does not match template" do + x = Nx.iota({2}) + + assert_raise RuntimeError, + ~r/expected the elixir_call function to match the given output template/, + fn -> + bad_callback(x) + end + end + test "works when using EXLA compiler directly" do x = Nx.tensor([1, 2, 3]) y = EXLA.jit_apply(&split_and_sum/1, [x]) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 715a149286..1b934a44b8 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2207,6 +2207,14 @@ defmodule Nx do Inside `defn`, this builds an expression node understood by compilers. Outside `defn` or on backends without special support, it executes `fun` directly and validates the result matches the template. + + ## Argument ordering + + When called inside `defn`, all tensor arguments must be placed **before** + any list arguments. Lists (including keyword lists) are treated as static + Elixir data that is appended to the callback at runtime, while the leading + non-list arguments are compiled as tensors and shipped to the target + backend. Passing a tensor after a list argument raises an error. """ @doc type: :backend def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 9794749db6..f455857f87 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -388,6 +388,8 @@ defmodule Nx.Defn.Expr do @impl true def elixir_call(out, in_args, fun) do + ensure_tensor_args_prefix!(in_args) + {tensor_args, _opts} = Enum.split_while(in_args, &(not is_list(&1))) [%T{data: %Expr{context: context}} | _] = Enum.map(tensor_args, &to_expr/1) @@ -404,6 +406,15 @@ defmodule Nx.Defn.Expr do end end + defp ensure_tensor_args_prefix!(args) do + {_tensor_prefix, static_suffix} = Enum.split_while(args, &is_struct(&1, Nx.Tensor)) + + if Enum.any?(static_suffix, &is_struct(&1, Nx.Tensor)) do + raise ArgumentError, + "Nx.elixir_call/3 expects all tensor arguments to appear before any static arguments, but got: #{inspect(args)}" + end + end + ## Nx.Defn AST callbacks @doc false diff --git a/nx/test/nx/defn/elixir_call_evaluator_test.exs b/nx/test/nx/defn/elixir_call_evaluator_test.exs index 92fad6b431..a73b6e3b6f 100644 --- a/nx/test/nx/defn/elixir_call_evaluator_test.exs +++ b/nx/test/nx/defn/elixir_call_evaluator_test.exs @@ -46,4 +46,19 @@ defmodule Nx.Defn.ElixirCallEvaluatorTest do expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) assert expected == y end + + defn invalid_order(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [[offset: 10.0], x], fn opts, t -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call enforces tensor arguments before lists" do + message = ~r|Nx.elixir_call/3 expects all tensor arguments to appear before any static arguments, but got| + assert_raise ArgumentError, message, fn -> + invalid_order(Nx.iota({2})) + end + end end diff --git a/nx_elixir_call_exla_design.md b/nx_elixir_call_exla_design.md index bacf3bb480..9b5822b7da 100644 --- a/nx_elixir_call_exla_design.md +++ b/nx_elixir_call_exla_design.md @@ -353,10 +353,13 @@ This section records where the **current implementation** matches or intentional - Shape: `Nx.elixir_call(output_template, args, fun)`. - `output_template` is the **first argument**, not an option. - `args` is a list of runtime arguments (tensors + static values). + - Inside `defn`, we require all non-list (tensor) arguments to appear + before any list argument; the first list marks the start of the static + tail that is replayed verbatim on the BEAM side. - `fun` is a plain Elixir function; we don’t support MFA in Phase 1. - `Nx.Defn.Expr.elixir_call/3`: - For tensor output: stores `:elixir_call` node with args `[in_args, fun, out_template]`, where `out_template = Nx.to_template(output)`. - - For tuple output: builds an internal tuple-shaped template (`tuple_out/1`) plus a **user template** (also passed as `out_template`) so EXLA sees the per-element shapes/dtypes. + - For tuple output: builds an internal tuple-shaped template (`tuple_out/1`) plus a `user_template = Nx.to_template(tuple)` that is passed as the third argument. - Rationale: keeping `output_template` as a *value argument* made the IR and EXLA lowering simpler and closer to existing `defn` conventions. #### 8.2 IR Representation (`Nx.Defn.Expr` / `Nx.Defn.Graph`) From 95f7860638104ab8120f76c9f1129d9b6e394fd0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 25 Nov 2025 00:48:01 -0300 Subject: [PATCH 09/42] finish changes code review --- .../exla/custom_calls/elixir_callback.cc | 13 +- exla/c_src/exla/elixir_callback_bridge.h | 12 +- exla/c_src/exla/exla.cc | 128 ++++------ exla/c_src/exla/exla_nif_util.h | 223 +++++++++++++++--- exla/lib/exla/callback_server.ex | 16 +- exla/lib/exla/nif.ex | 1 + 6 files changed, 262 insertions(+), 131 deletions(-) diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc index 4da48dd122..bbafe8c828 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -35,8 +35,8 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, int64_t callback_id = id_buf.reinterpret_data()[0]; // Collect all remaining input tensors (excluding callback id) into - // lightweight payloads. - std::vector inputs; + // lightweight payload views. + std::vector inputs; inputs.reserve(args.size() - 1); for (size_t i = 1; i < args.size(); ++i) { @@ -47,17 +47,14 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, ffi::AnyBuffer buf = *maybe_buf_or; - exla::ElixirCallbackTensor tensor; + exla::ElixirCallbackArg tensor; tensor.dtype = buf.element_type(); auto dims = buf.dimensions(); tensor.dims.assign(dims.begin(), dims.end()); - size_t size_bytes = buf.size_bytes(); - tensor.data.resize(size_bytes); - if (size_bytes > 0) { - std::memcpy(tensor.data.data(), buf.untyped_data(), size_bytes); - } + tensor.data = reinterpret_cast(buf.untyped_data()); + tensor.size_bytes = buf.size_bytes(); inputs.push_back(std::move(tensor)); } diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h index e6d9414f54..2e2e34ca46 100644 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -18,6 +18,13 @@ struct ElixirCallbackTensor { std::vector data; }; +struct ElixirCallbackArg { + xla::ffi::DataType dtype; + std::vector dims; + const uint8_t *data = nullptr; + size_t size_bytes = 0; +}; + struct ElixirCallbackResult { bool ok = false; std::string error; @@ -38,8 +45,9 @@ void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, // // It returns an ElixirCallbackResult that either contains a list of output // tensors (on success) or an error message. -ElixirCallbackResult CallElixirCallback(int64_t callback_id, - const std::vector &inputs); +ElixirCallbackResult +CallElixirCallback(int64_t callback_id, + const std::vector &inputs); } // namespace exla diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index a9da3e9205..8b668ce56e 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -550,81 +550,17 @@ ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { } // Map ffi::DataType to a Nx-style {atom, bits} pair used on the Elixir side. -std::pair EncodeNxType(ErlNifEnv *env, - xla::ffi::DataType dtype) { - const char *atom = nullptr; - int bits = 0; - - switch (dtype) { - case xla::ffi::PRED: - atom = "u"; - bits = 8; - break; - case xla::ffi::S8: - atom = "s"; - bits = 8; - break; - case xla::ffi::S16: - atom = "s"; - bits = 16; - break; - case xla::ffi::S32: - atom = "s"; - bits = 32; - break; - case xla::ffi::S64: - atom = "s"; - bits = 64; - break; - case xla::ffi::U8: - atom = "u"; - bits = 8; - break; - case xla::ffi::U16: - atom = "u"; - bits = 16; - break; - case xla::ffi::U32: - atom = "u"; - bits = 32; - break; - case xla::ffi::U64: - atom = "u"; - bits = 64; - break; - case xla::ffi::F16: - atom = "f"; - bits = 16; - break; - case xla::ffi::F32: - atom = "f"; - bits = 32; - break; - case xla::ffi::F64: - atom = "f"; - bits = 64; - break; - case xla::ffi::BF16: - atom = "bf"; - bits = 16; - break; - case xla::ffi::C64: - atom = "c"; - bits = 64; - break; - case xla::ffi::C128: - atom = "c"; - bits = 128; - break; - default: - atom = "f"; - bits = 32; - break; +std::optional> +EncodeNxType(ErlNifEnv *env, xla::ffi::DataType dtype) { + if (auto primitive = exla::PrimitiveTypeFromFfiDataType(dtype)) { + if (auto info = exla::PrimitiveTypeToNxTypeInfo(*primitive)) { + ERL_NIF_TERM atom_term = enif_make_atom(env, info->atom_name); + ERL_NIF_TERM bits_term = enif_make_int(env, info->bits); + return std::make_pair(atom_term, bits_term); + } } - ERL_NIF_TERM atom_term = enif_make_atom(env, atom); - ERL_NIF_TERM bits_term = enif_make_int(env, bits); - return {atom_term, bits_term}; + return std::nullopt; } } // namespace @@ -648,6 +584,22 @@ fine::Ok<> elixir_callback_reply(ErlNifEnv *env, int64_t reply_tag, FINE_NIF(elixir_callback_reply, ERL_NIF_DIRTY_JOB_IO_BOUND); +fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid) { + (void)env; + auto state = GetElixirCallbackBridgeState(); + + if (state->dispatcher_set && + std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) == + 0) { + state->dispatcher_set = false; + } + + return fine::Ok(); +} + +FINE_NIF(clear_elixir_callback_bridge, 0); + void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, fine::Term payload) { auto state = GetElixirCallbackBridgeState(); @@ -740,7 +692,7 @@ void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, ElixirCallbackResult CallElixirCallback(int64_t callback_id, - const std::vector &inputs) { + const std::vector &inputs) { auto state = GetElixirCallbackBridgeState(); if (!state->dispatcher_set) { @@ -761,21 +713,37 @@ CallElixirCallback(int64_t callback_id, ErlNifEnv *msg_env = enif_alloc_env(); - // Encode arguments as [{bin, {type, bits}, shape_tuple}, ...] + // Encode arguments as [{bin, {type, bits}, shape_tuple}, ...]. We currently + // send plain binaries because the BEAM callback needs to own the data + // lifetime. std::vector args_terms; args_terms.reserve(inputs.size()); for (const auto &tensor : inputs) { ERL_NIF_TERM bin_term; unsigned char *bin_data = - enif_make_new_binary(msg_env, tensor.data.size(), &bin_term); - if (tensor.data.size() > 0) { - memcpy(bin_data, tensor.data.data(), tensor.data.size()); + enif_make_new_binary(msg_env, tensor.size_bytes, &bin_term); + if (tensor.size_bytes > 0) { + memcpy(bin_data, tensor.data, tensor.size_bytes); } - auto [type_atom, bits_term] = EncodeNxType(msg_env, tensor.dtype); + auto type_tuple_or = EncodeNxType(msg_env, tensor.dtype); + if (!type_tuple_or.has_value()) { + enif_free_env(msg_env); + { + std::lock_guard lock(state->mu); + state->pending.erase(tag); + } + + ElixirCallbackResult res; + res.ok = false; + res.error = "unsupported tensor type in EXLA callback argument"; + return res; + } - ERL_NIF_TERM type_tuple = enif_make_tuple2(msg_env, type_atom, bits_term); + auto type_info = type_tuple_or.value(); + ERL_NIF_TERM type_tuple = + enif_make_tuple2(msg_env, type_info.first, type_info.second); std::vector dim_terms; dim_terms.reserve(tensor.dims.size()); diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 714f74f2da..65caec2f0f 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -4,10 +4,11 @@ #include #include -#include "xla/shape.h" -#include "xla/shape_util.h" #include "mlir/IR/Types.h" #include "stablehlo/dialect/StablehloOps.h" +#include "xla/ffi/api/ffi.h" +#include "xla/shape.h" +#include "xla/shape_util.h" namespace exla { @@ -28,6 +29,187 @@ static auto type = fine::Atom("type"); static auto u = fine::Atom("u"); static auto warning = fine::Atom("warning"); } // namespace atoms + +struct NxTypeInfo { + const char *atom_name; + const fine::Atom *atom_ref; + uint64_t bits; + + fine::Atom atom() const { + if (atom_ref) { + return *atom_ref; + } + return fine::Atom(atom_name); + } +}; + +inline std::optional +PrimitiveTypeToNxTypeInfo(xla::PrimitiveType type) { + switch (type) { + case xla::PRED: + return NxTypeInfo{"pred", &atoms::pred, 8}; + case xla::S2: + return NxTypeInfo{"s", &atoms::s, 2}; + case xla::S4: + return NxTypeInfo{"s", &atoms::s, 4}; + case xla::S8: + return NxTypeInfo{"s", &atoms::s, 8}; + case xla::S16: + return NxTypeInfo{"s", &atoms::s, 16}; + case xla::S32: + return NxTypeInfo{"s", &atoms::s, 32}; + case xla::S64: + return NxTypeInfo{"s", &atoms::s, 64}; + case xla::U2: + return NxTypeInfo{"u", &atoms::u, 2}; + case xla::U4: + return NxTypeInfo{"u", &atoms::u, 4}; + case xla::U8: + return NxTypeInfo{"u", &atoms::u, 8}; + case xla::U16: + return NxTypeInfo{"u", &atoms::u, 16}; + case xla::U32: + return NxTypeInfo{"u", &atoms::u, 32}; + case xla::U64: + return NxTypeInfo{"u", &atoms::u, 64}; + case xla::F8E4M3FN: + case xla::F8E5M2: + return NxTypeInfo{"f", &atoms::f, 8}; + case xla::F16: + return NxTypeInfo{"f", &atoms::f, 16}; + case xla::BF16: + return NxTypeInfo{"bf", &atoms::bf, 16}; + case xla::F32: + return NxTypeInfo{"f", &atoms::f, 32}; + case xla::F64: + return NxTypeInfo{"f", &atoms::f, 64}; + case xla::C64: + return NxTypeInfo{"c", &atoms::c, 64}; + case xla::C128: + return NxTypeInfo{"c", &atoms::c, 128}; + default: + return std::nullopt; + } +} + +inline std::optional +PrimitiveTypeFromFfiDataType(xla::ffi::DataType dtype) { + switch (dtype) { + case xla::ffi::PRED: + return xla::PRED; + case xla::ffi::S2: + return xla::S2; + case xla::ffi::S4: + return xla::S4; + case xla::ffi::S8: + return xla::S8; + case xla::ffi::S16: + return xla::S16; + case xla::ffi::S32: + return xla::S32; + case xla::ffi::S64: + return xla::S64; + case xla::ffi::U2: + return xla::U2; + case xla::ffi::U4: + return xla::U4; + case xla::ffi::U8: + return xla::U8; + case xla::ffi::U16: + return xla::U16; + case xla::ffi::U32: + return xla::U32; + case xla::ffi::U64: + return xla::U64; + case xla::ffi::F8E4M3FN: + return xla::F8E4M3FN; + case xla::ffi::F8E5M2: + return xla::F8E5M2; + case xla::ffi::F16: + return xla::F16; + case xla::ffi::BF16: + return xla::BF16; + case xla::ffi::F32: + return xla::F32; + case xla::ffi::F64: + return xla::F64; + case xla::ffi::C64: + return xla::C64; + case xla::ffi::C128: + return xla::C128; + default: + return std::nullopt; + } +} + +inline std::optional +PrimitiveTypeFromMlirElement(const mlir::Type &element_type) { + if (element_type.isSignlessInteger(1)) { + return xla::PRED; + } + + if (auto integer_type = mlir::dyn_cast(element_type)) { + int width = integer_type.getWidth(); + if (integer_type.isUnsigned()) { + switch (width) { + case 2: + return xla::U2; + case 4: + return xla::U4; + case 8: + return xla::U8; + case 16: + return xla::U16; + case 32: + return xla::U32; + case 64: + return xla::U64; + } + } else { + switch (width) { + case 2: + return xla::S2; + case 4: + return xla::S4; + case 8: + return xla::S8; + case 16: + return xla::S16; + case 32: + return xla::S32; + case 64: + return xla::S64; + } + } + } else if (element_type.isBF16()) { + return xla::BF16; + } else if (auto float_type = mlir::dyn_cast(element_type)) { + int width = float_type.getWidth(); + switch (width) { + case 8: + return xla::F8E4M3FN; + case 16: + return xla::F16; + case 32: + return xla::F32; + case 64: + return xla::F64; + } + } else if (auto complex_type = + mlir::dyn_cast(element_type)) { + auto inner = complex_type.getElementType(); + if (auto float_type = mlir::dyn_cast(inner)) { + switch (float_type.getWidth()) { + case 32: + return xla::C64; + case 64: + return xla::C128; + } + } + } + + return std::nullopt; +} } // namespace exla namespace fine { @@ -176,46 +358,17 @@ template <> struct Encoder { return fine::encode(env, exla::atoms::token); } - std::optional type_name; - std::optional type_size; - if (mlir::isa(type)) { auto tensor_type = mlir::cast(type); auto element_type = tensor_type.getElementType(); - - if (element_type.isSignlessInteger(1)) { - type_name = exla::atoms::pred; - type_size = 8; - } else if (auto integer_type = - mlir::dyn_cast(element_type)) { - if (integer_type.isUnsigned()) { - type_name = exla::atoms::u; - } else { - type_name = exla::atoms::s; + if (auto primitive = exla::PrimitiveTypeFromMlirElement(element_type)) { + if (auto info = exla::PrimitiveTypeToNxTypeInfo(*primitive)) { + return fine::encode(env, std::make_tuple(info->atom(), info->bits)); } - - type_size = integer_type.getWidth(); - } else if (element_type.isBF16()) { - type_name = exla::atoms::bf; - type_size = 16; - } else if (auto float_type = - mlir::dyn_cast(element_type)) { - type_name = exla::atoms::f; - type_size = float_type.getWidth(); - } else if (auto complex_type = - mlir::dyn_cast(element_type)) { - auto element_type = complex_type.getElementType(); - type_name = exla::atoms::c; - type_size = mlir::cast(element_type).getWidth() * 2; } } - if (type_name) { - return fine::encode( - env, std::make_tuple(type_name.value(), type_size.value())); - } else { - throw std::invalid_argument("encode failed, unexpected mlir type"); - } + throw std::invalid_argument("encode failed, unexpected mlir type"); } static ERL_NIF_TERM encode_shape(ErlNifEnv *env, const mlir::Type &type) { diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 7d69fa7ee7..a5177fabab 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -74,16 +74,20 @@ defmodule EXLA.CallbackServer do @impl true def init(:ok) do # Inform native side that this process is the dispatcher for elixir callbacks. - # - # If the NIF has not implemented `start_elixir_callback_bridge/1` yet, we - # fail silently so that the rest of the system continues to work. This - # allows developing the Elixir side and the native side independently. - _ = - EXLA.NIF.start_elixir_callback_bridge(self()) + _ = EXLA.NIF.start_elixir_callback_bridge(self()) {:ok, %__MODULE__{}} end + @impl true + def terminate(_reason, _state) do + try do + EXLA.NIF.clear_elixir_callback_bridge(self()) + rescue + _ -> :ok + end + end + @impl true def handle_call({:register, fun, out_template, static_args}, _from, %__MODULE__{} = state) do key = {fun, Nx.to_template(out_template), static_args} diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 5cded357bd..2a7bb2df13 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -80,6 +80,7 @@ defmodule EXLA.NIF do # Elixir callback bridge (Phase 1: CPU-only, simple APIs) def start_elixir_callback_bridge(_dispatcher_pid), do: err!() + def clear_elixir_callback_bridge(_dispatcher_pid), do: err!() def elixir_callback_reply(_reply_tag, _payload), do: err!() defp err!(), do: :erlang.nif_error(:undef) From c9aa9bdee1b893c4e2d632739e6b256becdabd5b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 25 Nov 2025 00:51:32 -0300 Subject: [PATCH 10/42] chore: remove unused files --- code_review.md | 27 -- nx_elixir_call_exla_design.md | 478 ---------------------------------- 2 files changed, 505 deletions(-) delete mode 100644 code_review.md delete mode 100644 nx_elixir_call_exla_design.md diff --git a/code_review.md b/code_review.md deleted file mode 100644 index 276e22c184..0000000000 --- a/code_review.md +++ /dev/null @@ -1,27 +0,0 @@ -review notes: - -- Nx.Defn Expr added a third out_template argument, but the template can be inferred from the expression itself. -- exla: elixir_call_test and elixir_call_exla_test are redundant with each other. we can combine them in a single file. - -- callback server and the rest of the code seem to assume that tensor arguments are always at the beginning of the function. We should enforce this more clearly and document this. -- given that the callback server is named, enif_whereis_pid(https://www.erlang.org/doc/apps/erts/erl_nif.html#enif_whereis_pid) could be used to fetch the current pid for the function. -- EXLA.CallbackServer decode_args uses from_binary without options. We should keep track of the backend options such as the device used to allocate an EXLA tensor. Ideally, we shouldn't even be copying data back and forth. shape should already be passed as a tuple from the NIF. -- EXLA.CallbackServer does encode_reply/encode_outputs really need to "Nx.to_binary" the results? It seems like we should be able to pass EXLA Buffer refs back and forth. - -- EXLA.Defn operands = call_args ++ [callback_id_value] we should prepend the id instead of append id -- What is api_version in EXLA.MLIR.Value.elixir_call? -- Should the callback_id instead be an attribute given that it should not change during execution? - ----- - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - exla_elixir_callback, exla_elixir_callback_impl, - ffi::Ffi::Bind() - .RemainingArgs() - .RemainingRets()); - -This could receive the id in the first argument, and the tensors in the second. - -exla.cc:550 I think there is already another function for mapping exla to nx types. - -exla.cc:648 FINE_NIF(elixir_callback_reply, 0) I think is missing an IO-bound attr \ No newline at end of file diff --git a/nx_elixir_call_exla_design.md b/nx_elixir_call_exla_design.md deleted file mode 100644 index 9b5822b7da..0000000000 --- a/nx_elixir_call_exla_design.md +++ /dev/null @@ -1,478 +0,0 @@ -## Design: `Nx.elixir_call/3` and EXLA Integration - -### 1. Overview - -This document describes a two-phase plan to implement safe, efficient support for calling arbitrary Elixir code from `defn` via `Nx.elixir_call/3`, with a focus on EXLA. - -- **Phase 1**: CPU-only implementation using EXLA + XLA host `CustomCall`, with a safe bridge to Elixir (no `nif_call`-style reentry into BEAM). -- **Phase 2**: Graph segmentation in `Nx.Defn.Graph` so the compiler can: - - Treat `elixir_call` as a boundary and split the computation into stages. - - Enable cross-device execution (CPU/GPU) while preserving a single user API. - - Eventually optimize some callbacks to be compiled away or lowered differently (e.g. pure functions expressible in Nx). - -This work extends [PR #1627, “feat: Nx.elixir_call/3”](https://github.com/elixir-nx/nx/pull/1627), which currently implements `Nx.elixir_call/3` only in `Nx.Defn.Evaluator`. - ---- - -### 2. Goals and Non-goals - -- **Goals** - - **G1**: Provide a **public API** (`Nx.elixir_call/3`) that allows calling user-provided Elixir code inside `defn`. - - **G2**: Implement a **safe** EXLA backend for `elixir_call` on **CPU** using XLA host `CustomCall`. - - **G3**: Ensure callbacks have **statically known shapes/dtypes** to keep compilation and gradients well-defined. - - **G4**: Provide a **unified intermediate representation** in `Nx.Defn.Graph` so future backends (EXLA GPU, other compilers) can share the same abstraction. - - **G5**: In Phase 2, support **graph segmentation** around `elixir_call` so that: - - We can mix device computation (CPU/GPU) with Elixir callbacks. - - The compiler can decide to either split or compile callbacks, depending on their structure. - -- **Non-goals (for now)** - - **NG1**: No direct, device-side callbacks for GPU/TPU in Phase 1 (no infeed/outfeed complexity yet). - - **NG2**: No guarantees about **side-effect isolation** of callbacks (user is responsible), beyond not violating BEAM safety. - - **NG3**: No attempt to automatically infer output shapes/dtypes of callbacks at runtime; shapes must be known at `defn`/compile time. - ---- - -### 3. Terminology - -- **`elixir_call` node**: The internal IR node representing a call to arbitrary Elixir code (backed by `Nx.elixir_call/3`). -- **Callback ID**: A stable identifier (string or integer) used to look up the Elixir function and output spec at compile/run time. -- **Output spec**: Shape and type description for all outputs of a callback. -- **Bridge thread**: A native (C/C++) thread that acts as a mediator between XLA/EXLA and BEAM, using message-passing only (no direct BEAM calls from arbitrary XLA threads). - ---- - -### 4. Phase 1: CPU-only EXLA Backend (Host `CustomCall`) - -#### 4.1 Public API: `Nx.elixir_call/3` - -- **Goal**: Reuse and finalize the API introduced in [nx#1627](https://github.com/elixir-nx/nx/pull/1627). - -- **Shape** (subject to minor refinement): - - - `Nx.elixir_call(args, fun_or_mfa, opts \\ [])` - -- **Key options / metadata**: - - **`id` or `name`**: A stable callback identifier (string or integer). - - **`output_template`** (or equivalent): A value (or list/tuple of values) that describes the **shapes and dtypes** of the callback’s outputs: - - Can be Nx tensors or a structured spec, but must be statically known at `defn` compile time. - - Potentially an **`impure`** flag (or similar) in the future to guide compiler optimizations. - -- **Constraints**: - - `fun_or_mfa` is not executed at `defn` compile time (except possibly in the evaluator backend). - - Output shape/type comes from `output_template`, not from running the function. - -#### 4.2 Nx IR: Representing `elixir_call` - -- **Extend** `Nx.Defn.Expr` / `Nx.Defn.Graph` to carry `elixir_call` nodes explicitly. - -- Proposed internal form: - - - `{:elixir_call, meta, args}` - - Where: - - **`meta`** includes: - - `callback_id` (string/int). - - `fun_or_mfa` or internal reference (for evaluator backend and dispatcher). - - `output_spec` (shapes + dtypes). - - Any flags required for compilation/grad. - - **`args`**: list of argument expressions. - -- **Requirements**: - - Shape inference for `elixir_call` uses `output_spec`. - - Optimizer must **not** fuse or eliminate `elixir_call`; it is a logical boundary and may be effectful. - - The evaluator backend (as in nx#1627) already knows how to interpret it. - -#### 4.3 EXLA Lowering: From `elixir_call` to HLO/StableHLO - -- In the EXLA backend (Elixir side): - - - When encountering an `elixir_call` node while building HLO/StableHLO: - - - Lower `args` to HLO values. - - Construct a `CustomCall` operation with: - - **Operands**: those input HLO values. - - **Result types**: from `output_spec`. - - **Call target name**: e.g. `"exla_elixir_callback"`. - - **Attributes**: - - `callback_id` (string/int). - - Optionally an encoded `output_spec` (if needed on the native side). - -- **CPU-only restriction** (Phase 1): - - If the active EXLA client is **CPU**, allow this lowering. - - If the client is GPU (or other non-CPU), raise a **clear error**: - - e.g. “`Nx.elixir_call/3` is currently only supported for EXLA CPU; please run on CPU or wait for Phase 2 segmentation support.” - -#### 4.4 Native EXLA: Callback Registry and Bridge - -- **Callback registry (Elixir → native)**: - - At the time of building an EXLA executable, collect all callbacks: - - - Map: `callback_id → {fun_or_mfa, output_spec}`. - - - Pass this mapping down to the native side, associated with the executable or run context. - -- **Native data structures** (C/C++ side): - - - `struct CallbackRequest { RunRef run_ref; CallbackId callback_id; std::vector args; ReplyTag reply_tag; std::promise promise; };` - - - `struct CallbackResult { ReplyTag reply_tag; std::vector outputs; Error error; };` - - - A **thread-safe queue** for `CallbackRequest`s. - - - A **map** `reply_tag → std::promise` guarded by a mutex. - -- **Bridge thread**: - - - Started when the EXLA NIF is initialized (or when the first callback-capable executable is created). - - - Main loop: - 1. Pop `CallbackRequest` from the queue. - 2. Serialize `args` into a compact binary representation (shape metadata + flat data). - 3. Use `enif_send` to send a message to a **dedicated Elixir dispatcher process**: - - Message format (conceptual): - `{:exla_elixir_call, run_ref, callback_id, args_bin, reply_tag}`. - 4. Wait on the `std::promise`/`std::future` associated with `reply_tag` until `CallbackResult` is set: - - **Important**: This wait uses only native primitives (no BEAM APIs, no `nif_call`), so it is safe w.r.t. BEAM scheduling. - 5. On success/failure, the handler (see next section) is unblocked. - -#### 4.5 XLA Host `CustomCall` Handler (CPU Client) - -- **Registration**: - - - For the EXLA CPU client, register a host call target with XLA: - - - Name: `"exla_elixir_callback"`. - -- **Handler logic**: - - 1. Extract: - - `callback_id` from `CustomCall` attributes. - - Operand buffers (inputs). - - Output buffers and their shapes/dtypes. - 2. Convert operand buffers into host tensors and build a `CallbackRequest`: - - Assign a fresh `reply_tag`. - - Create a `std::promise` and `std::future`. - - Insert `reply_tag → promise` into the map. - 3. Enqueue the `CallbackRequest` onto the native request queue. - 4. Block on the `future` until a result arrives (native wait). - 5. Once the `CallbackResult` is available: - - On success: - - Write returned tensor data into XLA’s output buffers. - - Return `OK` to XLA. - - On error or timeout: - - Return an error `Status` so the XLA run fails with a descriptive error. - -#### 4.6 Elixir Dispatcher Process - -- Implement a **GenServer** in Nx/EXLA that acts as the BEAM-side dispatcher for callbacks. - -- Responsibilities: - - - Maintain: - - `callbacks: %{ {run_ref, callback_id} => {fun_or_mfa, output_spec} }`. - - - Handle messages from the bridge thread: - - ```elixir - def handle_info({:exla_elixir_call, run_ref, callback_id, args_bin, reply_tag}, state) do - {args, arg_specs} = deserialize_tensors(args_bin) - {fun_or_mfa, output_spec} = - Map.fetch!(state.callbacks, {run_ref, callback_id}) - - # Execute user code (possibly in a Task for isolation) - result = - try do - call_user_fun(fun_or_mfa, args) - rescue - exception -> {:error, {:exception, exception, __STACKTRACE__}} - catch - kind, reason -> {:error, {kind, reason}} - end - - reply_payload = - encode_result(result, output_spec) # either {:ok, tensors_bin} or {:error, reason} - - # One NIF call to signal back to native side (bridge thread sees this) - send_reply_to_nif(reply_tag, reply_payload) - - {:noreply, state} - end - ``` - - - Ensure: - - Result shapes/dtypes match `output_spec`; otherwise return a structured error. - - Optional: enforce configurable timeouts per callback and abort the run on timeout. - -- **API considerations**: - - - A worker or supervisor module (e.g. `EXLA.CallbackServer`) could manage: - - Registration of callbacks per `run_ref`. - - Cleanup after run completion. - -when registering the callback, there should be a fun/capture -> integer mapping (maybe use :counters for generating these integers) and the function should be registered with this id. The id should be returned so that the compiler can use it. This turns the callback server into the source of truth and the generator of ids - -#### 4.7 Error Handling and Validation - -- **Compile-time checks**: - - - Verify that: - - `output_template` can be converted into a valid `output_spec`. - - Grad rules (where applicable) can be defined; if not, error clearly or fallback. - -- **Runtime checks**: - - - After Elixir callback returns: - - Validate result shape/dtype vs `output_spec`. - - On mismatch, generate a descriptive error and fail the XLA run. - -- **Timeouts**: - - - Optional but recommended: - - Per-callback timeout at the dispatcher level. - - If timeout expires, reply with error; native side then aborts the run. - -- **Safety**: - - - No calls from arbitrary XLA threads into BEAM functions. - - All BEAM interaction uses `enif_send` from the bridge thread or explicit NIF calls from Elixir processes. - ---- - -### 5. Phase 2: Graph Segmentation and Cross-Device Support - -After Phase 1 is solid on CPU, we extend support to all EXLA devices (CPU/GPU) via **segmentation** in `Nx.Defn.Graph`. This aligns with [the discussion on nx#1627](https://github.com/elixir-nx/nx/pull/1627), where `elixir_call` and other “optional callback” mechanisms share a unified specification, and the compiler decides whether to split or to compile. - -#### 5.1 Treat `elixir_call` as a Stage Boundary - -- In `Nx.Defn.Graph`, treat each `elixir_call` as a **potential cut point**: - - - Find maximal subgraphs that: - - Contain no `elixir_call`. - - Are otherwise pure Nx computations. - -- Build a sequence: - - - `stage_0` → `elixir_call_0` → `stage_1` → `elixir_call_1` → … → `stage_n`. - -- Each `stage_i` will be compiled separately for a target device (CPU or GPU). - -#### 5.2 Stage Compilation - -- For each pure stage: - - - Infer shapes and types as usual. - - Choose a device (matching the EXLA client or using more advanced heuristics later). - - Compile to an EXLA executable. - -- For each `elixir_call` between stages: - - - Reuse the **Phase 1 dispatcher + bridge**: - - Inputs: outputs of the previous stage (converted to host tensors). - - Outputs: inputs to the next stage (converted back and transferred as needed). - -#### 5.3 Orchestration Runtime - -- Implement an orchestrator (in Nx/EXLA) that performs: - - 1. Run `stage_0` on its device → get outputs. - 2. Transfer these outputs to host (if needed). - 3. Invoke the Elixir callback via the dispatcher → get callback outputs. - 4. Transfer callback outputs to the device for `stage_1` (if needed). - 5. Repeat until all stages and callbacks are executed. - -- This orchestration: - - - Provides **consistent semantics** across CPU and GPU. - - Keeps the user API (`Nx.elixir_call/3`) unchanged. - - Allows future optimizations where: - - Some callbacks are compiled away (if expressible in pure Nx). - - Some backends choose device-specific mechanisms (e.g. XLA GPU host-callbacks or infeed/outfeed) internally. - -#### 5.4 Compiler Decisions (Future Work) - -- Over time, the compiler can classify callbacks: - - - **Pure, shape-stable callbacks definable in Nx**: - - Potentially inline/compile them, removing the runtime callback. - - - **Genuinely dynamic callbacks**: - - Keep them as segmentation boundaries. - -- This addresses the concern raised in [nx#1627](https://github.com/elixir-nx/nx/pull/1627) about having a **unified specification** for callbacks while allowing the compiler to choose between splitting and compiling. - ---- - -### 6. Open Questions / Next Steps - -- **Naming and API**: - - Finalize `Nx.elixir_call/3` naming and argument order. - - Decide whether to expose more advanced options (timeouts, impurity markers, etc.) in `opts`. - -- **Gradients**: - - For Phase 1, gradients may be: - - Not supported for arbitrary callbacks (raise on use), or - - Supported only when the callback is expressible in Nx and compiled away (future optimization). - -- **Concurrency model**: - - Decide how many bridge threads to run. - - Understand the interaction with multiple concurrent EXLA runs and multiple callback-heavy computations. - -- **Device-specific optimizations** (beyond segmentation): - - Investigate XLA’s GPU host-callback support and whether to implement a more tightly integrated path for GPU (possibly involving infeed/outfeed under the hood) once segmentation version is stable. - ---- - -### 7. Implementation Order Checklist - -1. **Land / refine `Nx.elixir_call/3` API and IR node** (based on [nx#1627](https://github.com/elixir-nx/nx/pull/1627)). -2. **Add EXLA lowering** for CPU: - - Map `elixir_call` → HLO/StableHLO `CustomCall` with target `"exla_elixir_callback"`. -3. **Implement native callback registry + bridge thread** in EXLA NIF. -4. **Register CPU host `CustomCall` handler** (`"exla_elixir_callback"`) and wire it to the bridge. -5. **Implement Elixir dispatcher process** for callbacks + error handling + sanity checks. -6. **Add tests** for CPU: - - Simple callbacks. - - Multiple callbacks in a single `defn`. - - Error cases (shape mismatch, thrown exceptions). -7. **Introduce segmentation in `Nx.Defn.Graph`**: - - Identify stages between `elixir_call` nodes. - - Compile/orchestrate stages for CPU/GPU. -8. **Extend EXLA to allow callbacks under segmentation** when using GPU clients. -9. Iterate on compiler-side heuristics to decide when callbacks can be compiled away vs split. - -### 8. Phase 1 – Intended vs Implemented (Status Summary) - -This section records where the **current implementation** matches or intentionally diverges from the original Phase 1 plan, so future work can see what is done vs still open. - -#### 8.1 Public API (`Nx.elixir_call/3`) - -- **Intended (this doc, §4.1)**: - - Shape: `Nx.elixir_call(args, fun_or_mfa, opts \\ [])`. - - Explicit `output_template` / `output_spec` passed via `opts`. -- **Implemented (Nx 0.10 + EXLA 0.10)**: - - Shape: `Nx.elixir_call(output_template, args, fun)`. - - `output_template` is the **first argument**, not an option. - - `args` is a list of runtime arguments (tensors + static values). - - Inside `defn`, we require all non-list (tensor) arguments to appear - before any list argument; the first list marks the start of the static - tail that is replayed verbatim on the BEAM side. - - `fun` is a plain Elixir function; we don’t support MFA in Phase 1. - - `Nx.Defn.Expr.elixir_call/3`: - - For tensor output: stores `:elixir_call` node with args `[in_args, fun, out_template]`, where `out_template = Nx.to_template(output)`. - - For tuple output: builds an internal tuple-shaped template (`tuple_out/1`) plus a `user_template = Nx.to_template(tuple)` that is passed as the third argument. - - Rationale: keeping `output_template` as a *value argument* made the IR and EXLA lowering simpler and closer to existing `defn` conventions. - -#### 8.2 IR Representation (`Nx.Defn.Expr` / `Nx.Defn.Graph`) - -- **Intended (§4.2)**: - - `{:elixir_call, meta, args}` with `meta` carrying: - - `callback_id`, `fun_or_mfa`, `output_spec`, etc. -- **Implemented**: - - `Expr` op is still `:elixir_call`, but: - - We **do not** store `callback_id` in `meta`; it is managed solely by EXLA. - - Arguments are `[in_args, fun, out_template]`. - - Shape/type inference for the node uses `out_template` via the existing template machinery. - - `Nx.Defn.Tree.apply_args/4` and `Nx.Defn.Evaluator.compute_cache/4` / `eval_apply/4` were updated to be aware of the `out_template` third arg but largely **ignore it** at runtime (it is for compilation only). - -#### 8.3 EXLA Lowering to StableHLO `CustomCall` - -- **Intended (§4.3)**: - - `CustomCall("exla_elixir_callback")` with: - - Result types from `output_spec`. - - Attributes: - - `callback_id` (string/int). - - Possibly encoded `output_spec`. -- **Implemented**: - - We lower to a `stablehlo.custom_call` with: - - `call_target_name = "exla_elixir_callback"`. - - `api_version = 4` (typed FFI). - - **No `backend_config` or dictionary attributes** for callback id. - - Instead of encoding `callback_id` as an attribute, we: - - Append a scalar S64 operand at the **end of the operand list** carrying `callback_id`. - - Register a typed FFI handler that: - - Interprets the last operand as the callback id. - - Treats the remaining operands as regular tensor arguments. - - Result types are derived from the `out_template`: - - `container_to_typespecs(out_template)` produces one `EXLA.Typespec` per tensor (including tuple elements). - -#### 8.4 Native Bridge & Callback Registry - -- **Intended (§4.4–4.5)**: - - Per-run mapping `callback_id → {fun_or_mfa, output_spec}`. - - Bridge using `RunRef`, `CallbackRequest`, `CallbackResult`, per-run state. -- **Implemented**: - - We use a **global, process-wide `EXLA.CallbackServer`**: - - Maps `callback_id (integer)` → `{fun, out_template, static_args}`. - - Reuses ids when the same `{fun, out_template, static_args}` triple is registered again. - - The C++ bridge maintains: - - A global `ElixirCallbackBridgeState` with: - - `dispatcher_pid`, `next_tag`, and a `pending` map from `reply_tag` to a small `ElixirCallbackPending` object (`std::mutex`, `std::condition_variable`, `ElixirCallbackResult`). - - The **bridge thread** concept is realized as: - - The host `CustomCall` handler runs on an XLA-controlled thread. - - It **blocks natively** on a `std::condition_variable` associated with a `reply_tag` until the BEAM side replies via `EXLA.NIF.elixir_callback_reply/2`. - - There is currently **no per-run `RunRef`**; callbacks are effectively global to the VM. - -#### 8.5 Elixir Dispatcher (`EXLA.CallbackServer`) - -- **Intended (§4.6)**: - - Dedicated dispatcher process keyed by `(run_ref, callback_id)`. - - Messages like `{:exla_elixir_call, run_ref, callback_id, args_bin, reply_tag}`. -- **Implemented**: - - `EXLA.CallbackServer` is a `GenServer` with: - - `callbacks :: %{callback_id => {fun, out_template, static_args}}`. - - Native side sends: - - `{:exla_elixir_call, callback_id, args_spec, reply_tag}`. - - `args_spec` is a list of `{bin, {type_atom, bits}, shape_list}` tuples. - - Dispatcher logic: - - Decodes `args_spec` into `Nx.Tensor`s (`Nx.from_binary/3` + `Nx.reshape/2`). - - Appends `static_args` captured at registration time. - - Executes the callback function with `[tensor_args ++ static_args]`. - - Validates the result against `out_template`: - - Tuple size check + per-element shape/dtype/names check. - -#### 8.6 Error Handling & Mapping - -- **Intended (§4.7)**: - - Shape/dtype validation vs `output_spec`. - - Clear errors; possibly mapping to `ArgumentError` or similar. -- **Implemented (Phase 1)**: - - `EXLA.CallbackServer.ensure_compatible/2`: - - Returns `{:ok, value}` on success. - - Returns tagged errors: - - `{:error, {:shape_mismatch, left, right}}`. - - `{:error, {:invalid_result, left, right}}` for non-tensors/tuples. - - `encode_reply/1` maps internal error tuples to **typed error payloads**: - - Shape mismatch / invalid result: - - Encoded as `{:error, {:argument_error, message_binary}}` where the message mirrors `Nx.ensure_call_compatible!/2`, e.g. - - `"expected the elixir_call function to match the given output template ..., got: ..."` - - Decode failures, invalid args spec, unknown callback id, user exceptions, throws, exits: - - Encoded as `{:error, {:runtime_error, message_binary}}` with descriptive text (`"Elixir callback raised: ..."`, etc.). - - Native `DeliverElixirCallbackReply`: - - For `{:ok, binaries}`, fills result buffers. - - For `{:error, {kind_atom, message_binary}}`: - - Uses the **message** as `result.error` string returned to XLA. - - As a result: - - From the user’s point of view, `Nx.elixir_call/3` under EXLA now fails with: - - `RuntimeError` carrying a **descriptive, Nx-style message** (e.g. shape mismatch) instead of the generic `"elixir callback returned error"`. - - We do **not** currently raise `ArgumentError` directly from EXLA runs; everything surfaces as `RuntimeError` with a rich message, which tests explicitly assert. - -#### 8.7 Timeouts & Robustness - -- **Intended (§4.7, Timeouts)**: - - Optional per-callback timeout. - - Native side aborts run on timeout. -- **Implemented**: - - **No timeouts yet**: - - The host `CustomCall` waits indefinitely on `condition_variable` for the reply. - - If the Elixir dispatcher never replies, the XLA run will hang. - - This is an explicit **TODO** for future hardening; the design here still stands, but is not yet implemented. - -#### 8.8 Phase 2 – Not Implemented Yet - -All of §5 (segmentation and cross-device support) remains **design only**: - -- `elixir_call` is **not yet used as a segmentation boundary** in `Nx.Defn.Graph`. -- There is no multi-stage orchestration for GPU/TPU plus CPU callbacks. -- Callbacks are only allowed on the **EXLA host (CPU) client**, and we eagerly raise if the client platform is not `:host`. - -This section should be updated again once segmentation and GPU support are implemented. From c7c4871bc3133400e21e8f06012750c05060e86b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 25 Nov 2025 01:03:17 -0300 Subject: [PATCH 11/42] docs: document the lock issue --- exla/test/exla/defn/elixir_call_test.exs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index f7b3cbe7b8..d83d80748a 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -8,12 +8,18 @@ defmodule EXLA.Defn.ElixirCallTest do :ok end + defp add_offset_callback(t, opts) do + t + |> Nx.as_type(:f32) + # TODO: if we run on the same device there will be a problem due to the device locking. + |> Nx.backend_transfer({EXLA.Backend, client: :host, device_id: 1}) + |> Nx.add(opts[:offset]) |> dbg(structs: false) + end + defn add_offset(x) do out = %{x | type: Nx.Type.to_floating(x.type)} - Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> - Nx.add(Nx.as_type(t, :f32), opts[:offset]) - end) + Nx.elixir_call(out, [x, [offset: 10.0]], &add_offset_callback/2) end test "elixir_call with single output" do From ddb3733197743e1ae0e515fa766b7e126ca3aec3 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 26 Nov 2025 03:07:23 -0300 Subject: [PATCH 12/42] refactor: ElixirCallbackPending as a resource --- exla/c_src/exla/elixir_callback_bridge.h | 19 +++++- exla/c_src/exla/exla.cc | 78 +++++++++--------------- exla/lib/exla/callback_server.ex | 17 ++++-- exla/lib/exla/nif.ex | 3 + exla/test/exla/defn/elixir_call_test.exs | 4 +- 5 files changed, 63 insertions(+), 58 deletions(-) diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h index 2e2e34ca46..7b8964ddbc 100644 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -1,6 +1,8 @@ #pragma once +#include #include +#include #include #include @@ -31,9 +33,20 @@ struct ElixirCallbackResult { std::vector outputs; }; -// Called from the Elixir side to deliver a reply for a given callback tag. -void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, - fine::Term payload); +// Per-callback pending state used to synchronize between the XLA host thread +// and the Elixir-side dispatcher. This is exposed as a Fine resource so we +// can pass it as an opaque handle in messages instead of using integer tags. +struct ElixirCallbackPending { + std::mutex mu; + std::condition_variable cv; + bool done = false; + ElixirCallbackResult result; +}; + +// Called from the Elixir side to deliver a reply for a given pending handle. +void DeliverElixirCallbackReply( + ErlNifEnv *env, fine::ResourcePtr pending, + fine::Term payload); // Synchronously calls the Elixir callback identified by `callback_id` with the // given tensor arguments. This function: diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 8b668ce56e..10e600ecb7 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -32,6 +32,7 @@ FINE_RESOURCE(exla::ExlaBuffer); FINE_RESOURCE(exla::ExlaExecutable); FINE_RESOURCE(exla::MLIRModule); FINE_RESOURCE(exla::MLIRFunction); +FINE_RESOURCE(exla::ElixirCallbackPending); // MLIR Functions @@ -529,22 +530,14 @@ FINE_NIF(get_per_device_memory, 0); namespace { -struct ElixirCallbackPending { - std::mutex mu; - std::condition_variable cv; - bool done = false; - ElixirCallbackResult result; -}; - struct ElixirCallbackBridgeState { ErlNifPid dispatcher_pid; bool dispatcher_set = false; - std::atomic next_tag{1}; - std::mutex mu; - std::unordered_map> pending; }; -ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { +// We keep a single global bridge state, but expose it as a Fine resource so +// Elixir code can tie its lifetime to EXLA.CallbackServer. +static ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { static ElixirCallbackBridgeState *state = new ElixirCallbackBridgeState(); return state; } @@ -567,7 +560,6 @@ EncodeNxType(ErlNifEnv *env, xla::ffi::DataType dtype) { fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid) { - (void)env; auto state = GetElixirCallbackBridgeState(); state->dispatcher_pid = dispatcher_pid; state->dispatcher_set = true; @@ -576,9 +568,11 @@ fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, FINE_NIF(start_elixir_callback_bridge, 0); -fine::Ok<> elixir_callback_reply(ErlNifEnv *env, int64_t reply_tag, - fine::Term payload) { - DeliverElixirCallbackReply(env, reply_tag, payload); +fine::Ok<> +elixir_callback_reply(ErlNifEnv *env, + fine::ResourcePtr pending, + fine::Term payload) { + DeliverElixirCallbackReply(env, pending, payload); return fine::Ok(); } @@ -586,7 +580,6 @@ FINE_NIF(elixir_callback_reply, ERL_NIF_DIRTY_JOB_IO_BOUND); fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid) { - (void)env; auto state = GetElixirCallbackBridgeState(); if (state->dispatcher_set && @@ -600,20 +593,21 @@ fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, FINE_NIF(clear_elixir_callback_bridge, 0); -void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, - fine::Term payload) { - auto state = GetElixirCallbackBridgeState(); +// Allocate and return a Fine resource handle associated with the bridge. +// This lets Elixir hold a reference (e.g., in EXLA.CallbackServer state) so +// the bridge lifetime is attached to that process. The actual per-callback +// pending resources are created independently for each call. +fine::ResourcePtr +acquire_elixir_callback_bridge(ErlNifEnv *env) { + (void)env; + return fine::make_resource(); +} - std::shared_ptr pending; - { - std::lock_guard lock(state->mu); - auto it = state->pending.find(reply_tag); - if (it == state->pending.end()) { - return; - } - pending = it->second; - } +FINE_NIF(acquire_elixir_callback_bridge, 0); +void DeliverElixirCallbackReply( + ErlNifEnv *env, fine::ResourcePtr pending, + fine::Term payload) { ElixirCallbackResult result; int arity = 0; @@ -683,11 +677,6 @@ void DeliverElixirCallbackReply(ErlNifEnv *env, int64_t reply_tag, } pending->cv.notify_one(); - - { - std::lock_guard lock(state->mu); - state->pending.erase(reply_tag); - } } ElixirCallbackResult @@ -702,14 +691,7 @@ CallElixirCallback(int64_t callback_id, return res; } - auto pending = std::make_shared(); - - int64_t tag = state->next_tag.fetch_add(1, std::memory_order_relaxed); - - { - std::lock_guard lock(state->mu); - state->pending.emplace(tag, pending); - } + auto pending = fine::make_resource(); ErlNifEnv *msg_env = enif_alloc_env(); @@ -730,10 +712,6 @@ CallElixirCallback(int64_t callback_id, auto type_tuple_or = EncodeNxType(msg_env, tensor.dtype); if (!type_tuple_or.has_value()) { enif_free_env(msg_env); - { - std::lock_guard lock(state->mu); - state->pending.erase(tag); - } ElixirCallbackResult res; res.ok = false; @@ -768,14 +746,18 @@ CallElixirCallback(int64_t callback_id, ERL_NIF_TERM args_list = enif_make_list_from_array(msg_env, args_terms.data(), args_terms.size()); - ERL_NIF_TERM tag_term = enif_make_int64(msg_env, tag); + ERL_NIF_TERM pending_term = fine::encode(msg_env, pending); ERL_NIF_TERM cb_term = enif_make_int64(msg_env, callback_id); ERL_NIF_TERM msg = enif_make_tuple4(msg_env, enif_make_atom(msg_env, "exla_elixir_call"), - cb_term, args_list, tag_term); + cb_term, args_list, pending_term); - enif_send(msg_env, &state->dispatcher_pid, msg_env, msg); + // Use the dispatcher pid registered via start_elixir_callback_bridge/1. + // Calling enif_whereis_pid from this non-scheduler thread is unsafe and + // was causing a segfault. + ErlNifPid dispatcher_pid = state->dispatcher_pid; + enif_send(msg_env, &dispatcher_pid, msg_env, msg); enif_free_env(msg_env); std::unique_lock lock(pending->mu); diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index a5177fabab..57df90e2ce 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -34,14 +34,21 @@ defmodule EXLA.CallbackServer do @type callback_id :: non_neg_integer() defstruct next_id: 1, - callbacks: %{} + callbacks: %{}, + # Opaque handle to the native elixir callback bridge so its + # lifetime is tied to this server process. + bridge_ref: nil @type t :: %__MODULE__{ next_id: non_neg_integer(), # We store the original function, its output template, and any # static (non-tensor) arguments that should always be appended to # the decoded tensor arguments coming from native. - callbacks: %{callback_id() => {fun(), Nx.t() | tuple(), [term()]}} + callbacks: %{callback_id() => {fun(), Nx.t() | tuple(), [term()]}}, + # Native bridge resource. We don't use it directly in Elixir, but + # holding a reference here ensures the native bridge stays alive + # as long as this server does. + bridge_ref: term() } ## Public API @@ -73,10 +80,12 @@ defmodule EXLA.CallbackServer do @impl true def init(:ok) do - # Inform native side that this process is the dispatcher for elixir callbacks. + # Inform native side that this process is the dispatcher for elixir callbacks + # and acquire a bridge resource so its lifetime is attached to this server. _ = EXLA.NIF.start_elixir_callback_bridge(self()) + bridge_ref = EXLA.NIF.acquire_elixir_callback_bridge() - {:ok, %__MODULE__{}} + {:ok, %__MODULE__{bridge_ref: bridge_ref}} end @impl true diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 2a7bb2df13..f3d142ff37 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -83,5 +83,8 @@ defmodule EXLA.NIF do def clear_elixir_callback_bridge(_dispatcher_pid), do: err!() def elixir_callback_reply(_reply_tag, _payload), do: err!() + # Bridge resource handle so EXLA.CallbackServer can keep the bridge alive. + def acquire_elixir_callback_bridge(), do: err!() + defp err!(), do: :erlang.nif_error(:undef) end diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index d83d80748a..b1bd4d4ab6 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -11,9 +11,7 @@ defmodule EXLA.Defn.ElixirCallTest do defp add_offset_callback(t, opts) do t |> Nx.as_type(:f32) - # TODO: if we run on the same device there will be a problem due to the device locking. - |> Nx.backend_transfer({EXLA.Backend, client: :host, device_id: 1}) - |> Nx.add(opts[:offset]) |> dbg(structs: false) + |> Nx.add(opts[:offset]) end defn add_offset(x) do From bdae93c3d069229dd0216e7950950add6cd48a5b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 26 Nov 2025 03:39:40 -0300 Subject: [PATCH 13/42] refactor: improve result decoding and reduce data copies --- .../exla/custom_calls/elixir_callback.cc | 42 +++--- exla/c_src/exla/elixir_callback_bridge.h | 56 ++++++-- exla/c_src/exla/exla.cc | 134 +++++++++--------- exla/lib/exla/callback_server.ex | 4 +- exla/lib/exla/nif.ex | 2 +- 5 files changed, 129 insertions(+), 109 deletions(-) diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc index bbafe8c828..04ffaf940b 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -59,22 +59,11 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, inputs.push_back(std::move(tensor)); } - // Call back into Elixir through the bridge. - exla::ElixirCallbackResult result = - exla::CallElixirCallback(callback_id, inputs); - - if (!result.ok) { - return ffi::Error(ffi::ErrorCode::kInternal, result.error); - } - - if (result.outputs.size() != rets.size()) { - return ffi::Error( - ffi::ErrorCode::kInternal, - "mismatched number of callback outputs vs custom_call results"); - } + // Prepare output buffer descriptors so the callback bridge can write results + // directly into the final destination buffers. + std::vector outputs; + outputs.reserve(rets.size()); - // Copy returned binaries into the result buffers. We rely on the Elixir side - // (Nx.elixir_call/3) to have already validated shapes and dtypes. for (size_t i = 0; i < rets.size(); ++i) { auto maybe_ret_or = rets.get(i); if (!maybe_ret_or) { @@ -84,20 +73,21 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, ffi::Result ret = *maybe_ret_or; ffi::AnyBuffer out = *ret; - const auto &payload = result.outputs[i]; + exla::ElixirCallbackOutputBuffer buf; + buf.data = static_cast(out.untyped_data()); + buf.size = ffi::ByteWidth(out.element_type()) * + static_cast(out.element_count()); - size_t expected = - ffi::ByteWidth(out.element_type()) * out.element_count(); + outputs.push_back(buf); + } - if (payload.data.size() != expected) { - return ffi::Error( - ffi::ErrorCode::kInternal, - "callback returned binary of unexpected size for result buffer"); - } + // Call back into Elixir through the bridge. On success, the bridge writes + // results directly into the provided output buffers. + exla::ElixirCallbackResult result = + exla::CallElixirCallback(callback_id, inputs, outputs); - if (expected > 0) { - std::memcpy(out.untyped_data(), payload.data.data(), expected); - } + if (!result.ok) { + return ffi::Error(ffi::ErrorCode::kInternal, result.error); } return ffi::Error::Success(); diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h index 7b8964ddbc..299c4bc3c6 100644 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -12,14 +12,6 @@ namespace exla { -// Lightweight tensor payload used to transfer arguments and results between -// the XLA host CustomCall handler and the Elixir dispatcher. -struct ElixirCallbackTensor { - xla::ffi::DataType dtype; - std::vector dims; - std::vector data; -}; - struct ElixirCallbackArg { xla::ffi::DataType dtype; std::vector dims; @@ -27,26 +19,45 @@ struct ElixirCallbackArg { size_t size_bytes = 0; }; +// Result of an Elixir callback. On success, data has already been copied into +// the pre-registered output buffers held by ElixirCallbackPending, so we only +// need to track success or an error message here. struct ElixirCallbackResult { bool ok = false; std::string error; - std::vector outputs; +}; + +// Host-side description of an output buffer that should receive the callback +// result for a given output index. +struct ElixirCallbackOutputBuffer { + uint8_t *data = nullptr; + size_t size = 0; }; // Per-callback pending state used to synchronize between the XLA host thread // and the Elixir-side dispatcher. This is exposed as a Fine resource so we // can pass it as an opaque handle in messages instead of using integer tags. struct ElixirCallbackPending { + // Constructor used on the host callback path where we pre-register the + // destination buffers for each output. + explicit ElixirCallbackPending( + std::vector outputs) + : outputs(std::move(outputs)) {} + std::mutex mu; std::condition_variable cv; bool done = false; ElixirCallbackResult result; + std::vector outputs; }; // Called from the Elixir side to deliver a reply for a given pending handle. +// We receive the reply as a status atom (e.g. :ok or :error) and a result +// term. For the :ok case the result is a list of binaries that we decode as +// ElixirCallbackTensor outputs via Fine's decoding machinery. void DeliverElixirCallbackReply( ErlNifEnv *env, fine::ResourcePtr pending, - fine::Term payload); + fine::Atom status, fine::Term result); // Synchronously calls the Elixir callback identified by `callback_id` with the // given tensor arguments. This function: @@ -56,12 +67,31 @@ void DeliverElixirCallbackReply( // * Blocks the calling native thread until the reply arrives via // DeliverElixirCallbackReply/3 // -// It returns an ElixirCallbackResult that either contains a list of output -// tensors (on success) or an error message. +// It returns an ElixirCallbackResult that either indicates success (data has +// been written into the registered output buffers) or an error message. ElixirCallbackResult CallElixirCallback(int64_t callback_id, - const std::vector &inputs); + const std::vector &inputs, + const std::vector &outputs); } // namespace exla +namespace fine { + +// Decode a binary term into a raw byte vector. We only care about the payload +// bytes; dtype and shape are validated on the Elixir side. +template <> struct Decoder> { + static std::vector decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { + ErlNifBinary bin; + if (!enif_inspect_binary(env, term, &bin)) { + throw std::invalid_argument( + "decode failed, expected binary for callback output"); + } + + std::vector bytes; + bytes.assign(bin.data, bin.data + bin.size); + return bytes; + } +}; +} // namespace fine diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 10e600ecb7..dd81d12b27 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -34,6 +34,12 @@ FINE_RESOURCE(exla::MLIRModule); FINE_RESOURCE(exla::MLIRFunction); FINE_RESOURCE(exla::ElixirCallbackPending); +// Opaque handle type used only so Elixir can keep the bridge alive via a +// Fine resource. It carries no data; the real bridge state is the singleton +// ElixirCallbackBridgeState below. +struct ElixirCallbackBridgeHandle {}; +FINE_RESOURCE(ElixirCallbackBridgeHandle); + // MLIR Functions fine::ResourcePtr decode_exla_buffer(ErlNifEnv *env, @@ -571,8 +577,8 @@ FINE_NIF(start_elixir_callback_bridge, 0); fine::Ok<> elixir_callback_reply(ErlNifEnv *env, fine::ResourcePtr pending, - fine::Term payload) { - DeliverElixirCallbackReply(env, pending, payload); + fine::Atom status, fine::Term result) { + DeliverElixirCallbackReply(env, pending, status, result); return fine::Ok(); } @@ -595,84 +601,77 @@ FINE_NIF(clear_elixir_callback_bridge, 0); // Allocate and return a Fine resource handle associated with the bridge. // This lets Elixir hold a reference (e.g., in EXLA.CallbackServer state) so -// the bridge lifetime is attached to that process. The actual per-callback -// pending resources are created independently for each call. -fine::ResourcePtr +// the bridge lifetime is attached to that process. Per-callback pending +// resources are created independently for each call. +fine::ResourcePtr acquire_elixir_callback_bridge(ErlNifEnv *env) { (void)env; - return fine::make_resource(); + return fine::make_resource(); } FINE_NIF(acquire_elixir_callback_bridge, 0); void DeliverElixirCallbackReply( ErlNifEnv *env, fine::ResourcePtr pending, - fine::Term payload) { - ElixirCallbackResult result; - - int arity = 0; - const ERL_NIF_TERM *tuple = nullptr; - ERL_NIF_TERM term = payload; - - if (!enif_get_tuple(env, term, &arity, &tuple) || arity != 2) { - result.ok = false; - result.error = "invalid callback reply payload, expected {status, value}"; - } else { - char atom_buf[16]; - if (enif_get_atom(env, tuple[0], atom_buf, sizeof(atom_buf), - ERL_NIF_LATIN1) && - strcmp(atom_buf, "ok") == 0) { - // tuple[1] is a list of binaries representing outputs. - ERL_NIF_TERM list = tuple[1]; - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) { - ErlNifBinary bin; - if (!enif_inspect_binary(env, head, &bin)) { - result.ok = false; - result.error = "invalid binary in callback reply"; - break; - } - - ElixirCallbackTensor tensor; - tensor.dtype = xla::ffi::DataType::INVALID; - tensor.dims = {}; - tensor.data.assign(bin.data, bin.data + bin.size); - result.outputs.push_back(std::move(tensor)); - - list = tail; - } - - if (result.error.empty()) { - result.ok = true; - } - } else { - // Error reply: tuple[1] is expected to be {kind_atom, message :: binary} - result.ok = false; - ERL_NIF_TERM err_term = tuple[1]; - - int err_arity = 0; - const ERL_NIF_TERM *err_tuple = nullptr; - if (enif_get_tuple(env, err_term, &err_arity, &err_tuple) && - err_arity == 2) { - // We ignore the kind atom for now (e.g. :argument_error or - // :runtime_error) and use only the message as the XLA error text. - ErlNifBinary msg_bin; - if (enif_inspect_binary(env, err_tuple[1], &msg_bin)) { - result.error.assign(reinterpret_cast(msg_bin.data), - msg_bin.size); - } else { - result.error = "elixir callback returned error"; - } + fine::Atom status, fine::Term result_term) { + ElixirCallbackResult cb_result; + + if (status == "ok") { + // Successful reply: result_term is a list of binaries that we decode into + // raw byte vectors via Fine and copy directly into the registered output + // buffers. + try { + auto payloads = + fine::decode>>(env, result_term); + + std::lock_guard lock(pending->mu); + + if (payloads.size() != pending->outputs.size()) { + cb_result.ok = false; + cb_result.error = + "mismatched number of callback outputs vs registered buffers"; } else { - result.error = "elixir callback returned error"; + cb_result.ok = true; + + for (size_t i = 0; i < payloads.size(); ++i) { + const auto &bytes = payloads[i]; + auto &out_buf = pending->outputs[i]; + + if (bytes.size() != out_buf.size) { + cb_result.ok = false; + cb_result.error = + "callback returned binary of unexpected size for result buffer"; + break; + } + + if (out_buf.size > 0) { + std::memcpy(out_buf.data, bytes.data(), out_buf.size); + } + } } + } catch (const std::exception &e) { + cb_result.ok = false; + cb_result.error = + std::string("failed to decode Elixir callback outputs: ") + e.what(); + } + } else { + // Error reply: result_term is expected to be {kind_atom, message :: binary} + cb_result.ok = false; + + try { + auto decoded = + fine::decode>(env, result_term); + ErlNifBinary msg_bin = std::get<1>(decoded); + cb_result.error.assign(reinterpret_cast(msg_bin.data), + msg_bin.size); + } catch (const std::exception &) { + cb_result.error = "elixir callback returned error"; } } { std::lock_guard lock(pending->mu); - pending->result = std::move(result); + pending->result = std::move(cb_result); pending->done = true; } @@ -681,7 +680,8 @@ void DeliverElixirCallbackReply( ElixirCallbackResult CallElixirCallback(int64_t callback_id, - const std::vector &inputs) { + const std::vector &inputs, + const std::vector &outputs) { auto state = GetElixirCallbackBridgeState(); if (!state->dispatcher_set) { @@ -691,7 +691,7 @@ CallElixirCallback(int64_t callback_id, return res; } - auto pending = fine::make_resource(); + auto pending = fine::make_resource(outputs); ErlNifEnv *msg_env = enif_alloc_env(); diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 57df90e2ce..8efb7addc8 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -294,9 +294,9 @@ defmodule EXLA.CallbackServer do |> Enum.map(&Nx.to_binary/1) end - defp send_reply(reply_tag, payload) do + defp send_reply(reply_tag, {status, result}) do try do - EXLA.NIF.elixir_callback_reply(reply_tag, payload) + EXLA.NIF.elixir_callback_reply(reply_tag, status, result) rescue _ -> Logger.error( diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index f3d142ff37..98ff89b02e 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -81,7 +81,7 @@ defmodule EXLA.NIF do # Elixir callback bridge (Phase 1: CPU-only, simple APIs) def start_elixir_callback_bridge(_dispatcher_pid), do: err!() def clear_elixir_callback_bridge(_dispatcher_pid), do: err!() - def elixir_callback_reply(_reply_tag, _payload), do: err!() + def elixir_callback_reply(_reply_tag, _status, _result), do: err!() # Bridge resource handle so EXLA.CallbackServer can keep the bridge alive. def acquire_elixir_callback_bridge(), do: err!() From cb8c3452f3621714ffa6343599efbbcfc51dc40e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:32:04 -0300 Subject: [PATCH 14/42] refactor: implement fine encoder for ffi types --- exla/c_src/exla/elixir_callback_bridge.h | 139 ++++++++++++++ exla/c_src/exla/exla.cc | 53 +----- exla/c_src/exla/exla_nif_util.h | 219 ++++------------------- exla/lib/exla/callback_server.ex | 28 +-- 4 files changed, 195 insertions(+), 244 deletions(-) diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h index 299c4bc3c6..ce55f8e823 100644 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -94,4 +94,143 @@ template <> struct Decoder> { } }; +// Define encoding for {ffi_dtype, dims} into %EXLA.Typespec{} term. This is +// used by the Elixir callback bridge to surface type and shape information +// about callback arguments to the Elixir side. +template <> +struct Encoder>> { + static ERL_NIF_TERM + encode(ErlNifEnv *env, + const std::tuple> &spec) { + const xla::ffi::DataType &dtype = std::get<0>(spec); + const std::vector &dims = std::get<1>(spec); + + ERL_NIF_TERM keys[] = { + fine::encode(env, fine::Atom("__struct__")), + fine::encode(env, fine::Atom("type")), + fine::encode(env, fine::Atom("shape")), + }; + + ERL_NIF_TERM values[] = { + fine::encode(env, fine::Atom("Elixir.EXLA.Typespec")), + encode_type(env, dtype), + encode_shape(env, dtype, dims), + }; + + ERL_NIF_TERM map; + if (!enif_make_map_from_arrays(env, keys, values, 3, &map)) { + throw std::runtime_error("encode: failed to make a map"); + } + + return map; + } + +private: + static ERL_NIF_TERM encode_type(ErlNifEnv *env, xla::ffi::DataType dtype) { + using DT = xla::ffi::DataType; + + // Tokens are encoded as the atom :token with empty shape. + if (dtype == DT::TOKEN) { + return fine::encode(env, fine::Atom("token")); + } + + std::optional type_name; + std::optional type_size; + + switch (dtype) { + case DT::PRED: + type_name = fine::Atom("pred"); + type_size = 8; + break; + + case DT::U8: + type_name = fine::Atom("u"); + type_size = 8; + break; + case DT::U16: + type_name = fine::Atom("u"); + type_size = 16; + break; + case DT::U32: + type_name = fine::Atom("u"); + type_size = 32; + break; + case DT::U64: + type_name = fine::Atom("u"); + type_size = 64; + break; + + case DT::S8: + type_name = fine::Atom("s"); + type_size = 8; + break; + case DT::S16: + type_name = fine::Atom("s"); + type_size = 16; + break; + case DT::S32: + type_name = fine::Atom("s"); + type_size = 32; + break; + case DT::S64: + type_name = fine::Atom("s"); + type_size = 64; + break; + + case DT::F16: + type_name = fine::Atom("f"); + type_size = 16; + break; + case DT::F32: + type_name = fine::Atom("f"); + type_size = 32; + break; + case DT::F64: + type_name = fine::Atom("f"); + type_size = 64; + break; + + case DT::BF16: + type_name = fine::Atom("bf"); + type_size = 16; + break; + + case DT::C64: + type_name = fine::Atom("c"); + type_size = 64; + break; + case DT::C128: + type_name = fine::Atom("c"); + type_size = 128; + break; + + default: + break; + } + + if (type_name && type_size) { + return fine::encode( + env, std::make_tuple(type_name.value(), type_size.value())); + } + + throw std::invalid_argument("encode failed, unexpected ffi::DataType"); + } + + static ERL_NIF_TERM encode_shape(ErlNifEnv *env, xla::ffi::DataType dtype, + const std::vector &dims) { + if (dtype == xla::ffi::DataType::TOKEN) { + return enif_make_tuple(env, 0); + } + + std::vector dim_terms; + dim_terms.reserve(dims.size()); + + for (auto d : dims) { + dim_terms.push_back(fine::encode(env, d)); + } + + return enif_make_tuple_from_array(env, dim_terms.data(), dim_terms.size()); + } +}; + } // namespace fine diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index dd81d12b27..6bb9db52df 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -548,20 +548,6 @@ static ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { return state; } -// Map ffi::DataType to a Nx-style {atom, bits} pair used on the Elixir side. -std::optional> -EncodeNxType(ErlNifEnv *env, xla::ffi::DataType dtype) { - if (auto primitive = exla::PrimitiveTypeFromFfiDataType(dtype)) { - if (auto info = exla::PrimitiveTypeToNxTypeInfo(*primitive)) { - ERL_NIF_TERM atom_term = enif_make_atom(env, info->atom_name); - ERL_NIF_TERM bits_term = enif_make_int(env, info->bits); - return std::make_pair(atom_term, bits_term); - } - } - - return std::nullopt; -} - } // namespace fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, @@ -695,9 +681,8 @@ CallElixirCallback(int64_t callback_id, ErlNifEnv *msg_env = enif_alloc_env(); - // Encode arguments as [{bin, {type, bits}, shape_tuple}, ...]. We currently - // send plain binaries because the BEAM callback needs to own the data - // lifetime. + // Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send + // plain binaries because the BEAM callback needs to own the data lifetime. std::vector args_terms; args_terms.reserve(inputs.size()); @@ -709,36 +694,12 @@ CallElixirCallback(int64_t callback_id, memcpy(bin_data, tensor.data, tensor.size_bytes); } - auto type_tuple_or = EncodeNxType(msg_env, tensor.dtype); - if (!type_tuple_or.has_value()) { - enif_free_env(msg_env); - - ElixirCallbackResult res; - res.ok = false; - res.error = "unsupported tensor type in EXLA callback argument"; - return res; - } - - auto type_info = type_tuple_or.value(); - ERL_NIF_TERM type_tuple = - enif_make_tuple2(msg_env, type_info.first, type_info.second); - - std::vector dim_terms; - dim_terms.reserve(tensor.dims.size()); - for (auto d : tensor.dims) { - dim_terms.push_back(enif_make_int64(msg_env, d)); - } - - ERL_NIF_TERM shape_tuple; - if (dim_terms.empty()) { - shape_tuple = enif_make_tuple(msg_env, 0); - } else { - shape_tuple = enif_make_tuple_from_array(msg_env, dim_terms.data(), - dim_terms.size()); - } + // Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via + // Fine's encoder defined in exla_nif_util.h. + ERL_NIF_TERM typespec_term = + fine::encode(msg_env, std::make_tuple(tensor.dtype, tensor.dims)); - ERL_NIF_TERM arg_tuple = - enif_make_tuple3(msg_env, bin_term, type_tuple, shape_tuple); + ERL_NIF_TERM arg_tuple = enif_make_tuple2(msg_env, bin_term, typespec_term); args_terms.push_back(arg_tuple); } diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 65caec2f0f..b2babd53cb 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -6,7 +6,6 @@ #include "mlir/IR/Types.h" #include "stablehlo/dialect/StablehloOps.h" -#include "xla/ffi/api/ffi.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -29,187 +28,6 @@ static auto type = fine::Atom("type"); static auto u = fine::Atom("u"); static auto warning = fine::Atom("warning"); } // namespace atoms - -struct NxTypeInfo { - const char *atom_name; - const fine::Atom *atom_ref; - uint64_t bits; - - fine::Atom atom() const { - if (atom_ref) { - return *atom_ref; - } - return fine::Atom(atom_name); - } -}; - -inline std::optional -PrimitiveTypeToNxTypeInfo(xla::PrimitiveType type) { - switch (type) { - case xla::PRED: - return NxTypeInfo{"pred", &atoms::pred, 8}; - case xla::S2: - return NxTypeInfo{"s", &atoms::s, 2}; - case xla::S4: - return NxTypeInfo{"s", &atoms::s, 4}; - case xla::S8: - return NxTypeInfo{"s", &atoms::s, 8}; - case xla::S16: - return NxTypeInfo{"s", &atoms::s, 16}; - case xla::S32: - return NxTypeInfo{"s", &atoms::s, 32}; - case xla::S64: - return NxTypeInfo{"s", &atoms::s, 64}; - case xla::U2: - return NxTypeInfo{"u", &atoms::u, 2}; - case xla::U4: - return NxTypeInfo{"u", &atoms::u, 4}; - case xla::U8: - return NxTypeInfo{"u", &atoms::u, 8}; - case xla::U16: - return NxTypeInfo{"u", &atoms::u, 16}; - case xla::U32: - return NxTypeInfo{"u", &atoms::u, 32}; - case xla::U64: - return NxTypeInfo{"u", &atoms::u, 64}; - case xla::F8E4M3FN: - case xla::F8E5M2: - return NxTypeInfo{"f", &atoms::f, 8}; - case xla::F16: - return NxTypeInfo{"f", &atoms::f, 16}; - case xla::BF16: - return NxTypeInfo{"bf", &atoms::bf, 16}; - case xla::F32: - return NxTypeInfo{"f", &atoms::f, 32}; - case xla::F64: - return NxTypeInfo{"f", &atoms::f, 64}; - case xla::C64: - return NxTypeInfo{"c", &atoms::c, 64}; - case xla::C128: - return NxTypeInfo{"c", &atoms::c, 128}; - default: - return std::nullopt; - } -} - -inline std::optional -PrimitiveTypeFromFfiDataType(xla::ffi::DataType dtype) { - switch (dtype) { - case xla::ffi::PRED: - return xla::PRED; - case xla::ffi::S2: - return xla::S2; - case xla::ffi::S4: - return xla::S4; - case xla::ffi::S8: - return xla::S8; - case xla::ffi::S16: - return xla::S16; - case xla::ffi::S32: - return xla::S32; - case xla::ffi::S64: - return xla::S64; - case xla::ffi::U2: - return xla::U2; - case xla::ffi::U4: - return xla::U4; - case xla::ffi::U8: - return xla::U8; - case xla::ffi::U16: - return xla::U16; - case xla::ffi::U32: - return xla::U32; - case xla::ffi::U64: - return xla::U64; - case xla::ffi::F8E4M3FN: - return xla::F8E4M3FN; - case xla::ffi::F8E5M2: - return xla::F8E5M2; - case xla::ffi::F16: - return xla::F16; - case xla::ffi::BF16: - return xla::BF16; - case xla::ffi::F32: - return xla::F32; - case xla::ffi::F64: - return xla::F64; - case xla::ffi::C64: - return xla::C64; - case xla::ffi::C128: - return xla::C128; - default: - return std::nullopt; - } -} - -inline std::optional -PrimitiveTypeFromMlirElement(const mlir::Type &element_type) { - if (element_type.isSignlessInteger(1)) { - return xla::PRED; - } - - if (auto integer_type = mlir::dyn_cast(element_type)) { - int width = integer_type.getWidth(); - if (integer_type.isUnsigned()) { - switch (width) { - case 2: - return xla::U2; - case 4: - return xla::U4; - case 8: - return xla::U8; - case 16: - return xla::U16; - case 32: - return xla::U32; - case 64: - return xla::U64; - } - } else { - switch (width) { - case 2: - return xla::S2; - case 4: - return xla::S4; - case 8: - return xla::S8; - case 16: - return xla::S16; - case 32: - return xla::S32; - case 64: - return xla::S64; - } - } - } else if (element_type.isBF16()) { - return xla::BF16; - } else if (auto float_type = mlir::dyn_cast(element_type)) { - int width = float_type.getWidth(); - switch (width) { - case 8: - return xla::F8E4M3FN; - case 16: - return xla::F16; - case 32: - return xla::F32; - case 64: - return xla::F64; - } - } else if (auto complex_type = - mlir::dyn_cast(element_type)) { - auto inner = complex_type.getElementType(); - if (auto float_type = mlir::dyn_cast(inner)) { - switch (float_type.getWidth()) { - case 32: - return xla::C64; - case 64: - return xla::C128; - } - } - } - - return std::nullopt; -} } // namespace exla namespace fine { @@ -358,17 +176,46 @@ template <> struct Encoder { return fine::encode(env, exla::atoms::token); } + std::optional type_name; + std::optional type_size; + if (mlir::isa(type)) { auto tensor_type = mlir::cast(type); auto element_type = tensor_type.getElementType(); - if (auto primitive = exla::PrimitiveTypeFromMlirElement(element_type)) { - if (auto info = exla::PrimitiveTypeToNxTypeInfo(*primitive)) { - return fine::encode(env, std::make_tuple(info->atom(), info->bits)); + + if (element_type.isSignlessInteger(1)) { + type_name = exla::atoms::pred; + type_size = 8; + } else if (auto integer_type = + mlir::dyn_cast(element_type)) { + if (integer_type.isUnsigned()) { + type_name = exla::atoms::u; + } else { + type_name = exla::atoms::s; } + + type_size = integer_type.getWidth(); + } else if (element_type.isBF16()) { + type_name = exla::atoms::bf; + type_size = 16; + } else if (auto float_type = + mlir::dyn_cast(element_type)) { + type_name = exla::atoms::f; + type_size = float_type.getWidth(); + } else if (auto complex_type = + mlir::dyn_cast(element_type)) { + auto element_type = complex_type.getElementType(); + type_name = exla::atoms::c; + type_size = mlir::cast(element_type).getWidth() * 2; } } - throw std::invalid_argument("encode failed, unexpected mlir type"); + if (type_name) { + return fine::encode( + env, std::make_tuple(type_name.value(), type_size.value())); + } else { + throw std::invalid_argument("encode failed, unexpected mlir type"); + } } static ERL_NIF_TERM encode_shape(ErlNifEnv *env, const mlir::Type &type) { diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 8efb7addc8..5f740f1b82 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -206,18 +206,22 @@ defmodule EXLA.CallbackServer do defp decode_args(args_spec) when is_list(args_spec) do result = - Enum.reduce_while(args_spec, {:ok, []}, fn {bin, {type, bits}, shape}, {:ok, acc} -> - try do - tensor = - bin - |> Nx.from_binary({type, bits}) - |> Nx.reshape(shape) - - {:cont, {:ok, [tensor | acc]}} - rescue - exception -> - {:halt, {:error, {:decode_failed, exception}}} - end + Enum.reduce_while(args_spec, {:ok, []}, fn + {bin, %EXLA.Typespec{type: type, shape: shape}}, {:ok, acc} -> + try do + tensor = + bin + |> Nx.from_binary(type) + |> Nx.reshape(shape) + + {:cont, {:ok, [tensor | acc]}} + rescue + exception -> + {:halt, {:error, {:decode_failed, exception}}} + end + + other, _acc -> + {:halt, {:error, {:invalid_args_spec, other}}} end) case result do From a79b0376c1241816231f2a35d9e23b36d2f793d4 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:39:42 -0300 Subject: [PATCH 15/42] docs: update docs --- exla/c_src/exla/elixir_callback_bridge.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h index ce55f8e823..daed144f61 100644 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -62,7 +62,7 @@ void DeliverElixirCallbackReply( // Synchronously calls the Elixir callback identified by `callback_id` with the // given tensor arguments. This function: // -// * Allocates a unique reply_tag +// * Allocates a unique ElixirCallbackPending resource // * Sends a message to the dispatcher via enif_send/3 // * Blocks the calling native thread until the reply arrives via // DeliverElixirCallbackReply/3 From 658b8bfbfa35f132de0ab2f9ab24322ae2c728ca Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:43:34 -0300 Subject: [PATCH 16/42] chore: use exla nif atoms --- exla/c_src/exla/elixir_callback_bridge.h | 51 +++++++++++------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h index daed144f61..f5310939e9 100644 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -6,9 +6,10 @@ #include #include +#include "exla_nif_util.h" +#include "xla/ffi/api/ffi.h" #include #include -#include "xla/ffi/api/ffi.h" namespace exla { @@ -105,17 +106,13 @@ struct Encoder>> { const xla::ffi::DataType &dtype = std::get<0>(spec); const std::vector &dims = std::get<1>(spec); - ERL_NIF_TERM keys[] = { - fine::encode(env, fine::Atom("__struct__")), - fine::encode(env, fine::Atom("type")), - fine::encode(env, fine::Atom("shape")), - }; + ERL_NIF_TERM keys[] = {fine::encode(env, exla::atoms::__struct__), + fine::encode(env, exla::atoms::type), + fine::encode(env, exla::atoms::shape)}; - ERL_NIF_TERM values[] = { - fine::encode(env, fine::Atom("Elixir.EXLA.Typespec")), - encode_type(env, dtype), - encode_shape(env, dtype, dims), - }; + ERL_NIF_TERM values[] = {fine::encode(env, exla::atoms::ElixirEXLATypespec), + encode_type(env, dtype), + encode_shape(env, dtype, dims)}; ERL_NIF_TERM map; if (!enif_make_map_from_arrays(env, keys, values, 3, &map)) { @@ -131,7 +128,7 @@ struct Encoder>> { // Tokens are encoded as the atom :token with empty shape. if (dtype == DT::TOKEN) { - return fine::encode(env, fine::Atom("token")); + return fine::encode(env, exla::atoms::token); } std::optional type_name; @@ -139,68 +136,68 @@ struct Encoder>> { switch (dtype) { case DT::PRED: - type_name = fine::Atom("pred"); + type_name = exla::atoms::pred; type_size = 8; break; case DT::U8: - type_name = fine::Atom("u"); + type_name = exla::atoms::u; type_size = 8; break; case DT::U16: - type_name = fine::Atom("u"); + type_name = exla::atoms::u; type_size = 16; break; case DT::U32: - type_name = fine::Atom("u"); + type_name = exla::atoms::u; type_size = 32; break; case DT::U64: - type_name = fine::Atom("u"); + type_name = exla::atoms::u; type_size = 64; break; case DT::S8: - type_name = fine::Atom("s"); + type_name = exla::atoms::s; type_size = 8; break; case DT::S16: - type_name = fine::Atom("s"); + type_name = exla::atoms::s; type_size = 16; break; case DT::S32: - type_name = fine::Atom("s"); + type_name = exla::atoms::s; type_size = 32; break; case DT::S64: - type_name = fine::Atom("s"); + type_name = exla::atoms::s; type_size = 64; break; case DT::F16: - type_name = fine::Atom("f"); + type_name = exla::atoms::f; type_size = 16; break; case DT::F32: - type_name = fine::Atom("f"); + type_name = exla::atoms::f; type_size = 32; break; case DT::F64: - type_name = fine::Atom("f"); + type_name = exla::atoms::f; type_size = 64; break; case DT::BF16: - type_name = fine::Atom("bf"); + type_name = exla::atoms::bf; type_size = 16; break; case DT::C64: - type_name = fine::Atom("c"); + type_name = exla::atoms::c; type_size = 64; break; case DT::C128: - type_name = fine::Atom("c"); + type_name = exla::atoms::c; type_size = 128; break; From 6d4d3270d379b4830f1cf9a55de035ab09b36917 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 26 Nov 2025 05:15:36 -0300 Subject: [PATCH 17/42] refactor: reorganize files and namespace things --- .../exla/custom_calls/elixir_callback.cc | 4 +- .../custom_calls/elixir_callback_bridge.cc | 188 +++++++++++++ .../custom_calls/elixir_callback_bridge.h | 256 ++++++++++++++++++ exla/c_src/exla/elixir_callback_bridge.h | 6 +- exla/c_src/exla/exla.cc | 205 +------------- 5 files changed, 459 insertions(+), 200 deletions(-) create mode 100644 exla/c_src/exla/custom_calls/elixir_callback_bridge.cc create mode 100644 exla/c_src/exla/custom_calls/elixir_callback_bridge.h diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc index 04ffaf940b..cdd5b1b0cf 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -1,4 +1,4 @@ -#include "../elixir_callback_bridge.h" +#include "elixir_callback_bridge.h" #include #include @@ -84,7 +84,7 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, // Call back into Elixir through the bridge. On success, the bridge writes // results directly into the provided output buffers. exla::ElixirCallbackResult result = - exla::CallElixirCallback(callback_id, inputs, outputs); + exla::callback_bridge::InvokeElixirCallback(callback_id, inputs, outputs); if (!result.ok) { return ffi::Error(ffi::ErrorCode::kInternal, result.error); diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc new file mode 100644 index 0000000000..e4d716dde1 --- /dev/null +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -0,0 +1,188 @@ +#include "elixir_callback_bridge.h" + +#include + +namespace exla { + +namespace callback_bridge { + +struct ElixirCallbackBridgeState { + ErlNifPid dispatcher_pid; + bool dispatcher_set = false; +}; + +ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { + static ElixirCallbackBridgeState *state = new ElixirCallbackBridgeState(); + return state; +} + +fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid) { + (void)env; + auto state = GetElixirCallbackBridgeState(); + state->dispatcher_pid = dispatcher_pid; + state->dispatcher_set = true; + return fine::Ok(); +} + +fine::Ok<> elixir_callback_reply( + ErlNifEnv *env, fine::ResourcePtr pending, + fine::Atom status, fine::Term result) { + DeliverElixirCallbackReply(env, pending, status, result); + return fine::Ok(); +} + +fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid) { + (void)env; + auto state = GetElixirCallbackBridgeState(); + + if (state->dispatcher_set && + std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) == + 0) { + state->dispatcher_set = false; + } + + return fine::Ok(); +} + +fine::ResourcePtr +acquire_elixir_callback_bridge(ErlNifEnv *env) { + (void)env; + return fine::make_resource(); +} + +void DeliverElixirCallbackReply( + ErlNifEnv *env, fine::ResourcePtr pending, + fine::Atom status, fine::Term result_term) { + ElixirCallbackResult cb_result; + + if (status == "ok") { + // Successful reply: result_term is a list of binaries that we decode into + // raw byte vectors via Fine and copy directly into the registered output + // buffers. + try { + auto payloads = + fine::decode>>(env, result_term); + + std::lock_guard lock(pending->mu); + + if (payloads.size() != pending->outputs.size()) { + cb_result.ok = false; + cb_result.error = + "mismatched number of callback outputs vs registered buffers"; + } else { + cb_result.ok = true; + + for (size_t i = 0; i < payloads.size(); ++i) { + const auto &bytes = payloads[i]; + auto &out_buf = pending->outputs[i]; + + if (bytes.size() != out_buf.size) { + cb_result.ok = false; + cb_result.error = + "callback returned binary of unexpected size for result buffer"; + break; + } + + if (out_buf.size > 0) { + std::memcpy(out_buf.data, bytes.data(), out_buf.size); + } + } + } + } catch (const std::exception &e) { + cb_result.ok = false; + cb_result.error = + std::string("failed to decode Elixir callback outputs: ") + e.what(); + } + } else { + // Error reply: result_term is expected to be {kind_atom, message :: binary} + cb_result.ok = false; + + try { + auto decoded = + fine::decode>(env, result_term); + ErlNifBinary msg_bin = std::get<1>(decoded); + cb_result.error.assign(reinterpret_cast(msg_bin.data), + msg_bin.size); + } catch (const std::exception &) { + cb_result.error = "elixir callback returned error"; + } + } + + { + std::lock_guard lock(pending->mu); + pending->result = std::move(cb_result); + pending->done = true; + } + + pending->cv.notify_one(); +} + +ElixirCallbackResult InvokeElixirCallback( + int64_t callback_id, const std::vector &inputs, + const std::vector &outputs) { + auto state = GetElixirCallbackBridgeState(); + + if (!state->dispatcher_set) { + ElixirCallbackResult res; + res.ok = false; + res.error = "EXLA elixir callback dispatcher is not set"; + return res; + } + + auto pending = fine::make_resource(outputs); + + ErlNifEnv *msg_env = enif_alloc_env(); + + // Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send + // plain binaries because the BEAM callback needs to own the data lifetime. + std::vector args_terms; + args_terms.reserve(inputs.size()); + + for (const auto &tensor : inputs) { + ERL_NIF_TERM bin_term; + unsigned char *bin_data = + enif_make_new_binary(msg_env, tensor.size_bytes, &bin_term); + if (tensor.size_bytes > 0) { + memcpy(bin_data, tensor.data, tensor.size_bytes); + } + + // Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via + // Fine's encoder defined in exla_nif_util.h. + ERL_NIF_TERM typespec_term = + fine::encode(msg_env, std::make_tuple(tensor.dtype, tensor.dims)); + + ERL_NIF_TERM arg_tuple = enif_make_tuple2(msg_env, bin_term, typespec_term); + + args_terms.push_back(arg_tuple); + } + + ERL_NIF_TERM args_list = + enif_make_list_from_array(msg_env, args_terms.data(), args_terms.size()); + + ERL_NIF_TERM pending_term = fine::encode(msg_env, pending); + ERL_NIF_TERM cb_term = enif_make_int64(msg_env, callback_id); + + ERL_NIF_TERM msg = + enif_make_tuple4(msg_env, enif_make_atom(msg_env, "exla_elixir_call"), + cb_term, args_list, pending_term); + + // Use the dispatcher pid registered via start_elixir_callback_bridge/1. + // Calling enif_whereis_pid from this non-scheduler thread is unsafe and + // was causing a segfault. + ErlNifPid dispatcher_pid = state->dispatcher_pid; + enif_send(msg_env, &dispatcher_pid, msg_env, msg); + enif_free_env(msg_env); + + std::unique_lock lock(pending->mu); + pending->cv.wait(lock, [&pending] { return pending->done; }); + + return pending->result; +} + +} // namespace callback_bridge + +} // namespace exla + + diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h new file mode 100644 index 0000000000..a7753b3acb --- /dev/null +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h @@ -0,0 +1,256 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "../exla_nif_util.h" +#include "xla/ffi/api/ffi.h" +#include +#include + +namespace exla { + +struct ElixirCallbackArg { + xla::ffi::DataType dtype; + std::vector dims; + const uint8_t *data = nullptr; + size_t size_bytes = 0; +}; + +// Result of an Elixir callback. On success, data has already been copied into +// the pre-registered output buffers held by ElixirCallbackPending, so we only +// need to track success or an error message here. +struct ElixirCallbackResult { + bool ok = false; + std::string error; +}; + +// Host-side description of an output buffer that should receive the callback +// result for a given output index. +struct ElixirCallbackOutputBuffer { + uint8_t *data = nullptr; + size_t size = 0; +}; + +namespace callback_bridge { + +// Opaque handle type used only so Elixir can keep the bridge alive via a +// Fine resource. It carries no data; the real bridge state is stored +// internally in the bridge implementation. +struct ElixirCallbackBridgeHandle {}; + +// Per-callback pending state used to synchronize between the XLA host thread +// and the Elixir-side dispatcher. This is exposed as a Fine resource so we +// can pass it as an opaque handle in messages instead of using integer tags. +struct ElixirCallbackPending { + // Constructor used on the host callback path where we pre-register the + // destination buffers for each output. + explicit ElixirCallbackPending( + std::vector outputs) + : outputs(std::move(outputs)) {} + + std::mutex mu; + std::condition_variable cv; + bool done = false; + ElixirCallbackResult result; + std::vector outputs; +}; + +// Called from the Elixir side to deliver a reply for a given pending handle. +// We receive the reply as a status atom (e.g. :ok or :error) and a result +// term. For the :ok case the result is a list of binaries that we decode as +// ElixirCallbackTensor outputs via Fine's decoding machinery. +void DeliverElixirCallbackReply( + ErlNifEnv *env, fine::ResourcePtr pending, + fine::Atom status, fine::Term result); + +// Synchronously calls the Elixir callback identified by `callback_id` with the +// given tensor arguments. This function: +// +// * Allocates a unique ElixirCallbackPending resource +// * Sends a message to the dispatcher via enif_send/3 +// * Blocks the calling native thread until the reply arrives via +// DeliverElixirCallbackReply/3 +// +// It returns an ElixirCallbackResult that either indicates success (data has +// been written into the registered output buffers) or an error message. +ElixirCallbackResult InvokeElixirCallback( + int64_t callback_id, const std::vector &inputs, + const std::vector &outputs); + +fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid); + +fine::Ok<> elixir_callback_reply( + ErlNifEnv *env, fine::ResourcePtr pending, + fine::Atom status, fine::Term result); + +fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid); + +fine::ResourcePtr +acquire_elixir_callback_bridge(ErlNifEnv *env); + +} // namespace callback_bridge + +} // namespace exla + +namespace fine { + +// Decode a binary term into a raw byte vector. We only care about the payload +// bytes; dtype and shape are validated on the Elixir side. +template <> struct Decoder> { + static std::vector decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { + ErlNifBinary bin; + if (!enif_inspect_binary(env, term, &bin)) { + throw std::invalid_argument( + "decode failed, expected binary for callback output"); + } + + std::vector bytes; + bytes.assign(bin.data, bin.data + bin.size); + return bytes; + } +}; + +// Define encoding for {ffi_dtype, dims} into %EXLA.Typespec{} term. This is +// used by the Elixir callback bridge to surface type and shape information +// about callback arguments to the Elixir side. +template <> +struct Encoder>> { + static ERL_NIF_TERM + encode(ErlNifEnv *env, + const std::tuple> &spec) { + const xla::ffi::DataType &dtype = std::get<0>(spec); + const std::vector &dims = std::get<1>(spec); + + ERL_NIF_TERM keys[] = {fine::encode(env, exla::atoms::__struct__), + fine::encode(env, exla::atoms::type), + fine::encode(env, exla::atoms::shape)}; + + ERL_NIF_TERM values[] = {fine::encode(env, exla::atoms::ElixirEXLATypespec), + encode_type(env, dtype), + encode_shape(env, dtype, dims)}; + + ERL_NIF_TERM map; + if (!enif_make_map_from_arrays(env, keys, values, 3, &map)) { + throw std::runtime_error("encode: failed to make a map"); + } + + return map; + } + +private: + static ERL_NIF_TERM encode_type(ErlNifEnv *env, xla::ffi::DataType dtype) { + using DT = xla::ffi::DataType; + + // Tokens are encoded as the atom :token with empty shape. + if (dtype == DT::TOKEN) { + return fine::encode(env, exla::atoms::token); + } + + std::optional type_name; + std::optional type_size; + + switch (dtype) { + case DT::PRED: + type_name = exla::atoms::pred; + type_size = 8; + break; + + case DT::U8: + type_name = exla::atoms::u; + type_size = 8; + break; + case DT::U16: + type_name = exla::atoms::u; + type_size = 16; + break; + case DT::U32: + type_name = exla::atoms::u; + type_size = 32; + break; + case DT::U64: + type_name = exla::atoms::u; + type_size = 64; + break; + + case DT::S8: + type_name = exla::atoms::s; + type_size = 8; + break; + case DT::S16: + type_name = exla::atoms::s; + type_size = 16; + break; + case DT::S32: + type_name = exla::atoms::s; + type_size = 32; + break; + case DT::S64: + type_name = exla::atoms::s; + type_size = 64; + break; + + case DT::F16: + type_name = exla::atoms::f; + type_size = 16; + break; + case DT::F32: + type_name = exla::atoms::f; + type_size = 32; + break; + case DT::F64: + type_name = exla::atoms::f; + type_size = 64; + break; + + case DT::BF16: + type_name = exla::atoms::bf; + type_size = 16; + break; + + case DT::C64: + type_name = exla::atoms::c; + type_size = 64; + break; + case DT::C128: + type_name = exla::atoms::c; + type_size = 128; + break; + + default: + break; + } + + if (type_name && type_size) { + return fine::encode( + env, std::make_tuple(type_name.value(), type_size.value())); + } + + throw std::invalid_argument("encode failed, unexpected ffi::DataType"); + } + + static ERL_NIF_TERM encode_shape(ErlNifEnv *env, xla::ffi::DataType dtype, + const std::vector &dims) { + if (dtype == xla::ffi::DataType::TOKEN) { + return enif_make_tuple(env, 0); + } + + std::vector dim_terms; + dim_terms.reserve(dims.size()); + + for (auto d : dims) { + dim_terms.push_back(fine::encode(env, d)); + } + + return enif_make_tuple_from_array(env, dim_terms.data(), dim_terms.size()); + } +}; + +} // namespace fine + + diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h index f5310939e9..7402ebd803 100644 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ b/exla/c_src/exla/elixir_callback_bridge.h @@ -71,9 +71,9 @@ void DeliverElixirCallbackReply( // It returns an ElixirCallbackResult that either indicates success (data has // been written into the registered output buffers) or an error message. ElixirCallbackResult -CallElixirCallback(int64_t callback_id, - const std::vector &inputs, - const std::vector &outputs); +InvokeElixirCallback(int64_t callback_id, + const std::vector &inputs, + const std::vector &outputs); } // namespace exla diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 6bb9db52df..fcc391d29e 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -6,7 +6,7 @@ #include #include -#include "elixir_callback_bridge.h" +#include "custom_calls/elixir_callback_bridge.h" #include "exla_client.h" #include "exla_cuda.h" #include "exla_log_sink.h" @@ -23,6 +23,9 @@ namespace exla { +using callback_bridge::ElixirCallbackBridgeHandle; +using callback_bridge::ElixirCallbackPending; + FINE_RESOURCE(llvm::StdThreadPool); FINE_RESOURCE(mlir::MLIRContext); FINE_RESOURCE(mlir::Value); @@ -32,12 +35,7 @@ FINE_RESOURCE(exla::ExlaBuffer); FINE_RESOURCE(exla::ExlaExecutable); FINE_RESOURCE(exla::MLIRModule); FINE_RESOURCE(exla::MLIRFunction); -FINE_RESOURCE(exla::ElixirCallbackPending); - -// Opaque handle type used only so Elixir can keep the bridge alive via a -// Fine resource. It carries no data; the real bridge state is the singleton -// ElixirCallbackBridgeState below. -struct ElixirCallbackBridgeHandle {}; +FINE_RESOURCE(ElixirCallbackPending); FINE_RESOURCE(ElixirCallbackBridgeHandle); // MLIR Functions @@ -532,201 +530,18 @@ get_per_device_memory(ErlNifEnv *env, fine::ResourcePtr client) { FINE_NIF(get_per_device_memory, 0); -// Elixir callback bridge +// Elixir callback bridge NIF registrations -namespace { - -struct ElixirCallbackBridgeState { - ErlNifPid dispatcher_pid; - bool dispatcher_set = false; -}; - -// We keep a single global bridge state, but expose it as a Fine resource so -// Elixir code can tie its lifetime to EXLA.CallbackServer. -static ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { - static ElixirCallbackBridgeState *state = new ElixirCallbackBridgeState(); - return state; -} - -} // namespace - -fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, - ErlNifPid dispatcher_pid) { - auto state = GetElixirCallbackBridgeState(); - state->dispatcher_pid = dispatcher_pid; - state->dispatcher_set = true; - return fine::Ok(); -} +using callback_bridge::acquire_elixir_callback_bridge; +using callback_bridge::clear_elixir_callback_bridge; +using callback_bridge::elixir_callback_reply; +using callback_bridge::start_elixir_callback_bridge; FINE_NIF(start_elixir_callback_bridge, 0); - -fine::Ok<> -elixir_callback_reply(ErlNifEnv *env, - fine::ResourcePtr pending, - fine::Atom status, fine::Term result) { - DeliverElixirCallbackReply(env, pending, status, result); - return fine::Ok(); -} - FINE_NIF(elixir_callback_reply, ERL_NIF_DIRTY_JOB_IO_BOUND); - -fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, - ErlNifPid dispatcher_pid) { - auto state = GetElixirCallbackBridgeState(); - - if (state->dispatcher_set && - std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) == - 0) { - state->dispatcher_set = false; - } - - return fine::Ok(); -} - FINE_NIF(clear_elixir_callback_bridge, 0); - -// Allocate and return a Fine resource handle associated with the bridge. -// This lets Elixir hold a reference (e.g., in EXLA.CallbackServer state) so -// the bridge lifetime is attached to that process. Per-callback pending -// resources are created independently for each call. -fine::ResourcePtr -acquire_elixir_callback_bridge(ErlNifEnv *env) { - (void)env; - return fine::make_resource(); -} - FINE_NIF(acquire_elixir_callback_bridge, 0); -void DeliverElixirCallbackReply( - ErlNifEnv *env, fine::ResourcePtr pending, - fine::Atom status, fine::Term result_term) { - ElixirCallbackResult cb_result; - - if (status == "ok") { - // Successful reply: result_term is a list of binaries that we decode into - // raw byte vectors via Fine and copy directly into the registered output - // buffers. - try { - auto payloads = - fine::decode>>(env, result_term); - - std::lock_guard lock(pending->mu); - - if (payloads.size() != pending->outputs.size()) { - cb_result.ok = false; - cb_result.error = - "mismatched number of callback outputs vs registered buffers"; - } else { - cb_result.ok = true; - - for (size_t i = 0; i < payloads.size(); ++i) { - const auto &bytes = payloads[i]; - auto &out_buf = pending->outputs[i]; - - if (bytes.size() != out_buf.size) { - cb_result.ok = false; - cb_result.error = - "callback returned binary of unexpected size for result buffer"; - break; - } - - if (out_buf.size > 0) { - std::memcpy(out_buf.data, bytes.data(), out_buf.size); - } - } - } - } catch (const std::exception &e) { - cb_result.ok = false; - cb_result.error = - std::string("failed to decode Elixir callback outputs: ") + e.what(); - } - } else { - // Error reply: result_term is expected to be {kind_atom, message :: binary} - cb_result.ok = false; - - try { - auto decoded = - fine::decode>(env, result_term); - ErlNifBinary msg_bin = std::get<1>(decoded); - cb_result.error.assign(reinterpret_cast(msg_bin.data), - msg_bin.size); - } catch (const std::exception &) { - cb_result.error = "elixir callback returned error"; - } - } - - { - std::lock_guard lock(pending->mu); - pending->result = std::move(cb_result); - pending->done = true; - } - - pending->cv.notify_one(); -} - -ElixirCallbackResult -CallElixirCallback(int64_t callback_id, - const std::vector &inputs, - const std::vector &outputs) { - auto state = GetElixirCallbackBridgeState(); - - if (!state->dispatcher_set) { - ElixirCallbackResult res; - res.ok = false; - res.error = "EXLA elixir callback dispatcher is not set"; - return res; - } - - auto pending = fine::make_resource(outputs); - - ErlNifEnv *msg_env = enif_alloc_env(); - - // Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send - // plain binaries because the BEAM callback needs to own the data lifetime. - std::vector args_terms; - args_terms.reserve(inputs.size()); - - for (const auto &tensor : inputs) { - ERL_NIF_TERM bin_term; - unsigned char *bin_data = - enif_make_new_binary(msg_env, tensor.size_bytes, &bin_term); - if (tensor.size_bytes > 0) { - memcpy(bin_data, tensor.data, tensor.size_bytes); - } - - // Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via - // Fine's encoder defined in exla_nif_util.h. - ERL_NIF_TERM typespec_term = - fine::encode(msg_env, std::make_tuple(tensor.dtype, tensor.dims)); - - ERL_NIF_TERM arg_tuple = enif_make_tuple2(msg_env, bin_term, typespec_term); - - args_terms.push_back(arg_tuple); - } - - ERL_NIF_TERM args_list = - enif_make_list_from_array(msg_env, args_terms.data(), args_terms.size()); - - ERL_NIF_TERM pending_term = fine::encode(msg_env, pending); - ERL_NIF_TERM cb_term = enif_make_int64(msg_env, callback_id); - - ERL_NIF_TERM msg = - enif_make_tuple4(msg_env, enif_make_atom(msg_env, "exla_elixir_call"), - cb_term, args_list, pending_term); - - // Use the dispatcher pid registered via start_elixir_callback_bridge/1. - // Calling enif_whereis_pid from this non-scheduler thread is unsafe and - // was causing a segfault. - ErlNifPid dispatcher_pid = state->dispatcher_pid; - enif_send(msg_env, &dispatcher_pid, msg_env, msg); - enif_free_env(msg_env); - - std::unique_lock lock(pending->mu); - pending->cv.wait(lock, [&pending] { return pending->done; }); - - return pending->result; -} - // Logging fine::Ok<> start_log_sink(ErlNifEnv *env, ErlNifPid logger_pid) { From 81ac875a6b872e06a256ed05f835fdca565d1399 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 26 Nov 2025 05:31:21 -0300 Subject: [PATCH 18/42] kill handle --- exla/Makefile | 2 +- .../exla/custom_calls/elixir_callback.cc | 10 +- .../custom_calls/elixir_callback_bridge.cc | 42 ++-- .../custom_calls/elixir_callback_bridge.h | 51 ++-- exla/c_src/exla/elixir_callback_bridge.h | 233 ------------------ exla/c_src/exla/exla.cc | 8 +- exla/lib/exla/callback_server.ex | 15 +- exla/lib/exla/nif.ex | 5 +- 8 files changed, 49 insertions(+), 317 deletions(-) delete mode 100644 exla/c_src/exla/elixir_callback_bridge.h diff --git a/exla/Makefile b/exla/Makefile index 43e79e3907..39371f9fcd 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -84,7 +84,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO) SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/ipc.cc SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc) -HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h +HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/custom_calls/elixir_callback_bridge.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc index cdd5b1b0cf..6f9c90fde1 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -36,7 +36,7 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, // Collect all remaining input tensors (excluding callback id) into // lightweight payload views. - std::vector inputs; + std::vector inputs; inputs.reserve(args.size() - 1); for (size_t i = 1; i < args.size(); ++i) { @@ -47,7 +47,7 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, ffi::AnyBuffer buf = *maybe_buf_or; - exla::ElixirCallbackArg tensor; + exla::callback_bridge::Arg tensor; tensor.dtype = buf.element_type(); auto dims = buf.dimensions(); @@ -61,7 +61,7 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, // Prepare output buffer descriptors so the callback bridge can write results // directly into the final destination buffers. - std::vector outputs; + std::vector outputs; outputs.reserve(rets.size()); for (size_t i = 0; i < rets.size(); ++i) { @@ -73,7 +73,7 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, ffi::Result ret = *maybe_ret_or; ffi::AnyBuffer out = *ret; - exla::ElixirCallbackOutputBuffer buf; + exla::callback_bridge::OutputBuffer buf; buf.data = static_cast(out.untyped_data()); buf.size = ffi::ByteWidth(out.element_type()) * static_cast(out.element_count()); @@ -83,7 +83,7 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, // Call back into Elixir through the bridge. On success, the bridge writes // results directly into the provided output buffers. - exla::ElixirCallbackResult result = + exla::callback_bridge::Result result = exla::callback_bridge::InvokeElixirCallback(callback_id, inputs, outputs); if (!result.ok) { diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc index e4d716dde1..4653e28b2e 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -6,36 +6,36 @@ namespace exla { namespace callback_bridge { -struct ElixirCallbackBridgeState { +struct BridgeState { ErlNifPid dispatcher_pid; bool dispatcher_set = false; }; -ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { - static ElixirCallbackBridgeState *state = new ElixirCallbackBridgeState(); +BridgeState *GetBridgeState() { + static BridgeState *state = new BridgeState(); return state; } fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid) { (void)env; - auto state = GetElixirCallbackBridgeState(); + auto state = GetBridgeState(); state->dispatcher_pid = dispatcher_pid; state->dispatcher_set = true; return fine::Ok(); } -fine::Ok<> elixir_callback_reply( - ErlNifEnv *env, fine::ResourcePtr pending, - fine::Atom status, fine::Term result) { - DeliverElixirCallbackReply(env, pending, status, result); +fine::Ok<> elixir_callback_reply(ErlNifEnv *env, + fine::ResourcePtr pending, + fine::Atom status, fine::Term result) { + deliver_reply(env, pending, status, result); return fine::Ok(); } fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid) { (void)env; - auto state = GetElixirCallbackBridgeState(); + auto state = GetBridgeState(); if (state->dispatcher_set && std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) == @@ -46,16 +46,9 @@ fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, return fine::Ok(); } -fine::ResourcePtr -acquire_elixir_callback_bridge(ErlNifEnv *env) { - (void)env; - return fine::make_resource(); -} - -void DeliverElixirCallbackReply( - ErlNifEnv *env, fine::ResourcePtr pending, - fine::Atom status, fine::Term result_term) { - ElixirCallbackResult cb_result; +void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, + fine::Atom status, fine::Term result_term) { + Result cb_result; if (status == "ok") { // Successful reply: result_term is a list of binaries that we decode into @@ -119,19 +112,18 @@ void DeliverElixirCallbackReply( pending->cv.notify_one(); } -ElixirCallbackResult InvokeElixirCallback( - int64_t callback_id, const std::vector &inputs, - const std::vector &outputs) { - auto state = GetElixirCallbackBridgeState(); +Result InvokeElixirCallback(int64_t callback_id, const std::vector &inputs, + const std::vector &outputs) { + auto state = GetBridgeState(); if (!state->dispatcher_set) { - ElixirCallbackResult res; + Result res; res.ok = false; res.error = "EXLA elixir callback dispatcher is not set"; return res; } - auto pending = fine::make_resource(outputs); + auto pending = fine::make_resource(outputs); ErlNifEnv *msg_env = enif_alloc_env(); diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h index a7753b3acb..ad9228bf43 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h @@ -13,7 +13,9 @@ namespace exla { -struct ElixirCallbackArg { +namespace callback_bridge { + +struct Arg { xla::ffi::DataType dtype; std::vector dims; const uint8_t *data = nullptr; @@ -21,79 +23,66 @@ struct ElixirCallbackArg { }; // Result of an Elixir callback. On success, data has already been copied into -// the pre-registered output buffers held by ElixirCallbackPending, so we only +// the pre-registered output buffers held by Pending, so we only // need to track success or an error message here. -struct ElixirCallbackResult { +struct Result { bool ok = false; std::string error; }; // Host-side description of an output buffer that should receive the callback // result for a given output index. -struct ElixirCallbackOutputBuffer { +struct OutputBuffer { uint8_t *data = nullptr; size_t size = 0; }; -namespace callback_bridge { - -// Opaque handle type used only so Elixir can keep the bridge alive via a -// Fine resource. It carries no data; the real bridge state is stored -// internally in the bridge implementation. -struct ElixirCallbackBridgeHandle {}; - // Per-callback pending state used to synchronize between the XLA host thread // and the Elixir-side dispatcher. This is exposed as a Fine resource so we // can pass it as an opaque handle in messages instead of using integer tags. -struct ElixirCallbackPending { +struct Pending { // Constructor used on the host callback path where we pre-register the // destination buffers for each output. - explicit ElixirCallbackPending( - std::vector outputs) + explicit Pending(std::vector outputs) : outputs(std::move(outputs)) {} std::mutex mu; std::condition_variable cv; bool done = false; - ElixirCallbackResult result; - std::vector outputs; + Result result; + std::vector outputs; }; // Called from the Elixir side to deliver a reply for a given pending handle. // We receive the reply as a status atom (e.g. :ok or :error) and a result // term. For the :ok case the result is a list of binaries that we decode as // ElixirCallbackTensor outputs via Fine's decoding machinery. -void DeliverElixirCallbackReply( - ErlNifEnv *env, fine::ResourcePtr pending, - fine::Atom status, fine::Term result); +void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, + fine::Atom status, fine::Term result); // Synchronously calls the Elixir callback identified by `callback_id` with the // given tensor arguments. This function: // -// * Allocates a unique ElixirCallbackPending resource +// * Allocates a unique Pending resource // * Sends a message to the dispatcher via enif_send/3 // * Blocks the calling native thread until the reply arrives via -// DeliverElixirCallbackReply/3 +// deliver_reply/3 // -// It returns an ElixirCallbackResult that either indicates success (data has +// It returns an Result that either indicates success (data has // been written into the registered output buffers) or an error message. -ElixirCallbackResult InvokeElixirCallback( - int64_t callback_id, const std::vector &inputs, - const std::vector &outputs); +Result InvokeElixirCallback(int64_t callback_id, const std::vector &inputs, + const std::vector &outputs); fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid); -fine::Ok<> elixir_callback_reply( - ErlNifEnv *env, fine::ResourcePtr pending, - fine::Atom status, fine::Term result); +fine::Ok<> elixir_callback_reply(ErlNifEnv *env, + fine::ResourcePtr pending, + fine::Atom status, fine::Term result); fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid); -fine::ResourcePtr -acquire_elixir_callback_bridge(ErlNifEnv *env); - } // namespace callback_bridge } // namespace exla diff --git a/exla/c_src/exla/elixir_callback_bridge.h b/exla/c_src/exla/elixir_callback_bridge.h deleted file mode 100644 index 7402ebd803..0000000000 --- a/exla/c_src/exla/elixir_callback_bridge.h +++ /dev/null @@ -1,233 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include "exla_nif_util.h" -#include "xla/ffi/api/ffi.h" -#include -#include - -namespace exla { - -struct ElixirCallbackArg { - xla::ffi::DataType dtype; - std::vector dims; - const uint8_t *data = nullptr; - size_t size_bytes = 0; -}; - -// Result of an Elixir callback. On success, data has already been copied into -// the pre-registered output buffers held by ElixirCallbackPending, so we only -// need to track success or an error message here. -struct ElixirCallbackResult { - bool ok = false; - std::string error; -}; - -// Host-side description of an output buffer that should receive the callback -// result for a given output index. -struct ElixirCallbackOutputBuffer { - uint8_t *data = nullptr; - size_t size = 0; -}; - -// Per-callback pending state used to synchronize between the XLA host thread -// and the Elixir-side dispatcher. This is exposed as a Fine resource so we -// can pass it as an opaque handle in messages instead of using integer tags. -struct ElixirCallbackPending { - // Constructor used on the host callback path where we pre-register the - // destination buffers for each output. - explicit ElixirCallbackPending( - std::vector outputs) - : outputs(std::move(outputs)) {} - - std::mutex mu; - std::condition_variable cv; - bool done = false; - ElixirCallbackResult result; - std::vector outputs; -}; - -// Called from the Elixir side to deliver a reply for a given pending handle. -// We receive the reply as a status atom (e.g. :ok or :error) and a result -// term. For the :ok case the result is a list of binaries that we decode as -// ElixirCallbackTensor outputs via Fine's decoding machinery. -void DeliverElixirCallbackReply( - ErlNifEnv *env, fine::ResourcePtr pending, - fine::Atom status, fine::Term result); - -// Synchronously calls the Elixir callback identified by `callback_id` with the -// given tensor arguments. This function: -// -// * Allocates a unique ElixirCallbackPending resource -// * Sends a message to the dispatcher via enif_send/3 -// * Blocks the calling native thread until the reply arrives via -// DeliverElixirCallbackReply/3 -// -// It returns an ElixirCallbackResult that either indicates success (data has -// been written into the registered output buffers) or an error message. -ElixirCallbackResult -InvokeElixirCallback(int64_t callback_id, - const std::vector &inputs, - const std::vector &outputs); - -} // namespace exla - -namespace fine { - -// Decode a binary term into a raw byte vector. We only care about the payload -// bytes; dtype and shape are validated on the Elixir side. -template <> struct Decoder> { - static std::vector decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { - ErlNifBinary bin; - if (!enif_inspect_binary(env, term, &bin)) { - throw std::invalid_argument( - "decode failed, expected binary for callback output"); - } - - std::vector bytes; - bytes.assign(bin.data, bin.data + bin.size); - return bytes; - } -}; - -// Define encoding for {ffi_dtype, dims} into %EXLA.Typespec{} term. This is -// used by the Elixir callback bridge to surface type and shape information -// about callback arguments to the Elixir side. -template <> -struct Encoder>> { - static ERL_NIF_TERM - encode(ErlNifEnv *env, - const std::tuple> &spec) { - const xla::ffi::DataType &dtype = std::get<0>(spec); - const std::vector &dims = std::get<1>(spec); - - ERL_NIF_TERM keys[] = {fine::encode(env, exla::atoms::__struct__), - fine::encode(env, exla::atoms::type), - fine::encode(env, exla::atoms::shape)}; - - ERL_NIF_TERM values[] = {fine::encode(env, exla::atoms::ElixirEXLATypespec), - encode_type(env, dtype), - encode_shape(env, dtype, dims)}; - - ERL_NIF_TERM map; - if (!enif_make_map_from_arrays(env, keys, values, 3, &map)) { - throw std::runtime_error("encode: failed to make a map"); - } - - return map; - } - -private: - static ERL_NIF_TERM encode_type(ErlNifEnv *env, xla::ffi::DataType dtype) { - using DT = xla::ffi::DataType; - - // Tokens are encoded as the atom :token with empty shape. - if (dtype == DT::TOKEN) { - return fine::encode(env, exla::atoms::token); - } - - std::optional type_name; - std::optional type_size; - - switch (dtype) { - case DT::PRED: - type_name = exla::atoms::pred; - type_size = 8; - break; - - case DT::U8: - type_name = exla::atoms::u; - type_size = 8; - break; - case DT::U16: - type_name = exla::atoms::u; - type_size = 16; - break; - case DT::U32: - type_name = exla::atoms::u; - type_size = 32; - break; - case DT::U64: - type_name = exla::atoms::u; - type_size = 64; - break; - - case DT::S8: - type_name = exla::atoms::s; - type_size = 8; - break; - case DT::S16: - type_name = exla::atoms::s; - type_size = 16; - break; - case DT::S32: - type_name = exla::atoms::s; - type_size = 32; - break; - case DT::S64: - type_name = exla::atoms::s; - type_size = 64; - break; - - case DT::F16: - type_name = exla::atoms::f; - type_size = 16; - break; - case DT::F32: - type_name = exla::atoms::f; - type_size = 32; - break; - case DT::F64: - type_name = exla::atoms::f; - type_size = 64; - break; - - case DT::BF16: - type_name = exla::atoms::bf; - type_size = 16; - break; - - case DT::C64: - type_name = exla::atoms::c; - type_size = 64; - break; - case DT::C128: - type_name = exla::atoms::c; - type_size = 128; - break; - - default: - break; - } - - if (type_name && type_size) { - return fine::encode( - env, std::make_tuple(type_name.value(), type_size.value())); - } - - throw std::invalid_argument("encode failed, unexpected ffi::DataType"); - } - - static ERL_NIF_TERM encode_shape(ErlNifEnv *env, xla::ffi::DataType dtype, - const std::vector &dims) { - if (dtype == xla::ffi::DataType::TOKEN) { - return enif_make_tuple(env, 0); - } - - std::vector dim_terms; - dim_terms.reserve(dims.size()); - - for (auto d : dims) { - dim_terms.push_back(fine::encode(env, d)); - } - - return enif_make_tuple_from_array(env, dim_terms.data(), dim_terms.size()); - } -}; - -} // namespace fine diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index fcc391d29e..a3db6f5463 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -23,8 +23,7 @@ namespace exla { -using callback_bridge::ElixirCallbackBridgeHandle; -using callback_bridge::ElixirCallbackPending; +using callback_bridge::Pending; FINE_RESOURCE(llvm::StdThreadPool); FINE_RESOURCE(mlir::MLIRContext); @@ -35,8 +34,7 @@ FINE_RESOURCE(exla::ExlaBuffer); FINE_RESOURCE(exla::ExlaExecutable); FINE_RESOURCE(exla::MLIRModule); FINE_RESOURCE(exla::MLIRFunction); -FINE_RESOURCE(ElixirCallbackPending); -FINE_RESOURCE(ElixirCallbackBridgeHandle); +FINE_RESOURCE(Pending); // MLIR Functions @@ -532,7 +530,6 @@ FINE_NIF(get_per_device_memory, 0); // Elixir callback bridge NIF registrations -using callback_bridge::acquire_elixir_callback_bridge; using callback_bridge::clear_elixir_callback_bridge; using callback_bridge::elixir_callback_reply; using callback_bridge::start_elixir_callback_bridge; @@ -540,7 +537,6 @@ using callback_bridge::start_elixir_callback_bridge; FINE_NIF(start_elixir_callback_bridge, 0); FINE_NIF(elixir_callback_reply, ERL_NIF_DIRTY_JOB_IO_BOUND); FINE_NIF(clear_elixir_callback_bridge, 0); -FINE_NIF(acquire_elixir_callback_bridge, 0); // Logging diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 5f740f1b82..0d2d84260c 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -34,21 +34,14 @@ defmodule EXLA.CallbackServer do @type callback_id :: non_neg_integer() defstruct next_id: 1, - callbacks: %{}, - # Opaque handle to the native elixir callback bridge so its - # lifetime is tied to this server process. - bridge_ref: nil + callbacks: %{} @type t :: %__MODULE__{ next_id: non_neg_integer(), # We store the original function, its output template, and any # static (non-tensor) arguments that should always be appended to # the decoded tensor arguments coming from native. - callbacks: %{callback_id() => {fun(), Nx.t() | tuple(), [term()]}}, - # Native bridge resource. We don't use it directly in Elixir, but - # holding a reference here ensures the native bridge stays alive - # as long as this server does. - bridge_ref: term() + callbacks: %{callback_id() => {fun(), Nx.t() | tuple(), [term()]}} } ## Public API @@ -81,11 +74,9 @@ defmodule EXLA.CallbackServer do @impl true def init(:ok) do # Inform native side that this process is the dispatcher for elixir callbacks - # and acquire a bridge resource so its lifetime is attached to this server. _ = EXLA.NIF.start_elixir_callback_bridge(self()) - bridge_ref = EXLA.NIF.acquire_elixir_callback_bridge() - {:ok, %__MODULE__{bridge_ref: bridge_ref}} + {:ok, %__MODULE__{}} end @impl true diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 98ff89b02e..b579c015e8 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -78,13 +78,10 @@ defmodule EXLA.NIF do def reset_peak_memory(_client), do: err!() def get_per_device_memory(_client), do: err!() - # Elixir callback bridge (Phase 1: CPU-only, simple APIs) + # Elixir callback bridge def start_elixir_callback_bridge(_dispatcher_pid), do: err!() def clear_elixir_callback_bridge(_dispatcher_pid), do: err!() def elixir_callback_reply(_reply_tag, _status, _result), do: err!() - # Bridge resource handle so EXLA.CallbackServer can keep the bridge alive. - def acquire_elixir_callback_bridge(), do: err!() - defp err!(), do: :erlang.nif_error(:undef) end From 4149353d0b91c5d284fc145bc099ab1a5fb4a82d Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 27 Nov 2025 21:44:58 -0300 Subject: [PATCH 19/42] refactor: use erlnifbinary --- .../exla/custom_calls/elixir_callback_bridge.cc | 9 ++++----- .../exla/custom_calls/elixir_callback_bridge.h | 16 ---------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc index 4653e28b2e..77e15ce169 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -55,8 +55,7 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, // raw byte vectors via Fine and copy directly into the registered output // buffers. try { - auto payloads = - fine::decode>>(env, result_term); + auto payloads = fine::decode>(env, result_term); std::lock_guard lock(pending->mu); @@ -68,10 +67,10 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, cb_result.ok = true; for (size_t i = 0; i < payloads.size(); ++i) { - const auto &bytes = payloads[i]; + const ErlNifBinary &bytes = payloads[i]; auto &out_buf = pending->outputs[i]; - if (bytes.size() != out_buf.size) { + if (bytes.size != out_buf.size) { cb_result.ok = false; cb_result.error = "callback returned binary of unexpected size for result buffer"; @@ -79,7 +78,7 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, } if (out_buf.size > 0) { - std::memcpy(out_buf.data, bytes.data(), out_buf.size); + std::memcpy(out_buf.data, bytes.data, out_buf.size); } } } diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h index ad9228bf43..fb954bfa32 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h @@ -89,22 +89,6 @@ fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, namespace fine { -// Decode a binary term into a raw byte vector. We only care about the payload -// bytes; dtype and shape are validated on the Elixir side. -template <> struct Decoder> { - static std::vector decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { - ErlNifBinary bin; - if (!enif_inspect_binary(env, term, &bin)) { - throw std::invalid_argument( - "decode failed, expected binary for callback output"); - } - - std::vector bytes; - bytes.assign(bin.data, bin.data + bin.size); - return bytes; - } -}; - // Define encoding for {ffi_dtype, dims} into %EXLA.Typespec{} term. This is // used by the Elixir callback bridge to surface type and shape information // about callback arguments to the Elixir side. From 9a60a261ef0c4b618c1741524e7f04d92f8e08b4 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 27 Nov 2025 21:56:46 -0300 Subject: [PATCH 20/42] refactor: leverage fine encoding --- .../custom_calls/elixir_callback_bridge.cc | 32 +++++++------------ exla/mix.exs | 2 +- exla/mix.lock | 2 +- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc index 77e15ce169..518deb8a37 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -128,42 +128,32 @@ Result InvokeElixirCallback(int64_t callback_id, const std::vector &inputs, // Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send // plain binaries because the BEAM callback needs to own the data lifetime. - std::vector args_terms; + std::vector>>> + args_terms; args_terms.reserve(inputs.size()); for (const auto &tensor : inputs) { - ERL_NIF_TERM bin_term; - unsigned char *bin_data = - enif_make_new_binary(msg_env, tensor.size_bytes, &bin_term); - if (tensor.size_bytes > 0) { - memcpy(bin_data, tensor.data, tensor.size_bytes); - } + fine::Term bin_term = fine::make_new_binary( + msg_env, reinterpret_cast(tensor.data), + tensor.size_bytes); // Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via // Fine's encoder defined in exla_nif_util.h. - ERL_NIF_TERM typespec_term = - fine::encode(msg_env, std::make_tuple(tensor.dtype, tensor.dims)); - - ERL_NIF_TERM arg_tuple = enif_make_tuple2(msg_env, bin_term, typespec_term); + auto arg_tuple = + std::make_tuple(bin_term, std::make_tuple(tensor.dtype, tensor.dims)); args_terms.push_back(arg_tuple); } - ERL_NIF_TERM args_list = - enif_make_list_from_array(msg_env, args_terms.data(), args_terms.size()); - - ERL_NIF_TERM pending_term = fine::encode(msg_env, pending); - ERL_NIF_TERM cb_term = enif_make_int64(msg_env, callback_id); - - ERL_NIF_TERM msg = - enif_make_tuple4(msg_env, enif_make_atom(msg_env, "exla_elixir_call"), - cb_term, args_list, pending_term); + auto msg = std::make_tuple(fine::Atom("exla_elixir_call"), callback_id, + args_terms, pending); // Use the dispatcher pid registered via start_elixir_callback_bridge/1. // Calling enif_whereis_pid from this non-scheduler thread is unsafe and // was causing a segfault. ErlNifPid dispatcher_pid = state->dispatcher_pid; - enif_send(msg_env, &dispatcher_pid, msg_env, msg); + enif_send(msg_env, &dispatcher_pid, msg_env, fine::encode(msg_env, msg)); enif_free_env(msg_env); std::unique_lock lock(pending->mu); diff --git a/exla/mix.exs b/exla/mix.exs index 5818232299..bfc7ed77ad 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -69,7 +69,7 @@ defmodule EXLA.MixProject do {:nx, path: "../nx"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, {:xla, "~> 0.9.0", runtime: false}, - {:fine, "~> 0.1.0", runtime: false}, + {:fine, "~> 0.1", runtime: false}, {:elixir_make, "~> 0.6", runtime: false}, {:benchee, "~> 1.0", only: :dev}, {:ex_doc, "~> 0.29", only: :docs}, diff --git a/exla/mix.lock b/exla/mix.lock index 4508f2a57c..2b18c93da9 100644 --- a/exla/mix.lock +++ b/exla/mix.lock @@ -5,7 +5,7 @@ "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, - "fine": {:hex, :fine, "0.1.0", "9bb99a5ff9b968f12c3b458fa1277c39e9a620b23a9439103703a25917293871", [:mix], [], "hexpm", "1d6485bf811b95dc6ae3d197c0e6f994880b86167a827983bb29cbfc03a02684"}, + "fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, From da72cddce2e2bd0aea1abbc1d21ae1e83c0df342 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 27 Nov 2025 21:57:13 -0300 Subject: [PATCH 21/42] chore: remove identity calls --- exla/lib/exla/callback_server.ex | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 0d2d84260c..b78dea4f52 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -231,7 +231,7 @@ defmodule EXLA.CallbackServer do "expected the elixir_call function to match the given output template " <> "#{inspect(right)}, got: #{inspect(left)}" - {:error, {:argument_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + {:error, {:argument_error, msg}} end # Callback returned something that isn't a tensor/tuple matching the template. @@ -240,43 +240,43 @@ defmodule EXLA.CallbackServer do "expected the elixir_call function to return a value compatible with the output " <> "template #{inspect(right)}, got: #{inspect(left)}" - {:error, {:argument_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + {:error, {:argument_error, msg}} end # Argument decoding failures. defp encode_reply({:error, {:decode_failed, exception}}) do msg = Exception.message(exception) msg = "failed to decode Elixir callback arguments: #{msg}" - {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + {:error, {:runtime_error, msg}} end defp encode_reply({:error, {:invalid_args_spec, other}}) do msg = "invalid args_spec for Elixir callback: #{inspect(other)}" - {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + {:error, {:runtime_error, msg}} end # Unknown callback id from native. defp encode_reply({:error, :unknown_callback}) do msg = "unknown EXLA elixir_call callback id" - {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + {:error, {:runtime_error, msg}} end # User-raised exceptions. defp encode_reply({:error, {:exception, exception, _stack}}) do msg = Exception.message(exception) msg = "Elixir callback raised: #{msg}" - {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + {:error, {:runtime_error, msg}} end # Catches other error tuples (throws, exits, etc). defp encode_reply({:error, {kind, reason}}) do msg = "Elixir callback #{kind}: #{inspect(reason)}" - {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + {:error, {:runtime_error, msg}} end defp encode_reply({:error, reason}) do msg = "Elixir callback error: #{inspect(reason)}" - {:error, {:runtime_error, :erlang.binary_to_term(:erlang.term_to_binary(msg))}} + {:error, {:runtime_error, msg}} end defp encode_outputs(%Nx.Tensor{} = tensor) do From 38b35f463655261689a1a9bf0049a0adc9120ceb Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 27 Nov 2025 22:01:59 -0300 Subject: [PATCH 22/42] fix: use error atom --- exla/c_src/exla/custom_calls/elixir_callback_bridge.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc index 518deb8a37..c028e43fb5 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -94,9 +94,13 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, try { auto decoded = fine::decode>(env, result_term); + fine::Atom kind = std::get<0>(decoded); ErlNifBinary msg_bin = std::get<1>(decoded); - cb_result.error.assign(reinterpret_cast(msg_bin.data), - msg_bin.size); + + cb_result.error = + "elixir callback returned " + kind.to_string() + ": " + + std::string(reinterpret_cast(msg_bin.data), + msg_bin.size); } catch (const std::exception &) { cb_result.error = "elixir callback returned error"; } From e6ecfcac4c8369118ae395a6ea6b78fcadb58231 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 27 Nov 2025 23:49:13 -0300 Subject: [PATCH 23/42] refactor: do not use dynamic arities --- exla/lib/exla/callback_server.ex | 62 ++++++++++++------- exla/lib/exla/defn.ex | 19 ++++-- exla/test/exla/defn/elixir_call_test.exs | 23 ++++++- nx/lib/nx.ex | 52 ++++++++-------- nx/lib/nx/backend.ex | 7 ++- nx/lib/nx/defn/evaluator.ex | 18 ++---- nx/lib/nx/defn/expr.ex | 32 +++++----- nx/lib/nx/defn/tree.ex | 15 ++--- .../nx/defn/elixir_call_evaluator_test.exs | 19 +----- torchx/test/torchx/defn/elixir_call_test.exs | 4 +- 10 files changed, 134 insertions(+), 117 deletions(-) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index b78dea4f52..c56478bc5a 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -57,16 +57,16 @@ defmodule EXLA.CallbackServer do end @doc """ - Registers a callback function, its output template, and static arguments, returning a callback id. + Registers a callback function, its output template, argument template, and options, + returning a callback id. - The same `{fun, out_template, static_args}` triple will always return the - same id for the lifetime of this VM. This id is what the EXLA compiler - encodes into the host `CustomCall` so the native side can reference the - right callback. + The same `{fun, out_template, arg_template, opts}` quadruple will always return the + same id for the lifetime of this VM. This id is what the EXLA compiler encodes into + the host `CustomCall` so the native side can reference the right callback. """ - @spec register(fun(), Nx.t() | tuple(), [term()]) :: callback_id() - def register(fun, out_template, static_args) when is_function(fun) and is_list(static_args) do - GenServer.call(__MODULE__, {:register, fun, out_template, static_args}) + @spec register(fun(), Nx.t() | tuple(), term(), [term()]) :: callback_id() + def register(fun, out_template, arg_template, opts) when is_function(fun) and is_list(opts) do + GenServer.call(__MODULE__, {:register, fun, out_template, arg_template, opts}) end ## GenServer callbacks @@ -89,8 +89,8 @@ defmodule EXLA.CallbackServer do end @impl true - def handle_call({:register, fun, out_template, static_args}, _from, %__MODULE__{} = state) do - key = {fun, Nx.to_template(out_template), static_args} + def handle_call({:register, fun, out_template, arg_template, opts}, _from, %__MODULE__{} = state) do + key = {fun, out_template, arg_template, opts} case find_existing_id(state.callbacks, key) do {:ok, id} -> @@ -98,7 +98,7 @@ defmodule EXLA.CallbackServer do :error -> id = state.next_id - state = put_in(state.callbacks[id], {fun, Nx.to_template(out_template), static_args}) + state = put_in(state.callbacks[id], {fun, out_template, arg_template, opts}) state = %{state | next_id: id + 1} {:reply, id, state} end @@ -107,11 +107,11 @@ defmodule EXLA.CallbackServer do @impl true def handle_info({:exla_elixir_call, callback_id, args_spec, reply_tag}, %__MODULE__{} = state) do case Map.fetch(state.callbacks, callback_id) do - {:ok, {fun, out_template, static_args}} -> + {:ok, {fun, out_template, arg_template, opts}} -> reply_payload = args_spec - |> decode_args() - |> run_callback(fun, static_args, out_template) + |> decode_args(arg_template) + |> run_callback(fun, opts, out_template) |> encode_reply() send_reply(reply_tag, reply_payload) @@ -141,12 +141,12 @@ defmodule EXLA.CallbackServer do end) end - defp run_callback({:error, reason}, _fun, _static_args, _out_template), do: {:error, reason} + defp run_callback({:error, reason}, _fun, _opts, _out_template), do: {:error, reason} - defp run_callback({:ok, tensor_args}, fun, static_args, out_template) do + defp run_callback({:ok, tensor_args}, fun, opts, out_template) do result = try do - apply(fun, tensor_args ++ static_args) + fun.(tensor_args, opts) rescue exception -> {:error, {:exception, exception, __STACKTRACE__}} @@ -195,7 +195,7 @@ defmodule EXLA.CallbackServer do defp ensure_compatible(left, right), do: {:error, {:invalid_result, left, right}} - defp decode_args(args_spec) when is_list(args_spec) do + defp decode_args(args_spec, arg_template) when is_list(args_spec) do result = Enum.reduce_while(args_spec, {:ok, []}, fn {bin, %EXLA.Typespec{type: type, shape: shape}}, {:ok, acc} -> @@ -216,12 +216,16 @@ defmodule EXLA.CallbackServer do end) case result do - {:ok, tensors} -> {:ok, Enum.reverse(tensors)} - {:error, _} = error -> error + {:ok, tensors} -> + tensors = Enum.reverse(tensors) + materialize_args(arg_template, tensors) + + {:error, _} = error -> + error end end - defp decode_args(other), do: {:error, {:invalid_args_spec, other}} + defp decode_args(other, _arg_template), do: {:error, {:invalid_args_spec, other}} defp encode_reply({:ok, value}), do: {:ok, encode_outputs(value)} @@ -279,6 +283,22 @@ defmodule EXLA.CallbackServer do {:error, {:runtime_error, msg}} end + defp materialize_args(arg_template, tensors) do + {container, remaining} = + Nx.Defn.Composite.traverse(arg_template, tensors, fn + %Nx.Tensor{} = _template, [next | rest] -> + {next, rest} + + other, acc -> + {other, acc} + end) + + case remaining do + [] -> {:ok, container} + _ -> {:error, {:invalid_args_spec, :extra_values}} + end + end + defp encode_outputs(%Nx.Tensor{} = tensor) do [Nx.to_binary(tensor)] end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 5ec08e9b0c..b789ed47b4 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -548,18 +548,25 @@ defmodule EXLA.Defn do defp cached_recur_operator( :elixir_call, - %T{data: %Expr{args: [in_args, fun, out_template]}} = expr, + %T{data: %Expr{args: [tensor_expr, opts, fun, out_template]}} = expr, %{client: %EXLA.Client{platform: :host}} = state, cache ) do - {tensor_args, static_args} = Enum.split_while(in_args, &(not is_list(&1))) + # Flatten the tensor_or_container expression into its tensor leaves so we + # can compile each as an independent operand to the host callback. + tensor_exprs = Composite.flatten_list([tensor_expr]) - {call_args, cache} = - Enum.map_reduce(tensor_args, cache, fn arg, cache -> + {arg_values, cache} = + Enum.map_reduce(tensor_exprs, cache, fn arg, cache -> recur_operator(arg, state, cache) |> unwrap_single_tensor!() end) - callback_id = EXLA.CallbackServer.register(fun, out_template, static_args) + # Build a template container for the tensor_or_container argument so the + # callback server can reconstruct the full structure from a flat list of + # decoded tensors. + arg_template = Nx.to_template(tensor_expr) + + callback_id = EXLA.CallbackServer.register(fun, out_template, arg_template, opts) typespecs = container_to_typespecs(out_template) # Pass callback id as an extra scalar s64 operand at the end so that the @@ -569,7 +576,7 @@ defmodule EXLA.Defn do callback_id_value = Value.constant(state.builder, [callback_id], callback_id_typespec) - operands = [callback_id_value | call_args] + operands = [callback_id_value | arg_values] results = Value.elixir_call(operands, typespecs) diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index b1bd4d4ab6..cd32de3d94 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -17,7 +17,7 @@ defmodule EXLA.Defn.ElixirCallTest do defn add_offset(x) do out = %{x | type: Nx.Type.to_floating(x.type)} - Nx.elixir_call(out, [x, [offset: 10.0]], &add_offset_callback/2) + Nx.elixir_call(out, x, [offset: 10.0], &add_offset_callback/2) end test "elixir_call with single output" do @@ -36,7 +36,7 @@ defmodule EXLA.Defn.ElixirCallTest do out_template = {out0, out1} {a, b} = - Nx.elixir_call(out_template, [fx], fn t -> + Nx.elixir_call(out_template, fx, fn t -> {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} end) @@ -55,7 +55,7 @@ defmodule EXLA.Defn.ElixirCallTest do defn bad_callback(x) do out = %{x | type: Nx.Type.to_floating(x.type)} - Nx.elixir_call(out, [x], fn _t -> + Nx.elixir_call(out, x, fn _t -> # Wrong shape on purpose Nx.tensor([1.0, 2.0, 3.0]) end) @@ -79,4 +79,21 @@ defmodule EXLA.Defn.ElixirCallTest do expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) assert_equal(y, expected) end + + def add_and_subtract_callback({x, y}) do + {Nx.add(x, y), Nx.subtract(x, y)} + end + + defn add_and_subtract(x, y) do + Nx.elixir_call({x, x}, {x, y}, &add_and_subtract_callback/1) + end + + test "elixir_call with tuple input" do + x = Nx.tensor([1, 2, 3]) + y = Nx.tensor([4, 5, 6]) + assert {add, sub} = add_and_subtract(x, y) + + assert_equal(add, Nx.add(x, y)) + assert_equal(sub, Nx.subtract(x, y)) + end end diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 1b934a44b8..c48a2cec95 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2197,48 +2197,48 @@ defmodule Nx do end @doc """ - Invokes an Elixir function from within defn. + Invokes an Elixir function from within `defn`. This function allows integrating arbitrary Elixir code into `defn` graphs. It receives an output template (a tensor or a tuple of tensors) that - specifies the expected shapes, types, and names of the result, a list of - arguments to pass to the Elixir function, and the function itself. + specifies the expected shapes, types, and names of the result, a tensor + or tensor container argument, optional Elixir options, and the function + itself. Inside `defn`, this builds an expression node understood by compilers. Outside `defn` or on backends without special support, it executes `fun` directly and validates the result matches the template. - - ## Argument ordering - - When called inside `defn`, all tensor arguments must be placed **before** - any list arguments. Lists (including keyword lists) are treated as static - Elixir data that is appended to the callback at runtime, while the leading - non-list arguments are compiled as tensors and shipped to the target - backend. Passing a tensor after a list argument raises an error. """ @doc type: :backend - def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do + def elixir_call(output, tensor_or_container, fun) when is_function(fun, 1) do + elixir_call(output, tensor_or_container, [], fn value, _opts -> fun.(value) end) + end + + def elixir_call(output, tensor_or_container, opts, fun) + when is_list(opts) and is_function(fun, 2) do {:arity, arity} = Function.info(fun, :arity) - num_args = length(args) - if arity != num_args do + if arity != 2 do raise ArgumentError, - "expected #{arity} arguments, got #{num_args}" + "expected elixir_call callback to have arity 2, got #{arity}" end - backend = Nx.Shared.list_impl!(args) + # Outside defn, we execute the callback directly or via the backend if it + # provides a specialized implementation. We resolve the backend from all + # tensors inside the container to support tuple/map containers. + tensors = Nx.Defn.Composite.flatten_list([tensor_or_container]) + backend = Nx.Shared.list_impl!(tensors) - cond do - function_exported?(backend, :elixir_call, 3) -> - output - |> backend.elixir_call(args, fun) - |> ensure_call_compatible!(output) + result = + cond do + function_exported?(backend, :elixir_call, 4) -> + backend.elixir_call(output, tensor_or_container, opts, fun) - true -> - fun - |> apply(args) - |> ensure_call_compatible!(output) - end + true -> + fun.(tensor_or_container, opts) + end + + ensure_call_compatible!(result, output) end defp ensure_call_compatible!(left, right) when tuple_size(left) == tuple_size(right) do diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index f8556ce308..b588b785c5 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -143,13 +143,14 @@ defmodule Nx.Backend do @callback optional(atom, [term], fun) :: tensor @doc """ - Invoked to execute a generic Elixir callback from within defn. + Invoked to execute a generic Elixir callback from within `defn`. The backend may choose how to execute it. For example, EXLA can lower to a custom_call that interacts with Erlang/Elixir via C; pure CPU backends may call the function directly. """ - @callback elixir_call(out :: tensor | tuple, [term], fun) :: tensor + @callback elixir_call(out :: tensor | tuple, tensor_or_container :: term, keyword, fun) :: + tensor @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor @@ -171,7 +172,7 @@ defmodule Nx.Backend do @optional_callbacks [ optional: 3, - elixir_call: 3, + elixir_call: 4, solve: 3, determinant: 2, logical_not: 2, diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index b4459c26c0..ce7d6cad29 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -175,13 +175,8 @@ defmodule Nx.Defn.Evaluator do Map.put(cache, [:optional | id], optional_expr_cache) end - defp compute_cache(:elixir_call, %{data: %Expr{args: args}}, state, cache) do - [in_args, _fun, _out_template] = args - - Enum.reduce(in_args, cache, fn - t, cache when is_list(t) -> cache - t, cache -> compute_cache(t, state, cache) - end) + defp compute_cache(:elixir_call, %{data: %Expr{args: [tensor_expr, _opts, _fun, _out]}}, state, cache) do + compute_cache(tensor_expr, state, cache) end defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do @@ -442,18 +437,17 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [in_args, fun, _out_template]}} = expr, + %{data: %Expr{args: [tensor_expr, opts, fun, _out_template]}} = expr, state, caches ) do - {tensor_args, opts} = Enum.split_while(in_args, &(not is_list(&1))) - {evaluated_tensors, caches} = Enum.map_reduce(tensor_args, caches, &eval(&1, state, &2)) - backend = Nx.Shared.list_impl!(evaluated_tensors) + {tensor_value, caches} = eval(tensor_expr, state, caches) + backend = Nx.Shared.list_impl!([tensor_value]) if backend == Nx.Defn.Expr do {expr, caches} else - {apply(fun, evaluated_tensors ++ opts), caches} + {fun.(tensor_value, opts), caches} end end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index f455857f87..10a70f4884 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -41,7 +41,7 @@ defmodule Nx.Defn.Expr do * `attach_token(token(%Nx.Defn.Token{}), expr)` - * `elixir_call(name, args, fun)` + * `elixir_call(out, tensor_or_container, opts, fun)` `defn` compilers must handle said nodes accordingly. """ @@ -387,31 +387,31 @@ defmodule Nx.Defn.Expr do end @impl true - def elixir_call(out, in_args, fun) do - ensure_tensor_args_prefix!(in_args) + def elixir_call(out, tensor_or_container, opts, fun) when is_function(fun, 2) do + # Convert the entire tensor_or_container into an expression container, + # preserving its structure but ensuring all tensors are Expr-backed. + tensor_expr = + Composite.traverse(tensor_or_container, fn + %T{} = t -> to_expr(t) + other -> other + end) - {tensor_args, _opts} = Enum.split_while(in_args, &(not is_list(&1))) - [%T{data: %Expr{context: context}} | _] = Enum.map(tensor_args, &to_expr/1) + # Grab context from the first tensor in the flattened container. + [%T{data: %Expr{context: context}} | _] = + Composite.flatten_list([tensor_expr]) case out do t when is_struct(t, Nx.Tensor) -> out_template = Nx.to_template(t) - expr(t, context, :elixir_call, [in_args, fun, out_template]) + expr(t, context, :elixir_call, [tensor_expr, opts, fun, out_template]) tuple when is_tuple(tuple) -> out_template = tuple_out(tuple_size(tuple)) user_template = Nx.to_template(tuple) - expr_node = expr(out_template, context, :elixir_call, [in_args, fun, user_template]) - tuple(expr_node, Tuple.to_list(tuple)) - end - end + expr_node = + expr(out_template, context, :elixir_call, [tensor_expr, opts, fun, user_template]) - defp ensure_tensor_args_prefix!(args) do - {_tensor_prefix, static_suffix} = Enum.split_while(args, &is_struct(&1, Nx.Tensor)) - - if Enum.any?(static_suffix, &is_struct(&1, Nx.Tensor)) do - raise ArgumentError, - "Nx.elixir_call/3 expects all tensor arguments to appear before any static arguments, but got: #{inspect(args)}" + tuple(expr_node, Tuple.to_list(tuple)) end end diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 02ee4d001c..ba4e6e1812 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -193,18 +193,11 @@ defmodule Nx.Defn.Tree do end def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, _type, acc, fun) do - [in_args, callback, out_template] = args - - {in_args, acc} = - Enum.map_reduce(in_args, acc, fn t, acc -> - if is_list(t) do - {t, acc} - else - Composite.traverse(t, acc, fun) - end - end) + [tensor_expr, callback_opts, callback, out_template] = args + + {tensor_expr, acc} = Composite.traverse(tensor_expr, acc, fun) - {[in_args, callback, out_template], acc} + {[tensor_expr, callback_opts, callback, out_template], acc} end def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do diff --git a/nx/test/nx/defn/elixir_call_evaluator_test.exs b/nx/test/nx/defn/elixir_call_evaluator_test.exs index a73b6e3b6f..964d8cd260 100644 --- a/nx/test/nx/defn/elixir_call_evaluator_test.exs +++ b/nx/test/nx/defn/elixir_call_evaluator_test.exs @@ -10,7 +10,7 @@ defmodule Nx.Defn.ElixirCallEvaluatorTest do defn add_offset(x) do out = %{x | type: Nx.Type.to_floating(x.type)} - Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.elixir_call(out, x, [offset: 10.0], fn t, opts -> Nx.add(Nx.as_type(t, :f32), opts[:offset]) end) end @@ -31,7 +31,7 @@ defmodule Nx.Defn.ElixirCallEvaluatorTest do out_template = {out0, out1} {a, b} = - Nx.elixir_call(out_template, [fx], fn t -> + Nx.elixir_call(out_template, fx, fn t -> {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} end) @@ -46,19 +46,4 @@ defmodule Nx.Defn.ElixirCallEvaluatorTest do expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) assert expected == y end - - defn invalid_order(x) do - out = %{x | type: Nx.Type.to_floating(x.type)} - - Nx.elixir_call(out, [[offset: 10.0], x], fn opts, t -> - Nx.add(Nx.as_type(t, :f32), opts[:offset]) - end) - end - - test "elixir_call enforces tensor arguments before lists" do - message = ~r|Nx.elixir_call/3 expects all tensor arguments to appear before any static arguments, but got| - assert_raise ArgumentError, message, fn -> - invalid_order(Nx.iota({2})) - end - end end diff --git a/torchx/test/torchx/defn/elixir_call_test.exs b/torchx/test/torchx/defn/elixir_call_test.exs index 9c504fa6c8..8e1c2b590d 100644 --- a/torchx/test/torchx/defn/elixir_call_test.exs +++ b/torchx/test/torchx/defn/elixir_call_test.exs @@ -12,7 +12,7 @@ defmodule Torchx.Defn.ElixirCallEvaluatorTest do defn add_offset(x) do out = %{x | type: Nx.Type.to_floating(x.type)} - Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.elixir_call(out, x, [offset: 10.0], fn t, opts -> Nx.add(Nx.as_type(t, :f32), opts[:offset]) end) end @@ -33,7 +33,7 @@ defmodule Torchx.Defn.ElixirCallEvaluatorTest do out_template = {out0, out1} {a, b} = - Nx.elixir_call(out_template, [fx], fn t -> + Nx.elixir_call(out_template, fx, fn t -> {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} end) From e2543896625b82d50ad1dea6e48f6d029b666aef Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 28 Nov 2025 07:31:09 -0300 Subject: [PATCH 24/42] chore: changes due to code review --- .../custom_calls/elixir_callback_bridge.cc | 6 ++- .../custom_calls/elixir_callback_bridge.h | 46 ++----------------- exla/lib/exla/callback_server.ex | 4 +- 3 files changed, 9 insertions(+), 47 deletions(-) diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc index c028e43fb5..2031f084c6 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -154,8 +154,10 @@ Result InvokeElixirCallback(int64_t callback_id, const std::vector &inputs, args_terms, pending); // Use the dispatcher pid registered via start_elixir_callback_bridge/1. - // Calling enif_whereis_pid from this non-scheduler thread is unsafe and - // was causing a segfault. + // We still are within the NIF thread that started the computation, + // but we don't know its env, therefore we cannot use enif_whereis_pid. + // enif_whereis_pid can be called with NULL, but only from non-ERTS + // threads, and doing so here results in a segfault. ErlNifPid dispatcher_pid = state->dispatcher_pid; enif_send(msg_env, &dispatcher_pid, msg_env, fine::encode(msg_env, msg)); enif_free_env(msg_env); diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h index fb954bfa32..d2e01c6d60 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h @@ -92,32 +92,8 @@ namespace fine { // Define encoding for {ffi_dtype, dims} into %EXLA.Typespec{} term. This is // used by the Elixir callback bridge to surface type and shape information // about callback arguments to the Elixir side. -template <> -struct Encoder>> { - static ERL_NIF_TERM - encode(ErlNifEnv *env, - const std::tuple> &spec) { - const xla::ffi::DataType &dtype = std::get<0>(spec); - const std::vector &dims = std::get<1>(spec); - - ERL_NIF_TERM keys[] = {fine::encode(env, exla::atoms::__struct__), - fine::encode(env, exla::atoms::type), - fine::encode(env, exla::atoms::shape)}; - - ERL_NIF_TERM values[] = {fine::encode(env, exla::atoms::ElixirEXLATypespec), - encode_type(env, dtype), - encode_shape(env, dtype, dims)}; - - ERL_NIF_TERM map; - if (!enif_make_map_from_arrays(env, keys, values, 3, &map)) { - throw std::runtime_error("encode: failed to make a map"); - } - - return map; - } - -private: - static ERL_NIF_TERM encode_type(ErlNifEnv *env, xla::ffi::DataType dtype) { +template <> struct Encoder { + static ERL_NIF_TERM encode(ErlNifEnv *env, const xla::ffi::DataType &dtype) { using DT = xla::ffi::DataType; // Tokens are encoded as the atom :token with empty shape. @@ -204,23 +180,7 @@ struct Encoder>> { env, std::make_tuple(type_name.value(), type_size.value())); } - throw std::invalid_argument("encode failed, unexpected ffi::DataType"); - } - - static ERL_NIF_TERM encode_shape(ErlNifEnv *env, xla::ffi::DataType dtype, - const std::vector &dims) { - if (dtype == xla::ffi::DataType::TOKEN) { - return enif_make_tuple(env, 0); - } - - std::vector dim_terms; - dim_terms.reserve(dims.size()); - - for (auto d : dims) { - dim_terms.push_back(fine::encode(env, d)); - } - - return enif_make_tuple_from_array(env, dim_terms.data(), dim_terms.size()); + throw std::invalid_argument("encode failed, unexpected xla::ffi::DataType"); } }; diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index c56478bc5a..895428b575 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -198,12 +198,12 @@ defmodule EXLA.CallbackServer do defp decode_args(args_spec, arg_template) when is_list(args_spec) do result = Enum.reduce_while(args_spec, {:ok, []}, fn - {bin, %EXLA.Typespec{type: type, shape: shape}}, {:ok, acc} -> + {bin, {type, shape_list}}, {:ok, acc} -> try do tensor = bin |> Nx.from_binary(type) - |> Nx.reshape(shape) + |> Nx.reshape(List.to_tuple(shape_list)) {:cont, {:ok, [tensor | acc]}} rescue From 4853248558d9a18b5d8b0e8fe05d6394fff71851 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 28 Nov 2025 07:34:01 -0300 Subject: [PATCH 25/42] fix: use Nx.compatible --- exla/lib/exla/callback_server.ex | 35 ++++---------------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 895428b575..4efe0ec2e9 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -160,41 +160,14 @@ defmodule EXLA.CallbackServer do error value -> - case ensure_compatible(value, out_template) do - {:ok, compatible} -> {:ok, compatible} - {:error, reason} -> {:error, reason} + if Nx.compatible?(value, out_template) do + {:ok, value} + else + {:error, {:shape_mismatch, value, out_template}} end end end - defp ensure_compatible(left, right) when is_tuple(left) and is_tuple(right) do - if tuple_size(left) == tuple_size(right) do - [Tuple.to_list(left), Tuple.to_list(right)] - |> Enum.zip_with(fn [l, r] -> - case ensure_compatible(l, r) do - {:ok, _} -> :ok - {:error, reason} -> throw({:error, reason}) - end - end) - - {:ok, left} - else - {:error, {:mismatched_tuple_size, left, right}} - end - catch - {:error, reason} -> {:error, reason} - end - - defp ensure_compatible(%Nx.Tensor{} = left, %Nx.Tensor{} = right) do - if left.shape == right.shape and left.type == right.type and left.names == right.names do - {:ok, left} - else - {:error, {:shape_mismatch, left, right}} - end - end - - defp ensure_compatible(left, right), do: {:error, {:invalid_result, left, right}} - defp decode_args(args_spec, arg_template) when is_list(args_spec) do result = Enum.reduce_while(args_spec, {:ok, []}, fn From 6961bab71c533fb8d9f139882ed3416a15473b17 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 28 Nov 2025 07:45:37 -0300 Subject: [PATCH 26/42] refactor: allow any type in static argument --- exla/lib/exla/callback_server.ex | 6 +-- exla/test/exla/defn/elixir_call_test.exs | 22 +++++++++ nx/lib/nx.ex | 19 ++++---- nx/lib/nx/backend.ex | 12 ----- nx/lib/nx/defn/evaluator.ex | 4 +- nx/lib/nx/defn/expr.ex | 60 ++++++++++++------------ 6 files changed, 68 insertions(+), 55 deletions(-) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 4efe0ec2e9..d10b5c27e0 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -60,13 +60,13 @@ defmodule EXLA.CallbackServer do Registers a callback function, its output template, argument template, and options, returning a callback id. - The same `{fun, out_template, arg_template, opts}` quadruple will always return the + The same `{fun, out_template, arg_template, static_arguments}` quadruple will always return the same id for the lifetime of this VM. This id is what the EXLA compiler encodes into the host `CustomCall` so the native side can reference the right callback. """ @spec register(fun(), Nx.t() | tuple(), term(), [term()]) :: callback_id() - def register(fun, out_template, arg_template, opts) when is_function(fun) and is_list(opts) do - GenServer.call(__MODULE__, {:register, fun, out_template, arg_template, opts}) + def register(fun, out_template, arg_template, static_arguments) when is_function(fun) do + GenServer.call(__MODULE__, {:register, fun, out_template, arg_template, static_arguments}) end ## GenServer callbacks diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index cd32de3d94..975dc914c0 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -96,4 +96,26 @@ defmodule EXLA.Defn.ElixirCallTest do assert_equal(add, Nx.add(x, y)) assert_equal(sub, Nx.subtract(x, y)) end + + def add_and_subtract_with_opts_callback({x, y}, {ref, pid}) do + send(pid, {:add_and_subtract_with_opts, ref}) + {Nx.add(x, y), Nx.subtract(x, y)} + end + + defn add_and_subtract_with_opts(x, y, opts) do + Nx.elixir_call({x, x}, {x, y}, {opts[:ref], opts[:pid]}, &add_and_subtract_with_opts_callback/2) + end + + test "elixir_call with non-list second argument" do + x = Nx.tensor([1, 2, 3]) + y = Nx.tensor([4, 5, 6]) + ref = make_ref() + + assert {add, sub} = add_and_subtract_with_opts(x, y, ref: ref, pid: self()) + + assert_equal(add, Nx.add(x, y)) + assert_equal(sub, Nx.subtract(x, y)) + + assert_receive {:add_and_subtract_with_opts, ^ref} + end end diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index c48a2cec95..c7731189d3 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2202,9 +2202,12 @@ defmodule Nx do This function allows integrating arbitrary Elixir code into `defn` graphs. It receives an output template (a tensor or a tuple of tensors) that specifies the expected shapes, types, and names of the result, a tensor - or tensor container argument, optional Elixir options, and the function + or tensor container argument, and an optional static argument, and the function itself. + The `static_argument` will be passed through the Elixir processes to the callback function + along with the executable Nx code. + Inside `defn`, this builds an expression node understood by compilers. Outside `defn` or on backends without special support, it executes `fun` directly and validates the result matches the template. @@ -2214,8 +2217,8 @@ defmodule Nx do elixir_call(output, tensor_or_container, [], fn value, _opts -> fun.(value) end) end - def elixir_call(output, tensor_or_container, opts, fun) - when is_list(opts) and is_function(fun, 2) do + def elixir_call(output, tensor_or_container, static_argument, fun) + when is_function(fun, 2) do {:arity, arity} = Function.info(fun, :arity) if arity != 2 do @@ -2230,12 +2233,10 @@ defmodule Nx do backend = Nx.Shared.list_impl!(tensors) result = - cond do - function_exported?(backend, :elixir_call, 4) -> - backend.elixir_call(output, tensor_or_container, opts, fun) - - true -> - fun.(tensor_or_container, opts) + if backend == Nx.Defn.Expr do + backend.elixir_call(output, tensor_or_container, static_argument, fun) + else + fun.(tensor_or_container, static_argument) end ensure_call_compatible!(result, output) diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index b588b785c5..650b5ab0d8 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -141,17 +141,6 @@ defmodule Nx.Backend do fallback to the default implementation. """ @callback optional(atom, [term], fun) :: tensor - - @doc """ - Invoked to execute a generic Elixir callback from within `defn`. - - The backend may choose how to execute it. For example, EXLA can lower - to a custom_call that interacts with Erlang/Elixir via C; pure CPU - backends may call the function directly. - """ - @callback elixir_call(out :: tensor | tuple, tensor_or_container :: term, keyword, fun) :: - tensor - @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor @callback eigh({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor @@ -172,7 +161,6 @@ defmodule Nx.Backend do @optional_callbacks [ optional: 3, - elixir_call: 4, solve: 3, determinant: 2, logical_not: 2, diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index ce7d6cad29..c4a3760d47 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -437,7 +437,7 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [tensor_expr, opts, fun, _out_template]}} = expr, + %{data: %Expr{args: [tensor_expr, static_argument, fun, _out_template]}} = expr, state, caches ) do @@ -447,7 +447,7 @@ defmodule Nx.Defn.Evaluator do if backend == Nx.Defn.Expr do {expr, caches} else - {fun.(tensor_value, opts), caches} + {fun.(tensor_value, static_argument), caches} end end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 10a70f4884..e63c0fc0d7 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -386,35 +386,6 @@ defmodule Nx.Defn.Expr do end end - @impl true - def elixir_call(out, tensor_or_container, opts, fun) when is_function(fun, 2) do - # Convert the entire tensor_or_container into an expression container, - # preserving its structure but ensuring all tensors are Expr-backed. - tensor_expr = - Composite.traverse(tensor_or_container, fn - %T{} = t -> to_expr(t) - other -> other - end) - - # Grab context from the first tensor in the flattened container. - [%T{data: %Expr{context: context}} | _] = - Composite.flatten_list([tensor_expr]) - - case out do - t when is_struct(t, Nx.Tensor) -> - out_template = Nx.to_template(t) - expr(t, context, :elixir_call, [tensor_expr, opts, fun, out_template]) - - tuple when is_tuple(tuple) -> - out_template = tuple_out(tuple_size(tuple)) - user_template = Nx.to_template(tuple) - expr_node = - expr(out_template, context, :elixir_call, [tensor_expr, opts, fun, user_template]) - - tuple(expr_node, Tuple.to_list(tuple)) - end - end - ## Nx.Defn AST callbacks @doc false @@ -1424,6 +1395,37 @@ defmodule Nx.Defn.Expr do context || acc end + @doc """ + Helper for defining an :elixir_call expression node. + """ + def elixir_call(out, tensor_or_container, static_argument, fun) when is_function(fun, 2) do + # Convert the entire tensor_or_container into an expression container, + # preserving its structure but ensuring all tensors are Expr-backed. + tensor_expr = + Composite.traverse(tensor_or_container, fn + %T{} = t -> to_expr(t) + other -> other + end) + + # Grab context from the first tensor in the flattened container. + [%T{data: %Expr{context: context}} | _] = + Composite.flatten_list([tensor_expr]) + + case out do + t when is_struct(t, Nx.Tensor) -> + out_template = Nx.to_template(t) + expr(t, context, :elixir_call, [tensor_expr, static_argument, fun, out_template]) + + tuple when is_tuple(tuple) -> + out_template = tuple_out(tuple_size(tuple)) + user_template = Nx.to_template(tuple) + expr_node = + expr(out_template, context, :elixir_call, [tensor_expr, static_argument, fun, user_template]) + + tuple(expr_node, Tuple.to_list(tuple)) + end + end + ## Constant helpers and related optimizations defp constant(%{vectorized_axes: [_ | _]} = out, number) do From 3731c9789be040713a71f2b43567e0d983a38a2e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 28 Nov 2025 07:50:28 -0300 Subject: [PATCH 27/42] chore: revert torchx --- torchx/mix.exs | 4 +- torchx/test/torchx/defn/elixir_call_test.exs | 51 -------------------- 2 files changed, 2 insertions(+), 53 deletions(-) delete mode 100644 torchx/test/torchx/defn/elixir_call_test.exs diff --git a/torchx/mix.exs b/torchx/mix.exs index 3cabdf15ae..e6e88bd54b 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -41,8 +41,8 @@ defmodule Torchx.MixProject do defp deps do [ - # {:nx, "~> 0.10.0"}, - {:nx, path: "../nx"}, + {:nx, "~> 0.10.0"}, + # {:nx, path: "../nx"}, {:fine, "~> 0.1.0", runtime: false}, {:ex_doc, "~> 0.29", only: :docs} ] diff --git a/torchx/test/torchx/defn/elixir_call_test.exs b/torchx/test/torchx/defn/elixir_call_test.exs deleted file mode 100644 index 8e1c2b590d..0000000000 --- a/torchx/test/torchx/defn/elixir_call_test.exs +++ /dev/null @@ -1,51 +0,0 @@ -defmodule Torchx.Defn.ElixirCallEvaluatorTest do - use ExUnit.Case, async: true - import Nx.Defn - import Nx.Testing - - setup do - Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) - Nx.default_backend(Torchx.Backend) - :ok - end - - defn add_offset(x) do - out = %{x | type: Nx.Type.to_floating(x.type)} - - Nx.elixir_call(out, x, [offset: 10.0], fn t, opts -> - Nx.add(Nx.as_type(t, :f32), opts[:offset]) - end) - end - - test "elixir_call with single output" do - x = Nx.iota({5}) - y = add_offset(x) - - expected = Nx.add(Nx.as_type(x, :f32), 10.0) - assert_equal(y, expected) - end - - defn split_and_sum(x) do - fx = Nx.as_type(x, :f32) - - out0 = fx - out1 = fx - out_template = {out0, out1} - - {a, b} = - Nx.elixir_call(out_template, fx, fn t -> - {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} - end) - - Nx.add(a, b) - end - - test "elixir_call with tuple output" do - x = Nx.tensor([1, 2, 3]) - y = split_and_sum(x) - - fx = Nx.as_type(x, :f32) - expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) - assert_equal(y, expected) - end -end From ea141a34fdab6f63405914cc712d9e01479c237c Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 28 Nov 2025 09:21:12 -0300 Subject: [PATCH 28/42] fix: handle containers --- exla/lib/exla/callback_server.ex | 10 ++---- exla/test/exla/defn/elixir_call_test.exs | 39 ++++++++++++++++++++++++ nx/lib/nx.ex | 21 ++++--------- nx/lib/nx/defn/expr.ex | 25 +++++++++++++++ 4 files changed, 73 insertions(+), 22 deletions(-) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index d10b5c27e0..9995e5c729 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -272,13 +272,9 @@ defmodule EXLA.CallbackServer do end end - defp encode_outputs(%Nx.Tensor{} = tensor) do - [Nx.to_binary(tensor)] - end - - defp encode_outputs(tuple) when is_tuple(tuple) do - tuple - |> Tuple.to_list() + defp encode_outputs(container) do + [container] + |> Nx.Defn.Composite.flatten_list() |> Enum.map(&Nx.to_binary/1) end diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index 975dc914c0..193011be19 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -5,6 +5,7 @@ defmodule EXLA.Defn.ElixirCallTest do setup do Nx.default_backend(EXLA.Backend) + Nx.Defn.default_options(compiler: EXLA) :ok end @@ -118,4 +119,42 @@ defmodule EXLA.Defn.ElixirCallTest do assert_receive {:add_and_subtract_with_opts, ^ref} end + + defn return_as_container(x, y, template_fun, container_fun) do + Nx.elixir_call(template_fun.(x, y), {x, y}, container_fun) + end + + test "elixir_call with container output" do + x = Nx.tensor([1, 2, 3]) + y = Nx.tensor([4, 5, 6]) + + ref = make_ref() + pid = self() + container_fun = fn {x, y} -> + send(pid, {:container_fun, ref}) + {x, y} + end + + template_fun = fn x, y -> {x, y} end + + assert {x_res, y_res} = return_as_container(x, y, template_fun, container_fun) + assert_equal(x_res, x) + assert_equal(y_res, y) + assert_receive {:container_fun, ^ref} + + ref = make_ref() + + container_fun = fn {x, y} -> + send(pid, {:container_fun, ref}) + %{x: x, y: y} + end + + template_fun = fn x, y -> %{x: x, y: y} end + + assert result = return_as_container(x, y, template_fun, container_fun) + assert %{x: _, y: _} = result + assert_equal(result.x, x) + assert_equal(result.y, y) + assert_receive {:container_fun, ^ref} + end end diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index c7731189d3..7652533270 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2242,22 +2242,13 @@ defmodule Nx do ensure_call_compatible!(result, output) end - defp ensure_call_compatible!(left, right) when tuple_size(left) == tuple_size(right) do - [Tuple.to_list(left), Tuple.to_list(right)] - |> Enum.zip_with(fn [l, r] -> ensure_call_compatible!(l, r) end) - - left - end - - defp ensure_call_compatible!( - %{shape: shape, type: type, names: names} = left, - %{shape: shape, type: type, names: names} - ), - do: left - defp ensure_call_compatible!(left, right) do - raise ArgumentError, - "expected the elixir_call function to match the given output template #{inspect(right)}, got: #{inspect(left)}" + if Nx.compatible?(left, right) do + left + else + raise ArgumentError, + "expected the elixir_call function to match the given output template #{inspect(right)}, got: #{inspect(left)}" + end end defp chunk([], data, type) do diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index e63c0fc0d7..e1a9a902d2 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1423,6 +1423,31 @@ defmodule Nx.Defn.Expr do expr(out_template, context, :elixir_call, [tensor_expr, static_argument, fun, user_template]) tuple(expr_node, Tuple.to_list(tuple)) + + container -> + user_template = Nx.to_template(container) + + leaf_templates = Composite.flatten_list([user_template]) + leaf_count = length(leaf_templates) + + root = + expr( + tuple_out(leaf_count), + context, + :elixir_call, + [tensor_expr, static_argument, fun, user_template] + ) + + {container_expr, _} = + Composite.traverse(user_template, {0, root}, fn + %T{} = template, {i, root} -> + {expr(template, context, :elem, [root, i]), {i + 1, root}} + + other, acc -> + {other, acc} + end) + + container_expr end end From 045281ca4719085f6a9df36c4bfba33372de44a6 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 28 Nov 2025 09:23:30 -0300 Subject: [PATCH 29/42] defend against exceptions --- exla/lib/exla/callback_server.ex | 45 ++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 9995e5c729..bcc059d77c 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -106,26 +106,33 @@ defmodule EXLA.CallbackServer do @impl true def handle_info({:exla_elixir_call, callback_id, args_spec, reply_tag}, %__MODULE__{} = state) do - case Map.fetch(state.callbacks, callback_id) do - {:ok, {fun, out_template, arg_template, opts}} -> - reply_payload = - args_spec - |> decode_args(arg_template) - |> run_callback(fun, opts, out_template) - |> encode_reply() - - send_reply(reply_tag, reply_payload) - - {:noreply, state} - - :error -> - Logger.error( - "EXLA.CallbackServer received callback id #{inspect(callback_id)} that is not registered" - ) + reply_payload = + try do + case Map.fetch(state.callbacks, callback_id) do + {:ok, {fun, out_template, arg_template, opts}} -> + args_spec + |> decode_args(arg_template) + |> run_callback(fun, opts, out_template) + |> encode_reply() + + :error -> + Logger.error( + "EXLA.CallbackServer received callback id #{inspect(callback_id)} that is not registered" + ) + + encode_reply({:error, :unknown_callback}) + end + rescue + exception -> + msg = Exception.message(exception) + encode_reply({:error, {:runtime_error, "Elixir callback server crashed: " <> msg}}) + catch + kind, reason -> + encode_reply({:error, {:runtime_error, "Elixir callback server #{kind}: #{inspect(reason)}"}}) + end - send_reply(reply_tag, {:error, :unknown_callback}) - {:noreply, state} - end + send_reply(reply_tag, reply_payload) + {:noreply, state} end def handle_info(other, state) do From 696dd5f2f9b121b813962221c2d7f9f8fc51d30d Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 04:12:44 -0300 Subject: [PATCH 30/42] refactor: pass callback id and pid as attributes --- .../exla/custom_calls/elixir_callback.cc | 42 ++++++------------ .../custom_calls/elixir_callback_bridge.cc | 42 ++++++++++++++++-- .../custom_calls/elixir_callback_bridge.h | 10 +++-- exla/lib/exla/application.ex | 2 +- exla/lib/exla/callback_server.ex | 8 ++-- exla/lib/exla/callback_server/supervisor.ex | 23 ++++++++++ exla/lib/exla/defn.ex | 37 +++++++++------- exla/lib/exla/mlir/value.ex | 44 ++++++++++++++++--- 8 files changed, 145 insertions(+), 63 deletions(-) create mode 100644 exla/lib/exla/callback_server/supervisor.ex diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc index 6f9c90fde1..4b6259ac2a 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -11,35 +11,16 @@ namespace ffi = xla::ffi; namespace { -ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, - ffi::RemainingRets rets) { - if (args.size() == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "exla_elixir_callback expects at least one argument"); - } - - // The first argument is a scalar S64 tensor carrying the callback id. - auto id_buf_or = args.get(0); - if (!id_buf_or) { - return id_buf_or.error(); - } - - ffi::AnyBuffer id_buf = *id_buf_or; - - if (id_buf.element_count() != 1 || - id_buf.element_type() != ffi::DataType::S64) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "exla_elixir_callback callback id must be scalar s64"); - } - - int64_t callback_id = id_buf.reinterpret_data()[0]; - - // Collect all remaining input tensors (excluding callback id) into - // lightweight payload views. +ffi::Error +exla_elixir_callback_impl(ffi::RemainingArgs args, uint64_t callback_id, + ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, + ffi::RemainingRets rets) { + // Collect all input tensors into lightweight payload views. std::vector inputs; - inputs.reserve(args.size() - 1); + inputs.reserve(args.size()); - for (size_t i = 1; i < args.size(); ++i) { + for (size_t i = 0; i < args.size(); ++i) { auto maybe_buf_or = args.get(i); if (!maybe_buf_or) { return maybe_buf_or.error(); @@ -84,7 +65,9 @@ ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, // Call back into Elixir through the bridge. On success, the bridge writes // results directly into the provided output buffers. exla::callback_bridge::Result result = - exla::callback_bridge::InvokeElixirCallback(callback_id, inputs, outputs); + exla::callback_bridge::InvokeElixirCallback( + callback_id, callback_server_pid_words, callback_server_pid_size, + inputs, outputs); if (!result.ok) { return ffi::Error(ffi::ErrorCode::kInternal, result.error); @@ -99,6 +82,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( exla_elixir_callback, exla_elixir_callback_impl, ffi::Ffi::Bind() .RemainingArgs() + .Attr("callback_id") + .Attr>("callback_server_pid") + .Attr("callback_server_pid_size") .RemainingRets()); XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "exla_elixir_callback", "Host", diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc index 2031f084c6..e7407beee2 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -115,8 +115,12 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, pending->cv.notify_one(); } -Result InvokeElixirCallback(int64_t callback_id, const std::vector &inputs, - const std::vector &outputs) { +Result +InvokeElixirCallback(uint64_t callback_id, + xla::ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, + const std::vector &inputs, + const std::vector &outputs) { auto state = GetBridgeState(); if (!state->dispatcher_set) { @@ -130,6 +134,37 @@ Result InvokeElixirCallback(int64_t callback_id, const std::vector &inputs, ErlNifEnv *msg_env = enif_alloc_env(); + // Reinterpret the 64-bit words as a contiguous byte buffer and use the + // original (unpadded) size when decoding the callback server pid term. + if (callback_server_pid_size > + callback_server_pid_words.size() * sizeof(int64_t)) { + Result res; + res.ok = false; + res.error = "inconsistent callback server pid size"; + return res; + } + + const unsigned char *pid_bytes = reinterpret_cast( + callback_server_pid_words.begin()); + + ERL_NIF_TERM callback_server_pid_term; + if (!enif_binary_to_term(msg_env, pid_bytes, callback_server_pid_size, + &callback_server_pid_term, 0)) { + Result res; + res.ok = false; + res.error = "failed to decode callback server pid term"; + return res; + } + + ErlNifPid callback_server_pid; + if (!enif_get_local_pid(msg_env, callback_server_pid_term, + &callback_server_pid)) { + Result res; + res.ok = false; + res.error = "failed to decode callback server pid"; + return res; + } + // Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send // plain binaries because the BEAM callback needs to own the data lifetime. std::vector &inputs, // but we don't know its env, therefore we cannot use enif_whereis_pid. // enif_whereis_pid can be called with NULL, but only from non-ERTS // threads, and doing so here results in a segfault. - ErlNifPid dispatcher_pid = state->dispatcher_pid; - enif_send(msg_env, &dispatcher_pid, msg_env, fine::encode(msg_env, msg)); + enif_send(msg_env, &callback_server_pid, msg_env, fine::encode(msg_env, msg)); enif_free_env(msg_env); std::unique_lock lock(pending->mu); diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h index d2e01c6d60..077d1ceb93 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h @@ -68,10 +68,14 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, // * Blocks the calling native thread until the reply arrives via // deliver_reply/3 // -// It returns an Result that either indicates success (data has +// It returns a Result that either indicates success (data has // been written into the registered output buffers) or an error message. -Result InvokeElixirCallback(int64_t callback_id, const std::vector &inputs, - const std::vector &outputs); +Result +InvokeElixirCallback(uint64_t callback_id, + xla::ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, + const std::vector &inputs, + const std::vector &outputs); fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid); diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 48e1e5c8db..83985c6183 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -23,7 +23,7 @@ defmodule EXLA.Application do EXLA.Defn.Lock, EXLA.Defn.LockedCache, {Task.Supervisor, name: EXLA.Defn.TaskSupervisor}, - EXLA.CallbackServer + EXLA.CallbackServer.Supervisor ] Supervisor.start_link(children, name: __MODULE__, strategy: :one_for_one) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index bcc059d77c..d7a1da9a59 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -53,7 +53,7 @@ defmodule EXLA.CallbackServer do `:exla_elixir_call` messages to this process. """ def start_link(_init_arg) do - GenServer.start_link(__MODULE__, :ok, name: __MODULE__) + GenServer.start_link(__MODULE__, :ok) end @doc """ @@ -64,9 +64,9 @@ defmodule EXLA.CallbackServer do same id for the lifetime of this VM. This id is what the EXLA compiler encodes into the host `CustomCall` so the native side can reference the right callback. """ - @spec register(fun(), Nx.t() | tuple(), term(), [term()]) :: callback_id() - def register(fun, out_template, arg_template, static_arguments) when is_function(fun) do - GenServer.call(__MODULE__, {:register, fun, out_template, arg_template, static_arguments}) + @spec register(pid(), fun(), Nx.t() | tuple(), term(), [term()]) :: callback_id() + def register(callback_server_pid, fun, out_template, arg_template, static_arguments) when is_function(fun) do + GenServer.call(callback_server_pid, {:register, fun, out_template, arg_template, static_arguments}) end ## GenServer callbacks diff --git a/exla/lib/exla/callback_server/supervisor.ex b/exla/lib/exla/callback_server/supervisor.ex new file mode 100644 index 0000000000..8071aea247 --- /dev/null +++ b/exla/lib/exla/callback_server/supervisor.ex @@ -0,0 +1,23 @@ +defmodule EXLA.CallbackServer.Supervisor do + @moduledoc false + + use DynamicSupervisor + + @impl true + def start_link(init_arg) do + DynamicSupervisor.start_link(__MODULE__, init_arg, name: __MODULE__) + end + + @impl true + def init(_init_arg) do + DynamicSupervisor.init(strategy: :one_for_one) + end + + def start_callback_server do + DynamicSupervisor.start_child(__MODULE__, {EXLA.CallbackServer, []}) + end + + def terminate_callback_server(pid) do + DynamicSupervisor.terminate_child(__MODULE__, pid) + end +end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index b789ed47b4..4de3ef4eae 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -39,10 +39,19 @@ defmodule EXLA.Defn do def __compile__(key, vars, fun, options) do {run_options, compile_options} = Keyword.pop(options, :run_options, []) debug? = Keyword.get(compile_options, :debug, false) - callback = &to_computation(&1, &2, &3, &4, &5, compile_options) + + # We start the callback server regardless if it's needed + # as it's relatively cheap to start it. + callback_server_pid = + case EXLA.CallbackServer.Supervisor.start_callback_server() do + {:ok, pid} -> pid + {:error, reason} -> raise "Failed to start EXLA.CallbackServer: #{inspect(reason)}" + end + + callback = &to_computation(&1, &2, &3, &4, &5, compile_options, callback_server_pid) {executable, {used_inputs, outputs, outfeed, _input_typespecs?}} = - compile(key, vars, fun, compile_options, 0, [], callback) + compile(key, vars, fun, compile_options, 0, [], callback, callback_server_pid) if compile_options[:module_compilation] == :to_mlir do throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs}) @@ -68,7 +77,7 @@ defmodule EXLA.Defn do end end - defp to_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do + defp to_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options, callback_server_pid) do params = Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg -> {pos, arg} @@ -83,7 +92,8 @@ defmodule EXLA.Defn do precision: Keyword.get(options, :precision, :default), builder: function, params: Map.new(params ++ outfeed.infeeds), - scope_ids: Tree.scope_ids(expr) + scope_ids: Tree.scope_ids(expr), + callback_server_pid: callback_server_pid } {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) @@ -138,7 +148,7 @@ defmodule EXLA.Defn do ## Compile - defp compile(key, vars, fun, options, used_buffers, used_inputs, to_computation) do + defp compile(key, vars, fun, options, used_buffers, used_inputs, to_computation, callback_server_pid) do {cache, options} = Keyword.pop(options, :cache, true) {hooks, options} = Keyword.pop(options, :hooks, %{}) {debug?, options} = Keyword.pop(options, :debug, false) @@ -235,6 +245,8 @@ defmodule EXLA.Defn do expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1) outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed, client) + options = Keyword.put(options, :callback_server_pid, callback_server_pid) + {xla_time, executable} = :timer.tc(fn -> EXLA.MLIR.Module.compile( @@ -549,7 +561,7 @@ defmodule EXLA.Defn do defp cached_recur_operator( :elixir_call, %T{data: %Expr{args: [tensor_expr, opts, fun, out_template]}} = expr, - %{client: %EXLA.Client{platform: :host}} = state, + %{client: %EXLA.Client{platform: :host}, callback_server_pid: callback_server_pid} = state, cache ) do # Flatten the tensor_or_container expression into its tensor leaves so we @@ -566,20 +578,11 @@ defmodule EXLA.Defn do # decoded tensors. arg_template = Nx.to_template(tensor_expr) - callback_id = EXLA.CallbackServer.register(fun, out_template, arg_template, opts) + callback_id = EXLA.CallbackServer.register(callback_server_pid, fun, out_template, arg_template, opts) typespecs = container_to_typespecs(out_template) - # Pass callback id as an extra scalar s64 operand at the end so that the - # native handler can retrieve it without relying on backend_config attrs. - callback_id_typespec = Typespec.tensor({:s, 64}, {}) - - callback_id_value = - Value.constant(state.builder, [callback_id], callback_id_typespec) - - operands = [callback_id_value | arg_values] - results = - Value.elixir_call(operands, typespecs) + Value.elixir_call(arg_values, typespecs, callback_server_pid, callback_id) {wrap_tuple_result(results, expr), cache} end diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index d92fc61320..79cc9ae641 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -835,19 +835,45 @@ defmodule EXLA.MLIR.Value do @doc """ Builds a StableHLO `custom_call` that targets the EXLA Elixir callback bridge. - The `callback_id` is a small integer assigned by `EXLA.CallbackServer` that + The `callback_id` is a unique integer generated by `EXLA.CallbackServer` that identifies which Elixir function should be invoked when the host callback - runs. It is passed as an extra scalar S64 tensor operand (first argument) to - the custom call so the native handler can load it before touching any tensor - payloads. + runs. It is passed as an extra scalar ui64 attribute. """ - def elixir_call([%Value{function: func} | _] = operands, typespecs) do + def elixir_call([%Value{function: func} | _] = operands, typespecs, callback_server_pid, callback_id) do result_types = typespecs_to_mlir_types(typespecs) + pid_bin = :erlang.term_to_binary(callback_server_pid) + pid_size = byte_size(pid_bin) + + # Zero-pad the pid binary so its size is a multiple of 8 and it can be + # represented as a list of 64-bit words. + pad = + case rem(pid_size, 8) do + 0 -> 0 + r -> 8 - r + end + + padded_bin = + if pad == 0 do + pid_bin + else + pid_bin <> :binary.copy(<<0>>, pad) + end + + callback_server_pid_words = + for <> do + x + end + attributes = [ call_target_name: attr_string("exla_elixir_callback"), # api_version 4 enables the typed FFI API used by our callback handler. - api_version: attr_i32(4) + api_version: attr_i32(4), + backend_config: attr_dict( + callback_id: attr_ui64(callback_id), + callback_server_pid: attr_array_i64_elements(callback_server_pid_words), + callback_server_pid_size: attr_ui64(pid_size) + ) ] op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) @@ -1015,6 +1041,7 @@ defmodule EXLA.MLIR.Value do defp attr_i32(number), do: "#{number} : i32" defp attr_i64(number), do: "#{number} : i64" + defp attr_ui64(number), do: "#{number} : ui64" defp attr_padding(padding) do list = Enum.flat_map(padding, &Tuple.to_list/1) @@ -1046,6 +1073,11 @@ defmodule EXLA.MLIR.Value do "##{name}<#{content}>" end + defp attr_dict(keyword_list) do + content = Enum.map_join(keyword_list, ", ", fn {key, value} -> "#{key} = #{value}" end) + "{#{content}}" + end + defp join_list(list) do "[" <> Enum.join(list, ", ") <> "]" end From 66588a84009f957b783e38999f9ce52bde366150 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 04:28:10 -0300 Subject: [PATCH 31/42] feat: gc process when Exla Executable is gc'd --- exla/c_src/exla/exla.cc | 18 +++++++++++-- exla/c_src/exla/exla_client.cc | 29 ++++++++++++++++----- exla/c_src/exla/exla_client.h | 9 +++++-- exla/lib/exla/callback_server.ex | 4 +++ exla/lib/exla/callback_server/supervisor.ex | 1 - exla/lib/exla/mlir/module.ex | 4 ++- exla/lib/exla/nif.ex | 3 ++- 7 files changed, 54 insertions(+), 14 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index a3db6f5463..e1eeef7f72 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -203,7 +203,8 @@ fine::ResourcePtr mlir_compile(ErlNifEnv *env, fine::ResourcePtr client, fine::ResourcePtr module, std::vector argument_layouts, int64_t num_replicas, - int64_t num_partitions, bool use_spmd, int64_t device_id) { + int64_t num_partitions, bool use_spmd, int64_t device_id, + fine::Term callback_server_pid_term) { auto build_options = xla::ExecutableBuildOptions(); build_options.set_num_replicas(num_replicas); @@ -216,8 +217,21 @@ mlir_compile(ErlNifEnv *env, fine::ResourcePtr client, build_options.set_device_ordinal(device_id); } + // Decode the optional callback server pid. If the term is a pid, we convert + // it to an ErlNifPid; otherwise we treat it as "no pid" (e.g. nil). + absl::optional pid_opt; + ERL_NIF_TERM pid_term = callback_server_pid_term; + + if (enif_is_pid(env, pid_term)) { + ErlNifPid pid; + if (enif_get_local_pid(env, pid_term, &pid)) { + pid_opt = pid; + } + } + return unwrap(client->Compile(module->module(), argument_layouts, - build_options, compile_portable_executable)); + build_options, compile_portable_executable, + pid_opt)); } FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND); diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index fc48505883..5ec2aee98a 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -97,9 +97,22 @@ ExlaBuffer::CopyToDevice(xla::PjRtDevice *dst_device) { ExlaExecutable::ExlaExecutable( std::unique_ptr executable, - absl::optional fingerprint, ExlaClient *client) + absl::optional fingerprint, ExlaClient *client, + absl::optional callback_server_pid) : executable_(std::move(executable)), fingerprint_(std::move(fingerprint)), - client_(client) {} + client_(client), callback_server_pid_(callback_server_pid) {} + +ExlaExecutable::~ExlaExecutable() { + if (callback_server_pid_.has_value()) { + ErlNifEnv *env = enif_alloc_env(); + // Notify the callback server that this executable has been dropped so it + // can clean up any associated state. + ERL_NIF_TERM msg = + fine::encode(env, fine::Atom("exla_elixir_call_executable_dropped")); + enif_send(nullptr, &callback_server_pid_.value(), env, msg); + enif_free_env(env); + } +} tsl::StatusOr> PjRtBufferFromBinary(xla::PjRtClient *client, ERL_NIF_TERM source_term, @@ -391,13 +404,15 @@ ExlaClient::DeserializeExecutable(std::string deserialized_executable) { EXLA_ASSIGN_OR_RETURN(absl::optional fingerprint, ExecutableFingerprint(executable)); - return fine::make_resource(std::move(executable), - std::move(fingerprint), this); + return fine::make_resource( + std::move(executable), std::move(fingerprint), this, + /*callback_server_pid=*/absl::nullopt); } tsl::StatusOr> ExlaClient::Compile( mlir::ModuleOp module, std::vector argument_layouts, - xla::ExecutableBuildOptions &options, bool compile_portable_executable) { + xla::ExecutableBuildOptions &options, bool compile_portable_executable, + absl::optional callback_server_pid) { std::vector layouts; layouts.reserve(argument_layouts.size()); for (auto shape : argument_layouts) { @@ -419,8 +434,8 @@ tsl::StatusOr> ExlaClient::Compile( EXLA_ASSIGN_OR_RETURN(absl::optional fingerprint, ExecutableFingerprint(executable)); - return fine::make_resource(std::move(executable), - std::move(fingerprint), this); + return fine::make_resource( + std::move(executable), std::move(fingerprint), this, callback_server_pid); } tsl::Status ExlaClient::TransferToInfeed(ErlNifEnv *env, diff --git a/exla/c_src/exla/exla_client.h b/exla/c_src/exla/exla_client.h index 323fa26acb..061e1e511f 100644 --- a/exla/c_src/exla/exla_client.h +++ b/exla/c_src/exla/exla_client.h @@ -65,7 +65,10 @@ class ExlaExecutable { using RunResult = std::vector; ExlaExecutable(std::unique_ptr executable, - absl::optional fingerprint, ExlaClient *client); + absl::optional fingerprint, ExlaClient *client, + absl::optional callback_server_pid); + + ~ExlaExecutable(); xla::PjRtLoadedExecutable *executable() { return executable_.get(); } @@ -80,6 +83,7 @@ class ExlaExecutable { std::unique_ptr executable_; absl::optional fingerprint_; ExlaClient *client_; + absl::optional callback_server_pid_; }; class ExlaClient { @@ -95,7 +99,8 @@ class ExlaClient { tsl::StatusOr> Compile(mlir::ModuleOp computation, std::vector argument_layouts, xla::ExecutableBuildOptions &options, - bool compile_portable_executable); + bool compile_portable_executable, + absl::optional callback_server_pid); tsl::StatusOr> BufferFromBinary(ERL_NIF_TERM binary_term, xla::Shape &shape, int device_id); diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index d7a1da9a59..dc819fcb4b 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -135,6 +135,10 @@ defmodule EXLA.CallbackServer do {:noreply, state} end + def handle_info(:exla_elixir_call_executable_dropped, state) do + {:stop, :normal, state} + end + def handle_info(other, state) do Logger.debug("EXLA.CallbackServer ignoring unexpected message: #{inspect(other)}") {:noreply, state} diff --git a/exla/lib/exla/callback_server/supervisor.ex b/exla/lib/exla/callback_server/supervisor.ex index 8071aea247..bf9d4aaf5c 100644 --- a/exla/lib/exla/callback_server/supervisor.ex +++ b/exla/lib/exla/callback_server/supervisor.ex @@ -3,7 +3,6 @@ defmodule EXLA.CallbackServer.Supervisor do use DynamicSupervisor - @impl true def start_link(init_arg) do DynamicSupervisor.start_link(__MODULE__, init_arg, name: __MODULE__) end diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index 04f1c38a3c..d1ba3d0b0b 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -92,6 +92,7 @@ defmodule EXLA.MLIR.Module do ) do num_replicas = Keyword.get(options, :num_replicas, 1) num_partitions = Keyword.get(options, :num_partitions, 1) + callback_server_pid = Keyword.get(options, :callback_server_pid, nil) # JAX comments say SPMD can lead to subtle bugs so they only enable # when strictly necessary, which is when num_partitions is greater than 1. @@ -118,7 +119,8 @@ defmodule EXLA.MLIR.Module do num_replicas, num_partitions, use_spmd, - device_id + device_id, + callback_server_pid ) end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index b579c015e8..83a672d4c8 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -37,7 +37,8 @@ defmodule EXLA.NIF do _num_replicas, _num_partitions, _use_spmd, - _device_id + _device_id, + _callback_server_pid ), do: err!() From 4aa2f66b6737fc921849eb44e20de3283822e2b1 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 06:43:08 -0300 Subject: [PATCH 32/42] Update exla/lib/exla/callback_server.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- exla/lib/exla/callback_server.ex | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index dc819fcb4b..924e023a5b 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -122,13 +122,10 @@ defmodule EXLA.CallbackServer do encode_reply({:error, :unknown_callback}) end - rescue - exception -> - msg = Exception.message(exception) - encode_reply({:error, {:runtime_error, "Elixir callback server crashed: " <> msg}}) catch kind, reason -> - encode_reply({:error, {:runtime_error, "Elixir callback server #{kind}: #{inspect(reason)}"}}) + formatted = Exception.format(kind, reason, __STACKTRACE__) + encode_reply({:error, {:runtime_error, "Elixir callback server crashed: #{formatted}"}}) end send_reply(reply_tag, reply_payload) From 741a6b51679fd9def864e31dfba57b023dff8fb2 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 06:54:03 -0300 Subject: [PATCH 33/42] single pass materialize --- exla/lib/exla/callback_server.ex | 62 +++++++++++------------- exla/lib/exla/defn.ex | 28 +++++++++-- exla/lib/exla/mlir/value.ex | 18 ++++--- exla/test/exla/defn/elixir_call_test.exs | 8 ++- 4 files changed, 71 insertions(+), 45 deletions(-) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 924e023a5b..16babeb8e4 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -65,8 +65,12 @@ defmodule EXLA.CallbackServer do the host `CustomCall` so the native side can reference the right callback. """ @spec register(pid(), fun(), Nx.t() | tuple(), term(), [term()]) :: callback_id() - def register(callback_server_pid, fun, out_template, arg_template, static_arguments) when is_function(fun) do - GenServer.call(callback_server_pid, {:register, fun, out_template, arg_template, static_arguments}) + def register(callback_server_pid, fun, out_template, arg_template, static_arguments) + when is_function(fun) do + GenServer.call( + callback_server_pid, + {:register, fun, out_template, arg_template, static_arguments} + ) end ## GenServer callbacks @@ -89,7 +93,11 @@ defmodule EXLA.CallbackServer do end @impl true - def handle_call({:register, fun, out_template, arg_template, opts}, _from, %__MODULE__{} = state) do + def handle_call( + {:register, fun, out_template, arg_template, opts}, + _from, + %__MODULE__{} = state + ) do key = {fun, out_template, arg_template, opts} case find_existing_id(state.callbacks, key) do @@ -177,33 +185,10 @@ defmodule EXLA.CallbackServer do end defp decode_args(args_spec, arg_template) when is_list(args_spec) do - result = - Enum.reduce_while(args_spec, {:ok, []}, fn - {bin, {type, shape_list}}, {:ok, acc} -> - try do - tensor = - bin - |> Nx.from_binary(type) - |> Nx.reshape(List.to_tuple(shape_list)) - - {:cont, {:ok, [tensor | acc]}} - rescue - exception -> - {:halt, {:error, {:decode_failed, exception}}} - end - - other, _acc -> - {:halt, {:error, {:invalid_args_spec, other}}} - end) - - case result do - {:ok, tensors} -> - tensors = Enum.reverse(tensors) - materialize_args(arg_template, tensors) - - {:error, _} = error -> - error - end + materialize_args(arg_template, args_spec) + catch + {:error, reason} -> + {:error, reason} end defp decode_args(other, _arg_template), do: {:error, {:invalid_args_spec, other}} @@ -264,11 +249,20 @@ defmodule EXLA.CallbackServer do {:error, {:runtime_error, msg}} end - defp materialize_args(arg_template, tensors) do + defp materialize_args(arg_template, args_spec) do {container, remaining} = - Nx.Defn.Composite.traverse(arg_template, tensors, fn - %Nx.Tensor{} = _template, [next | rest] -> - {next, rest} + Nx.Defn.Composite.traverse(arg_template, args_spec, fn + %Nx.Tensor{} = template, [{bin, {type, shape_list}} | rest] -> + decoded = + bin + |> Nx.from_binary(type) + |> Nx.reshape(List.to_tuple(shape_list)) + + if Nx.compatible?(decoded, template) do + {decoded, rest} + else + throw({:error, {:shape_mismatch, decoded, template}}) + end other, acc -> {other, acc} diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 4de3ef4eae..4352b7a49d 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -77,7 +77,15 @@ defmodule EXLA.Defn do end end - defp to_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options, callback_server_pid) do + defp to_computation( + %Function{} = function, + expr, + used_typespecs, + outfeed, + client, + options, + callback_server_pid + ) do params = Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg -> {pos, arg} @@ -148,7 +156,16 @@ defmodule EXLA.Defn do ## Compile - defp compile(key, vars, fun, options, used_buffers, used_inputs, to_computation, callback_server_pid) do + defp compile( + key, + vars, + fun, + options, + used_buffers, + used_inputs, + to_computation, + callback_server_pid + ) do {cache, options} = Keyword.pop(options, :cache, true) {hooks, options} = Keyword.pop(options, :hooks, %{}) {debug?, options} = Keyword.pop(options, :debug, false) @@ -561,7 +578,8 @@ defmodule EXLA.Defn do defp cached_recur_operator( :elixir_call, %T{data: %Expr{args: [tensor_expr, opts, fun, out_template]}} = expr, - %{client: %EXLA.Client{platform: :host}, callback_server_pid: callback_server_pid} = state, + %{client: %EXLA.Client{platform: :host}, callback_server_pid: callback_server_pid} = + state, cache ) do # Flatten the tensor_or_container expression into its tensor leaves so we @@ -578,7 +596,9 @@ defmodule EXLA.Defn do # decoded tensors. arg_template = Nx.to_template(tensor_expr) - callback_id = EXLA.CallbackServer.register(callback_server_pid, fun, out_template, arg_template, opts) + callback_id = + EXLA.CallbackServer.register(callback_server_pid, fun, out_template, arg_template, opts) + typespecs = container_to_typespecs(out_template) results = diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 79cc9ae641..444e4ec6ef 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -839,7 +839,12 @@ defmodule EXLA.MLIR.Value do identifies which Elixir function should be invoked when the host callback runs. It is passed as an extra scalar ui64 attribute. """ - def elixir_call([%Value{function: func} | _] = operands, typespecs, callback_server_pid, callback_id) do + def elixir_call( + [%Value{function: func} | _] = operands, + typespecs, + callback_server_pid, + callback_id + ) do result_types = typespecs_to_mlir_types(typespecs) pid_bin = :erlang.term_to_binary(callback_server_pid) @@ -869,11 +874,12 @@ defmodule EXLA.MLIR.Value do call_target_name: attr_string("exla_elixir_callback"), # api_version 4 enables the typed FFI API used by our callback handler. api_version: attr_i32(4), - backend_config: attr_dict( - callback_id: attr_ui64(callback_id), - callback_server_pid: attr_array_i64_elements(callback_server_pid_words), - callback_server_pid_size: attr_ui64(pid_size) - ) + backend_config: + attr_dict( + callback_id: attr_ui64(callback_id), + callback_server_pid: attr_array_i64_elements(callback_server_pid_words), + callback_server_pid_size: attr_ui64(pid_size) + ) ] op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs index 193011be19..dc1f2f7e14 100644 --- a/exla/test/exla/defn/elixir_call_test.exs +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -104,7 +104,12 @@ defmodule EXLA.Defn.ElixirCallTest do end defn add_and_subtract_with_opts(x, y, opts) do - Nx.elixir_call({x, x}, {x, y}, {opts[:ref], opts[:pid]}, &add_and_subtract_with_opts_callback/2) + Nx.elixir_call( + {x, x}, + {x, y}, + {opts[:ref], opts[:pid]}, + &add_and_subtract_with_opts_callback/2 + ) end test "elixir_call with non-list second argument" do @@ -130,6 +135,7 @@ defmodule EXLA.Defn.ElixirCallTest do ref = make_ref() pid = self() + container_fun = fn {x, y} -> send(pid, {:container_fun, ref}) {x, y} From c6d272033d286a02d623b5488fdaac4f88361a91 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 06:56:48 -0300 Subject: [PATCH 34/42] chore: revert container_to_typespecs --- exla/lib/exla/defn.ex | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 4352b7a49d..4d4c094616 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1966,25 +1966,14 @@ defmodule EXLA.Defn do end defp container_to_typespecs(container) do - containers = - if is_list(container) do - container - else - [container] - end - - containers - |> Enum.reject(&is_function/1) + [container] |> Nx.Defn.Composite.flatten_list() |> Enum.flat_map(fn %Nx.Tensor{type: {:tuple, _}, data: %{args: values}} -> Enum.flat_map(values, &container_to_typespecs/1) - %Nx.Tensor{} = t -> + t -> [Typespec.tensor(t.type, t.shape)] - - _other -> - [] end) end From c5a2cec3366817f868f304dbc38c97fb26f9b67e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:03:51 -0300 Subject: [PATCH 35/42] fix: do not leak callback servers on error --- exla/lib/exla/defn.ex | 48 ++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 4d4c094616..0ca3af1701 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -48,32 +48,42 @@ defmodule EXLA.Defn do {:error, reason} -> raise "Failed to start EXLA.CallbackServer: #{inspect(reason)}" end - callback = &to_computation(&1, &2, &3, &4, &5, compile_options, callback_server_pid) + try do + callback = &to_computation(&1, &2, &3, &4, &5, compile_options, callback_server_pid) - {executable, {used_inputs, outputs, outfeed, _input_typespecs?}} = - compile(key, vars, fun, compile_options, 0, [], callback, callback_server_pid) + {executable, {used_inputs, outputs, outfeed, _input_typespecs?}} = + compile(key, vars, fun, compile_options, 0, [], callback, callback_server_pid) - if compile_options[:module_compilation] == :to_mlir do - throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs}) - end + if compile_options[:module_compilation] == :to_mlir do + throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs}) + end - fn [args] -> - {time, lock} = - :timer.tc(fn -> - EXLA.Defn.Lock.lock(run_key(executable)) - end) + fn [args] -> + {time, lock} = + :timer.tc(fn -> + EXLA.Defn.Lock.lock(run_key(executable)) + end) - debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms") + debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms") - {time, res} = - :timer.tc(fn -> - maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options) - end) + {time, res} = + :timer.tc(fn -> + maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options) + end) - debug? && - Logger.debug("EXLA execution on device #{executable.device_id} in #{us_to_ms(time)}ms") + debug? && + Logger.debug("EXLA execution on device #{executable.device_id} in #{us_to_ms(time)}ms") - res + res + end + rescue + e -> + EXLA.CallbackServer.Supervisor.terminate_callback_server(callback_server_pid) + reraise e, __STACKTRACE__ + catch + kind, reason -> + EXLA.CallbackServer.Supervisor.terminate_callback_server(callback_server_pid) + :erlang.raise(kind, reason, __STACKTRACE__) end end From d20e83f03a1fb177db7f70ad827ab8eba45dcc7a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:06:28 -0300 Subject: [PATCH 36/42] refactor: skip supervisor module --- exla/lib/exla/application.ex | 2 +- exla/lib/exla/callback_server/supervisor.ex | 22 --------------------- exla/lib/exla/defn.ex | 6 +++--- 3 files changed, 4 insertions(+), 26 deletions(-) delete mode 100644 exla/lib/exla/callback_server/supervisor.ex diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 83985c6183..032166acc4 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -23,7 +23,7 @@ defmodule EXLA.Application do EXLA.Defn.Lock, EXLA.Defn.LockedCache, {Task.Supervisor, name: EXLA.Defn.TaskSupervisor}, - EXLA.CallbackServer.Supervisor + {DynamicSupervisor, name: EXLA.CallbackServer.Supervisor, strategy: :one_for_one} ] Supervisor.start_link(children, name: __MODULE__, strategy: :one_for_one) diff --git a/exla/lib/exla/callback_server/supervisor.ex b/exla/lib/exla/callback_server/supervisor.ex deleted file mode 100644 index bf9d4aaf5c..0000000000 --- a/exla/lib/exla/callback_server/supervisor.ex +++ /dev/null @@ -1,22 +0,0 @@ -defmodule EXLA.CallbackServer.Supervisor do - @moduledoc false - - use DynamicSupervisor - - def start_link(init_arg) do - DynamicSupervisor.start_link(__MODULE__, init_arg, name: __MODULE__) - end - - @impl true - def init(_init_arg) do - DynamicSupervisor.init(strategy: :one_for_one) - end - - def start_callback_server do - DynamicSupervisor.start_child(__MODULE__, {EXLA.CallbackServer, []}) - end - - def terminate_callback_server(pid) do - DynamicSupervisor.terminate_child(__MODULE__, pid) - end -end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 0ca3af1701..17d77324a7 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -43,7 +43,7 @@ defmodule EXLA.Defn do # We start the callback server regardless if it's needed # as it's relatively cheap to start it. callback_server_pid = - case EXLA.CallbackServer.Supervisor.start_callback_server() do + case DynamicSupervisor.start_child(EXLA.CallbackServer.Supervisor, {EXLA.CallbackServer, []}) do {:ok, pid} -> pid {:error, reason} -> raise "Failed to start EXLA.CallbackServer: #{inspect(reason)}" end @@ -78,11 +78,11 @@ defmodule EXLA.Defn do end rescue e -> - EXLA.CallbackServer.Supervisor.terminate_callback_server(callback_server_pid) + DynamicSupervisor.terminate_child(EXLA.CallbackServer.Supervisor, callback_server_pid) reraise e, __STACKTRACE__ catch kind, reason -> - EXLA.CallbackServer.Supervisor.terminate_callback_server(callback_server_pid) + DynamicSupervisor.terminate_child(EXLA.CallbackServer.Supervisor, callback_server_pid) :erlang.raise(kind, reason, __STACKTRACE__) end end From 66f62cc74daa681d250bd15a1400f5ce7359960a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:07:08 -0300 Subject: [PATCH 37/42] Update nx/lib/nx.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- nx/lib/nx.ex | 6 ------ 1 file changed, 6 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 7652533270..306e407119 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2219,12 +2219,6 @@ defmodule Nx do def elixir_call(output, tensor_or_container, static_argument, fun) when is_function(fun, 2) do - {:arity, arity} = Function.info(fun, :arity) - - if arity != 2 do - raise ArgumentError, - "expected elixir_call callback to have arity 2, got #{arity}" - end # Outside defn, we execute the callback directly or via the backend if it # provides a specialized implementation. We resolve the backend from all From 35eda2de88503cf5fa2fa972b877755ed84167ce Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:24:31 -0300 Subject: [PATCH 38/42] fix: proper container support --- nx/lib/nx/defn/evaluator.ex | 33 +++++++++++++--- .../nx/defn/elixir_call_evaluator_test.exs | 39 +++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index c4a3760d47..23839b4b7f 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -176,7 +176,7 @@ defmodule Nx.Defn.Evaluator do end defp compute_cache(:elixir_call, %{data: %Expr{args: [tensor_expr, _opts, _fun, _out]}}, state, cache) do - compute_cache(tensor_expr, state, cache) + composite_compute_cache(tensor_expr, state, cache) end defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do @@ -437,17 +437,18 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [tensor_expr, static_argument, fun, _out_template]}} = expr, + %{data: %Expr{args: [tensor_expr, static_argument, fun, out_template]}} = expr, state, caches ) do - {tensor_value, caches} = eval(tensor_expr, state, caches) - backend = Nx.Shared.list_impl!([tensor_value]) + {tensor_value, caches} = composite_eval(tensor_expr, state, caches) + backend = Nx.Shared.list_impl!(Composite.flatten_list([tensor_value])) if backend == Nx.Defn.Expr do {expr, caches} else - {fun.(tensor_value, static_argument), caches} + result = fun.(tensor_value, static_argument) + {reshape_elixir_call_result(result, out_template), caches} end end @@ -486,6 +487,28 @@ defmodule Nx.Defn.Evaluator do {value, [cache | caches]} end + + defp reshape_elixir_call_result(result, %Nx.Tensor{} = template) do + # Single-tensor output: just ensure compatibility with the template. + if not Nx.compatible?(template, result) do + raise "expected the elixir_call function to match the given output template" + end + + result + end + + defp reshape_elixir_call_result(result, template_container) do + # Container (tuple/map/etc) output: we expect the callback to return + # a container with the same flattened tensor leaves as the template. + if not Nx.compatible?(result, template_container) do + raise "expected the elixir_call function to match the given output template" + end + + result_leaves = Composite.flatten_list([result]) + + List.to_tuple(result_leaves) + end + ## Control flow helpers defp while(acc, condition, block, state, caches) do diff --git a/nx/test/nx/defn/elixir_call_evaluator_test.exs b/nx/test/nx/defn/elixir_call_evaluator_test.exs index 964d8cd260..a26408acce 100644 --- a/nx/test/nx/defn/elixir_call_evaluator_test.exs +++ b/nx/test/nx/defn/elixir_call_evaluator_test.exs @@ -46,4 +46,43 @@ defmodule Nx.Defn.ElixirCallEvaluatorTest do expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) assert expected == y end + + defn return_as_container(x, y, template_fun, container_fun) do + Nx.elixir_call(template_fun.(x, y), {x, y}, container_fun) + end + + test "elixir_call with container output" do + x = Nx.tensor([1, 2, 3]) + y = Nx.tensor([4, 5, 6]) + + ref = make_ref() + pid = self() + + container_fun = fn {x, y} -> + send(pid, {:container_fun, ref}) + {x, y} + end + + template_fun = fn x, y -> {x, y} end + + assert {x_res, y_res} = return_as_container(x, y, template_fun, container_fun) + assert x_res == x + assert y_res == y + assert_receive {:container_fun, ^ref} + + ref = make_ref() + + container_fun = fn {x, y} -> + send(pid, {:container_fun, ref}) + %{x: x, y: {%{key: y}, Nx.s32(1)}} + end + + template_fun = fn x, y -> %{x: x, y: {%{key: y}, Nx.s32(1)}} end + + assert result = return_as_container(x, y, template_fun, container_fun) + assert %{x: _, y: {%{key: _}, _}} = result + assert result.x == x + assert result.y == {%{key: y}, Nx.s32(1)} + assert_receive {:container_fun, ^ref} + end end From 5df4c5def6449a09888b440a71252e6fdaba9228 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:26:30 -0300 Subject: [PATCH 39/42] fix: values cannot be expr in defn devaluator --- nx/lib/nx.ex | 5 ++--- nx/lib/nx/defn/evaluator.ex | 20 +++++++++----------- nx/lib/nx/defn/expr.ex | 8 +++++++- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 306e407119..9ca770641e 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2219,7 +2219,6 @@ defmodule Nx do def elixir_call(output, tensor_or_container, static_argument, fun) when is_function(fun, 2) do - # Outside defn, we execute the callback directly or via the backend if it # provides a specialized implementation. We resolve the backend from all # tensors inside the container to support tuple/map containers. @@ -2228,9 +2227,9 @@ defmodule Nx do result = if backend == Nx.Defn.Expr do - backend.elixir_call(output, tensor_or_container, static_argument, fun) + backend.elixir_call(output, tensor_or_container, static_argument, fun) else - fun.(tensor_or_container, static_argument) + fun.(tensor_or_container, static_argument) end ensure_call_compatible!(result, output) diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 23839b4b7f..2eddc8918d 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -175,7 +175,12 @@ defmodule Nx.Defn.Evaluator do Map.put(cache, [:optional | id], optional_expr_cache) end - defp compute_cache(:elixir_call, %{data: %Expr{args: [tensor_expr, _opts, _fun, _out]}}, state, cache) do + defp compute_cache( + :elixir_call, + %{data: %Expr{args: [tensor_expr, _opts, _fun, _out]}}, + state, + cache + ) do composite_compute_cache(tensor_expr, state, cache) end @@ -437,19 +442,13 @@ defmodule Nx.Defn.Evaluator do defp eval_apply( :elixir_call, - %{data: %Expr{args: [tensor_expr, static_argument, fun, out_template]}} = expr, + %{data: %Expr{args: [tensor_expr, static_argument, fun, out_template]}}, state, caches ) do {tensor_value, caches} = composite_eval(tensor_expr, state, caches) - backend = Nx.Shared.list_impl!(Composite.flatten_list([tensor_value])) - - if backend == Nx.Defn.Expr do - {expr, caches} - else - result = fun.(tensor_value, static_argument) - {reshape_elixir_call_result(result, out_template), caches} - end + result = fun.(tensor_value, static_argument) + {reshape_elixir_call_result(result, out_template), caches} end defp eval_apply(op, %{vectorized_axes: [_ | _]} = ans, _state, _caches) do @@ -487,7 +486,6 @@ defmodule Nx.Defn.Evaluator do {value, [cache | caches]} end - defp reshape_elixir_call_result(result, %Nx.Tensor{} = template) do # Single-tensor output: just ensure compatibility with the template. if not Nx.compatible?(template, result) do diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index e1a9a902d2..599e4a56a6 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1419,8 +1419,14 @@ defmodule Nx.Defn.Expr do tuple when is_tuple(tuple) -> out_template = tuple_out(tuple_size(tuple)) user_template = Nx.to_template(tuple) + expr_node = - expr(out_template, context, :elixir_call, [tensor_expr, static_argument, fun, user_template]) + expr(out_template, context, :elixir_call, [ + tensor_expr, + static_argument, + fun, + user_template + ]) tuple(expr_node, Tuple.to_list(tuple)) From ebee6542fdbc61ab0c405ea71a65ff8ad6b842c8 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:33:49 -0300 Subject: [PATCH 40/42] docs: add examples --- nx/lib/nx.ex | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 9ca770641e..992a564e9f 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2208,6 +2208,37 @@ defmodule Nx do The `static_argument` will be passed through the Elixir processes to the callback function along with the executable Nx code. + ## Examples + + While most code inside `defn` is restricted, `elixir_call/4` allows you + to perform arbitrary Elixir operations, such as message passing: + + iex> pid = self() + iex> x = Nx.tensor([1, 2, 3]) + iex> out = Nx.template({3}, {:s, 32}) + iex> _ = + ...> Nx.elixir_call(out, x, fn t -> + ...> send(pid, {:sum, Enum.sum(Nx.to_flat_list(t))}) + ...> t + ...> end) + iex> receive do {:sum, value} -> value end + 6 + + You can also use the `static_argument` to pass non-tensor metadata to + your callback while still validating the tensor result against a template: + + iex> pid = self() + iex> x = Nx.tensor([1, 2, 3]) + iex> y = Nx.tensor([4, 5, 6]) + iex> out = %{x: x, y: y} + iex> _ = + ...> Nx.elixir_call(out, {x, y}, [pid: pid], fn {a, b}, opts -> + ...> send(opts[:pid], {:dot, Nx.to_number(Nx.dot(a, b))}) + ...> %{x: a, y: b} + ...> end) + iex> receive do {:dot, value} -> value end + 32 + Inside `defn`, this builds an expression node understood by compilers. Outside `defn` or on backends without special support, it executes `fun` directly and validates the result matches the template. From 37fe2777081af471e90710f032b1e05f45ce33fc Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 14:55:27 -0300 Subject: [PATCH 41/42] Update exla/lib/exla/defn.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- exla/lib/exla/defn.ex | 4 ---- 1 file changed, 4 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 17d77324a7..10c50cd350 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -76,10 +76,6 @@ defmodule EXLA.Defn do res end - rescue - e -> - DynamicSupervisor.terminate_child(EXLA.CallbackServer.Supervisor, callback_server_pid) - reraise e, __STACKTRACE__ catch kind, reason -> DynamicSupervisor.terminate_child(EXLA.CallbackServer.Supervisor, callback_server_pid) From a2eaaf6c86e8cc1cb7088f4c743ca2ad50f029e2 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 29 Nov 2025 15:17:32 -0300 Subject: [PATCH 42/42] refactor: register callbacks based on expression id --- .../exla/custom_calls/elixir_callback.cc | 17 +++--- .../custom_calls/elixir_callback_bridge.cc | 37 +++++++++--- .../custom_calls/elixir_callback_bridge.h | 11 ++-- exla/lib/exla/callback_server.ex | 51 +++++----------- exla/lib/exla/defn.ex | 8 +-- exla/lib/exla/mlir/value.ex | 58 +++++++++++-------- nx/lib/nx.ex | 4 ++ 7 files changed, 98 insertions(+), 88 deletions(-) diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc index 4b6259ac2a..c6eba91651 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -11,11 +11,11 @@ namespace ffi = xla::ffi; namespace { -ffi::Error -exla_elixir_callback_impl(ffi::RemainingArgs args, uint64_t callback_id, - ffi::Span callback_server_pid_words, - uint64_t callback_server_pid_size, - ffi::RemainingRets rets) { +ffi::Error exla_elixir_callback_impl( + ffi::RemainingArgs args, ffi::Span callback_id_words, + uint64_t callback_id_size, + ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, ffi::RemainingRets rets) { // Collect all input tensors into lightweight payload views. std::vector inputs; inputs.reserve(args.size()); @@ -66,8 +66,8 @@ exla_elixir_callback_impl(ffi::RemainingArgs args, uint64_t callback_id, // results directly into the provided output buffers. exla::callback_bridge::Result result = exla::callback_bridge::InvokeElixirCallback( - callback_id, callback_server_pid_words, callback_server_pid_size, - inputs, outputs); + callback_id_words, callback_id_size, callback_server_pid_words, + callback_server_pid_size, inputs, outputs); if (!result.ok) { return ffi::Error(ffi::ErrorCode::kInternal, result.error); @@ -82,7 +82,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( exla_elixir_callback, exla_elixir_callback_impl, ffi::Ffi::Bind() .RemainingArgs() - .Attr("callback_id") + .Attr>("callback_id") + .Attr("callback_id_size") .Attr>("callback_server_pid") .Attr("callback_server_pid_size") .RemainingRets()); diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc index e7407beee2..de0f4d1912 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -115,12 +115,11 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, pending->cv.notify_one(); } -Result -InvokeElixirCallback(uint64_t callback_id, - xla::ffi::Span callback_server_pid_words, - uint64_t callback_server_pid_size, - const std::vector &inputs, - const std::vector &outputs) { +Result InvokeElixirCallback( + xla::ffi::Span callback_id_words, uint64_t callback_id_size, + xla::ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, const std::vector &inputs, + const std::vector &outputs) { auto state = GetBridgeState(); if (!state->dispatcher_set) { @@ -135,7 +134,15 @@ InvokeElixirCallback(uint64_t callback_id, ErlNifEnv *msg_env = enif_alloc_env(); // Reinterpret the 64-bit words as a contiguous byte buffer and use the - // original (unpadded) size when decoding the callback server pid term. + // original (unpadded) sizes when decoding the callback id and callback + // server pid terms. + if (callback_id_size > callback_id_words.size() * sizeof(int64_t)) { + Result res; + res.ok = false; + res.error = "inconsistent callback id size"; + return res; + } + if (callback_server_pid_size > callback_server_pid_words.size() * sizeof(int64_t)) { Result res; @@ -144,6 +151,18 @@ InvokeElixirCallback(uint64_t callback_id, return res; } + const unsigned char *id_bytes = + reinterpret_cast(callback_id_words.begin()); + + ERL_NIF_TERM callback_id_term; + if (!enif_binary_to_term(msg_env, id_bytes, callback_id_size, + &callback_id_term, 0)) { + Result res; + res.ok = false; + res.error = "failed to decode callback id term"; + return res; + } + const unsigned char *pid_bytes = reinterpret_cast( callback_server_pid_words.begin()); @@ -185,8 +204,8 @@ InvokeElixirCallback(uint64_t callback_id, args_terms.push_back(arg_tuple); } - auto msg = std::make_tuple(fine::Atom("exla_elixir_call"), callback_id, - args_terms, pending); + auto msg = std::make_tuple(fine::Atom("exla_elixir_call"), + fine::Term(callback_id_term), args_terms, pending); // Use the dispatcher pid registered via start_elixir_callback_bridge/1. // We still are within the NIF thread that started the computation, diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h index 077d1ceb93..177e57a305 100644 --- a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h @@ -70,12 +70,11 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, // // It returns a Result that either indicates success (data has // been written into the registered output buffers) or an error message. -Result -InvokeElixirCallback(uint64_t callback_id, - xla::ffi::Span callback_server_pid_words, - uint64_t callback_server_pid_size, - const std::vector &inputs, - const std::vector &outputs); +Result InvokeElixirCallback( + xla::ffi::Span callback_id_words, uint64_t callback_id_size, + xla::ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, const std::vector &inputs, + const std::vector &outputs); fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, ErlNifPid dispatcher_pid); diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex index 16babeb8e4..92c04ba3e9 100644 --- a/exla/lib/exla/callback_server.ex +++ b/exla/lib/exla/callback_server.ex @@ -4,10 +4,6 @@ defmodule EXLA.CallbackServer do This server has two responsibilities: - * Assign a stable integer callback id for each Elixir function + output - template pair that participates in `Nx.elixir_call/3` when using the - EXLA compiler. - * Receive callback requests from the native EXLA bridge thread, execute the Elixir function, validate the result against the expected output template, and reply back to native through a NIF. @@ -19,7 +15,7 @@ defmodule EXLA.CallbackServer do * Run a bridge thread that sends messages of the form: - {:exla_elixir_call, callback_id :: integer, args :: [Nx.Tensor.t()], reply_tag :: term()} + {:exla_elixir_call, callback_id :: term(), args :: [Nx.Tensor.t()], reply_tag :: term()} to this process and waits on a native future associated with `reply_tag`. @@ -31,17 +27,13 @@ defmodule EXLA.CallbackServer do require Logger - @type callback_id :: non_neg_integer() - - defstruct next_id: 1, - callbacks: %{} + defstruct callbacks: %{} @type t :: %__MODULE__{ - next_id: non_neg_integer(), # We store the original function, its output template, and any # static (non-tensor) arguments that should always be appended to # the decoded tensor arguments coming from native. - callbacks: %{callback_id() => {fun(), Nx.t() | tuple(), [term()]}} + callbacks: %{term() => {fun(), Nx.t() | tuple(), [term()]}} } ## Public API @@ -57,19 +49,18 @@ defmodule EXLA.CallbackServer do end @doc """ - Registers a callback function, its output template, argument template, and options, - returning a callback id. + Registers a callback function, its output template, argument template, and options. - The same `{fun, out_template, arg_template, static_arguments}` quadruple will always return the - same id for the lifetime of this VM. This id is what the EXLA compiler encodes into - the host `CustomCall` so the native side can reference the right callback. + The `id` is typically the underlying `Nx.Defn.Expr` id of the `:elixir_call` + node, which the EXLA compiler also encodes into the host `CustomCall` so the + native side can reference the right callback. """ - @spec register(pid(), fun(), Nx.t() | tuple(), term(), [term()]) :: callback_id() - def register(callback_server_pid, fun, out_template, arg_template, static_arguments) + @spec register(pid(), term(), fun(), Nx.t() | tuple(), term(), [term()]) :: :ok + def register(callback_server_pid, id, fun, out_template, arg_template, static_arguments) when is_function(fun) do GenServer.call( callback_server_pid, - {:register, fun, out_template, arg_template, static_arguments} + {:register, id, fun, out_template, arg_template, static_arguments} ) end @@ -94,22 +85,12 @@ defmodule EXLA.CallbackServer do @impl true def handle_call( - {:register, fun, out_template, arg_template, opts}, + {:register, id, fun, out_template, arg_template, opts}, _from, %__MODULE__{} = state ) do - key = {fun, out_template, arg_template, opts} - - case find_existing_id(state.callbacks, key) do - {:ok, id} -> - {:reply, id, state} - - :error -> - id = state.next_id - state = put_in(state.callbacks[id], {fun, out_template, arg_template, opts}) - state = %{state | next_id: id + 1} - {:reply, id, state} - end + state = put_in(state.callbacks[id], {fun, out_template, arg_template, opts}) + {:reply, :ok, state} end @impl true @@ -151,12 +132,6 @@ defmodule EXLA.CallbackServer do ## Internal helpers - defp find_existing_id(callbacks, key) do - Enum.reduce_while(callbacks, :error, fn {id, value}, _acc -> - if value == key, do: {:halt, {:ok, id}}, else: {:cont, :error} - end) - end - defp run_callback({:error, reason}, _fun, _opts, _out_template), do: {:error, reason} defp run_callback({:ok, tensor_args}, fun, opts, out_template) do diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 17d77324a7..06d9b588a1 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -587,7 +587,7 @@ defmodule EXLA.Defn do defp cached_recur_operator( :elixir_call, - %T{data: %Expr{args: [tensor_expr, opts, fun, out_template]}} = expr, + %T{data: %Expr{id: id, args: [tensor_expr, opts, fun, out_template]}} = expr, %{client: %EXLA.Client{platform: :host}, callback_server_pid: callback_server_pid} = state, cache @@ -606,13 +606,13 @@ defmodule EXLA.Defn do # decoded tensors. arg_template = Nx.to_template(tensor_expr) - callback_id = - EXLA.CallbackServer.register(callback_server_pid, fun, out_template, arg_template, opts) + :ok = + EXLA.CallbackServer.register(callback_server_pid, id, fun, out_template, arg_template, opts) typespecs = container_to_typespecs(out_template) results = - Value.elixir_call(arg_values, typespecs, callback_server_pid, callback_id) + Value.elixir_call(arg_values, typespecs, callback_server_pid, id) {wrap_tuple_result(results, expr), cache} end diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 444e4ec6ef..5190d83f52 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -835,9 +835,10 @@ defmodule EXLA.MLIR.Value do @doc """ Builds a StableHLO `custom_call` that targets the EXLA Elixir callback bridge. - The `callback_id` is a unique integer generated by `EXLA.CallbackServer` that - identifies which Elixir function should be invoked when the host callback - runs. It is passed as an extra scalar ui64 attribute. + The `callback_id` is typically the underlying `Nx.Defn.Expr` id of the + `:elixir_call` node. It is encoded as a binary (via `:erlang.term_to_binary/1`) + and then represented as a list of 64-bit words in the custom call attributes, + similar to how we encode the callback server PID. """ def elixir_call( [%Value{function: func} | _] = operands, @@ -847,42 +848,53 @@ defmodule EXLA.MLIR.Value do ) do result_types = typespecs_to_mlir_types(typespecs) - pid_bin = :erlang.term_to_binary(callback_server_pid) - pid_size = byte_size(pid_bin) + {callback_server_pid_words, callback_server_pid_size} = + term_to_int64_list(callback_server_pid) - # Zero-pad the pid binary so its size is a multiple of 8 and it can be + {callback_id_words, callback_id_size} = + term_to_int64_list(callback_id) + + attributes = [ + call_target_name: attr_string("exla_elixir_callback"), + # api_version 4 enables the typed FFI API used by our callback handler. + api_version: attr_i32(4), + backend_config: + attr_dict( + callback_id: attr_array_i64_elements(callback_id_words), + callback_id_size: attr_ui64(callback_id_size), + callback_server_pid: attr_array_i64_elements(callback_server_pid_words), + callback_server_pid_size: attr_ui64(callback_server_pid_size) + ) + ] + + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + end + + defp term_to_int64_list(term) do + bin = :erlang.term_to_binary(term) + size = byte_size(bin) + + # Zero-pad the binary so its size is a multiple of 8 and it can be # represented as a list of 64-bit words. pad = - case rem(pid_size, 8) do + case rem(size, 8) do 0 -> 0 r -> 8 - r end padded_bin = if pad == 0 do - pid_bin + bin else - pid_bin <> :binary.copy(<<0>>, pad) + bin <> :binary.copy(<<0>>, pad) end - callback_server_pid_words = + words = for <> do x end - attributes = [ - call_target_name: attr_string("exla_elixir_callback"), - # api_version 4 enables the typed FFI API used by our callback handler. - api_version: attr_i32(4), - backend_config: - attr_dict( - callback_id: attr_ui64(callback_id), - callback_server_pid: attr_array_i64_elements(callback_server_pid_words), - callback_server_pid_size: attr_ui64(pid_size) - ) - ] - - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + {words, size} end def get_tuple_element(%Value{function: func} = operand, index, typespec) do diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 992a564e9f..1c8f38bfcf 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2208,6 +2208,10 @@ defmodule Nx do The `static_argument` will be passed through the Elixir processes to the callback function along with the executable Nx code. + Tensors passed to the callback function are in the same backend as the inputs in the case + of `Nx.Defn.Evaluator` invocations. For other compilers, it is generally expected that + the tensors will be provided as `Nx.BinaryBackend` tensors. + ## Examples While most code inside `defn` is restricted, `elixir_call/4` allows you