From 1f9d0fc2f2deab77e017a5c2a154db6c2847cc1a Mon Sep 17 00:00:00 2001 From: Richard R <58728519+rrjbca@users.noreply.github.com> Date: Wed, 8 Sep 2021 15:19:08 +0100 Subject: [PATCH] BUG: Move handling of context arguments after handling of .depends keyword (#465) --- skypy/pipeline/_pipeline.py | 6 ++---- skypy/pipeline/tests/test_pipeline.py | 24 +++++++++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/skypy/pipeline/_pipeline.py b/skypy/pipeline/_pipeline.py index 7b2d73e76..2c52227b8 100644 --- a/skypy/pipeline/_pipeline.py +++ b/skypy/pipeline/_pipeline.py @@ -101,8 +101,6 @@ def __init__(self, configuration): self.dag.add_node(job, skip=False) if isinstance(settings, Item): items[job] = settings - # infer additional item properties from context - settings.infer(context) for table, columns in self.table_config.items(): table_complete = '.'.join((table, 'complete')) self.dag.add_node(table_complete) @@ -114,8 +112,6 @@ def __init__(self, configuration): self.dag.add_edge(job, table_complete) if isinstance(settings, Item): items[job] = settings - # infer additional item properties from context - settings.infer(context) # DAG nodes for individual columns in multi-column assignment names = [n.strip() for n in column.split(',')] if len(names) > 1: @@ -137,6 +133,8 @@ def __init__(self, configuration): while c: self.dag.add_edge(c, d) c, d = c.rpartition('.')[0], c + # infer additional item properties from context + settings.infer(context) def execute(self, parameters={}): r'''Run a pipeline. diff --git a/skypy/pipeline/tests/test_pipeline.py b/skypy/pipeline/tests/test_pipeline.py index 7c9674e4e..3b83968c9 100644 --- a/skypy/pipeline/tests/test_pipeline.py +++ b/skypy/pipeline/tests/test_pipeline.py @@ -1,7 +1,7 @@ from astropy.cosmology import FlatLambdaCDM, default_cosmology from astropy.cosmology.core import Cosmology from astropy.io import fits -from astropy.table import Table +from astropy.table import Table, vstack from astropy.table.column import Column from astropy.units import Quantity from astropy.utils.data import get_pkg_data_filename @@ -237,6 +237,28 @@ def value_in_cm(q): np.testing.assert_array_less(pipeline['test_table.lengths_in_cm'], 100) +def test_depends(): + + # Regression test for GitHub Issue #464 + # Previously the .depends keyword was also being passed to functions as a + # keyword argument. This was because Pipeline was executing Item.infer to + # handle additional function arguments from context before handling + # additional dependencies specified using the .depends keyword. The + # .depends keyword is now handled first. + + config = {'tables': { + 'table_1': { + 'column1': Call(np.random.uniform, [0, 1, 10])}, + 'table_2': { + '.init': Call(vstack, [], { + 'tables': [Ref('table_1')], + '.depends': ['table_1.complete']})}}} + + pipeline = Pipeline(config) + pipeline.execute() + assert np.all(pipeline['table_1'] == pipeline['table_2']) + + def teardown_module(module): # Remove fits file generated in test_pipeline