Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
d72680a
feat: add initial draft
polvalente Aug 8, 2025
da7d7e4
evaluator mode working
polvalente Aug 8, 2025
fc9c28c
test: add tests
polvalente Aug 9, 2025
25300b7
fix grad
polvalente Aug 9, 2025
a91c42b
Merge remote-tracking branch 'origin/main' into pv-feat/elixir-call
polvalente Aug 10, 2025
aa35431
feat(exla): initial Nx.elixir_call/3 CPU wiring
polvalente Nov 22, 2025
37a15af
feat: seemingly working mvp
polvalente Nov 23, 2025
7127a8c
feat: first working version
polvalente Nov 24, 2025
bc52205
wip: step through code review
polvalente Nov 25, 2025
95f7860
finish changes code review
polvalente Nov 25, 2025
2a1e627
Merge branch 'main' into pv-feat/elixir-call
polvalente Nov 25, 2025
c9aa9bd
chore: remove unused files
polvalente Nov 25, 2025
c36a4c6
Merge branch 'pv-feat/elixir-call' of github.com:elixir-nx/nx into pv…
polvalente Nov 25, 2025
c7c4871
docs: document the lock issue
polvalente Nov 25, 2025
ddb3733
refactor: ElixirCallbackPending as a resource
polvalente Nov 26, 2025
bdae93c
refactor: improve result decoding and reduce data copies
polvalente Nov 26, 2025
cb8c345
refactor: implement fine encoder for ffi types
polvalente Nov 26, 2025
a79b037
docs: update docs
polvalente Nov 26, 2025
658b8bf
chore: use exla nif atoms
polvalente Nov 26, 2025
6d4d327
refactor: reorganize files and namespace things
polvalente Nov 26, 2025
81ac875
kill handle
polvalente Nov 26, 2025
4149353
refactor: use erlnifbinary
polvalente Nov 28, 2025
9a60a26
refactor: leverage fine encoding
polvalente Nov 28, 2025
da72cdd
chore: remove identity calls
polvalente Nov 28, 2025
38b35f4
fix: use error atom
polvalente Nov 28, 2025
e6ecfca
refactor: do not use dynamic arities
polvalente Nov 28, 2025
e254389
chore: changes due to code review
polvalente Nov 28, 2025
4853248
fix: use Nx.compatible
polvalente Nov 28, 2025
6961bab
refactor: allow any type in static argument
polvalente Nov 28, 2025
3731c97
chore: revert torchx
polvalente Nov 28, 2025
ea141a3
fix: handle containers
polvalente Nov 28, 2025
045281c
defend against exceptions
polvalente Nov 28, 2025
696dd5f
refactor: pass callback id and pid as attributes
polvalente Nov 29, 2025
66588a8
feat: gc process when Exla Executable is gc'd
polvalente Nov 29, 2025
4aa2f66
Update exla/lib/exla/callback_server.ex
polvalente Nov 29, 2025
741a6b5
single pass materialize
polvalente Nov 29, 2025
c6d2720
chore: revert container_to_typespecs
polvalente Nov 29, 2025
c5a2cec
fix: do not leak callback servers on error
polvalente Nov 29, 2025
d20e83f
refactor: skip supervisor module
polvalente Nov 29, 2025
66f62cc
Update nx/lib/nx.ex
polvalente Nov 29, 2025
35eda2d
fix: proper container support
polvalente Nov 29, 2025
40fcea4
Merge branch 'pv-feat/elixir-call' of github.com:elixir-nx/nx into pv…
polvalente Nov 29, 2025
5df4c5d
fix: values cannot be expr in defn devaluator
polvalente Nov 29, 2025
ebee654
docs: add examples
polvalente Nov 29, 2025
37fe277
Update exla/lib/exla/defn.ex
polvalente Nov 29, 2025
a2eaaf6
refactor: register callbacks based on expression id
polvalente Nov 29, 2025
af79101
Merge branch 'pv-feat/elixir-call' of github.com:elixir-nx/nx into pv…
polvalente Nov 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO)

SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/ipc.cc
SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc)
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/custom_calls/elixir_callback_bridge.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o


Expand Down
107 changes: 107 additions & 0 deletions exla/c_src/exla/custom_calls/elixir_callback.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#include "elixir_callback_bridge.h"

