@@ -267,6 +267,58 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
267267"""
268268
269269
270+ def in_marimo_notebook () -> bool :
271+ try :
272+ import marimo as mo
273+
274+ return mo .running_in_notebook ()
275+ except ImportError :
276+ return False
277+
278+
279+ def _mo_write_internal (cell_id , stream , value : object ) -> None :
280+ """Write to marimo cell given cell_id and stream."""
281+ from marimo ._output import formatting
282+ from marimo ._messaging .ops import CellOp
283+ from marimo ._messaging .tracebacks import write_traceback
284+ from marimo ._messaging .cell_output import CellChannel
285+
286+ output = formatting .try_format (value )
287+ if output .traceback is not None :
288+ write_traceback (output .traceback )
289+ CellOp .broadcast_output (
290+ channel = CellChannel .OUTPUT ,
291+ mimetype = output .mimetype ,
292+ data = output .data ,
293+ cell_id = cell_id ,
294+ status = None ,
295+ stream = stream ,
296+ )
297+
298+
299+ def _mo_create_replace ():
300+ """Create mo.output.replace with current context pinned."""
301+ from marimo ._runtime .context import get_context
302+ from marimo ._runtime .context .types import ContextNotInitializedError
303+ from marimo ._output import formatting
304+
305+ try :
306+ ctx = get_context ()
307+ except ContextNotInitializedError :
308+ return
309+
310+ cell_id = ctx .execution_context .cell_id
311+ execution_context = ctx .execution_context
312+ stream = ctx .stream
313+
314+ def replace (value ):
315+ execution_context .output = [formatting .as_html (value )]
316+
317+ _mo_write_internal (cell_id = cell_id , value = value , stream = stream )
318+
319+ return replace
320+
321+
270322# Adapted from fastprogress
271323def in_notebook ():
272324 def in_colab ():
@@ -362,6 +414,28 @@ def callback(formatted):
362414 self ._html = formatted
363415 self .display_id .update (self )
364416
417+ progress_type = _lib .ProgressType .template_callback (
418+ progress_rate , progress_template , cores , callback
419+ )
420+ elif in_marimo_notebook ():
421+ import marimo as mo
422+
423+ if progress_template is None :
424+ progress_template = _progress_template
425+
426+ if progress_style is None :
427+ progress_style = _progress_style
428+
429+ self ._html = ""
430+
431+ mo .output .clear ()
432+ mo_output_replace = _mo_create_replace ()
433+
434+ def callback (formatted ):
435+ self ._html = formatted
436+ html = mo .Html (f"{ progress_style } \n { formatted } " )
437+ mo_output_replace (html )
438+
365439 progress_type = _lib .ProgressType .template_callback (
366440 progress_rate , progress_template , cores , callback
367441 )
0 commit comments