diff --git a/torchx/CMakeLists.txt b/torchx/CMakeLists.txt index 94bb70c77d..3684ecd44e 100644 --- a/torchx/CMakeLists.txt +++ b/torchx/CMakeLists.txt @@ -28,6 +28,7 @@ message(STATUS "C_SRC: ${C_SRC}") message(STATUS "PRIV_DIR: $ENV{PRIV_DIR}") message(STATUS "LIBTORCH_DIR: $ENV{LIBTORCH_DIR}") message(STATUS "ERTS_INCLUDE_DIR: $ENV{ERTS_INCLUDE_DIR}") +message(STATUS "FINE_INCLUDE_DIR: $ENV{FINE_INCLUDE_DIR}") message(STATUS "LIBTORCH_BASE: $ENV{LIBTORCH_BASE}") message(STATUS "MIX_BUILD_EMBEDDED $ENV{MIX_BUILD_EMBEDDED}") message(STATUS "LIBTORCH_LINK $ENV{LIBTORCH_LINK}") @@ -43,7 +44,7 @@ add_library(torchx SHARED ${torchx_sources}) target_link_libraries(torchx "${TORCH_LIBRARIES}") set_property(TARGET torchx PROPERTY CXX_STANDARD 17) -target_include_directories(torchx PUBLIC $ENV{ERTS_INCLUDE_DIR}) +target_include_directories(torchx PUBLIC $ENV{ERTS_INCLUDE_DIR} $ENV{FINE_INCLUDE_DIR}) install( TARGETS torchx diff --git a/torchx/c_src/nx_nif_utils.hpp b/torchx/c_src/nx_nif_utils.hpp deleted file mode 100644 index 79b2d947a8..0000000000 --- a/torchx/c_src/nx_nif_utils.hpp +++ /dev/null @@ -1,304 +0,0 @@ -#pragma once - -#include "erl_nif.h" - -ErlNifResourceType *TENSOR_TYPE; - -#define GET(ARGN, VAR) \ - if (!nx::nif::get(env, argv[ARGN], &VAR)) \ - return nx::nif::error(env, "Unable to get " #VAR " param."); - -#define PARAM(ARGN, TYPE, VAR) \ - TYPE VAR; \ - GET(ARGN, VAR) - -#define ATOM_PARAM(ARGN, VAR) \ - std::string VAR; \ - if (!nx::nif::get_atom(env, argv[ARGN], VAR)) \ - return nx::nif::error(env, "Unable to get " #VAR " atom param."); - -#define TUPLE_PARAM(ARGN, TYPE, VAR) \ - TYPE VAR; \ - if (!nx::nif::get_tuple(env, argv[ARGN], VAR)) { \ - std::ostringstream msg; \ - msg << "Unable to get " #VAR " tuple param in NIF." << __func__ << "/" << argc; \ - return nx::nif::error(env, msg.str().c_str()); \ - } - -#define LIST_PARAM(ARGN, TYPE, VAR) \ - TYPE VAR; \ - if (!nx::nif::get_list(env, argv[ARGN], VAR)) \ - return nx::nif::error(env, "Unable to get " #VAR " list param."); - -#define BINARY_PARAM(ARGN, VAR) \ - ErlNifBinary VAR; \ - if (!enif_inspect_binary(env, argv[ARGN], &VAR)) \ - return nx::nif::error(env, "Unable to get " #VAR " binary param."); - -namespace nx -{ - namespace nif - { - // Status helpers - - // Helper for returning `{:error, msg}` from NIF. - ERL_NIF_TERM error(ErlNifEnv *env, const char *msg) - { - ERL_NIF_TERM atom = enif_make_atom(env, "error"); - ERL_NIF_TERM msg_term = enif_make_string(env, msg, ERL_NIF_LATIN1); - return enif_make_tuple2(env, atom, msg_term); - } - - // Helper for returning `{:ok, term}` from NIF. - ERL_NIF_TERM ok(ErlNifEnv *env) - { - return enif_make_atom(env, "ok"); - } - - // Helper for returning `:ok` from NIF. - ERL_NIF_TERM ok(ErlNifEnv *env, ERL_NIF_TERM term) - { - return enif_make_tuple2(env, ok(env), term); - } - - // Numeric types - - int get(ErlNifEnv *env, ERL_NIF_TERM term, int *var) - { - return enif_get_int(env, term, - reinterpret_cast(var)); - } - - int get(ErlNifEnv *env, ERL_NIF_TERM term, int64_t *var) - { - return enif_get_int64(env, term, - reinterpret_cast(var)); - } - - int get(ErlNifEnv *env, ERL_NIF_TERM term, double *var) - { - return enif_get_double(env, term, var); - } - - // Standard types - - int get(ErlNifEnv *env, ERL_NIF_TERM term, std::string &var) - { - unsigned len; - int ret = enif_get_list_length(env, term, &len); - - if (!ret) - { - ErlNifBinary bin; - ret = enif_inspect_binary(env, term, &bin); - if (!ret) - { - return 0; - } - var = std::string((const char *)bin.data, bin.size); - return ret; - } - - var.resize(len + 1); - ret = enif_get_string(env, term, &*(var.begin()), var.size(), ERL_NIF_LATIN1); - - if (ret > 0) - { - var.resize(ret - 1); - } - else if (ret == 0) - { - var.resize(0); - } - else - { - } - - return ret; - } - - ERL_NIF_TERM make(ErlNifEnv *env, bool var) - { - if (var) - return enif_make_atom(env, "true"); - - return enif_make_atom(env, "false"); - } - - ERL_NIF_TERM make(ErlNifEnv *env, int64_t var) - { - return enif_make_int64(env, var); - } - - ERL_NIF_TERM make(ErlNifEnv *env, int var) - { - return enif_make_int(env, var); - } - - ERL_NIF_TERM make(ErlNifEnv *env, double var) - { - return enif_make_double(env, var); - } - - ERL_NIF_TERM make(ErlNifEnv *env, ErlNifBinary var) - { - return enif_make_binary(env, &var); - } - - ERL_NIF_TERM make(ErlNifEnv *env, std::string var) - { - return enif_make_string(env, var.c_str(), ERL_NIF_LATIN1); - } - - ERL_NIF_TERM make(ErlNifEnv *env, const char *string) - { - return enif_make_string(env, string, ERL_NIF_LATIN1); - } - - // Atoms - - int get_atom(ErlNifEnv *env, ERL_NIF_TERM term, std::string &var) - { - unsigned atom_length; - if (!enif_get_atom_length(env, term, &atom_length, ERL_NIF_LATIN1)) - { - return 0; - } - - var.resize(atom_length + 1); - - if (!enif_get_atom(env, term, &(*(var.begin())), var.size(), ERL_NIF_LATIN1)) - return 0; - - var.resize(atom_length); - - return 1; - } - - ERL_NIF_TERM atom(ErlNifEnv *env, const char *msg) - { - return enif_make_atom(env, msg); - } - - // Boolean - - int get(ErlNifEnv *env, ERL_NIF_TERM term, bool *var) - { - std::string bool_atom; - if (!get_atom(env, term, bool_atom)) - return 0; - - if (bool_atom == "true") - *var = true; - else if (bool_atom == "false") - *var = false; - else - return 0; // error - - return 1; - } - - // Containers - - template - int get_tuple(ErlNifEnv *env, ERL_NIF_TERM tuple, std::vector &var) - { - const ERL_NIF_TERM *terms; - int length; - if (!enif_get_tuple(env, tuple, &length, &terms)) - return 0; - var.reserve(length); - - for (int i = 0; i < length; i++) - { - T data; - if (!get(env, terms[i], &data)) - return 0; - var.push_back(data); - } - return 1; - } - - int get_list(ErlNifEnv *env, - ERL_NIF_TERM list, - std::vector &var) - { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) - return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) - { - ErlNifBinary elem; - if (!enif_inspect_binary(env, head, &elem)) - return 0; - var.push_back(elem); - list = tail; - } - return 1; - } - - int get_list(ErlNifEnv *env, - ERL_NIF_TERM list, - std::vector &var) - { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) - return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) - { - std::string elem; - if (!get_atom(env, head, elem)) - return 0; - var.push_back(elem); - list = tail; - } - return 1; - } - - int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var) - { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) - return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) - { - int64_t elem; - if (!get(env, head, &elem)) - return 0; - var.push_back(elem); - list = tail; - } - return 1; - } - - int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var) - { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) - return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) - { - torch::Tensor *elem; - if (!enif_get_resource(env, head, TENSOR_TYPE, reinterpret_cast(&elem))) - { - return 0; - } - var.push_back(*elem); - list = tail; - } - return 1; - } - } -} diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 67efafccfe..1b150da9c1 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -1,3 +1,4 @@ +#include #include #if defined(USING_TORCH_V1) @@ -6,1277 +7,1082 @@ #include #endif -#include +#include "torchx_nif_util.h" #include - -#include "nx_nif_utils.hpp" - -std::map dtypes = {{"byte", torch::kByte}, {"char", torch::kChar}, {"short", torch::kShort}, {"int", torch::kInt}, {"long", torch::kLong}, {"float8_e5m2", torch::kFloat8_e5m2}, {"half", torch::kHalf}, {"brain", torch::kBFloat16}, {"float", torch::kFloat}, {"double", torch::kDouble}, {"bool", torch::kBool}, {"complex", at::ScalarType::ComplexFloat}, {"complex_double", at::ScalarType::ComplexDouble}}; -std::map dtype_sizes = {{"byte", 1}, {"char", 1}, {"short", 2}, {"int", 4}, {"long", 8}, {"float8_e5m2", 1}, {"half", 2}, {"brain", 2}, {"float", 4}, {"double", 8}, {"complex", 8}, {"complex_double", 16}}; - -inline torch::ScalarType string2type(const std::string &atom) { - return dtypes[atom]; -} - -inline const std::string *type2string(const torch::ScalarType type) { - for (std::map::iterator i = dtypes.begin(); i != dtypes.end(); ++i) { - if (i->second == type) - return &i->first; +#include + +namespace torchx { + +// Register TorchTensor as a resource type +FINE_RESOURCE(TorchTensor); + +// Macro to register both _cpu and _io variants of a function +// Following EXLA's pattern - create wrapper functions +#define REGISTER_TENSOR_NIF(NAME) \ + auto NAME##_cpu = NAME; \ + auto NAME##_io = NAME; \ + FINE_NIF(NAME##_cpu, ERL_NIF_DIRTY_JOB_CPU_BOUND); \ + FINE_NIF(NAME##_io, ERL_NIF_DIRTY_JOB_IO_BOUND) + +// Macro to register both _cpu and _io variants for a specific arity +// Creates a unified NIF handler that dispatches to the function +// Usage: REGISTER_TENSOR_NIF_ARITY(name, function_symbol) +#define REGISTER_TENSOR_NIF_ARITY(NAME, SYMBOL) \ + static ERL_NIF_TERM SYMBOL##_nif(ErlNifEnv *env, int argc, \ + const ERL_NIF_TERM argv[]) { \ + return fine::nif(env, argc, argv, SYMBOL); \ + } \ + auto __nif_registration_##SYMBOL##_cpu = fine::Registration::register_nif( \ + {#NAME "_cpu", fine::nif_arity(SYMBOL), SYMBOL##_nif, \ + ERL_NIF_DIRTY_JOB_CPU_BOUND}); \ + auto __nif_registration_##SYMBOL##_io = fine::Registration::register_nif( \ + {#NAME "_io", fine::nif_arity(SYMBOL), SYMBOL##_nif, \ + ERL_NIF_DIRTY_JOB_IO_BOUND}); \ + static_assert(true, "require a semicolon after the macro") + +// Helper to get tensor from resource, with proper error checking +torch::Tensor &get_tensor(fine::ResourcePtr tensor_res) { + return tensor_res->tensor(); +} + +// Helper to create a tensor resource result +fine::Ok> +tensor_ok(const torch::Tensor &tensor) { + return fine::Ok(fine::make_resource(tensor)); +} + +// Helper for vector of int64 to IntArrayRef conversion +c10::IntArrayRef vec_to_array_ref(const std::vector &vec) { + return c10::IntArrayRef(vec); +} + +// Helper for device tuple (device_type, device_index) to torch::Device +torch::Device +tuple_to_device(const std::tuple &device_tuple) { + return torch::Device( + static_cast(std::get<0>(device_tuple)), + static_cast(std::get<1>(device_tuple))); +} + +// Helper to count elements in a shape +uint64_t elem_count(const std::vector &shape) { + return std::accumulate(shape.begin(), shape.end(), 1ULL, std::multiplies<>{}); +} + +// ============================================================================ +// Tensor Management Functions +// ============================================================================ + +fine::Atom delete_tensor(ErlNifEnv *env, + fine::ResourcePtr tensor) { + if (tensor->deallocate()) { + return fine::Atom("ok"); + } else { + // Throw exception so backend can catch and return :already_deallocated + throw std::invalid_argument("Tensor has been deallocated"); } - return nullptr; } -// the class instance to manage the refcount of Tensor -class TensorP { - public: - TensorP(ErlNifEnv *env, const ERL_NIF_TERM arg) : ptr(nullptr) { - // setup - if (!enif_get_resource(env, arg, TENSOR_TYPE, (void **)&ptr)) { - err = nx::nif::error(env, "Unable to get tensor param in NIF"); - return; - } - - refcount = (std::atomic *)(ptr + 1); - deleted = (std::atomic_flag *)(refcount + 1); +REGISTER_TENSOR_NIF(delete_tensor); - if (refcount->load() == 0) { - // already deallocated - ptr = nullptr; - err = nx::nif::error(env, "Tensor has been deallocated"); - return; - } - - if (is_valid()) { - // increase reference count - ++(*refcount); - } - } - - ~TensorP() { - if (is_valid()) { - // decrease reference count - if (refcount->fetch_sub(1) == 0) { - ptr->~Tensor(); - } - } - } - - bool deallocate() { - if (is_valid() && atomic_flag_test_and_set(deleted) == false) { - --(*refcount); - return true; - } else { - return false; - } - } +fine::Ok> +from_blob(ErlNifEnv *env, ErlNifBinary blob, std::vector shape, + fine::Atom type_atom, std::tuple device_tuple) { - torch::Tensor *data() const { - return ptr; - } + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); - bool is_valid() const { - return ptr != nullptr; + // Check if binary is large enough + if (blob.size / dtype_sizes[type_atom.to_string()] < elem_count(shape)) { + throw std::invalid_argument( + "Binary size is too small for the requested shape"); } - ERL_NIF_TERM error() { - return err; - } + auto tensor = torch::from_blob(blob.data, vec_to_array_ref(shape), + torch::device(torch::kCPU).dtype(type)); - private: - torch::Tensor *ptr; - std::atomic *refcount; - std::atomic_flag *deleted; - ERL_NIF_TERM err; -}; - -#define NIF(NAME) ERL_NIF_TERM NAME(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) - -#define SCALAR_PARAM(ARGN, VAR) \ - torch::Scalar VAR; \ - VAR.~Scalar(); \ - double double_##VAR; \ - std::vector complex_##VAR; \ - if (nx::nif::get_tuple(env, argv[ARGN], complex_##VAR)) { \ - new (&VAR) torch::Scalar(c10::complex( \ - complex_##VAR[0], \ - complex_##VAR[1])); \ - } else if (enif_get_double(env, argv[ARGN], &double_##VAR) == 0) { \ - int64_t int64_##VAR; \ - enif_get_int64(env, argv[ARGN], (ErlNifSInt64 *)&int64_##VAR); \ - new (&VAR) torch::Scalar(int64_##VAR); \ - } else { \ - new (&VAR) torch::Scalar(double_##VAR); \ + if (device.type() == torch::kCPU) { + return tensor_ok(tensor.clone()); + } else { + return tensor_ok(tensor.to(device)); } +} -#define SHAPE_PARAM(ARGN, VAR) TUPLE_PARAM(ARGN, std::vector, VAR) - -#define TYPE_PARAM(ARGN, VAR) \ - ATOM_PARAM(ARGN, VAR##_atom) \ - torch::ScalarType VAR = string2type(VAR##_atom) - -#define DEVICE_PARAM(ARGN, VAR) TUPLE_PARAM(ARGN, std::vector, VAR) - -#define DEVICE(DEV_VEC) torch::device(torch::Device((torch::DeviceType)DEV_VEC[0], (torch::DeviceIndex)DEV_VEC[1])) +REGISTER_TENSOR_NIF(from_blob); -#define OPTS(TYPE, DEV_VEC) DEVICE(DEV_VEC).dtype(TYPE) +// to_blob - arity 1 and 2 versions +fine::Ok to_blob_1(ErlNifEnv *env, + fine::ResourcePtr tensor_res) { + auto &t = get_tensor(tensor_res); + size_t byte_size = t.nbytes(); -#define TENSOR_PARAM(ARGN, VAR) \ - TensorP VAR##_tp(env, argv[ARGN]); \ - torch::Tensor *VAR; \ - if (!VAR##_tp.is_valid()) { \ - return VAR##_tp.error(); \ - } else { \ - VAR = VAR##_tp.data(); \ - } + torch::optional device = torch::device_of(t); + torch::Tensor reshaped = t.flatten(); + void *data_ptr = reshaped.data_ptr(); -#define CATCH() \ - catch (c10::Error & error) { \ - std::ostringstream msg; \ - msg << error.msg() << " in NIF." << __func__ << "/" << argc; \ - return nx::nif::error(env, msg.str().c_str()); \ - } + ErlNifBinary result; + enif_alloc_binary(byte_size, &result); -#define SCALAR(S) \ - try { \ - if (c10::isFloatingType(S.type())) \ - return nx::nif::ok(env, nx::nif::make(env, S.toDouble())); \ - else \ - return nx::nif::ok(env, nx::nif::make(env, (int64_t)S.toLong())); \ - } \ - CATCH() - -#define TENSOR(T) \ - try { \ - return nx::nif::ok(env, create_tensor_resource(env, T)); \ - } \ - CATCH() - -#define TENSOR_LIST(TL) \ - try { \ - const std::vector &tl = TL; \ - std::vector res_list; \ - for (torch::Tensor t : tl) \ - res_list.push_back(create_tensor_resource(env, t)); \ - return nx::nif::ok(env, enif_make_list_from_array(env, res_list.data(), res_list.size())); \ - } \ - CATCH() - -#define TENSOR_TUPLE(TT) \ - try { \ - const std::tuple &tt = TT; \ - std::vector res_list; \ - for (torch::Tensor t : {std::get<0>(tt), std::get<1>(tt)}) \ - res_list.push_back(create_tensor_resource(env, t)); \ - return nx::nif::ok(env, enif_make_tuple_from_array(env, res_list.data(), res_list.size())); \ - } \ - CATCH() - -#define TENSOR_TUPLE_3(TT) \ - try { \ - const std::tuple &tt = TT; \ - std::vector res_list; \ - for (torch::Tensor t : {std::get<0>(tt), std::get<1>(tt), std::get<2>(tt)}) \ - res_list.push_back(create_tensor_resource(env, t)); \ - return nx::nif::ok(env, enif_make_tuple_from_array(env, res_list.data(), res_list.size())); \ - } \ - CATCH() - -ERL_NIF_TERM -create_tensor_resource(ErlNifEnv *env, torch::Tensor tensor) { - ERL_NIF_TERM ret; - torch::Tensor *tensorPtr; - std::atomic *refcount; - - tensorPtr = (torch::Tensor *)enif_alloc_resource(TENSOR_TYPE, sizeof(torch::Tensor) + sizeof(std::atomic) + sizeof(std::atomic_flag)); - if (tensorPtr == NULL) - return enif_make_badarg(env); - - new (tensorPtr) torch::Tensor(tensor.variable_data()); - refcount = new (tensorPtr + 1) std::atomic(1); - new (refcount + 1) std::atomic_flag(); - - ret = enif_make_resource(env, tensorPtr); - enif_release_resource(tensorPtr); - - return ret; -} - -NIF(delete_tensor) { - TensorP tensor(env, argv[0]); - - return tensor.deallocate() ? nx::nif::ok(env) : enif_make_badarg(env); -} - -uint64_t elem_count(std::vector shape) { - return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>{}); -} - -NIF(from_blob) { - BINARY_PARAM(0, blob); - SHAPE_PARAM(1, shape); - TYPE_PARAM(2, type); - DEVICE_PARAM(3, device); - - if (blob.size / dtype_sizes[type_atom] < elem_count(shape)) - return nx::nif::error(env, "Binary size is too small for the requested shape"); - - auto tensor = torch::from_blob(blob.data, shape, torch::device(torch::kCPU).dtype(type)); - - if (DEVICE(device).device().type() == torch::kCPU) { - TENSOR(tensor.clone()); + // Always copy data to avoid use-after-free when tensor is deallocated + if (device.has_value() && device.value().type() == torch::kCPU) { + memcpy(result.data, data_ptr, byte_size); } else { - TENSOR(tensor.to(DEVICE(device))); + memcpy(result.data, reshaped.to(torch::kCPU).data_ptr(), byte_size); } -} - -NIF(to_blob) { - ERL_NIF_TERM result; - TENSOR_PARAM(0, t); - size_t byte_size = t->nbytes(); - int64_t limit = 0; - bool has_received_limit = (argc == 2); + return fine::Ok(result); +} - if (has_received_limit) { - PARAM(1, int64_t, param_limit); - limit = param_limit; - byte_size = limit * t->itemsize(); - } +fine::Ok to_blob_2(ErlNifEnv *env, + fine::ResourcePtr tensor_res, + int64_t limit) { + auto &t = get_tensor(tensor_res); + size_t byte_size = limit * t.itemsize(); - torch::optional device = torch::device_of(*t); - // flatten the tensor to compensate for operations which return - // a column-major tensor. t->flatten() is a no-op if the tensor - // is already row-major, which was verified by printing t->data_ptr - // and reshaped.data_ptr and confirming they had the same value. - // We also slice if a limit was received and it doesn't encompass the full tensor. - torch::Tensor reshaped = (has_received_limit && byte_size < t->nbytes()) ? t->flatten().slice(0, 0, limit) : t->flatten(); + torch::optional device = torch::device_of(t); + torch::Tensor reshaped = + (byte_size < t.nbytes()) ? t.flatten().slice(0, 0, limit) : t.flatten(); void *data_ptr = reshaped.data_ptr(); - if (device.has_value() && device.value().type() == torch::kCPU && data_ptr == t->data_ptr()) { - // case where we own the data_ptr and the data is in the CPU already - return nx::nif::ok(env, enif_make_resource_binary(env, t, data_ptr, byte_size)); - } else if (device.has_value() && device.value().type() == torch::kCPU) { - // case where we don't own the data_ptr but the data is in the CPU already - void *result_data = (void *)enif_make_new_binary(env, byte_size, &result); - memcpy(result_data, data_ptr, byte_size); - return nx::nif::ok(env, result); + ErlNifBinary result; + enif_alloc_binary(byte_size, &result); + + // Always copy data to avoid use-after-free when tensor is deallocated + if (device.has_value() && device.value().type() == torch::kCPU) { + memcpy(result.data, data_ptr, byte_size); } else { - // case where the data isn't in the CPU, therefore we don't own the data_ptr - void *result_data = (void *)enif_make_new_binary(env, byte_size, &result); - memcpy(result_data, reshaped.to(torch::kCPU).data_ptr(), byte_size); - return nx::nif::ok(env, result); + memcpy(result.data, reshaped.to(torch::kCPU).data_ptr(), byte_size); } + + return fine::Ok(result); } -NIF(item) { - TENSOR_PARAM(0, t); +REGISTER_TENSOR_NIF_ARITY(to_blob, to_blob_1); +REGISTER_TENSOR_NIF_ARITY(to_blob, to_blob_2); - SCALAR(t->item()); +fine::Ok item(ErlNifEnv *env, + fine::ResourcePtr tensor) { + return fine::Ok(get_tensor(tensor).item()); } -NIF(scalar_type) { - TENSOR_PARAM(0, t); +REGISTER_TENSOR_NIF(item); - const std::string *type_name = type2string(t->scalar_type()); - - if (type_name != nullptr) - return nx::nif::ok(env, enif_make_atom(env, type_name->c_str())); - else - return nx::nif::error(env, "Could not determine tensor type."); +fine::Ok scalar_type(ErlNifEnv *env, + fine::ResourcePtr tensor) { + const std::string *type_name = type2string(get_tensor(tensor).scalar_type()); + if (type_name != nullptr) { + return fine::Ok(fine::Atom(*type_name)); + } else { + throw std::runtime_error("Could not determine tensor type."); + } } -NIF(shape) { - TENSOR_PARAM(0, t); +FINE_NIF(scalar_type, 0); +fine::Ok shape(ErlNifEnv *env, + fine::ResourcePtr tensor) { + auto &t = get_tensor(tensor); std::vector sizes; - for (int64_t dim = 0; dim < t->dim(); dim++) - sizes.push_back(nx::nif::make(env, (t->size(dim)))); - - return nx::nif::ok(env, enif_make_tuple_from_array(env, sizes.data(), sizes.size())); + for (int64_t dim = 0; dim < t.dim(); dim++) { + sizes.push_back(fine::encode(env, t.size(dim))); + } + // Return as tuple (not list) since Elixir expects {} not [] + return fine::Ok( + fine::Term(enif_make_tuple_from_array(env, sizes.data(), sizes.size()))); } -NIF(mps_is_available) { +FINE_NIF(shape, 0); + +bool mps_is_available(ErlNifEnv *env) { #ifdef MAC_ARM64 - bool has_mps = at::hasMPS(); + return at::hasMPS(); #else - bool has_mps = false; + return false; #endif - return nx::nif::make(env, has_mps); } -NIF(cuda_is_available) { - return nx::nif::make(env, (bool)torch::cuda::is_available()); -} +FINE_NIF(mps_is_available, 0); + +bool cuda_is_available(ErlNifEnv *env) { return torch::cuda::is_available(); } + +FINE_NIF(cuda_is_available, 0); -NIF(cuda_device_count) { - return nx::nif::make(env, (int)torch::cuda::device_count()); +int64_t cuda_device_count(ErlNifEnv *env) { + return static_cast(torch::cuda::device_count()); } -NIF(nbytes) { - TENSOR_PARAM(0, t); +FINE_NIF(cuda_device_count, 0); - return nx::nif::ok(env, enif_make_int64(env, t->nbytes())); +fine::Ok nbytes(ErlNifEnv *env, + fine::ResourcePtr tensor) { + return fine::Ok(static_cast(get_tensor(tensor).nbytes())); } -NIF(split) { - TENSOR_PARAM(0, t); - PARAM(1, int64_t, batch_size); +FINE_NIF(nbytes, 0); - TENSOR_LIST(torch::split(*t, batch_size)); +// ============================================================================ +// Tensor Shape Operations +// ============================================================================ + +fine::Ok>> +split(ErlNifEnv *env, fine::ResourcePtr tensor, + int64_t batch_size) { + auto tensors = torch::split(get_tensor(tensor), batch_size); + std::vector> results; + for (const auto &t : tensors) { + results.push_back(fine::make_resource(t)); + } + return fine::Ok(results); } -NIF(reshape) { - TENSOR_PARAM(0, t); - SHAPE_PARAM(1, shape); +REGISTER_TENSOR_NIF(split); - TENSOR(torch::reshape(*t, shape)); +fine::Ok> +reshape(ErlNifEnv *env, fine::ResourcePtr tensor, + std::vector shape) { + return tensor_ok(torch::reshape(get_tensor(tensor), vec_to_array_ref(shape))); } -NIF(to_type) { - TENSOR_PARAM(0, t); - TYPE_PARAM(1, type); +REGISTER_TENSOR_NIF(reshape); - TENSOR(t->toType(type)); +fine::Ok> +to_type(ErlNifEnv *env, fine::ResourcePtr tensor, + fine::Atom type_atom) { + auto type = string2type(type_atom.to_string()); + return tensor_ok(get_tensor(tensor).toType(type)); } -NIF(to_device) { - TENSOR_PARAM(0, t); - DEVICE_PARAM(1, device); +REGISTER_TENSOR_NIF(to_type); - TENSOR(t->to(DEVICE(device))); +fine::Ok> +to_device(ErlNifEnv *env, fine::ResourcePtr tensor, + std::tuple device_tuple) { + auto device = tuple_to_device(device_tuple); + return tensor_ok(get_tensor(tensor).to(device)); } -NIF(squeeze) { - TENSOR_PARAM(0, t); +REGISTER_TENSOR_NIF(to_device); - if (argc == 2) { - PARAM(1, int64_t, dim); - TENSOR(torch::squeeze(*t, dim)); - } else - TENSOR(torch::squeeze(*t)); +std::variant>> +squeeze(ErlNifEnv *env, fine::ResourcePtr tensor, + std::optional dim) { + if (dim.has_value()) { + return tensor_ok(torch::squeeze(get_tensor(tensor), dim.value())); + } else { + return tensor_ok(torch::squeeze(get_tensor(tensor))); + } } -NIF(broadcast_to) { - TENSOR_PARAM(0, t); - SHAPE_PARAM(1, shape); +REGISTER_TENSOR_NIF(squeeze); - TENSOR(torch::broadcast_to(*t, shape).clone()); +fine::Ok> +broadcast_to(ErlNifEnv *env, fine::ResourcePtr tensor, + std::vector shape) { + return tensor_ok( + torch::broadcast_to(get_tensor(tensor), vec_to_array_ref(shape)).clone()); } -NIF(transpose) { - TENSOR_PARAM(0, t); - PARAM(1, int64_t, dim0); - PARAM(2, int64_t, dim1); +REGISTER_TENSOR_NIF(broadcast_to); - TENSOR(torch::transpose(*t, dim0, dim1)); +fine::Ok> +transpose(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t dim0, + int64_t dim1) { + return tensor_ok(torch::transpose(get_tensor(tensor), dim0, dim1)); } -NIF(slice) { - TENSOR_PARAM(0, input); - LIST_PARAM(1, std::vector, starts); - LIST_PARAM(2, std::vector, lengths); - LIST_PARAM(3, std::vector, strides); +REGISTER_TENSOR_NIF(transpose); - torch::Tensor output; - torch::Tensor destination = *input; +fine::Ok> +slice(ErlNifEnv *env, fine::ResourcePtr input, + std::vector starts, std::vector lengths, + std::vector strides) { - auto shape = input->sizes(); - size_t dim = 0; - for (dim = 0; dim < starts.size(); dim++) { - auto start = starts[dim]; - auto stride = strides[dim]; - auto length = lengths[dim]; - auto end = std::min(start + length, shape[dim]); // Ensuring we don't go out of bounds - - // arguments are dimension, start index, NON-INCLUSIVE end index and stride - destination = destination.slice(dim, start, end, stride); - if (dim == starts.size() - 1) { - output = destination.clone(); - } + torch::Tensor result = get_tensor(input); + auto shape = result.sizes(); + + for (size_t dim = 0; dim < starts.size(); dim++) { + int64_t start = starts[dim]; + int64_t stride = strides[dim]; + int64_t length = lengths[dim]; + int64_t end = std::min(start + length, shape[dim]); + + result = result.slice(dim, start, end, stride); } - TENSOR(output); + // Clone the result to ensure memory ownership + return tensor_ok(result.clone()); } -NIF(concatenate) { - LIST_PARAM(0, std::vector, tensors); +REGISTER_TENSOR_NIF(slice); - PARAM(1, int64_t, axis); - - TENSOR(torch::cat(tensors, axis)); +fine::Ok> +concatenate(ErlNifEnv *env, + std::vector> tensor_list, + int64_t dim) { + std::vector tensors; + for (const auto &t : tensor_list) { + tensors.push_back(get_tensor(t)); + } + return tensor_ok(torch::cat(tensors, dim)); } -NIF(gather) { - TENSOR_PARAM(0, input); - TENSOR_PARAM(1, indices); - PARAM(2, int64_t, axis); +REGISTER_TENSOR_NIF(concatenate); - TENSOR(torch::gather(*input, axis, *indices)); +fine::Ok> +gather(ErlNifEnv *env, fine::ResourcePtr input, + fine::ResourcePtr index, int64_t dim) { + return tensor_ok(torch::gather(get_tensor(input), dim, get_tensor(index))); } -NIF(index_put) { - TENSOR_PARAM(0, input); - LIST_PARAM(1, std::vector, indices); - TENSOR_PARAM(2, updates); - PARAM(3, bool, accumulate); +REGISTER_TENSOR_NIF(gather); - c10::List> convertedList; - for (const torch::Tensor &tensor : indices) { - convertedList.push_back(tensor); +fine::Ok> +index_put(ErlNifEnv *env, fine::ResourcePtr input, + std::vector> indices, + fine::ResourcePtr values, bool accumulate) { + + c10::List> torch_indices; + for (const auto &idx : indices) { + torch_indices.push_back(get_tensor(idx)); } - TENSOR(torch::index_put(*input, convertedList, *updates, accumulate)); + torch::Tensor result = get_tensor(input).clone(); + result.index_put_(torch_indices, get_tensor(values), accumulate); + return tensor_ok(result); } -NIF(index) { - TENSOR_PARAM(0, input); - LIST_PARAM(1, std::vector, indices); +REGISTER_TENSOR_NIF(index_put); + +fine::Ok> +index(ErlNifEnv *env, fine::ResourcePtr input, + std::vector> indices) { - c10::List> convertedList; - for (const torch::Tensor &tensor : indices) { - convertedList.push_back(tensor); + c10::List> torch_indices; + for (const auto &idx : indices) { + torch_indices.push_back(get_tensor(idx)); } - TENSOR(torch::index(*input, convertedList)); + return tensor_ok(get_tensor(input).index(torch_indices)); } -NIF(argsort) { - TENSOR_PARAM(0, input); - PARAM(1, bool, stable); - PARAM(2, int64_t, axis); - PARAM(3, bool, is_descending); +REGISTER_TENSOR_NIF(index); - TENSOR(torch::argsort(*input, stable, axis, is_descending)); +fine::Ok> +argsort(ErlNifEnv *env, fine::ResourcePtr input, bool stable, + int64_t dim, bool descending) { + return tensor_ok(torch::argsort(get_tensor(input), stable, dim, descending)); } -NIF(top_k) { - TENSOR_PARAM(0, input); - PARAM(1, int64_t, k); +REGISTER_TENSOR_NIF(argsort); - TENSOR_TUPLE(at::topk(*input, k)); +fine::Ok< + std::tuple, fine::ResourcePtr>> +top_k(ErlNifEnv *env, fine::ResourcePtr input, int64_t k) { + auto result = torch::topk(get_tensor(input), k); + return fine::Ok( + std::make_tuple(fine::make_resource(std::get<0>(result)), + fine::make_resource(std::get<1>(result)))); } -NIF(flip) { - TENSOR_PARAM(0, input); - LIST_PARAM(1, std::vector, dims); +REGISTER_TENSOR_NIF(top_k); - TENSOR(torch::flip(*input, dims)); +fine::Ok> +flip(ErlNifEnv *env, fine::ResourcePtr input, + std::vector dims) { + return tensor_ok(torch::flip(get_tensor(input), vec_to_array_ref(dims))); } -NIF(unfold) { - TENSOR_PARAM(0, input); - PARAM(1, int64_t, dim); - PARAM(2, int64_t, size); - PARAM(3, int64_t, step); +REGISTER_TENSOR_NIF(flip); - TENSOR(at::native::unfold(*input, dim, size, step)); +fine::Ok> +unfold(ErlNifEnv *env, fine::ResourcePtr input, int64_t dim, + int64_t size, int64_t step) { + return tensor_ok(get_tensor(input).unfold(dim, size, step)); } -NIF(put) { - TENSOR_PARAM(0, input); - LIST_PARAM(1, std::vector, indices); - TENSOR_PARAM(2, source); +REGISTER_TENSOR_NIF(unfold); - torch::Tensor output = input->clone(); +fine::Ok> +put(ErlNifEnv *env, fine::ResourcePtr input, + std::vector indices, fine::ResourcePtr source) { + + torch::Tensor output = get_tensor(input).clone(); torch::Tensor destination = output; - auto source_shape = source->sizes(); + auto source_shape = get_tensor(source).sizes(); size_t dim = 0; for (dim = 0; dim < indices.size() - 1; dim++) { auto start = indices[dim]; - // arguments are dimension, start index and NON-INCLUSIVE end index destination = destination.slice(dim, start, start + source_shape[dim]); } - auto start = indices[dim]; - destination.slice(dim, start, start + source_shape[dim]) = *source; - - TENSOR(output); -} - -NIF(permute) { - TENSOR_PARAM(0, t); - LIST_PARAM(1, std::vector, dims); + destination.slice(dim, indices[dim], indices[dim] + source_shape[dim]) + .copy_(get_tensor(source)); - TENSOR(t->permute(dims).contiguous()); + return tensor_ok(output); } -/* Creation */ +REGISTER_TENSOR_NIF(put); -NIF(scalar_tensor) { - SCALAR_PARAM(0, scalar); - TYPE_PARAM(1, type); - DEVICE_PARAM(2, device); - - TENSOR(torch::scalar_tensor(scalar, OPTS(type, device))); +fine::Ok> +permute(ErlNifEnv *env, fine::ResourcePtr input, + std::vector permutation) { + return tensor_ok(get_tensor(input).permute(vec_to_array_ref(permutation))); } -NIF(randint) { - PARAM(0, int64_t, min); - PARAM(1, int64_t, max); - SHAPE_PARAM(2, shape); - TYPE_PARAM(3, type); - DEVICE_PARAM(4, device); - - TENSOR(torch::randint(min, max, shape, OPTS(type, device))); -} +REGISTER_TENSOR_NIF(permute); -NIF(rand) { - PARAM(0, double, min); - PARAM(1, double, max); - SHAPE_PARAM(2, shape); - TYPE_PARAM(3, type); - DEVICE_PARAM(4, device); +// ============================================================================ +// Tensor Creation Functions +// ============================================================================ - TENSOR(min + torch::rand(shape, OPTS(type, device)) * (max - min)); +fine::Ok> +scalar_tensor(ErlNifEnv *env, torch::Scalar scalar, fine::Atom type_atom, + std::tuple device_tuple) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + return tensor_ok( + torch::scalar_tensor(scalar, torch::dtype(type).device(device))); } -NIF(normal) { - PARAM(0, double, mean); - PARAM(1, double, std); - SHAPE_PARAM(2, shape); - TYPE_PARAM(3, type); - DEVICE_PARAM(4, device); +REGISTER_TENSOR_NIF(scalar_tensor); - TENSOR(torch::normal(mean, std, shape, c10::nullopt, OPTS(type, device))); +fine::Ok> +randint(ErlNifEnv *env, int64_t low, int64_t high, std::vector shape, + fine::Atom type_atom, std::tuple device_tuple) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + return tensor_ok(torch::randint(low, high, vec_to_array_ref(shape), + torch::dtype(type).device(device))); } -NIF(arange) { - PARAM(0, int64_t, start); - PARAM(1, int64_t, end); - PARAM(2, int64_t, step); - TYPE_PARAM(3, type); - DEVICE_PARAM(4, device); +REGISTER_TENSOR_NIF(randint); - if (argc == 6) { - SHAPE_PARAM(5, shape); - TENSOR(torch::reshape(torch::arange((double)start, (double)end, (double)step, OPTS(type, device)), shape)); +fine::Ok> +rand(ErlNifEnv *env, double min, double max, std::vector shape, + fine::Atom type_atom, std::tuple device_tuple) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + auto result = + torch::rand(vec_to_array_ref(shape), torch::dtype(type).device(device)); + // Scale from [0, 1) to [min, max) + result = result * (max - min) + min; + return tensor_ok(result); +} + +REGISTER_TENSOR_NIF(rand); + +fine::Ok> +normal(ErlNifEnv *env, double mean, double std, std::vector shape, + fine::Atom type_atom, std::tuple device_tuple) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + return tensor_ok(torch::normal(mean, std, vec_to_array_ref(shape), + c10::nullopt, + torch::dtype(type).device(device))); +} + +REGISTER_TENSOR_NIF(normal); + +// arange - arity 5 and 6 versions +fine::Ok> +arange_5(ErlNifEnv *env, int64_t start, int64_t end, int64_t step, + fine::Atom type_atom, std::tuple device_tuple) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + return tensor_ok(torch::arange( + static_cast(start), static_cast(end), + static_cast(step), torch::dtype(type).device(device))); +} + +fine::Ok> +arange_6(ErlNifEnv *env, int64_t start, int64_t end, int64_t step, + fine::Atom type_atom, std::tuple device_tuple, + std::vector shape) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + auto result = torch::arange( + static_cast(start), static_cast(end), + static_cast(step), torch::dtype(type).device(device)); + return tensor_ok(torch::reshape(result, vec_to_array_ref(shape))); +} + +REGISTER_TENSOR_NIF_ARITY(arange, arange_5); +REGISTER_TENSOR_NIF_ARITY(arange, arange_6); + +fine::Ok> +ones(ErlNifEnv *env, std::vector shape, fine::Atom type_atom, + std::tuple device_tuple) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + return tensor_ok( + torch::ones(vec_to_array_ref(shape), torch::dtype(type).device(device))); +} + +REGISTER_TENSOR_NIF(ones); + +fine::Ok> +eye(ErlNifEnv *env, int64_t m, int64_t n, fine::Atom type_atom, + std::tuple device_tuple) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + return tensor_ok(torch::eye(m, n, torch::dtype(type).device(device))); +} + +REGISTER_TENSOR_NIF(eye); + +fine::Ok> +full(ErlNifEnv *env, std::vector shape, torch::Scalar scalar, + fine::Atom type_atom, std::tuple device_tuple) { + auto type = string2type(type_atom.to_string()); + auto device = tuple_to_device(device_tuple); + return tensor_ok(torch::full(vec_to_array_ref(shape), scalar, + torch::dtype(type).device(device))); +} + +REGISTER_TENSOR_NIF(full); + +// ============================================================================ +// Binary Operations +// ============================================================================ + +#define BINARY_OP(NAME, TORCH_OP) \ + fine::Ok> NAME( \ + ErlNifEnv *env, fine::ResourcePtr a, \ + fine::ResourcePtr b) { \ + return tensor_ok(torch::TORCH_OP(get_tensor(a), get_tensor(b))); \ + } \ + REGISTER_TENSOR_NIF(NAME) + +BINARY_OP(bitwise_and, bitwise_and); +BINARY_OP(bitwise_or, bitwise_or); +BINARY_OP(bitwise_xor, bitwise_xor); +BINARY_OP(left_shift, __lshift__); +BINARY_OP(right_shift, __rshift__); +BINARY_OP(equal, eq); +BINARY_OP(not_equal, not_equal); +BINARY_OP(greater, greater); +BINARY_OP(less, less); +BINARY_OP(greater_equal, greater_equal); +BINARY_OP(less_equal, less_equal); +BINARY_OP(logical_and, logical_and); +BINARY_OP(logical_or, logical_or); +BINARY_OP(logical_xor, logical_xor); +BINARY_OP(add, add); +BINARY_OP(subtract, subtract); +BINARY_OP(divide, divide); +BINARY_OP(remainder, remainder); +BINARY_OP(quotient, floor_divide); +BINARY_OP(multiply, multiply); +BINARY_OP(pow, pow); +BINARY_OP(atan2, atan2); +BINARY_OP(min, min); +BINARY_OP(max, max); +BINARY_OP(fmod, fmod); + +#undef BINARY_OP + +// ============================================================================ +// Unary Operations +// ============================================================================ + +#define UNARY_OP(NAME, TORCH_OP) \ + fine::Ok> NAME( \ + ErlNifEnv *env, fine::ResourcePtr a) { \ + return tensor_ok(torch::TORCH_OP(get_tensor(a))); \ + } \ + REGISTER_TENSOR_NIF(NAME) + +UNARY_OP(abs, abs); +UNARY_OP(ceil, ceil); +UNARY_OP(floor, floor); +UNARY_OP(negate, neg); +UNARY_OP(round, round); +UNARY_OP(sign, sign); +UNARY_OP(exp, exp); +UNARY_OP(expm1, expm1); +UNARY_OP(sqrt, sqrt); +UNARY_OP(rsqrt, rsqrt); +UNARY_OP(log, log); +UNARY_OP(log1p, log1p); +UNARY_OP(bitwise_not, bitwise_not); +UNARY_OP(logical_not, logical_not); +UNARY_OP(sigmoid, sigmoid); +UNARY_OP(sin, sin); +UNARY_OP(asin, asin); +UNARY_OP(sinh, sinh); +UNARY_OP(asinh, asinh); +UNARY_OP(cos, cos); +UNARY_OP(acos, acos); +UNARY_OP(cosh, cosh); +UNARY_OP(acosh, acosh); +UNARY_OP(tan, tan); +UNARY_OP(atan, atan); +UNARY_OP(tanh, tanh); +UNARY_OP(atanh, atanh); +UNARY_OP(erf, erf); +UNARY_OP(erfc, erfc); +UNARY_OP(erf_inv, erfinv); +// cbrt is not in torch namespace, needs custom implementation +fine::Ok> +cbrt(ErlNifEnv *env, fine::ResourcePtr tensor) { + auto &t = get_tensor(tensor); + if (t.scalar_type() == torch::kDouble) { + return tensor_ok(torch::pow(t, 1.0 / 3)); } else { - TENSOR(torch::arange((double)start, (double)end, (double)step, OPTS(type, device))); + return tensor_ok(torch::pow(t, 1.0f / 3)); } } -NIF(ones) { - SHAPE_PARAM(0, shape); - TYPE_PARAM(1, type); - DEVICE_PARAM(2, device); - - TENSOR(torch::ones(shape, OPTS(type, device))); +REGISTER_TENSOR_NIF(cbrt); +UNARY_OP(is_nan, isnan); +UNARY_OP(is_infinity, isinf); +UNARY_OP(view_as_real, view_as_real); +// conjugate needs special handling - conj() returns a view, must clone +fine::Ok> +conjugate(ErlNifEnv *env, fine::ResourcePtr a) { + at::Tensor conjugated = get_tensor(a).conj(); + return tensor_ok(conjugated.clone(conjugated.suggest_memory_format())); } -NIF(eye) { - PARAM(0, int64_t, m); - PARAM(1, int64_t, n); - TYPE_PARAM(2, type); - DEVICE_PARAM(3, device); +REGISTER_TENSOR_NIF(conjugate); - TENSOR(torch::eye(m, n, OPTS(type, device))); -} - -NIF(full) { - SHAPE_PARAM(0, shape); - SCALAR_PARAM(1, scalar); - TYPE_PARAM(2, type); - DEVICE_PARAM(3, device); +#undef UNARY_OP - TENSOR(torch::full(shape, scalar, OPTS(type, device))); -} +// ============================================================================ +// Reduction Operations +// ============================================================================ -/* Binary Ops */ +fine::Ok> +tensordot(ErlNifEnv *env, fine::ResourcePtr a, + fine::ResourcePtr b, std::vector axes_a, + std::vector batch_a, std::vector axes_b, + std::vector batch_b) { -#define BINARY_OP(OP) BINARY_OP2(OP, OP) - -#define BINARY_OP2(OP, NATIVE_OP) \ - NIF(OP) { \ - TENSOR_PARAM(0, a); \ - TENSOR_PARAM(1, b); \ - \ - TENSOR(torch::NATIVE_OP(*a, *b)); \ - } - -#define BINARY_OPB(OP) \ - NIF(OP) { \ - TENSOR_PARAM(0, a); \ - TENSOR_PARAM(1, b); \ - \ - nx::nif::ok(env, nx::nif::make(env, torch::OP(*a, *b))); \ - } - -#define UNARY_OP(OP) UNARY_OP2(OP, OP) - -#define UNARY_OP2(OP, NATIVE) \ - NIF(OP) { \ - TENSOR_PARAM(0, a); \ - TENSOR(torch::NATIVE(*a)); \ - } - -BINARY_OP(bitwise_and) -BINARY_OP(bitwise_or) -BINARY_OP(bitwise_xor) -BINARY_OP2(left_shift, __lshift__) -BINARY_OP2(right_shift, __rshift__) - -BINARY_OP2(equal, eq) -BINARY_OP(not_equal) -BINARY_OP(greater) -BINARY_OP(less) -BINARY_OP(greater_equal) -BINARY_OP(less_equal) - -BINARY_OP(logical_and) -BINARY_OP(logical_or) -BINARY_OP(logical_xor) - -BINARY_OP(add) -BINARY_OP(subtract) -BINARY_OP(divide) -BINARY_OP(remainder) -BINARY_OP(multiply) -BINARY_OP(matmul) -BINARY_OP2(pow, pow) -BINARY_OP(atan2) -BINARY_OP(min) -BINARY_OP(max) - -NIF(fmod) { - TENSOR_PARAM(0, a); - TENSOR_PARAM(1, b); - TENSOR(at::fmod(*a, *b)); -} - -NIF(quotient) { - TENSOR_PARAM(0, a); - TENSOR_PARAM(1, b); - TENSOR(torch::divide(*a, *b, "trunc")); -} - -NIF(tensordot) { - TENSOR_PARAM(0, t1); - TENSOR_PARAM(1, t2); - LIST_PARAM(2, std::vector, axes1); - LIST_PARAM(3, std::vector, batch_axes1); - LIST_PARAM(4, std::vector, axes2); - LIST_PARAM(5, std::vector, batch_axes2); - - bool is_batched = batch_axes1.size() > 0 || batch_axes2.size() > 0; + bool is_batched = batch_a.size() > 0 || batch_b.size() > 0; torch::Tensor result; if (is_batched) { - // if any of the tensors is batched, we need to apply some transformations - // on the inputs and on the result to wrap the batched APIs that torch exposes - std::vector batch_dims1, batch_dims2; + // Handle batched tensordot using vmap/BatchDim + std::vector batch_dims_a, batch_dims_b; int64_t vmap_level = 0; - for (auto dim : batch_axes1) { - batch_dims1.push_back(at::BatchDim(vmap_level++, dim)); + for (auto dim : batch_a) { + batch_dims_a.push_back(at::BatchDim(vmap_level++, dim)); } - torch::Tensor batched_1 = at::makeBatched(*t1, at::BatchDims(batch_dims1.begin(), batch_dims1.end())); + torch::Tensor batched_a = at::makeBatched( + get_tensor(a), at::BatchDims(batch_dims_a.begin(), batch_dims_a.end())); vmap_level = 0; - - for (auto dim : batch_axes2) { - batch_dims2.push_back(at::BatchDim(vmap_level++, dim)); + for (auto dim : batch_b) { + batch_dims_b.push_back(at::BatchDim(vmap_level++, dim)); } - torch::Tensor batched_2 = at::makeBatched(*t2, at::BatchDims(batch_dims2.begin(), batch_dims2.end())); + torch::Tensor batched_b = at::makeBatched( + get_tensor(b), at::BatchDims(batch_dims_b.begin(), batch_dims_b.end())); + + torch::Tensor batched_result = + torch::tensordot(batched_a, batched_b, vec_to_array_ref(axes_a), + vec_to_array_ref(axes_b)); - torch::Tensor batched_result = torch::tensordot(batched_1, batched_2, axes1, axes2); auto impl = at::maybeGetBatchedImpl(batched_result); if (!impl) { - return nx::nif::error(env, "unable to get tensordot result"); + throw std::runtime_error("unable to get tensordot result"); } result = torch::clone(impl->value()); } else { - result = torch::tensordot(*t1, *t2, axes1, axes2); + result = + torch::tensordot(get_tensor(a), get_tensor(b), vec_to_array_ref(axes_a), + vec_to_array_ref(axes_b)); } - TENSOR(result); -} - -/* Unary Ops */ - -UNARY_OP(abs) -UNARY_OP(ceil) -UNARY_OP(floor) -UNARY_OP2(negate, negative) -UNARY_OP(round) -UNARY_OP(sign) -UNARY_OP(exp) -UNARY_OP(expm1) -UNARY_OP(sqrt) -UNARY_OP(rsqrt) -UNARY_OP(log) -UNARY_OP(log1p) -UNARY_OP(bitwise_not) -UNARY_OP(logical_not) -UNARY_OP2(sigmoid, sigmoid) - -UNARY_OP(sin) -UNARY_OP(asin) -UNARY_OP(sinh) -UNARY_OP(asinh) -UNARY_OP(cos) -UNARY_OP(acos) -UNARY_OP(cosh) -UNARY_OP(acosh) -UNARY_OP(tan) -UNARY_OP(atan) -UNARY_OP(tanh) -UNARY_OP(atanh) -UNARY_OP(erf) -UNARY_OP(erfc) -UNARY_OP2(erf_inv, erfinv) - -NIF(view_as_real) { - TENSOR_PARAM(0, tensor); - TENSOR(torch::view_as_real(*tensor)); -} - -NIF(conjugate) { - TENSOR_PARAM(0, tensor); - at::Tensor conjugated = tensor->conj(); - TENSOR(conjugated.clone(conjugated.suggest_memory_format())); -} - -NIF(triangular_solve) { - TENSOR_PARAM(0, a); - TENSOR_PARAM(1, b); - PARAM(2, bool, transpose); - PARAM(3, bool, upper); - - auto ts_a = *a; - if (transpose) { - auto num_dims = a->dim(); - ts_a = torch::transpose(*a, num_dims - 2, num_dims - 1); - upper = !upper; - } + return tensor_ok(result); +} - torch::Tensor result = torch::linalg_solve_triangular(ts_a, *b, upper, true, false); +REGISTER_TENSOR_NIF(tensordot); - TENSOR(result); +fine::Ok> +matmul(ErlNifEnv *env, fine::ResourcePtr a, + fine::ResourcePtr b) { + return tensor_ok(torch::matmul(get_tensor(a), get_tensor(b))); } -NIF(determinant) { - TENSOR_PARAM(0, t); +REGISTER_TENSOR_NIF(matmul); - TENSOR(t->det()); +fine::Ok> +pad(ErlNifEnv *env, fine::ResourcePtr tensor, + fine::ResourcePtr constant, std::vector config) { + return tensor_ok(torch::constant_pad_nd(get_tensor(tensor), + vec_to_array_ref(config), + get_tensor(constant).item())); } -NIF(sort) { - TENSOR_PARAM(0, t); - PARAM(1, bool, stable); - PARAM(2, int64_t, axis); - PARAM(3, bool, descending); +REGISTER_TENSOR_NIF(pad); + +fine::Ok> +triangular_solve(ErlNifEnv *env, fine::ResourcePtr a, + fine::ResourcePtr b, bool transpose, bool upper) { + auto ts_a = get_tensor(a); + if (transpose) { + auto num_dims = ts_a.dim(); + ts_a = torch::transpose(ts_a, num_dims - 2, num_dims - 1); + upper = !upper; + } - std::tuple result = t->sort(stable, axis, descending); - TENSOR(std::get<0>(result)); + torch::Tensor result = + torch::linalg_solve_triangular(ts_a, get_tensor(b), upper, true, false); + return tensor_ok(result); } -NIF(clip) { - TENSOR_PARAM(0, t); - TENSOR_PARAM(1, min); - TENSOR_PARAM(2, max); +REGISTER_TENSOR_NIF(triangular_solve); - TENSOR(torch::clip(*t, *min, *max)); +fine::Ok> +determinant(ErlNifEnv *env, fine::ResourcePtr t) { + return tensor_ok(get_tensor(t).det()); } -NIF(where) { - TENSOR_PARAM(0, pred); - TENSOR_PARAM(1, on_true); - TENSOR_PARAM(2, on_false); +REGISTER_TENSOR_NIF(determinant); - TENSOR(torch::where(*pred, *on_true, *on_false)); +fine::Ok> sort(ErlNifEnv *env, + fine::ResourcePtr t, + bool stable, int64_t axis, + bool descending) { + std::tuple result = + get_tensor(t).sort(stable, axis, descending); + return tensor_ok(std::get<0>(result)); } -/* Aggregates */ +REGISTER_TENSOR_NIF(sort); -NIF(sum) { - TENSOR_PARAM(0, t); - LIST_PARAM(1, std::vector, dims); - PARAM(2, bool, keep_dim); - - TENSOR(torch::sum(*t, dims, keep_dim)); +fine::Ok> +clip(ErlNifEnv *env, fine::ResourcePtr t, + fine::ResourcePtr min, fine::ResourcePtr max) { + return tensor_ok( + torch::clip(get_tensor(t), get_tensor(min), get_tensor(max))); } -NIF(product) { - TENSOR_PARAM(0, t); +REGISTER_TENSOR_NIF(clip); - if (argc == 1) { - TENSOR(torch::prod(*t)); - } +fine::Ok> +where(ErlNifEnv *env, fine::ResourcePtr pred, + fine::ResourcePtr on_true, + fine::ResourcePtr on_false) { + return tensor_ok(torch::where(get_tensor(pred), get_tensor(on_true), + get_tensor(on_false))); +} - PARAM(1, int64_t, dim); - PARAM(2, bool, keep_dim); +REGISTER_TENSOR_NIF(where); - TENSOR(torch::prod(*t, dim, keep_dim)); +fine::Ok> sum(ErlNifEnv *env, + fine::ResourcePtr t, + std::vector dims, + bool keep_dim) { + return tensor_ok(torch::sum(get_tensor(t), vec_to_array_ref(dims), keep_dim)); } -NIF(argmax) { - TENSOR_PARAM(0, t); - PARAM(1, int64_t, dim); - PARAM(2, bool, keep_dim); +REGISTER_TENSOR_NIF(sum); - if (dim == -1) { - TENSOR(torch::argmax(*t)); - } else { - TENSOR(torch::argmax(*t, dim, keep_dim)); - } +// product - arity 1 and 3 versions +fine::Ok> +product_1(ErlNifEnv *env, fine::ResourcePtr t) { + return tensor_ok(torch::prod(get_tensor(t))); +} + +fine::Ok> +product_3(ErlNifEnv *env, fine::ResourcePtr t, int64_t dim, + bool keep_dim) { + return tensor_ok(torch::prod(get_tensor(t), dim, keep_dim)); } -NIF(argmin) { - TENSOR_PARAM(0, t); - PARAM(1, int64_t, dim); - PARAM(2, bool, keep_dim); +REGISTER_TENSOR_NIF_ARITY(product, product_1); +REGISTER_TENSOR_NIF_ARITY(product, product_3); +fine::Ok> +argmax(ErlNifEnv *env, fine::ResourcePtr t, int64_t dim, + bool keep_dim) { if (dim == -1) { - TENSOR(torch::argmin(*t)); + return tensor_ok(torch::argmax(get_tensor(t))); } else { - TENSOR(torch::argmin(*t, dim, keep_dim)); + return tensor_ok(torch::argmax(get_tensor(t), dim, keep_dim)); } } -NIF(cbrt) { - TENSOR_PARAM(0, tensor); +REGISTER_TENSOR_NIF(argmax); - if (tensor->scalar_type() == torch::kDouble) { - TENSOR(torch::pow(*tensor, 1.0 / 3)); +fine::Ok> +argmin(ErlNifEnv *env, fine::ResourcePtr t, int64_t dim, + bool keep_dim) { + if (dim == -1) { + return tensor_ok(torch::argmin(get_tensor(t))); } else { - TENSOR(torch::pow(*tensor, 1.0f / 3)); + return tensor_ok(torch::argmin(get_tensor(t), dim, keep_dim)); } } -NIF(fft) { - TENSOR_PARAM(0, tensor); - PARAM(1, int64_t, length); - PARAM(2, int64_t, axis); - TENSOR(torch::fft::fft(*tensor, length, axis)); -} -NIF(ifft) { - TENSOR_PARAM(0, tensor); - PARAM(1, int64_t, length); - PARAM(2, int64_t, axis); - TENSOR(torch::fft::ifft(*tensor, length, axis)); -} +REGISTER_TENSOR_NIF(argmin); -NIF(fft2) { - TENSOR_PARAM(0, tensor); - LIST_PARAM(1, std::vector, lengths); - LIST_PARAM(2, std::vector, axes); - TENSOR(torch::fft::fft2(*tensor, lengths, axes)); +fine::Ok> +fft(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t length, + int64_t axis) { + return tensor_ok(torch::fft::fft(get_tensor(tensor), length, axis)); } -NIF(ifft2) { - TENSOR_PARAM(0, tensor); - LIST_PARAM(1, std::vector, lengths); - LIST_PARAM(2, std::vector, axes); - TENSOR(torch::fft::ifft2(*tensor, lengths, axes)); +REGISTER_TENSOR_NIF(fft); + +fine::Ok> +ifft(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t length, + int64_t axis) { + return tensor_ok(torch::fft::ifft(get_tensor(tensor), length, axis)); } -NIF(is_nan) { - TENSOR_PARAM(0, tensor); +REGISTER_TENSOR_NIF(ifft); - TENSOR(torch::isnan(*tensor)); +fine::Ok> +fft2(ErlNifEnv *env, fine::ResourcePtr tensor, + std::vector lengths, std::vector axes) { + return tensor_ok(torch::fft::fft2( + get_tensor(tensor), vec_to_array_ref(lengths), vec_to_array_ref(axes))); } -NIF(is_infinity) { - TENSOR_PARAM(0, tensor); +REGISTER_TENSOR_NIF(fft2); - TENSOR(torch::isinf(*tensor)); +fine::Ok> +ifft2(ErlNifEnv *env, fine::ResourcePtr tensor, + std::vector lengths, std::vector axes) { + return tensor_ok(torch::fft::ifft2( + get_tensor(tensor), vec_to_array_ref(lengths), vec_to_array_ref(axes))); } -NIF(all) { - TENSOR_PARAM(0, t); +REGISTER_TENSOR_NIF(ifft2); - if (argc == 1) { - TENSOR(torch::all(*t)); - } else { - PARAM(1, int64_t, axis); - PARAM(2, bool, keep_dim); - - TENSOR(torch::all(*t, axis, keep_dim)); - } +// all - arity 1 and 3 versions +fine::Ok> +all_1(ErlNifEnv *env, fine::ResourcePtr t) { + return tensor_ok(torch::all(get_tensor(t))); } -NIF(any) { - TENSOR_PARAM(0, t); +fine::Ok> all_3(ErlNifEnv *env, + fine::ResourcePtr t, + int64_t axis, bool keep_dim) { + return tensor_ok(torch::all(get_tensor(t), axis, keep_dim)); +} - if (argc == 1) { - TENSOR(torch::any(*t)); - } else { - PARAM(1, int64_t, axis); - PARAM(2, bool, keep_dim); +REGISTER_TENSOR_NIF_ARITY(all, all_1); +REGISTER_TENSOR_NIF_ARITY(all, all_3); - TENSOR(torch::any(*t, axis, keep_dim)); - } +// any - arity 1 and 3 versions +fine::Ok> +any_1(ErlNifEnv *env, fine::ResourcePtr t) { + return tensor_ok(torch::any(get_tensor(t))); } -NIF(all_close) { - TENSOR_PARAM(0, a); - TENSOR_PARAM(1, b); - PARAM(2, double, rtol); - PARAM(3, double, atol); - PARAM(4, bool, equal_nan); +fine::Ok> any_3(ErlNifEnv *env, + fine::ResourcePtr t, + int64_t axis, bool keep_dim) { + return tensor_ok(torch::any(get_tensor(t), axis, keep_dim)); +} - bool all_close = torch::allclose(*a, *b, rtol, atol, equal_nan); +REGISTER_TENSOR_NIF_ARITY(any, any_1); +REGISTER_TENSOR_NIF_ARITY(any, any_3); +fine::Ok> +all_close(ErlNifEnv *env, fine::ResourcePtr a, + fine::ResourcePtr b, double rtol, double atol, + bool equal_nan) { + bool result = + torch::allclose(get_tensor(a), get_tensor(b), rtol, atol, equal_nan); auto init_opts = torch::device(torch::kCPU).dtype(torch::kBool); - TENSOR(torch::scalar_tensor(all_close, init_opts)); + return tensor_ok(torch::scalar_tensor(result, init_opts)); } -NIF(cumulative_sum) { - TENSOR_PARAM(0, t); - PARAM(1, int64_t, axis); +REGISTER_TENSOR_NIF(all_close); - TENSOR(torch::cumsum(*t, axis)); +fine::Ok> +cumulative_sum(ErlNifEnv *env, fine::ResourcePtr t, int64_t axis) { + return tensor_ok(torch::cumsum(get_tensor(t), axis)); } -NIF(cumulative_product) { - TENSOR_PARAM(0, t); - PARAM(1, int64_t, axis); +REGISTER_TENSOR_NIF(cumulative_sum); - TENSOR(torch::cumprod(*t, axis)); +fine::Ok> +cumulative_product(ErlNifEnv *env, fine::ResourcePtr t, + int64_t axis) { + return tensor_ok(torch::cumprod(get_tensor(t), axis)); } -NIF(cumulative_min) { - TENSOR_PARAM(0, t); - PARAM(1, int64_t, axis); +REGISTER_TENSOR_NIF(cumulative_product); - const std::tuple &tt = torch::cummin(*t, axis); - TENSOR(std::get<0>(tt)); +fine::Ok> +cumulative_min(ErlNifEnv *env, fine::ResourcePtr t, int64_t axis) { + const std::tuple &tt = + torch::cummin(get_tensor(t), axis); + return tensor_ok(std::get<0>(tt)); } -NIF(cumulative_max) { - TENSOR_PARAM(0, t); - PARAM(1, int64_t, axis); +REGISTER_TENSOR_NIF(cumulative_min); - const std::tuple &tt = torch::cummax(*t, axis); - TENSOR(std::get<0>(tt)); +fine::Ok> +cumulative_max(ErlNifEnv *env, fine::ResourcePtr t, int64_t axis) { + const std::tuple &tt = + torch::cummax(get_tensor(t), axis); + return tensor_ok(std::get<0>(tt)); } -NIF(cholesky) { - TENSOR_PARAM(0, t); - bool upper = false; +REGISTER_TENSOR_NIF(cumulative_max); - if (argc == 2) { - GET(1, upper); - } +// cholesky - arity 1 and 2 versions +fine::Ok> +cholesky_1(ErlNifEnv *env, fine::ResourcePtr t) { + return tensor_ok(torch::cholesky(get_tensor(t))); +} +fine::Ok> +cholesky_2(ErlNifEnv *env, fine::ResourcePtr t, bool upper) { if (upper) { - TENSOR(torch::cholesky(*t).mH()); + return tensor_ok(torch::cholesky(get_tensor(t)).mH()); } - - TENSOR(torch::cholesky(*t)); + return tensor_ok(torch::cholesky(get_tensor(t))); } -NIF(pad) { - TENSOR_PARAM(0, tensor); - TENSOR_PARAM(1, constant); - LIST_PARAM(2, std::vector, config); +REGISTER_TENSOR_NIF_ARITY(cholesky, cholesky_1); +REGISTER_TENSOR_NIF_ARITY(cholesky, cholesky_2); - TENSOR(torch::constant_pad_nd(*tensor, config, constant->item())); +// qr - arity 1 and 2 versions +fine::Ok< + std::tuple, fine::ResourcePtr>> +qr_1(ErlNifEnv *env, fine::ResourcePtr t) { + auto result = torch::linalg_qr(get_tensor(t), "reduced"); + return fine::Ok( + std::make_tuple(fine::make_resource(std::get<0>(result)), + fine::make_resource(std::get<1>(result)))); } -/* Transformations */ - -NIF(qr) { - TENSOR_PARAM(0, t); - bool reduced = true; - - if (argc == 2) { - GET(1, reduced); - } - - TENSOR_TUPLE(torch::linalg_qr(*t, reduced ? "reduced" : "complete")); +fine::Ok< + std::tuple, fine::ResourcePtr>> +qr_2(ErlNifEnv *env, fine::ResourcePtr t, bool reduced) { + auto result = + torch::linalg_qr(get_tensor(t), reduced ? "reduced" : "complete"); + return fine::Ok( + std::make_tuple(fine::make_resource(std::get<0>(result)), + fine::make_resource(std::get<1>(result)))); } -NIF(svd) { - TENSOR_PARAM(0, t); - bool full_matrices = true; - - if (argc == 2) { - GET(1, full_matrices); - } +REGISTER_TENSOR_NIF_ARITY(qr, qr_1); +REGISTER_TENSOR_NIF_ARITY(qr, qr_2); - TENSOR_TUPLE_3(torch::linalg_svd(*t, full_matrices)); +// svd - arity 1 and 2 versions +fine::Ok< + std::tuple, fine::ResourcePtr, + fine::ResourcePtr>> +svd_1(ErlNifEnv *env, fine::ResourcePtr t) { + auto result = torch::linalg_svd(get_tensor(t), true); + return fine::Ok( + std::make_tuple(fine::make_resource(std::get<0>(result)), + fine::make_resource(std::get<1>(result)), + fine::make_resource(std::get<2>(result)))); } -NIF(lu) { - TENSOR_PARAM(0, t); - - std::tuple lu_result = torch::linalg_lu_factor(*t); - std::tuple plu = torch::lu_unpack(std::get<0>(lu_result), std::get<1>(lu_result)); - - TENSOR_TUPLE_3(plu); +fine::Ok< + std::tuple, fine::ResourcePtr, + fine::ResourcePtr>> +svd_2(ErlNifEnv *env, fine::ResourcePtr t, bool full_matrices) { + auto result = torch::linalg_svd(get_tensor(t), full_matrices); + return fine::Ok( + std::make_tuple(fine::make_resource(std::get<0>(result)), + fine::make_resource(std::get<1>(result)), + fine::make_resource(std::get<2>(result)))); } -NIF(amax) { - TENSOR_PARAM(0, tensor); - LIST_PARAM(1, std::vector, axes); - PARAM(2, bool, keep_axes); +REGISTER_TENSOR_NIF_ARITY(svd, svd_1); +REGISTER_TENSOR_NIF_ARITY(svd, svd_2); + +fine::Ok< + std::tuple, fine::ResourcePtr, + fine::ResourcePtr>> +lu(ErlNifEnv *env, fine::ResourcePtr t) { + std::tuple lu_result = + torch::linalg_lu_factor(get_tensor(t)); + std::tuple plu = + torch::lu_unpack(std::get<0>(lu_result), std::get<1>(lu_result)); - TENSOR(at::amax(*tensor, axes, keep_axes)); + return fine::Ok( + std::make_tuple(fine::make_resource(std::get<0>(plu)), + fine::make_resource(std::get<1>(plu)), + fine::make_resource(std::get<2>(plu)))); } -NIF(amin) { - TENSOR_PARAM(0, tensor); - LIST_PARAM(1, std::vector, axes); - PARAM(2, bool, keep_axes); +REGISTER_TENSOR_NIF(lu); - TENSOR(at::amin(*tensor, axes, keep_axes)); +fine::Ok> +amax(ErlNifEnv *env, fine::ResourcePtr tensor, + std::vector axes, bool keep_axes) { + return tensor_ok( + at::amax(get_tensor(tensor), vec_to_array_ref(axes), keep_axes)); } -NIF(eigh) { - TENSOR_PARAM(0, tensor); +REGISTER_TENSOR_NIF(amax); - TENSOR_TUPLE(torch::linalg_eigh(*tensor)); +fine::Ok> +amin(ErlNifEnv *env, fine::ResourcePtr tensor, + std::vector axes, bool keep_axes) { + return tensor_ok( + at::amin(get_tensor(tensor), vec_to_array_ref(axes), keep_axes)); } -NIF(solve) { - TENSOR_PARAM(0, tensorA); - TENSOR_PARAM(1, tensorB); +REGISTER_TENSOR_NIF(amin); - TENSOR(torch::linalg_solve(*tensorA, *tensorB)); +fine::Ok< + std::tuple, fine::ResourcePtr>> +eigh(ErlNifEnv *env, fine::ResourcePtr tensor) { + auto result = torch::linalg_eigh(get_tensor(tensor)); + return fine::Ok( + std::make_tuple(fine::make_resource(std::get<0>(result)), + fine::make_resource(std::get<1>(result)))); } -NIF(conv) { - TENSOR_PARAM(0, tensor); - TENSOR_PARAM(1, kernel); +REGISTER_TENSOR_NIF(eigh); - LIST_PARAM(2, std::vector, stride); - LIST_PARAM(3, std::vector, padding); - LIST_PARAM(4, std::vector, dilation); - PARAM(5, bool, transposed); - PARAM(6, int64_t, groups); +fine::Ok> +solve(ErlNifEnv *env, fine::ResourcePtr tensorA, + fine::ResourcePtr tensorB) { + return tensor_ok( + torch::linalg_solve(get_tensor(tensorA), get_tensor(tensorB))); +} + +REGISTER_TENSOR_NIF(solve); +fine::Ok> +conv(ErlNifEnv *env, fine::ResourcePtr tensor, + fine::ResourcePtr kernel, std::vector stride, + std::vector padding, std::vector dilation, + bool transposed, int64_t groups) { c10::optional bias_tensor; - std::vector output_padding; output_padding.push_back(0); - // aten::convolution(Tensor input, Tensor weight, Tensor? bias, - // int[] stride, int[] padding, int[] dilation, bool transposed, - // int[] output_padding, int groups) -> Tensor - TENSOR(at::convolution(*tensor, *kernel, bias_tensor, - stride, padding, dilation, transposed, output_padding, groups)); + return tensor_ok(at::convolution(get_tensor(tensor), get_tensor(kernel), + bias_tensor, vec_to_array_ref(stride), + vec_to_array_ref(padding), + vec_to_array_ref(dilation), transposed, + vec_to_array_ref(output_padding), groups)); } -NIF(max_pool_3d) { - TENSOR_PARAM(0, tensor); - LIST_PARAM(1, std::vector, kernel_size); - LIST_PARAM(2, std::vector, strides); - LIST_PARAM(3, std::vector, padding); - LIST_PARAM(4, std::vector, dilation); +REGISTER_TENSOR_NIF(conv); - TENSOR(at::max_pool3d(*tensor, kernel_size, strides, padding, dilation)); +fine::Ok> +max_pool_3d(ErlNifEnv *env, fine::ResourcePtr tensor, + std::vector kernel_size, std::vector strides, + std::vector padding, std::vector dilation) { + return tensor_ok( + at::max_pool3d(get_tensor(tensor), vec_to_array_ref(kernel_size), + vec_to_array_ref(strides), vec_to_array_ref(padding), + vec_to_array_ref(dilation))); } -void free_tensor(ErlNifEnv *env, void *obj) { - torch::Tensor *tensor = reinterpret_cast(obj); - std::atomic *refcount = reinterpret_cast *>(tensor + 1); - std::atomic_flag *deleted = reinterpret_cast(refcount + 1); +REGISTER_TENSOR_NIF(max_pool_3d); - if (atomic_flag_test_and_set(deleted) == false) { - tensor->~Tensor(); - } - - deleted->~atomic_flag(); - refcount->~atomic(); -} - -static int -open_resource_type(ErlNifEnv *env) { - const char *name = "Tensor"; - ErlNifResourceFlags flags = (ErlNifResourceFlags)(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER); - - TENSOR_TYPE = enif_open_resource_type(env, NULL, name, free_tensor, flags, NULL); - if (TENSOR_TYPE == NULL) - return -1; - return 0; -} - -int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM load_info) { - // Silence "unused var" warnings. - (void)(env); - (void)(priv_data); - (void)(old_priv_data); - (void)(load_info); - - return 0; -} - -int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) { - if (open_resource_type(env) == -1) - return -1; - - // Silence "unused var" warnings. - (void)(priv_data); - (void)(load_info); - - return 0; -} - -#define F(NAME, ARITY) \ - { \ -#NAME, ARITY, NAME, 0 \ - } - -#define DF(NAME, ARITY) \ - {#NAME "_cpu", ARITY, NAME, ERL_NIF_DIRTY_JOB_CPU_BOUND}, \ - { \ -#NAME "_io", ARITY, NAME, ERL_NIF_DIRTY_JOB_IO_BOUND \ - } +} // namespace torchx -static ErlNifFunc nif_functions[] = { - DF(randint, 5), - DF(rand, 5), - DF(normal, 5), - DF(arange, 5), - DF(arange, 6), - DF(scalar_tensor, 3), - DF(ones, 3), - DF(eye, 4), - DF(full, 4), - - DF(item, 1), - DF(from_blob, 4), - DF(to_blob, 1), - DF(to_blob, 2), - DF(delete_tensor, 1), - DF(reshape, 2), - DF(split, 2), - DF(to_type, 2), - DF(to_device, 2), - DF(squeeze, 2), - DF(squeeze, 1), - DF(broadcast_to, 2), - DF(transpose, 3), - DF(permute, 2), - DF(slice, 4), - DF(concatenate, 2), - DF(gather, 3), - DF(index, 2), - DF(index_put, 4), - DF(argsort, 4), - DF(top_k, 2), - DF(flip, 2), - DF(unfold, 4), - DF(put, 3), - - DF(add, 2), - DF(subtract, 2), - DF(divide, 2), - DF(remainder, 2), - DF(fmod, 2), - DF(quotient, 2), - DF(multiply, 2), - DF(pow, 2), - DF(atan2, 2), - DF(min, 2), - DF(max, 2), - DF(solve, 2), - - DF(bitwise_and, 2), - DF(bitwise_or, 2), - DF(bitwise_xor, 2), - DF(left_shift, 2), - DF(right_shift, 2), - - DF(equal, 2), - DF(not_equal, 2), - DF(greater, 2), - DF(less, 2), - DF(greater_equal, 2), - DF(less_equal, 2), - - DF(logical_and, 2), - DF(logical_or, 2), - DF(logical_xor, 2), - DF(logical_not, 1), - - DF(sum, 3), - DF(product, 1), - DF(product, 3), - DF(argmax, 3), - DF(argmin, 3), - DF(any, 1), - DF(any, 3), - DF(all, 1), - DF(all, 3), - DF(all_close, 5), - - DF(cumulative_sum, 2), - DF(cumulative_product, 2), - DF(cumulative_min, 2), - DF(cumulative_max, 2), - - DF(abs, 1), - DF(ceil, 1), - DF(floor, 1), - DF(negate, 1), - DF(round, 1), - DF(sign, 1), - DF(exp, 1), - DF(expm1, 1), - DF(sqrt, 1), - DF(rsqrt, 1), - DF(log, 1), - DF(log1p, 1), - DF(bitwise_not, 1), - DF(sigmoid, 1), - DF(sin, 1), - DF(asin, 1), - DF(sinh, 1), - DF(asinh, 1), - DF(view_as_real, 1), - DF(conjugate, 1), - DF(cos, 1), - DF(acos, 1), - DF(cosh, 1), - DF(acosh, 1), - DF(tan, 1), - DF(atan, 1), - DF(tanh, 1), - DF(atanh, 1), - DF(erf, 1), - DF(erfc, 1), - DF(erf_inv, 1), - DF(cbrt, 1), - DF(is_nan, 1), - DF(is_infinity, 1), - DF(fft, 3), - DF(ifft, 3), - DF(fft2, 3), - DF(ifft2, 3), - - DF(tensordot, 6), - DF(matmul, 2), - DF(pad, 3), - - DF(cholesky, 1), - DF(cholesky, 2), - DF(eigh, 1), - DF(qr, 1), - DF(qr, 2), - DF(svd, 1), - DF(svd, 2), - DF(lu, 1), - DF(triangular_solve, 4), - DF(determinant, 1), - DF(sort, 4), - DF(clip, 3), - DF(where, 3), - DF(amax, 3), - DF(amin, 3), - - DF(conv, 7), - DF(max_pool_3d, 5), - - F(mps_is_available, 0), - F(cuda_is_available, 0), - F(cuda_device_count, 0), - F(scalar_type, 1), - F(shape, 1), - F(nbytes, 1)}; - -ERL_NIF_INIT(Elixir.Torchx.NIF, nif_functions, load, NULL, upgrade, NULL) +// Initialize the NIF module +FINE_INIT("Elixir.Torchx.NIF"); diff --git a/torchx/c_src/torchx_nif_util.h b/torchx/c_src/torchx_nif_util.h new file mode 100644 index 0000000000..5bd83e7dc6 --- /dev/null +++ b/torchx/c_src/torchx_nif_util.h @@ -0,0 +1,234 @@ +#ifndef TORCHX_NIF_UTIL_H_ +#define TORCHX_NIF_UTIL_H_ + +#include +#include +#include +#include + +namespace torchx { + +// Atom definitions +namespace atoms { +static auto already_deallocated = fine::Atom("already_deallocated"); +} // namespace atoms + +// Type mappings +inline std::map dtypes = { + {"byte", torch::kByte}, + {"char", torch::kChar}, + {"short", torch::kShort}, + {"int", torch::kInt}, + {"long", torch::kLong}, + {"float8_e5m2", torch::kFloat8_e5m2}, + {"half", torch::kHalf}, + {"brain", torch::kBFloat16}, + {"float", torch::kFloat}, + {"double", torch::kDouble}, + {"bool", torch::kBool}, + {"complex", at::ScalarType::ComplexFloat}, + {"complex_double", at::ScalarType::ComplexDouble}}; + +inline std::map dtype_sizes = { + {"byte", 1}, + {"char", 1}, + {"short", 2}, + {"int", 4}, + {"long", 8}, + {"float8_e5m2", 1}, + {"half", 2}, + {"brain", 2}, + {"float", 4}, + {"double", 8}, + {"complex", 8}, + {"complex_double", 16}}; + +inline torch::ScalarType string2type(const std::string &atom) { + return dtypes[atom]; +} + +inline const std::string *type2string(const torch::ScalarType type) { + for (std::map::iterator i = + dtypes.begin(); + i != dtypes.end(); ++i) { + if (i->second == type) + return &i->first; + } + return nullptr; +} + +// Tensor resource wrapper with deallocation tracking +class TorchTensor { +public: + TorchTensor(torch::Tensor tensor) : tensor_(tensor), deallocated_(false) {} + + torch::Tensor &tensor() { + if (deallocated_) { + throw std::runtime_error("Tensor has been deallocated"); + } + return tensor_; + } + + const torch::Tensor &tensor() const { + if (deallocated_) { + throw std::runtime_error("Tensor has been deallocated"); + } + return tensor_; + } + + bool deallocate() { + if (!deallocated_) { + deallocated_ = true; + // Assignment to empty tensor properly handles destruction and frees + // memory The destructor will be called automatically by the assignment + // operator + tensor_ = torch::Tensor(); + return true; + } + return false; + } + + bool is_deallocated() const { return deallocated_; } + +private: + torch::Tensor tensor_; + bool deallocated_; +}; + +} // namespace torchx + +// Fine specializations for torch types +namespace fine { + +// Decoder for std::vector from tuple (for shape parameters) +// Elixir passes shapes as tuples like {2, 3}, but we need vector +template <> struct Decoder> { + static std::vector decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { + // First try to decode as tuple (for shapes) + int size; + const ERL_NIF_TERM *terms; + if (enif_get_tuple(env, term, &size, &terms)) { + std::vector vec; + vec.reserve(size); + for (int i = 0; i < size; i++) { + vec.push_back(fine::decode(env, terms[i])); + } + return vec; + } + + // Otherwise try to decode as list + unsigned int length; + if (!enif_get_list_length(env, term, &length)) { + throw std::invalid_argument("decode failed, expected a tuple or list"); + } + + std::vector vector; + vector.reserve(length); + + auto list = term; + ERL_NIF_TERM head, tail; + while (enif_get_list_cell(env, list, &head, &tail)) { + auto elem = fine::decode(env, head); + vector.push_back(elem); + list = tail; + } + + return vector; + } +}; + +// Decoder for torch::Scalar +template <> struct Decoder { + static torch::Scalar decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { + // Try to decode as double + try { + return torch::Scalar(fine::decode(env, term)); + } catch (const std::invalid_argument &) { + // Try to decode as int64 + try { + return torch::Scalar(fine::decode(env, term)); + } catch (const std::invalid_argument &) { + // Try to decode as complex number (tuple of two doubles) + auto complex_tuple = fine::decode>(env, term); + return torch::Scalar(c10::complex(std::get<0>(complex_tuple), + std::get<1>(complex_tuple))); + } + } + } +}; + +// Encoder for torch::Scalar +template <> struct Encoder { + static ERL_NIF_TERM encode(ErlNifEnv *env, const torch::Scalar &scalar) { + if (scalar.isIntegral(false)) { + return fine::encode(env, scalar.toLong()); + } else if (scalar.isFloatingPoint()) { + return fine::encode(env, scalar.toDouble()); + } else if (scalar.isComplex()) { + auto complex = scalar.toComplexDouble(); + return fine::encode(env, std::make_tuple(complex.real(), complex.imag())); + } else { + throw std::runtime_error("Unknown scalar type"); + } + } +}; + +// Decoder for torch::ScalarType (from atom string) +template <> struct Decoder { + static torch::ScalarType decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { + auto type_string = fine::decode(env, term).to_string(); + return torchx::string2type(type_string); + } +}; + +// Decoder for torch::Device +template <> struct Decoder { + static torch::Device decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { + auto device_string = fine::decode(env, term); + return torch::Device(device_string); + } +}; + +// Decoder for c10::IntArrayRef (from list of int64) +template <> struct Decoder { + static c10::IntArrayRef decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { + // We need to store the vector somewhere persistent for IntArrayRef to reference + // This is tricky - IntArrayRef is a view, so we'll decode to vector instead + throw std::runtime_error( + "Cannot decode directly to IntArrayRef, use std::vector"); + } +}; + +// Decoder for std::vector +template <> +struct Decoder> { + static std::vector decode(ErlNifEnv *env, + const ERL_NIF_TERM &term) { + auto tensor_resources = fine::decode>>(env, term); + std::vector tensors; + tensors.reserve(tensor_resources.size()); + for (const auto &res : tensor_resources) { + tensors.push_back(res->tensor()); + } + return tensors; + } +}; + +// Encoder for std::vector +template <> +struct Encoder> { + static ERL_NIF_TERM encode(ErlNifEnv *env, + const std::vector &tensors) { + std::vector> tensor_resources; + tensor_resources.reserve(tensors.size()); + for (const auto &tensor : tensors) { + tensor_resources.push_back(fine::make_resource(tensor)); + } + return fine::encode(env, tensor_resources); + } +}; + +} // namespace fine + +#endif // TORCHX_NIF_UTIL_H_ + diff --git a/torchx/mix.exs b/torchx/mix.exs index faaad7d9b7..e6e88bd54b 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -43,6 +43,7 @@ defmodule Torchx.MixProject do [ {:nx, "~> 0.10.0"}, # {:nx, path: "../nx"}, + {:fine, "~> 0.1.0", runtime: false}, {:ex_doc, "~> 0.29", only: :docs} ] end @@ -284,7 +285,8 @@ defmodule Torchx.MixProject do "LIBTORCH_LINK" => "#{libtorch_link_path}/lib", "MIX_APP_PATH" => mix_app_path, "PRIV_DIR" => priv_path, - "ERTS_INCLUDE_DIR" => erts_include_dir + "ERTS_INCLUDE_DIR" => erts_include_dir, + "FINE_INCLUDE_DIR" => Fine.include_dir() } cmd!(cmake, ["-S", ".", "-B", cmake_build_dir], env) diff --git a/torchx/mix.lock b/torchx/mix.lock index 841e637316..d90591d64a 100644 --- a/torchx/mix.lock +++ b/torchx/mix.lock @@ -2,6 +2,7 @@ "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, + "fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"},