diff --git a/examples/hetero/recommender_system.py b/examples/hetero/recommender_system.py index 8dd5d8349123..824a19fffb8a 100644 --- a/examples/hetero/recommender_system.py +++ b/examples/hetero/recommender_system.py @@ -23,7 +23,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') -data = MovieLens(path, model_name='all-MiniLM-L6-v2')[0] +data = MovieLens(path)[0] # Add user node features for message passing: data['user'].x = torch.eye(data['user'].num_nodes) diff --git a/torch_geometric/datasets/movie_lens.py b/torch_geometric/datasets/movie_lens.py index 1c9f435c6a74..f73ec687b42d 100644 --- a/torch_geometric/datasets/movie_lens.py +++ b/torch_geometric/datasets/movie_lens.py @@ -42,7 +42,7 @@ def __init__( root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, - model_name: Optional[str] = 'all-MiniLM-L6-v2', + model_name: Optional[str] = 'Alibaba-NLP/gte-modernbert-base', force_reload: bool = False, ) -> None: self.model_name = model_name @@ -68,7 +68,8 @@ def download(self) -> None: def process(self) -> None: import pandas as pd - from sentence_transformers import SentenceTransformer + + from torch_geometric.llm.models import SentenceTransformer data = HeteroData() @@ -78,10 +79,11 @@ def process(self) -> None: genres = df['genres'].str.get_dummies('|').values genres = torch.from_numpy(genres).to(torch.float) - model = SentenceTransformer(self.model_name) + model = SentenceTransformer(model_name=self.model_name) + if torch.cuda.is_available(): + model = model.cuda() with torch.no_grad(): - emb = model.encode(df['title'].values, show_progress_bar=True, - convert_to_tensor=True).cpu() + emb = model.encode(text=df['title'].values).cpu() data['movie'].x = torch.cat([emb, genres], dim=-1)