diff --git a/exla/Makefile b/exla/Makefile index 43e79e3907..39371f9fcd 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -84,7 +84,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO) SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/ipc.cc SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc) -HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h +HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/custom_calls/elixir_callback_bridge.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o diff --git a/exla/c_src/exla/custom_calls/elixir_callback.cc b/exla/c_src/exla/custom_calls/elixir_callback.cc new file mode 100644 index 0000000000..c6eba91651 --- /dev/null +++ b/exla/c_src/exla/custom_calls/elixir_callback.cc @@ -0,0 +1,94 @@ +#include "elixir_callback_bridge.h" + +#include +#include +#include + +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; + +namespace { + +ffi::Error exla_elixir_callback_impl( + ffi::RemainingArgs args, ffi::Span callback_id_words, + uint64_t callback_id_size, + ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, ffi::RemainingRets rets) { + // Collect all input tensors into lightweight payload views. + std::vector inputs; + inputs.reserve(args.size()); + + for (size_t i = 0; i < args.size(); ++i) { + auto maybe_buf_or = args.get(i); + if (!maybe_buf_or) { + return maybe_buf_or.error(); + } + + ffi::AnyBuffer buf = *maybe_buf_or; + + exla::callback_bridge::Arg tensor; + tensor.dtype = buf.element_type(); + + auto dims = buf.dimensions(); + tensor.dims.assign(dims.begin(), dims.end()); + + tensor.data = reinterpret_cast(buf.untyped_data()); + tensor.size_bytes = buf.size_bytes(); + + inputs.push_back(std::move(tensor)); + } + + // Prepare output buffer descriptors so the callback bridge can write results + // directly into the final destination buffers. + std::vector outputs; + outputs.reserve(rets.size()); + + for (size_t i = 0; i < rets.size(); ++i) { + auto maybe_ret_or = rets.get(i); + if (!maybe_ret_or) { + return maybe_ret_or.error(); + } + + ffi::Result ret = *maybe_ret_or; + ffi::AnyBuffer out = *ret; + + exla::callback_bridge::OutputBuffer buf; + buf.data = static_cast(out.untyped_data()); + buf.size = ffi::ByteWidth(out.element_type()) * + static_cast(out.element_count()); + + outputs.push_back(buf); + } + + // Call back into Elixir through the bridge. On success, the bridge writes + // results directly into the provided output buffers. + exla::callback_bridge::Result result = + exla::callback_bridge::InvokeElixirCallback( + callback_id_words, callback_id_size, callback_server_pid_words, + callback_server_pid_size, inputs, outputs); + + if (!result.ok) { + return ffi::Error(ffi::ErrorCode::kInternal, result.error); + } + + return ffi::Error::Success(); +} + +} // namespace + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + exla_elixir_callback, exla_elixir_callback_impl, + ffi::Ffi::Bind() + .RemainingArgs() + .Attr>("callback_id") + .Attr("callback_id_size") + .Attr>("callback_server_pid") + .Attr("callback_server_pid_size") + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "exla_elixir_callback", "Host", + exla_elixir_callback); + + diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc new file mode 100644 index 0000000000..de0f4d1912 --- /dev/null +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.cc @@ -0,0 +1,228 @@ +#include "elixir_callback_bridge.h" + +#include + +namespace exla { + +namespace callback_bridge { + +struct BridgeState { + ErlNifPid dispatcher_pid; + bool dispatcher_set = false; +}; + +BridgeState *GetBridgeState() { + static BridgeState *state = new BridgeState(); + return state; +} + +fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid) { + (void)env; + auto state = GetBridgeState(); + state->dispatcher_pid = dispatcher_pid; + state->dispatcher_set = true; + return fine::Ok(); +} + +fine::Ok<> elixir_callback_reply(ErlNifEnv *env, + fine::ResourcePtr pending, + fine::Atom status, fine::Term result) { + deliver_reply(env, pending, status, result); + return fine::Ok(); +} + +fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid) { + (void)env; + auto state = GetBridgeState(); + + if (state->dispatcher_set && + std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) == + 0) { + state->dispatcher_set = false; + } + + return fine::Ok(); +} + +void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, + fine::Atom status, fine::Term result_term) { + Result cb_result; + + if (status == "ok") { + // Successful reply: result_term is a list of binaries that we decode into + // raw byte vectors via Fine and copy directly into the registered output + // buffers. + try { + auto payloads = fine::decode>(env, result_term); + + std::lock_guard lock(pending->mu); + + if (payloads.size() != pending->outputs.size()) { + cb_result.ok = false; + cb_result.error = + "mismatched number of callback outputs vs registered buffers"; + } else { + cb_result.ok = true; + + for (size_t i = 0; i < payloads.size(); ++i) { + const ErlNifBinary &bytes = payloads[i]; + auto &out_buf = pending->outputs[i]; + + if (bytes.size != out_buf.size) { + cb_result.ok = false; + cb_result.error = + "callback returned binary of unexpected size for result buffer"; + break; + } + + if (out_buf.size > 0) { + std::memcpy(out_buf.data, bytes.data, out_buf.size); + } + } + } + } catch (const std::exception &e) { + cb_result.ok = false; + cb_result.error = + std::string("failed to decode Elixir callback outputs: ") + e.what(); + } + } else { + // Error reply: result_term is expected to be {kind_atom, message :: binary} + cb_result.ok = false; + + try { + auto decoded = + fine::decode>(env, result_term); + fine::Atom kind = std::get<0>(decoded); + ErlNifBinary msg_bin = std::get<1>(decoded); + + cb_result.error = + "elixir callback returned " + kind.to_string() + ": " + + std::string(reinterpret_cast(msg_bin.data), + msg_bin.size); + } catch (const std::exception &) { + cb_result.error = "elixir callback returned error"; + } + } + + { + std::lock_guard lock(pending->mu); + pending->result = std::move(cb_result); + pending->done = true; + } + + pending->cv.notify_one(); +} + +Result InvokeElixirCallback( + xla::ffi::Span callback_id_words, uint64_t callback_id_size, + xla::ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, const std::vector &inputs, + const std::vector &outputs) { + auto state = GetBridgeState(); + + if (!state->dispatcher_set) { + Result res; + res.ok = false; + res.error = "EXLA elixir callback dispatcher is not set"; + return res; + } + + auto pending = fine::make_resource(outputs); + + ErlNifEnv *msg_env = enif_alloc_env(); + + // Reinterpret the 64-bit words as a contiguous byte buffer and use the + // original (unpadded) sizes when decoding the callback id and callback + // server pid terms. + if (callback_id_size > callback_id_words.size() * sizeof(int64_t)) { + Result res; + res.ok = false; + res.error = "inconsistent callback id size"; + return res; + } + + if (callback_server_pid_size > + callback_server_pid_words.size() * sizeof(int64_t)) { + Result res; + res.ok = false; + res.error = "inconsistent callback server pid size"; + return res; + } + + const unsigned char *id_bytes = + reinterpret_cast(callback_id_words.begin()); + + ERL_NIF_TERM callback_id_term; + if (!enif_binary_to_term(msg_env, id_bytes, callback_id_size, + &callback_id_term, 0)) { + Result res; + res.ok = false; + res.error = "failed to decode callback id term"; + return res; + } + + const unsigned char *pid_bytes = reinterpret_cast( + callback_server_pid_words.begin()); + + ERL_NIF_TERM callback_server_pid_term; + if (!enif_binary_to_term(msg_env, pid_bytes, callback_server_pid_size, + &callback_server_pid_term, 0)) { + Result res; + res.ok = false; + res.error = "failed to decode callback server pid term"; + return res; + } + + ErlNifPid callback_server_pid; + if (!enif_get_local_pid(msg_env, callback_server_pid_term, + &callback_server_pid)) { + Result res; + res.ok = false; + res.error = "failed to decode callback server pid"; + return res; + } + + // Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send + // plain binaries because the BEAM callback needs to own the data lifetime. + std::vector>>> + args_terms; + args_terms.reserve(inputs.size()); + + for (const auto &tensor : inputs) { + fine::Term bin_term = fine::make_new_binary( + msg_env, reinterpret_cast(tensor.data), + tensor.size_bytes); + + // Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via + // Fine's encoder defined in exla_nif_util.h. + auto arg_tuple = + std::make_tuple(bin_term, std::make_tuple(tensor.dtype, tensor.dims)); + + args_terms.push_back(arg_tuple); + } + + auto msg = std::make_tuple(fine::Atom("exla_elixir_call"), + fine::Term(callback_id_term), args_terms, pending); + + // Use the dispatcher pid registered via start_elixir_callback_bridge/1. + // We still are within the NIF thread that started the computation, + // but we don't know its env, therefore we cannot use enif_whereis_pid. + // enif_whereis_pid can be called with NULL, but only from non-ERTS + // threads, and doing so here results in a segfault. + enif_send(msg_env, &callback_server_pid, msg_env, fine::encode(msg_env, msg)); + enif_free_env(msg_env); + + std::unique_lock lock(pending->mu); + pending->cv.wait(lock, [&pending] { return pending->done; }); + + return pending->result; +} + +} // namespace callback_bridge + +} // namespace exla + + diff --git a/exla/c_src/exla/custom_calls/elixir_callback_bridge.h b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h new file mode 100644 index 0000000000..177e57a305 --- /dev/null +++ b/exla/c_src/exla/custom_calls/elixir_callback_bridge.h @@ -0,0 +1,192 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "../exla_nif_util.h" +#include "xla/ffi/api/ffi.h" +#include +#include + +namespace exla { + +namespace callback_bridge { + +struct Arg { + xla::ffi::DataType dtype; + std::vector dims; + const uint8_t *data = nullptr; + size_t size_bytes = 0; +}; + +// Result of an Elixir callback. On success, data has already been copied into +// the pre-registered output buffers held by Pending, so we only +// need to track success or an error message here. +struct Result { + bool ok = false; + std::string error; +}; + +// Host-side description of an output buffer that should receive the callback +// result for a given output index. +struct OutputBuffer { + uint8_t *data = nullptr; + size_t size = 0; +}; + +// Per-callback pending state used to synchronize between the XLA host thread +// and the Elixir-side dispatcher. This is exposed as a Fine resource so we +// can pass it as an opaque handle in messages instead of using integer tags. +struct Pending { + // Constructor used on the host callback path where we pre-register the + // destination buffers for each output. + explicit Pending(std::vector outputs) + : outputs(std::move(outputs)) {} + + std::mutex mu; + std::condition_variable cv; + bool done = false; + Result result; + std::vector outputs; +}; + +// Called from the Elixir side to deliver a reply for a given pending handle. +// We receive the reply as a status atom (e.g. :ok or :error) and a result +// term. For the :ok case the result is a list of binaries that we decode as +// ElixirCallbackTensor outputs via Fine's decoding machinery. +void deliver_reply(ErlNifEnv *env, fine::ResourcePtr pending, + fine::Atom status, fine::Term result); + +// Synchronously calls the Elixir callback identified by `callback_id` with the +// given tensor arguments. This function: +// +// * Allocates a unique Pending resource +// * Sends a message to the dispatcher via enif_send/3 +// * Blocks the calling native thread until the reply arrives via +// deliver_reply/3 +// +// It returns a Result that either indicates success (data has +// been written into the registered output buffers) or an error message. +Result InvokeElixirCallback( + xla::ffi::Span callback_id_words, uint64_t callback_id_size, + xla::ffi::Span callback_server_pid_words, + uint64_t callback_server_pid_size, const std::vector &inputs, + const std::vector &outputs); + +fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid); + +fine::Ok<> elixir_callback_reply(ErlNifEnv *env, + fine::ResourcePtr pending, + fine::Atom status, fine::Term result); + +fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, + ErlNifPid dispatcher_pid); + +} // namespace callback_bridge + +} // namespace exla + +namespace fine { + +// Define encoding for {ffi_dtype, dims} into %EXLA.Typespec{} term. This is +// used by the Elixir callback bridge to surface type and shape information +// about callback arguments to the Elixir side. +template <> struct Encoder { + static ERL_NIF_TERM encode(ErlNifEnv *env, const xla::ffi::DataType &dtype) { + using DT = xla::ffi::DataType; + + // Tokens are encoded as the atom :token with empty shape. + if (dtype == DT::TOKEN) { + return fine::encode(env, exla::atoms::token); + } + + std::optional type_name; + std::optional type_size; + + switch (dtype) { + case DT::PRED: + type_name = exla::atoms::pred; + type_size = 8; + break; + + case DT::U8: + type_name = exla::atoms::u; + type_size = 8; + break; + case DT::U16: + type_name = exla::atoms::u; + type_size = 16; + break; + case DT::U32: + type_name = exla::atoms::u; + type_size = 32; + break; + case DT::U64: + type_name = exla::atoms::u; + type_size = 64; + break; + + case DT::S8: + type_name = exla::atoms::s; + type_size = 8; + break; + case DT::S16: + type_name = exla::atoms::s; + type_size = 16; + break; + case DT::S32: + type_name = exla::atoms::s; + type_size = 32; + break; + case DT::S64: + type_name = exla::atoms::s; + type_size = 64; + break; + + case DT::F16: + type_name = exla::atoms::f; + type_size = 16; + break; + case DT::F32: + type_name = exla::atoms::f; + type_size = 32; + break; + case DT::F64: + type_name = exla::atoms::f; + type_size = 64; + break; + + case DT::BF16: + type_name = exla::atoms::bf; + type_size = 16; + break; + + case DT::C64: + type_name = exla::atoms::c; + type_size = 64; + break; + case DT::C128: + type_name = exla::atoms::c; + type_size = 128; + break; + + default: + break; + } + + if (type_name && type_size) { + return fine::encode( + env, std::make_tuple(type_name.value(), type_size.value())); + } + + throw std::invalid_argument("encode failed, unexpected xla::ffi::DataType"); + } +}; + +} // namespace fine + + diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index abde2f8fca..e1eeef7f72 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,8 +1,12 @@ +#include #include #include +#include #include #include +#include +#include "custom_calls/elixir_callback_bridge.h" #include "exla_client.h" #include "exla_cuda.h" #include "exla_log_sink.h" @@ -19,6 +23,8 @@ namespace exla { +using callback_bridge::Pending; + FINE_RESOURCE(llvm::StdThreadPool); FINE_RESOURCE(mlir::MLIRContext); FINE_RESOURCE(mlir::Value); @@ -28,6 +34,7 @@ FINE_RESOURCE(exla::ExlaBuffer); FINE_RESOURCE(exla::ExlaExecutable); FINE_RESOURCE(exla::MLIRModule); FINE_RESOURCE(exla::MLIRFunction); +FINE_RESOURCE(Pending); // MLIR Functions @@ -196,7 +203,8 @@ fine::ResourcePtr mlir_compile(ErlNifEnv *env, fine::ResourcePtr client, fine::ResourcePtr module, std::vector argument_layouts, int64_t num_replicas, - int64_t num_partitions, bool use_spmd, int64_t device_id) { + int64_t num_partitions, bool use_spmd, int64_t device_id, + fine::Term callback_server_pid_term) { auto build_options = xla::ExecutableBuildOptions(); build_options.set_num_replicas(num_replicas); @@ -209,8 +217,21 @@ mlir_compile(ErlNifEnv *env, fine::ResourcePtr client, build_options.set_device_ordinal(device_id); } + // Decode the optional callback server pid. If the term is a pid, we convert + // it to an ErlNifPid; otherwise we treat it as "no pid" (e.g. nil). + absl::optional pid_opt; + ERL_NIF_TERM pid_term = callback_server_pid_term; + + if (enif_is_pid(env, pid_term)) { + ErlNifPid pid; + if (enif_get_local_pid(env, pid_term, &pid)) { + pid_opt = pid; + } + } + return unwrap(client->Compile(module->module(), argument_layouts, - build_options, compile_portable_executable)); + build_options, compile_portable_executable, + pid_opt)); } FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND); @@ -521,6 +542,16 @@ get_per_device_memory(ErlNifEnv *env, fine::ResourcePtr client) { FINE_NIF(get_per_device_memory, 0); +// Elixir callback bridge NIF registrations + +using callback_bridge::clear_elixir_callback_bridge; +using callback_bridge::elixir_callback_reply; +using callback_bridge::start_elixir_callback_bridge; + +FINE_NIF(start_elixir_callback_bridge, 0); +FINE_NIF(elixir_callback_reply, ERL_NIF_DIRTY_JOB_IO_BOUND); +FINE_NIF(clear_elixir_callback_bridge, 0); + // Logging fine::Ok<> start_log_sink(ErlNifEnv *env, ErlNifPid logger_pid) { diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index fc48505883..5ec2aee98a 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -97,9 +97,22 @@ ExlaBuffer::CopyToDevice(xla::PjRtDevice *dst_device) { ExlaExecutable::ExlaExecutable( std::unique_ptr executable, - absl::optional fingerprint, ExlaClient *client) + absl::optional fingerprint, ExlaClient *client, + absl::optional callback_server_pid) : executable_(std::move(executable)), fingerprint_(std::move(fingerprint)), - client_(client) {} + client_(client), callback_server_pid_(callback_server_pid) {} + +ExlaExecutable::~ExlaExecutable() { + if (callback_server_pid_.has_value()) { + ErlNifEnv *env = enif_alloc_env(); + // Notify the callback server that this executable has been dropped so it + // can clean up any associated state. + ERL_NIF_TERM msg = + fine::encode(env, fine::Atom("exla_elixir_call_executable_dropped")); + enif_send(nullptr, &callback_server_pid_.value(), env, msg); + enif_free_env(env); + } +} tsl::StatusOr> PjRtBufferFromBinary(xla::PjRtClient *client, ERL_NIF_TERM source_term, @@ -391,13 +404,15 @@ ExlaClient::DeserializeExecutable(std::string deserialized_executable) { EXLA_ASSIGN_OR_RETURN(absl::optional fingerprint, ExecutableFingerprint(executable)); - return fine::make_resource(std::move(executable), - std::move(fingerprint), this); + return fine::make_resource( + std::move(executable), std::move(fingerprint), this, + /*callback_server_pid=*/absl::nullopt); } tsl::StatusOr> ExlaClient::Compile( mlir::ModuleOp module, std::vector argument_layouts, - xla::ExecutableBuildOptions &options, bool compile_portable_executable) { + xla::ExecutableBuildOptions &options, bool compile_portable_executable, + absl::optional callback_server_pid) { std::vector layouts; layouts.reserve(argument_layouts.size()); for (auto shape : argument_layouts) { @@ -419,8 +434,8 @@ tsl::StatusOr> ExlaClient::Compile( EXLA_ASSIGN_OR_RETURN(absl::optional fingerprint, ExecutableFingerprint(executable)); - return fine::make_resource(std::move(executable), - std::move(fingerprint), this); + return fine::make_resource( + std::move(executable), std::move(fingerprint), this, callback_server_pid); } tsl::Status ExlaClient::TransferToInfeed(ErlNifEnv *env, diff --git a/exla/c_src/exla/exla_client.h b/exla/c_src/exla/exla_client.h index 323fa26acb..061e1e511f 100644 --- a/exla/c_src/exla/exla_client.h +++ b/exla/c_src/exla/exla_client.h @@ -65,7 +65,10 @@ class ExlaExecutable { using RunResult = std::vector; ExlaExecutable(std::unique_ptr executable, - absl::optional fingerprint, ExlaClient *client); + absl::optional fingerprint, ExlaClient *client, + absl::optional callback_server_pid); + + ~ExlaExecutable(); xla::PjRtLoadedExecutable *executable() { return executable_.get(); } @@ -80,6 +83,7 @@ class ExlaExecutable { std::unique_ptr executable_; absl::optional fingerprint_; ExlaClient *client_; + absl::optional callback_server_pid_; }; class ExlaClient { @@ -95,7 +99,8 @@ class ExlaClient { tsl::StatusOr> Compile(mlir::ModuleOp computation, std::vector argument_layouts, xla::ExecutableBuildOptions &options, - bool compile_portable_executable); + bool compile_portable_executable, + absl::optional callback_server_pid); tsl::StatusOr> BufferFromBinary(ERL_NIF_TERM binary_term, xla::Shape &shape, int device_id); diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 714f74f2da..b2babd53cb 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -4,10 +4,10 @@ #include #include -#include "xla/shape.h" -#include "xla/shape_util.h" #include "mlir/IR/Types.h" #include "stablehlo/dialect/StablehloOps.h" +#include "xla/shape.h" +#include "xla/shape_util.h" namespace exla { diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 9ec098a3e6..032166acc4 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -22,7 +22,8 @@ defmodule EXLA.Application do EXLA.Client, EXLA.Defn.Lock, EXLA.Defn.LockedCache, - {Task.Supervisor, name: EXLA.Defn.TaskSupervisor} + {Task.Supervisor, name: EXLA.Defn.TaskSupervisor}, + {DynamicSupervisor, name: EXLA.CallbackServer.Supervisor, strategy: :one_for_one} ] Supervisor.start_link(children, name: __MODULE__, strategy: :one_for_one) diff --git a/exla/lib/exla/callback_server.ex b/exla/lib/exla/callback_server.ex new file mode 100644 index 0000000000..92c04ba3e9 --- /dev/null +++ b/exla/lib/exla/callback_server.ex @@ -0,0 +1,268 @@ +defmodule EXLA.CallbackServer do + @moduledoc """ + Dispatcher and registry for `Nx.elixir_call/3` callbacks used by EXLA. + + This server has two responsibilities: + + * Receive callback requests from the native EXLA bridge thread, execute + the Elixir function, validate the result against the expected output + template, and reply back to native through a NIF. + + The native side is expected to: + + * Lower `:elixir_call` nodes to a CPU-only host `CustomCall` named + `"exla_elixir_callback"` with a callback id encoded in its attributes. + + * Run a bridge thread that sends messages of the form: + + {:exla_elixir_call, callback_id :: term(), args :: [Nx.Tensor.t()], reply_tag :: term()} + + to this process and waits on a native future associated with `reply_tag`. + + * Provide a NIF `EXLA.NIF.elixir_callback_reply/2` that completes the + native future when we send the reply back. + """ + + use GenServer + + require Logger + + defstruct callbacks: %{} + + @type t :: %__MODULE__{ + # We store the original function, its output template, and any + # static (non-tensor) arguments that should always be appended to + # the decoded tensor arguments coming from native. + callbacks: %{term() => {fun(), Nx.t() | tuple(), [term()]}} + } + + ## Public API + + @doc """ + Starts the callback server and registers it as the EXLA dispatcher process. + + The EXLA NIF is notified of the dispatcher PID so it can route + `:exla_elixir_call` messages to this process. + """ + def start_link(_init_arg) do + GenServer.start_link(__MODULE__, :ok) + end + + @doc """ + Registers a callback function, its output template, argument template, and options. + + The `id` is typically the underlying `Nx.Defn.Expr` id of the `:elixir_call` + node, which the EXLA compiler also encodes into the host `CustomCall` so the + native side can reference the right callback. + """ + @spec register(pid(), term(), fun(), Nx.t() | tuple(), term(), [term()]) :: :ok + def register(callback_server_pid, id, fun, out_template, arg_template, static_arguments) + when is_function(fun) do + GenServer.call( + callback_server_pid, + {:register, id, fun, out_template, arg_template, static_arguments} + ) + end + + ## GenServer callbacks + + @impl true + def init(:ok) do + # Inform native side that this process is the dispatcher for elixir callbacks + _ = EXLA.NIF.start_elixir_callback_bridge(self()) + + {:ok, %__MODULE__{}} + end + + @impl true + def terminate(_reason, _state) do + try do + EXLA.NIF.clear_elixir_callback_bridge(self()) + rescue + _ -> :ok + end + end + + @impl true + def handle_call( + {:register, id, fun, out_template, arg_template, opts}, + _from, + %__MODULE__{} = state + ) do + state = put_in(state.callbacks[id], {fun, out_template, arg_template, opts}) + {:reply, :ok, state} + end + + @impl true + def handle_info({:exla_elixir_call, callback_id, args_spec, reply_tag}, %__MODULE__{} = state) do + reply_payload = + try do + case Map.fetch(state.callbacks, callback_id) do + {:ok, {fun, out_template, arg_template, opts}} -> + args_spec + |> decode_args(arg_template) + |> run_callback(fun, opts, out_template) + |> encode_reply() + + :error -> + Logger.error( + "EXLA.CallbackServer received callback id #{inspect(callback_id)} that is not registered" + ) + + encode_reply({:error, :unknown_callback}) + end + catch + kind, reason -> + formatted = Exception.format(kind, reason, __STACKTRACE__) + encode_reply({:error, {:runtime_error, "Elixir callback server crashed: #{formatted}"}}) + end + + send_reply(reply_tag, reply_payload) + {:noreply, state} + end + + def handle_info(:exla_elixir_call_executable_dropped, state) do + {:stop, :normal, state} + end + + def handle_info(other, state) do + Logger.debug("EXLA.CallbackServer ignoring unexpected message: #{inspect(other)}") + {:noreply, state} + end + + ## Internal helpers + + defp run_callback({:error, reason}, _fun, _opts, _out_template), do: {:error, reason} + + defp run_callback({:ok, tensor_args}, fun, opts, out_template) do + result = + try do + fun.(tensor_args, opts) + rescue + exception -> + {:error, {:exception, exception, __STACKTRACE__}} + catch + kind, reason -> + {:error, {kind, reason}} + end + + case result do + {:error, _} = error -> + error + + value -> + if Nx.compatible?(value, out_template) do + {:ok, value} + else + {:error, {:shape_mismatch, value, out_template}} + end + end + end + + defp decode_args(args_spec, arg_template) when is_list(args_spec) do + materialize_args(arg_template, args_spec) + catch + {:error, reason} -> + {:error, reason} + end + + defp decode_args(other, _arg_template), do: {:error, {:invalid_args_spec, other}} + + defp encode_reply({:ok, value}), do: {:ok, encode_outputs(value)} + + # Shape mismatch between callback result and output template. + defp encode_reply({:error, {:shape_mismatch, left, right}}) do + msg = + "expected the elixir_call function to match the given output template " <> + "#{inspect(right)}, got: #{inspect(left)}" + + {:error, {:argument_error, msg}} + end + + # Callback returned something that isn't a tensor/tuple matching the template. + defp encode_reply({:error, {:invalid_result, left, right}}) do + msg = + "expected the elixir_call function to return a value compatible with the output " <> + "template #{inspect(right)}, got: #{inspect(left)}" + + {:error, {:argument_error, msg}} + end + + # Argument decoding failures. + defp encode_reply({:error, {:decode_failed, exception}}) do + msg = Exception.message(exception) + msg = "failed to decode Elixir callback arguments: #{msg}" + {:error, {:runtime_error, msg}} + end + + defp encode_reply({:error, {:invalid_args_spec, other}}) do + msg = "invalid args_spec for Elixir callback: #{inspect(other)}" + {:error, {:runtime_error, msg}} + end + + # Unknown callback id from native. + defp encode_reply({:error, :unknown_callback}) do + msg = "unknown EXLA elixir_call callback id" + {:error, {:runtime_error, msg}} + end + + # User-raised exceptions. + defp encode_reply({:error, {:exception, exception, _stack}}) do + msg = Exception.message(exception) + msg = "Elixir callback raised: #{msg}" + {:error, {:runtime_error, msg}} + end + + # Catches other error tuples (throws, exits, etc). + defp encode_reply({:error, {kind, reason}}) do + msg = "Elixir callback #{kind}: #{inspect(reason)}" + {:error, {:runtime_error, msg}} + end + + defp encode_reply({:error, reason}) do + msg = "Elixir callback error: #{inspect(reason)}" + {:error, {:runtime_error, msg}} + end + + defp materialize_args(arg_template, args_spec) do + {container, remaining} = + Nx.Defn.Composite.traverse(arg_template, args_spec, fn + %Nx.Tensor{} = template, [{bin, {type, shape_list}} | rest] -> + decoded = + bin + |> Nx.from_binary(type) + |> Nx.reshape(List.to_tuple(shape_list)) + + if Nx.compatible?(decoded, template) do + {decoded, rest} + else + throw({:error, {:shape_mismatch, decoded, template}}) + end + + other, acc -> + {other, acc} + end) + + case remaining do + [] -> {:ok, container} + _ -> {:error, {:invalid_args_spec, :extra_values}} + end + end + + defp encode_outputs(container) do + [container] + |> Nx.Defn.Composite.flatten_list() + |> Enum.map(&Nx.to_binary/1) + end + + defp send_reply(reply_tag, {status, result}) do + try do + EXLA.NIF.elixir_callback_reply(reply_tag, status, result) + rescue + _ -> + Logger.error( + "EXLA.CallbackServer failed to send reply to native for tag #{inspect(reply_tag)}" + ) + end + end +end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 413c38ce45..1621a1051e 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -39,36 +39,59 @@ defmodule EXLA.Defn do def __compile__(key, vars, fun, options) do {run_options, compile_options} = Keyword.pop(options, :run_options, []) debug? = Keyword.get(compile_options, :debug, false) - callback = &to_computation(&1, &2, &3, &4, &5, compile_options) - {executable, {used_inputs, outputs, outfeed, _input_typespecs?}} = - compile(key, vars, fun, compile_options, 0, [], callback) + # We start the callback server regardless if it's needed + # as it's relatively cheap to start it. + callback_server_pid = + case DynamicSupervisor.start_child(EXLA.CallbackServer.Supervisor, {EXLA.CallbackServer, []}) do + {:ok, pid} -> pid + {:error, reason} -> raise "Failed to start EXLA.CallbackServer: #{inspect(reason)}" + end - if compile_options[:module_compilation] == :to_mlir do - throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs}) - end + try do + callback = &to_computation(&1, &2, &3, &4, &5, compile_options, callback_server_pid) - fn [args] -> - {time, lock} = - :timer.tc(fn -> - EXLA.Defn.Lock.lock(run_key(executable)) - end) + {executable, {used_inputs, outputs, outfeed, _input_typespecs?}} = + compile(key, vars, fun, compile_options, 0, [], callback, callback_server_pid) - debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms") + if compile_options[:module_compilation] == :to_mlir do + throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs}) + end - {time, res} = - :timer.tc(fn -> - maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options) - end) + fn [args] -> + {time, lock} = + :timer.tc(fn -> + EXLA.Defn.Lock.lock(run_key(executable)) + end) - debug? && - Logger.debug("EXLA execution on device #{executable.device_id} in #{us_to_ms(time)}ms") + debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms") - res + {time, res} = + :timer.tc(fn -> + maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options) + end) + + debug? && + Logger.debug("EXLA execution on device #{executable.device_id} in #{us_to_ms(time)}ms") + + res + end + catch + kind, reason -> + DynamicSupervisor.terminate_child(EXLA.CallbackServer.Supervisor, callback_server_pid) + :erlang.raise(kind, reason, __STACKTRACE__) end end - defp to_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do + defp to_computation( + %Function{} = function, + expr, + used_typespecs, + outfeed, + client, + options, + callback_server_pid + ) do params = Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg -> {pos, arg} @@ -83,7 +106,8 @@ defmodule EXLA.Defn do precision: Keyword.get(options, :precision, :default), builder: function, params: Map.new(params ++ outfeed.infeeds), - scope_ids: Tree.scope_ids(expr) + scope_ids: Tree.scope_ids(expr), + callback_server_pid: callback_server_pid } {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) @@ -138,7 +162,16 @@ defmodule EXLA.Defn do ## Compile - defp compile(key, vars, fun, options, used_buffers, used_inputs, to_computation) do + defp compile( + key, + vars, + fun, + options, + used_buffers, + used_inputs, + to_computation, + callback_server_pid + ) do {cache, options} = Keyword.pop(options, :cache, true) {hooks, options} = Keyword.pop(options, :hooks, %{}) {debug?, options} = Keyword.pop(options, :debug, false) @@ -235,6 +268,8 @@ defmodule EXLA.Defn do expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1) outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed, client) + options = Keyword.put(options, :callback_server_pid, callback_server_pid) + {xla_time, executable} = :timer.tc(fn -> EXLA.MLIR.Module.compile( @@ -546,6 +581,51 @@ defmodule EXLA.Defn do end end + defp cached_recur_operator( + :elixir_call, + %T{data: %Expr{id: id, args: [tensor_expr, opts, fun, out_template]}} = expr, + %{client: %EXLA.Client{platform: :host}, callback_server_pid: callback_server_pid} = + state, + cache + ) do + # Flatten the tensor_or_container expression into its tensor leaves so we + # can compile each as an independent operand to the host callback. + tensor_exprs = Composite.flatten_list([tensor_expr]) + + {arg_values, cache} = + Enum.map_reduce(tensor_exprs, cache, fn arg, cache -> + recur_operator(arg, state, cache) |> unwrap_single_tensor!() + end) + + # Build a template container for the tensor_or_container argument so the + # callback server can reconstruct the full structure from a flat list of + # decoded tensors. + arg_template = Nx.to_template(tensor_expr) + + :ok = + EXLA.CallbackServer.register(callback_server_pid, id, fun, out_template, arg_template, opts) + + typespecs = container_to_typespecs(out_template) + + results = + Value.elixir_call(arg_values, typespecs, callback_server_pid, id) + + {wrap_tuple_result(results, expr), cache} + end + + defp cached_recur_operator( + :elixir_call, + _expr, + %{client: %EXLA.Client{platform: platform}}, + _cache + ) do + raise """ + Nx.elixir_call/3 is currently only supported for EXLA CPU (platform: :host), + but the active EXLA client is configured for platform #{inspect(platform)}. + Please run on the :host client or wait for future segmentation-based support. + """ + end + defp cached_recur_operator( :lu, %T{ diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index 04f1c38a3c..d1ba3d0b0b 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -92,6 +92,7 @@ defmodule EXLA.MLIR.Module do ) do num_replicas = Keyword.get(options, :num_replicas, 1) num_partitions = Keyword.get(options, :num_partitions, 1) + callback_server_pid = Keyword.get(options, :callback_server_pid, nil) # JAX comments say SPMD can lead to subtle bugs so they only enable # when strictly necessary, which is when num_partitions is greater than 1. @@ -118,7 +119,8 @@ defmodule EXLA.MLIR.Module do num_replicas, num_partitions, use_spmd, - device_id + device_id, + callback_server_pid ) end diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index f955e67200..5190d83f52 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -832,6 +832,71 @@ defmodule EXLA.MLIR.Value do {p, l, u} end + @doc """ + Builds a StableHLO `custom_call` that targets the EXLA Elixir callback bridge. + + The `callback_id` is typically the underlying `Nx.Defn.Expr` id of the + `:elixir_call` node. It is encoded as a binary (via `:erlang.term_to_binary/1`) + and then represented as a list of 64-bit words in the custom call attributes, + similar to how we encode the callback server PID. + """ + def elixir_call( + [%Value{function: func} | _] = operands, + typespecs, + callback_server_pid, + callback_id + ) do + result_types = typespecs_to_mlir_types(typespecs) + + {callback_server_pid_words, callback_server_pid_size} = + term_to_int64_list(callback_server_pid) + + {callback_id_words, callback_id_size} = + term_to_int64_list(callback_id) + + attributes = [ + call_target_name: attr_string("exla_elixir_callback"), + # api_version 4 enables the typed FFI API used by our callback handler. + api_version: attr_i32(4), + backend_config: + attr_dict( + callback_id: attr_array_i64_elements(callback_id_words), + callback_id_size: attr_ui64(callback_id_size), + callback_server_pid: attr_array_i64_elements(callback_server_pid_words), + callback_server_pid_size: attr_ui64(callback_server_pid_size) + ) + ] + + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + end + + defp term_to_int64_list(term) do + bin = :erlang.term_to_binary(term) + size = byte_size(bin) + + # Zero-pad the binary so its size is a multiple of 8 and it can be + # represented as a list of 64-bit words. + pad = + case rem(size, 8) do + 0 -> 0 + r -> 8 - r + end + + padded_bin = + if pad == 0 do + bin + else + bin <> :binary.copy(<<0>>, pad) + end + + words = + for <> do + x + end + + {words, size} + end + def get_tuple_element(%Value{function: func} = operand, index, typespec) do result_types = typespecs_to_mlir_types([typespec]) attributes = [index: attr_i32(index)] @@ -994,6 +1059,7 @@ defmodule EXLA.MLIR.Value do defp attr_i32(number), do: "#{number} : i32" defp attr_i64(number), do: "#{number} : i64" + defp attr_ui64(number), do: "#{number} : ui64" defp attr_padding(padding) do list = Enum.flat_map(padding, &Tuple.to_list/1) @@ -1025,6 +1091,11 @@ defmodule EXLA.MLIR.Value do "##{name}<#{content}>" end + defp attr_dict(keyword_list) do + content = Enum.map_join(keyword_list, ", ", fn {key, value} -> "#{key} = #{value}" end) + "{#{content}}" + end + defp join_list(list) do "[" <> Enum.join(list, ", ") <> "]" end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 203dc30fd8..83a672d4c8 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -37,7 +37,8 @@ defmodule EXLA.NIF do _num_replicas, _num_partitions, _use_spmd, - _device_id + _device_id, + _callback_server_pid ), do: err!() @@ -78,5 +79,10 @@ defmodule EXLA.NIF do def reset_peak_memory(_client), do: err!() def get_per_device_memory(_client), do: err!() + # Elixir callback bridge + def start_elixir_callback_bridge(_dispatcher_pid), do: err!() + def clear_elixir_callback_bridge(_dispatcher_pid), do: err!() + def elixir_callback_reply(_reply_tag, _status, _result), do: err!() + defp err!(), do: :erlang.nif_error(:undef) end diff --git a/exla/mix.exs b/exla/mix.exs index 5818232299..bfc7ed77ad 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -69,7 +69,7 @@ defmodule EXLA.MixProject do {:nx, path: "../nx"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, {:xla, "~> 0.9.0", runtime: false}, - {:fine, "~> 0.1.0", runtime: false}, + {:fine, "~> 0.1", runtime: false}, {:elixir_make, "~> 0.6", runtime: false}, {:benchee, "~> 1.0", only: :dev}, {:ex_doc, "~> 0.29", only: :docs}, diff --git a/exla/mix.lock b/exla/mix.lock index 4508f2a57c..2b18c93da9 100644 --- a/exla/mix.lock +++ b/exla/mix.lock @@ -5,7 +5,7 @@ "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, - "fine": {:hex, :fine, "0.1.0", "9bb99a5ff9b968f12c3b458fa1277c39e9a620b23a9439103703a25917293871", [:mix], [], "hexpm", "1d6485bf811b95dc6ae3d197c0e6f994880b86167a827983bb29cbfc03a02684"}, + "fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs new file mode 100644 index 0000000000..dc1f2f7e14 --- /dev/null +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -0,0 +1,166 @@ +defmodule EXLA.Defn.ElixirCallTest do + use ExUnit.Case, async: true + import Nx.Defn + import Nx.Testing + + setup do + Nx.default_backend(EXLA.Backend) + Nx.Defn.default_options(compiler: EXLA) + :ok + end + + defp add_offset_callback(t, opts) do + t + |> Nx.as_type(:f32) + |> Nx.add(opts[:offset]) + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, x, [offset: 10.0], &add_offset_callback/2) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert_equal(y, expected) + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, fx, fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end + + defn bad_callback(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, x, fn _t -> + # Wrong shape on purpose + Nx.tensor([1.0, 2.0, 3.0]) + end) + end + + test "elixir_call errors when result shape does not match template" do + x = Nx.iota({2}) + + assert_raise RuntimeError, + ~r/expected the elixir_call function to match the given output template/, + fn -> + bad_callback(x) + end + end + + test "works when using EXLA compiler directly" do + x = Nx.tensor([1, 2, 3]) + y = EXLA.jit_apply(&split_and_sum/1, [x]) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end + + def add_and_subtract_callback({x, y}) do + {Nx.add(x, y), Nx.subtract(x, y)} + end + + defn add_and_subtract(x, y) do + Nx.elixir_call({x, x}, {x, y}, &add_and_subtract_callback/1) + end + + test "elixir_call with tuple input" do + x = Nx.tensor([1, 2, 3]) + y = Nx.tensor([4, 5, 6]) + assert {add, sub} = add_and_subtract(x, y) + + assert_equal(add, Nx.add(x, y)) + assert_equal(sub, Nx.subtract(x, y)) + end + + def add_and_subtract_with_opts_callback({x, y}, {ref, pid}) do + send(pid, {:add_and_subtract_with_opts, ref}) + {Nx.add(x, y), Nx.subtract(x, y)} + end + + defn add_and_subtract_with_opts(x, y, opts) do + Nx.elixir_call( + {x, x}, + {x, y}, + {opts[:ref], opts[:pid]}, + &add_and_subtract_with_opts_callback/2 + ) + end + + test "elixir_call with non-list second argument" do + x = Nx.tensor([1, 2, 3]) + y = Nx.tensor([4, 5, 6]) + ref = make_ref() + + assert {add, sub} = add_and_subtract_with_opts(x, y, ref: ref, pid: self()) + + assert_equal(add, Nx.add(x, y)) + assert_equal(sub, Nx.subtract(x, y)) + + assert_receive {:add_and_subtract_with_opts, ^ref} + end + + defn return_as_container(x, y, template_fun, container_fun) do + Nx.elixir_call(template_fun.(x, y), {x, y}, container_fun) + end + + test "elixir_call with container output" do + x = Nx.tensor([1, 2, 3]) + y = Nx.tensor([4, 5, 6]) + + ref = make_ref() + pid = self() + + container_fun = fn {x, y} -> + send(pid, {:container_fun, ref}) + {x, y} + end + + template_fun = fn x, y -> {x, y} end + + assert {x_res, y_res} = return_as_container(x, y, template_fun, container_fun) + assert_equal(x_res, x) + assert_equal(y_res, y) + assert_receive {:container_fun, ^ref} + + ref = make_ref() + + container_fun = fn {x, y} -> + send(pid, {:container_fun, ref}) + %{x: x, y: y} + end + + template_fun = fn x, y -> %{x: x, y: y} end + + assert result = return_as_container(x, y, template_fun, container_fun) + assert %{x: _, y: _} = result + assert_equal(result.x, x) + assert_equal(result.y, y) + assert_receive {:container_fun, ^ref} + end +end diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 091372d005..1c8f38bfcf 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2196,6 +2196,89 @@ defmodule Nx do list end + @doc """ + Invokes an Elixir function from within `defn`. + + This function allows integrating arbitrary Elixir code into `defn` graphs. + It receives an output template (a tensor or a tuple of tensors) that + specifies the expected shapes, types, and names of the result, a tensor + or tensor container argument, and an optional static argument, and the function + itself. + + The `static_argument` will be passed through the Elixir processes to the callback function + along with the executable Nx code. + + Tensors passed to the callback function are in the same backend as the inputs in the case + of `Nx.Defn.Evaluator` invocations. For other compilers, it is generally expected that + the tensors will be provided as `Nx.BinaryBackend` tensors. + + ## Examples + + While most code inside `defn` is restricted, `elixir_call/4` allows you + to perform arbitrary Elixir operations, such as message passing: + + iex> pid = self() + iex> x = Nx.tensor([1, 2, 3]) + iex> out = Nx.template({3}, {:s, 32}) + iex> _ = + ...> Nx.elixir_call(out, x, fn t -> + ...> send(pid, {:sum, Enum.sum(Nx.to_flat_list(t))}) + ...> t + ...> end) + iex> receive do {:sum, value} -> value end + 6 + + You can also use the `static_argument` to pass non-tensor metadata to + your callback while still validating the tensor result against a template: + + iex> pid = self() + iex> x = Nx.tensor([1, 2, 3]) + iex> y = Nx.tensor([4, 5, 6]) + iex> out = %{x: x, y: y} + iex> _ = + ...> Nx.elixir_call(out, {x, y}, [pid: pid], fn {a, b}, opts -> + ...> send(opts[:pid], {:dot, Nx.to_number(Nx.dot(a, b))}) + ...> %{x: a, y: b} + ...> end) + iex> receive do {:dot, value} -> value end + 32 + + Inside `defn`, this builds an expression node understood by compilers. + Outside `defn` or on backends without special support, it executes `fun` + directly and validates the result matches the template. + """ + @doc type: :backend + def elixir_call(output, tensor_or_container, fun) when is_function(fun, 1) do + elixir_call(output, tensor_or_container, [], fn value, _opts -> fun.(value) end) + end + + def elixir_call(output, tensor_or_container, static_argument, fun) + when is_function(fun, 2) do + # Outside defn, we execute the callback directly or via the backend if it + # provides a specialized implementation. We resolve the backend from all + # tensors inside the container to support tuple/map containers. + tensors = Nx.Defn.Composite.flatten_list([tensor_or_container]) + backend = Nx.Shared.list_impl!(tensors) + + result = + if backend == Nx.Defn.Expr do + backend.elixir_call(output, tensor_or_container, static_argument, fun) + else + fun.(tensor_or_container, static_argument) + end + + ensure_call_compatible!(result, output) + end + + defp ensure_call_compatible!(left, right) do + if Nx.compatible?(left, right) do + left + else + raise ArgumentError, + "expected the elixir_call function to match the given output template #{inspect(right)}, got: #{inspect(left)}" + end + end + defp chunk([], data, type) do match_types [type] do <> = data diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 3c463ba237..650b5ab0d8 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -141,7 +141,6 @@ defmodule Nx.Backend do fallback to the default implementation. """ @callback optional(atom, [term], fun) :: tensor - @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor @callback eigh({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 936ec1f4b0..2eddc8918d 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -175,6 +175,15 @@ defmodule Nx.Defn.Evaluator do Map.put(cache, [:optional | id], optional_expr_cache) end + defp compute_cache( + :elixir_call, + %{data: %Expr{args: [tensor_expr, _opts, _fun, _out]}}, + state, + cache + ) do + composite_compute_cache(tensor_expr, state, cache) + end + defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do %{parent_ids: parent_ids, current_ids: current_ids} = state @@ -431,6 +440,17 @@ defmodule Nx.Defn.Evaluator do end end + defp eval_apply( + :elixir_call, + %{data: %Expr{args: [tensor_expr, static_argument, fun, out_template]}}, + state, + caches + ) do + {tensor_value, caches} = composite_eval(tensor_expr, state, caches) + result = fun.(tensor_value, static_argument) + {reshape_elixir_call_result(result, out_template), caches} + end + defp eval_apply(op, %{vectorized_axes: [_ | _]} = ans, _state, _caches) do raise "unexpected vectorized axes in evaluator for operation #{inspect(op)}: #{inspect(ans)}" end @@ -466,6 +486,27 @@ defmodule Nx.Defn.Evaluator do {value, [cache | caches]} end + defp reshape_elixir_call_result(result, %Nx.Tensor{} = template) do + # Single-tensor output: just ensure compatibility with the template. + if not Nx.compatible?(template, result) do + raise "expected the elixir_call function to match the given output template" + end + + result + end + + defp reshape_elixir_call_result(result, template_container) do + # Container (tuple/map/etc) output: we expect the callback to return + # a container with the same flattened tensor leaves as the template. + if not Nx.compatible?(result, template_container) do + raise "expected the elixir_call function to match the given output template" + end + + result_leaves = Composite.flatten_list([result]) + + List.to_tuple(result_leaves) + end + ## Control flow helpers defp while(acc, condition, block, state, caches) do diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 782e4a07fd..599e4a56a6 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -41,6 +41,8 @@ defmodule Nx.Defn.Expr do * `attach_token(token(%Nx.Defn.Token{}), expr)` + * `elixir_call(out, tensor_or_container, opts, fun)` + `defn` compilers must handle said nodes accordingly. """ @@ -1393,6 +1395,68 @@ defmodule Nx.Defn.Expr do context || acc end + @doc """ + Helper for defining an :elixir_call expression node. + """ + def elixir_call(out, tensor_or_container, static_argument, fun) when is_function(fun, 2) do + # Convert the entire tensor_or_container into an expression container, + # preserving its structure but ensuring all tensors are Expr-backed. + tensor_expr = + Composite.traverse(tensor_or_container, fn + %T{} = t -> to_expr(t) + other -> other + end) + + # Grab context from the first tensor in the flattened container. + [%T{data: %Expr{context: context}} | _] = + Composite.flatten_list([tensor_expr]) + + case out do + t when is_struct(t, Nx.Tensor) -> + out_template = Nx.to_template(t) + expr(t, context, :elixir_call, [tensor_expr, static_argument, fun, out_template]) + + tuple when is_tuple(tuple) -> + out_template = tuple_out(tuple_size(tuple)) + user_template = Nx.to_template(tuple) + + expr_node = + expr(out_template, context, :elixir_call, [ + tensor_expr, + static_argument, + fun, + user_template + ]) + + tuple(expr_node, Tuple.to_list(tuple)) + + container -> + user_template = Nx.to_template(container) + + leaf_templates = Composite.flatten_list([user_template]) + leaf_count = length(leaf_templates) + + root = + expr( + tuple_out(leaf_count), + context, + :elixir_call, + [tensor_expr, static_argument, fun, user_template] + ) + + {container_expr, _} = + Composite.traverse(user_template, {0, root}, fn + %T{} = template, {i, root} -> + {expr(template, context, :elem, [root, i]), {i + 1, root}} + + other, acc -> + {other, acc} + end) + + container_expr + end + end + ## Constant helpers and related optimizations defp constant(%{vectorized_axes: [_ | _]} = out, number) do diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index b3ffb0ab91..f94e9da0d4 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -122,6 +122,10 @@ defmodule Nx.Defn.Grad do acc end + defp parents_args(:elixir_call, _expr, _id, acc, _parent_vectorized_names) do + acc + end + defp parents_args( :optional, %{data: %{args: [call, _expr, callback]}} = t, diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 582b9d4689..ba4e6e1812 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -192,6 +192,14 @@ defmodule Nx.Defn.Tree do {[call, expr, callback], acc} end + def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, _type, acc, fun) do + [tensor_expr, callback_opts, callback, out_template] = args + + {tensor_expr, acc} = Composite.traverse(tensor_expr, acc, fun) + + {[tensor_expr, callback_opts, callback, out_template], acc} + end + def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do {hooks, acc} = Enum.map_reduce(token.hooks, acc, fn %{expr: expr} = token, acc -> diff --git a/nx/test/nx/defn/elixir_call_evaluator_test.exs b/nx/test/nx/defn/elixir_call_evaluator_test.exs new file mode 100644 index 0000000000..a26408acce --- /dev/null +++ b/nx/test/nx/defn/elixir_call_evaluator_test.exs @@ -0,0 +1,88 @@ +defmodule Nx.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, x, [offset: 10.0], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert Nx.all_close(y, expected) |> Nx.to_number() == 1 + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, fx, fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert expected == y + end + + defn return_as_container(x, y, template_fun, container_fun) do + Nx.elixir_call(template_fun.(x, y), {x, y}, container_fun) + end + + test "elixir_call with container output" do + x = Nx.tensor([1, 2, 3]) + y = Nx.tensor([4, 5, 6]) + + ref = make_ref() + pid = self() + + container_fun = fn {x, y} -> + send(pid, {:container_fun, ref}) + {x, y} + end + + template_fun = fn x, y -> {x, y} end + + assert {x_res, y_res} = return_as_container(x, y, template_fun, container_fun) + assert x_res == x + assert y_res == y + assert_receive {:container_fun, ^ref} + + ref = make_ref() + + container_fun = fn {x, y} -> + send(pid, {:container_fun, ref}) + %{x: x, y: {%{key: y}, Nx.s32(1)}} + end + + template_fun = fn x, y -> %{x: x, y: {%{key: y}, Nx.s32(1)}} end + + assert result = return_as_container(x, y, template_fun, container_fun) + assert %{x: _, y: {%{key: _}, _}} = result + assert result.x == x + assert result.y == {%{key: y}, Nx.s32(1)} + assert_receive {:container_fun, ^ref} + end +end