Source code for torchvideo.transforms.transforms.normalize_video

import numbers
from typing import Union, Sequence

import torch

from .. import functional as VF
from .transform import Transform


[docs]class NormalizeVideo(Transform[torch.Tensor, torch.Tensor, None]): r""" Normalise ``torch.*Tensor`` :math:`t` given mean: :math:`M = (\mu_1, \ldots, \mu_n)` and std: :math:`\Sigma = (\sigma_1, \ldots, \sigma_n)`: :math:`t'_c = \frac{t_c - M_c}{\Sigma_c}` Args: mean: Sequence of means for each channel, or a single mean applying to all channels. std: Sequence of standard deviations for each channel, or a single standard deviation applying to all channels. channel_dim: Index of channel dimension. 0 for ``'CTHW'`` tensors and ` for ``'TCHW'`` tensors. inplace: Whether or not to perform the operation in place without allocating a new tensor. """ def __init__( self, mean: Union[Sequence[numbers.Number], numbers.Number], std: Union[Sequence[numbers.Number], numbers.Number], channel_dim: int = 0, inplace: bool = False, ): self.mean = mean self.std = std self.inplace = inplace self.channel_dim = channel_dim if isinstance(std, numbers.Number) and std == 0: raise ValueError("std cannot be 0") if isinstance(std, Sequence) and any([s == 0 for s in std]): raise ValueError("std {} contained 0 value, cannot be 0".format(std)) def _gen_params(self, frames: torch.Tensor) -> None: return None def __repr__(self) -> str: return ( self.__class__.__name__ + "(mean={mean!r}, std={std!r}, channel_dim={channel_dim!r})".format( mean=self.mean, std=self.std, channel_dim=self.channel_dim ) ) def _transform(self, frames: torch.Tensor, params: None) -> torch.Tensor: channel_count = frames.shape[self.channel_dim] mean = self._broadcast_to_seq(self.mean, channel_count) std = self._broadcast_to_seq(self.std, channel_count) return VF.normalize( frames, mean, std, inplace=self.inplace, channel_dim=self.channel_dim ) @staticmethod def _broadcast_to_seq( x: Union[numbers.Number, Sequence], channel_count: int ) -> Sequence[numbers.Number]: if isinstance(x, numbers.Number): return [x] * channel_count # else assume already a sequence return x