DragStream / tensor_utils.py
bowmanchow's picture
add code
0328207
import numpy as np
from PIL import Image, ImageDraw, ImageColor
import scipy
import cv2
import torch
import kornia
import torch.nn.functional as F
def image_to_pil(image):
"""Convert a numpy array to a PIL Image."""
if isinstance(image, np.ndarray):
return Image.fromarray(image)
elif isinstance(image, Image.Image):
return image
else:
raise ValueError("Unsupported image type")
def image_to_np(image):
"""Convert a numpy array to a PIL Image."""
if isinstance(image, np.ndarray):
return image
elif isinstance(image, Image.Image):
return np.array(image)
else:
raise ValueError("Unsupported image type")
def get_bbox_center(bbox):
x1, y1, x2, y2 = bbox
center_x = int((x1 + x2) // 2)
center_y = int((y1 + y2) // 2)
return (center_x, center_y)
def save_mask_to_file(
mask,
file_path,
):
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask *= 255
elif mask.max() > 1:
pass
Image.fromarray(mask).save(file_path)
def read_mask_from_file(
file_path,
):
mask = Image.open(file_path).convert("L")
mask = image_to_np(mask)
return mask > 0
def bbox_from_mask(
mask: np.ndarray | Image.Image | torch.Tensor,
):
"""
Compute axis-aligned bounding box for a mask (numpy array, PIL.Image, or torch.Tensor).
Returns:
(min_x, min_y, max_x, max_y) (inclusive coordinates)
or None if mask has no positive / True pixels.
Rules:
- Non-zero (or True) pixels are foreground.
- Supports 2D or (H,W,1) masks directly.
- For multi-channel masks (H,W,C), foreground = any channel > 0.
- For torch tensors, stays on device for reduction (fast), then moves only indices to CPU.
"""
# Convert PIL to numpy
if isinstance(mask, Image.Image):
mask = np.array(mask)
# Torch path
if isinstance(mask, torch.Tensor):
m = mask
# Ensure at least 2D
if m.ndim < 2:
return None
# If more than 2D, collapse channels/extra dims via any() over non-spatial dims
# Assume last two dims are (H,W)
if m.ndim > 2:
# Move all non-spatial dims to a single dim then reduce
# Example shapes:
# (H,W,1) -> squeeze
# (C,H,W) -> any over C
# (B,1,H,W) -> any over B & channel
# Strategy: bring H,W to end and flatten others.
# Easier: identify H,W as last two dims.
spatial_h, spatial_w = m.shape[-2], m.shape[-1]
if m.shape[:-2] != ():
m = (m != 0).any(dim=tuple(range(0, m.ndim - 2)))
m = m.to(torch.bool)
else:
m = m != 0
if m.dtype != torch.bool:
m = m != 0
if not m.any():
return None
# Find rows / cols with any foreground
rows = torch.any(m, dim=1)
cols = torch.any(m, dim=0)
y_idx = torch.nonzero(rows, as_tuple=False).squeeze(1)
x_idx = torch.nonzero(cols, as_tuple=False).squeeze(1)
y_min = int(y_idx[0].item())
y_max = int(y_idx[-1].item())
x_min = int(x_idx[0].item())
x_max = int(x_idx[-1].item())
return (x_min, y_min, x_max, y_max)
# Numpy path
mask_np = np.asarray(mask)
if mask_np.ndim < 2:
return None
# Handle channels
if mask_np.ndim == 3:
if mask_np.shape[2] == 1:
mask_np = mask_np[..., 0]
else:
mask_np = np.any(mask_np != 0, axis=2)
fg = mask_np != 0
if not fg.any():
return None
y_indices, x_indices = np.where(fg)
y_min, y_max = int(y_indices.min()), int(y_indices.max())
x_min, x_max = int(x_indices.min()), int(x_indices.max())
return (x_min, y_min, x_max, y_max)
def remove_small_components(mask, min_size=10):
labeled, nlabels = scipy.ndimage.label(mask)
for idx in range(1, nlabels + 1):
if np.sum(labeled == idx) < min_size:
labeled[labeled == idx] = 0
return (labeled > 0).astype(np.uint8) * 255
def draw_bbox_on_image(
image: np.ndarray | Image.Image,
bbox,
color="yellow",
width=3,
):
"""Draw a bounding box on an image."""
if image is None or bbox is None:
return image
image = image.copy()
image = image_to_pil(image)
draw = ImageDraw.Draw(image)
x1, y1, x2, y2 = bbox
draw.rectangle(
[x1, y1, x2, y2],
outline=color,
width=width,
)
return image
def draw_mask_on_image(
image: np.ndarray | Image.Image | None,
mask: np.ndarray | Image.Image | None,
mask_color: str | list[int] | tuple[int, int, int] = [30, 255, 144],
alpha: float = 0.3,
):
"""
Draw a binary mask overlay on an image.
mask_color can be:
- string (e.g. "red", "#ff0000", "#f00")
- list/tuple/np.ndarray of 3 ints/floats in 0..255 (R,G,B)
alpha: 0..1 overlay opacity.
"""
if image is None or mask is None:
return image
if not (0.0 <= alpha <= 1.0):
raise ValueError("alpha must be between 0 and 1")
# Normalize mask_color to (R,G,B)
if isinstance(mask_color, str):
rgb = ImageColor.getrgb(mask_color)
elif isinstance(mask_color, (list, tuple, np.ndarray)):
if len(mask_color) != 3:
raise ValueError("mask_color list/tuple must have length 3")
rgb = tuple(int(round(float(c))) for c in mask_color)
else:
raise ValueError("Unsupported mask_color type")
rgb = tuple(np.clip(rgb, 0, 255))
image = image.copy()
image = image_to_pil(image)
mask = image_to_np(mask)
# Binarize mask
mask_bin = (mask > 0).astype(np.uint8)
if mask_bin.ndim != 2:
raise ValueError("mask must be 2D after binarization")
h, w = mask_bin.shape
# Build RGBA overlay
overlay = np.zeros((h, w, 4), dtype=np.uint8)
overlay[..., 0] = rgb[0]
overlay[..., 1] = rgb[1]
overlay[..., 2] = rgb[2]
overlay[..., 3] = (
(alpha * 255).astype(np.uint8) if isinstance(alpha, np.ndarray) else int(alpha * 255)
)
# Zero alpha where mask is 0
overlay[mask_bin == 0, 3] = 0
masked_image = Image.alpha_composite(
image.convert("RGBA"),
Image.fromarray(overlay),
)
return masked_image
def draw_mask_bbox_on_image(
image,
mask,
mask_color: list[int] = [30, 255, 144],
mask_alpha: float = 0.3,
bbox_color="yellow",
bbox_width=3,
):
"""Draw a mask and its bounding box on an image."""
image = draw_mask_on_image(
image,
mask,
mask_color=mask_color,
alpha=mask_alpha,
)
bbox = bbox_from_mask(mask)
if bbox is None:
return image, None
image = draw_bbox_on_image(
image,
bbox,
color=bbox_color,
width=bbox_width,
)
return image, bbox
def draw_points_on_image(
image,
points: list[tuple],
color="red",
radius=5,
):
image = image.copy()
"""Draw points on an image."""
assert isinstance(points, list), "points must be a list of tuples"
# if color is not a list, change it to a list with length of points
if not isinstance(color, list):
color = [color] * len(points)
assert len(color) == len(points), "color must be a list of the same length as points"
# if radius is not a list, change it to a list with length of points
if not isinstance(radius, list):
radius = [radius] * len(points)
assert len(radius) == len(points), "radius must be a list of the same length as points"
image = image_to_pil(image)
draw = ImageDraw.Draw(image)
# draw points, colors, and radius on the image
for point, color, r in zip(points, color, radius):
x, y = point
draw.circle(
(x, y),
radius=r,
fill=color,
outline=color,
)
return image
def draw_lines_on_image(
image,
points: list[tuple],
color="red",
width=3,
):
"""
Draw polyline on image.
color can be:
- single name / "#rrggbb" / "rrggbb"
- list of such specs (length == len(points)-1)
"""
if image is None:
return image
if not isinstance(points, list) or len(points) < 2:
return image
image = image.copy()
image = image_to_pil(image)
# Normalize color list
if not isinstance(color, list):
color_list = [color] * (len(points) - 1)
else:
if len(color) == len(points):
color_list = color[:-1]
else:
color_list = color
if len(color_list) != len(points) - 1:
raise ValueError("color list length must be len(points)-1 or len(points)")
def normalize(c):
if isinstance(c, str):
c = c.strip()
if len(c) == 6 and all(ch in "0123456789abcdefABCDEF" for ch in c):
c = "#" + c
return ImageColor.getrgb(c)
return c # assume tuple
color_list = [normalize(c) for c in color_list]
draw = ImageDraw.Draw(image)
for i in range(len(points) - 1):
draw.line([points[i], points[i + 1]], fill=color_list[i], width=width)
return image
def draw_arrow_on_image(
image,
start_point: tuple,
end_point: tuple,
color: str = "white",
thickness: int = 5,
):
image = image.copy()
na = np.array(image)
# Draw arrowed line, from start_point to end_point in color with thickness
na = cv2.arrowedLine(na, start_point, end_point, color, thickness)
return Image.fromarray(na)
def trajectory_interpolate_1d(
trajectory: list[float],
scale: int,
) -> list[float]:
"""
Interpolate a 1D trajectory to a fixed number of points.
Args:
trajectory (List[float]): Sequence of scalar values (len >= 2).
scale (int): Number of interpolated steps between original samples.
Returns:
List[float]: Interpolated 1D trajectory of length (L-1)*scale + 1.
"""
assert isinstance(trajectory, list), "trajectory must be a list"
assert len(trajectory) > 1, "trajectory must have at least 2 points"
assert isinstance(scale, int), "scale must be an integer"
assert scale > 0, "scale must be greater than 0"
traj_np = np.asarray(trajectory, dtype=np.float32).reshape(-1)
L = traj_np.shape[0]
x = np.arange(L, dtype=np.float32)
x_new = np.linspace(0, L - 1, (L - 1) * scale + 1, dtype=np.float32)
y_new = np.interp(x_new, x, traj_np) # linear 1D interpolation
return y_new.tolist()
def trajectory_interpolate(
trajectory: list[tuple],
scale: int,
):
"""Interpolate a trajectory to a fixed number of points."""
assert isinstance(trajectory, list), "trajectory must be a list of tuples"
assert len(trajectory) > 1, "trajectory must have at least 2 points"
assert isinstance(scale, int), "scale must be an integer"
assert scale > 0, "scale must be greater than 0"
original_trajectory_length = len(trajectory)
# Convert trajectory to numpy array
trajectory_np = np.array(trajectory)
# print(f"{trajectory_np = }")
trajectory_torch = torch.tensor(trajectory_np, dtype=torch.float32)
trajectory_torch_interpolated = torch.nn.functional.interpolate(
trajectory_torch.unsqueeze(0).unsqueeze(0),
size=((original_trajectory_length - 1) * scale + 1, 2),
mode="bilinear",
align_corners=True,
).squeeze()
# print(f"{trajectory_torch_interpolated = }")
interpolated_trajectory = []
for i in range(trajectory_torch_interpolated.shape[0]):
x = int(trajectory_torch_interpolated[i, 0].item())
y = int(trajectory_torch_interpolated[i, 1].item())
interpolated_trajectory.append((x, y))
# Return the interpolated trajectory
return interpolated_trajectory
def dilate_mask(
mask: np.ndarray | None,
dilate_factor: int = 15,
):
if mask is None:
return None
mask = mask.astype(np.uint8)
mask = cv2.dilate(mask, np.ones((dilate_factor, dilate_factor), np.uint8), iterations=1)
return mask
def dilate_masks(
masks: list[np.ndarray],
dilate_factor: int = 15,
):
return [dilate_mask(mask, dilate_factor) for mask in masks]
def shift_masks(
ref_mask,
deltas: list[tuple[float, float]],
):
ref_mask_indices = np.where(ref_mask > 0)
# print(f"{ref_mask_indices = }")
shifted_masks_indices = [
(
ref_mask_indices[0] + int(delta[0]),
ref_mask_indices[1] + int(delta[1]),
)
for delta in deltas
]
# print(f"{shifted_masks_indices = }")
# filter out-of-bounds indices
shifted_masks_indices = [
(
np.clip(shifted_mask_indexs[0], 0, ref_mask.shape[0] - 1),
np.clip(shifted_mask_indexs[1], 0, ref_mask.shape[1] - 1),
)
for shifted_mask_indexs in shifted_masks_indices
]
shifted_masks = []
for i, shifted_mask_indexs in enumerate(shifted_masks_indices):
shifted_mask = np.zeros_like(ref_mask, dtype=np.uint8)
# shifted_mask_indexs = (
# np.clip(shifted_mask_indexs[0], 0, ref_mask.shape[0] - 1),
# np.clip(shifted_mask_indexs[1], 0, ref_mask.shape[1] - 1)
# )
shifted_mask[shifted_mask_indexs] = 1
shifted_masks.append(shifted_mask)
# for i, shifted_mask in enumerate(shifted_masks):
# Image.fromarray(shifted_mask * 255).save(f"shifted_mask_{i}.png")
return shifted_masks, shifted_masks_indices
def rotate_points(points, angle, center=(0.0, 0.0), degrees=True):
"""
Rotate 2D point(s) around a center by angle.
points: array-like of shape (2,) or (N, 2) as [x, y]
angle: rotation angle (degrees by default)
center: rotation center [cx, cy]
degrees: if True, angle is in degrees; otherwise radians
"""
pts = np.asarray(points, dtype=float)
ctr = np.asarray(center, dtype=float)
theta = np.deg2rad(angle) if degrees else angle
c, s = np.cos(theta), np.sin(theta)
R = np.array([[c, -s], [s, c]])
shifted = pts - ctr
rotated = shifted @ R.T
return rotated + ctr
def calculate_angle(vector_1: torch.Tensor, vector_2: torch.Tensor):
dot_product = torch.dot(vector_1, vector_2)
magnitude_1 = torch.norm(vector_1)
magnitude_2 = torch.norm(vector_2)
if magnitude_1 == 0 or magnitude_2 == 0:
raise ValueError("One of the vectors has zero magnitude, cannot calculate angle.")
cos_theta = dot_product / (magnitude_1 * magnitude_2)
cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
angle_rad = torch.acos(cos_theta)
angle_deg = torch.rad2deg(angle_rad)
cross_product = vector_1[0] * vector_2[1] - vector_1[1] * vector_2[0]
if cross_product < 0:
angle_deg = -angle_deg
return angle_deg
def calculate_angle_from_points(
center_points: torch.Tensor,
handle_points: torch.Tensor,
target_points: torch.Tensor,
):
"""
center_points (x, y)
"""
center_points = torch.Tensor(center_points)
handle_points = torch.Tensor(handle_points)
target_points = torch.Tensor(target_points)
v1 = handle_points - center_points
v2 = target_points - center_points
return calculate_angle(v1, v2)
def tensor_2d_translation(
tensor: torch.Tensor,
translation: tuple[float, float] | torch.Tensor,
mode: str = "bilinear",
):
"""
Translate a 2D tensor by a given translation vector.
Always performs the operation in float32 and casts back to the original tensor dtype.
"""
# Record original dtype (before any conversion)
original_dtype = tensor.dtype if isinstance(tensor, torch.Tensor) else torch.float32
if not isinstance(tensor, torch.Tensor):
tensor = torch.tensor(tensor)
# Convert to float32 for kornia
tensor = tensor.to(torch.float32)
origin_shape = tensor.shape
if len(origin_shape) == 2:
tensor = tensor[None, None, ...]
elif len(origin_shape) == 3:
tensor = tensor[None, ...]
if not isinstance(translation, torch.Tensor):
translation = torch.tensor(translation, device=tensor.device)
translation = translation.to(dtype=torch.float32, device=tensor.device)
if translation.ndim == 1:
translation = translation.unsqueeze(0)
translated_tensor = kornia.geometry.transform.translate(
tensor,
translation=translation,
mode=mode,
)
if len(origin_shape) == 2:
translated_tensor = translated_tensor[0, 0, ...]
elif len(origin_shape) == 3:
translated_tensor = translated_tensor[0, ...]
# Cast back to original dtype
translated_tensor = translated_tensor.to(original_dtype)
return translated_tensor
def tensor_2d_rotation(
tensor: torch.Tensor,
angle: float,
center=None,
mode: str = "bilinear",
):
"""
Rotate a 2D tensor by a given angle (clockwise).
Performs computations in float32; casts result back to original tensor dtype.
angle and center are also promoted to float32 internally.
"""
# Record original dtypes
tensor_original_dtype = tensor.dtype if isinstance(tensor, torch.Tensor) else torch.float32
angle_original_dtype = angle.dtype if isinstance(angle, torch.Tensor) else None
center_original_dtype = (
(center.dtype if isinstance(center, torch.Tensor) else None) if center is not None else None
)
if not isinstance(tensor, torch.Tensor):
tensor = torch.tensor(tensor)
tensor = tensor.to(torch.float32)
origin_shape = tensor.shape
if len(origin_shape) == 2:
tensor = tensor[None, None, ...]
elif len(origin_shape) == 3:
tensor = tensor[None, ...]
# Clockwise -> negate
angle = -angle
if not isinstance(angle, torch.Tensor):
angle = torch.tensor(angle, device=tensor.device)
angle = angle.to(dtype=torch.float32, device=tensor.device)
if angle.ndim == 0:
angle = angle.unsqueeze(0)
if center is not None:
if not isinstance(center, torch.Tensor):
center = torch.tensor(center, device=tensor.device)
center = center.to(dtype=torch.float32, device=tensor.device)
rotated_tensor = kornia.geometry.transform.rotate(
tensor,
angle,
center=center,
mode=mode,
)
if len(origin_shape) == 2:
rotated_tensor = rotated_tensor[0, 0, ...]
elif len(origin_shape) == 3:
rotated_tensor = rotated_tensor[0, ...]
# Cast result back
rotated_tensor = rotated_tensor.to(tensor_original_dtype)
return rotated_tensor
def resize_tensor(
tensor: torch.Tensor,
size: int | tuple[int, int] = None,
scale_factor: float | tuple[float, float] = None,
mode: str = "bilinear",
) -> torch.Tensor:
"""
Resize a 2D tensor to a given size.
Args:
tensor (torch.Tensor): The input tensor to be resized.
size (Union[int, Tuple[int, int]]): The target size. If an int is provided, it will be used for both dimensions.
scale_factor (Union[float, Tuple[float, float]]): The scale factor for resizing. If provided, it will override the size argument.
Returns:
torch.Tensor: The resized tensor.
"""
# if not isinstance(tensor, torch.Tensor):
# tensor = torch.tensor(tensor, dtype=torch.float32)
origin_shape = tensor.shape
if len(origin_shape) == 2:
tensor = tensor[None, None, ...]
elif len(origin_shape) == 3:
tensor = tensor[None, ...]
resized_tensor = F.interpolate(
tensor,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=(True if mode in ["linear", "bilinear", "bicubic", "trilinear"] else None),
)
if len(origin_shape) == 2:
resized_tensor = resized_tensor[0, 0, ...]
elif len(origin_shape) == 3:
resized_tensor = resized_tensor[0, ...]
return resized_tensor
def warp_tensor(
tensor: torch.Tensor,
is_rotation: bool,
delta,
rotation_center: tuple[float, float] | torch.Tensor | None = None,
original_height: int | None = None,
mode: str = "nearest",
) -> torch.Tensor:
"""
Warp a tensor by translation or rotation based on a trajectory step.
Args:
tensor: Tensor to warp. Can be (H, W), (C, H, W), or (B, C, H, W).
is_rotation: If True, warp by rotation; otherwise by translation.
delta: The delta for this step. For rotation: scalar angle (degrees).
For translation: (dx, dy) in original image pixel coordinates.
Can be a torch.Tensor, tuple, list, or scalar.
rotation_center: (x, y) center of rotation in original image pixel coordinates.
Required when is_rotation is True.
original_height: The height of the original image at which delta was computed.
If provided and differs from tensor's spatial height, delta and
rotation_center are rescaled accordingly.
If None, no rescaling is applied.
mode: Interpolation mode for warping.
Returns:
Warped tensor with the same shape as input.
"""
tensor_height = tensor.shape[-2]
if original_height is not None and original_height != tensor_height:
scale = original_height / tensor_height
else:
scale = 1.0
if is_rotation:
if rotation_center is None:
raise ValueError("rotation_center is required when is_rotation is True")
if not isinstance(rotation_center, torch.Tensor):
rotation_center = torch.tensor(
rotation_center, dtype=tensor.dtype, device=tensor.device
)
center = rotation_center.to(dtype=tensor.dtype, device=tensor.device) / scale
return tensor_2d_rotation(tensor, angle=delta, center=center, mode=mode)
else:
# delta can be a tuple/list/tensor; tensor_2d_translation handles conversion
if isinstance(delta, torch.Tensor):
return tensor_2d_translation(tensor, translation=delta / scale, mode=mode)
else:
# For tuple/list/scalar, scale manually before passing
delta_scaled = tuple(d / scale for d in delta)
return tensor_2d_translation(tensor, translation=delta_scaled, mode=mode)
def warp_tensor_sequence(
tensor: torch.Tensor,
is_rotation: bool,
deltas: list,
rotation_center: tuple[float, float] | torch.Tensor | None = None,
original_height: int | None = None,
mode: str = "nearest",
cumulative: bool = False,
) -> list[torch.Tensor]:
"""
Warp a tensor by a sequence of deltas, returning a list of warped tensors.
Args:
tensor: Tensor to warp. Can be (H, W), (C, H, W), or (B, C, H, W).
is_rotation: If True, warp by rotation; otherwise by translation.
deltas: List of deltas for each step. For rotation: each is a scalar angle (degrees).
For translation: each is (dx, dy) in original image pixel coordinates.
Each delta can be a torch.Tensor, tuple, list, or scalar.
rotation_center: (x, y) center of rotation in original image pixel coordinates.
Required when is_rotation is True.
original_height: The height of the original image at which deltas were computed.
If provided and differs from tensor's spatial height, deltas and
rotation_center are rescaled accordingly.
If None, no rescaling is applied.
mode: Interpolation mode for warping.
cumulative: If True, each warp is applied on top of the previous result
(i.e. sequential composition). If False, each delta is applied
independently to the original tensor.
Returns:
List of warped tensors, one per delta, each with the same shape as input.
"""
warped_tensors = []
current = tensor
for delta in deltas:
source = current if cumulative else tensor
warped = warp_tensor(
source,
is_rotation=is_rotation,
delta=delta,
rotation_center=rotation_center,
original_height=original_height,
mode=mode,
)
warped_tensors.append(warped)
if cumulative:
current = warped
return warped_tensors
def combine_masks_or(
masks: list[torch.Tensor | np.ndarray],
) -> torch.Tensor | np.ndarray:
"""
Combine a list of binary masks using logical OR (union).
Each mask is assumed to be a 2D tensor/array with values in [0, 1].
The result is clamped to [0, 1].
Returns a tensor if any input is a tensor, otherwise a numpy array.
"""
if len(masks) == 0:
raise ValueError("masks list is empty")
result = masks[0].clone() if isinstance(masks[0], torch.Tensor) else masks[0].copy()
for m in masks[1:]:
result = result + m
if isinstance(result, torch.Tensor):
result = torch.clamp(result, 0, 1)
else:
result = np.clip(result, 0, 1)
return result
def record_tensor_statics(
tensor: torch.Tensor,
axis=None,
keepdim=False,
):
mean = tensor.detach().mean(axis, keepdim=keepdim)
std = tensor.detach().std(axis, keepdim=keepdim)
tensor_max = tensor.detach().amax(axis, keepdim=keepdim)
tensor_min = tensor.detach().amin(axis, keepdim=keepdim)
return mean, std, tensor_max, tensor_min
def normalize_tensor(
tensor,
dim,
target_mean,
target_std,
):
"""
Normalize a tensor along a specified dimension.
"""
mean = tensor.mean(dim=dim, keepdim=True)
std = tensor.std(dim=dim, keepdim=True)
assert mean.shape == target_mean.shape == std.shape == target_std.shape
new_tensor = (tensor - mean) / std
new_tensor = new_tensor * target_std + target_mean
return new_tensor
def normalize_tensor_to_match_tensor(
target_tensor,
dim,
reference_tensor,
):
reference_mean, reference_std, reference_max, reference_min = record_tensor_statics(
reference_tensor,
axis=dim,
keepdim=True,
)
return normalize_tensor(
target_tensor,
dim=dim,
target_mean=reference_mean,
target_std=reference_std,
)
def build_gaussian_focus_map(
h: int,
w: int,
center_y: float,
center_x: float,
radius: float,
sigma: float | None = None,
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Build a (h, w) gaussian focus map:
- Inside circle (dist <= r): weight = 1
- Outside: weight = exp(- ((dist - r)^2) / (2 * sigma^2))
sigma defaults to radius / 2 if not provided.
Returned shape: [1, 1, 1, h, w] ready for broadcasting over [B, F, C, h, w].
"""
if sigma is None:
sigma = max(1e-6, radius / 2.0)
yy = torch.arange(h, device=device, dtype=dtype).view(h, 1)
xx = torch.arange(w, device=device, dtype=dtype).view(1, w)
dist = torch.sqrt((yy - center_y) ** 2 + (xx - center_x) ** 2)
outside = (dist - radius).clamp_min(0.0)
outside_weight = torch.exp(-(outside**2) / (2.0 * sigma**2))
weight = torch.where(dist <= radius, torch.ones_like(dist), outside_weight)
return weight.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1,1,1,h,w]
def build_anisotropic_gaussian(
H: int,
W: int,
center_x: float,
center_y: float,
sigma_x: float,
sigma_y: float,
# *,
clamp: bool = True,
normalize: bool = True,
min_value: float = 0.0,
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Core builder: create anisotropic Gaussian over (H,W).
G(y,x) = exp( - ( (x-cx)^2 / (2 sigma_x^2) + (y-cy)^2 / (2 sigma_y^2) ) )
Returns shape [H,W].
center_x, center_y: float (pixel coordinates)
sigma_x, sigma_y: positive float
"""
sigma_x = max(1e-6, float(sigma_x))
sigma_y = max(1e-6, float(sigma_y))
yy = torch.arange(H, device=device, dtype=dtype).view(H, 1)
xx = torch.arange(W, device=device, dtype=dtype).view(1, W)
gx = (xx - center_x) ** 2 / (2.0 * sigma_x * sigma_x)
gy = (yy - center_y) ** 2 / (2.0 * sigma_y * sigma_y)
gauss = torch.exp(-(gx + gy))
if normalize:
m = gauss.max()
if m > 0:
gauss = gauss / m
if clamp:
gauss = gauss.clamp_(min_value, 1.0)
return gauss
def build_anisotropic_gaussian_from_bbox(
H: int,
W: int,
y_min: int,
y_max: int,
x_min: int,
x_max: int,
# *,
padding_scale: float = 0.15,
sigma_scale: float = 0.5,
min_sigma: float = 1.0,
clamp: bool = True,
normalize: bool = True,
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Compute center & (sigma_x, sigma_y) from a bounding box, then call build_anisotropic_gaussian.
sigma_x = ( (bbox_width * (1+padding_scale))/2 ) * sigma_scale
sigma_y = ( (bbox_height * (1+padding_scale))/2 ) * sigma_scale
Both clamped by min_sigma.
"""
# Center
center_y = 0.5 * (y_min + y_max)
center_x = 0.5 * (x_min + x_max)
bbox_h = y_max - y_min + 1
bbox_w = x_max - x_min + 1
eff_h = bbox_h * (1.0 + padding_scale)
eff_w = bbox_w * (1.0 + padding_scale)
sigma_y = max(min_sigma, 0.5 * eff_h * sigma_scale)
sigma_x = max(min_sigma, 0.5 * eff_w * sigma_scale)
return build_anisotropic_gaussian(
H=H,
W=W,
center_x=center_x,
center_y=center_y,
sigma_x=sigma_x,
sigma_y=sigma_y,
clamp=clamp,
normalize=normalize,
device=device,
dtype=dtype,
)
def build_anisotropic_gaussian_from_mask(
mask: np.ndarray | Image.Image | torch.Tensor,
# *,
padding_scale: float = 0.15,
sigma_scale: float = 0.5,
min_sigma: float = 1.0,
clamp: bool = True,
normalize: bool = True,
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor | None:
"""
Compute bounding box from mask, then call build_anisotropic_gaussian_from_bbox.
Returns None if mask has no positive pixels.
"""
bbox = bbox_from_mask(mask)
if bbox is None:
return None
x_min, y_min, x_max, y_max = bbox
# print(f"{bbox = }")
if isinstance(mask, torch.Tensor):
H, W = mask.shape[-2], mask.shape[-1]
else:
mask_np = np.asarray(mask)
H, W = mask_np.shape[-2], mask_np.shape[-1]
return build_anisotropic_gaussian_from_bbox(
H=H,
W=W,
y_min=y_min,
y_max=y_max,
x_min=x_min,
x_max=x_max,
padding_scale=padding_scale,
sigma_scale=sigma_scale,
min_sigma=min_sigma,
clamp=clamp,
normalize=normalize,
device=mask.device if isinstance(mask, torch.Tensor) else device,
dtype=dtype,
)
def combine_gaussian_maps(
maps: list[torch.Tensor],
mode: str = "prob_or",
clamp: bool = True,
) -> torch.Tensor:
"""
Combine multiple Gaussian (or weight) maps into one in [0,1].
Args:
maps: list of tensors with identical shape (e.g. [1,1,1,H,W] or [H,W]).
mode:
- "prob_or": 1 - prod(1 - g) (smooth union, fast saturation)
- "sum_clamp": clamp(sum(g), 0, 1)
- "sum_norm": sum(g) / max(sum(g))
- "max": elementwise max
clamp: final clamp to [0,1] (except sum_norm which is already normalized).
Returns:
Combined tensor.
"""
assert len(maps) > 0
if len(maps) == 1:
out = maps[0]
return out.clamp_(0, 1) if clamp else out
stacked = torch.stack(maps, dim=0)
if mode == "prob_or":
out = 1.0 - torch.prod(1.0 - stacked, dim=0)
elif mode == "sum_clamp":
out = stacked.sum(dim=0)
if clamp:
out = out.clamp_(0.0, 1.0)
elif mode == "sum_norm":
out = stacked.sum(dim=0)
maxv = out.max()
if maxv > 0:
out = out / maxv
if clamp:
out = out.clamp_(0.0, 1.0)
elif mode == "max":
out, _ = stacked.max(dim=0)
if clamp:
out = out.clamp_(0.0, 1.0)
else:
raise ValueError(f"Unknown mode: {mode}")
return out