快捷方式

如何取樣影片剪輯

在此示例中,我們將學習如何從影片中取樣影片 剪輯。剪輯通常指幀的序列或批次,並且通常作為影片模型的輸入傳遞。

首先,是一些樣板程式碼:我們將從網上下載一個影片,並定義一個繪圖工具。您可以忽略這部分,直接跳轉到 建立解碼器

from typing import Optional
import torch
import requests


# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
    raise RuntimeError(f"Failed to download video. {response.status_code = }.")

raw_video_bytes = response.content


def plot(frames: torch.Tensor, title : Optional[str] = None):
    try:
        from torchvision.utils import make_grid
        from torchvision.transforms.v2.functional import to_pil_image
        import matplotlib.pyplot as plt
    except ImportError:
        print("Cannot plot, please run `pip install torchvision matplotlib`")
        return

    plt.rcParams["savefig.bbox"] = 'tight'
    fig, ax = plt.subplots()
    ax.imshow(to_pil_image(make_grid(frames)))
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    if title is not None:
        ax.set_title(title)
    plt.tight_layout()

建立解碼器

從影片中取樣剪輯總是從建立一個 VideoDecoder 物件開始。如果您還不熟悉 VideoDecoder,請快速檢視:使用 VideoDecoder 解碼影片

from torchcodec.decoders import VideoDecoder

# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)

取樣基礎知識

我們現在可以使用解碼器來取樣剪輯。讓我們先看一個簡單的例子:所有其他取樣器都遵循類似的 API 和原理。我們將使用 clips_at_random_indices() 來取樣從隨機索引開始的剪輯。

from torchcodec.samplers import clips_at_random_indices

# The samplers RNG is controlled by pytorch's RNG. We set a seed for this
# tutorial to be reproducible across runs, but note that hard-coding a seed for
# a training run is generally not recommended.
torch.manual_seed(0)

