Skip to content

Commit 85600f7

Browse files
authored
Add helper functions getDefaultXLAGenerator and createXLAGenerator to XLA random number generator (#9682)
Add helper functions getDefaultXLAGenerator and createXLAGenerator to XLA random number generator These helper functions will be used with XLA hook later. Refer to #9159
1 parent dd60969 commit 85600f7

File tree

5 files changed

+254
-2
lines changed

5 files changed

+254
-2
lines changed

.github/scripts/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ function run_torch_xla_cpp_tests() {
3131
TORCH_DIR=$(python -c "import pkgutil; import os; print(os.path.dirname(pkgutil.get_loader('torch').get_filename()))")
3232
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${TORCH_DIR}/lib
3333
export PJRT_DEVICE=CPU
34+
export CPU_NUM_DEVICES=2
3435
export XLA_EXPERIMENTAL="nonzero:masked_select:nms"
3536

3637
test_names=("test_aten_xla_tensor_1"

test/cpp/run_tests.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ if [[ "$BAZEL_VERB" == "coverage" ]]; then
8282
EXTRA_FLAGS="$EXTRA_FLAGS --remote_download_outputs=all" # for lcov symlink
8383
fi
8484

85+
# Forward PJRT_DEVICE and CPU_NUM_DEVICES to bazel test environment.
86+
# Set sensible defaults when not provided so tests run reproducibly.
87+
: "${PJRT_DEVICE:=CPU}"
88+
: "${CPU_NUM_DEVICES:=2}"
89+
export PJRT_DEVICE CPU_NUM_DEVICES
90+
if [[ -n "${PJRT_DEVICE}" ]]; then
91+
EXTRA_FLAGS="$EXTRA_FLAGS --test_env=PJRT_DEVICE=${PJRT_DEVICE}"
92+
fi
93+
if [[ -n "${CPU_NUM_DEVICES}" ]]; then
94+
EXTRA_FLAGS="$EXTRA_FLAGS --test_env=CPU_NUM_DEVICES=${CPU_NUM_DEVICES}"
95+
fi
96+
8597
test_names=("all")
8698
if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then
8799
test_names=("test_aten_xla_tensor_1"

test/cpp/test_xla_generator.cpp

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
#include <gmock/gmock.h>
12
#include <gtest/gtest.h>
23
#include <torch/torch.h>
34

5+
#include <cstdlib>
6+
47
#include "test/cpp/torch_xla_test.h"
58
#include "torch_xla/csrc/xla_generator.h"
69

@@ -102,5 +105,122 @@ TEST_F(XLAGeneratorTest, Clone) {
102105
ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed());
103106
}
104107

108+
TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) {
109+
// Test getting default generator for device 0
110+
auto result = at::detail::GetDefaultXLAGenerator(0);
111+
ASSERT_TRUE(result.ok()) << "Failed to get default generator: "
112+
<< result.status();
113+
114+
const at::Generator& default_gen = result.value();
115+
ASSERT_EQ(default_gen.device().type(), at::DeviceType::XLA);
116+
ASSERT_EQ(default_gen.device().index(), 0);
117+
118+
// Test getting default generator with -1 (should default to device 0)
119+
auto result_default = at::detail::GetDefaultXLAGenerator(-1);
120+
ASSERT_TRUE(result_default.ok())
121+
<< "Failed to get default generator with -1: " << result_default.status();
122+
123+
const at::Generator& default_gen_neg1 = result_default.value();
124+
ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA);
125+
ASSERT_EQ(default_gen_neg1.device().index(), 0);
126+
ASSERT_EQ(default_gen, default_gen_neg1);
127+
128+
// Test that subsequent calls return the same generator instance
129+
auto result2 = at::detail::GetDefaultXLAGenerator(0);
130+
ASSERT_TRUE(result2.ok());
131+
const at::Generator& default_gen2 = result2.value();
132+
ASSERT_EQ(default_gen, default_gen2);
133+
134+
// Test getting non-defuault device generator
135+
auto result_device1 = at::detail::GetDefaultXLAGenerator(1);
136+
ASSERT_TRUE(result_device1.ok())
137+
<< "Failed to get default generator for device 1: "
138+
<< result_device1.status();
139+
140+
const at::Generator& default_gen_device1 = result_device1.value();
141+
ASSERT_EQ(default_gen_device1.device().type(), at::DeviceType::XLA);
142+
ASSERT_EQ(default_gen_device1.device().index(), 1);
143+
ASSERT_NE(default_gen_device1, default_gen);
144+
}
145+
146+
TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) {
147+
// Test with invalid device indices
148+
auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2);
149+
ASSERT_FALSE(result_neg2.ok());
150+
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));
151+
ASSERT_THAT(result_neg2.status().message(),
152+
testing::HasSubstr("Invalid XLA device index"));
153+
154+
// Test with very large device index (assuming there aren't 1000 XLA devices)
155+
auto result_large = at::detail::GetDefaultXLAGenerator(100);
156+
ASSERT_FALSE(result_large.ok());
157+
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
158+
ASSERT_THAT(result_large.status().message(),
159+
testing::HasSubstr("Invalid XLA device index"));
160+
}
161+
162+
TEST_F(XLAGeneratorTest, CreateXLAGenerator) {
163+
// Test creating generator for device 1
164+
auto result = at::detail::CreateXLAGenerator(1);
165+
ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status();
166+
167+
at::Generator created_gen = result.value();
168+
ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA);
169+
ASSERT_EQ(created_gen.device().index(), 1);
170+
171+
// Test that the generator is initialized with default seed
172+
ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val);
173+
174+
// Test creating generator with -1 (should use current device)
175+
auto result_current = at::detail::CreateXLAGenerator(-1);
176+
ASSERT_TRUE(result_current.ok())
177+
<< "Failed to create generator with -1: " << result_current.status();
178+
179+
at::Generator created_gen_neg1 = result_current.value();
180+
ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA);
181+
// Device index should be >= 0 (actual device depends on current XLA device)
182+
ASSERT_GE(created_gen_neg1.device().index(), 0);
183+
}
184+
185+
TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) {
186+
// Test that each call creates a new generator instance
187+
auto result1 = at::detail::CreateXLAGenerator(0);
188+
auto result2 = at::detail::CreateXLAGenerator(0);
189+
190+
ASSERT_TRUE(result1.ok());
191+
ASSERT_TRUE(result2.ok());
192+
193+
at::Generator gen1 = result1.value();
194+
at::Generator gen2 = result2.value();
195+
196+
// Should be different instances (compare generators, not their stack
197+
// addresses)
198+
ASSERT_NE(gen1, gen2);
199+
200+
// But should have same device and initial seed
201+
ASSERT_EQ(gen1.device(), gen2.device());
202+
ASSERT_EQ(gen1.current_seed(), gen2.current_seed());
203+
204+
// Modifying one should not affect the other
205+
gen1.set_current_seed(12345);
206+
ASSERT_NE(gen1.current_seed(), gen2.current_seed());
207+
}
208+
209+
TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) {
210+
// Test with invalid device indices
211+
auto result_neg2 = at::detail::CreateXLAGenerator(-2);
212+
ASSERT_FALSE(result_neg2.ok());
213+
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));
214+
ASSERT_THAT(result_neg2.status().message(),
215+
testing::HasSubstr("Invalid XLA device index"));
216+
217+
// Test with very large device index (assuming there aren't 100 XLA devices)
218+
auto result_large = at::detail::CreateXLAGenerator(100);
219+
ASSERT_FALSE(result_large.ok());
220+
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
221+
ASSERT_THAT(result_large.status().message(),
222+
testing::HasSubstr("Invalid XLA device index"));
223+
}
224+
105225
} // namespace cpp_test
106-
} // namespace torch_xla
226+
} // namespace torch_xla

