@@ -44,6 +44,71 @@ defmodule Nx.DefnTest do
4444 assert % T { data: % Expr { op: :constant , args: [ % Complex { } = c ] } } = complex_constant ( )
4545 assert c == Complex . new ( 1 , :infinity )
4646 end
47+
48+ defn real_constants_as_arbitrary_precision ( x ) do
49+ cond do
50+ x == 0 -> Nx.Constants . e ( )
51+ x == 1 -> Nx.Constants . pi ( )
52+ true -> x
53+ end
54+ end
55+
56+ defn imaginary_constant_as_arbitrary_precision ( x ) do
57+ x * Nx.Constants . i ( )
58+ end
59+
60+ test "Defines real constants as numerical when used with arity 0 inside defn" do
61+ for input_type <- [ f: 8 , f: 16 , f: 32 , f: 64 ] do
62+ assert % T {
63+ type: ^ input_type ,
64+ data: % Expr { op: :cond , args: [ [ clause1 , clause2 ] , _last ] }
65+ } =
66+ real_constants_as_arbitrary_precision ( Nx . tensor ( 10.0 , type: input_type ) )
67+
68+ e = :math . exp ( 1 )
69+ pi = :math . pi ( )
70+ assert { _ , % T { type: ^ input_type , data: % Expr { op: :constant , args: [ ^ e ] } } } = clause1
71+ assert { _ , % T { type: ^ input_type , data: % Expr { op: :constant , args: [ ^ pi ] } } } = clause2
72+ end
73+
74+ for input_type <- [ c: 64 , c: 128 ] do
75+ assert % T {
76+ type: ^ input_type ,
77+ data: % Expr { op: :cond , args: [ [ clause1 , clause2 ] , _last ] }
78+ } =
79+ real_constants_as_arbitrary_precision ( Nx . tensor ( 10.0 , type: input_type ) )
80+
81+ e = Complex . new ( :math . exp ( 1 ) )
82+ pi = Complex . new ( :math . pi ( ) )
83+ assert { _ , % T { type: ^ input_type , data: % Expr { op: :constant , args: [ ^ e ] } } } = clause1
84+ assert { _ , % T { type: ^ input_type , data: % Expr { op: :constant , args: [ ^ pi ] } } } = clause2
85+ end
86+ end
87+
88+ test "Defines imaginary constant as Complex.new(0, 1) when used with arity 0 inside defn" do
89+ for input_type <- [ f: 8 , f: 16 , f: 32 , f: 64 , c: 64 , c: 128 ] do
90+ type =
91+ case input_type do
92+ { :f , 64 } -> { :c , 128 }
93+ { :c , 128 } -> { :c , 128 }
94+ _ -> { :c , 64 }
95+ end
96+
97+ # note: we expect the number to be shifted to the left
98+ # because after AST normalization, but before Expr parsing,
99+ # Nx.Constants.i() is converted to Complex.new(0, 1), which is
100+ # a literal number.
101+ assert % T {
102+ type: ^ type ,
103+ data: % Expr { op: :multiply , args: [ arg0 , arg1 ] }
104+ } =
105+ imaginary_constant_as_arbitrary_precision ( Nx . tensor ( 10.0 , type: input_type ) )
106+
107+ i = Complex . new ( 0 , 1 )
108+ % T { data: % Expr { op: :constant , args: [ ^ i ] } } = arg0
109+ % T { data: % Expr { op: :parameter , args: [ 0 ] } } = arg1
110+ end
111+ end
47112 end
48113
49114 describe "Nx.tensor" do
0 commit comments