Skip to content

Commit bacafe9

Browse files
committed
cast inputs to f_model into tf.float32 for more stability
1 parent 64b3132 commit bacafe9

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

tensordiffeq/models.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def compile(self, layer_sizes, f_model, domain, bcs, isAdaptive=False,
2727
self.u_weights = u_weights
2828
self.X_f_dims = tf.shape(self.domain.X_f)
2929
self.X_f_len = tf.slice(self.X_f_dims, [0], [1]).numpy()
30-
tmp = [np.reshape(vec, (-1,1)) for i, vec in enumerate(self.domain.X_f.T)]
30+
# must explicitly cast data into tf.float32 for stability
31+
tmp = [tf.cast(np.reshape(vec, (-1,1)), tf.float32) for i, vec in enumerate(self.domain.X_f.T)]
3132
self.X_f_in = np.asarray(tmp)
3233
self.u_model = neural_net(self.layer_sizes)
3334

@@ -120,12 +121,14 @@ def loss_and_flat_grad(w):
120121
return loss_and_flat_grad
121122

122123
def predict(self, X_star):
123-
X_star = convertTensor(X_star)
124+
# predict using concatenated data
124125
u_star = self.u_model(X_star)
125-
126-
f_u_star = self.f_model(self.u_model, X_star[:, 0:1],
127-
X_star[:, 1:2])
128-
126+
# split data into tuples for ND support
127+
# must explicitly cast data into tf.float32 for stability
128+
tmp = [tf.cast(np.reshape(vec, (-1,1)), tf.float32) for i, vec in enumerate(X_star.T)]
129+
X_star = np.asarray(tmp)
130+
X_star = tuple(X_star)
131+
f_u_star = self.f_model(self.u_model, *X_star)
129132
return u_star.numpy(), f_u_star.numpy()
130133

131134

0 commit comments

Comments
 (0)