#include <cstring>
#include <vector>
#include <string>

#include "xla/ffi/api/ffi.h"
#include "xla/ffi/ffi_api.h"

namespace ffi = xla::ffi;

namespace {

ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args,
ffi::RemainingRets rets) {
if (args.size() == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"exla_elixir_callback expects at least one argument");
}

// The first argument is a scalar S64 tensor carrying the callback id.
auto id_buf_or = args.get<ffi::AnyBuffer>(0);
if (!id_buf_or) {
return id_buf_or.error();
}

ffi::AnyBuffer id_buf = *id_buf_or;

if (id_buf.element_count() != 1 ||
id_buf.element_type() != ffi::DataType::S64) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"exla_elixir_callback callback id must be scalar s64");
}

int64_t callback_id = id_buf.reinterpret_data<int64_t>()[0];

// Collect all remaining input tensors (excluding callback id) into
// lightweight payload views.
std::vector<exla::callback_bridge::Arg> inputs;
inputs.reserve(args.size() - 1);

for (size_t i = 1; i < args.size(); ++i) {
auto maybe_buf_or = args.get<ffi::AnyBuffer>(i);
if (!maybe_buf_or) {
return maybe_buf_or.error();
}

ffi::AnyBuffer buf = *maybe_buf_or;

exla::callback_bridge::Arg tensor;
tensor.dtype = buf.element_type();

auto dims = buf.dimensions();
tensor.dims.assign(dims.begin(), dims.end());

tensor.data = reinterpret_cast<const uint8_t *>(buf.untyped_data());
tensor.size_bytes = buf.size_bytes();

inputs.push_back(std::move(tensor));
}

// Prepare output buffer descriptors so the callback bridge can write results
// directly into the final destination buffers.
std::vector<exla::callback_bridge::OutputBuffer> outputs;
outputs.reserve(rets.size());

for (size_t i = 0; i < rets.size(); ++i) {
auto maybe_ret_or = rets.get<ffi::AnyBuffer>(i);
if (!maybe_ret_or) {
return maybe_ret_or.error();
}

ffi::Result<ffi::AnyBuffer> ret = *maybe_ret_or;
ffi::AnyBuffer out = *ret;

exla::callback_bridge::OutputBuffer buf;
buf.data = static_cast<uint8_t *>(out.untyped_data());
buf.size = ffi::ByteWidth(out.element_type()) *
static_cast<size_t>(out.element_count());

outputs.push_back(buf);
}

// Call back into Elixir through the bridge. On success, the bridge writes
// results directly into the provided output buffers.
exla::callback_bridge::Result result =
exla::callback_bridge::InvokeElixirCallback(callback_id, inputs, outputs);

if (!result.ok) {
return ffi::Error(ffi::ErrorCode::kInternal, result.error);
}

return ffi::Error::Success();
}

} // namespace

XLA_FFI_DEFINE_HANDLER_SYMBOL(
exla_elixir_callback, exla_elixir_callback_impl,
ffi::Ffi::Bind()
.RemainingArgs()
.RemainingRets());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "exla_elixir_callback", "Host",
exla_elixir_callback);


180 changes: 180 additions & 0 deletions exla/c_src/exla/custom_calls/elixir_callback_bridge.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#include "elixir_callback_bridge.h"

#include <cstring>

namespace exla {

namespace callback_bridge {

struct BridgeState {
ErlNifPid dispatcher_pid;
bool dispatcher_set = false;
};

BridgeState *GetBridgeState() {
static BridgeState *state = new BridgeState();
return state;
}

fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env,
ErlNifPid dispatcher_pid) {
(void)env;
auto state = GetBridgeState();
state->dispatcher_pid = dispatcher_pid;
state->dispatcher_set = true;
return fine::Ok();
}

fine::Ok<> elixir_callback_reply(ErlNifEnv *env,
fine::ResourcePtr<Pending> pending,
fine::Atom status, fine::Term result) {
deliver_reply(env, pending, status, result);
return fine::Ok();
}

fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env,
ErlNifPid dispatcher_pid) {
(void)env;
auto state = GetBridgeState();

if (state->dispatcher_set &&
std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) ==
0) {
state->dispatcher_set = false;
}

