Skip to content

Commit

Permalink
ruff: fix zip strict argument
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Nov 4, 2024
1 parent aaef961 commit d4e0de9
Show file tree
Hide file tree
Showing 25 changed files with 82 additions and 60 deletions.
3 changes: 2 additions & 1 deletion examples/moving-geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def source(t, x):
gradx = sum(
num_reference_derivative(discr, (i,), x)
for i in range(discr.dim))
intx = sum(actx.np.sum(xi * wi) for xi, wi in zip(x, discr.quad_weights()))
intx = sum(actx.np.sum(xi * wi)
for xi, wi in zip(x, discr.quad_weights(), strict=True))

assert gradx is not None
assert intx is not None
Expand Down
8 changes: 5 additions & 3 deletions examples/simple-dg.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def grad(self, vec):
for idim in range(self.volume_discr.dim)]

return make_obj_array([
sum(dref_i*ipder_i for dref_i, ipder_i in zip(dref, ipder[iambient]))
sum(dref_i*ipder_i
for dref_i, ipder_i in zip(dref, ipder[iambient], strict=True))
for iambient in range(self.volume_discr.ambient_dim)])

def div(self, vecs):
Expand Down Expand Up @@ -259,7 +260,7 @@ def inverse_mass(self, vec):
vec_i,
arg_names=("mass_inv_mat", "vec"),
tagged=(FirstAxisIsElementsTag(),)
) for grp, vec_i in zip(discr.groups, vec)
) for grp, vec_i in zip(discr.groups, vec, strict=True)
)
) / actx.thaw(self.vol_jacobian())

Expand Down Expand Up @@ -321,7 +322,8 @@ def face_mass(self, vec):
),
tagged=(FirstAxisIsElementsTag(),))
for afgrp, volgrp, vec_i in zip(all_faces_discr.groups,
vol_discr.groups, vec)
vol_discr.groups,
vec, strict=True)
)
)

Expand Down
2 changes: 1 addition & 1 deletion meshmode/discretization/connection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def check_connection(actx: ArrayContext, connection: DirectDiscretizationConnect

assert len(connection.groups) == len(to_discr.groups)

for cgrp, tgrp in zip(connection.groups, to_discr.groups):
for cgrp, tgrp in zip(connection.groups, to_discr.groups, strict=True):
for batch in cgrp.batches:
fgrp = from_discr.groups[batch.from_group_index]

Expand Down
4 changes: 2 additions & 2 deletions meshmode/discretization/connection/chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _build_batches(actx, from_bins, to_bins, batch):
def to_device(x):
return actx.freeze(actx.from_numpy(np.asarray(x)))

for ibatch, (from_bin, to_bin) in enumerate(zip(from_bins, to_bins)):
for ibatch, (from_bin, to_bin) in enumerate(zip(from_bins, to_bins, strict=True)):
yield InterpolationBatch(
from_group_index=batch[ibatch].from_group_index,
from_element_indices=to_device(from_bin),
Expand Down Expand Up @@ -248,7 +248,7 @@ def flatten_chained_connection(actx, connection):

# build new groups
groups = []
for igrp, (from_bin, to_bin) in enumerate(zip(from_bins, to_bins)):
for igrp, (from_bin, to_bin) in enumerate(zip(from_bins, to_bins, strict=True)):
groups.append(DiscretizationConnectionElementGroup(
list(_build_batches(actx, from_bin, to_bin,
batch_info[igrp]))))
Expand Down
4 changes: 2 additions & 2 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def group_pick_knl(is_surjective: bool):

group_arrays = []
for i_tgrp, (cgrp, group_pick_info) in enumerate(
zip(self.groups, self._global_point_pick_info(actx))):
zip(self.groups, self._global_point_pick_info(actx), strict=True)):

group_array_contributions = []

Expand Down Expand Up @@ -926,7 +926,7 @@ def knl():
tgt_node_nr_base = 0
mats = []
for i_tgrp, (tgrp, cgrp) in enumerate(
zip(conn.to_discr.groups, conn.groups)):
zip(conn.to_discr.groups, conn.groups, strict=True)):
for i_batch, batch in enumerate(cgrp.batches):
if not len(batch.from_element_indices):
continue
Expand Down
7 changes: 4 additions & 3 deletions meshmode/discretization/connection/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def make_face_restriction(
connection_data = {}

for igrp, (grp, fagrp_list) in enumerate(
zip(discr.groups, discr.mesh.facial_adjacency_groups)):
zip(discr.groups, discr.mesh.facial_adjacency_groups, strict=True)):

