Skip to content

Commit 49eaa53

Browse files
committed
add model.sae and model.load_model
1 parent 9a1303a commit 49eaa53

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tensordiffeq/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ def predict(self, X_star):
180180
f_u_star = self.f_model(self.u_model, *X_star)
181181
return u_star.numpy(), f_u_star.numpy()
182182

183+
def save(self, path):
184+
self.u_model.save(path)
185+
186+
def load_model(self, path):
187+
self.u_model = tf.keras.models.load_model(path)
188+
183189

184190
# WIP
185191
# TODO Distributed Discovery Model

0 commit comments

Comments
 (0)