1+ #include < gmock/gmock.h>
12#include < gtest/gtest.h>
23#include < torch/torch.h>
34
5+ #include < cstdlib>
6+
47#include " test/cpp/torch_xla_test.h"
58#include " torch_xla/csrc/xla_generator.h"
69
@@ -102,5 +105,122 @@ TEST_F(XLAGeneratorTest, Clone) {
102105 ASSERT_NE (cloned_gen.current_seed (), gen_.current_seed ());
103106}
104107
108+ TEST_F (XLAGeneratorTest, GetDefaultXLAGenerator) {
109+ // Test getting default generator for device 0
110+ auto result = at::detail::GetDefaultXLAGenerator (0 );
111+ ASSERT_TRUE (result.ok ()) << " Failed to get default generator: "
112+ << result.status ();
113+
114+ const at::Generator& default_gen = result.value ();
115+ ASSERT_EQ (default_gen.device ().type (), at::DeviceType::XLA);
116+ ASSERT_EQ (default_gen.device ().index (), 0 );
117+
118+ // Test getting default generator with -1 (should default to device 0)
119+ auto result_default = at::detail::GetDefaultXLAGenerator (-1 );
120+ ASSERT_TRUE (result_default.ok ())
121+ << " Failed to get default generator with -1: " << result_default.status ();
122+
123+ const at::Generator& default_gen_neg1 = result_default.value ();
124+ ASSERT_EQ (default_gen_neg1.device ().type (), at::DeviceType::XLA);
125+ ASSERT_EQ (default_gen_neg1.device ().index (), 0 );
126+ ASSERT_EQ (default_gen, default_gen_neg1);
127+
128+ // Test that subsequent calls return the same generator instance
129+ auto result2 = at::detail::GetDefaultXLAGenerator (0 );
130+ ASSERT_TRUE (result2.ok ());
131+ const at::Generator& default_gen2 = result2.value ();
132+ ASSERT_EQ (default_gen, default_gen2);
133+
134+ // Test getting non-defuault device generator
135+ auto result_device1 = at::detail::GetDefaultXLAGenerator (1 );
136+ ASSERT_TRUE (result_device1.ok ())
137+ << " Failed to get default generator for device 1: "
138+ << result_device1.status ();
139+
140+ const at::Generator& default_gen_device1 = result_device1.value ();
141+ ASSERT_EQ (default_gen_device1.device ().type (), at::DeviceType::XLA);
142+ ASSERT_EQ (default_gen_device1.device ().index (), 1 );
143+ ASSERT_NE (default_gen_device1, default_gen);
144+ }
145+
146+ TEST_F (XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) {
147+ // Test with invalid device indices
148+ auto result_neg2 = at::detail::GetDefaultXLAGenerator (-2 );
149+ ASSERT_FALSE (result_neg2.ok ());
150+ ASSERT_TRUE (absl::IsInvalidArgument (result_neg2.status ()));
151+ ASSERT_THAT (result_neg2.status ().message (),
152+ testing::HasSubstr (" Invalid XLA device index" ));
153+
154+ // Test with very large device index (assuming there aren't 1000 XLA devices)
155+ auto result_large = at::detail::GetDefaultXLAGenerator (100 );
156+ ASSERT_FALSE (result_large.ok ());
157+ ASSERT_TRUE (absl::IsInvalidArgument (result_large.status ()));
158+ ASSERT_THAT (result_large.status ().message (),
159+ testing::HasSubstr (" Invalid XLA device index" ));
160+ }
161+
162+ TEST_F (XLAGeneratorTest, CreateXLAGenerator) {
163+ // Test creating generator for device 1
164+ auto result = at::detail::CreateXLAGenerator (1 );
165+ ASSERT_TRUE (result.ok ()) << " Failed to create generator: " << result.status ();
166+
167+ at::Generator created_gen = result.value ();
168+ ASSERT_EQ (created_gen.device ().type (), at::DeviceType::XLA);
169+ ASSERT_EQ (created_gen.device ().index (), 1 );
170+
171+ // Test that the generator is initialized with default seed
172+ ASSERT_EQ (created_gen.current_seed (), c10::default_rng_seed_val);
173+
174+ // Test creating generator with -1 (should use current device)
175+ auto result_current = at::detail::CreateXLAGenerator (-1 );
176+ ASSERT_TRUE (result_current.ok ())
177+ << " Failed to create generator with -1: " << result_current.status ();
178+
179+ at::Generator created_gen_neg1 = result_current.value ();
180+ ASSERT_EQ (created_gen_neg1.device ().type (), at::DeviceType::XLA);
181+ // Device index should be >= 0 (actual device depends on current XLA device)
182+ ASSERT_GE (created_gen_neg1.device ().index (), 0 );
183+ }
184+
185+ TEST_F (XLAGeneratorTest, CreateXLAGeneratorUniqueness) {
186+ // Test that each call creates a new generator instance
187+ auto result1 = at::detail::CreateXLAGenerator (0 );
188+ auto result2 = at::detail::CreateXLAGenerator (0 );
189+
190+ ASSERT_TRUE (result1.ok ());
191+ ASSERT_TRUE (result2.ok ());
192+
193+ at::Generator gen1 = result1.value ();
194+ at::Generator gen2 = result2.value ();
195+
196+ // Should be different instances (compare generators, not their stack
197+ // addresses)
198+ ASSERT_NE (gen1, gen2);
199+
200+ // But should have same device and initial seed
201+ ASSERT_EQ (gen1.device (), gen2.device ());
202+ ASSERT_EQ (gen1.current_seed (), gen2.current_seed ());
203+
204+ // Modifying one should not affect the other
205+ gen1.set_current_seed (12345 );
206+ ASSERT_NE (gen1.current_seed (), gen2.current_seed ());
207+ }
208+
209+ TEST_F (XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) {
210+ // Test with invalid device indices
211+ auto result_neg2 = at::detail::CreateXLAGenerator (-2 );
212+ ASSERT_FALSE (result_neg2.ok ());
213+ ASSERT_TRUE (absl::IsInvalidArgument (result_neg2.status ()));
214+ ASSERT_THAT (result_neg2.status ().message (),
215+ testing::HasSubstr (" Invalid XLA device index" ));
216+
217+ // Test with very large device index (assuming there aren't 100 XLA devices)
218+ auto result_large = at::detail::CreateXLAGenerator (100 );
219+ ASSERT_FALSE (result_large.ok ());
220+ ASSERT_TRUE (absl::IsInvalidArgument (result_large.status ()));
221+ ASSERT_THAT (result_large.status ().message (),
222+ testing::HasSubstr (" Invalid XLA device index" ));
223+ }
224+
105225} // namespace cpp_test
106- } // namespace torch_xla
226+ } // namespace torch_xla
0 commit comments