Skip to content
This repository was archived by the owner on Nov 14, 2023. It is now read-only.

Commit a9dc83d

Browse files
authored
Fix ExperimentAnalysis usage for Ray 2.7 (#272)
Ray 2.7 has some breaking API changes for `tune.ExperimentAnalysis`, which `tune-sklearn` depends on to construct the output. This PR fixes the API usage. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
1 parent 6e813e1 commit a9dc83d

File tree

5 files changed

+42
-40
lines changed

5 files changed

+42
-40
lines changed

tests/test_randomizedsearch.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_local_dir(self):
156156
parameter_grid,
157157
early_stopping=scheduler,
158158
max_iters=10,
159-
local_dir="./test-result")
159+
local_dir=os.path.abspath("./test-result"))
160160
tune_search.fit(x, y)
161161

162162
self.assertTrue(len(os.listdir("./test-result")) != 0)
@@ -290,7 +290,7 @@ def test_warm_start_detection(self):
290290
parameter_grid,
291291
n_jobs=1,
292292
max_iters=10,
293-
local_dir="./test-result")
293+
local_dir=os.path.abspath("./test-result"))
294294
self.assertEqual(tune_search.early_stop_type,
295295
EarlyStopping.NO_EARLY_STOP)
296296

@@ -301,7 +301,7 @@ def test_warm_start_detection(self):
301301
parameter_grid,
302302
n_jobs=1,
303303
max_iters=10,
304-
local_dir="./test-result")
304+
local_dir=os.path.abspath("./test-result"))
305305
self.assertEqual(tune_search2.early_stop_type,
306306
EarlyStopping.NO_EARLY_STOP)
307307

@@ -312,7 +312,7 @@ def test_warm_start_detection(self):
312312
parameter_grid,
313313
n_jobs=1,
314314
max_iters=10,
315-
local_dir="./test-result")
315+
local_dir=os.path.abspath("./test-result"))
316316

317317
self.assertEqual(tune_search3.early_stop_type,
318318
EarlyStopping.NO_EARLY_STOP)
@@ -323,7 +323,7 @@ def test_warm_start_detection(self):
323323
early_stopping=True,
324324
n_jobs=1,
325325
max_iters=10,
326-
local_dir="./test-result")
326+
local_dir=os.path.abspath("./test-result"))
327327
self.assertEqual(tune_search4.early_stop_type,
328328
EarlyStopping.WARM_START_ITER)
329329

@@ -334,7 +334,7 @@ def test_warm_start_detection(self):
334334
early_stopping=True,
335335
n_jobs=1,
336336
max_iters=10,
337-
local_dir="./test-result")
337+
local_dir=os.path.abspath("./test-result"))
338338
self.assertEqual(tune_search5.early_stop_type,
339339
EarlyStopping.WARM_START_ENSEMBLE)
340340

@@ -349,7 +349,7 @@ def test_warm_start_error(self):
349349
n_jobs=1,
350350
early_stopping=False,
351351
max_iters=10,
352-
local_dir="./test-result")
352+
local_dir=os.path.abspath("./test-result"))
353353
self.assertFalse(tune_search._can_early_stop())
354354
with self.assertRaises(ValueError):
355355
tune_search = TuneSearchCV(
@@ -358,7 +358,7 @@ def test_warm_start_error(self):
358358
n_jobs=1,
359359
early_stopping=True,
360360
max_iters=10,
361-
local_dir="./test-result")
361+
local_dir=os.path.abspath("./test-result"))
362362

363363
from sklearn.linear_model import LogisticRegression
364364
clf = LogisticRegression()
@@ -370,7 +370,7 @@ def test_warm_start_error(self):
370370
early_stopping=True,
371371
n_jobs=1,
372372
max_iters=10,
373-
local_dir="./test-result")
373+
local_dir=os.path.abspath("./test-result"))
374374

375375
from sklearn.ensemble import RandomForestClassifier
376376
clf = RandomForestClassifier()
@@ -382,21 +382,24 @@ def test_warm_start_error(self):
382382
early_stopping=True,
383383
n_jobs=1,
384384
max_iters=10,
385-
local_dir="./test-result")
385+
local_dir=os.path.abspath("./test-result"))
386386

387387
def test_warn_reduce_maxiters(self):
388388
parameter_grid = {"alpha": Real(1e-4, 1e-1, prior="log-uniform")}
389389
from sklearn.ensemble import RandomForestClassifier
390390
clf = RandomForestClassifier(max_depth=2, random_state=0)
391391
with self.assertWarnsRegex(UserWarning, "max_iters is set"):
392392
TuneSearchCV(
393-
clf, parameter_grid, max_iters=10, local_dir="./test-result")
393+
clf,
394+
parameter_grid,
395+
max_iters=10,
396+
local_dir=os.path.abspath("./test-result"))
394397
with self.assertWarnsRegex(UserWarning, "max_iters is set"):
395398
TuneSearchCV(
396399
SGDClassifier(),
397400
parameter_grid,
398401
max_iters=10,
399-
local_dir="./test-result")
402+
local_dir=os.path.abspath("./test-result"))
400403

