[docs]classProgressBar(BaseLogger):""" TQDM progress bar handler to log training progress and computed metrics. Args: persist: set to ``True`` to persist the progress bar after completion (default = ``False``) bar_format : Specify a custom bar string formatting. May impact performance. [default: '{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]']. Set to ``None`` to use ``tqdm`` default bar formatting: '{l_bar}{bar}{r_bar}', where l_bar='{desc}: {percentage:3.0f}%|' and r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'. For more details on the formatting, see `tqdm docs <https://tqdm.github.io/docs/tqdm/>`_. tqdm_kwargs: kwargs passed to tqdm progress bar. By default, progress bar description displays "Epoch [5/10]" where 5 is the current epoch and 10 is the number of epochs; however, if ``max_epochs`` are set to 1, the progress bar instead displays "Iteration: [5/10]". If tqdm_kwargs defines `desc`, e.g. "Predictions", than the description is "Predictions [5/10]" if number of epochs is more than one otherwise it is simply "Predictions". Examples: Simple progress bar .. code-block:: python trainer = create_supervised_trainer(model, optimizer, loss) pbar = ProgressBar() pbar.attach(trainer) # Progress bar will looks like # Epoch [2/50]: [64/128] 50%|█████ [06:17<12:34] Log output to a file instead of stderr (tqdm's default output) .. code-block:: python trainer = create_supervised_trainer(model, optimizer, loss) log_file = open("output.log", "w") pbar = ProgressBar(file=log_file) pbar.attach(trainer) Attach metrics that already have been computed at :attr:`~ignite.engine.events.Events.ITERATION_COMPLETED` (such as :class:`~ignite.metrics.RunningAverage`) .. code-block:: python trainer = create_supervised_trainer(model, optimizer, loss) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') pbar = ProgressBar() pbar.attach(trainer, ['loss']) # Progress bar will looks like # Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34] Directly attach the engine's output .. code-block:: python trainer = create_supervised_trainer(model, optimizer, loss) pbar = ProgressBar() pbar.attach(trainer, output_transform=lambda x: {'loss': x}) # Progress bar will looks like # Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34] Example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta`` are also logged along with the NLL and Accuracy after each iteration: .. code-block:: python pbar.attach( trainer, metric_names=["nll", "accuracy"], state_attributes=["alpha", "beta"], ) Note: When adding attaching the progress bar to an engine, it is recommend that you replace every print operation in the engine's handlers triggered every iteration with ``pbar.log_message`` to guarantee the correct format of the stdout. Note: When using inside jupyter notebook, `ProgressBar` automatically uses `tqdm_notebook`. For correct rendering, please install `ipywidgets <https://ipywidgets.readthedocs.io/en/stable/user_install.html#installation>`_. Due to `tqdm notebook bugs <https://github.com/tqdm/tqdm/issues/594>`_, bar format may be needed to be set to an empty string value. .. versionchanged:: 0.5.0 `attach` now accepts an optional list of `state_attributes` """_events_order=[Events.STARTED,Events.EPOCH_STARTED,Events.ITERATION_STARTED,Events.ITERATION_COMPLETED,Events.EPOCH_COMPLETED,Events.COMPLETED,]# type: List[Union[Events, CallableEventWithFilter]]def__init__(self,persist:bool=False,bar_format:str="{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]",**tqdm_kwargs:Any,):try:fromtqdm.autonotebookimporttqdmexceptImportError:raiseRuntimeError("This contrib module requires tqdm to be installed. ""Please install it with command: \n pip install tqdm")self.pbar_cls=tqdmself.pbar=Noneself.persist=persistself.bar_format=bar_formatself.tqdm_kwargs=tqdm_kwargsdef_reset(self,pbar_total:Optional[int])->None:self.pbar=self.pbar_cls(total=pbar_total,leave=self.persist,bar_format=self.bar_format,initial=1,**self.tqdm_kwargs)def_close(self,engine:Engine)->None:ifself.pbarisnotNone:# https://github.com/tqdm/notebook.py#L240-L250# issue #1115 : notebook backend of tqdm checks if n < total (error or KeyboardInterrupt)# and the bar persists in 'danger' modeifself.pbar.totalisnotNone:self.pbar.n=self.pbar.totalself.pbar.close()self.pbar=None@staticmethoddef_compare_lt(event1:Union[Events,CallableEventWithFilter],event2:Union[Events,CallableEventWithFilter])->bool:i1=ProgressBar._events_order.index(event1)i2=ProgressBar._events_order.index(event2)returni1<i2
[docs]deflog_message(self,message:str)->None:""" Logs a message, preserving the progress bar correct output format. Args: message: string you wish to log. """fromtqdmimporttqdmtqdm.write(message,file=self.tqdm_kwargs.get("file",None))
[docs]defattach(# type: ignore[override]self,engine:Engine,metric_names:Optional[Union[str,List[str]]]=None,output_transform:Optional[Callable]=None,event_name:Union[Events,CallableEventWithFilter]=Events.ITERATION_COMPLETED,closing_event_name:Union[Events,CallableEventWithFilter]=Events.EPOCH_COMPLETED,state_attributes:Optional[List[str]]=None,)->None:""" Attaches the progress bar to an engine object. Args: engine: engine object. metric_names: list of metric names to plot or a string "all" to plot all available metrics. output_transform: a function to select what you want to print from the engine's output. This function may return either a dictionary with entries in the format of ``{name: value}``, or a single scalar, which will be displayed with the default name `output`. event_name: event's name on which the progress bar advances. Valid events are from :class:`~ignite.engine.events.Events`. closing_event_name: event's name on which the progress bar is closed. Valid events are from :class:`~ignite.engine.events.Events`. state_attributes: list of attributes of the ``trainer.state`` to plot. Note: Accepted output value types are numbers, 0d and 1d torch tensors and strings. """desc=self.tqdm_kwargs.get("desc",None)ifevent_namenotinengine._allowed_events:raiseValueError(f"Logging event {event_name.name} is not in allowed events for this engine")ifisinstance(closing_event_name,CallableEventWithFilter):ifclosing_event_name.filter!=CallableEventWithFilter.default_event_filter:raiseValueError("Closing Event should not be a filtered event")ifnotself._compare_lt(event_name,closing_event_name):raiseValueError(f"Logging event {event_name} should be called before closing event {closing_event_name}")log_handler=_OutputHandler(desc,metric_names,output_transform,closing_event_name=closing_event_name,state_attributes=state_attributes,)super(ProgressBar,self).attach(engine,log_handler,event_name)engine.add_event_handler(closing_event_name,self._close)
class_OutputHandler(BaseOutputHandler):"""Helper handler to log engine's output and/or metrics pbar = ProgressBar() Args: description: progress bar description. metric_names: list of metric names to plot or a string "all" to plot all available metrics. output_transform: output transform function to prepare `engine.state.output` as a number. For example, `output_transform = lambda output: output` This function can also return a dictionary, e.g `{'loss': loss1, 'another_loss': loss2}` to label the plot with corresponding keys. closing_event_name: event's name on which the progress bar is closed. Valid events are from :class:`~ignite.engine.events.Events` or any `event_name` added by :meth:`~ignite.engine.engine.Engine.register_events`. state_attributes: list of attributes of the ``trainer.state`` to plot. """def__init__(self,description:str,metric_names:Optional[Union[str,List[str]]]=None,output_transform:Optional[Callable]=None,closing_event_name:Union[Events,CallableEventWithFilter]=Events.EPOCH_COMPLETED,state_attributes:Optional[List[str]]=None,):ifmetric_namesisNoneandoutput_transformisNone:# This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandlermetric_names=[]super(_OutputHandler,self).__init__(description,metric_names,output_transform,global_step_transform=None,state_attributes=state_attributes)self.closing_event_name=closing_event_name@staticmethoddefget_max_number_events(event_name:Union[str,Events,CallableEventWithFilter],engine:Engine)->Optional[int]:ifevent_namein(Events.ITERATION_STARTED,Events.ITERATION_COMPLETED):returnengine.state.epoch_lengthifevent_namein(Events.EPOCH_STARTED,Events.EPOCH_COMPLETED):returnengine.state.max_epochsreturn1def__call__(self,engine:Engine,logger:ProgressBar,event_name:Union[str,Events])->None:pbar_total=self.get_max_number_events(event_name,engine)iflogger.pbarisNone:logger._reset(pbar_total=pbar_total)max_epochs=engine.state.max_epochsdefault_desc="Iteration"ifmax_epochs==1else"Epoch"desc=self.tagordefault_descmax_num_of_closing_events=self.get_max_number_events(self.closing_event_name,engine)ifmax_num_of_closing_eventsandmax_num_of_closing_events>1:global_step=engine.state.get_event_attrib_value(self.closing_event_name)desc+=f" [{global_step}/{max_num_of_closing_events}]"logger.pbar.set_description(desc)# type: ignore[attr-defined]rendered_metrics=self._setup_output_metrics_state_attrs(engine,log_text=True)metrics=OrderedDict()forkey,valueinrendered_metrics.items():key="_".join(key[1:])# tqdm has tag as descriptionmetrics[key]=valueifmetrics:logger.pbar.set_postfix(metrics)# type: ignore[attr-defined]global_step=engine.state.get_event_attrib_value(event_name)ifpbar_totalisnotNone:global_step=(global_step-1)%pbar_total+1logger.pbar.update(global_step-logger.pbar.n)# type: ignore[attr-defined]