Skip to content

Commit 4a48cdd

Browse files
committed
push helmholtz example
1 parent b31e5fe commit 4a48cdd

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

examples/steady-state.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import math
2+
import matplotlib.pyplot as plt
3+
import tensorflow as tf
4+
import tensordiffeq as tdq
5+
from tensordiffeq.boundaries import *
6+
from tensordiffeq.models import CollocationSolverND
7+
from tensorflow.math import sin
8+
from tensordiffeq.utils import constant
9+
10+
Domain = DomainND(["x", "y"])
11+
12+
Domain.add("x", [-1.0, 1.0], 1001)
13+
Domain.add("y", [-1.0, 1.0], 1001)
14+
15+
N_f = 10000
16+
Domain.generate_collocation_points(N_f)
17+
18+
19+
def f_model(u_model, x, y):
20+
u = u_model(tf.concat([x, y], 1))
21+
u_x = tf.gradients(u, x)[0]
22+
u_y = tf.gradients(u, y)[0]
23+
u_xx = tf.gradients(u_x, x)[0]
24+
u_yy = tf.gradients(u_y, y)[0]
25+
26+
a1 = constant(1.0)
27+
a2 = constant(4.0)
28+
ksq = constant(1.0)
29+
pi = constant(math.pi)
30+
31+
# note that we must use tensorflow math primitives such as sin, cos, etc!
32+
forcing = - (a1 * pi) ** 2 * sin(a1 * pi * x) * sin(a2 * pi * y) - \
33+
(a2 * pi) ** 2 * sin(a1 * pi * x) * sin(a2 * pi * y) + \
34+
ksq * sin(a1 * pi * x) * sin(a2 * pi * y)
35+
print(np.shape(u_xx + u_yy))
36+
37+
f_u = u_xx + u_yy + ksq * u - forcing # = 0
38+
39+
return f_u
40+
41+
42+
upper_x = dirichletBC(Domain, val=0.0, var='x', target="upper")
43+
lower_x = dirichletBC(Domain, val=0.0, var='x', target="lower")
44+
upper_y = dirichletBC(Domain, val=0.0, var='y', target="upper")
45+
lower_y = dirichletBC(Domain, val=0.0, var='y', target="lower")
46+
47+
BCs = [upper_x, lower_x, upper_y, lower_y]
48+
49+
layer_sizes = [2, 50, 50, 50, 50, 1]
50+
51+
model = CollocationSolverND()
52+
model.compile(layer_sizes, f_model, Domain, BCs)
53+
54+
model.fit(tf_iter=100, newton_iter=100)
55+
56+
# get exact solution
57+
nx, ny = (1001, 1001)
58+
x = np.linspace(-1, 1, nx)
59+
y = np.linspace(-1, 1, ny)
60+
61+
xv, yv = np.meshgrid(x, y)
62+
63+
x = np.reshape(x, (-1, 1))
64+
y = np.reshape(y, (-1, 1))
65+
66+
# Exact analytical soln is available:
67+
Exact_u = np.sin(math.pi * xv) * np.sin(4 * math.pi * yv)
68+
69+
# Flatten for use
70+
u_star = Exact_u.flatten()[:, None]
71+
72+
# Plotting
73+
x = Domain.domaindict[0]['xlinspace']
74+
y = Domain.domaindict[1]["ylinspace"]
75+
76+
X, Y = np.meshgrid(x, y)
77+
78+
# print(np.shape((X,Y))) # 2, 256, 256
79+
X_star = np.hstack((X.flatten()[:, None], Y.flatten()[:, None]))
80+
81+
lb = np.array([-1.0, -1.0])
82+
ub = np.array([1.0, 1])
83+
84+
u_pred, f_u_pred = model.predict(X_star)
85+
86+
error_u = tdq.helpers.find_L2_error(u_pred, u_star)
87+
print('Error u: %e' % (error_u))
88+
89+
U_pred = tdq.plotting.get_griddata(X_star, u_pred.flatten(), (X, Y))
90+
FU_pred = tdq.plotting.get_griddata(X_star, f_u_pred.flatten(), (X, Y))
91+
92+
lb = np.array([-1.0, -1.0])
93+
ub = np.array([1.0, 1.0])
94+
95+
tdq.plotting.plot_solution_domain1D(model, [x, y], ub=ub, lb=lb, Exact_u=Exact_u)

0 commit comments

Comments
 (0)