|
| 1 | +defmodule EMLX.ConfigTest do |
| 2 | + use EMLX.Case, async: false |
| 3 | + |
| 4 | + import ExUnit.CaptureLog |
| 5 | + |
| 6 | + setup do |
| 7 | + # Store original config value to restore after each test |
| 8 | + original_value = Application.get_env(:emlx, :warn_unsupported_option, true) |
| 9 | + |
| 10 | + on_exit(fn -> |
| 11 | + Application.put_env(:emlx, :warn_unsupported_option, original_value) |
| 12 | + end) |
| 13 | + |
| 14 | + {:ok, original_value: original_value} |
| 15 | + end |
| 16 | + |
| 17 | + # Test both argmax and argmin with config enabled/disabled |
| 18 | + for op <- [:argmax, :argmin] do |
| 19 | + describe "#{op} with warn_unsupported_option" do |
| 20 | + test "logs warning when config is enabled (default)" do |
| 21 | + Application.put_env(:emlx, :warn_unsupported_option, true) |
| 22 | + |
| 23 | + tensor = Nx.tensor([[1, 3, 2], [6, 4, 5]], backend: EMLX.Backend) |
| 24 | + |
| 25 | + log_output = |
| 26 | + capture_log(fn -> |
| 27 | + Nx.unquote(op)(tensor, axis: 0, tie_break: :high) |
| 28 | + end) |
| 29 | + |
| 30 | + assert log_output =~ |
| 31 | + "Nx.Backend.#{unquote(op)}/3 with tie_break: :high is not supported in EMLX" |
| 32 | + end |
| 33 | + |
| 34 | + test "does not log warning when config is disabled" do |
| 35 | + Application.put_env(:emlx, :warn_unsupported_option, false) |
| 36 | + |
| 37 | + tensor = Nx.tensor([[1, 3, 2], [6, 4, 5]], backend: EMLX.Backend) |
| 38 | + |
| 39 | + log_output = |
| 40 | + capture_log(fn -> |
| 41 | + Nx.unquote(op)(tensor, axis: 0, tie_break: :high) |
| 42 | + end) |
| 43 | + |
| 44 | + assert log_output == "" |
| 45 | + end |
| 46 | + end |
| 47 | + end |
| 48 | +end |
0 commit comments