Instructions to use Justin331/sam3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Justin331/sam3 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("mask-generation", model="Justin331/sam3")# Load model directly from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("Justin331/sam3") model = AutoModel.from_pretrained("Justin331/sam3") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import io | |
| import base64 | |
| import tempfile | |
| import zipfile | |
| import logging | |
| import sys | |
| import time | |
| from typing import Dict, Any, Optional | |
| from pathlib import Path | |
| import json | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| # CRITICAL: Patch torch.autocast BEFORE any SAM3 imports | |
| # SAM3 uses @torch.autocast decorators that get applied at import time | |
| # We must patch torch.autocast before the decorators are evaluated | |
| class Float32Autocast: | |
| """No-op autocast that forces float32.""" | |
| def __init__(self, device_type, dtype=None, enabled=True): | |
| self.device_type = device_type | |
| self.dtype = torch.float32 | |
| self.enabled = False | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, *args): | |
| pass | |
| # Store original and replace globally | |
| _ORIGINAL_AUTOCAST = torch.autocast | |
| torch.autocast = Float32Autocast | |
| if hasattr(torch.cuda, 'amp'): | |
| torch.cuda.amp.autocast = Float32Autocast | |
| if hasattr(torch, 'amp'): | |
| torch.amp.autocast = Float32Autocast | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S', | |
| stream=sys.stdout | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info("✓ Patched torch.autocast globally before SAM3 import") | |
| # SAM3 imports - using local sam3 package in repository | |
| # This will now use our patched autocast for all @torch.autocast decorators | |
| from sam3.model_builder import build_sam3_video_predictor | |
| # HuggingFace Hub for uploads | |
| try: | |
| from huggingface_hub import HfApi | |
| HF_HUB_AVAILABLE = True | |
| except ImportError: | |
| HF_HUB_AVAILABLE = False | |
| class EndpointHandler: | |
| """ | |
| SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints | |
| Processes video with text prompts and returns segmentation masks. | |
| Uses SAM3 repository code directly from local sam3/ package. | |
| """ | |
| def __init__(self, path: str = ""): | |
| """ | |
| Initialize SAM3 video predictor. | |
| Args: | |
| path: Path to model repository (not used - model loads from HF automatically) | |
| """ | |
| logger.info("="*80) | |
| logger.info("INITIALIZING SAM3 VIDEO SEGMENTATION HANDLER") | |
| logger.info("="*80) | |
| # Set device | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Device detection: {self.device}") | |
| if self.device != "cuda": | |
| logger.error("FATAL: SAM3 requires GPU acceleration. No CUDA device found.") | |
| raise ValueError("SAM3 requires GPU acceleration. No CUDA device found.") | |
| # Log GPU information | |
| if torch.cuda.is_available(): | |
| logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}") | |
| logger.info(f"CUDA Version: {torch.version.cuda}") | |
| logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") | |
| # Build SAM3 video predictor | |
| # Note: torch.autocast was already patched at module import time | |
| try: | |
| logger.info("Building SAM3 video predictor...") | |
| start_time = time.time() | |
| # Ensure BPE tokenizer file exists | |
| bpe_path = self._ensure_bpe_file() | |
| logger.info(f"BPE tokenizer path: {bpe_path}") | |
| # Build predictor with explicit bpe_path | |
| self.predictor = build_sam3_video_predictor( | |
| gpus_to_use=[0], | |
| bpe_path=bpe_path | |
| ) | |
| # Fix dtype mismatch: Convert all model parameters and buffers to float32 | |
| # This fixes: "Input type (c10::BFloat16) and bias type (float) should be the same" | |
| logger.info("Converting model to float32 to avoid dtype mismatch...") | |
| def convert_model_to_float32(model): | |
| """Recursively convert all model components to float32.""" | |
| conversion_count = 0 | |
| # Convert the model itself | |
| model.float() | |
| # Convert all parameters | |
| for name, param in model.named_parameters(): | |
| if param.dtype != torch.float32: | |
| param.data = param.data.float() | |
| conversion_count += 1 | |
| logger.debug(f" Converted parameter: {name}") | |
| # Convert all buffers (batch norm running stats, etc.) | |
| for buffer_name, buffer in model.named_buffers(): | |
| if buffer.dtype != torch.float32 and buffer.dtype in [torch.float16, torch.bfloat16]: | |
| model.register_buffer(buffer_name, buffer.float()) | |
| conversion_count += 1 | |
| logger.debug(f" Converted buffer: {buffer_name}") | |
| # Also convert submodules explicitly | |
| for name, module in model.named_modules(): | |
| if module is not model: # Skip the root module | |
| try: | |
| module.float() | |
| except Exception: | |
| pass # Some modules may not support .float() | |
| return conversion_count | |
| total_conversions = 0 | |
| # Convert the main model | |
| if hasattr(self.predictor, 'model') and self.predictor.model is not None: | |
| logger.info(" Converting main model...") | |
| total_conversions += convert_model_to_float32(self.predictor.model) | |
| # SAM3 may have additional models (detector, tracker, etc.) | |
| # Check for other potential model attributes | |
| for attr_name in ['detector', 'tracker', 'image_encoder', 'text_encoder']: | |
| if hasattr(self.predictor, attr_name): | |
| attr = getattr(self.predictor, attr_name) | |
| if attr is not None and hasattr(attr, 'float'): | |
| logger.info(f" Converting {attr_name}...") | |
| try: | |
| total_conversions += convert_model_to_float32(attr) | |
| except Exception as e: | |
| logger.warning(f" Could not convert {attr_name}: {e}") | |
| # Check if model has nested models | |
| if hasattr(self.predictor, 'model') and self.predictor.model is not None: | |
| model = self.predictor.model | |
| for attr_name in dir(model): | |
| if not attr_name.startswith('_'): | |
| try: | |
| attr = getattr(model, attr_name) | |
| if hasattr(attr, 'parameters') and hasattr(attr, 'float'): | |
| # This looks like a submodel | |
| if attr_name not in ['model', 'detector', 'tracker']: | |
| logger.debug(f" Found submodel: {attr_name}") | |
| try: | |
| convert_model_to_float32(attr) | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| if total_conversions > 0: | |
| logger.info(f"✓ Model converted to float32 ({total_conversions} tensors converted)") | |
| else: | |
| logger.warning("⚠ No tensors were converted - dtype fix may not have been applied correctly") | |
| # Additional safety: Wrap handle_request to ensure inputs are float32 | |
| original_handle_request = self.predictor.handle_request | |
| def float32_handle_request(request): | |
| """Wrapper to ensure all tensor inputs are float32.""" | |
| # Recursively convert any tensors in the request to float32 | |
| def ensure_float32(obj): | |
| if isinstance(obj, torch.Tensor): | |
| if obj.dtype in [torch.float16, torch.bfloat16]: | |
| return obj.float() | |
| return obj | |
| elif isinstance(obj, dict): | |
| return {k: ensure_float32(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return type(obj)(ensure_float32(item) for item in obj) | |
| return obj | |
| request = ensure_float32(request) | |
| return original_handle_request(request) | |
| self.predictor.handle_request = float32_handle_request | |
| # Also wrap handle_stream_request if it exists | |
| if hasattr(self.predictor, 'handle_stream_request'): | |
| original_handle_stream_request = self.predictor.handle_stream_request | |
| def float32_handle_stream_request(request): | |
| """Wrapper to ensure all tensor inputs are float32.""" | |
| def ensure_float32(obj): | |
| if isinstance(obj, torch.Tensor): | |
| if obj.dtype in [torch.float16, torch.bfloat16]: | |
| return obj.float() | |
| return obj | |
| elif isinstance(obj, dict): | |
| return {k: ensure_float32(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return type(obj)(ensure_float32(item) for item in obj) | |
| return obj | |
| request = ensure_float32(request) | |
| for response in original_handle_stream_request(request): | |
| yield response | |
| self.predictor.handle_stream_request = float32_handle_stream_request | |
| logger.info("✓ Added float32 enforcement wrappers to predictor methods") | |
| elapsed = time.time() - start_time | |
| logger.info(f"✓ SAM3 video predictor loaded successfully in {elapsed:.2f}s") | |
| except Exception as e: | |
| logger.error(f"✗ Failed to load SAM3 predictor: {type(e).__name__}: {e}") | |
| logger.exception("Full traceback:") | |
| raise | |
| # Initialize HuggingFace API for uploads (if available) | |
| self.hf_api = None | |
| hf_token = os.getenv("HF_TOKEN") | |
| if HF_HUB_AVAILABLE and hf_token: | |
| try: | |
| self.hf_api = HfApi(token=hf_token) | |
| logger.info("✓ HuggingFace Hub API initialized") | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize HF API: {e}") | |
| else: | |
| reasons = [] | |
| if not HF_HUB_AVAILABLE: | |
| reasons.append("huggingface_hub not installed") | |
| if not hf_token: | |
| reasons.append("HF_TOKEN not set") | |
| logger.info(f"HuggingFace Hub uploads disabled ({', '.join(reasons)})") | |
| logger.info("="*80) | |
| logger.info("INITIALIZATION COMPLETE - READY FOR REQUESTS") | |
| logger.info("="*80) | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Process video segmentation request using SAM3 video predictor API. | |
| Expected input format (HuggingFace Inference Toolkit standard): | |
| { | |
| "inputs": <base64_encoded_video>, | |
| "parameters": { | |
| "text_prompt": "object to segment", | |
| "return_format": "download_url" or "base64" or "metadata_only", # optional | |
| "output_repo": "username/dataset-name", # optional, for HF upload | |
| } | |
| } | |
| Returns: | |
| { | |
| "download_url": "https://...", # if uploaded to HF | |
| "frame_count": 120, | |
| "video_metadata": {...}, | |
| "compressed_size_mb": 15.3, | |
| "objects_detected": [1, 2, 3] # object IDs | |
| } | |
| """ | |
| request_start = time.time() | |
| logger.info("") | |
| logger.info("="*80) | |
| logger.info("NEW REQUEST RECEIVED") | |
| logger.info("="*80) | |
| try: | |
| # Extract and validate parameters | |
| logger.info("Parsing request parameters...") | |
| # DEBUG: Log the exact structure we received | |
| logger.info(f" Received keys: {list(data.keys())}") | |
| if "parameters" in data: | |
| logger.info(f" parameters dict keys: {list(data['parameters'].keys())}") | |
| # Video comes from "inputs" (HF toolkit standard) | |
| video_data = data.get("inputs") | |
| # Parameters might be at top level (flattened) or in "parameters" dict | |
| # HF Inference Toolkit doesn't always flatten, so check both locations | |
| parameters = data.get("parameters", {}) | |
| text_prompt = data.get("text_prompt") or parameters.get("text_prompt", "") | |
| output_repo = data.get("output_repo") or parameters.get("output_repo") | |
| return_format = data.get("return_format") or parameters.get("return_format", "metadata_only") | |
| # DEBUG: Log what we extracted | |
| logger.info(f" Extracted text_prompt: '{text_prompt}'") | |
| # Log request details | |
| logger.info(f" text_prompt: '{text_prompt}'") | |
| logger.info(f" return_format: {return_format}") | |
| logger.info(f" output_repo: {output_repo if output_repo else 'None'}") | |
| logger.info(f" video_data: {'Present' if video_data else 'Missing'} ({len(video_data) if video_data else 0} chars)") | |
| # Validate inputs | |
| if not video_data: | |
| logger.error("✗ Validation failed: No video data provided") | |
| return {"error": "No video data provided. Include video as 'inputs' in request."} | |
| if not text_prompt: | |
| logger.error("✗ Validation failed: No text prompt provided") | |
| return {"error": "No text prompt provided. Include 'text_prompt' in 'parameters'."} | |
| if return_format not in ["metadata_only", "base64", "download_url"]: | |
| logger.warning(f"Invalid return_format '{return_format}', defaulting to 'metadata_only'") | |
| return_format = "metadata_only" | |
| if return_format == "download_url" and not output_repo: | |
| logger.error("✗ Validation failed: download_url requires output_repo") | |
| return {"error": "return_format='download_url' requires 'output_repo' parameter"} | |
| logger.info("✓ Request validation passed") | |
| # Process video in temporary directory | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| tmpdir_path = Path(tmpdir) | |
| logger.info(f"Created temporary directory: {tmpdir}") | |
| # STEP 1: Decode and save video | |
| logger.info("") | |
| logger.info("STEP 1/9: Decoding video data...") | |
| step_start = time.time() | |
| try: | |
| video_path = self._prepare_video(video_data, tmpdir_path) | |
| video_size_mb = video_path.stat().st_size / 1e6 | |
| logger.info(f" Video saved to: {video_path}") | |
| logger.info(f" Video size: {video_size_mb:.2f} MB") | |
| logger.info(f"✓ Step 1 completed in {time.time() - step_start:.2f}s") | |
| except Exception as e: | |
| logger.error(f"✗ Step 1 failed: {type(e).__name__}: {e}") | |
| raise | |
| # STEP 2: Start SAM3 session | |
| logger.info("") | |
| logger.info("STEP 2/9: Starting SAM3 session...") | |
| step_start = time.time() | |
| try: | |
| response = self.predictor.handle_request( | |
| request=dict( | |
| type="start_session", | |
| resource_path=str(video_path), | |
| ) | |
| ) | |
| session_id = response["session_id"] | |
| logger.info(f" Session ID: {session_id}") | |
| logger.info(f"✓ Step 2 completed in {time.time() - step_start:.2f}s") | |
| except Exception as e: | |
| logger.error(f"✗ Step 2 failed: {type(e).__name__}: {e}") | |
| raise | |
| # STEP 3: Add text prompt | |
| logger.info("") | |
| logger.info("STEP 3/9: Adding text prompt to first frame...") | |
| step_start = time.time() | |
| try: | |
| response = self.predictor.handle_request( | |
| request=dict( | |
| type="add_prompt", | |
| session_id=session_id, | |
| frame_index=0, | |
| text=text_prompt, | |
| ) | |
| ) | |
| logger.info(f" Prompt: '{text_prompt}'") | |
| logger.info(f" Frame: 0") | |
| logger.info(f"✓ Step 3 completed in {time.time() - step_start:.2f}s") | |
| except Exception as e: | |
| logger.error(f"✗ Step 3 failed: {type(e).__name__}: {e}") | |
| raise | |
| # STEP 4: Propagate through video | |
| logger.info("") | |
| logger.info("STEP 4/9: Propagating segmentation through video...") | |
| step_start = time.time() | |
| try: | |
| outputs_per_frame = {} | |
| last_log_frame = -1 | |
| log_interval = 10 # Log every 10 frames | |
| for stream_response in self.predictor.handle_stream_request( | |
| request=dict( | |
| type="propagate_in_video", | |
| session_id=session_id, | |
| ) | |
| ): | |
| frame_idx = stream_response["frame_index"] | |
| outputs_per_frame[frame_idx] = stream_response["outputs"] | |
| # Log progress every N frames | |
| if frame_idx - last_log_frame >= log_interval: | |
| logger.info(f" Processing frame {frame_idx}...") | |
| last_log_frame = frame_idx | |
| logger.info(f" Total frames processed: {len(outputs_per_frame)}") | |
| logger.info(f"✓ Step 4 completed in {time.time() - step_start:.2f}s") | |
| except Exception as e: | |
| logger.error(f"✗ Step 4 failed: {type(e).__name__}: {e}") | |
| raise | |
| # STEP 5: Save masks to PNG files | |
| logger.info("") | |
| logger.info("STEP 5/9: Saving masks to PNG files...") | |
| step_start = time.time() | |
| try: | |
| masks_dir = tmpdir_path / "masks" | |
| masks_dir.mkdir() | |
| all_object_ids = set() | |
| mask_count = 0 | |
| for frame_idx, frame_output in outputs_per_frame.items(): | |
| frame_masks = self._save_frame_masks(frame_output, masks_dir, frame_idx) | |
| mask_count += frame_masks | |
| # Collect object IDs | |
| if "object_ids" in frame_output and frame_output["object_ids"] is not None: | |
| obj_ids = frame_output["object_ids"] | |
| if torch.is_tensor(obj_ids): | |
| obj_ids = obj_ids.cpu().tolist() | |
| elif isinstance(obj_ids, np.ndarray): | |
| obj_ids = obj_ids.tolist() | |
| if isinstance(obj_ids, list): | |
| all_object_ids.update(obj_ids) | |
| else: | |
| all_object_ids.add(obj_ids) | |
| logger.info(f" Masks directory: {masks_dir}") | |
| logger.info(f" Total mask files: {mask_count}") | |
| logger.info(f" Unique objects: {sorted(list(all_object_ids))}") | |
| logger.info(f"✓ Step 5 completed in {time.time() - step_start:.2f}s") | |
| except Exception as e: | |
| logger.error(f"✗ Step 5 failed: {type(e).__name__}: {e}") | |
| raise | |
| # STEP 6: Create ZIP archive | |
| logger.info("") | |
| logger.info("STEP 6/9: Creating ZIP archive...") | |
| step_start = time.time() | |
| try: | |
| zip_path = tmpdir_path / "masks.zip" | |
| self._create_zip(masks_dir, zip_path) | |
| zip_size_mb = zip_path.stat().st_size / 1e6 | |
| logger.info(f" ZIP path: {zip_path}") | |
| logger.info(f" ZIP size: {zip_size_mb:.2f} MB") | |
| logger.info(f" Compression ratio: {(1 - zip_size_mb / video_size_mb) * 100:.1f}%") | |
| logger.info(f"✓ Step 6 completed in {time.time() - step_start:.2f}s") | |
| except Exception as e: | |
| logger.error(f"✗ Step 6 failed: {type(e).__name__}: {e}") | |
| raise | |
| # STEP 7: Get video metadata | |
| logger.info("") | |
| logger.info("STEP 7/9: Extracting video metadata...") | |
| step_start = time.time() | |
| try: | |
| video_metadata = self._get_video_metadata(video_path) | |
| for key, value in video_metadata.items(): | |
| logger.info(f" {key}: {value}") | |
| logger.info(f"✓ Step 7 completed in {time.time() - step_start:.2f}s") | |
| except Exception as e: | |
| logger.warning(f"Step 7 partial failure: {e}") | |
| video_metadata = {} | |
| # STEP 8: Prepare response | |
| logger.info("") | |
| logger.info("STEP 8/9: Preparing response...") | |
| step_start = time.time() | |
| response = { | |
| "frame_count": len(outputs_per_frame), | |
| "objects_detected": sorted(list(all_object_ids)) if all_object_ids else [], | |
| "compressed_size_mb": round(zip_size_mb, 2), | |
| "video_metadata": video_metadata | |
| } | |
| if return_format == "download_url" and output_repo: | |
| logger.info(f" Uploading to HuggingFace dataset: {output_repo}") | |
| try: | |
| download_url = self._upload_to_hf(zip_path, output_repo) | |
| response["download_url"] = download_url | |
| logger.info(f" ✓ Upload successful: {download_url}") | |
| except Exception as e: | |
| logger.error(f" ✗ Upload failed: {e}") | |
| raise | |
| elif return_format == "base64": | |
| logger.info(" Encoding ZIP to base64...") | |
| try: | |
| with open(zip_path, "rb") as f: | |
| zip_bytes = f.read() | |
| response["masks_zip_base64"] = base64.b64encode(zip_bytes).decode("utf-8") | |
| logger.info(f" ✓ Encoded {len(response['masks_zip_base64'])} characters") | |
| except Exception as e: | |
| logger.error(f" ✗ Encoding failed: {e}") | |
| raise | |
| else: | |
| logger.info(" Returning metadata only (no mask data)") | |
| logger.info(f"✓ Step 8 completed in {time.time() - step_start:.2f}s") | |
| # STEP 9: Close session | |
| logger.info("") | |
| logger.info("STEP 9/9: Closing SAM3 session...") | |
| step_start = time.time() | |
| try: | |
| self.predictor.handle_request( | |
| request=dict( | |
| type="close_session", | |
| session_id=session_id, | |
| ) | |
| ) | |
| logger.info(f"✓ Step 9 completed in {time.time() - step_start:.2f}s") | |
| except Exception as e: | |
| logger.warning(f"Step 9 partial failure (non-critical): {e}") | |
| # Final summary | |
| total_time = time.time() - request_start | |
| logger.info("") | |
| logger.info("="*80) | |
| logger.info("REQUEST COMPLETED SUCCESSFULLY") | |
| logger.info(f"Total processing time: {total_time:.2f}s") | |
| logger.info(f"Frames processed: {len(outputs_per_frame)}") | |
| logger.info(f"Objects detected: {len(all_object_ids)}") | |
| logger.info("="*80) | |
| logger.info("") | |
| return response | |
| except Exception as e: | |
| total_time = time.time() - request_start | |
| logger.error("") | |
| logger.error("="*80) | |
| logger.error("REQUEST FAILED") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| logger.error(f"Error message: {str(e)}") | |
| logger.error(f"Time elapsed: {total_time:.2f}s") | |
| logger.error("="*80) | |
| logger.exception("Full traceback:") | |
| logger.error("") | |
| return { | |
| "error": str(e), | |
| "error_type": type(e).__name__ | |
| } | |
| def _ensure_bpe_file(self) -> str: | |
| """ | |
| Ensure BPE tokenizer file exists. Download from HuggingFace if missing. | |
| Returns path to the BPE file. | |
| """ | |
| logger.info("Checking for BPE tokenizer file...") | |
| # Try multiple possible paths | |
| possible_paths = [ | |
| Path("/repository/assets/bpe_simple_vocab_16e6.txt.gz"), | |
| Path("./assets/bpe_simple_vocab_16e6.txt.gz"), | |
| Path("../assets/bpe_simple_vocab_16e6.txt.gz"), | |
| Path("/app/assets/bpe_simple_vocab_16e6.txt.gz"), | |
| ] | |
| for bpe_file in possible_paths: | |
| if bpe_file.exists(): | |
| logger.info(f" ✓ BPE file found: {bpe_file}") | |
| return str(bpe_file) | |
| logger.warning(" BPE file not found in any expected location") | |
| # Use first path as default for download | |
| assets_dir = Path("/repository/assets") | |
| bpe_file = assets_dir / "bpe_simple_vocab_16e6.txt.gz" | |
| logger.warning(f" BPE file not found at {bpe_file}") | |
| logger.info(" Downloading from HuggingFace...") | |
| # Create assets directory | |
| assets_dir.mkdir(parents=True, exist_ok=True) | |
| # Try primary method: hf_hub_download | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| logger.info(" Attempting download via hf_hub_download...") | |
| downloaded_path = hf_hub_download( | |
| repo_id="facebook/sam3", | |
| filename="assets/bpe_simple_vocab_16e6.txt.gz", | |
| local_dir="/repository", | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f" ✓ BPE file downloaded: {downloaded_path}") | |
| return downloaded_path | |
| except Exception as e: | |
| logger.warning(f" Primary download failed: {e}") | |
| logger.info(" Trying fallback download method...") | |
| # Fallback: download directly from raw URL | |
| import urllib.request | |
| url = "https://huggingface.co/facebook/sam3/resolve/main/assets/bpe_simple_vocab_16e6.txt.gz" | |
| try: | |
| logger.info(f" Downloading from: {url}") | |
| urllib.request.urlretrieve(url, str(bpe_file)) | |
| logger.info(f" ✓ BPE file downloaded: {bpe_file}") | |
| return str(bpe_file) | |
| except Exception as e2: | |
| logger.error(f" ✗ Fallback download failed: {e2}") | |
| raise ValueError( | |
| f"Could not download BPE tokenizer file. Please add assets/bpe_simple_vocab_16e6.txt.gz " | |
| f"to your repository. Download from: {url}" | |
| ) | |
| def _prepare_video(self, video_data: str, tmpdir: Path) -> Path: | |
| """Decode base64 video and save to file.""" | |
| try: | |
| logger.info(" Decoding base64 data...") | |
| video_bytes = base64.b64decode(video_data) | |
| logger.info(f" Decoded {len(video_bytes)} bytes") | |
| except Exception as e: | |
| logger.error(f" Base64 decode failed: {e}") | |
| raise ValueError(f"Failed to decode base64 video: {e}") | |
| video_path = tmpdir / "input_video.mp4" | |
| video_path.write_bytes(video_bytes) | |
| return video_path | |
| def _save_frame_masks(self, frame_output: Dict, masks_dir: Path, frame_idx: int) -> int: | |
| """ | |
| Save masks for a frame as PNG files. | |
| Each object gets its own mask file: frame_XXXX_obj_Y.png | |
| Returns the number of masks saved. | |
| """ | |
| if "masks" not in frame_output or frame_output["masks"] is None: | |
| return 0 | |
| masks = frame_output["masks"] | |
| object_ids = frame_output.get("object_ids", []) | |
| # Handle different types of object_ids | |
| if torch.is_tensor(object_ids): | |
| object_ids = object_ids.cpu().tolist() | |
| elif isinstance(object_ids, np.ndarray): | |
| object_ids = object_ids.tolist() | |
| elif not isinstance(object_ids, list): | |
| object_ids = list(object_ids) if object_ids is not None else [] | |
| # Convert masks to numpy if tensor | |
| if torch.is_tensor(masks): | |
| masks = masks.cpu().numpy() | |
| # Ensure masks is 3D array [num_objects, height, width] | |
| if len(masks.shape) == 4: | |
| masks = masks[0] | |
| # Save each object's mask | |
| saved_count = 0 | |
| for i, obj_id in enumerate(object_ids): | |
| if i < len(masks): | |
| mask = masks[i] | |
| # Convert to binary (0 or 255) | |
| mask_binary = (mask > 0.5).astype(np.uint8) * 255 | |
| # Save as PNG | |
| mask_img = Image.fromarray(mask_binary) | |
| mask_filename = f"frame_{frame_idx:05d}_obj_{obj_id}.png" | |
| mask_img.save(masks_dir / mask_filename, compress_level=9) | |
| saved_count += 1 | |
| return saved_count | |
| def _create_zip(self, masks_dir: Path, zip_path: Path): | |
| """Create ZIP archive of all mask PNGs.""" | |
| mask_files = sorted(masks_dir.glob("*.png")) | |
| logger.info(f" Creating ZIP with {len(mask_files)} files...") | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED, compresslevel=9) as zipf: | |
| for mask_file in mask_files: | |
| zipf.write(mask_file, mask_file.name) | |
| def _get_video_metadata(self, video_path: Path) -> Dict[str, Any]: | |
| """Extract video metadata using OpenCV.""" | |
| try: | |
| cap = cv2.VideoCapture(str(video_path)) | |
| if not cap.isOpened(): | |
| logger.warning(f" Could not open video file: {video_path}") | |
| return {} | |
| metadata = { | |
| "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), | |
| "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), | |
| "fps": float(cap.get(cv2.CAP_PROP_FPS)), | |
| "frame_count": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), | |
| } | |
| cap.release() | |
| return metadata | |
| except Exception as e: | |
| logger.warning(f" Could not extract video metadata: {e}") | |
| return {} | |
| def _upload_to_hf(self, zip_path: Path, repo_id: str) -> str: | |
| """Upload ZIP file to HuggingFace dataset repository.""" | |
| if not self.hf_api: | |
| raise ValueError("HuggingFace Hub API not initialized. Set HF_TOKEN environment variable.") | |
| try: | |
| # Generate unique filename | |
| import time | |
| timestamp = int(time.time()) | |
| filename = f"masks_{timestamp}.zip" | |
| logger.info(f" Uploading {zip_path.stat().st_size / 1e6:.2f} MB...") | |
| # Upload file | |
| url = self.hf_api.upload_file( | |
| path_or_fileobj=str(zip_path), | |
| path_in_repo=filename, | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| ) | |
| # Return download URL | |
| download_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}" | |
| return download_url | |
| except Exception as e: | |
| logger.error(f" Upload error: {e}") | |
| raise ValueError(f"Failed to upload to HuggingFace: {e}") |