-
Notifications
You must be signed in to change notification settings - Fork 212
feat: Nx.elixir_call/3 #1627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Nx.elixir_call/3 #1627
Conversation
|
Great! Although I'm not sure if it elixir_call is the best name? It also seems this relates to optional callbacks somehow? For example, optional callbacks require a default implementation in Elixir to be given, so they have similar dispatch mechanisms. On the other hand, we may also want to allow what is defined as an Elixir call to go through grad or be optimised in EXLA. So I'm thinking there is an overall unified mechanism where they are specified the same, but the compiler decides if it is a split or compiled, based on its structure at compile time. Glad to chat about it later! |
…-feat/elixir-call
exla/c_src/exla/exla.cc
Outdated
| ElixirCallbackBridgeState *GetElixirCallbackBridgeState() { | ||
| static ElixirCallbackBridgeState *state = new ElixirCallbackBridgeState(); | ||
| return state; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supposed to allocate the state every time, or rather be a global? Currently the NIFs call this function and don't deallocate the state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be "global". Although I should probably be attaching the lifecycle to the handler process' lifetime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be "global". Although I should probably be attaching the lifecycle to the handler process' lifetime.
| @@ -0,0 +1,107 @@ | |||
| #include "elixir_callback_bridge.h" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonatanklosko I've moved things around a bit and complied to basically all of your reviews. The only one I was unable to do is use named processes.
| end | ||
|
|
||
| defp ensure_compatible(%Nx.Tensor{} = left, %Nx.Tensor{} = right) do | ||
| if left.shape == right.shape and left.type == right.type and left.names == right.names do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we have Nx.compatible? or something?
| """ | ||
| @doc type: :backend | ||
| def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do | ||
| {:arity, arity} = Function.info(fun, :arity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamic arities will be a pain to type, my suggestion is to force either tuples or maps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe two arguments: tensors and options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's force 1 or 2 arguments. First is a tensor or tensor container, second is the options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with one tnesor-container argument and another for opts :)
nx/lib/nx/backend.ex
Outdated
| to a custom_call that interacts with Erlang/Elixir via C; pure CPU | ||
| backends may call the function directly. | ||
| """ | ||
| @callback elixir_call(out :: tensor | tuple, [term], fun) :: tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a compiler function, not a backend one? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean by that. This is a backend callback at least because we need Nx.Defn.Expr to have this defined, but when running Nx.Defn.Evaluator we can have backends call the function directly and so on.
nx/lib/nx.ex
Outdated
| any list arguments. Lists (including keyword lists) are treated as static | ||
| Elixir data that is appended to the callback at runtime, while the leading | ||
| non-list arguments are compiled as tensors and shipped to the target | ||
| backend. Passing a tensor after a list argument raises an error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need examples.
| expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) | ||
| assert_equal(y, expected) | ||
| end | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO these tests are not necessary as there is nothing torch specific!
| fx = Nx.as_type(x, :f32) | ||
| expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) | ||
| assert expected == y | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check map output too and potentially nesting? We have the foundation for this already inside Nx.Defn anyway.
Allows calling Elixir functions inside
defnexpressions.Currently limited to Nx.Defn.Evaluator. EXLA will be able to support this in a future PR by making use of Nx.Defn.Graph, splitting the expression before and after the elixir_call node, creating an isolated stage for the elixir call.