torch_xla/csrc/xla_generator.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,113 @@
55
#include <ATen/core/Tensor.h>
66
#include <c10/core/Device.h>
77
#include <c10/core/DeviceType.h>
8+
#include <c10/core/GeneratorImpl.h>
89
#include <c10/core/TensorImpl.h>
10+
#include <c10/util/CallOnce.h>
911
#include <c10/util/intrusive_ptr.h>
1012

1113
#include <cstring>
14+
#include <deque>
15+
#include <vector>
16+
17+
#include "absl/status/status.h"
18+
#include "torch_xla/csrc/aten_xla_bridge.h"
19+
#include "torch_xla/csrc/runtime/computation_client.h"
20+
#include "torch_xla/csrc/runtime/runtime.h"
21+
#include "torch_xla/csrc/status.h"
22+
23+
namespace at {
24+
25+
namespace detail {
26+
27+
namespace {
28+
29+
// Total number of XLA devices in the system.
30+
static int64_t num_xla_devices;
31+
32+
// Ensures default_gens_xla is initialized once.
33+
static std::deque<c10::once_flag> xla_gens_init_flag;
34+
35+
// Default, global XLA generators, one per XLA device.
36+
static std::vector<at::Generator> default_gens_xla;
37+
38+
/*
39+
* Populates the global variables related to XLA generators
40+
* Warning: this function must only be called once!
41+
*/
42+
static absl::Status InitGlobalVars() {
43+
static const absl::Status* init_status = new absl::Status([]() {
44+
XLA_ASSIGN_OR_RETURN(auto c_client,
45+
torch_xla::runtime::GetComputationClient());
46+
num_xla_devices = static_cast<int64_t>(c_client->GetNumDevices());
47+
xla_gens_init_flag.resize(num_xla_devices);
48+
default_gens_xla.resize(num_xla_devices);
49+
return absl::OkStatus();
50+
}());
51+
return *init_status;
52+
}
53+
54+
// Validates and normalizes an XLA device index.
55+
// If requested_index == -1, the current device index is used.
56+
// Returns InvalidArgument if the resolved index is out of range.
57+
static absl::StatusOr<c10::DeviceIndex> NormalizeXLADeviceIndex(
58+
c10::DeviceIndex requested_index) {
59+
c10::DeviceIndex idx = requested_index;
60+
if (idx == -1) {
61+
idx = torch_xla::bridge::GetCurrentAtenDevice().index();
62+
}
63+
if (idx < 0 || idx >= num_xla_devices) {
64+
return absl::InvalidArgumentError(
65+
"Invalid device index for XLA generator. Provided index: " +
66+
std::to_string(idx));
67+
}
68+
return idx;
69+
}
70+
71+
} // anonymous namespace
72+
73+
/**
74+
* PyTorch maintains a collection of default generators that get
75+
* initialized once. The purpose of these default generators is to
76+
* maintain a global running state of the pseudo random number generation,
77+
* when a user does not explicitly mention any generator.
78+
* GetDefaultXLAGenerator gets the default generator for a particular
79+
* XLA device.
80+
*/
81+
absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(
82+
c10::DeviceIndex device_index) {
83+
XLA_RETURN_IF_ERROR(InitGlobalVars(), "Failed to initialize XLA generators");
84+
// Normalize and validate the target device index; default to current device
85+
// when unspecified
86+
XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx,
87+
NormalizeXLADeviceIndex(device_index),
88+
"Invalid XLA device index");
89+
c10::call_once(xla_gens_init_flag[idx], [&] {
90+
default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
91+
default_gens_xla[idx].seed();
92+
});
93+
return default_gens_xla[idx];
94+
}
95+
96+
/**
97+
* Utility to create a XLAGeneratorImpl. Returns a shared_ptr
98+
*/
99+
absl::StatusOr<at::Generator> CreateXLAGenerator(
100+
c10::DeviceIndex device_index) {
101+
XLA_RETURN_IF_ERROR(InitGlobalVars(), "Failed to initialize XLA generators");
102+
// Normalize and validate the target device index; default to current device
103+
// when unspecified
104+
XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx,
105+
NormalizeXLADeviceIndex(device_index),
106+
"Invalid XLA device index");
107+
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
108+
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
109+
xla_gen->set_current_seed(c10::default_rng_seed_val);
110+
return gen;
111+
}
112+
113+
} // namespace detail
114+
} // namespace at
12115

13116
namespace at {
14117

torch_xla/csrc/xla_generator.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22

33
#include <ATen/core/Generator.h>
44
#include <ATen/core/Tensor.h>
5+
#include <c10/core/Device.h>
6+
#include <c10/core/DeviceType.h>
7+
#include <c10/core/GeneratorImpl.h>
8+
#include <c10/core/TensorImpl.h>
59
#include <c10/util/intrusive_ptr.h>
610

711
#include <cstdint>
812

13+
#include "absl/status/status.h"
14+
#include "absl/status/statusor.h"
15+
916
namespace at {
1017

1118
// Holds the actual state variables for the XLA generator.
@@ -53,4 +60,13 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
5360
c10::intrusive_ptr<XLAGeneratorState> state_;
5461
};
5562

56-
} // namespace at
63+
namespace detail {
64+
65+
absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(
66+
c10::DeviceIndex device_index = -1);
67+
absl::StatusOr<at::Generator> CreateXLAGenerator(
68+
c10::DeviceIndex device_index = -1);
69+
70+
} // namespace detail
71+
72+
} // namespace at

0 commit comments

Comments
 (0)