diff --git a/README.md b/README.md index 79d47f0..dce523d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,165 @@ -# implementation from https://github.com/soCzech/TransNetV2 -# Pytorch implementation of TransNet V2 +# TransNetV2 (PyTorch) — Scene / Shot Boundary Detection with NVDEC (optional) + PyAV fallback + +This repository is a fork of [soCzech/TransNetV2](https://github.com/soCzech/TransNetV2) with a **PyTorch inference pipeline** and a clean, OOP-based API. + +It supports two decoding backends: + +- **NVIDIA NVDEC (GPU decode)** via **PyNvVideoCodec** *(optional)* — fastest path when the codec is supported by your GPU. +- **PyAV (FFmpeg, CPU decode)** — always available fallback (and the default when you run on CPU). + +> When you run on **CUDA**, the library tries NVDEC first and **automatically falls back** to PyAV if NVDEC can’t decode the input (for example: unsupported codec/profile/chroma on this GPU). + +--- + +## Table of Contents + +- [Features](#features) +- [Installation](#installation) +- [Usage](#usage) +- [Examples](#examples) +- [Configuration](#configuration) +- [Logging](#logging) +- [License](#license) + +--- + +## Features + +- **PyTorch model**: loads TransNetV2 weights with PyTorch and runs inference on CPU or CUDA. +- **Automatic backend selection**: + - `device="cpu"` → **PyAV** + - `device="cuda"` → try **NVDEC (PyNvVideoCodec)**, otherwise **PyAV** +- **Progress bars**: optional `tqdm` progress bars during decoding / window processing. +- **Clean API**: a single entry point class: `SceneDetector`. + +--- + +## Installation + +### 1) Install Python deps + +```bash +pip install -r requirements.txt +``` + +### 2) Install PyTorch + +Use the official selector to pick the correct CUDA / CPU build: +- https://pytorch.org/get-started/locally/ + +### 3) Install PyAV (CPU backend) + +PyAV provides binary wheels on PyPI for Windows / Linux / macOS: + +```bash +pip install av +``` + +Docs: +- https://pyav.org/docs/develop/overview/installation.html + +### 4) (Optional) Install PyNvVideoCodec (NVDEC backend) + +If you want **GPU-accelerated decode**, install NVIDIA **PyNvVideoCodec** (requires NVIDIA driver + compatible GPU): + +- https://developer.nvidia.com/pynvvideocodec +- API Programming Guide: https://docs.nvidia.com/video-technologies/pynvvideocodec/pynvc-api-prog-guide/index.html + +> NVDEC codec support depends on the GPU and the codec/profile of the input video. If NVDEC can’t decode your video, the library will fall back to PyAV automatically. + +--- + +## Usage + +### Basic (auto device selection) ```python -from transnetv2pt import predict_video -scenes = predict_video('video.mp4') +from transnetv2pt import SceneDetector + +detector = SceneDetector() # CUDA if available else CPU +scenes = detector.predict("path/to/video.mp4", show_progressbar=True) + +print(scenes) # [[start_frame, end_frame], ...] ``` + +### Force CUDA (NVDEC first, fallback to PyAV if unsupported) + +```python +from transnetv2pt import SceneDetector +import torch + +detector = SceneDetector(torch.device("cuda")) +scenes = detector.predict("path/to/video.mp4", show_progressbar=True) +``` + +### Force CPU (PyAV only) + +```python +from transnetv2pt import SceneDetector +import torch + +detector = SceneDetector(torch.device("cpu")) +scenes = detector.predict("path/to/video.mp4", show_progressbar=True) +``` + +--- + +## Examples + +### Extract keyframes at scene starts (OpenCV) + +```python +from pathlib import Path +import cv2 +import torch +from transnetv2pt import SceneDetector + +video_path = Path("video.mkv") + +detector = SceneDetector(torch.device("cuda")) # or "cpu" +scenes = detector.predict(str(video_path), show_progressbar=True) + +cap = cv2.VideoCapture(str(video_path)) +for i, (start, end) in enumerate(scenes): + cap.set(cv2.CAP_PROP_POS_FRAMES, int(start)) + ok, frame = cap.read() + if ok: + cv2.imwrite(f"scene_{i:04d}_start.png", frame) +cap.release() +``` + +--- + +## Configuration + +### Device selection + +- `SceneDetector()`: + - uses **CUDA** if `torch.cuda.is_available()` else CPU +- `SceneDetector(torch.device("cpu"))`: + - always uses **PyAV** +- `SceneDetector(torch.device("cuda"))`: + - tries **NVDEC** first (if PyNvVideoCodec installed), falls back to **PyAV** on decode errors. + +### Progress bars + +- `show_progressbar=True` will enable `tqdm` for: + - NVDEC window iteration (GPU backend) + - frame extraction / window iteration (PyAV backend) + +--- + +## Logging + +The library uses the standard Python `logging` module. To see logs: + +```python +import logging +logging.basicConfig(level=logging.INFO) +``` + +--- + +## License + +This project inherits the MIT License from the original TransNetV2 repository. See `LICENSE` for details. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..265d955 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy>=1.21 # Численные операции +av>=16.0 # Работа с видео и изображениями +tqdm>=4.60 # Прогресс-бары +tensorrt>=8.5 # TensorRT для оптимизации (если нужен) +Pillow>=8.4 # Обработка изображений, если понадобится diff --git a/setup.py b/setup.py index c747d3c..23e7b95 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="transnetv2pt", - version="1.0.0", + version="1.1.2", include_package_data=True, install_requires=[ "torch>=1.7", diff --git a/transnetv2pt/__init__.py b/transnetv2pt/__init__.py index 6ef28fe..b59c396 100644 --- a/transnetv2pt/__init__.py +++ b/transnetv2pt/__init__.py @@ -1 +1,3 @@ -from .inference import predict_video +from .inference import SceneDetector + +__all__ = ["SceneDetector"] \ No newline at end of file diff --git a/transnetv2pt/backend_nvvc.py b/transnetv2pt/backend_nvvc.py new file mode 100644 index 0000000..752a754 --- /dev/null +++ b/transnetv2pt/backend_nvvc.py @@ -0,0 +1,191 @@ +import logging +from typing import List +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm + +# Try to import PyNvVideoCodec (NVIDIA Video Codec Python bindings) +try: + import PyNvVideoCodec as nvc +except ImportError: + nvc = None + +class NVVCBackend: + """ + Backend for video decoding and scene detection using NVIDIA's NVDEC for acceleration. + This backend decodes video frames on the GPU and processes them with the TransNetV2 model. + """ + def __init__(self): + self.logger = logging.getLogger(__name__) + if nvc is None: + # If NVDEC bindings are not available, this backend cannot be used. + self.logger.error("PyNvVideoCodec (NVDEC) is not available") + # Raise ImportError so that callers can catch and handle the fallback. + raise ImportError("PyNvVideoCodec module is not installed or could not be imported") + + def predict_video(self, video_path: str, model: torch.nn.Module, device: torch.device, show_progressbar: bool = False): + """ + Decode the video using NVDEC and use the TransNetV2 model to predict scene cuts. + + Parameters: + video_path (str): Path to the video file. + model (torch.nn.Module): Loaded TransNetV2 model for prediction. + device (torch.device): The device on which the model is running (must be a CUDA device). + show_progressbar (bool): If True, display a progress bar for decoding windows. + + Returns: + scenes (np.ndarray): An array of [start_frame, end_frame] pairs for each detected scene. + """ + if device.type != "cuda": + raise RuntimeError("NVVCBackend requires a CUDA device for decoding.") + # Initialize NVDEC decoder for the input video + decoder = nvc.SimpleDecoder(enc_file_path=video_path, + gpu_id=self._get_cuda_gpu_id(device), + use_device_memory=True, + output_color_type=nvc.OutputColorType.RGB) + total_frames = len(decoder) + if total_frames <= 0: + raise ValueError(f"Empty or invalid video stream: {video_path}") + # Plan padding for start and end frames + pad_start = 25 + pad_end = 25 + 50 - (total_frames % 50 if total_frames % 50 != 0 else 50) + total_virtual = pad_start + total_frames + pad_end + num_windows = (total_virtual - 100) // 50 + 1 + + self.logger.info(f"NVDEC open: {video_path} | frames={total_frames} | windows={num_windows}") + + # Retrieve the first and last frame for padding (as GPU tensors) + first_frame_gpu = torch.from_dlpack(decoder[0]) # First frame + last_frame_gpu = torch.from_dlpack(decoder[total_frames - 1]) # Last frame + # Resize padding frames to the model input size (27x48) + start_frame_rgb = self._resize_frame(first_frame_gpu, target_h=27, target_w=48) + end_frame_rgb = self._resize_frame(last_frame_gpu, target_h=27, target_w=48) + + # Prepare to iterate over video frames in windows of 100 (with overlap of 50) + preds_list: List[np.ndarray] = [] + buffer: List[torch.Tensor] = [] + # Fill the initial buffer with the first 100 frames (including padding at start) + buffer = self._append_frames(decoder, start_frame_rgb, end_frame_rgb, buffer, + vi_start=0, count=100, target_h=27, target_w=48, + pad_start=pad_start, total_frames=total_frames, pad_end=pad_end) + assert len(buffer) == 100 + next_vi = 100 + + # Iterate over each window of 100 frames + frame_windows = range(num_windows) + if show_progressbar: + frame_windows = tqdm(frame_windows, total=num_windows, desc="NVDEC windows", unit="win") + for _ in frame_windows: + # Stack buffer list into a batch tensor of shape [1, 100, 27, 48, 3] (uint8) + batch = torch.stack(buffer, dim=0).unsqueeze(0) + # Run the model on this batch of frames + with torch.inference_mode(): + one_hot, _ = model(batch) + # Apply sigmoid and take the center 50 frame predictions from the 100 + p = torch.sigmoid(one_hot)[0, 25:75, 0].cpu().numpy() + preds_list.append(p) + # Slide the window: drop the first 50 frames and decode the next 50 + buffer = buffer[50:] + buffer = self._append_frames(decoder, start_frame_rgb, end_frame_rgb, buffer, + vi_start=next_vi, count=50, target_h=27, target_w=48, + pad_start=pad_start, total_frames=total_frames, pad_end=pad_end) + next_vi += 50 + + # Concatenate all prediction segments into one array + single_frame_pred = np.concatenate(preds_list, axis=0) + # Convert frame-wise predictions to scene boundaries + scenes = self._predictions_to_scenes(single_frame_pred) + self.logger.info(f"Detected {len(scenes)} scenes") + return scenes + + def _get_cuda_gpu_id(self, device: torch.device) -> int: + """Utility to get the CUDA device index (or 0 if unspecified).""" + if device.type != "cuda": + return 0 + return 0 if device.index is None else int(device.index) + + def _resize_frame(self, frame_tensor: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor: + """ + Resize a single frame tensor (uint8 HWC) to the target size using bilinear interpolation. + The input frame_tensor is expected to be on GPU memory. + """ + # Convert HWC uint8 [H, W, C] to NCHW float tensor + x = frame_tensor.permute(2, 0, 1).unsqueeze(0).to(dtype=torch.float32) + # Resize the image + x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False) + # Clamp values to [0, 255] and convert back to uint8 + x = x.clamp_(0.0, 255.0).to(dtype=torch.uint8) + # Convert back to HWC layout + return x.squeeze(0).permute(1, 2, 0).contiguous() + + def _append_frames(self, decoder: "nvc.SimpleDecoder", start_frame_rgb: torch.Tensor, + end_frame_rgb: torch.Tensor, buffer: List[torch.Tensor], + vi_start: int, count: int, target_h: int, target_w: int, + pad_start: int, total_frames: int, pad_end: int) -> List[torch.Tensor]: + """ + Append a range of frames to the buffer list for virtual indices [vi_start, vi_start+count). + + This handles padding frames (using start_frame_rgb or end_frame_rgb) for indices outside the actual video range, + and decodes real frames for indices within the video. + """ + if count == 0: + return buffer + + # Calculate prefix padding (frames before the real video start) + prefix = 0 + if vi_start < pad_start: + prefix = min(pad_start - vi_start, count) + # Calculate suffix padding (frames beyond the real video end) + suffix = 0 + end_index = vi_start + count + actual_end_index = pad_start + total_frames + if end_index > actual_end_index: + suffix = min(end_index - actual_end_index, count - prefix) + # Number of real frames in this range (excluding prefix/suffix) + real_count = count - prefix - suffix + + # Add prefix padding frames (repeat the first frame) + for _ in range(prefix): + buffer.append(start_frame_rgb) + # Add real video frames by decoding with NVDEC + if real_count > 0: + real_start_index = max(0, vi_start - pad_start) + decoder.seek_to_index(real_start_index) + batch_frames = decoder.get_batch_frames(batch_size=real_count) + for frame in batch_frames: + frame_tensor = torch.from_dlpack(frame) + frame_resized = self._resize_frame(frame_tensor, target_h=target_h, target_w=target_w) + buffer.append(frame_resized) + # Add suffix padding frames (repeat the last frame) + for _ in range(suffix): + buffer.append(end_frame_rgb) + return buffer + + def _predictions_to_scenes(self, predictions: np.ndarray, threshold: float = 0.5) -> np.ndarray: + """ + Convert an array of frame-level predictions (probabilities) into scene boundary intervals. + A scene boundary is identified when predictions go from below the threshold to above the threshold. + + Returns: + np.ndarray: Array of [start_frame, end_frame] pairs for each detected scene. + """ + pred = (predictions > threshold).astype(np.uint8) + scenes = [] + t_prev = 0 + start = 0 + for i, t in enumerate(pred): + if t_prev == 1 and t == 0: + # A new scene starts when we transition from cut (1) to no-cut (0) + start = i + if t_prev == 0 and t == 1 and i != 0: + # Previous scene ends right before this cut + scenes.append([start, i]) + t_prev = t + # Handle the last scene after the final cut + if t_prev == 0: + scenes.append([start, len(pred) - 1]) + if len(scenes) == 0: + # If no cuts at all, the entire video is one scene + return np.array([[0, len(pred) - 1]], dtype=np.int32) + return np.array(scenes, dtype=np.int32) diff --git a/transnetv2pt/backend_pyav.py b/transnetv2pt/backend_pyav.py new file mode 100644 index 0000000..8eee58e --- /dev/null +++ b/transnetv2pt/backend_pyav.py @@ -0,0 +1,137 @@ +import logging +import av +import numpy as np +import torch +from tqdm import tqdm + +class PyAVBackend: + """ + Backend for video decoding and scene detection using PyAV (FFmpeg). + This backend decodes video frames on the CPU using PyAV and processes them with the TransNetV2 model. + """ + def __init__(self): + self.logger = logging.getLogger(__name__) + + def predict_video(self, video_path: str, model: torch.nn.Module, device: torch.device, show_progressbar: bool = False): + """ + Decode the video using PyAV and use the TransNetV2 model to predict scene cuts. + + Parameters: + video_path (str): Path to the video file. + model (torch.nn.Module): Loaded TransNetV2 model for prediction. + device (torch.device): The device on which to run the model (CPU or CUDA). + show_progressbar (bool): If True, display a progress bar during frame extraction and processing. + + Returns: + scenes (np.ndarray): An array of [start_frame, end_frame] pairs for each detected scene. + """ + # Decode all frames from the video using PyAV + target_width = 48 + target_height = 27 + frames = self._extract_frames(video_path, target_width, target_height, show_progressbar) + num_frames = frames.shape[0] + if num_frames == 0: + raise ValueError(f"No frames extracted from video: {video_path}") + + # Determine padding at start and end (25 frames on each side, plus extra to align to 50) + pad_start = 25 + pad_end = 25 + 50 - (num_frames % 50 if num_frames % 50 != 0 else 50) + total_virtual = pad_start + num_frames + pad_end + num_windows = (total_virtual - 100) // 50 + 1 + + # Prepare padded frame sequence (with repeated first and last frame for padding) + start_frame = frames[0:1] # first frame + end_frame = frames[-1:] # last frame + padded_frames = np.concatenate([ + np.repeat(start_frame, pad_start, axis=0), + frames, + np.repeat(end_frame, pad_end, axis=0) + ], axis=0) + + # Process frames in windows of 100 with stride 50 + preds_list = [] + frame_windows = range(num_windows) + if show_progressbar: + frame_windows = tqdm(frame_windows, total=num_windows, desc="Processing windows", unit="win") + for i in frame_windows: + start_idx = i * 50 + batch_frames = padded_frames[start_idx : start_idx + 100] # shape (100, H, W, 3) + batch_frames = batch_frames[np.newaxis, ...] # shape (1, 100, H, W, 3) + # Move batch to the model's device and run prediction + batch_tensor = torch.from_numpy(batch_frames).to(device) + with torch.inference_mode(): + one_hot, _ = model(batch_tensor) + p = torch.sigmoid(one_hot)[0, 25:75, 0].cpu().numpy() + preds_list.append(p) + + # Combine predictions from all windows + single_frame_pred = np.concatenate(preds_list, axis=0) + # Convert frame-wise predictions to scene boundaries + scenes = self._predictions_to_scenes(single_frame_pred) + self.logger.info(f"Detected {len(scenes)} scenes") + return scenes + + def _extract_frames(self, video_path: str, target_width: int, target_height: int, show_progressbar: bool = False) -> np.ndarray: + """ + Extract all frames from the video at the specified resolution using PyAV. + + Parameters: + video_path (str): Path to the video file. + target_width (int): Width to scale frames to. + target_height (int): Height to scale frames to. + show_progressbar (bool): If True, show a progress bar during frame extraction. + + Returns: + np.ndarray: Array of frames with shape (num_frames, target_height, target_width, 3) in uint8 format. + """ + self.logger.info(f"Opening video: {video_path}") + frames_list = [] + try: + with av.open(video_path) as container: + if not container.streams.video: + raise ValueError(f"No video stream found in file: {video_path}") + stream = container.streams.video[0] + stream.thread_type = "AUTO" + total_frames = stream.frames or None # total frame count if known + frame_iterator = container.decode(video=0) + if show_progressbar: + frame_iterator = tqdm(frame_iterator, total=total_frames, desc="Extracting frames", unit="frame") + for frame in frame_iterator: + # Convert frame to RGB and resize to target dimensions + frame = frame.reformat(width=target_width, height=target_height, format="rgb24") + frame_array = frame.to_ndarray() + frames_list.append(frame_array) + except (av.FFmpegError, OSError, ValueError) as e: + # Log and re-raise any errors encountered during video reading or decoding + self.logger.error(f"Failed to open or decode video: {video_path}. PyAV error: {e}") + raise + + self.logger.info(f"Extracted {len(frames_list)} frames from {video_path}") + return np.asarray(frames_list, dtype=np.uint8) + + def _predictions_to_scenes(self, predictions: np.ndarray, threshold: float = 0.5) -> np.ndarray: + """ + Convert an array of frame-level predictions into scene boundary intervals. + + Parameters: + predictions (np.ndarray): 1D array of scene-cut probabilities for each frame. + threshold (float): Threshold for detecting a scene cut (default 0.5). + + Returns: + np.ndarray: Array of [start_frame, end_frame] pairs for each detected scene. + """ + pred = (predictions > threshold).astype(np.uint8) + scenes = [] + t_prev = 0 + start = 0 + for i, t in enumerate(pred): + if t_prev == 1 and t == 0: + start = i + if t_prev == 0 and t == 1 and i != 0: + scenes.append([start, i]) + t_prev = t + if t_prev == 0: + scenes.append([start, len(pred) - 1]) + if len(scenes) == 0: + return np.array([[0, len(pred) - 1]], dtype=np.int32) + return np.array(scenes, dtype=np.int32) diff --git a/transnetv2pt/inference.py b/transnetv2pt/inference.py index a27eb8a..d69a7bc 100644 --- a/transnetv2pt/inference.py +++ b/transnetv2pt/inference.py @@ -1,96 +1,83 @@ -# https://github.com/soCzech/TransNetV2 +import os +import logging import torch +# Import TransNetV2 model class from the same package from .transnetv2_pytorch import TransNetV2 -import ffmpeg -import numpy as np -import os - - -model = TransNetV2() -state_dict = torch.load( - f"{os.path.dirname(os.path.abspath(__file__))}/transnetv2-pytorch-weights.pth") -model.load_state_dict(state_dict) -model.eval() - - -def input_iterator(frames): - # return windows of size 100 where the first/last 25 frames are from the previous/next batch - # the first and last window must be padded by copies of the first and last frame of the video - no_padded_frames_start = 25 - no_padded_frames_end = 25 + 50 - \ - (len(frames) % 50 if len(frames) % 50 != 0 else 50) # 25 - 74 - - start_frame = np.expand_dims(frames[0], 0) - end_frame = np.expand_dims(frames[-1], 0) - padded_inputs = np.concatenate( - [start_frame] * no_padded_frames_start + - [frames] + [end_frame] * no_padded_frames_end, 0 - ) - - ptr = 0 - while ptr + 100 <= len(padded_inputs): - out = padded_inputs[ptr:ptr + 100] - ptr += 50 - yield out[np.newaxis] - - -def predictions_to_scenes(predictions: np.ndarray, threshold: float = 0.5): - predictions = (predictions > threshold).astype(np.uint8) - - scenes = [] - t, t_prev, start = -1, 0, 0 - for i, t in enumerate(predictions): - if t_prev == 1 and t == 0: - start = i - if t_prev == 0 and t == 1 and i != 0: - scenes.append([start, i]) - t_prev = t - if t == 0: - scenes.append([start, i]) - - # just fix if all predictions are 1 - if len(scenes) == 0: - return np.array([[0, len(predictions) - 1]], dtype=np.int32) - - return np.array(scenes, dtype=np.int32) - - -def predict_raw(model, video, device=torch.device('cuda:0')): - model.to(device) - with torch.no_grad(): - predictions = [] - for inp in input_iterator(video): - video_tensor = torch.from_numpy(inp) - # shape: batch dim x video frames x frame height x frame width x RGB (not BGR) channels - video_tensor = video_tensor.to(device) - - single_frame_pred, all_frame_pred = model(video_tensor) - - single_frame_pred = torch.sigmoid(single_frame_pred).cpu().numpy() - all_frame_pred = torch.sigmoid( - all_frame_pred["many_hot"]).cpu().numpy() - predictions.append( - (single_frame_pred[0, 25:75, 0], all_frame_pred[0, 25:75, 0])) - print("\r[TransNetV2] Processing video frames {}/{}".format( - min(len(predictions) * 50, len(video)), len(video) - ), end="") - single_frame_pred = np.concatenate( - [single_ for single_, all_ in predictions]) - all_frames_pred = np.concatenate( - [all_ for single_, all_ in predictions]) - - return video.shape[0], single_frame_pred[:len(video)], all_frames_pred[:len(video)] - -def predict_video(filename_or_video): - if isinstance(filename_or_video, str): - video_stream, err = ffmpeg.input(filename_or_video).output( - "pipe:", format="rawvideo", pix_fmt="rgb24", s="48x27" - ).run(capture_stdout=True, capture_stderr=True) - video = np.frombuffer(video_stream, np.uint8).reshape([-1, 27, 48, 3]) - else: - assert filename_or_video.shape[1] == 27 and filename_or_video.shape[2] == 48 and filename_or_video.shape[3] == 3 - video = filename_or_video - _, single_frame_pred, _ = predict_raw(model, video) - scenes = predictions_to_scenes(single_frame_pred) - return scenes +# Import backend classes for decoding and inference +from . import backend_pyav +from . import backend_nvvc + +class SceneDetector: + """ + SceneDetector is an interface for detecting scene boundaries in videos using the TransNetV2 model. + It automatically selects between an NVIDIA GPU-accelerated decoding backend (if available) or a CPU-based PyAV backend. + """ + def __init__(self, device: torch.device | None = None): + """ + Initialize the SceneDetector. + If device is not provided, use CUDA if available, otherwise CPU. + The TransNetV2 model will be loaded on the specified device upon first use. + """ + # Determine device (CUDA or CPU) + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = torch.device(device) + self.model = None + # Set up a logger for this class + self.logger = logging.getLogger(__name__) + + def _init_model(self): + """ + Load the TransNetV2 model and weights onto the specified device. + Uses torch.compile for optimization if running on CUDA. + """ + model = TransNetV2() + # Load model weights from the package directory + state_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "transnetv2-pytorch-weights.pth") + state_dict = torch.load(state_path, map_location="cpu") + model.load_state_dict(state_dict) + model.eval() + model.to(self.device) + # Optimize model execution on CUDA + if self.device.type == "cuda": + torch.set_float32_matmul_precision("high") + model = torch.compile(model, mode="max-autotune-no-cudagraphs") + return model + + def predict(self, video_path: str, show_progressbar: bool = False): + """ + Detect scene boundaries in the given video file. + + Parameters: + video_path (str): Path to the video file to process. + show_progressbar (bool): If True, display a progress bar during processing. + + Returns: + scenes (np.ndarray): An array of [start_frame, end_frame] pairs for each detected scene. + """ + # Initialize the model on first use + if self.model is None: + self.model = self._init_model() + self.logger.info(f"Initialized TransNetV2 model on {self.device.type.upper()} device") + + # If device is CPU or CUDA is not available, use PyAV backend directly + if self.device.type != "cuda": + self.logger.debug("Using PyAV backend (CPU decoding)") + backend = backend_pyav.PyAVBackend() + scenes = backend.predict_video(video_path, self.model, device=self.device, show_progressbar=show_progressbar) + return scenes + + # If device is CUDA, attempt to use NVDEC backend for GPU decoding + try: + self.logger.debug("Attempting NVDEC backend (GPU decoding)") + backend = backend_nvvc.NVVCBackend() + scenes = backend.predict_video(video_path, self.model, device=self.device, show_progressbar=show_progressbar) + return scenes + except Exception as e: + # If any error occurs (e.g., NVDEC not available or decoding fails), fall back to PyAV + self.logger.warning(f"NVDEC backend failed (error: {e}). Falling back to PyAV backend.") + backend = backend_pyav.PyAVBackend() + scenes = backend.predict_video(video_path, self.model, device=self.device, show_progressbar=show_progressbar) + return scenes \ No newline at end of file