From e3229817d0e3f63534e1cd60a41454c0a8410ec0 Mon Sep 17 00:00:00 2001 From: henry-ald Date: Wed, 3 Dec 2025 10:49:47 +0000 Subject: [PATCH] add complex support --- stringgen/string_emulator.py | 39 ++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/stringgen/string_emulator.py b/stringgen/string_emulator.py index 8e9693f..59e9893 100644 --- a/stringgen/string_emulator.py +++ b/stringgen/string_emulator.py @@ -11,6 +11,14 @@ from pkg_resources import resource_filename +def complex_to_real(x): + return np.array([x.real, x.imag]).ravel() + + +def real_to_complex(x): + return x[: len(x) // 2] + 1j * x[len(x) // 2 :] + + class CosmicStringEmulator: def __init__( self, @@ -26,6 +34,7 @@ def __init__( self.emulation_shape = emulation_shape self.norm = norm # Normalization self.pbc = pbc # Periodic boundary conditions + self.cplx = cplx # Complex inputs if device is None: self.device = 0 if torch.cuda.is_available() else "cpu" else: @@ -196,13 +205,34 @@ def emulate(self, features, n_emulations=1, max_iterations=100): del tensor_norms # initialise random starting image with mean=0, std=1 - x0 = np.random.normal(0, 1, self.emulation_shape) + if self.cplx: + x0 = np.random.normal( + 0, + np.sqrt(0.5), + self.emulation_shape, + ) + 1j * np.random.normal( + 0, + np.sqrt(0.5), + self.emulation_shape, + ) + x0 = complex_to_real(x0) + else: + x0 = np.random.normal( + 0, + 1, + self.emulation_shape, + ) + x0 = x0.ravel() + coeffs = torch.from_numpy(wph_coeffs).to(self.device).contiguous() # synthesis def objective(x): start_time = time.time() # Reshape x + if self.cplx: + x = real_to_complex(x) + x_curr = x.reshape((self.emulation_shape)) # Compute the loss (squared 2-norm) loss_tot = torch.zeros(1) @@ -228,12 +258,14 @@ def objective(x): print( f"Loss: {loss_tot.item()} (computed in {time.time() - start_time}s)" ) + if self.cplx: + return loss_tot.item(), complex_to_real(x_grad) return loss_tot.item(), x_grad.ravel() total_start_time = time.time() result = opt.minimize( objective, - x0.ravel(), + x0, method="L-BFGS-B", jac=True, tol=None, @@ -250,6 +282,9 @@ def objective(x): ) print(f"Synthesis time: {time.time() - total_start_time}s") + if self.cplx: + x_final = real_to_complex(x_final) + x_final = x_final.reshape(self.emulation_shape).astype(np.float32) x_final = x_final * std + mean emulations.append(x_final)