Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions stringgen/string_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from pkg_resources import resource_filename


def complex_to_real(x):
return np.array([x.real, x.imag]).ravel()


def real_to_complex(x):
return x[: len(x) // 2] + 1j * x[len(x) // 2 :]


class CosmicStringEmulator:
def __init__(
self,
Expand All @@ -26,6 +34,7 @@ def __init__(
self.emulation_shape = emulation_shape
self.norm = norm # Normalization
self.pbc = pbc # Periodic boundary conditions
self.cplx = cplx # Complex inputs
if device is None:
self.device = 0 if torch.cuda.is_available() else "cpu"
else:
Expand Down Expand Up @@ -196,13 +205,34 @@ def emulate(self, features, n_emulations=1, max_iterations=100):
del tensor_norms

# initialise random starting image with mean=0, std=1
x0 = np.random.normal(0, 1, self.emulation_shape)
if self.cplx:
x0 = np.random.normal(
0,
np.sqrt(0.5),
self.emulation_shape,
) + 1j * np.random.normal(
0,
np.sqrt(0.5),
self.emulation_shape,
)
x0 = complex_to_real(x0)
else:
x0 = np.random.normal(
0,
1,
self.emulation_shape,
)
x0 = x0.ravel()

coeffs = torch.from_numpy(wph_coeffs).to(self.device).contiguous()

# synthesis
def objective(x):
start_time = time.time()
# Reshape x
if self.cplx:
x = real_to_complex(x)

x_curr = x.reshape((self.emulation_shape))
# Compute the loss (squared 2-norm)
loss_tot = torch.zeros(1)
Expand All @@ -228,12 +258,14 @@ def objective(x):
print(
f"Loss: {loss_tot.item()} (computed in {time.time() - start_time}s)"
)
if self.cplx:
return loss_tot.item(), complex_to_real(x_grad)
return loss_tot.item(), x_grad.ravel()

total_start_time = time.time()
result = opt.minimize(
objective,
x0.ravel(),
x0,
method="L-BFGS-B",
jac=True,
tol=None,
Expand All @@ -250,6 +282,9 @@ def objective(x):
)
print(f"Synthesis time: {time.time() - total_start_time}s")

if self.cplx:
x_final = real_to_complex(x_final)

x_final = x_final.reshape(self.emulation_shape).astype(np.float32)
x_final = x_final * std + mean
emulations.append(x_final)
Expand Down