Skip to content

Commit a020ccf

Browse files
remove raising on tie_break :high (#79)
* Remove raising for tie_break: :high * Config for logger warning on unsupported option * changes due to code review --------- Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent 250b583 commit a020ccf

File tree

3 files changed

+73
-4
lines changed

3 files changed

+73
-4
lines changed

README.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
EMLX is the Nx Backend for the [MLX](https://github.com/ml-explore/mlx) library.
44

5-
Because of MLX's nature, EMLX with GPU backend is only supported on macOS.
5+
Because of MLX's nature, EMLX with GPU backend is only supported on macOS.
66

7-
MLX with CPU backend is available on most mainstream platforms, however, the CPU backend may not be as optimized as the GPU backend,
7+
MLX with CPU backend is available on most mainstream platforms, however, the CPU backend may not be as optimized as the GPU backend,
88
especially for non-macOS OSes, as they're not prioritized for development. Right now, EMLX supports x86_64 and arm64 architectures
99
on both macOS and Linux.
1010

@@ -49,6 +49,23 @@ Defaulting to Nx.Defn.Evaluator is the safest option for now.
4949
Nx.Defn.default_options(compiler: EMLX)
5050
```
5151

52+
### Configuration
53+
54+
EMLX supports several configuration options that can be set in your application's config:
55+
56+
#### `:warn_unsupported_option`
57+
58+
Controls whether warnings are logged when unsupported options are used with certain operations.
59+
60+
- **Type**: `boolean`
61+
- **Default**: `true`
62+
- **Description**: When enabled, EMLX will log warnings for operations that receive options not supported by the MLX backend. For example, `Nx.argmax/2` and `Nx.argmin/2` with `tie_break: :high` will log a warning since MLX doesn't support this tie-breaking behavior.
63+
64+
```elixir
65+
# In config/config.exs
66+
config :emlx, :warn_unsupported_option, false
67+
```
68+
5269
### MLX binaries
5370

5471
EMLX relies on the [MLX](https://github.com/ml-explore/mlx) library to function, and currently EMLX will download precompiled builds from [mlx-build](https://github.com/cocoa-xu/mlx-build).

lib/emlx/backend.ex

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ defmodule EMLX.Backend do
44
alias Nx.Tensor, as: T
55
alias EMLX.Backend, as: Backend
66

7+
require Logger
8+
79
defstruct [:ref, :shape, :type, :data]
810

911
@impl true
@@ -535,8 +537,10 @@ defmodule EMLX.Backend do
535537
axis = opts[:axis]
536538
keep_axis = opts[:keep_axis] == true
537539

538-
if opts[:tie_break] == :high do
539-
raise "Nx.Backend.#{unquote(op)}/3 with tie_break: :high is not supported in EMLX"
540+
if Application.get_env(:emlx, :warn_unsupported_option, true) and opts[:tie_break] == :high do
541+
Logger.warning(
542+
"Nx.Backend.#{unquote(op)}/3 with tie_break: :high is not supported in EMLX"
543+
)
540544
end
541545

542546
t_mx = from_nx(tensor)

test/emlx/config_test.exs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
defmodule EMLX.ConfigTest do
2+
use EMLX.Case, async: false
3+
4+
import ExUnit.CaptureLog
5+
6+
setup do
7+
# Store original config value to restore after each test
8+
original_value = Application.get_env(:emlx, :warn_unsupported_option, true)
9+
10+
on_exit(fn ->
11+
Application.put_env(:emlx, :warn_unsupported_option, original_value)
12+
end)
13+
14+
{:ok, original_value: original_value}
15+
end
16+
17+
# Test both argmax and argmin with config enabled/disabled
18+
for op <- [:argmax, :argmin] do
19+
describe "#{op} with warn_unsupported_option" do
20+
test "logs warning when config is enabled (default)" do
21+
Application.put_env(:emlx, :warn_unsupported_option, true)
22+
23+
tensor = Nx.tensor([[1, 3, 2], [6, 4, 5]], backend: EMLX.Backend)
24+
25+
log_output =
26+
capture_log(fn ->
27+
Nx.unquote(op)(tensor, axis: 0, tie_break: :high)
28+
end)
29+
30+
assert log_output =~
31+
"Nx.Backend.#{unquote(op)}/3 with tie_break: :high is not supported in EMLX"
32+
end
33+
34+
test "does not log warning when config is disabled" do
35+
Application.put_env(:emlx, :warn_unsupported_option, false)
36+
37+
tensor = Nx.tensor([[1, 3, 2], [6, 4, 5]], backend: EMLX.Backend)
38+
39+
log_output =
40+
capture_log(fn ->
41+
Nx.unquote(op)(tensor, axis: 0, tie_break: :high)
42+
end)
43+
44+
assert log_output == ""
45+
end
46+
end
47+
end
48+
end

0 commit comments

Comments
 (0)