@@ -64,6 +64,49 @@ def test_to_local_requires_grad(self):
6464 # All gradients should be 1.0 since we did a sum()
6565 self .assertTrue (torch .allclose (local_tensor .grad , torch .ones_like (tensor )))
6666
67+ def test_to_local_grad_independence (self ):
68+ """Test that gradients are independent between original and local tensor."""
69+ world_size = xr .global_runtime_device_count ()
70+ mesh = DeviceMesh ("xla" , list (range (world_size )))
71+
72+ tensor = torch .randn (100_000 , 88 , requires_grad = True )
73+ sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
74+
75+ # Create gradients
76+ res = sharded_tensor .sum ()
77+ res .backward ()
78+
79+ # Get local tensor
80+ local_tensor = sharded_tensor .to_local ()
81+
82+ # Verify gradients are initially the same
83+ self .assertTrue (torch .allclose (local_tensor .grad , sharded_tensor .grad ))
84+
85+ # Modify local tensor's gradient
86+ local_tensor .grad [0 , 0 ] = 999.0
87+
88+ # Verify gradients are now independent (not the same object)
89+ self .assertFalse (local_tensor .grad is sharded_tensor .grad )
90+ self .assertFalse (torch .allclose (local_tensor .grad , sharded_tensor .grad ))
91+
92+ def test_to_local_grad_none_handling (self ):
93+ """Test that to_local() handles None gradients correctly."""
94+ world_size = xr .global_runtime_device_count ()
95+ mesh = DeviceMesh ("xla" , list (range (world_size )))
96+
97+ tensor = torch .randn (100_000 , 88 , requires_grad = True )
98+ sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
99+
100+ # Don't do backward pass, so grad remains None
101+ self .assertIsNone (sharded_tensor .grad )
102+
103+ # Get local tensor
104+ local_tensor = sharded_tensor .to_local ()
105+
106+ # Verify local tensor has correct properties
107+ self .assertTrue (local_tensor .requires_grad )
108+ self .assertIsNone (local_tensor .grad )
109+
67110
68111if __name__ == "__main__" :
69112 result = unittest .main (exit = False )
0 commit comments