Source code for torchvideo.transforms.transforms.ndarray_to_pil_video

import PIL
import numpy as np

from .types import PILVideoI
from .transform import Transform


[docs]class NDArrayToPILVideo(Transform[np.ndarray, PILVideoI, None]): """Convert :py:class:`numpy.ndarray` of the format :math:`(T, H, W, C)` or :math:`( C, T, H, W)` to a PIL video (an iterator of PIL images) """ def __init__(self, format="thwc"): """ Args: format: dimensional layout of array, one of ``"thwc"`` or ``"cthw"`` """ if format.lower() not in {"thwc", "cthw"}: raise ValueError( "Invalid format {!r}, only 'thwc' and 'cthw' are " "supported".format(format) ) self.format = format def _transform(self, frames: np.ndarray, params: None) -> PILVideoI: if self.format == "cthw": frames = np.moveaxis(frames, 0, -1) for frame in frames: yield PIL.Image.fromarray(frame) def _gen_params(self, frames: np.ndarray) -> None: return None def __repr__(self): return self.__class__.__name__ + "(format={!r})".format(self.format)