Source code for torchvideo.datasets.label_sets.csv_label_set

from typing import Optional

from .label_set import LabelSet, Label


[docs]class CsvLabelSet(LabelSet): """LabelSet for a pandas DataFrame or Series. The index of the DataFrame/Series is assumed to be the set of video names and the values in a series the label. For a dataframe the ``field`` kwarg specifies which field to use as the label Examples: >>> import pandas as pd >>> df = pd.DataFrame({'video': ['video1', 'video2'], ... 'label': [1, 2]}).set_index('video') >>> label_set = CsvLabelSet(df, col='label') >>> label_set['video1'] 1 """ def __init__(self, df, col: Optional[str] = "label"): """ Args: df: pandas DataFrame or Series containing video names/ids and their corresponding labels. col: The column to read the label from when df is a DataFrame. """ self.df = df self._field = col
[docs] def __getitem__(self, video_name: str) -> Label: try: return self.df[self._field].loc[video_name] except KeyError: return self.df.loc[video_name]