Skip to content
Open
58 changes: 58 additions & 0 deletions python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from uuid import uuid4
from ..models import message_asdict_factory, Message, NewMessage, User
from ..ychat import YChat
from ..utils import find_mentions_callback

USER = User(
username=str(uuid4()),
Expand Down Expand Up @@ -167,6 +168,63 @@ def test_update_message_should_append_content():
assert message_dict["sender"] == msg.sender


def test_update_message_includes_mentions():
chat = YChat()
chat.set_user(USER)
chat.set_user(USER2)
chat.set_user(USER3)

new_msg = create_new_message(f"@{USER2.mention_name} Hello!")
msg_id = chat.add_message(new_msg)
msg = chat.get_message(msg_id)
assert msg
assert msg.mentions == [USER2.username]

msg.body = f"@{USER3.mention_name} Goodbye!"
chat.update_message(msg, trigger_actions=[find_mentions_callback])
updated_msg = chat.get_message(msg_id)
assert updated_msg
assert updated_msg.mentions == [USER3.username]


def test_update_message_append_includes_mentions():
chat = YChat()
chat.set_user(USER)
chat.set_user(USER2)
chat.set_user(USER3)

new_msg = create_new_message(f"@{USER2.mention_name} Hello!")
msg_id = chat.add_message(new_msg)
msg = chat.get_message(msg_id)
assert msg
assert msg.mentions == [USER2.username]

msg.body = f" and @{USER3.mention_name}!"
chat.update_message(msg, append=True, trigger_actions=[find_mentions_callback])
updated_msg = chat.get_message(msg_id)
assert updated_msg
assert sorted(updated_msg.mentions) == sorted([USER2.username, USER3.username])


def test_update_message_append_no_duplicate_mentions():
chat = YChat()
chat.set_user(USER)
chat.set_user(USER2)

new_msg = create_new_message(f"@{USER2.mention_name} Hello!")
msg_id = chat.add_message(new_msg)
msg = chat.get_message(msg_id)
assert msg
assert msg.mentions == [USER2.username]

msg.body = f" @{USER2.mention_name} again!"
chat.update_message(msg, append=True, trigger_actions=[find_mentions_callback])
updated_msg = chat.get_message(msg_id)
assert updated_msg
assert updated_msg.mentions == [USER2.username]
assert len(updated_msg.mentions) == 1


def test_indexes_by_id():
chat = YChat()
msg = create_new_message()
Expand Down
33 changes: 33 additions & 0 deletions python/jupyterlab-chat/jupyterlab_chat/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

"""Utility functions for jupyter-chat."""

import re
from typing import TYPE_CHECKING, Set

if TYPE_CHECKING:
from .models import Message
from .ychat import YChat


def find_mentions(message: "Message", chat: "YChat") -> None:
"""
Callback to extract and update mentions in a message.

Finds all @mentions in the message body and updates the message's mentions list
with the corresponding usernames.

Args:
message: The message object to update
chat: The YChat instance for accessing user data
"""
mention_pattern = re.compile(r"@([\w-]+):?")
mentioned_names: Set[str] = set(re.findall(mention_pattern, message.body))
users = chat.get_users()
mentioned_usernames = []
for username, user in users.items():
if user.mention_name in mentioned_names and user.username not in mentioned_usernames:
mentioned_usernames.append(username)

message.mentions = mentioned_usernames
37 changes: 25 additions & 12 deletions python/jupyterlab-chat/jupyterlab_chat/ychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import re

from .models import message_asdict_factory, FileAttachment, NotebookAttachment, Message, NewMessage, User
from .utils import find_mentions_callback


class YChat(YBaseDoc):
Expand Down Expand Up @@ -124,10 +125,18 @@ def _get_messages(self) -> list[dict]:
"""
return self._ymessages.to_py() or []

def add_message(self, new_message: NewMessage) -> str:
def add_message(self, new_message: NewMessage, trigger_actions: list[Callable] | None = None) -> str:
"""
Append a message to the document.

Args:
new_message: The message to add
trigger_actions: List of callbacks to execute on the message. Defaults to [find_mentions_callback].
Each callback receives (message, chat) as arguments.
"""
if trigger_actions is None:
trigger_actions = [find_mentions_callback]

timestamp: float = time.time()
uid = str(uuid4())
message = Message(
Expand All @@ -136,15 +145,9 @@ def add_message(self, new_message: NewMessage) -> str:
id=uid,
)

# find all mentioned users and add them as message mentions
mention_pattern = re.compile("@([\w-]+):?")
mentioned_names: Set[str] = set(re.findall(mention_pattern, message.body))
users = self.get_users()
mentioned_usernames = []
for username, user in users.items():
if user.mention_name in mentioned_names and user.username not in mentioned_usernames:
mentioned_usernames.append(username)
message.mentions = mentioned_usernames
# Execute all trigger action callbacks
for callback in trigger_actions:
callback(message, self)

with self._ydoc.transaction():
index = len(self._ymessages) - next((i for i, v in enumerate(self._get_messages()[::-1]) if v["time"] < timestamp), len(self._ymessages))
Expand All @@ -155,17 +158,27 @@ def add_message(self, new_message: NewMessage) -> str:

return uid

def update_message(self, message: Message, append: bool = False):
def update_message(self, message: Message, append: bool = False, trigger_actions: list[Callable] | None = None):
"""
Update a message of the document.
If append is True, the content will be append to the previous content.

Args:
message: The message to update
append: If True, the content will be appended to the previous content
trigger_actions: List of callbacks to execute on the message. Each callback receives (message, chat) as arguments.
"""
with self._ydoc.transaction():
index = self._indexes_by_id[message.id]
initial_message = self._ymessages[index]
message.time = initial_message["time"] # type:ignore[index]
if append:
message.body = initial_message["body"] + message.body # type:ignore[index]

# Execute all trigger action callbacks
if trigger_actions:
for callback in trigger_actions:
callback(message, self)

self._ymessages[index] = asdict(message, dict_factory=message_asdict_factory)

def get_attachments(self) -> dict[str, Union[FileAttachment, NotebookAttachment]]:
Expand Down