Skip to content

Commit c6fc98d

Browse files
authored
feat: improve precision for nx constants (#1574)
1 parent 9cfcd05 commit c6fc98d

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

nx/lib/nx/constants.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ defmodule Nx.Constants do
33
Common constants used in computations.
44
55
This module can be used in `defn`.
6+
7+
The functions `e/0`, `pi/0` and `i/0` will follow the same rules as
8+
literal constants when used inside `defn`. This means that they will
9+
use the surrounding precision instead of defaulting to f32.
610
"""
711

812
import Nx.Shared

nx/lib/nx/defn/compiler.ex

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,18 @@ defmodule Nx.Defn.Compiler do
592592
{{{:., dot_meta, [Complex, :new]}, meta, args}, state}
593593
end
594594

595+
defp normalize({{:., dot_meta, [Nx.Constants, :i]}, meta, []}, state) do
596+
{{{:., dot_meta, [Complex, :new]}, meta, [0, 1]}, state}
597+
end
598+
599+
defp normalize({{:., dot_meta, [Nx.Constants, :e]}, meta, []}, state) do
600+
{{{:., dot_meta, [:math, :exp]}, meta, [1]}, state}
601+
end
602+
603+
defp normalize({{:., dot_meta, [Nx.Constants, :pi]}, meta, []}, state) do
604+
{{{:., dot_meta, [:math, :pi]}, meta, []}, state}
605+
end
606+
595607
defp normalize({{:., dot_meta, [mod, name]}, meta, args}, state) when mod in @allowed_modules do
596608
{args, state} = normalize_list(args, state)
597609
{{{:., dot_meta, [mod, name]}, meta, args}, state}

nx/test/nx/defn_test.exs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,71 @@ defmodule Nx.DefnTest do
4444
assert %T{data: %Expr{op: :constant, args: [%Complex{} = c]}} = complex_constant()
4545
assert c == Complex.new(1, :infinity)
4646
end
47+
48+
defn real_constants_as_arbitrary_precision(x) do
49+
cond do
50+
x == 0 -> Nx.Constants.e()
51+
x == 1 -> Nx.Constants.pi()
52+
true -> x
53+
end
54+
end
55+
56+
defn imaginary_constant_as_arbitrary_precision(x) do
57+
x * Nx.Constants.i()
58+
end
59+
60+
test "Defines real constants as numerical when used with arity 0 inside defn" do
61+
for input_type <- [f: 8, f: 16, f: 32, f: 64] do
62+
assert %T{
63+
type: ^input_type,
64+
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
65+
} =
66+
real_constants_as_arbitrary_precision(Nx.tensor(10.0, type: input_type))
67+
68+
e = :math.exp(1)
69+
pi = :math.pi()
70+
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^e]}}} = clause1
71+
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^pi]}}} = clause2
72+
end
73+
74+
for input_type <- [c: 64, c: 128] do
75+
assert %T{
76+
type: ^input_type,
77+
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
78+
} =
79+
real_constants_as_arbitrary_precision(Nx.tensor(10.0, type: input_type))
80+
81+
e = Complex.new(:math.exp(1))
82+
pi = Complex.new(:math.pi())
83+
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^e]}}} = clause1
84+
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^pi]}}} = clause2
85+
end
86+
end
87+
88+
test "Defines imaginary constant as Complex.new(0, 1) when used with arity 0 inside defn" do
89+
for input_type <- [f: 8, f: 16, f: 32, f: 64, c: 64, c: 128] do
90+
type =
91+
case input_type do
92+
{:f, 64} -> {:c, 128}
93+
{:c, 128} -> {:c, 128}
94+
_ -> {:c, 64}
95+
end
96+
97+
# note: we expect the number to be shifted to the left
98+
# because after AST normalization, but before Expr parsing,
99+
# Nx.Constants.i() is converted to Complex.new(0, 1), which is
100+
# a literal number.
101+
assert %T{
102+
type: ^type,
103+
data: %Expr{op: :multiply, args: [arg0, arg1]}
104+
} =
105+
imaginary_constant_as_arbitrary_precision(Nx.tensor(10.0, type: input_type))
106+
107+
i = Complex.new(0, 1)
108+
%T{data: %Expr{op: :constant, args: [^i]}} = arg0
109+
%T{data: %Expr{op: :parameter, args: [0]}} = arg1
110+
end
111+
end
47112
end
48113

49114
describe "Nx.tensor" do

0 commit comments

Comments
 (0)