Skip to content

Conversation

@polvalente
Copy link
Contributor

Allows calling Elixir functions inside defn expressions.

Currently limited to Nx.Defn.Evaluator. EXLA will be able to support this in a future PR by making use of Nx.Defn.Graph, splitting the expression before and after the elixir_call node, creating an isolated stage for the elixir call.

@josevalim
Copy link
Collaborator

Great! Although I'm not sure if it elixir_call is the best name? It also seems this relates to optional callbacks somehow? For example, optional callbacks require a default implementation in Elixir to be given, so they have similar dispatch mechanisms. On the other hand, we may also want to allow what is defined as an Elixir call to go through grad or be optimised in EXLA. So I'm thinking there is an overall unified mechanism where they are specified the same, but the compiler decides if it is a split or compiled, based on its structure at compile time. Glad to chat about it later!

Comment on lines 547 to 549
ElixirCallbackBridgeState *GetElixirCallbackBridgeState() {
static ElixirCallbackBridgeState *state = new ElixirCallbackBridgeState();
return state;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to allocate the state every time, or rather be a global? Currently the NIFs call this function and don't deallocate the state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be "global". Although I should probably be attaching the lifecycle to the handler process' lifetime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be "global". Although I should probably be attaching the lifecycle to the handler process' lifetime.

@@ -0,0 +1,107 @@
#include "elixir_callback_bridge.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko I've moved things around a bit and complied to basically all of your reviews. The only one I was unable to do is use named processes.

end

defp ensure_compatible(%Nx.Tensor{} = left, %Nx.Tensor{} = right) do
if left.shape == right.shape and left.type == right.type and left.names == right.names do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have Nx.compatible? or something?

"""
@doc type: :backend
def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do
{:arity, arity} = Function.info(fun, :arity)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dynamic arities will be a pain to type, my suggestion is to force either tuples or maps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe two arguments: tensors and options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's force 1 or 2 arguments. First is a tensor or tensor container, second is the options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with one tnesor-container argument and another for opts :)

to a custom_call that interacts with Erlang/Elixir via C; pure CPU
backends may call the function directly.
"""
@callback elixir_call(out :: tensor | tuple, [term], fun) :: tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a compiler function, not a backend one? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean by that. This is a backend callback at least because we need Nx.Defn.Expr to have this defined, but when running Nx.Defn.Evaluator we can have backends call the function directly and so on.

nx/lib/nx.ex Outdated
any list arguments. Lists (including keyword lists) are treated as static
Elixir data that is appended to the callback at runtime, while the leading
non-list arguments are compiled as tensors and shipped to the target
backend. Passing a tensor after a list argument raises an error.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need examples.

expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0))
assert_equal(y, expected)
end
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO these tests are not necessary as there is nothing torch specific!

fx = Nx.as_type(x, :f32)
expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0))
assert expected == y
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check map output too and potentially nesting? We have the foundation for this already inside Nx.Defn anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants