Skip to content

Commit

Permalink
Merge pull request #903 from scverse/bugfix/902-numpy-2-compat
Browse files Browse the repository at this point in the history
numpy 2.0 and python >= 3.10 compatibility changes
  • Loading branch information
ilan-gold authored Nov 6, 2024
2 parents 28d5077 + d62b3d0 commit 26c2693
Show file tree
Hide file tree
Showing 30 changed files with 188 additions and 165 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/

# C extensions
*.so
*.pyc
.DS_Store
*/.DS_Store
.idea
Expand Down
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
mypy_path = squidpy
python_version = 3.9
python_version = 3.10
plugins = numpy.typing.mypy_plugin

ignore_errors = False
Expand Down
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
fail_fast: false
default_language_version:
python: python3
default_stages:
- pre-commit
- pre-push
Expand All @@ -24,3 +22,4 @@ repos:
- id: mypy
additional_dependencies: [numpy, pandas, types-requests]
exclude: .scripts/ci/download_data.py|squidpy/datasets/_(dataset|image).py # See https://github.com/

2 changes: 1 addition & 1 deletion .scripts/ci/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main(args: argparse.Namespace) -> None:
obj = _maybe_download_data(func_name, path)

# we could do without the AnnData check as well (1 less req. in tox.ini), but it's better to be safe
assert isinstance(obj, (AnnData, sq.im.ImageContainer)), type(obj)
assert isinstance(obj, AnnData | sq.im.ImageContainer), type(obj)
assert path.is_file(), path


Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "squidpy"
dynamic = ["version"]
description = "Spatial Single Cell Analysis in Python"
readme = "README.rst"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down Expand Up @@ -42,7 +42,8 @@ authors = [
]
maintainers = [
{name = "Giovanni Palla", email = "[email protected]"},
{name = "Michal Klein", email = "[email protected]"}
{name = "Michal Klein", email = "[email protected]"},
{name = "Tim Treis", email = "[email protected]"}
]