mgrp = grp.mesh_el_group

Expand All @@ -251,7 +251,7 @@ def make_face_restriction(
if isinstance(fagrp, InteriorAdjacencyGroup)]
for fagrp in int_grps:
group_boundary_faces.extend(
zip(fagrp.elements, fagrp.element_faces))
zip(fagrp.elements, fagrp.element_faces, strict=True))

elif boundary_tag is FACE_RESTR_ALL:
group_boundary_faces.extend(
Expand All @@ -270,7 +270,8 @@ def make_face_restriction(
group_boundary_faces.extend(
zip(
bdry_grp.elements,
bdry_grp.element_faces))
bdry_grp.element_faces,
strict=True))

# }}}

Expand Down
2 changes: 1 addition & 1 deletion meshmode/discretization/connection/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def vandermonde_matrix(grp):
c_i,
arg_names=("vdm", "coeffs"),
tagged=(FirstAxisIsElementsTag(),))
for grp, c_i in zip(self.to_discr.groups, coefficients)
for grp, c_i in zip(self.to_discr.groups, coefficients, strict=True)
)
)

Expand Down
15 changes: 10 additions & 5 deletions meshmode/discretization/connection/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def _build_interpolation_batches_for_group(
assert len(refinement_result) == num_children
# Refined -> interpolates to children
for from_bin, to_bin, child_idx in zip(
from_bins[1:], to_bins[1:], refinement_result):
from_bins[1:], to_bins[1:], refinement_result,
strict=True):
from_bin.append(elt_idx)
to_bin.append(child_idx)

