59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
from rich.console import ConsoleRenderable, Group, RichCast
|
|
from rich.progress import Progress, Task, TaskID
|
|
|
|
|
|
class CustomProgress(Progress):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
cust_objs: list[ConsoleRenderable] | None = None,
|
|
cust_callables: list[Callable[[], ConsoleRenderable]] | None = None,
|
|
ft_callback: Callable[[Task], None] | None = None,
|
|
**kwargs,
|
|
) -> None:
|
|
self._cust_objs = cust_objs or []
|
|
self._cust_callables = cust_callables or []
|
|
self._ft_callback = ft_callback
|
|
if self._ft_callback:
|
|
kwargs["disable"] = True
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def update(
|
|
self,
|
|
task_id: TaskID,
|
|
*,
|
|
total: float | None = None,
|
|
completed: float | None = None,
|
|
advance: float | None = None,
|
|
description: str | None = None,
|
|
visible: bool | None = None,
|
|
refresh: bool = False,
|
|
**fields: Any,
|
|
) -> None:
|
|
t = super().update(
|
|
task_id,
|
|
total=total,
|
|
completed=completed,
|
|
advance=advance,
|
|
description=description,
|
|
visible=visible,
|
|
refresh=refresh,
|
|
**fields,
|
|
)
|
|
if self._ft_callback:
|
|
self._ft_callback(
|
|
self.tasks[task_id],
|
|
)
|
|
return t
|
|
|
|
def get_renderable(self) -> ConsoleRenderable | RichCast | str:
|
|
objs = [obj for obj in self._cust_objs]
|
|
for cust_call in self._cust_callables:
|
|
objs.append(cust_call())
|
|
renderable = Group(*objs, *self.get_renderables())
|
|
return renderable
|