使用 VideoDecoder 解碼影片¶
在此示例中,我們將學習如何使用 VideoDecoder 類來解碼影片。
首先,一些樣板程式碼:我們將從網上下載一個影片,並定義一個繪圖實用程式。您可以忽略這部分,直接跳轉到 建立解碼器。
from typing import Optional
import torch
import requests
# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
raw_video_bytes = response.content
def plot(frames: torch.Tensor, title : Optional[str] = None):
try:
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
import matplotlib.pyplot as plt
except ImportError:
print("Cannot plot, please run `pip install torchvision matplotlib`")
return
plt.rcParams["savefig.bbox"] = 'tight'
fig, ax = plt.subplots()
ax.imshow(to_pil_image(make_grid(frames)))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if title is not None:
ax.set_title(title)
plt.tight_layout()
建立解碼器¶
我們現在可以從原始(編碼)影片位元組建立解碼器。您當然也可以使用本地影片檔案並將路徑作為輸入,而不是下載影片。
from torchcodec.decoders import VideoDecoder
# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)
影片尚未被解碼器解碼,但我們已經可以透過 metadata 屬性訪問一些元資料,該屬性是一個 VideoStreamMetadata 物件。
print(decoder.metadata)
VideoStreamMetadata:
duration_seconds_from_header: 13.8
begin_stream_seconds_from_header: 0.0
bit_rate: 505790.0
codec: h264
stream_index: 0
begin_stream_seconds_from_content: 0.0
end_stream_seconds_from_content: 13.8
width: 640
height: 360
num_frames_from_header: 345
num_frames_from_content: 345
average_fps_from_header: 25.0
pixel_aspect_ratio: 1
duration_seconds: 13.8
begin_stream_seconds: 0.0
end_stream_seconds: 13.8
num_frames: 345
average_fps: 25.0
透過索引解碼器來解碼幀¶
first_frame = decoder[0] # using a single int index
every_twenty_frame = decoder[0 : -1 : 20] # using slices
print(f"{first_frame.shape = }")
print(f"{first_frame.dtype = }")
print(f"{every_twenty_frame.shape = }")
print(f"{every_twenty_frame.dtype = }")
first_frame.shape = torch.Size([3, 360, 640])
first_frame.dtype = torch.uint8
every_twenty_frame.shape = torch.Size([18, 3, 360, 640])
every_twenty_frame.dtype = torch.uint8
索引解碼器會將幀作為 torch.Tensor 物件返回。預設情況下,幀的形狀為 (N, C, H, W),其中 N 是批次大小,C 是通道數,H 是高度,W 是幀的寬度。批次維度 N 僅在我們解碼多個幀時存在。可以使用 VideoDecoder 的 dimension_order 引數將維度順序更改為 N, H, W, C。幀的 dtype 始終是 torch.uint8。
注意
如果您需要解碼多個幀,我們建議使用批處理方法,因為它們速度更快:get_frames_at()、get_frames_in_range()、get_frames_played_at() 和 get_frames_played_in_range()。它們在下面進行了描述。
plot(first_frame, "First frame")

plot(every_twenty_frame, "Every 20 frame")

遍歷幀¶
解碼器是一個正常的迭代物件,可以像這樣進行迭代
for frame in decoder:
assert (
isinstance(frame, torch.Tensor)
and frame.shape == (3, decoder.metadata.height, decoder.metadata.width)
)
檢索幀的 pts 和持續時間¶
索引解碼器會返回純粹的 torch.Tensor 物件。有時,檢索幀的額外資訊(例如它們的 pts(顯示時間戳)和持續時間)會很有用。這可以透過 get_frame_at() 和 get_frames_at() 方法實現,它們將分別返回 Frame 和 FrameBatch 物件。
last_frame = decoder.get_frame_at(len(decoder) - 1)
print(f"{type(last_frame) = }")
print(last_frame)
type(last_frame) = <class 'torchcodec._frame.Frame'>
Frame:
data (shape): torch.Size([3, 360, 640])
pts_seconds: 13.76
duration_seconds: 0.04
other_frames = decoder.get_frames_at([10, 0, 50])
print(f"{type(other_frames) = }")
print(other_frames)
type(other_frames) = <class 'torchcodec._frame.FrameBatch'>
FrameBatch:
data (shape): torch.Size([3, 3, 360, 640])
pts_seconds: tensor([0.4000, 0.0000, 2.0000], dtype=torch.float64)
duration_seconds: tensor([0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(last_frame.data, "Last frame")
plot(other_frames.data, "Other frames")
Frame 和 FrameBatch 都有一個 data 欄位,其中包含解碼後的張量資料。它們還具有 pts_seconds 和 duration_seconds 欄位,對於 Frame 是單個整數,對於 FrameBatch 是 1D torch.Tensor(批次中的每個幀一個值)。
使用基於時間的索引¶
到目前為止,我們都是根據幀的索引來檢索幀的。我們也可以使用 get_frame_played_at() 和 get_frames_played_at() 根據幀的播放時間來檢索幀,它們也分別返回 Frame 和 FrameBatch。
frame_at_2_seconds = decoder.get_frame_played_at(seconds=2)
print(f"{type(frame_at_2_seconds) = }")
print(frame_at_2_seconds)
type(frame_at_2_seconds) = <class 'torchcodec._frame.Frame'>
Frame:
data (shape): torch.Size([3, 360, 640])
pts_seconds: 2.0
duration_seconds: 0.04
other_frames = decoder.get_frames_played_at(seconds=[10.1, 0.3, 5])
print(f"{type(other_frames) = }")
print(other_frames)
type(other_frames) = <class 'torchcodec._frame.FrameBatch'>
FrameBatch:
data (shape): torch.Size([3, 3, 360, 640])
pts_seconds: tensor([10.0800, 0.2800, 5.0000], dtype=torch.float64)
duration_seconds: tensor([0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(frame_at_2_seconds.data, "Frame played at 2 seconds")
plot(other_frames.data, "Other frames")
指令碼總執行時間: (0 分 3.125 秒)



