From f7cf2e8b1e452603140aff15cd7382c38e8c015b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 27 Jan 2025 14:38:34 -0300 Subject: [PATCH 1/2] feat: improve precision for nx constants --- nx/lib/nx/defn/compiler.ex | 12 +++++++ nx/test/nx/defn_test.exs | 65 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index 809076c748..bbbd3ddcd4 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -592,6 +592,18 @@ defmodule Nx.Defn.Compiler do {{{:., dot_meta, [Complex, :new]}, meta, args}, state} end + defp normalize({{:., dot_meta, [Nx.Constants, :i]}, meta, []}, state) do + {{{:., dot_meta, [Complex, :new]}, meta, [0, 1]}, state} + end + + defp normalize({{:., dot_meta, [Nx.Constants, :e]}, meta, []}, state) do + {{{:., dot_meta, [:math, :exp]}, meta, [1]}, state} + end + + defp normalize({{:., dot_meta, [Nx.Constants, :pi]}, meta, []}, state) do + {{{:., dot_meta, [:math, :pi]}, meta, []}, state} + end + defp normalize({{:., dot_meta, [mod, name]}, meta, args}, state) when mod in @allowed_modules do {args, state} = normalize_list(args, state) {{{:., dot_meta, [mod, name]}, meta, args}, state} diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index 90fe9263dd..62993b07a3 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -44,6 +44,71 @@ defmodule Nx.DefnTest do assert %T{data: %Expr{op: :constant, args: [%Complex{} = c]}} = complex_constant() assert c == Complex.new(1, :infinity) end + + defn real_constants_as_arbitrary_precision(x) do + cond do + x == 0 -> Nx.Constants.e() + x == 1 -> Nx.Constants.pi() + true -> x + end + end + + defn imaginary_constant_as_arbitrary_precision(x) do + x * Nx.Constants.i() + end + + test "Defines real constants as numerical when used with arity 0 inside defn" do + for input_type <- [f: 8, f: 16, f: 32, f: 64] do + assert %T{ + type: ^input_type, + data: %Expr{op: :cond, args: [[clause1, clause2], _last]} + } = + real_constants_as_arbitrary_precision(Nx.tensor(10.0, type: input_type)) + + e = :math.exp(1) + pi = :math.pi() + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^e]}}} = clause1 + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^pi]}}} = clause2 + end + + for input_type <- [c: 64, c: 128] do + assert %T{ + type: ^input_type, + data: %Expr{op: :cond, args: [[clause1, clause2], _last]} + } = + real_constants_as_arbitrary_precision(Nx.tensor(10.0, type: input_type)) + + e = Complex.new(:math.exp(1)) + pi = Complex.new(:math.pi()) + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^e]}}} = clause1 + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^pi]}}} = clause2 + end + end + + test "Defines imaginary constant as Complex.new(0, 1) when used with arity 0 inside defn" do + for input_type <- [f: 8, f: 16, f: 32, f: 64, c: 64, c: 128] do + type = + case input_type do + {:f, 64} -> {:c, 128} + {:c, 128} -> {:c, 128} + _ -> {:c, 64} + end + + # note: we expect the number to be shifted to the left + # because after AST normalization, but before Expr parsing, + # Nx.Constants.i() is converted to Complex.new(0, 1), which is + # a literal number. + assert %T{ + type: ^type, + data: %Expr{op: :multiply, args: [arg0, arg1]} + } = + imaginary_constant_as_arbitrary_precision(Nx.tensor(10.0, type: input_type)) + + i = Complex.new(0, 1) + %T{data: %Expr{op: :constant, args: [^i]}} = arg0 + %T{data: %Expr{op: :parameter, args: [0]}} = arg1 + end + end end describe "Nx.tensor" do From d53812eeeca739c8dc4d1c7449c08844e6abfb3f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 27 Jan 2025 14:41:54 -0300 Subject: [PATCH 2/2] docs: add moduledoc to Nx.Constants --- nx/lib/nx/constants.ex | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nx/lib/nx/constants.ex b/nx/lib/nx/constants.ex index 0f41b64aca..5dc2e4a875 100644 --- a/nx/lib/nx/constants.ex +++ b/nx/lib/nx/constants.ex @@ -3,6 +3,10 @@ defmodule Nx.Constants do Common constants used in computations. This module can be used in `defn`. + + The functions `e/0`, `pi/0` and `i/0` will follow the same rules as + literal constants when used inside `defn`. This means that they will + use the surrounding precision instead of defaulting to f32. """ import Nx.Shared