Skip to content

Commit

Permalink
Allow different number of grid points along each axis
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Feb 13, 2024
1 parent cd6b0a9 commit 923833f
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions pysages/approxfun/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,11 @@ def compute_mesh(grid: Grid):
Returns a dense mesh with the same shape as `grid`, but on the hypercube
[-1, 1]ⁿ, where `n` is the dimensionality of `grid`.
"""
h = 2 / grid.shape
o = -1 + h / 2

nodes = o + h * np.hstack([np.arange(i).reshape(-1, 1) for i in grid.shape])
def transform(n):
return vmap(lambda k: -1 + (2 * k + 1) / n)

return _compute_mesh(nodes)
return _compute_mesh(transform, grid.shape)


@dispatch
Expand All @@ -128,16 +127,20 @@ def compute_mesh(grid: Grid[Chebyshev]): # noqa: F811 # pylint: disable=C0116,E
def transform(n):
return vmap(lambda k: -np.cos((k + 1 / 2) * np.pi / n))

nodes = np.hstack([transform(i)(np.arange(i).reshape(-1, 1)) for i in grid.shape])
return _compute_mesh(transform, grid.shape)

return _compute_mesh(nodes)

def _compute_mesh(transform, shape):
axes = (transform(i)(np.arange(i)) for i in np.flip(shape))
return cartesian_product(*axes)

def _compute_mesh(nodes):
components = np.meshgrid(
*nodes.T,
)
return np.hstack([v.reshape(-1, 1) for v in components])

def cartesian_product(*collections):
"""
Given a set of `collections`, returns an array with their [Cartesian
Product](https://en.wikipedia.org/wiki/Cartesian_product).
"""
return np.dstack(np.meshgrid(*collections)).reshape(-1, len(collections))


def vander_builder(grid, exponents):
Expand Down

0 comments on commit 923833f

Please sign in to comment.