Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions python/nutpie/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,58 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
"""


def in_marimo_notebook() -> bool:
try:
import marimo as mo

return mo.running_in_notebook()
except ImportError:
return False


def _mo_write_internal(cell_id, stream, value: object) -> None:
"""Write to marimo cell given cell_id and stream."""
from marimo._output import formatting
from marimo._messaging.ops import CellOp
from marimo._messaging.tracebacks import write_traceback
from marimo._messaging.cell_output import CellChannel

output = formatting.try_format(value)
if output.traceback is not None:
write_traceback(output.traceback)
CellOp.broadcast_output(
channel=CellChannel.OUTPUT,
mimetype=output.mimetype,
data=output.data,
cell_id=cell_id,
status=None,
stream=stream,
)


def _mo_create_replace():
"""Create mo.output.replace with current context pinned."""
from marimo._runtime.context import get_context
from marimo._runtime.context.types import ContextNotInitializedError
from marimo._output import formatting

try:
ctx = get_context()
except ContextNotInitializedError:
return

cell_id = ctx.execution_context.cell_id
execution_context = ctx.execution_context
stream = ctx.stream

def replace(value):
execution_context.output = [formatting.as_html(value)]

_mo_write_internal(cell_id=cell_id, value=value, stream=stream)

return replace


# Adapted from fastprogress
def in_notebook():
def in_colab():
Expand Down Expand Up @@ -362,6 +414,28 @@ def callback(formatted):
self._html = formatted
self.display_id.update(self)

progress_type = _lib.ProgressType.template_callback(
progress_rate, progress_template, cores, callback
)
elif in_marimo_notebook():
import marimo as mo

if progress_template is None:
progress_template = _progress_template

if progress_style is None:
progress_style = _progress_style

self._html = ""

mo.output.clear()
mo_output_replace = _mo_create_replace()

def callback(formatted):
self._html = formatted
html = mo.Html(f"{progress_style}\n{formatted}")
mo_output_replace(html)

progress_type = _lib.ProgressType.template_callback(
progress_rate, progress_template, cores, callback
)
Expand Down