From a601b148513afc289197ee4776bf2a089cd35659 Mon Sep 17 00:00:00 2001 From: Joao Felipe Guedes Date: Sun, 19 Apr 2020 13:56:56 -0300 Subject: [PATCH] Including method SVD.get_utility_matrix --- funk_svd/svd.py | 16 ++++++++++++++++ run_experiment.py | 15 ++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/funk_svd/svd.py b/funk_svd/svd.py index 16ec092..7c3e2fc 100644 --- a/funk_svd/svd.py +++ b/funk_svd/svd.py @@ -253,3 +253,19 @@ def _on_epoch_end(self, start, val_loss=None, val_rmse=None, val_mae=None): print('val_mae: {:.2f}'.format(val_mae), end=' - ') print('took {:.1f} sec'.format(end - start)) + + def get_utility_matrix(self, X, fillna = 0): + """ Creates an utility matrix based on a [u_id, i_id, rating] dataframe + + Args: + X {pd.DataFrame} -- dataframe with columns u_id, i_id and rating + fillna {int} -- value to fill the non-existing ratings + + Returns: + pd.DataFrame -- utility matrix with users as index and items as + columns + """ + + return X.pivot( + index='u_id', columns='i_id', values='rating' + ).fillna(fillna) \ No newline at end of file diff --git a/run_experiment.py b/run_experiment.py index 4c353ab..ad8c1df 100644 --- a/run_experiment.py +++ b/run_experiment.py @@ -6,7 +6,6 @@ from sklearn.metrics import mean_absolute_error - df = fetch_ml_ratings(variant='100k') train = df.sample(frac=0.8, random_state=7) @@ -16,9 +15,19 @@ svd = SVD(learning_rate=0.001, regularization=0.005, n_epochs=100, n_factors=15, min_rating=1, max_rating=5) +df_matrix_original = svd.get_utility_matrix(df) +print ("Original Utility Matrix: \n", df_matrix_original.values) + +# Getting all u_id and i_id combinations +df_user_item = pd.melt(df_matrix_original.reset_index(drop=False), id_vars='u_id') + svd.fit(X=train, X_val=val, early_stopping=True, shuffle=False) -pred = svd.predict(test) -mae = mean_absolute_error(test["rating"], pred) +pred_test = svd.predict(test) +df_user_item["rating"] = svd.predict(df_user_item) + +print ("Predicted Utility Matrix: \n", svd.get_utility_matrix(df_user_item).values) + +mae = mean_absolute_error(test["rating"], pred_test) print(f'Test MAE: {mae:.2f}')