Skip to content

Commit

Permalink
- _optimize_tile_size function added to calculate the best tiles size…
Browse files Browse the repository at this point in the history
… to minimize the partial uncovered area

- aligned_min_x and aligned_min_y added to _calculate_window_corners function to remove partial regions at the beginning of the slides
- max_n_cells, n_splits, and split_line parameters added to the sliding_window function
  • Loading branch information
mossishahi committed Dec 1, 2024
1 parent 8f4b2cd commit a8cc17e
Showing 1 changed file with 204 additions and 62 deletions.
266 changes: 204 additions & 62 deletions src/squidpy/tl/_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,36 @@

from collections import defaultdict
from itertools import product

import time
import numpy as np
import math
import pandas as pd
from anndata import AnnData
from scanpy import logging as logg
from spatialdata import SpatialData
import math
import math

from squidpy._docs import d
from squidpy.gr._utils import _save_data

__all__ = ["sliding_window"]


@d.dedent
def sliding_window(
adata: AnnData | SpatialData,
library_key: str | None = None,
window_size: int | None = None,
overlap: int = 0,
coord_columns: tuple[str, str] = ("globalX", "globalY"),
sliding_window_key: str = "sliding_window_assignment",
window_size: int | tuple[int, int] | None = None,
spatial_key: str = "spatial",
sliding_window_key: str = "sliding_window_assignment",
overlap: int = 0,
max_n_cells = None,
split_line: str = 'h',
n_splits = None,
drop_partial_windows: bool = False,
square: bool = False,
window_size_per_library_key: str = 'equal',
copy: bool = False,
) -> pd.DataFrame | None:
"""
Expand All @@ -33,35 +40,47 @@ def sliding_window(
Parameters
----------
%(adata)s
window_size: int
Size of the sliding window.
%(library_key)s
coord_columns: Tuple[str, str]
Tuple of column names in `adata.obs` that specify the coordinates (x, y), e.i. ('globalX', 'globalY')
window_size: int
Size of the sliding window.
%(spatial_key)s
sliding_window_key: str
Base name for sliding window columns.
overlap: int
Overlap size between consecutive windows. (0 = no overlap)
%(spatial_key)s
drop_partial_windows: bool
If True, drop windows that are smaller than the window size at the borders.
overlap: int
Overlap size between consecutive windows. (0 = no overlap)
max_n_cells: int
If window_size is None, either 'n_split' or 'max_n_cells' can be set.
max_n_cells sets an upper limit for the number of cells within each region.
n_splits: int
This can be used to split the entire region to some splits.
copy: bool
If True, return the result, otherwise save it to the adata object.
split_line: str
If 'square' is False, this set's the orientation for rectanglular regions. `h` : Horizontal, `v`: Vertical
Returns
-------
If ``copy = True``, returns the sliding window annotation(s) as pandas dataframe
Otherwise, stores the sliding window annotation(s) in .obs.
"""

if overlap < 0:
raise ValueError("Overlap must be non-negative.")

if isinstance(adata, SpatialData):
adata = adata.table

assert max_n_cells is None or n_splits is None, "You can specify only one from the parameters 'n_split' and 'max_n_cells' "

# we don't want to modify the original adata in case of copy=True
if copy:
adata = adata.copy()

