Skip to content

Commit 9f954a8

Browse files
modified analytiv unit test to not train model by adding train_* booleans
1 parent 28aaa49 commit 9f954a8

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

batchglm/train/tf/base_glm_all/estimator.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def train(
206206
learning_rate=None,
207207
convergence_criteria="all_converged",
208208
stopping_criteria=None,
209-
train_mu: bool = None,
210-
train_r: bool = None,
209+
train_loc: bool = None,
210+
train_scale: bool = None,
211211
use_batching=False,
212212
optim_algo=None,
213213
**kwargs
@@ -227,19 +227,19 @@ def train(
227227
:param stopping_criteria: Additional parameter for convergence criteria.
228228
229229
See parameter `convergence_criteria` for exact meaning
230-
:param train_mu: Set to True/False in order to enable/disable training of mu
231-
:param train_r: Set to True/False in order to enable/disable training of r
230+
:param train_loc: Set to True/False in order to enable/disable training of loc
231+
:param train_scale: Set to True/False in order to enable/disable training of scale
232232
:param use_batching: If True, will use mini-batches with the batch size defined in the constructor.
233233
Otherwise, the gradient of the full dataset will be used.
234234
:param optim_algo: name of the requested train op.
235235
See :func:train_utils.MultiTrainer.train_op_by_name for further details.
236236
"""
237-
if train_mu is None:
237+
if train_loc is None:
238238
# check if mu was initialized with MLE
239239
train_mu = self._train_loc
240-
if train_r is None:
240+
if train_scale is None:
241241
# check if r was initialized with MLE
242-
train_r = self._train_scale
242+
train_scale = self._train_scale
243243

244244
# Check whether newton-rhapson is desired:
245245
require_hessian = False
@@ -290,15 +290,15 @@ def train(
290290
logging.getLogger("batchglm").debug("learning_rate " + str(learning_rate))
291291
logging.getLogger("batchglm").debug("convergence_criteria " + str(convergence_criteria))
292292
logging.getLogger("batchglm").debug("stopping_criteria " + str(stopping_criteria))
293-
logging.getLogger("batchglm").debug("train_mu " + str(train_mu))
294-
logging.getLogger("batchglm").debug("train_r " + str(train_r))
293+
logging.getLogger("batchglm").debug("train_loc " + str(train_loc))
294+
logging.getLogger("batchglm").debug("train_scale " + str(train_scale))
295295
logging.getLogger("batchglm").debug("use_batching " + str(use_batching))
296296
logging.getLogger("batchglm").debug("optim_algo " + str(optim_algo))
297297
if len(kwargs) > 0:
298298
logging.getLogger("batchglm").debug("**kwargs: ")
299299
logging.getLogger("batchglm").debug(kwargs)
300300

301-
if train_mu or train_r:
301+
if train_loc or train_scale:
302302
if use_batching:
303303
train_op = self.model.trainer_batch.train_op_by_name(optim_algo)
304304
else:

batchglm/unit_test/glm_all/test_acc_analytic_glm_all.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def estimate(self):
7575
{
7676
"convergence_criteria": "all_converged",
7777
"use_batching": False,
78-
"optim_algo": "gd"
78+
"optim_algo": "gd",
79+
"train_loc": False,
80+
"train_scale": False
7981
},
8082
])
8183

@@ -251,7 +253,7 @@ class Test_AccuracyAnalytic_GLM_NB(
251253

252254
def test_a_closed_b_closed(self):
253255
logging.getLogger("tensorflow").setLevel(logging.ERROR),
254-
logging.getLogger("batchglm").setLevel(logging.INFO)
256+
logging.getLogger("batchglm").setLevel(logging.DEBUG)
255257
logger.error("Test_AccuracyAnalytic_GLM_NB.test_a_closed_b_closed()")
256258

257259
self.noise_model = "nb"

0 commit comments

Comments
 (0)