Source code for ignite.handlers.terminate_on_nan
import logging
import numbers
from typing import Callable, Union
import torch
from ignite.engine import Engine
from ignite.utils import apply_to_type
__all__ = ["TerminateOnNan"]
[docs]class TerminateOnNan:
"""TerminateOnNan handler can be used to stop the training if the `process_function`'s output
contains a NaN or infinite number or `torch.tensor`.
The output can be of type: number, tensor or collection of them. The training is stopped if
there is at least a single number/tensor have NaN or Infinite value. For example, if the output is
`[1.23, torch.tensor(...), torch.tensor(float('nan'))]` the handler will stop the training.
Args:
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into a number or `torch.tensor`
or collection of them. This can be useful if, for example, you have a multi-output model and
you want to check one or multiple values of the output.
Examples:
.. code-block:: python
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
"""
def __init__(self, output_transform: Callable = lambda x: x):
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self.logger.addHandler(logging.StreamHandler())
self._output_transform = output_transform
def __call__(self, engine: Engine) -> None:
output = self._output_transform(engine.state.output)
def raise_error(x: Union[float, torch.Tensor]) -> None:
if isinstance(x, numbers.Number):
x = torch.tensor(x)
if isinstance(x, torch.Tensor) and not bool(torch.isfinite(x).all()):
raise RuntimeError("Infinite or NaN tensor found.")
try:
apply_to_type(output, (numbers.Number, torch.Tensor), raise_error)
except RuntimeError:
self.logger.warning(f"{self.__class__.__name__}: Output '{output}' contains NaN or Inf. Stop training")
engine.terminate()