Skip to content

Commit bd60991

Browse files
committed
feat: support binary backend in compiled mode
1 parent 35a4dac commit bd60991

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

lib/emlx.ex

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -318,25 +318,32 @@ defmodule EMLX do
318318

319319
@impl Nx.Defn.Compiler
320320
def __jit__(key, vars, fun, args_list, opts) do
321-
# TODO: instead of checking the backend here,
322-
# we should automatically convert from binary backend to EMLX backend
323-
# given a device optionsg
324-
case Nx.default_backend() do
325-
EMLX.Backend ->
326-
:ok
327-
328-
{EMLX.Backend, _} ->
329-
:ok
330-
331-
other ->
332-
raise ArgumentError, "EMLX can only be used with the EMLX backend, got: #{inspect(other)}"
333-
end
334-
335321
__compile__(key, vars, fun, opts).(args_list)
336322
end
337323

338324
@impl Nx.Defn.Compiler
339325
def __compile__(key, vars, fun, opts) do
326+
backend = Nx.default_backend()
327+
328+
target_backend =
329+
case backend do
330+
EMLX.Backend ->
331+
backend
332+
333+
{EMLX.Backend, _} ->
334+
backend
335+
336+
Nx.BinaryBackend ->
337+
EMLX.Backend
338+
339+
{Nx.BinaryBackend, _} ->
340+
EMLX.Backend
341+
342+
other ->
343+
raise ArgumentError,
344+
"EMLX can only be used with the EMLX.Backend or Nx.BinaryBackend, got: #{inspect(other)}"
345+
end
346+
340347
expr = fun.(vars)
341348

342349
fn [args] ->
@@ -346,6 +353,12 @@ defmodule EMLX do
346353
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} ->
347354
{device, ref}
348355

356+
%Nx.Tensor{data: %Nx.BinaryBackend{}} = t ->
357+
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} =
358+
Nx.backend_copy(t, target_backend)
359+
360+
{device, ref}
361+
349362
other ->
350363
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} = Nx.to_tensor(other)
351364
{device, ref}

test/emlx_test.exs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,30 @@ defmodule EMLXTest do
99
assert_equal(left, Nx.tensor(3))
1010
assert_equal(right, Nx.tensor(-1))
1111
end
12+
13+
test "__jit__ supports binary backend in arguments" do
14+
{left, right} =
15+
Nx.Defn.jit_apply(
16+
&{Nx.add(&1, &2), Nx.subtract(&1, &2)},
17+
[Nx.tensor(1, backend: Nx.BinaryBackend), 2],
18+
compiler: EMLX
19+
)
20+
21+
assert_equal(left, Nx.tensor(3))
22+
assert_equal(right, Nx.tensor(-1))
23+
end
24+
25+
test "__jit__ supports binary backend as the default backend" do
26+
Nx.with_default_backend(Nx.BinaryBackend, fn ->
27+
{left, right} =
28+
Nx.Defn.jit_apply(
29+
&{Nx.add(&1, &2), Nx.subtract(&1, &2)},
30+
[Nx.tensor(1), 2],
31+
compiler: EMLX
32+
)
33+
34+
assert_equal(left, Nx.tensor(3))
35+
assert_equal(right, Nx.tensor(-1))
36+
end)
37+
end
1238
end

0 commit comments

Comments
 (0)