diff --git a/examples/reference/chat/ChatFeed.ipynb b/examples/reference/chat/ChatFeed.ipynb index decb3fd775..cf3fc7e2b3 100644 --- a/examples/reference/chat/ChatFeed.ipynb +++ b/examples/reference/chat/ChatFeed.ipynb @@ -54,6 +54,7 @@ "* **`placeholder_text`** (str): The text to display next to the placeholder icon.\n", "* **`placeholder_params`** (dict) Defaults to `{\"user\": \" \", \"reaction_icons\": {}, \"show_copy_icon\": False, \"show_timestamp\": False}` Params to pass to the placeholder `ChatMessage`, like `reaction_icons`, `timestamp_format`, `show_avatar`, `show_user`, `show_timestamp`.\n", "* **`placeholder_threshold`** (float): Min duration in seconds of buffering before displaying the placeholder. If 0, the placeholder will be disabled. Defaults to 0.2.\n", + "* **`post_hook`** (callable): A hook to execute after a new message is *completely* added, i.e. the generator is exhausted. The `stream` method will trigger this callback on every call. The signature must include the `message` and `instance` arguments.\n", "* **`auto_scroll_limit`** (int): Max pixel distance from the latest object in the Column to activate automatic scrolling upon update. Setting to 0 disables auto-scrolling.\n", "* **`scroll_button_threshold`** (int): Min pixel distance from the latest object in the Column to display the scroll button. Setting to 0 disables the scroll button.\n", "* **`load_buffer`** (int): The number of objects loaded on each side of the visible objects. When scrolled halfway into the buffer, the feed will automatically load additional objects while unloading objects on the opposite side.\n", diff --git a/panel/chat/feed.py b/panel/chat/feed.py index 8c0147db9f..f458100b02 100644 --- a/panel/chat/feed.py +++ b/panel/chat/feed.py @@ -81,7 +81,7 @@ class ChatFeed(ListPanel): auto_scroll_limit = param.Integer(default=200, bounds=(0, None), doc=""" Max pixel distance from the latest object in the Column to activate automatic scrolling upon update. Setting to 0 - disables auto-scrolling.""",) + disables auto-scrolling.""") callback = param.Callable(allow_refs=False, doc=""" Callback to execute when a user sends a message or @@ -133,6 +133,11 @@ class ChatFeed(ListPanel): `help` as the user. This is useful for providing instructions, and will not be included in the `serialize` method by default.""") + load_buffer = param.Integer(default=50, bounds=(0, None), doc=""" + The number of objects loaded on each side of the visible objects. + When scrolled halfway into the buffer, the feed will automatically + load additional objects while unloading objects on the opposite side.""") + placeholder_text = param.String(default="", doc=""" The text to display next to the placeholder icon.""") @@ -148,6 +153,12 @@ class ChatFeed(ListPanel): Min duration in seconds of buffering before displaying the placeholder. If 0, the placeholder will be disabled.""") + post_hook = param.Callable(allow_refs=False, doc=""" + A hook to execute after a new message is *completely* added, + i.e. the generator is exhausted. The `stream` method will trigger + this callback on every call. The signature must include the + `message` and `instance` arguments.""") + renderers = param.HookList(doc=""" A callable or list of callables that accept the value and return a Panel object to render the value. If a list is provided, will @@ -155,11 +166,6 @@ class ChatFeed(ListPanel): exception. If None, will attempt to infer the renderer from the value.""") - load_buffer = param.Integer(default=50, bounds=(0, None), doc=""" - The number of objects loaded on each side of the visible objects. - When scrolled halfway into the buffer, the feed will automatically - load additional objects while unloading objects on the opposite side.""") - scroll_button_threshold = param.Integer(default=100, bounds=(0, None),doc=""" Min pixel distance from the latest object in the Column to display the scroll button. Setting to 0 @@ -182,6 +188,8 @@ class ChatFeed(ListPanel): _callback_trigger = param.Event(doc="Triggers the callback to respond.") + _post_hook_trigger = param.Event(doc="Triggers the append callback.") + _disabled_stack = param.List(doc=""" The previous disabled state of the feed.""") @@ -262,6 +270,7 @@ def __init__(self, *objects, **params): # handle async callbacks using this trick self.param.watch(self._prepare_response, '_callback_trigger') + self.param.watch(self._after_append_completed, '_post_hook_trigger') def _get_model( self, doc: Document, root: Model | None = None, @@ -430,6 +439,7 @@ async def _serialize_response(self, response: Any) -> ChatMessage | None: response_message = self._upsert_message(await response, response_message) else: response_message = self._upsert_message(response, response_message) + self.param.trigger("_post_hook_trigger") finally: if response_message: response_message.show_activity_dot = False @@ -484,6 +494,7 @@ async def _handle_callback(self, message, loop: asyncio.BaseEventLoop): else: response = await asyncio.to_thread(self.callback, *callback_args) await self._serialize_response(response) + return response async def _prepare_response(self, *_) -> None: """ @@ -580,6 +591,7 @@ def send( value = {"object": value} message = self._build_message(value, user=user, avatar=avatar) self.append(message) + self.param.trigger("_post_hook_trigger") if respond: self.respond() return message @@ -644,6 +656,8 @@ def stream( value = {"object": value} message = self._build_message(value, user=user, avatar=avatar) self._replace_placeholder(message) + + self.param.trigger("_post_hook_trigger") return message def respond(self): @@ -758,6 +772,19 @@ def _serialize_for_transformers( serialized_messages.append({"role": role, "content": content}) return serialized_messages + async def _after_append_completed(self, message): + """ + Trigger the append callback after a message is added to the chat feed. + """ + if self.post_hook is None: + return + + message = self._chat_log.objects[-1] + if iscoroutinefunction(self.post_hook): + await self.post_hook(message, self) + else: + self.post_hook(message, self) + def serialize( self, exclude_users: List[str] | None = None, diff --git a/panel/tests/chat/test_feed.py b/panel/tests/chat/test_feed.py index 6050976442..3d7207da2f 100644 --- a/panel/tests/chat/test_feed.py +++ b/panel/tests/chat/test_feed.py @@ -997,3 +997,85 @@ def test_invalid(self): chat_feed = ChatFeed() chat_feed.send("I'm a user", user="user") chat_feed.serialize(format="atransform") + + +@pytest.mark.xdist_group("chat") +class TestChatFeedPostHook: + + def test_return_string(self, chat_feed): + def callback(contents, user, instance): + yield f"Echo: {contents}" + + def append_callback(message, instance): + logs.append(message.object) + + logs = [] + chat_feed.callback = callback + chat_feed.post_hook = append_callback + chat_feed.send("Hello World!") + wait_until(lambda: chat_feed.objects[-1].object == "Echo: Hello World!") + assert logs == ["Hello World!", "Echo: Hello World!"] + + def test_yield_string(self, chat_feed): + def callback(contents, user, instance): + yield f"Echo: {contents}" + + def append_callback(message, instance): + logs.append(message.object) + + logs = [] + chat_feed.callback = callback + chat_feed.post_hook = append_callback + chat_feed.send("Hello World!") + wait_until(lambda: chat_feed.objects[-1].object == "Echo: Hello World!") + assert logs == ["Hello World!", "Echo: Hello World!"] + + def test_generator(self, chat_feed): + def callback(contents, user, instance): + message = "Echo: " + for char in contents: + message += char + yield message + + def append_callback(message, instance): + logs.append(message.object) + + logs = [] + chat_feed.callback = callback + chat_feed.post_hook = append_callback + chat_feed.send("Hello World!") + wait_until(lambda: chat_feed.objects[-1].object == "Echo: Hello World!") + assert logs == ["Hello World!", "Echo: Hello World!"] + + def test_async_generator(self, chat_feed): + async def callback(contents, user, instance): + message = "Echo: " + for char in contents: + message += char + yield message + + async def append_callback(message, instance): + logs.append(message.object) + + logs = [] + chat_feed.callback = callback + chat_feed.post_hook = append_callback + chat_feed.send("Hello World!") + wait_until(lambda: chat_feed.objects[-1].object == "Echo: Hello World!") + assert logs == ["Hello World!", "Echo: Hello World!"] + + def test_stream(self, chat_feed): + def callback(contents, user, instance): + message = instance.stream("Echo: ") + for char in contents: + message = instance.stream(char, message=message) + + def append_callback(message, instance): + logs.append(message.object) + + logs = [] + chat_feed.callback = callback + chat_feed.post_hook = append_callback + chat_feed.send("AB") + wait_until(lambda: chat_feed.objects[-1].object == "Echo: AB") + assert logs == ["AB", "Echo: ", "Echo: AB"]