if 'sliding_window_assignment_colors' in adata.uns:
del adata.uns['sliding_window_assignment_colors']
# extract coordinates of observations
x_col, y_col = coord_columns
if x_col in adata.obs and y_col in adata.obs:
Expand All @@ -76,52 +95,95 @@ def sliding_window(
raise ValueError(
f"Coordinates not found. Provide `{coord_columns}` in `adata.obs` or specify a suitable `spatial_key` in `adata.obsm`."
)

if library_key is not None and library_key not in adata.obs:
raise ValueError(f"Library key '{library_key}' not found in adata.obs")

if library_key is None and not 'fov' in adata.obs.columns:
adata.obs['fov'] = 'fov1'

# infer window size if not provided
if window_size is None:
coord_range = max(
coords[x_col].max() - coords[x_col].min(),
coords[y_col].max() - coords[y_col].min(),
)
# mostly arbitrary choice, except that full integers usually generate windows with 1-2 cells at the borders
window_size = max(int(np.floor(coord_range // 3.95)), 1)
libraries = adata.obs[library_key].unique()

if window_size <= 0:
raise ValueError("Window size must be larger than 0.")
fovs_x_range = [(adata.obs[adata.obs[library_key]==key][x_col].max(), adata.obs[adata.obs[library_key]==key][x_col].min()) for key in libraries]
fovs_y_range = [(adata.obs[adata.obs[library_key]==key][y_col].max(), adata.obs[adata.obs[library_key]==key][y_col].min()) for key in libraries]
fovs_width = [i-j for (i, j) in fovs_x_range]
fovs_height = [i-j for (i, j) in fovs_y_range]
fovs_n_cell = [adata[adata.obs[library_key]==key].shape[0] for key in libraries]
fovs_area = [i*j for i, j in zip(fovs_width, fovs_height)]
fovs_density = [i/j for i, j in zip(fovs_n_cell, fovs_area)]
window_sizes = []

if library_key is not None and library_key not in adata.obs:
raise ValueError(f"Library key '{library_key}' not found in adata.obs")
if window_size is None:
if max_n_cells is None and n_splits is None:
n_splits = 2

libraries = [None] if library_key is None else adata.obs[library_key].unique()
if window_size_per_library_key == 'equal':
if max_n_cells is not None:
n_splits = max(2, int(max(fovs_n_cell)/max_n_cells))
else:
max_n_cells = int(max(fovs_n_cell) / n_splits)
min_n_cells = int(min(fovs_n_cell) / n_splits)
maximum_region_area = max_n_cells / max(fovs_density)
minimum_region_area = min_n_cells / max(fovs_density)
window_size = _optimize_tile_size(min(fovs_width),
min(fovs_height),
minimum_region_area,
maximum_region_area,
square,
split_line)
window_sizes = [window_size]*len(libraries)
else:
for i, lib in enumerate(libraries):
if max_n_cells is not None:
n_splits = max(2, int(fovs_n_cell[i]/max_n_cells))
else:
max_n_cells = int(fovs_n_cell[i] / n_splits)
min_n_cells = int(fovs_n_cell[i] / n_splits)
minimum_region_area = min_n_cells / max(fovs_density)
maximum_region_area = fovs_area[i]/fovs_density[i]
window_sizes.append(
_optimize_tile_size(fovs_width[i],
fovs_height[i],
minimum_region_area,
maximum_region_area,
square,
split_line
)
)
else:
assert split_line is None, logg.warning("'split' ignored as window_size is specified for square regions")
assert n_splits is None, logg.warning("'n_split' ignored as window_size is specified for square regions")
assert max_n_cells is None, logg.warning("'max_n_cells' ignored as window_size is specified")
if isinstance(window_size, (int, float)):
if window_size<= 0:
raise ValueError("Window size must be larger than 0.")
else:
window_size = (window_size, window_size)
elif isinstance(window_size, tuple):
for i in window_size:
if i<= 0:
raise ValueError("Window size must be larger than 0.")

window_sizes = [window_size]*len(libraries)

# Create a DataFrame to store the sliding window assignments
sliding_window_df = pd.DataFrame(index=adata.obs.index)

if sliding_window_key in adata.obs:
logg.warning(f"Overwriting existing column '{sliding_window_key}' in adata.obs.")
adata.obs[sliding_window_key] = 'window_0'

for lib in libraries:
if lib is not None:
lib_mask = adata.obs[library_key] == lib
lib_coords = coords.loc[lib_mask]
else:
lib_mask = np.ones(len(adata), dtype=bool)
lib_coords = coords

min_x, max_x = lib_coords[x_col].min(), lib_coords[x_col].max()
min_y, max_y = lib_coords[y_col].min(), lib_coords[y_col].max()

for i, lib in enumerate(libraries):
lib_mask = adata.obs[library_key] == lib
lib_coords = coords.loc[lib_mask]

# precalculate windows
windows = _calculate_window_corners(
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
window_size=window_size,
fovs_x_range[i],
fovs_y_range[i],
window_size=window_sizes[i],
overlap=overlap,
drop_partial_windows=drop_partial_windows,
)

lib_key = f"{lib}_" if lib is not None else ""

# assign observations to windows
Expand All @@ -131,6 +193,11 @@ def sliding_window(
y_start = window["y_start"]
y_end = window["y_end"]

if drop_partial_windows:
# Check if the window is within the bounds
if x_end > fovs_x_range[i][0] or y_end > fovs_y_range[i][0]:
continue # Skip windows that extend beyond the region

mask = (
(lib_coords[x_col] >= x_start)
& (lib_coords[x_col] <= x_end)
Expand All @@ -156,30 +223,26 @@ def sliding_window(

if overlap == 0:
# create categorical variable for ordered windows
# Ensure the column is a string type
sliding_window_df[sliding_window_key] = pd.Categorical(
sliding_window_df[sliding_window_key],
ordered=True,
categories=sorted(
sliding_window_df[sliding_window_key].unique(),
key=lambda x: int(x.split("_")[-1]),
),
)

sliding_window_df[x_col] = coords[x_col]
sliding_window_df[y_col] = coords[y_col]
)

if copy:
return sliding_window_df
for col_name, col_data in sliding_window_df.items():
_save_data(adata, attr="obs", key=col_name, data=col_data)
sliding_window_df = sliding_window_df.loc[adata.obs.index]
_save_data(adata, attr="obs", key=sliding_window_key, data=sliding_window_df[sliding_window_key])


def _calculate_window_corners(
min_x: int,
max_x: int,
min_y: int,
max_y: int,
window_size: int,
x_range: int,
y_range: int,
window_size: int = None,
overlap: int = 0,
drop_partial_windows: bool = False,
) -> pd.DataFrame:
Expand Down Expand Up @@ -209,17 +272,22 @@ def _calculate_window_corners(
-------
windows: pandas DataFrame with columns ['x_start', 'x_end', 'y_start', 'y_end']
"""
x_window_size, y_window_size = window_size

if overlap < 0:
raise ValueError("Overlap must be non-negative.")
if overlap >= window_size:
if overlap >= x_window_size or overlap >= y_window_size:
raise ValueError("Overlap must be less than the window size.")

x_step = window_size - overlap
y_step = window_size - overlap

max_x, min_x = x_range
max_y, min_y = y_range

x_step = x_window_size - overlap
y_step = y_window_size - overlap

# Align min_x and min_y to ensure that the first window starts properly
aligned_min_x = min_x - (min_x % window_size) if min_x % window_size != 0 else min_x
aligned_min_y = min_y - (min_y % window_size) if min_y % window_size != 0 else min_y
aligned_min_x = min_x - (min_x % x_window_size) if min_x % x_window_size != 0 else min_x
aligned_min_y = min_y - (min_y % y_window_size) if min_y % y_window_size != 0 else min_y

# Generate starting points starting from the aligned minimum values
x_starts = np.arange(aligned_min_x, max_x, x_step)
Expand All @@ -228,16 +296,90 @@ def _calculate_window_corners(
# Create all combinations of x and y starting points
starts = list(product(x_starts, y_starts))
windows = pd.DataFrame(starts, columns=["x_start", "y_start"])
windows["x_end"] = windows["x_start"] + window_size
windows["y_end"] = windows["y_start"] + window_size
windows["x_end"] = windows["x_start"] + x_window_size
windows["y_end"] = windows["y_start"] + y_window_size

# Adjust windows that extend beyond the bounds
# if drop_partial_windows:
# # Remove windows that go beyond the max_x or max_y
# windows = windows[
# (windows["x_end"] <= max_x) & (windows["y_end"] <= max_y)
# ]
# else:
# # If not dropping partial windows, clip the end points to max_x and max_y
# windows["x_end"] = windows["x_end"].clip(upper=max_x)
# windows["y_end"] = windows["y_end"].clip(upper=max_y)

if not drop_partial_windows:
windows["x_end"] = windows["x_end"].clip(upper=max_x)
windows["y_end"] = windows["y_end"].clip(upper=max_y)
else:
valid_windows = (windows["x_end"] <= max_x) & (windows["y_end"] <= max_y)
windows = windows[valid_windows]

windows = windows.reset_index(drop=True)
return windows[["x_start", "x_end", "y_start", "y_end"]]

def _optimize_tile_size(L, W, A_min=None, A_max=None, square=False, split_line='v'):
"""
This function optimizes the tile size for covering a rectangle of dimensions LxW.
It returns a tuple (x, y) where x and y are the dimensions of the optimal tile.
Parameters:
- L (int): Length of the rectangle.
- W (int): Width of the rectangle.
- A_min (int, optional): Minimum allowed area of each tile. If None, no minimum area limit is applied.
- A_max (int, optional): Maximum allowed area of each tile. If None, no maximum area limit is applied.
- square (bool, optional): If True, tiles will be square (x = y).
Returns:
- tuple: (x, y) representing the optimal tile dimensions.
"""
best_tile_size = None
min_uncovered_area = float('inf')
if square:
# Calculate square tiles
max_side = int(math.sqrt(A_max)) if A_max else int(min(L, W))
min_side = int(math.sqrt(A_min)) if A_min else 1
# Try all square tile sizes from min_side to max_side
for side in range(min_side, max_side + 1):
if (A_min and side * side < A_min) or (A_max and side * side > A_max):
continue # Skip sizes that are out of the area limits

# Calculate number of tiles that fit in the rectangle
num_tiles_x = L // side
num_tiles_y = W // side
uncovered_area = L * W - (num_tiles_x * num_tiles_y * side * side)

# Track the best tile size
if uncovered_area < min_uncovered_area:
min_uncovered_area = uncovered_area
best_tile_size = (side, side)
else:
# For non-square tiles, optimize both dimensions independently
if split_line == 'v':
max_tile_length = A_max/W if A_max else int(L)
max_tile_width = W
min_tile_length = A_min/W
min_tile_width = W
if split_line == 'h':
max_tile_length = L
max_tile_width = A_max/L if A_max else int()
min_tile_width = A_min/L
min_tile_length = L
# Try all combinations of width and height within the bounds
for width in range(int(min_tile_width), int(max_tile_width) + 1):
for height in range(int(min_tile_length), int(max_tile_length) + 1):
if (A_min and width * height < A_min) or (A_max and width * height > A_max):
continue # Skip sizes that are out of the area limits

# Calculate number of tiles that fit in the rectangle
num_tiles_x = L // width
num_tiles_y = W // height
uncovered_area = L * W - (num_tiles_x * num_tiles_y * width * height)

# Track the best tile size (minimizing uncovered area)
if uncovered_area < min_uncovered_area:
min_uncovered_area = uncovered_area
best_tile_size = (height, width)

return best_tile_size

0 comments on commit a8cc17e

Please sign in to comment.