return fine::Ok();
}

void deliver_reply(ErlNifEnv *env, fine::ResourcePtr<Pending> pending,
fine::Atom status, fine::Term result_term) {
Result cb_result;

if (status == "ok") {
// Successful reply: result_term is a list of binaries that we decode into
// raw byte vectors via Fine and copy directly into the registered output
// buffers.
try {
auto payloads =
fine::decode<std::vector<std::vector<uint8_t>>>(env, result_term);

std::lock_guard<std::mutex> lock(pending->mu);

if (payloads.size() != pending->outputs.size()) {
cb_result.ok = false;
cb_result.error =
"mismatched number of callback outputs vs registered buffers";
} else {
cb_result.ok = true;

for (size_t i = 0; i < payloads.size(); ++i) {
const auto &bytes = payloads[i];
auto &out_buf = pending->outputs[i];

if (bytes.size() != out_buf.size) {
cb_result.ok = false;
cb_result.error =
"callback returned binary of unexpected size for result buffer";
break;
}

if (out_buf.size > 0) {
std::memcpy(out_buf.data, bytes.data(), out_buf.size);
}
}
}
} catch (const std::exception &e) {
cb_result.ok = false;
cb_result.error =
std::string("failed to decode Elixir callback outputs: ") + e.what();
}
} else {
// Error reply: result_term is expected to be {kind_atom, message :: binary}
cb_result.ok = false;

try {
auto decoded =
fine::decode<std::tuple<fine::Atom, ErlNifBinary>>(env, result_term);
ErlNifBinary msg_bin = std::get<1>(decoded);
cb_result.error.assign(reinterpret_cast<const char *>(msg_bin.data),
msg_bin.size);
} catch (const std::exception &) {
cb_result.error = "elixir callback returned error";
}
}

{
std::lock_guard<std::mutex> lock(pending->mu);
pending->result = std::move(cb_result);
pending->done = true;
}

pending->cv.notify_one();
}

Result InvokeElixirCallback(int64_t callback_id, const std::vector<Arg> &inputs,
const std::vector<OutputBuffer> &outputs) {
auto state = GetBridgeState();

if (!state->dispatcher_set) {
Result res;
res.ok = false;
res.error = "EXLA elixir callback dispatcher is not set";
return res;
}

auto pending = fine::make_resource<Pending>(outputs);

ErlNifEnv *msg_env = enif_alloc_env();

// Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send
// plain binaries because the BEAM callback needs to own the data lifetime.
std::vector<ERL_NIF_TERM> args_terms;
args_terms.reserve(inputs.size());

for (const auto &tensor : inputs) {
ERL_NIF_TERM bin_term;
unsigned char *bin_data =
enif_make_new_binary(msg_env, tensor.size_bytes, &bin_term);
if (tensor.size_bytes > 0) {
memcpy(bin_data, tensor.data, tensor.size_bytes);
}

// Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via
// Fine's encoder defined in exla_nif_util.h.
ERL_NIF_TERM typespec_term =
fine::encode(msg_env, std::make_tuple(tensor.dtype, tensor.dims));

ERL_NIF_TERM arg_tuple = enif_make_tuple2(msg_env, bin_term, typespec_term);

args_terms.push_back(arg_tuple);
}

ERL_NIF_TERM args_list =
enif_make_list_from_array(msg_env, args_terms.data(), args_terms.size());

ERL_NIF_TERM pending_term = fine::encode(msg_env, pending);
ERL_NIF_TERM cb_term = enif_make_int64(msg_env, callback_id);

ERL_NIF_TERM msg =
enif_make_tuple4(msg_env, enif_make_atom(msg_env, "exla_elixir_call"),
cb_term, args_list, pending_term);

// Use the dispatcher pid registered via start_elixir_callback_bridge/1.
// Calling enif_whereis_pid from this non-scheduler thread is unsafe and
// was causing a segfault.
ErlNifPid dispatcher_pid = state->dispatcher_pid;
enif_send(msg_env, &dispatcher_pid, msg_env, msg);
enif_free_env(msg_env);

std::unique_lock<std::mutex> lock(pending->mu);
pending->cv.wait(lock, [&pending] { return pending->done; });

return pending->result;
}

} // namespace callback_bridge

} // namespace exla


Loading