Expand All @@ -97,8 +98,10 @@ def _build_interpolation_batches_for_group(

from itertools import chain
for from_bin, to_bin, unit_nodes in zip(
from_bins, to_bins,
chain([fine_unit_nodes], mapped_unit_nodes)):
from_bins,
to_bins,
chain([fine_unit_nodes], mapped_unit_nodes),
strict=True):
if not from_bin:
continue
yield InterpolationBatch(
Expand Down Expand Up @@ -148,8 +151,10 @@ def make_refinement_connection(actx, refiner, coarse_discr, group_factory):

groups = []
for group_idx, (coarse_discr_group, fine_discr_group, record) in \
enumerate(zip(coarse_discr.groups, fine_discr.groups,
refiner.group_refinement_records)):
enumerate(zip(coarse_discr.groups,
fine_discr.groups,
refiner.group_refinement_records,
strict=True)):
groups.append(
DiscretizationConnectionElementGroup(
list(_build_interpolation_batches_for_group(
Expand Down
3 changes: 2 additions & 1 deletion meshmode/discretization/connection/same_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def make_same_mesh_connection(actx, to_discr, from_discr):
return IdentityDiscretizationConnection(from_discr)

groups = []
for igrp, (fgrp, tgrp) in enumerate(zip(from_discr.groups, to_discr.groups)):
for igrp, (fgrp, tgrp) in enumerate(
zip(from_discr.groups, to_discr.groups, strict=True)):
from arraycontext.metadata import NameHint
all_elements = actx.tag(NameHint(f"all_el_ind_grp{igrp}"),
actx.tag_axis(0,
Expand Down
3 changes: 2 additions & 1 deletion meshmode/discretization/poly_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ def quadrature_rule(self):
else:
nodes_tp = self._nodes

for idim, (nodes, basis) in enumerate(zip(nodes_tp, self._basis.bases)):
for idim, (nodes, basis) in enumerate(
zip(nodes_tp, self._basis.bases, strict=True)):
# get current dimension's nodes
iaxis = (*(0,)*idim, slice(None), *(0,)*(self.dim-idim-1))
nodes = nodes[iaxis]
Expand Down
12 changes: 7 additions & 5 deletions meshmode/discretization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _check_discr_same_connectivity(discr, other):
if not all(
sg.discretization_key() == og.discretization_key()
and sg.nelements == og.nelements
for sg, og in zip(discr.groups, other.groups)):
for sg, og in zip(discr.groups, other.groups, strict=True)):
return False

return True
Expand Down Expand Up @@ -482,7 +482,8 @@ def cells(self):
grp.nunit_dofs,
grp.nelements * grp.nunit_dofs + 1,
grp.nunit_dofs)
for grp_offset, grp in zip(grp_offsets, self.vis_discr.groups)
for grp_offset, grp in zip(grp_offsets[:-1], self.vis_discr.groups,
strict=True)
])

return self.vis_discr.mesh.nelements, connectivity, offsets
Expand Down Expand Up @@ -1161,7 +1162,8 @@ def write_xdmf_file(self, file_name, names_and_fields,

grids = []
node_nr_base = 0
for igrp, (vgrp, gnodes) in enumerate(zip(connectivity.groups, nodes)):
for igrp, (vgrp, gnodes) in enumerate(
zip(connectivity.groups, nodes, strict=True)):
grp_name = f"Group_{igrp:05d}"
h5grp = h5grid.create_group(grp_name)

Expand Down Expand Up @@ -1318,7 +1320,7 @@ def make_visualizer(actx, discr, vis_order=None,
vis_discr = discr.copy(actx=actx, group_factory=VisGroupFactory(vis_order))

if all(grp.discretization_key() == vgrp.discretization_key()
for grp, vgrp in zip(discr.groups, vis_discr.groups)):
for grp, vgrp in zip(discr.groups, vis_discr.groups, strict=True)):
from warnings import warn
warn("Visualization discretization is identical to base discretization. "
"To avoid the creation of a separate discretization for "
Expand Down Expand Up @@ -1383,7 +1385,7 @@ def write_nodal_adjacency_vtk_file(file_name, mesh,
(mesh.ambient_dim, mesh.nelements),
dtype=mesh.vertices.dtype)

for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups, strict=True):
centroids[:, base_element_nr:base_element_nr + grp.nelements] = (
np.sum(mesh.vertices[:, grp.vertex_indices], axis=-1)
/ grp.vertex_indices.shape[-1])
Expand Down
2 changes: 1 addition & 1 deletion meshmode/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def complete_some(self):
raise ValueError(
"duplicate local/remote part pair in inter_rank_bdry_info")

for i_src_rank, recvd in zip(source_ranks, data):
for i_src_rank, recvd in zip(source_ranks, data, strict=True):
(remote_part_id, local_part_id,
remote_bdry_mesh, remote_group_infos) = recvd

Expand Down
2 changes: 1 addition & 1 deletion meshmode/dof_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def check_dofarray_against_discr(discr, dof_ary: DOFArray):
"DOFArray has unexpected number of groups "
f"({len(dof_ary)}, expected: {len(discr.groups)})")

for i, (grp, grp_ary) in enumerate(zip(discr.groups, dof_ary)):
for i, (grp, grp_ary) in enumerate(zip(discr.groups, dof_ary, strict=True)):
expected_shape = (grp.nelements, grp.nunit_dofs)
if grp_ary.shape != expected_shape:
raise InconsistentDOFArray(
Expand Down
20 changes: 12 additions & 8 deletions meshmode/interop/firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def _get_firedrake_facial_adjacency_groups(fdrake_mesh_topology,
to_keep = np.isin(int_elements, cells_to_use)
cells_to_use_inv = dict(zip(cells_to_use,
np.arange(np.size(cells_to_use),
dtype=IntType)))
dtype=IntType),
strict=True))

# Keep the cells that we are using and change old cell index
# to new cell index
Expand Down Expand Up @@ -459,7 +460,8 @@ def _get_firedrake_orientations(fdrake_mesh, unflipped_group, vertices,
orient = np.ones(num_cells)
if normals:
for i, (normal, vert_indices) in enumerate(
zip(np.array(normals), unflipped_group.vertex_indices)):
zip(np.array(normals), unflipped_group.vertex_indices,
strict=True)):
edge = vertices[:, vert_indices[1]] - vertices[:, vert_indices[0]]
if np.cross(normal, edge) < 0:
orient[i] = -1.0
Expand Down Expand Up @@ -612,18 +614,19 @@ def import_firedrake_mesh(fdrake_mesh, cells_to_use=None,
# Get all the nodal information we can from the topology
with ProcessLogger(logger, "Retrieving vertex indices and computing "
"NodalAdjacency from firedrake mesh"):
vertex_indices, nodal_adjacency = \
_get_firedrake_nodal_info(fdrake_mesh, cells_to_use=cells_to_use)
vertex_indices, nodal_adjacency = (
_get_firedrake_nodal_info(fdrake_mesh, cells_to_use=cells_to_use))

# If only using some cells, vertices may need new indices as many
# will be removed
if cells_to_use is not None:
vert_ndx_new2old = np.unique(vertex_indices.flatten())
vert_ndx_old2new = dict(zip(vert_ndx_new2old,
np.arange(np.size(vert_ndx_new2old),
dtype=vertex_indices.dtype)))
vertex_indices = \
np.vectorize(vert_ndx_old2new.__getitem__)(vertex_indices)
dtype=vertex_indices.dtype),
strict=True))
vertex_indices = (
np.vectorize(vert_ndx_old2new.__getitem__)(vertex_indices))