clips = clips_at_random_indices(
    decoder,
    num_clips=5,
    num_frames_per_clip=4,
    num_indices_between_frames=3,
)
clips
FrameBatch:
  data (shape): torch.Size([5, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [10.2000, 10.3200, 10.4400, 10.5600],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 9.6000,  9.7200,  9.8400,  9.9600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

取樣器的輸出是一系列剪輯,表示為 FrameBatch 物件。在此物件中,我們有不同的欄位

  • data: 一個 5D uint8 張量,表示幀資料。其形狀為 (num_clips, num_frames_per_clip, …),其中 … 是 (C, H, W) 或 (H, W, C),具體取決於 VideoDecoderdimension_order 引數。這通常會傳遞給模型。

  • pts_seconds: 一個形狀為 (num_clips, num_frames_per_clip) 的 2D 浮點張量,給出每個剪輯中每幀的起始時間戳(以秒為單位)。

  • duration_seconds: 一個形狀為 (num_clips, num_frames_per_clip) 的 2D 浮點張量,給出每個剪輯中每幀的時長(以秒為單位)。

plot(clips[0].data)
sampling

索引和操作剪輯

剪輯是 FrameBatch 物件,它們支援原生的 PyTorch 索引語義(包括花式索引)。這使得根據給定標準輕鬆過濾剪輯變得容易。例如,從上面的剪輯中,我們可以輕鬆地過濾掉那些在特定時間戳之後開始的剪輯

tensor([11.3600, 10.2000,  9.8000,  9.6000,  8.4400], dtype=torch.float64)
clips_starting_after_five_seconds = clips[clip_starts > 5]
clips_starting_after_five_seconds
FrameBatch:
  data (shape): torch.Size([5, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [10.2000, 10.3200, 10.4400, 10.5600],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 9.6000,  9.7200,  9.8400,  9.9600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
every_other_clip = clips[::2]
every_other_clip
FrameBatch:
  data (shape): torch.Size([3, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

注意

獲取給定時間戳之後剪輯的一種更自然、更有效的方法是依賴取樣範圍引數,我們將在後面的 高階引數:取樣範圍 中介紹。

基於索引和基於時間的取樣器

到目前為止,我們使用了 clips_at_random_indices()。Torchcodec 支援其他取樣器,它們分為兩大類:

基於索引的取樣器

基於時間的取樣器

所有這些取樣器都遵循類似的 API,並且基於時間的取樣器具有與基於索引的取樣器類似的引數。兩種取樣器型別在速度方面通常具有相當的效能。

注意

使用基於時間的取樣器還是基於索引的取樣器更好?基於索引的取樣器具有更簡單的 API,並且由於索引的離散性質,其行為可能更容易理解和控制。對於具有恆定幀率的影片,基於索引的取樣器與基於時間的取樣器行為完全相同。但是,對於具有可變幀率的影片(這很常見),依賴索引可能會對影片的某些區域進行欠取樣/過取樣,這可能導致模型訓練時產生不良的副作用。使用基於時間的取樣器可確保時間維度的統一取樣特性。

高階引數:取樣範圍

有時,我們可能不希望從整個影片中取樣剪輯。我們可能只對在較小區間內開始的剪輯感興趣。在取樣器中,sampling_range_startsampling_range_end 引數控制取樣範圍:它們定義了允許剪輯 *開始* 的位置。有兩件重要的事情需要牢記:

  • sampling_range_end 是一個 *開放* 的上限:剪輯只能在 [sampling_range_start, sampling_range_end) 範圍內開始。

  • 由於這些引數定義了剪輯可以開始的位置,剪輯可能包含 sampling_range_end 之後的幀!

from torchcodec.samplers import clips_at_regular_timestamps

clips = clips_at_regular_timestamps(
    decoder,
    seconds_between_clip_starts=1,
    num_frames_per_clip=4,
    seconds_between_frames=0.5,
    sampling_range_start=2,
    sampling_range_end=5
)
clips
FrameBatch:
  data (shape): torch.Size([3, 4, 3, 360, 640])
  pts_seconds: tensor([[2.0000, 2.4800, 3.0000, 3.4800],
        [3.0000, 3.4800, 4.0000, 4.4800],
        [4.0000, 4.4800, 5.0000, 5.4800]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

高階引數:策略

根據影片的時長或持續時間以及取樣引數,取樣器可能會嘗試取樣影片末尾之外的幀。policy 引數定義瞭如何用有效幀替換此類無效幀。

from torchcodec.samplers import clips_at_random_timestamps

end_of_video = decoder.metadata.end_stream_seconds
print(f"{end_of_video = }")
end_of_video = 13.8
torch.manual_seed(0)
clips = clips_at_random_timestamps(
    decoder,
    num_clips=1,
    num_frames_per_clip=5,
    seconds_between_frames=0.4,
    sampling_range_start=end_of_video - 1,
    sampling_range_end=end_of_video,
    policy="repeat_last",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.6800, 13.6800, 13.6800]], dtype=torch.float64)

我們上面看到影片的末尾在 13.8 秒。取樣器嘗試在時間戳 [13.28, 13.68, 14.08, …] 處取樣幀,但 14.08 是一個無效的時間戳,超出了影片末尾。使用“repeat_last”策略(這是預設策略),取樣器會簡單地重複 13.68 秒的最後一幀來構建剪輯。

另一種策略是“wrap”:取樣器然後圍繞剪輯進行包裝,並在必要時重複前幾幀有效幀

torch.manual_seed(0)
clips = clips_at_random_timestamps(
    decoder,
    num_clips=1,
    num_frames_per_clip=5,
    seconds_between_frames=0.4,
    sampling_range_start=end_of_video - 1,
    sampling_range_end=end_of_video,
    policy="wrap",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.2800, 13.6800, 13.2800]], dtype=torch.float64)

預設情況下,sampling_range_end 的值會自動設定為使取樣器 *不* 嘗試取樣影片末尾之外的幀:預設值可確保剪輯在末尾之前足夠早地開始。這意味著預設情況下,policy 引數很少生效,大多數使用者可能不必過多擔心它。

指令碼的總執行時間: (0 分鐘 0.612 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源