Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 71 additions & 20 deletions rectools/model_selection/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import typing as tp
from contextlib import contextmanager

import pandas as pd

from rectools.columns import Columns
from rectools.dataset import Dataset
Expand All @@ -24,6 +27,26 @@
from .splitter import Splitter


@contextmanager
def compute_timing(label: str, timings: tp.Optional[tp.Dict[str, float]] = None) -> tp.Iterator[None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this function accepts the label timings param but don't really need it

Let's please rewrite it in one of the following ways:

  1. Remove both params and simply return the elapsed time without dictionaries
  2. Rewrite it as a class

I personally prefer the second option since it's clearer.
But anyway let's not use this labels and dictionary inside. We can easily fill them out of the class

And example (it's simplified a bit, please add init if required by linters, also types)

class Timer:        
    def __enter__(self):
        self._start = time.perf_counter()
        self._end = None
        return self

    def __exit__(self, *args):
        self._end = time.perf_counter()

    @property
    def elapsed(self):
        return self._end - self._start
    
    
with Timer() as timer:
    # code
    pass

fit_time = timer.elapsed

"""
Context manager to compute timing for a code block.

Parameters
----------
label : str
Label to store the timing result in the timings dictionary.
timings : dict, optional
Dictionary to store the timing results. If None, timing is not recorded.
"""
if timings is not None:
start_time = time.time()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use time.perf_counter instead, it's more correct for measuring time intervals

yield
timings[label] = round(time.time() - start_time, 2)
else:
yield


def cross_validate( # pylint: disable=too-many-locals
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update CHANGELOG.MD

dataset: Dataset,
splitter: Splitter,
Expand All @@ -36,6 +59,7 @@ def cross_validate( # pylint: disable=too-many-locals
ref_models: tp.Optional[tp.List[str]] = None,
validate_ref_models: bool = False,
on_unsupported_targets: ErrorBehaviour = "warn",
compute_timings: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add new argument to docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this param? what's wrong if we always measure the time?

) -> tp.Dict[str, tp.Any]:
"""
Run cross validation on multiple models with multiple metrics.
Expand Down Expand Up @@ -100,6 +124,47 @@ def cross_validate( # pylint: disable=too-many-locals
]
}
"""

def fit_recommend(model: ModelBase, ref_model: bool = False) -> tp.Tuple[pd.DataFrame, tp.Dict[str, tp.Any]]:
"""
Trains the given recommendation model on a dataset split and generates recommendations.

Parameters
----------
model : ModelBase
The recommendation model to be trained and used for generating recommendations.
Must be an instance of a subclass of `rectools.models.base.ModelBase`.
ref_model : bool, optional, default False
Indicates whether the model is a reference model used for comparison. If True,
and `validate_ref_models` is False, this model's recommendations may be reused
across splits without being refitted.

Returns
-------
tuple(pd.DataFrame, dict)
- A DataFrame with recommendations.
- A dictionary containing timing metrics (`fit_time` and `recommend_time`), if
`compute_timings` is enabled; otherwise, an empty dictionary.
"""
timings: tp.Optional[tp.Dict[str, float]] = (
{} if compute_timings and (validate_ref_models or not ref_model) else None
)

with compute_timing("fit_time", timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", timings):
reco = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)

return reco, (timings or {})

split_iterator = splitter.split(dataset.interactions, collect_fold_stats=True)

split_infos = []
Expand All @@ -123,35 +188,20 @@ def cross_validate( # pylint: disable=too-many-locals

# ### Train ref models if any
ref_reco = {}
ref_res = {}
for model_name in ref_models or []:
model = models[model_name]
model.fit(fold_dataset)
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)
ref_reco[model_name], ref_res[model_name] = fit_recommend(model, ref_model=True)

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
if model_name in ref_reco and not validate_ref_models:
continue

if model_name in ref_reco:
reco = ref_reco[model_name]
model_res = ref_res[model_name]
else:
model.fit(fold_dataset)
reco = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)
reco, model_res = fit_recommend(model)

metric_values = calc_metrics(
metrics,
Expand All @@ -163,6 +213,7 @@ def cross_validate( # pylint: disable=too-many-locals
)
res = {"model": model_name, "i_split": split_info["i_split"]}
res.update(metric_values)
res.update(model_res)
metrics_all.append(res)

result = {"splits": split_infos, "metrics": metrics_all}
Expand Down
Loading