Source code for torchvideo.datasets.image_folder_video_dataset

import torch
from pathlib import Path
from typing import Union, Optional, Callable, Tuple, List, Iterator

import PIL.Image
from PIL.Image import Image

from torchvideo.samplers import FrameSampler, frame_idx_to_list, _default_sampler
from torchvideo.transforms import PILVideoToTensor
from .video_dataset import VideoDataset
from .types import Label, empty_label, PILVideoTransform
from .helpers import invoke_transform
from .label_sets import LabelSet


[docs]class ImageFolderVideoDataset(VideoDataset): """Dataset stored as a folder containing folders of images, where each folder represents a video. The folder hierarchy should look something like this: :: root/video1/frame_000001.jpg root/video1/frame_000002.jpg root/video1/frame_000003.jpg ... root/video2/frame_000001.jpg root/video2/frame_000002.jpg root/video2/frame_000003.jpg root/video2/frame_000004.jpg ... """ def __init__( self, root_path: Union[str, Path], filename_template: str, filter: Optional[Callable[[Path], bool]] = None, label_set: Optional[LabelSet] = None, sampler: FrameSampler = _default_sampler(), transform: Optional[PILVideoTransform] = None, frame_counter: Optional[Callable[[Path], int]] = None, ): """ Args: root_path: Path to dataset on disk. Contents of this folder should be example folders, each with frames named according to the ``filename_template`` argument. filename_template: Python 3 style formatting string describing frame filenames: e.g. ``"frame_{:06d}.jpg"`` for the example dataset in the class docstring. filter: Optional filter callable that decides whether a given example folder is to be included in the dataset or not. label_set: Optional label set for labelling examples. sampler: Optional sampler for drawing frames from each video. transform: Optional transform performed over the loaded clip. frame_counter: Optional callable used to determine the number of frames each video contains. The callable will be passed the path to a video folder and should return a positive integer representing the number of frames. This tends to be useful if you've precomputed the number of frames in a dataset. """ super().__init__(root_path, label_set, sampler=sampler, transform=transform) self._video_dirs = sorted( [d for d in self.root_path.iterdir() if filter is None or filter(d)] ) self.labels = self._label_examples(self._video_dirs, label_set) self.video_lengths = self._measure_video_lengths( self._video_dirs, frame_counter ) self.filename_template = filename_template if self.transform is None: self.transform = PILVideoToTensor() @property def video_ids(self): return self._video_dirs
[docs] def __len__(self) -> int: return len(self._video_dirs)
[docs] def __getitem__( self, index: int ) -> Union[torch.Tensor, Tuple[torch.Tensor, Label]]: video_folder = self._video_dirs[index] video_length = self.video_lengths[index] frames_idx = self.sampler.sample(video_length) frames = self._load_frames(frames_idx, video_folder) if self.labels is not None: label = self.labels[index] else: label = empty_label frames_tensor, label = invoke_transform(self.transform, frames, label) if label == empty_label: return frames_tensor return frames_tensor, label
@staticmethod def _measure_video_lengths( video_dirs, frame_counter: Optional[Callable[[Path], int]] ): if frame_counter is None: return [len(list(video_dir.iterdir())) for video_dir in video_dirs] else: return [frame_counter(video_dir) for video_dir in video_dirs] @staticmethod def _label_examples(video_dirs, label_set: Optional[LabelSet]): if label_set is not None: return [label_set[video_dir.name] for video_dir in video_dirs] else: return None def _load_frames( self, frames_idx: Union[slice, List[slice], List[int]], video_folder: Path ) -> Iterator[Image]: frame_numbers = frame_idx_to_list(frames_idx) filepaths = [ video_folder / self.filename_template.format(index + 1) for index in frame_numbers ] frames = (self._load_image(path) for path in filepaths) # shape: (n_frames, height, width, channels) return frames def _load_image(self, path: Path) -> Image: if not path.exists(): raise ValueError("Image path {} does not exist".format(path)) return PIL.Image.open(str(path))