dependencies = [
Expand Down Expand Up @@ -117,6 +118,9 @@ include-package-data = true
[tool.hatch.version]
source = "vcs"

[tool.hatch.metadata]
allow-direct-references = true

[tool.ruff]
line-length = 120
exclude = [
Expand Down
9 changes: 5 additions & 4 deletions src/squidpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

from importlib import metadata
from importlib.metadata import PackageMetadata

from squidpy import datasets, gr, im, pl, read, tl

try:
md = metadata.metadata(__name__)
__version__ = md.get("version", "")
__author__ = md.get("Author", "")
__maintainer__ = md.get("Maintainer-email", "")
md: PackageMetadata = metadata.metadata(__name__)
__version__ = md["Version"] if "Version" in md else ""
__author__ = md["Author"] if "Author" in md else ""
__maintainer__ = md["Maintainer-email"] if "Maintainer-email" in md else ""
except ImportError:
md = None # type: ignore[assignment]

Expand Down
4 changes: 2 additions & 2 deletions src/squidpy/_constants/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from abc import ABC, ABCMeta
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from enum import Enum, EnumMeta
from functools import wraps
from typing import Any, Callable
from typing import Any


def _pretty_raise_enum(cls: type[ModeEnum], fun: Callable[..., Any]) -> Callable[..., Any]:
Expand Down
3 changes: 2 additions & 1 deletion src/squidpy/_docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from textwrap import dedent
from typing import Any, Callable
from typing import Any

from docrep import DocstringProcessor

Expand Down
4 changes: 2 additions & 2 deletions src/squidpy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import functools
import inspect
import warnings
from collections.abc import Generator, Hashable, Iterable, Sequence
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from contextlib import contextmanager
from enum import Enum
from multiprocessing import Manager, cpu_count
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

import joblib as jl
import numpy as np
Expand Down
8 changes: 4 additions & 4 deletions src/squidpy/datasets/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

import os
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from inspect import Parameter, Signature, signature
from pathlib import Path
from typing import Any, Callable, Union
from typing import Any, TypeAlias, Union

import anndata
from anndata import AnnData
from scanpy import logging as logg
from scanpy import read
from scanpy._utils import check_presence_download

PathLike = Union[os.PathLike[str], str]
Function_t = Callable[..., Union[AnnData, Any]]
PathLike: TypeAlias = os.PathLike[str] | str
Function_t: TypeAlias = Callable[..., AnnData | Any]


@dataclass(frozen=True)
Expand Down
6 changes: 3 additions & 3 deletions src/squidpy/gr/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def spatial_neighbors(
for lib in libs:
ixs.extend(np.where(adata.obs[library_key] == lib)[0])
mats.append(_build_fun(adata[adata.obs[library_key] == lib]))
ixs = np.argsort(ixs) # invert
ixs = np.argsort(ixs).tolist() # invert
Adj = block_diag([m[0] for m in mats], format="csr")[ixs, :][:, ixs]
Dst = block_diag([m[1] for m in mats], format="csr")[ixs, :][:, ixs]
else:
Expand Down Expand Up @@ -400,7 +400,7 @@ def _build_connectivity(
Dst = csr_matrix((dists, indices, indptr), shape=(N, N))
# fmt: on
else:
r = 1 if radius is None else radius if isinstance(radius, (int, float)) else max(radius)
r = 1 if radius is None else radius if isinstance(radius, int | float) else max(radius)
tree = NearestNeighbors(n_neighbors=n_neighs, radius=r, metric="euclidean")
tree.fit(coords)

Expand Down Expand Up @@ -519,7 +519,7 @@ def mask_graph(
dists_key = Key.obsp.spatial_dist(spatial_key)

# check polygon type
if not isinstance(polygon_mask, (Polygon, MultiPolygon)):
if not isinstance(polygon_mask, Polygon | MultiPolygon):
raise ValueError(f"`polygon_mask` should be of type `Polygon` or `MultiPolygon`, got {type(polygon_mask)}")

# get elements
Expand Down
17 changes: 9 additions & 8 deletions src/squidpy/gr/_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from itertools import product
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, Union
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, Union

import numpy as np
import pandas as pd
Expand All @@ -33,10 +33,11 @@

__all__ = ["ligrec", "PermutationTest"]

StrSeq = Sequence[str]
SeqTuple = Sequence[tuple[str, str]]
Interaction_t = Union[pd.DataFrame, Mapping[str, StrSeq], StrSeq, tuple[StrSeq, StrSeq], SeqTuple]
Cluster_t = Union[StrSeq, tuple[StrSeq, StrSeq], SeqTuple]
StrSeq: TypeAlias = Sequence[str]
SeqTuple: TypeAlias = Sequence[tuple[str, str]]
Interaction_t: TypeAlias = pd.DataFrame | Mapping[str, StrSeq] | StrSeq | tuple[StrSeq, StrSeq] | SeqTuple

Cluster_t: TypeAlias = StrSeq | tuple[StrSeq, StrSeq] | SeqTuple

SOURCE = "source"
TARGET = "target"
Expand Down Expand Up @@ -263,7 +264,7 @@ def prepare(
if isinstance(interactions[0], str):
interactions = list(product(interactions, repeat=2))
elif len(interactions) == 2:
interactions = tuple(zip(*interactions))
interactions = tuple(zip(*interactions, strict=False))

if not all(len(i) == 2 for i in interactions):
raise ValueError("Not all interactions are of length `2`.")
Expand Down Expand Up @@ -392,8 +393,8 @@ def test(
data["clusters"] = data["clusters"].cat.remove_unused_categories()
cat = data["clusters"].cat

cluster_mapper = dict(zip(cat.categories, range(len(cat.categories))))
gene_mapper = dict(zip(data.columns[:-1], range(len(data.columns) - 1))) # -1 for 'clusters'
cluster_mapper = dict(zip(cat.categories, range(len(cat.categories)), strict=False))
gene_mapper = dict(zip(data.columns[:-1], range(len(data.columns) - 1), strict=False)) # -1 for 'clusters'

data.columns = [gene_mapper[c] if c != "clusters" else c for c in data.columns]
clusters_ = np.array([[cluster_mapper[c1], cluster_mapper[c2]] for c1, c2 in clusters], dtype=np.uint32)
Expand Down
8 changes: 4 additions & 4 deletions src/squidpy/gr/_nhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from __future__ import annotations

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from typing import Any, Callable
from typing import Any

import networkx as nx
import numba.types as nt
Expand Down Expand Up @@ -259,7 +259,7 @@ def centrality_scores(
_assert_categorical_obs(adata, cluster_key)
_assert_connectivity_key(adata, connectivity_key)

if isinstance(score, (str, Centrality)):
if isinstance(score, str | Centrality):
centrality = [score]
elif score is None:
centrality = [c.s for c in Centrality]
Expand Down Expand Up @@ -386,7 +386,7 @@ def _interaction_matrix(
cur_row = cats[i]
cur_indices = indices_list[i]
cur_data = data_list[i]
for j, val in zip(cur_indices, cur_data):
for j, val in zip(cur_indices, cur_data): # noqa: B905
cur_col = cats[j]
output[cur_row, cur_col] += val
return output
Expand Down
28 changes: 14 additions & 14 deletions src/squidpy/gr/_ppatterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,12 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr
if layer not in adata.obsm:
raise KeyError(f"Key `{layer!r}` not found in `adata.obsm`.")
if ixs is None:
ixs = np.arange(adata.obsm[layer].shape[1])
ixs = list(np.ravel([ixs]))
ixs = list(range(adata.obsm[layer].shape[1]))
elif isinstance(ixs, int):
ixs = [ixs]
else:
ixs = list(ixs)

return adata.obsm[layer][:, ixs].T, ixs

if attr == "X":
Expand Down Expand Up @@ -297,11 +301,9 @@ def _occur_count(
y = clust_y[idx_y]
# Treat computing co-occurrence using the same split and different splits differently
# Pairwise distance matrix for between the same split is symmetric and therefore only needs to be counted once
for i, j in zip(x, y):
if same_split:
co_occur[i, j] += 1
else:
co_occur[i, j] += 1
for i, j in zip(x, y): # noqa: B905 # cannot use strict=False because of numba
co_occur[i, j] += 1
if not same_split:
co_occur[j, i] += 1

# Prevent divison by zero errors when we have low cell counts/small intervals
Expand Down Expand Up @@ -416,24 +418,22 @@ def co_occurrence(
n_obs = spatial.shape[0]
if n_splits is None:
size_arr = (n_obs**2 * spatial.itemsize) / 1024 / 1024 # calc expected mem usage
n_splits = 1
if size_arr > 2000:
n_splits = 1
while 2048 < (n_obs / n_splits):
while (n_obs / n_splits) > 2048:
n_splits += 1
logg.warning(
f"`n_splits` was automatically set to `{n_splits}` to "
f"prevent `{n_obs}x{n_obs}` distance matrix from being created"
)
else:
n_splits = 1
n_splits = max(min(n_splits, n_obs), 1)
n_splits = int(max(min(n_splits, n_obs), 1))

# split array and labels
spatial_splits = tuple(s for s in np.array_split(spatial, n_splits, axis=0) if len(s))
labs_splits = tuple(s for s in np.array_split(labs, n_splits, axis=0) if len(s))
# create idx array including unique combinations and self-comparison
x, y = np.triu_indices_from(np.empty((n_splits, n_splits)))
idx_splits = list(zip(x, y))
idx_splits = list(zip(x, y, strict=False))

n_jobs = _get_n_cores(n_jobs)
start = logg.info(
Expand Down Expand Up @@ -578,7 +578,7 @@ def _g_moments(w: spmatrix | NDArrayA) -> tuple[float, float, float]:

# s1
t = w.transpose() + w
t2 = t.multiply(t)
t2 = t.multiply(t) if isinstance(t, spmatrix) else t * t
s1 = t2.sum() / 2.0

# s2
Expand Down
14 changes: 8 additions & 6 deletions src/squidpy/gr/_sepal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Callable, Literal
from collections.abc import Callable, Sequence
from typing import Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -151,7 +151,7 @@ def sepal(

if sepal_score[key_added].isna().any():
logg.warning("Found `NaN` in sepal scores, consider increasing `n_iter` to a higher value")
sepal_score.sort_values(by=key_added, ascending=False, inplace=True)
sepal_score = sepal_score.sort_values(by=key_added, ascending=False)

if copy:
logg.info("Finish", time=start)
Expand Down Expand Up @@ -180,10 +180,12 @@ def _score_helper(
else:
raise NotImplementedError(f"Laplacian for `{max_neighs}` neighbors is not yet implemented.")

score, sparse = [], issparse(vals)
score = []
for i in ixs:
conc = vals[:, i].toarray().flatten() if sparse else vals[:, i].copy()
conc = vals[:, i].toarray().flatten() if sparse else vals[:, i].copy()
if isinstance(vals, spmatrix):
conc = vals[:, i].toarray().flatten()
else:
conc = vals[:, i].copy()
time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh)
score.append(dt * time_iter)

Expand Down
6 changes: 3 additions & 3 deletions src/squidpy/gr/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _create_sparse_df(
)

if not issparse(data):
pred = (lambda col: ~np.isnan(col)) if fill_value is np.nan else (lambda col: ~np.isclose(col, fill_value))
pred = (lambda col: ~np.isnan(col)) if np.isnan(fill_value) else (lambda col: ~np.isclose(col, fill_value))
dtype = SparseDtype(data.dtype, fill_value=fill_value)
n_rows, n_cols = data.shape
arrays = []
Expand Down Expand Up @@ -236,13 +236,13 @@ def _extract_expression(
res = adata[:, genes].layers[layer]
if isinstance(res, AnnData):
res = res.X
elif not isinstance(res, (np.ndarray, spmatrix)):
elif not isinstance(res, np.ndarray | spmatrix):
raise TypeError(f"Invalid expression type `{type(res).__name__}`.")

# handle views
if isinstance(res, ArrayView):
return np.asarray(res), genes
if isinstance(res, (SparseCSRView, SparseCSCView)):
if isinstance(res, SparseCSRView | SparseCSCView):
mro = type(res).mro()
if csr_matrix in mro:
return csr_matrix(res), genes
Expand Down
Loading

0 comments on commit 26c2693

Please sign in to comment.