From 7a873693d98624bf92c879c09eebadb79d19f5c6 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:05:06 -0300 Subject: [PATCH] feat: allow complex literals in defn --- nx/lib/nx/defn/compiler.ex | 7 +++++++ nx/lib/nx/defn/expr.ex | 2 +- nx/test/nx/defn_test.exs | 9 +++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index b529eae365..809076c748 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -585,6 +585,13 @@ defmodule Nx.Defn.Compiler do {{{:., dot_meta, [Nx, name]}, meta, args}, state} end + # We also allow specifically Complex.new so that literal complex numbers + # can be written in defn. + defp normalize({{:., dot_meta, [Complex, :new]}, meta, args}, state) do + {args, state} = normalize_list(args, state) + {{{:., dot_meta, [Complex, :new]}, meta, args}, state} + end + defp normalize({{:., dot_meta, [mod, name]}, meta, args}, state) when mod in @allowed_modules do {args, state} = normalize_list(args, state) {{{:., dot_meta, [mod, name]}, meta, args}, state} diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 638891eaf1..32b78436a2 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1271,7 +1271,7 @@ defmodule Nx.Defn.Expr do "value and inline it inside the defn expression. Got: #{inspect(t)}" end - defp to_expr(number) when is_number(number), + defp to_expr(number) when is_number(number) or is_struct(number, Complex), do: constant(%T{shape: {}, names: [], type: Nx.Type.infer(number)}, number) defp to_expr(other) do diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index d532ea4043..231e8e6a9f 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -25,6 +25,10 @@ defmodule Nx.DefnTest do @tensor [1, 2, 3] defn(list_constant, do: Nx.tensor(@tensor)) + defn complex_constant do + Complex.new(1, :infinity) + end + test "from list" do assert %T{data: %Expr{op: :tensor}} = list_constant() end @@ -35,6 +39,11 @@ defmodule Nx.DefnTest do test "from binary" do assert %T{data: %Expr{op: :tensor}} = binary_constant() end + + test "complex literals" do + assert %T{data: %Expr{op: :constant, args: [%Complex{} = c]}} = complex_constant() + assert c == Complex.new(1, :infinity) + end end describe "Nx.tensor" do