-
Notifications
You must be signed in to change notification settings - Fork 212
feat: Nx.elixir_call/3 #1627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+1,488
−42
Merged
feat: Nx.elixir_call/3 #1627
Changes from 21 commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
d72680a
feat: add initial draft
polvalente da7d7e4
evaluator mode working
polvalente fc9c28c
test: add tests
polvalente 25300b7
fix grad
polvalente a91c42b
Merge remote-tracking branch 'origin/main' into pv-feat/elixir-call
polvalente aa35431
feat(exla): initial Nx.elixir_call/3 CPU wiring
polvalente 37a15af
feat: seemingly working mvp
polvalente 7127a8c
feat: first working version
polvalente bc52205
wip: step through code review
polvalente 95f7860
finish changes code review
polvalente 2a1e627
Merge branch 'main' into pv-feat/elixir-call
polvalente c9aa9bd
chore: remove unused files
polvalente c36a4c6
Merge branch 'pv-feat/elixir-call' of github.com:elixir-nx/nx into pv…
polvalente c7c4871
docs: document the lock issue
polvalente ddb3733
refactor: ElixirCallbackPending as a resource
polvalente bdae93c
refactor: improve result decoding and reduce data copies
polvalente cb8c345
refactor: implement fine encoder for ffi types
polvalente a79b037
docs: update docs
polvalente 658b8bf
chore: use exla nif atoms
polvalente 6d4d327
refactor: reorganize files and namespace things
polvalente 81ac875
kill handle
polvalente 4149353
refactor: use erlnifbinary
polvalente 9a60a26
refactor: leverage fine encoding
polvalente da72cdd
chore: remove identity calls
polvalente 38b35f4
fix: use error atom
polvalente e6ecfca
refactor: do not use dynamic arities
polvalente e254389
chore: changes due to code review
polvalente 4853248
fix: use Nx.compatible
polvalente 6961bab
refactor: allow any type in static argument
polvalente 3731c97
chore: revert torchx
polvalente ea141a3
fix: handle containers
polvalente 045281c
defend against exceptions
polvalente 696dd5f
refactor: pass callback id and pid as attributes
polvalente 66588a8
feat: gc process when Exla Executable is gc'd
polvalente 4aa2f66
Update exla/lib/exla/callback_server.ex
polvalente 741a6b5
single pass materialize
polvalente c6d2720
chore: revert container_to_typespecs
polvalente c5a2cec
fix: do not leak callback servers on error
polvalente d20e83f
refactor: skip supervisor module
polvalente 66f62cc
Update nx/lib/nx.ex
polvalente 35eda2d
fix: proper container support
polvalente 40fcea4
Merge branch 'pv-feat/elixir-call' of github.com:elixir-nx/nx into pv…
polvalente 5df4c5d
fix: values cannot be expr in defn devaluator
polvalente ebee654
docs: add examples
polvalente 37fe277
Update exla/lib/exla/defn.ex
polvalente a2eaaf6
refactor: register callbacks based on expression id
polvalente af79101
Merge branch 'pv-feat/elixir-call' of github.com:elixir-nx/nx into pv…
polvalente File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| #include "elixir_callback_bridge.h" | ||
|
|
||
| #include <cstring> | ||
| #include <vector> | ||
| #include <string> | ||
|
|
||
| #include "xla/ffi/api/ffi.h" | ||
| #include "xla/ffi/ffi_api.h" | ||
|
|
||
| namespace ffi = xla::ffi; | ||
|
|
||
| namespace { | ||
|
|
||
| ffi::Error exla_elixir_callback_impl(ffi::RemainingArgs args, | ||
| ffi::RemainingRets rets) { | ||
| if (args.size() == 0) { | ||
| return ffi::Error(ffi::ErrorCode::kInvalidArgument, | ||
| "exla_elixir_callback expects at least one argument"); | ||
| } | ||
|
|
||
| // The first argument is a scalar S64 tensor carrying the callback id. | ||
| auto id_buf_or = args.get<ffi::AnyBuffer>(0); | ||
| if (!id_buf_or) { | ||
| return id_buf_or.error(); | ||
| } | ||
|
|
||
| ffi::AnyBuffer id_buf = *id_buf_or; | ||
|
|
||
| if (id_buf.element_count() != 1 || | ||
| id_buf.element_type() != ffi::DataType::S64) { | ||
| return ffi::Error(ffi::ErrorCode::kInvalidArgument, | ||
| "exla_elixir_callback callback id must be scalar s64"); | ||
| } | ||
|
|
||
| int64_t callback_id = id_buf.reinterpret_data<int64_t>()[0]; | ||
|
|
||
| // Collect all remaining input tensors (excluding callback id) into | ||
| // lightweight payload views. | ||
| std::vector<exla::callback_bridge::Arg> inputs; | ||
| inputs.reserve(args.size() - 1); | ||
|
|
||
| for (size_t i = 1; i < args.size(); ++i) { | ||
| auto maybe_buf_or = args.get<ffi::AnyBuffer>(i); | ||
| if (!maybe_buf_or) { | ||
| return maybe_buf_or.error(); | ||
| } | ||
|
|
||
| ffi::AnyBuffer buf = *maybe_buf_or; | ||
|
|
||
| exla::callback_bridge::Arg tensor; | ||
| tensor.dtype = buf.element_type(); | ||
|
|
||
| auto dims = buf.dimensions(); | ||
| tensor.dims.assign(dims.begin(), dims.end()); | ||
|
|
||
| tensor.data = reinterpret_cast<const uint8_t *>(buf.untyped_data()); | ||
| tensor.size_bytes = buf.size_bytes(); | ||
|
|
||
| inputs.push_back(std::move(tensor)); | ||
| } | ||
|
|
||
| // Prepare output buffer descriptors so the callback bridge can write results | ||
| // directly into the final destination buffers. | ||
| std::vector<exla::callback_bridge::OutputBuffer> outputs; | ||
| outputs.reserve(rets.size()); | ||
|
|
||
| for (size_t i = 0; i < rets.size(); ++i) { | ||
| auto maybe_ret_or = rets.get<ffi::AnyBuffer>(i); | ||
| if (!maybe_ret_or) { | ||
| return maybe_ret_or.error(); | ||
| } | ||
|
|
||
| ffi::Result<ffi::AnyBuffer> ret = *maybe_ret_or; | ||
| ffi::AnyBuffer out = *ret; | ||
|
|
||
| exla::callback_bridge::OutputBuffer buf; | ||
| buf.data = static_cast<uint8_t *>(out.untyped_data()); | ||
| buf.size = ffi::ByteWidth(out.element_type()) * | ||
| static_cast<size_t>(out.element_count()); | ||
|
|
||
| outputs.push_back(buf); | ||
| } | ||
|
|
||
| // Call back into Elixir through the bridge. On success, the bridge writes | ||
| // results directly into the provided output buffers. | ||
| exla::callback_bridge::Result result = | ||
| exla::callback_bridge::InvokeElixirCallback(callback_id, inputs, outputs); | ||
|
|
||
| if (!result.ok) { | ||
| return ffi::Error(ffi::ErrorCode::kInternal, result.error); | ||
| } | ||
|
|
||
| return ffi::Error::Success(); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| XLA_FFI_DEFINE_HANDLER_SYMBOL( | ||
| exla_elixir_callback, exla_elixir_callback_impl, | ||
| ffi::Ffi::Bind() | ||
| .RemainingArgs() | ||
| .RemainingRets()); | ||
|
|
||
| XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "exla_elixir_callback", "Host", | ||
| exla_elixir_callback); | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| #include "elixir_callback_bridge.h" | ||
|
|
||
| #include <cstring> | ||
|
|
||
| namespace exla { | ||
|
|
||
| namespace callback_bridge { | ||
|
|
||
| struct BridgeState { | ||
| ErlNifPid dispatcher_pid; | ||
| bool dispatcher_set = false; | ||
| }; | ||
|
|
||
| BridgeState *GetBridgeState() { | ||
| static BridgeState *state = new BridgeState(); | ||
| return state; | ||
| } | ||
|
|
||
| fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, | ||
| ErlNifPid dispatcher_pid) { | ||
| (void)env; | ||
| auto state = GetBridgeState(); | ||
| state->dispatcher_pid = dispatcher_pid; | ||
| state->dispatcher_set = true; | ||
| return fine::Ok(); | ||
| } | ||
|
|
||
| fine::Ok<> elixir_callback_reply(ErlNifEnv *env, | ||
| fine::ResourcePtr<Pending> pending, | ||
| fine::Atom status, fine::Term result) { | ||
| deliver_reply(env, pending, status, result); | ||
| return fine::Ok(); | ||
| } | ||
|
|
||
| fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, | ||
| ErlNifPid dispatcher_pid) { | ||
| (void)env; | ||
| auto state = GetBridgeState(); | ||
|
|
||
| if (state->dispatcher_set && | ||
| std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) == | ||
| 0) { | ||
| state->dispatcher_set = false; | ||
| } | ||
|
|
||
| return fine::Ok(); | ||
| } | ||
|
|
||
| void deliver_reply(ErlNifEnv *env, fine::ResourcePtr<Pending> pending, | ||
| fine::Atom status, fine::Term result_term) { | ||
| Result cb_result; | ||
|
|
||
| if (status == "ok") { | ||
| // Successful reply: result_term is a list of binaries that we decode into | ||
| // raw byte vectors via Fine and copy directly into the registered output | ||
| // buffers. | ||
| try { | ||
| auto payloads = | ||
| fine::decode<std::vector<std::vector<uint8_t>>>(env, result_term); | ||
polvalente marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| std::lock_guard<std::mutex> lock(pending->mu); | ||
|
|
||
| if (payloads.size() != pending->outputs.size()) { | ||
| cb_result.ok = false; | ||
| cb_result.error = | ||
| "mismatched number of callback outputs vs registered buffers"; | ||
| } else { | ||
| cb_result.ok = true; | ||
|
|
||
| for (size_t i = 0; i < payloads.size(); ++i) { | ||
| const auto &bytes = payloads[i]; | ||
| auto &out_buf = pending->outputs[i]; | ||
|
|
||
| if (bytes.size() != out_buf.size) { | ||
| cb_result.ok = false; | ||
| cb_result.error = | ||
| "callback returned binary of unexpected size for result buffer"; | ||
| break; | ||
| } | ||
|
|
||
| if (out_buf.size > 0) { | ||
| std::memcpy(out_buf.data, bytes.data(), out_buf.size); | ||
| } | ||
| } | ||
| } | ||
| } catch (const std::exception &e) { | ||
| cb_result.ok = false; | ||
| cb_result.error = | ||
| std::string("failed to decode Elixir callback outputs: ") + e.what(); | ||
| } | ||
| } else { | ||
| // Error reply: result_term is expected to be {kind_atom, message :: binary} | ||
| cb_result.ok = false; | ||
|
|
||
| try { | ||
| auto decoded = | ||
| fine::decode<std::tuple<fine::Atom, ErlNifBinary>>(env, result_term); | ||
polvalente marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ErlNifBinary msg_bin = std::get<1>(decoded); | ||
| cb_result.error.assign(reinterpret_cast<const char *>(msg_bin.data), | ||
| msg_bin.size); | ||
| } catch (const std::exception &) { | ||
| cb_result.error = "elixir callback returned error"; | ||
| } | ||
| } | ||
|
|
||
| { | ||
| std::lock_guard<std::mutex> lock(pending->mu); | ||
| pending->result = std::move(cb_result); | ||
| pending->done = true; | ||
| } | ||
|
|
||
| pending->cv.notify_one(); | ||
| } | ||
|
|
||
| Result InvokeElixirCallback(int64_t callback_id, const std::vector<Arg> &inputs, | ||
| const std::vector<OutputBuffer> &outputs) { | ||
| auto state = GetBridgeState(); | ||
|
|
||
| if (!state->dispatcher_set) { | ||
| Result res; | ||
| res.ok = false; | ||
| res.error = "EXLA elixir callback dispatcher is not set"; | ||
| return res; | ||
| } | ||
|
|
||
| auto pending = fine::make_resource<Pending>(outputs); | ||
|
|
||
| ErlNifEnv *msg_env = enif_alloc_env(); | ||
|
|
||
| // Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send | ||
| // plain binaries because the BEAM callback needs to own the data lifetime. | ||
| std::vector<ERL_NIF_TERM> args_terms; | ||
polvalente marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| args_terms.reserve(inputs.size()); | ||
|
|
||
| for (const auto &tensor : inputs) { | ||
| ERL_NIF_TERM bin_term; | ||
| unsigned char *bin_data = | ||
| enif_make_new_binary(msg_env, tensor.size_bytes, &bin_term); | ||
| if (tensor.size_bytes > 0) { | ||
| memcpy(bin_data, tensor.data, tensor.size_bytes); | ||
| } | ||
|
|
||
| // Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via | ||
| // Fine's encoder defined in exla_nif_util.h. | ||
| ERL_NIF_TERM typespec_term = | ||
| fine::encode(msg_env, std::make_tuple(tensor.dtype, tensor.dims)); | ||
|
|
||
| ERL_NIF_TERM arg_tuple = enif_make_tuple2(msg_env, bin_term, typespec_term); | ||
|
|
||
| args_terms.push_back(arg_tuple); | ||
| } | ||
|
|
||
| ERL_NIF_TERM args_list = | ||
| enif_make_list_from_array(msg_env, args_terms.data(), args_terms.size()); | ||
|
|
||
| ERL_NIF_TERM pending_term = fine::encode(msg_env, pending); | ||
| ERL_NIF_TERM cb_term = enif_make_int64(msg_env, callback_id); | ||
|
|
||
| ERL_NIF_TERM msg = | ||
| enif_make_tuple4(msg_env, enif_make_atom(msg_env, "exla_elixir_call"), | ||
| cb_term, args_list, pending_term); | ||
|
|
||
| // Use the dispatcher pid registered via start_elixir_callback_bridge/1. | ||
| // Calling enif_whereis_pid from this non-scheduler thread is unsafe and | ||
| // was causing a segfault. | ||
| ErlNifPid dispatcher_pid = state->dispatcher_pid; | ||
| enif_send(msg_env, &dispatcher_pid, msg_env, msg); | ||
| enif_free_env(msg_env); | ||
|
|
||
| std::unique_lock<std::mutex> lock(pending->mu); | ||
| pending->cv.wait(lock, [&pending] { return pending->done; }); | ||
|
|
||
| return pending->result; | ||
| } | ||
|
|
||
| } // namespace callback_bridge | ||
|
|
||
| } // namespace exla | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.