diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index bbb7b9c..0655173 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -267,6 +267,58 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs): """ +def in_marimo_notebook() -> bool: + try: + import marimo as mo + + return mo.running_in_notebook() + except ImportError: + return False + + +def _mo_write_internal(cell_id, stream, value: object) -> None: + """Write to marimo cell given cell_id and stream.""" + from marimo._output import formatting + from marimo._messaging.ops import CellOp + from marimo._messaging.tracebacks import write_traceback + from marimo._messaging.cell_output import CellChannel + + output = formatting.try_format(value) + if output.traceback is not None: + write_traceback(output.traceback) + CellOp.broadcast_output( + channel=CellChannel.OUTPUT, + mimetype=output.mimetype, + data=output.data, + cell_id=cell_id, + status=None, + stream=stream, + ) + + +def _mo_create_replace(): + """Create mo.output.replace with current context pinned.""" + from marimo._runtime.context import get_context + from marimo._runtime.context.types import ContextNotInitializedError + from marimo._output import formatting + + try: + ctx = get_context() + except ContextNotInitializedError: + return + + cell_id = ctx.execution_context.cell_id + execution_context = ctx.execution_context + stream = ctx.stream + + def replace(value): + execution_context.output = [formatting.as_html(value)] + + _mo_write_internal(cell_id=cell_id, value=value, stream=stream) + + return replace + + # Adapted from fastprogress def in_notebook(): def in_colab(): @@ -362,6 +414,28 @@ def callback(formatted): self._html = formatted self.display_id.update(self) + progress_type = _lib.ProgressType.template_callback( + progress_rate, progress_template, cores, callback + ) + elif in_marimo_notebook(): + import marimo as mo + + if progress_template is None: + progress_template = _progress_template + + if progress_style is None: + progress_style = _progress_style + + self._html = "" + + mo.output.clear() + mo_output_replace = _mo_create_replace() + + def callback(formatted): + self._html = formatted + html = mo.Html(f"{progress_style}\n{formatted}") + mo_output_replace(html) + progress_type = _lib.ProgressType.template_callback( progress_rate, progress_template, cores, callback )