401404
def test_warn_early_stop(self):
402405
X, y = make_classification(
@@ -893,9 +896,9 @@ def testHyperoptPointsToEvaluate(self):
893896
from ray.tune.search.hyperopt import HyperOptSearch
894897
# Skip test if category conversion is not available
895898
if not hasattr(HyperOptSearch, "_convert_categories_to_indices"):
896-
self.skipTest(f"The current version of Ray does not support the "
897-
f"`points_to_evaluate` argument for search method "
898-
f"`hyperopt`. Skipping test.")
899+
self.skipTest("The current version of Ray does not support the "
900+
"`points_to_evaluate` argument for search method "
901+
"`hyperopt`. Skipping test.")
899902
return
900903
self._test_points_to_evaluate("hyperopt")
901904

tune_sklearn/_trainable.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ def _train(self):
319319
return_train_score=self.return_train_score,
320320
error_score=self.error_score)
321321
except ValueError as e:
322-
if ("It is very likely that your"
323-
"model is misconfigured") not in str(e):
322+
if ("It is very likely that your model is misconfigured" not in
323+
str(e)):
324324
raise e
325325
fit_failed = True
326326

@@ -367,9 +367,9 @@ def _train(self):
367367

368368
return ret
369369

370-
def save_checkpoint(self, checkpoint_dir):
370+
def save_checkpoint(self, checkpoint_dir: str):
371371
# forward-compatbility
372-
return self._save(checkpoint_dir)
372+
self._save(checkpoint_dir)
373373

374374
def _save(self, checkpoint_dir):
375375
"""Creates a checkpoint in ``checkpoint_dir``, creating a pickle file.
@@ -387,21 +387,21 @@ def _save(self, checkpoint_dir):
387387
cpickle.dump(self.estimator_list, f)
388388
except Exception:
389389
warnings.warn("Unable to save estimator.", category=RuntimeWarning)
390-
return path
391390

392-
def load_checkpoint(self, checkpoint):
391+
def load_checkpoint(self, checkpoint_dir: str):
393392
# forward-compatbility
394-
return self._restore(checkpoint)
393+
self._restore(checkpoint_dir)
395394

396-
def _restore(self, checkpoint):
395+
def _restore(self, checkpoint_dir):
397396
"""Loads a checkpoint created from `save`.
398397
399398
Args:
400399
checkpoint (str): file path to pickled checkpoint file.
401400
402401
"""
402+
path = os.path.join(checkpoint_dir, "checkpoint")
403403
try:
404-
with open(checkpoint, "rb") as f:
404+
with open(path, "rb") as f:
405405
self.estimator_list = cpickle.load(f)
406406
except Exception:
407407
warnings.warn("No estimator restored", category=RuntimeWarning)

tune_sklearn/tune_basesearch.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def __init__(self,
327327
verbose=0,
328328
error_score="raise",
329329
return_train_score=False,
330-
local_dir="~/ray_results",
330+
local_dir=None,
331331
name=None,
332332
max_iters=1,
333333
use_gpu=False,
@@ -773,32 +773,31 @@ def _format_results(self, n_splits, out):
773773
trials = [
774774
trial for trial in out.trials if trial.status == Trial.TERMINATED
775775
]
776-
trial_dirs = [trial.logdir for trial in trials]
777-
# The result dtaframes are indexed by their trial logdir
778-
trial_dfs = out.fetch_trial_dataframes()
776+
trial_dfs = out.trial_dataframes
777+
trial_ids = list(trial_dfs)
779778

780779
# Try to find a template df to use for trials that did not return
781780
# any results. These trials should copy the structure and fill it
782781
# with NaNs so that the later reshape actions work.
783782
template_df = None
784-
fix_trial_dirs = [] # Holds trial dirs with no results
785-
for trial_dir in trial_dirs:
786-
if trial_dir in trial_dfs and template_df is None:
787-
template_df = trial_dfs[trial_dir]
788-
elif trial_dir not in trial_dfs:
789-
fix_trial_dirs.append(trial_dir)
783+
fix_trial_ids = [] # Holds trial_ids with no results
784+
for trial_id, trial_df in trial_dfs.items():
785+
if template_df is None and not trial_df.empty:
786+
template_df = trial_df
787+
elif trial_df.empty:
788+
fix_trial_ids.append(trial_id)
790789

791790
# Create NaN dataframes for trials without results
792-
if fix_trial_dirs:
791+
if fix_trial_ids:
793792
if template_df is None:
794793
# No trial returned any results
795794
return {}
796-
for trial_dir in fix_trial_dirs:
795+
for trial_id in fix_trial_ids:
797796
trial_df = pd.DataFrame().reindex_like(template_df)
798-
trial_dfs[trial_dir] = trial_df
797+
trial_dfs[trial_id] = trial_df
799798

800799
# Keep right order
801-
dfs = [trial_dfs[trial_dir] for trial_dir in trial_dirs]
800+
dfs = [trial_dfs[trial_id] for trial_id in trial_ids]
802801
finished = [df.iloc[[-1]] for df in dfs]
803802
test_scores = {}
804803
train_scores = {}

tune_sklearn/tune_gridsearch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(self,
157157
verbose=0,
158158
error_score="raise",
159159
return_train_score=False,
160-
local_dir="~/ray_results",
160+
local_dir=None,
161161
name=None,
162162
max_iters=1,
163163
use_gpu=False,

tune_sklearn/tune_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def __init__(self,
316316
random_state=None,
317317
error_score=np.nan,
318318
return_train_score=False,
319-
local_dir="~/ray_results",
319+
local_dir=None,
320320
name=None,
321321
max_iters=1,
322322
search_optimization="random",

0 commit comments

Comments
 (0)