with ProcessLogger(logger, "Building (possibly) unflipped "
"SimplexElementGroup from firedrake unit nodes/nodes"):
Expand Down Expand Up @@ -872,7 +875,8 @@ def export_mesh_to_firedrake(mesh, group_nr=None, comm=None):
group = mesh.groups[group_nr]
fd2mm_indices = np.unique(group.vertex_indices.flatten())
coords = mesh.vertices[:, fd2mm_indices].T
mm2fd_indices = dict(zip(fd2mm_indices, np.arange(np.size(fd2mm_indices))))
mm2fd_indices = dict(zip(fd2mm_indices, np.arange(np.size(fd2mm_indices)),
strict=True))
cells = np.vectorize(mm2fd_indices.__getitem__)(group.vertex_indices)

# Get a dmplex object and then a mesh topology
Expand Down
4 changes: 2 additions & 2 deletions meshmode/mesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,7 @@ def _compute_nodal_adjacency_from_vertices(mesh: Mesh) -> NodalAdjacency:
_, nvertices = mesh.vertices.shape
vertex_to_element: list[list[int]] = [[] for i in range(nvertices)]

for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups, strict=True):
if grp.vertex_indices is None:
raise ValueError("unable to compute nodal adjacency without vertices")

Expand All @@ -1565,7 +1565,7 @@ def _compute_nodal_adjacency_from_vertices(mesh: Mesh) -> NodalAdjacency:
vertex_to_element[ivertex].append(base_element_nr + iel_grp)

element_to_element: list[set[int]] = [set() for i in range(mesh.nelements)]
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups, strict=True):
assert grp.vertex_indices is not None

for iel_grp in range(grp.nelements):
Expand Down
5 changes: 3 additions & 2 deletions meshmode/mesh/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ def generate_regular_rect_mesh(
"lower topological dimension and map it.)")

axis_coords = [np.linspace(a_i, b_i, npoints_i)
for a_i, b_i, npoints_i in zip(a, b, npoints_per_axis)]
for a_i, b_i, npoints_i in zip(a, b, npoints_per_axis, strict=False)]

return generate_box_mesh(axis_coords, order=order,
periodic=periodic,
Expand Down Expand Up @@ -1655,7 +1655,8 @@ def warp_and_refine_until_resolved(
"(NaN or Inf)")

for base_element_nr, egrp in zip(
warped_mesh.base_element_nrs, warped_mesh.groups):
warped_mesh.base_element_nrs, warped_mesh.groups,
strict=True):
if not isinstance(egrp, SimplexElementGroup):
raise TypeError(
f"Unsupported element group type: '{type(egrp).__name__}'")
Expand Down
6 changes: 4 additions & 2 deletions meshmode/mesh/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ def get_mesh(self, return_tag_to_elements_map=False):
i = 0

for el_vertices, el_nodes, el_type, el_markers in zip(
self.element_vertices, self.element_nodes, self.element_types,
self.element_markers):
self.element_vertices,
self.element_nodes,
self.element_types,
self.element_markers, strict=True):
if el_type is not group_el_type:
continue

Expand Down
Loading

0 comments on commit d4e0de9

Please sign in to comment.