如何取樣影片剪輯¶
在此示例中,我們將學習如何從影片中取樣影片 剪輯。剪輯通常指幀的序列或批次,並且通常作為影片模型的輸入傳遞。
首先,是一些樣板程式碼:我們將從網上下載一個影片,並定義一個繪圖工具。您可以忽略這部分,直接跳轉到 建立解碼器。
from typing import Optional
import torch
import requests
# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
raw_video_bytes = response.content
def plot(frames: torch.Tensor, title : Optional[str] = None):
try:
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
import matplotlib.pyplot as plt
except ImportError:
print("Cannot plot, please run `pip install torchvision matplotlib`")
return
plt.rcParams["savefig.bbox"] = 'tight'
fig, ax = plt.subplots()
ax.imshow(to_pil_image(make_grid(frames)))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if title is not None:
ax.set_title(title)
plt.tight_layout()
建立解碼器¶
從影片中取樣剪輯總是從建立一個 VideoDecoder 物件開始。如果您還不熟悉 VideoDecoder,請快速檢視:使用 VideoDecoder 解碼影片。
from torchcodec.decoders import VideoDecoder
# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)
取樣基礎知識¶
我們現在可以使用解碼器來取樣剪輯。讓我們先看一個簡單的例子:所有其他取樣器都遵循類似的 API 和原理。我們將使用 clips_at_random_indices() 來取樣從隨機索引開始的剪輯。
from torchcodec.samplers import clips_at_random_indices
# The samplers RNG is controlled by pytorch's RNG. We set a seed for this
# tutorial to be reproducible across runs, but note that hard-coding a seed for
# a training run is generally not recommended.
torch.manual_seed(0)
clips = clips_at_random_indices(
decoder,
num_clips=5,
num_frames_per_clip=4,
num_indices_between_frames=3,
)
clips
FrameBatch:
data (shape): torch.Size([5, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[10.2000, 10.3200, 10.4400, 10.5600],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 9.6000, 9.7200, 9.8400, 9.9600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
取樣器的輸出是一系列剪輯,表示為 FrameBatch 物件。在此物件中,我們有不同的欄位
data: 一個 5D uint8 張量,表示幀資料。其形狀為 (num_clips, num_frames_per_clip, …),其中 … 是 (C, H, W) 或 (H, W, C),具體取決於VideoDecoder的dimension_order引數。這通常會傳遞給模型。pts_seconds: 一個形狀為 (num_clips, num_frames_per_clip) 的 2D 浮點張量,給出每個剪輯中每幀的起始時間戳(以秒為單位)。duration_seconds: 一個形狀為 (num_clips, num_frames_per_clip) 的 2D 浮點張量,給出每個剪輯中每幀的時長(以秒為單位)。
plot(clips[0].data)

索引和操作剪輯¶
剪輯是 FrameBatch 物件,它們支援原生的 PyTorch 索引語義(包括花式索引)。這使得根據給定標準輕鬆過濾剪輯變得容易。例如,從上面的剪輯中,我們可以輕鬆地過濾掉那些在特定時間戳之後開始的剪輯
clip_starts = clips.pts_seconds[:, 0]
clip_starts
tensor([11.3600, 10.2000, 9.8000, 9.6000, 8.4400], dtype=torch.float64)
clips_starting_after_five_seconds = clips[clip_starts > 5]
clips_starting_after_five_seconds
FrameBatch:
data (shape): torch.Size([5, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[10.2000, 10.3200, 10.4400, 10.5600],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 9.6000, 9.7200, 9.8400, 9.9600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
every_other_clip = clips[::2]
every_other_clip
FrameBatch:
data (shape): torch.Size([3, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
注意
獲取給定時間戳之後剪輯的一種更自然、更有效的方法是依賴取樣範圍引數,我們將在後面的 高階引數:取樣範圍 中介紹。
基於索引和基於時間的取樣器¶
到目前為止,我們使用了 clips_at_random_indices()。Torchcodec 支援其他取樣器,它們分為兩大類:
基於索引的取樣器
基於時間的取樣器
所有這些取樣器都遵循類似的 API,並且基於時間的取樣器具有與基於索引的取樣器類似的引數。兩種取樣器型別在速度方面通常具有相當的效能。
注意
使用基於時間的取樣器還是基於索引的取樣器更好?基於索引的取樣器具有更簡單的 API,並且由於索引的離散性質,其行為可能更容易理解和控制。對於具有恆定幀率的影片,基於索引的取樣器與基於時間的取樣器行為完全相同。但是,對於具有可變幀率的影片(這很常見),依賴索引可能會對影片的某些區域進行欠取樣/過取樣,這可能導致模型訓練時產生不良的副作用。使用基於時間的取樣器可確保時間維度的統一取樣特性。
高階引數:取樣範圍¶
有時,我們可能不希望從整個影片中取樣剪輯。我們可能只對在較小區間內開始的剪輯感興趣。在取樣器中,sampling_range_start 和 sampling_range_end 引數控制取樣範圍:它們定義了允許剪輯 *開始* 的位置。有兩件重要的事情需要牢記:
sampling_range_end是一個 *開放* 的上限:剪輯只能在 [sampling_range_start, sampling_range_end) 範圍內開始。由於這些引數定義了剪輯可以開始的位置,剪輯可能包含
sampling_range_end之後的幀!
from torchcodec.samplers import clips_at_regular_timestamps
clips = clips_at_regular_timestamps(
decoder,
seconds_between_clip_starts=1,
num_frames_per_clip=4,
seconds_between_frames=0.5,
sampling_range_start=2,
sampling_range_end=5
)
clips
FrameBatch:
data (shape): torch.Size([3, 4, 3, 360, 640])
pts_seconds: tensor([[2.0000, 2.4800, 3.0000, 3.4800],
[3.0000, 3.4800, 4.0000, 4.4800],
[4.0000, 4.4800, 5.0000, 5.4800]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
高階引數:策略¶
根據影片的時長或持續時間以及取樣引數,取樣器可能會嘗試取樣影片末尾之外的幀。policy 引數定義瞭如何用有效幀替換此類無效幀。
from torchcodec.samplers import clips_at_random_timestamps
end_of_video = decoder.metadata.end_stream_seconds
print(f"{end_of_video = }")
end_of_video = 13.8
torch.manual_seed(0)
clips = clips_at_random_timestamps(
decoder,
num_clips=1,
num_frames_per_clip=5,
seconds_between_frames=0.4,
sampling_range_start=end_of_video - 1,
sampling_range_end=end_of_video,
policy="repeat_last",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.6800, 13.6800, 13.6800]], dtype=torch.float64)
我們上面看到影片的末尾在 13.8 秒。取樣器嘗試在時間戳 [13.28, 13.68, 14.08, …] 處取樣幀,但 14.08 是一個無效的時間戳,超出了影片末尾。使用“repeat_last”策略(這是預設策略),取樣器會簡單地重複 13.68 秒的最後一幀來構建剪輯。
另一種策略是“wrap”:取樣器然後圍繞剪輯進行包裝,並在必要時重複前幾幀有效幀
torch.manual_seed(0)
clips = clips_at_random_timestamps(
decoder,
num_clips=1,
num_frames_per_clip=5,
seconds_between_frames=0.4,
sampling_range_start=end_of_video - 1,
sampling_range_end=end_of_video,
policy="wrap",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.2800, 13.6800, 13.2800]], dtype=torch.float64)
預設情況下,sampling_range_end 的值會自動設定為使取樣器 *不* 嘗試取樣影片末尾之外的幀:預設值可確保剪輯在末尾之前足夠早地開始。這意味著預設情況下,policy 引數很少生效,大多數使用者可能不必過多擔心它。
指令碼的總執行時間: (0 分鐘 0.612 秒)