diff --git a/python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py b/python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py index b812549e..30bfa010 100644 --- a/python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py +++ b/python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py @@ -9,6 +9,7 @@ from uuid import uuid4 from ..models import message_asdict_factory, Message, NewMessage, User from ..ychat import YChat +from ..utils import find_mentions USER = User( username=str(uuid4()), @@ -167,6 +168,63 @@ def test_update_message_should_append_content(): assert message_dict["sender"] == msg.sender +def test_update_message_includes_mentions(): + chat = YChat() + chat.set_user(USER) + chat.set_user(USER2) + chat.set_user(USER3) + + new_msg = create_new_message(f"@{USER2.mention_name} Hello!") + msg_id = chat.add_message(new_msg) + msg = chat.get_message(msg_id) + assert msg + assert msg.mentions == [USER2.username] + + msg.body = f"@{USER3.mention_name} Goodbye!" + chat.update_message(msg, trigger_actions=[find_mentions]) + updated_msg = chat.get_message(msg_id) + assert updated_msg + assert updated_msg.mentions == [USER3.username] + + +def test_update_message_append_includes_mentions(): + chat = YChat() + chat.set_user(USER) + chat.set_user(USER2) + chat.set_user(USER3) + + new_msg = create_new_message(f"@{USER2.mention_name} Hello!") + msg_id = chat.add_message(new_msg) + msg = chat.get_message(msg_id) + assert msg + assert msg.mentions == [USER2.username] + + msg.body = f" and @{USER3.mention_name}!" + chat.update_message(msg, append=True, trigger_actions=[find_mentions]) + updated_msg = chat.get_message(msg_id) + assert updated_msg + assert sorted(updated_msg.mentions) == sorted([USER2.username, USER3.username]) + + +def test_update_message_append_no_duplicate_mentions(): + chat = YChat() + chat.set_user(USER) + chat.set_user(USER2) + + new_msg = create_new_message(f"@{USER2.mention_name} Hello!") + msg_id = chat.add_message(new_msg) + msg = chat.get_message(msg_id) + assert msg + assert msg.mentions == [USER2.username] + + msg.body = f" @{USER2.mention_name} again!" + chat.update_message(msg, append=True, trigger_actions=[find_mentions]) + updated_msg = chat.get_message(msg_id) + assert updated_msg + assert updated_msg.mentions == [USER2.username] + assert len(updated_msg.mentions) == 1 + + def test_indexes_by_id(): chat = YChat() msg = create_new_message() diff --git a/python/jupyterlab-chat/jupyterlab_chat/utils.py b/python/jupyterlab-chat/jupyterlab_chat/utils.py new file mode 100644 index 00000000..3046b9c1 --- /dev/null +++ b/python/jupyterlab-chat/jupyterlab_chat/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +"""Utility functions for jupyter-chat.""" + +import re +from typing import TYPE_CHECKING, Set + +if TYPE_CHECKING: + from .models import Message + from .ychat import YChat + + +def find_mentions(message: "Message", chat: "YChat") -> None: + """ + Callback to extract and update mentions in a message. + + Finds all @mentions in the message body and updates the message's mentions list + with the corresponding usernames. + + Args: + message: The message object to update + chat: The YChat instance for accessing user data + """ + mention_pattern = re.compile(r"@([\w-]+):?") + mentioned_names: Set[str] = set(re.findall(mention_pattern, message.body)) + users = chat.get_users() + mentioned_usernames = [] + for username, user in users.items(): + if user.mention_name in mentioned_names and user.username not in mentioned_usernames: + mentioned_usernames.append(username) + + message.mentions = mentioned_usernames diff --git a/python/jupyterlab-chat/jupyterlab_chat/ychat.py b/python/jupyterlab-chat/jupyterlab_chat/ychat.py index a89e4fe0..38814fc3 100644 --- a/python/jupyterlab-chat/jupyterlab_chat/ychat.py +++ b/python/jupyterlab-chat/jupyterlab_chat/ychat.py @@ -15,6 +15,7 @@ import re from .models import message_asdict_factory, FileAttachment, NotebookAttachment, Message, NewMessage, User +from .utils import find_mentions class YChat(YBaseDoc): @@ -124,10 +125,18 @@ def _get_messages(self) -> list[dict]: """ return self._ymessages.to_py() or [] - def add_message(self, new_message: NewMessage) -> str: + def add_message(self, new_message: NewMessage, trigger_actions: list[Callable] | None = None) -> str: """ Append a message to the document. + + Args: + new_message: The message to add + trigger_actions: List of callbacks to execute on the message. Defaults to [find_mentions]. + Each callback receives (message, chat) as arguments. """ + if trigger_actions is None: + trigger_actions = [find_mentions] + timestamp: float = time.time() uid = str(uuid4()) message = Message( @@ -136,15 +145,9 @@ def add_message(self, new_message: NewMessage) -> str: id=uid, ) - # find all mentioned users and add them as message mentions - mention_pattern = re.compile("@([\w-]+):?") - mentioned_names: Set[str] = set(re.findall(mention_pattern, message.body)) - users = self.get_users() - mentioned_usernames = [] - for username, user in users.items(): - if user.mention_name in mentioned_names and user.username not in mentioned_usernames: - mentioned_usernames.append(username) - message.mentions = mentioned_usernames + # Execute all trigger action callbacks + for callback in trigger_actions: + callback(message, self) with self._ydoc.transaction(): index = len(self._ymessages) - next((i for i, v in enumerate(self._get_messages()[::-1]) if v["time"] < timestamp), len(self._ymessages)) @@ -155,10 +158,14 @@ def add_message(self, new_message: NewMessage) -> str: return uid - def update_message(self, message: Message, append: bool = False): + def update_message(self, message: Message, append: bool = False, trigger_actions: list[Callable] | None = None): """ Update a message of the document. - If append is True, the content will be append to the previous content. + + Args: + message: The message to update + append: If True, the content will be appended to the previous content + trigger_actions: List of callbacks to execute on the message. Each callback receives (message, chat) as arguments. """ with self._ydoc.transaction(): index = self._indexes_by_id[message.id] @@ -166,6 +173,12 @@ def update_message(self, message: Message, append: bool = False): message.time = initial_message["time"] # type:ignore[index] if append: message.body = initial_message["body"] + message.body # type:ignore[index] + + # Execute all trigger action callbacks + if trigger_actions: + for callback in trigger_actions: + callback(message, self) + self._ymessages[index] = asdict(message, dict_factory=message_asdict_factory) def get_attachments(self) -> dict[str, Union[FileAttachment, NotebookAttachment]]: