File tree Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -186,7 +186,7 @@ def setUp(self):
186186 self .env = torchax .default_env ()
187187 torchax .enable_accuracy_mode ()
188188 #self.env.config.debug_accuracy_for_each_op = True
189- self .env .config .debug_print_each_op = True
189+ self .env .config .debug_print_each_op = False
190190 torch .manual_seed (0 )
191191 self .old_var = self .env .config .use_torch_native_for_cpu_tensor
192192 self .env .config .use_torch_native_for_cpu_tensor = False
Original file line number Diff line number Diff line change @@ -10,6 +10,11 @@ class Configuration:
1010
1111 use_int32_for_index : bool = False
1212
13+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
14+ # allowed. However, if that torch.Tensor happens to be scalar, then we
15+ # can use scalar * tensor math to handle it
16+ allow_mixed_math_with_scalar_tensor : bool = True
17+
1318 # If true, we will convert Views into torchax.Tensors eagerly
1419 force_materialize_views : bool = False
1520
Original file line number Diff line number Diff line change @@ -639,6 +639,10 @@ def t2j_iso(self, torchtensors):
639639 """
640640
641641 def to_jax (x ):
642+ if self .config .allow_mixed_math_with_scalar_tensor and not isinstance (
643+ x , Tensor ):
644+ if x .squeeze ().ndim == 0 :
645+ return x .item ()
642646 if isinstance (
643647 x , torch .distributed ._functional_collectives .AsyncCollectiveTensor ):
644648 x = x .wait ()
You can’t perform that action at this time.
0 commit comments