import numpy as np from PIL import Image, ImageDraw, ImageColor import scipy import cv2 import torch import kornia import torch.nn.functional as F def image_to_pil(image): """Convert a numpy array to a PIL Image.""" if isinstance(image, np.ndarray): return Image.fromarray(image) elif isinstance(image, Image.Image): return image else: raise ValueError("Unsupported image type") def image_to_np(image): """Convert a numpy array to a PIL Image.""" if isinstance(image, np.ndarray): return image elif isinstance(image, Image.Image): return np.array(image) else: raise ValueError("Unsupported image type") def get_bbox_center(bbox): x1, y1, x2, y2 = bbox center_x = int((x1 + x2) // 2) center_y = int((y1 + y2) // 2) return (center_x, center_y) def save_mask_to_file( mask, file_path, ): mask = mask.astype(np.uint8) if mask.max() <= 1: mask *= 255 elif mask.max() > 1: pass Image.fromarray(mask).save(file_path) def read_mask_from_file( file_path, ): mask = Image.open(file_path).convert("L") mask = image_to_np(mask) return mask > 0 def bbox_from_mask( mask: np.ndarray | Image.Image | torch.Tensor, ): """ Compute axis-aligned bounding box for a mask (numpy array, PIL.Image, or torch.Tensor). Returns: (min_x, min_y, max_x, max_y) (inclusive coordinates) or None if mask has no positive / True pixels. Rules: - Non-zero (or True) pixels are foreground. - Supports 2D or (H,W,1) masks directly. - For multi-channel masks (H,W,C), foreground = any channel > 0. - For torch tensors, stays on device for reduction (fast), then moves only indices to CPU. """ # Convert PIL to numpy if isinstance(mask, Image.Image): mask = np.array(mask) # Torch path if isinstance(mask, torch.Tensor): m = mask # Ensure at least 2D if m.ndim < 2: return None # If more than 2D, collapse channels/extra dims via any() over non-spatial dims # Assume last two dims are (H,W) if m.ndim > 2: # Move all non-spatial dims to a single dim then reduce # Example shapes: # (H,W,1) -> squeeze # (C,H,W) -> any over C # (B,1,H,W) -> any over B & channel # Strategy: bring H,W to end and flatten others. # Easier: identify H,W as last two dims. spatial_h, spatial_w = m.shape[-2], m.shape[-1] if m.shape[:-2] != (): m = (m != 0).any(dim=tuple(range(0, m.ndim - 2))) m = m.to(torch.bool) else: m = m != 0 if m.dtype != torch.bool: m = m != 0 if not m.any(): return None # Find rows / cols with any foreground rows = torch.any(m, dim=1) cols = torch.any(m, dim=0) y_idx = torch.nonzero(rows, as_tuple=False).squeeze(1) x_idx = torch.nonzero(cols, as_tuple=False).squeeze(1) y_min = int(y_idx[0].item()) y_max = int(y_idx[-1].item()) x_min = int(x_idx[0].item()) x_max = int(x_idx[-1].item()) return (x_min, y_min, x_max, y_max) # Numpy path mask_np = np.asarray(mask) if mask_np.ndim < 2: return None # Handle channels if mask_np.ndim == 3: if mask_np.shape[2] == 1: mask_np = mask_np[..., 0] else: mask_np = np.any(mask_np != 0, axis=2) fg = mask_np != 0 if not fg.any(): return None y_indices, x_indices = np.where(fg) y_min, y_max = int(y_indices.min()), int(y_indices.max()) x_min, x_max = int(x_indices.min()), int(x_indices.max()) return (x_min, y_min, x_max, y_max) def remove_small_components(mask, min_size=10): labeled, nlabels = scipy.ndimage.label(mask) for idx in range(1, nlabels + 1): if np.sum(labeled == idx) < min_size: labeled[labeled == idx] = 0 return (labeled > 0).astype(np.uint8) * 255 def draw_bbox_on_image( image: np.ndarray | Image.Image, bbox, color="yellow", width=3, ): """Draw a bounding box on an image.""" if image is None or bbox is None: return image image = image.copy() image = image_to_pil(image) draw = ImageDraw.Draw(image) x1, y1, x2, y2 = bbox draw.rectangle( [x1, y1, x2, y2], outline=color, width=width, ) return image def draw_mask_on_image( image: np.ndarray | Image.Image | None, mask: np.ndarray | Image.Image | None, mask_color: str | list[int] | tuple[int, int, int] = [30, 255, 144], alpha: float = 0.3, ): """ Draw a binary mask overlay on an image. mask_color can be: - string (e.g. "red", "#ff0000", "#f00") - list/tuple/np.ndarray of 3 ints/floats in 0..255 (R,G,B) alpha: 0..1 overlay opacity. """ if image is None or mask is None: return image if not (0.0 <= alpha <= 1.0): raise ValueError("alpha must be between 0 and 1") # Normalize mask_color to (R,G,B) if isinstance(mask_color, str): rgb = ImageColor.getrgb(mask_color) elif isinstance(mask_color, (list, tuple, np.ndarray)): if len(mask_color) != 3: raise ValueError("mask_color list/tuple must have length 3") rgb = tuple(int(round(float(c))) for c in mask_color) else: raise ValueError("Unsupported mask_color type") rgb = tuple(np.clip(rgb, 0, 255)) image = image.copy() image = image_to_pil(image) mask = image_to_np(mask) # Binarize mask mask_bin = (mask > 0).astype(np.uint8) if mask_bin.ndim != 2: raise ValueError("mask must be 2D after binarization") h, w = mask_bin.shape # Build RGBA overlay overlay = np.zeros((h, w, 4), dtype=np.uint8) overlay[..., 0] = rgb[0] overlay[..., 1] = rgb[1] overlay[..., 2] = rgb[2] overlay[..., 3] = ( (alpha * 255).astype(np.uint8) if isinstance(alpha, np.ndarray) else int(alpha * 255) ) # Zero alpha where mask is 0 overlay[mask_bin == 0, 3] = 0 masked_image = Image.alpha_composite( image.convert("RGBA"), Image.fromarray(overlay), ) return masked_image def draw_mask_bbox_on_image( image, mask, mask_color: list[int] = [30, 255, 144], mask_alpha: float = 0.3, bbox_color="yellow", bbox_width=3, ): """Draw a mask and its bounding box on an image.""" image = draw_mask_on_image( image, mask, mask_color=mask_color, alpha=mask_alpha, ) bbox = bbox_from_mask(mask) if bbox is None: return image, None image = draw_bbox_on_image( image, bbox, color=bbox_color, width=bbox_width, ) return image, bbox def draw_points_on_image( image, points: list[tuple], color="red", radius=5, ): image = image.copy() """Draw points on an image.""" assert isinstance(points, list), "points must be a list of tuples" # if color is not a list, change it to a list with length of points if not isinstance(color, list): color = [color] * len(points) assert len(color) == len(points), "color must be a list of the same length as points" # if radius is not a list, change it to a list with length of points if not isinstance(radius, list): radius = [radius] * len(points) assert len(radius) == len(points), "radius must be a list of the same length as points" image = image_to_pil(image) draw = ImageDraw.Draw(image) # draw points, colors, and radius on the image for point, color, r in zip(points, color, radius): x, y = point draw.circle( (x, y), radius=r, fill=color, outline=color, ) return image def draw_lines_on_image( image, points: list[tuple], color="red", width=3, ): """ Draw polyline on image. color can be: - single name / "#rrggbb" / "rrggbb" - list of such specs (length == len(points)-1) """ if image is None: return image if not isinstance(points, list) or len(points) < 2: return image image = image.copy() image = image_to_pil(image) # Normalize color list if not isinstance(color, list): color_list = [color] * (len(points) - 1) else: if len(color) == len(points): color_list = color[:-1] else: color_list = color if len(color_list) != len(points) - 1: raise ValueError("color list length must be len(points)-1 or len(points)") def normalize(c): if isinstance(c, str): c = c.strip() if len(c) == 6 and all(ch in "0123456789abcdefABCDEF" for ch in c): c = "#" + c return ImageColor.getrgb(c) return c # assume tuple color_list = [normalize(c) for c in color_list] draw = ImageDraw.Draw(image) for i in range(len(points) - 1): draw.line([points[i], points[i + 1]], fill=color_list[i], width=width) return image def draw_arrow_on_image( image, start_point: tuple, end_point: tuple, color: str = "white", thickness: int = 5, ): image = image.copy() na = np.array(image) # Draw arrowed line, from start_point to end_point in color with thickness na = cv2.arrowedLine(na, start_point, end_point, color, thickness) return Image.fromarray(na) def trajectory_interpolate_1d( trajectory: list[float], scale: int, ) -> list[float]: """ Interpolate a 1D trajectory to a fixed number of points. Args: trajectory (List[float]): Sequence of scalar values (len >= 2). scale (int): Number of interpolated steps between original samples. Returns: List[float]: Interpolated 1D trajectory of length (L-1)*scale + 1. """ assert isinstance(trajectory, list), "trajectory must be a list" assert len(trajectory) > 1, "trajectory must have at least 2 points" assert isinstance(scale, int), "scale must be an integer" assert scale > 0, "scale must be greater than 0" traj_np = np.asarray(trajectory, dtype=np.float32).reshape(-1) L = traj_np.shape[0] x = np.arange(L, dtype=np.float32) x_new = np.linspace(0, L - 1, (L - 1) * scale + 1, dtype=np.float32) y_new = np.interp(x_new, x, traj_np) # linear 1D interpolation return y_new.tolist() def trajectory_interpolate( trajectory: list[tuple], scale: int, ): """Interpolate a trajectory to a fixed number of points.""" assert isinstance(trajectory, list), "trajectory must be a list of tuples" assert len(trajectory) > 1, "trajectory must have at least 2 points" assert isinstance(scale, int), "scale must be an integer" assert scale > 0, "scale must be greater than 0" original_trajectory_length = len(trajectory) # Convert trajectory to numpy array trajectory_np = np.array(trajectory) # print(f"{trajectory_np = }") trajectory_torch = torch.tensor(trajectory_np, dtype=torch.float32) trajectory_torch_interpolated = torch.nn.functional.interpolate( trajectory_torch.unsqueeze(0).unsqueeze(0), size=((original_trajectory_length - 1) * scale + 1, 2), mode="bilinear", align_corners=True, ).squeeze() # print(f"{trajectory_torch_interpolated = }") interpolated_trajectory = [] for i in range(trajectory_torch_interpolated.shape[0]): x = int(trajectory_torch_interpolated[i, 0].item()) y = int(trajectory_torch_interpolated[i, 1].item()) interpolated_trajectory.append((x, y)) # Return the interpolated trajectory return interpolated_trajectory def dilate_mask( mask: np.ndarray | None, dilate_factor: int = 15, ): if mask is None: return None mask = mask.astype(np.uint8) mask = cv2.dilate(mask, np.ones((dilate_factor, dilate_factor), np.uint8), iterations=1) return mask def dilate_masks( masks: list[np.ndarray], dilate_factor: int = 15, ): return [dilate_mask(mask, dilate_factor) for mask in masks] def shift_masks( ref_mask, deltas: list[tuple[float, float]], ): ref_mask_indices = np.where(ref_mask > 0) # print(f"{ref_mask_indices = }") shifted_masks_indices = [ ( ref_mask_indices[0] + int(delta[0]), ref_mask_indices[1] + int(delta[1]), ) for delta in deltas ] # print(f"{shifted_masks_indices = }") # filter out-of-bounds indices shifted_masks_indices = [ ( np.clip(shifted_mask_indexs[0], 0, ref_mask.shape[0] - 1), np.clip(shifted_mask_indexs[1], 0, ref_mask.shape[1] - 1), ) for shifted_mask_indexs in shifted_masks_indices ] shifted_masks = [] for i, shifted_mask_indexs in enumerate(shifted_masks_indices): shifted_mask = np.zeros_like(ref_mask, dtype=np.uint8) # shifted_mask_indexs = ( # np.clip(shifted_mask_indexs[0], 0, ref_mask.shape[0] - 1), # np.clip(shifted_mask_indexs[1], 0, ref_mask.shape[1] - 1) # ) shifted_mask[shifted_mask_indexs] = 1 shifted_masks.append(shifted_mask) # for i, shifted_mask in enumerate(shifted_masks): # Image.fromarray(shifted_mask * 255).save(f"shifted_mask_{i}.png") return shifted_masks, shifted_masks_indices def rotate_points(points, angle, center=(0.0, 0.0), degrees=True): """ Rotate 2D point(s) around a center by angle. points: array-like of shape (2,) or (N, 2) as [x, y] angle: rotation angle (degrees by default) center: rotation center [cx, cy] degrees: if True, angle is in degrees; otherwise radians """ pts = np.asarray(points, dtype=float) ctr = np.asarray(center, dtype=float) theta = np.deg2rad(angle) if degrees else angle c, s = np.cos(theta), np.sin(theta) R = np.array([[c, -s], [s, c]]) shifted = pts - ctr rotated = shifted @ R.T return rotated + ctr def calculate_angle(vector_1: torch.Tensor, vector_2: torch.Tensor): dot_product = torch.dot(vector_1, vector_2) magnitude_1 = torch.norm(vector_1) magnitude_2 = torch.norm(vector_2) if magnitude_1 == 0 or magnitude_2 == 0: raise ValueError("One of the vectors has zero magnitude, cannot calculate angle.") cos_theta = dot_product / (magnitude_1 * magnitude_2) cos_theta = torch.clamp(cos_theta, -1.0, 1.0) angle_rad = torch.acos(cos_theta) angle_deg = torch.rad2deg(angle_rad) cross_product = vector_1[0] * vector_2[1] - vector_1[1] * vector_2[0] if cross_product < 0: angle_deg = -angle_deg return angle_deg def calculate_angle_from_points( center_points: torch.Tensor, handle_points: torch.Tensor, target_points: torch.Tensor, ): """ center_points (x, y) """ center_points = torch.Tensor(center_points) handle_points = torch.Tensor(handle_points) target_points = torch.Tensor(target_points) v1 = handle_points - center_points v2 = target_points - center_points return calculate_angle(v1, v2) def tensor_2d_translation( tensor: torch.Tensor, translation: tuple[float, float] | torch.Tensor, mode: str = "bilinear", ): """ Translate a 2D tensor by a given translation vector. Always performs the operation in float32 and casts back to the original tensor dtype. """ # Record original dtype (before any conversion) original_dtype = tensor.dtype if isinstance(tensor, torch.Tensor) else torch.float32 if not isinstance(tensor, torch.Tensor): tensor = torch.tensor(tensor) # Convert to float32 for kornia tensor = tensor.to(torch.float32) origin_shape = tensor.shape if len(origin_shape) == 2: tensor = tensor[None, None, ...] elif len(origin_shape) == 3: tensor = tensor[None, ...] if not isinstance(translation, torch.Tensor): translation = torch.tensor(translation, device=tensor.device) translation = translation.to(dtype=torch.float32, device=tensor.device) if translation.ndim == 1: translation = translation.unsqueeze(0) translated_tensor = kornia.geometry.transform.translate( tensor, translation=translation, mode=mode, ) if len(origin_shape) == 2: translated_tensor = translated_tensor[0, 0, ...] elif len(origin_shape) == 3: translated_tensor = translated_tensor[0, ...] # Cast back to original dtype translated_tensor = translated_tensor.to(original_dtype) return translated_tensor def tensor_2d_rotation( tensor: torch.Tensor, angle: float, center=None, mode: str = "bilinear", ): """ Rotate a 2D tensor by a given angle (clockwise). Performs computations in float32; casts result back to original tensor dtype. angle and center are also promoted to float32 internally. """ # Record original dtypes tensor_original_dtype = tensor.dtype if isinstance(tensor, torch.Tensor) else torch.float32 angle_original_dtype = angle.dtype if isinstance(angle, torch.Tensor) else None center_original_dtype = ( (center.dtype if isinstance(center, torch.Tensor) else None) if center is not None else None ) if not isinstance(tensor, torch.Tensor): tensor = torch.tensor(tensor) tensor = tensor.to(torch.float32) origin_shape = tensor.shape if len(origin_shape) == 2: tensor = tensor[None, None, ...] elif len(origin_shape) == 3: tensor = tensor[None, ...] # Clockwise -> negate angle = -angle if not isinstance(angle, torch.Tensor): angle = torch.tensor(angle, device=tensor.device) angle = angle.to(dtype=torch.float32, device=tensor.device) if angle.ndim == 0: angle = angle.unsqueeze(0) if center is not None: if not isinstance(center, torch.Tensor): center = torch.tensor(center, device=tensor.device) center = center.to(dtype=torch.float32, device=tensor.device) rotated_tensor = kornia.geometry.transform.rotate( tensor, angle, center=center, mode=mode, ) if len(origin_shape) == 2: rotated_tensor = rotated_tensor[0, 0, ...] elif len(origin_shape) == 3: rotated_tensor = rotated_tensor[0, ...] # Cast result back rotated_tensor = rotated_tensor.to(tensor_original_dtype) return rotated_tensor def resize_tensor( tensor: torch.Tensor, size: int | tuple[int, int] = None, scale_factor: float | tuple[float, float] = None, mode: str = "bilinear", ) -> torch.Tensor: """ Resize a 2D tensor to a given size. Args: tensor (torch.Tensor): The input tensor to be resized. size (Union[int, Tuple[int, int]]): The target size. If an int is provided, it will be used for both dimensions. scale_factor (Union[float, Tuple[float, float]]): The scale factor for resizing. If provided, it will override the size argument. Returns: torch.Tensor: The resized tensor. """ # if not isinstance(tensor, torch.Tensor): # tensor = torch.tensor(tensor, dtype=torch.float32) origin_shape = tensor.shape if len(origin_shape) == 2: tensor = tensor[None, None, ...] elif len(origin_shape) == 3: tensor = tensor[None, ...] resized_tensor = F.interpolate( tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=(True if mode in ["linear", "bilinear", "bicubic", "trilinear"] else None), ) if len(origin_shape) == 2: resized_tensor = resized_tensor[0, 0, ...] elif len(origin_shape) == 3: resized_tensor = resized_tensor[0, ...] return resized_tensor def warp_tensor( tensor: torch.Tensor, is_rotation: bool, delta, rotation_center: tuple[float, float] | torch.Tensor | None = None, original_height: int | None = None, mode: str = "nearest", ) -> torch.Tensor: """ Warp a tensor by translation or rotation based on a trajectory step. Args: tensor: Tensor to warp. Can be (H, W), (C, H, W), or (B, C, H, W). is_rotation: If True, warp by rotation; otherwise by translation. delta: The delta for this step. For rotation: scalar angle (degrees). For translation: (dx, dy) in original image pixel coordinates. Can be a torch.Tensor, tuple, list, or scalar. rotation_center: (x, y) center of rotation in original image pixel coordinates. Required when is_rotation is True. original_height: The height of the original image at which delta was computed. If provided and differs from tensor's spatial height, delta and rotation_center are rescaled accordingly. If None, no rescaling is applied. mode: Interpolation mode for warping. Returns: Warped tensor with the same shape as input. """ tensor_height = tensor.shape[-2] if original_height is not None and original_height != tensor_height: scale = original_height / tensor_height else: scale = 1.0 if is_rotation: if rotation_center is None: raise ValueError("rotation_center is required when is_rotation is True") if not isinstance(rotation_center, torch.Tensor): rotation_center = torch.tensor( rotation_center, dtype=tensor.dtype, device=tensor.device ) center = rotation_center.to(dtype=tensor.dtype, device=tensor.device) / scale return tensor_2d_rotation(tensor, angle=delta, center=center, mode=mode) else: # delta can be a tuple/list/tensor; tensor_2d_translation handles conversion if isinstance(delta, torch.Tensor): return tensor_2d_translation(tensor, translation=delta / scale, mode=mode) else: # For tuple/list/scalar, scale manually before passing delta_scaled = tuple(d / scale for d in delta) return tensor_2d_translation(tensor, translation=delta_scaled, mode=mode) def warp_tensor_sequence( tensor: torch.Tensor, is_rotation: bool, deltas: list, rotation_center: tuple[float, float] | torch.Tensor | None = None, original_height: int | None = None, mode: str = "nearest", cumulative: bool = False, ) -> list[torch.Tensor]: """ Warp a tensor by a sequence of deltas, returning a list of warped tensors. Args: tensor: Tensor to warp. Can be (H, W), (C, H, W), or (B, C, H, W). is_rotation: If True, warp by rotation; otherwise by translation. deltas: List of deltas for each step. For rotation: each is a scalar angle (degrees). For translation: each is (dx, dy) in original image pixel coordinates. Each delta can be a torch.Tensor, tuple, list, or scalar. rotation_center: (x, y) center of rotation in original image pixel coordinates. Required when is_rotation is True. original_height: The height of the original image at which deltas were computed. If provided and differs from tensor's spatial height, deltas and rotation_center are rescaled accordingly. If None, no rescaling is applied. mode: Interpolation mode for warping. cumulative: If True, each warp is applied on top of the previous result (i.e. sequential composition). If False, each delta is applied independently to the original tensor. Returns: List of warped tensors, one per delta, each with the same shape as input. """ warped_tensors = [] current = tensor for delta in deltas: source = current if cumulative else tensor warped = warp_tensor( source, is_rotation=is_rotation, delta=delta, rotation_center=rotation_center, original_height=original_height, mode=mode, ) warped_tensors.append(warped) if cumulative: current = warped return warped_tensors def combine_masks_or( masks: list[torch.Tensor | np.ndarray], ) -> torch.Tensor | np.ndarray: """ Combine a list of binary masks using logical OR (union). Each mask is assumed to be a 2D tensor/array with values in [0, 1]. The result is clamped to [0, 1]. Returns a tensor if any input is a tensor, otherwise a numpy array. """ if len(masks) == 0: raise ValueError("masks list is empty") result = masks[0].clone() if isinstance(masks[0], torch.Tensor) else masks[0].copy() for m in masks[1:]: result = result + m if isinstance(result, torch.Tensor): result = torch.clamp(result, 0, 1) else: result = np.clip(result, 0, 1) return result def record_tensor_statics( tensor: torch.Tensor, axis=None, keepdim=False, ): mean = tensor.detach().mean(axis, keepdim=keepdim) std = tensor.detach().std(axis, keepdim=keepdim) tensor_max = tensor.detach().amax(axis, keepdim=keepdim) tensor_min = tensor.detach().amin(axis, keepdim=keepdim) return mean, std, tensor_max, tensor_min def normalize_tensor( tensor, dim, target_mean, target_std, ): """ Normalize a tensor along a specified dimension. """ mean = tensor.mean(dim=dim, keepdim=True) std = tensor.std(dim=dim, keepdim=True) assert mean.shape == target_mean.shape == std.shape == target_std.shape new_tensor = (tensor - mean) / std new_tensor = new_tensor * target_std + target_mean return new_tensor def normalize_tensor_to_match_tensor( target_tensor, dim, reference_tensor, ): reference_mean, reference_std, reference_max, reference_min = record_tensor_statics( reference_tensor, axis=dim, keepdim=True, ) return normalize_tensor( target_tensor, dim=dim, target_mean=reference_mean, target_std=reference_std, ) def build_gaussian_focus_map( h: int, w: int, center_y: float, center_x: float, radius: float, sigma: float | None = None, device: torch.device | None = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Build a (h, w) gaussian focus map: - Inside circle (dist <= r): weight = 1 - Outside: weight = exp(- ((dist - r)^2) / (2 * sigma^2)) sigma defaults to radius / 2 if not provided. Returned shape: [1, 1, 1, h, w] ready for broadcasting over [B, F, C, h, w]. """ if sigma is None: sigma = max(1e-6, radius / 2.0) yy = torch.arange(h, device=device, dtype=dtype).view(h, 1) xx = torch.arange(w, device=device, dtype=dtype).view(1, w) dist = torch.sqrt((yy - center_y) ** 2 + (xx - center_x) ** 2) outside = (dist - radius).clamp_min(0.0) outside_weight = torch.exp(-(outside**2) / (2.0 * sigma**2)) weight = torch.where(dist <= radius, torch.ones_like(dist), outside_weight) return weight.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1,1,1,h,w] def build_anisotropic_gaussian( H: int, W: int, center_x: float, center_y: float, sigma_x: float, sigma_y: float, # *, clamp: bool = True, normalize: bool = True, min_value: float = 0.0, device: torch.device | None = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Core builder: create anisotropic Gaussian over (H,W). G(y,x) = exp( - ( (x-cx)^2 / (2 sigma_x^2) + (y-cy)^2 / (2 sigma_y^2) ) ) Returns shape [H,W]. center_x, center_y: float (pixel coordinates) sigma_x, sigma_y: positive float """ sigma_x = max(1e-6, float(sigma_x)) sigma_y = max(1e-6, float(sigma_y)) yy = torch.arange(H, device=device, dtype=dtype).view(H, 1) xx = torch.arange(W, device=device, dtype=dtype).view(1, W) gx = (xx - center_x) ** 2 / (2.0 * sigma_x * sigma_x) gy = (yy - center_y) ** 2 / (2.0 * sigma_y * sigma_y) gauss = torch.exp(-(gx + gy)) if normalize: m = gauss.max() if m > 0: gauss = gauss / m if clamp: gauss = gauss.clamp_(min_value, 1.0) return gauss def build_anisotropic_gaussian_from_bbox( H: int, W: int, y_min: int, y_max: int, x_min: int, x_max: int, # *, padding_scale: float = 0.15, sigma_scale: float = 0.5, min_sigma: float = 1.0, clamp: bool = True, normalize: bool = True, device: torch.device | None = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Compute center & (sigma_x, sigma_y) from a bounding box, then call build_anisotropic_gaussian. sigma_x = ( (bbox_width * (1+padding_scale))/2 ) * sigma_scale sigma_y = ( (bbox_height * (1+padding_scale))/2 ) * sigma_scale Both clamped by min_sigma. """ # Center center_y = 0.5 * (y_min + y_max) center_x = 0.5 * (x_min + x_max) bbox_h = y_max - y_min + 1 bbox_w = x_max - x_min + 1 eff_h = bbox_h * (1.0 + padding_scale) eff_w = bbox_w * (1.0 + padding_scale) sigma_y = max(min_sigma, 0.5 * eff_h * sigma_scale) sigma_x = max(min_sigma, 0.5 * eff_w * sigma_scale) return build_anisotropic_gaussian( H=H, W=W, center_x=center_x, center_y=center_y, sigma_x=sigma_x, sigma_y=sigma_y, clamp=clamp, normalize=normalize, device=device, dtype=dtype, ) def build_anisotropic_gaussian_from_mask( mask: np.ndarray | Image.Image | torch.Tensor, # *, padding_scale: float = 0.15, sigma_scale: float = 0.5, min_sigma: float = 1.0, clamp: bool = True, normalize: bool = True, device: torch.device | None = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor | None: """ Compute bounding box from mask, then call build_anisotropic_gaussian_from_bbox. Returns None if mask has no positive pixels. """ bbox = bbox_from_mask(mask) if bbox is None: return None x_min, y_min, x_max, y_max = bbox # print(f"{bbox = }") if isinstance(mask, torch.Tensor): H, W = mask.shape[-2], mask.shape[-1] else: mask_np = np.asarray(mask) H, W = mask_np.shape[-2], mask_np.shape[-1] return build_anisotropic_gaussian_from_bbox( H=H, W=W, y_min=y_min, y_max=y_max, x_min=x_min, x_max=x_max, padding_scale=padding_scale, sigma_scale=sigma_scale, min_sigma=min_sigma, clamp=clamp, normalize=normalize, device=mask.device if isinstance(mask, torch.Tensor) else device, dtype=dtype, ) def combine_gaussian_maps( maps: list[torch.Tensor], mode: str = "prob_or", clamp: bool = True, ) -> torch.Tensor: """ Combine multiple Gaussian (or weight) maps into one in [0,1]. Args: maps: list of tensors with identical shape (e.g. [1,1,1,H,W] or [H,W]). mode: - "prob_or": 1 - prod(1 - g) (smooth union, fast saturation) - "sum_clamp": clamp(sum(g), 0, 1) - "sum_norm": sum(g) / max(sum(g)) - "max": elementwise max clamp: final clamp to [0,1] (except sum_norm which is already normalized). Returns: Combined tensor. """ assert len(maps) > 0 if len(maps) == 1: out = maps[0] return out.clamp_(0, 1) if clamp else out stacked = torch.stack(maps, dim=0) if mode == "prob_or": out = 1.0 - torch.prod(1.0 - stacked, dim=0) elif mode == "sum_clamp": out = stacked.sum(dim=0) if clamp: out = out.clamp_(0.0, 1.0) elif mode == "sum_norm": out = stacked.sum(dim=0) maxv = out.max() if maxv > 0: out = out / maxv if clamp: out = out.clamp_(0.0, 1.0) elif mode == "max": out, _ = stacked.max(dim=0) if clamp: out = out.clamp_(0.0, 1.0) else: raise ValueError(f"Unknown mode: {mode}") return out