Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nx/lib/nx/constants.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ defmodule Nx.Constants do
Common constants used in computations.

This module can be used in `defn`.

The functions `e/0`, `pi/0` and `i/0` will follow the same rules as
literal constants when used inside `defn`. This means that they will
use the surrounding precision instead of defaulting to f32.
"""

import Nx.Shared
Expand Down
12 changes: 12 additions & 0 deletions nx/lib/nx/defn/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,18 @@ defmodule Nx.Defn.Compiler do
{{{:., dot_meta, [Complex, :new]}, meta, args}, state}
end

defp normalize({{:., dot_meta, [Nx.Constants, :i]}, meta, []}, state) do
{{{:., dot_meta, [Complex, :new]}, meta, [0, 1]}, state}
end

defp normalize({{:., dot_meta, [Nx.Constants, :e]}, meta, []}, state) do
{{{:., dot_meta, [:math, :exp]}, meta, [1]}, state}
end

defp normalize({{:., dot_meta, [Nx.Constants, :pi]}, meta, []}, state) do
{{{:., dot_meta, [:math, :pi]}, meta, []}, state}
end

defp normalize({{:., dot_meta, [mod, name]}, meta, args}, state) when mod in @allowed_modules do
{args, state} = normalize_list(args, state)
{{{:., dot_meta, [mod, name]}, meta, args}, state}
Expand Down
65 changes: 65 additions & 0 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,71 @@ defmodule Nx.DefnTest do
assert %T{data: %Expr{op: :constant, args: [%Complex{} = c]}} = complex_constant()
assert c == Complex.new(1, :infinity)
end

defn real_constants_as_arbitrary_precision(x) do
cond do
x == 0 -> Nx.Constants.e()
x == 1 -> Nx.Constants.pi()
true -> x
end
end

defn imaginary_constant_as_arbitrary_precision(x) do
x * Nx.Constants.i()
end

test "Defines real constants as numerical when used with arity 0 inside defn" do
for input_type <- [f: 8, f: 16, f: 32, f: 64] do
assert %T{
type: ^input_type,
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
} =
real_constants_as_arbitrary_precision(Nx.tensor(10.0, type: input_type))

e = :math.exp(1)
pi = :math.pi()
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^e]}}} = clause1
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^pi]}}} = clause2
end

for input_type <- [c: 64, c: 128] do
assert %T{
type: ^input_type,
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
} =
real_constants_as_arbitrary_precision(Nx.tensor(10.0, type: input_type))

e = Complex.new(:math.exp(1))
pi = Complex.new(:math.pi())
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^e]}}} = clause1
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [^pi]}}} = clause2
end
end

test "Defines imaginary constant as Complex.new(0, 1) when used with arity 0 inside defn" do
for input_type <- [f: 8, f: 16, f: 32, f: 64, c: 64, c: 128] do
type =
case input_type do
{:f, 64} -> {:c, 128}
{:c, 128} -> {:c, 128}
_ -> {:c, 64}
end

# note: we expect the number to be shifted to the left
# because after AST normalization, but before Expr parsing,
# Nx.Constants.i() is converted to Complex.new(0, 1), which is
# a literal number.
assert %T{
type: ^type,
data: %Expr{op: :multiply, args: [arg0, arg1]}
} =
imaginary_constant_as_arbitrary_precision(Nx.tensor(10.0, type: input_type))

i = Complex.new(0, 1)
%T{data: %Expr{op: :constant, args: [^i]}} = arg0
%T{data: %Expr{op: :parameter, args: [0]}} = arg1
end
end
end

describe "Nx.tensor" do
Expand Down
Loading