diff --git a/sdk/python/kfp/client/client.py b/sdk/python/kfp/client/client.py index f8897236343..46f19e7de96 100644 --- a/sdk/python/kfp/client/client.py +++ b/sdk/python/kfp/client/client.py @@ -686,6 +686,7 @@ def run_pipeline( version_id: Optional[str] = None, pipeline_root: Optional[str] = None, enable_caching: Optional[bool] = None, + cache_key: Optional[str] = '', service_account: Optional[str] = None, ) -> kfp_server_api.V2beta1Run: """Runs a specified pipeline. @@ -709,6 +710,8 @@ def run_pipeline( is ``True`` for all tasks by default. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. service_account: Specifies which Kubernetes service account to use for this run. @@ -721,6 +724,7 @@ def run_pipeline( pipeline_id=pipeline_id, version_id=version_id, enable_caching=enable_caching, + cache_key=cache_key, pipeline_root=pipeline_root, ) @@ -806,6 +810,7 @@ def create_recurring_run( enabled: bool = True, pipeline_root: Optional[str] = None, enable_caching: Optional[bool] = None, + cache_key: Optional[str] = '', service_account: Optional[str] = None, ) -> kfp_server_api.V2beta1RecurringRun: """Creates a recurring run. @@ -850,6 +855,8 @@ def create_recurring_run( different caching options for individual tasks. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. service_account: Specifies which Kubernetes service account this recurring run uses. Returns: @@ -862,6 +869,7 @@ def create_recurring_run( pipeline_id=pipeline_id, version_id=version_id, enable_caching=enable_caching, + cache_key=cache_key, pipeline_root=pipeline_root, ) @@ -908,6 +916,7 @@ def _create_job_config( pipeline_id: Optional[str], version_id: Optional[str], enable_caching: Optional[bool], + cache_key: Optional[str], pipeline_root: Optional[str], ) -> _JobConfig: """Creates a JobConfig with spec and resource_references. @@ -928,6 +937,8 @@ def _create_job_config( different caching options for individual tasks. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. pipeline_root: Root path of the pipeline outputs. Returns: @@ -956,7 +967,7 @@ def _create_job_config( # settings. if enable_caching is not None: _override_caching_options(pipeline_doc.pipeline_spec, - enable_caching) + enable_caching, cache_key) pipeline_spec = pipeline_doc.to_dict() pipeline_version_reference = None @@ -983,6 +994,7 @@ def create_run_from_pipeline_func( namespace: Optional[str] = None, pipeline_root: Optional[str] = None, enable_caching: Optional[bool] = None, + cache_key: Optional[str] = '', service_account: Optional[str] = None, experiment_id: Optional[str] = None, ) -> RunPipelineResult: @@ -1004,6 +1016,8 @@ def create_run_from_pipeline_func( different caching options for individual tasks. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. service_account: Specifies which Kubernetes service account to use for this run. experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name. @@ -1032,6 +1046,7 @@ def create_run_from_pipeline_func( namespace=namespace, pipeline_root=pipeline_root, enable_caching=enable_caching, + cache_key=cache_key, service_account=service_account, ) @@ -1044,6 +1059,7 @@ def create_run_from_pipeline_package( namespace: Optional[str] = None, pipeline_root: Optional[str] = None, enable_caching: Optional[bool] = None, + cache_key: Optional[str] = '', service_account: Optional[str] = None, experiment_id: Optional[str] = None, ) -> RunPipelineResult: @@ -1065,6 +1081,8 @@ def create_run_from_pipeline_package( different caching options for individual tasks. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. service_account: Specifies which Kubernetes service account to use for this run. experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name. @@ -1105,6 +1123,7 @@ def create_run_from_pipeline_package( params=arguments, pipeline_root=pipeline_root, enable_caching=enable_caching, + cache_key=cache_key, service_account=service_account, ) return RunPipelineResult(self, run_info) @@ -1681,6 +1700,7 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc: def _override_caching_options( pipeline_spec: pipeline_spec_pb2.PipelineSpec, enable_caching: bool, + cache_key: str = '', ) -> None: """Overrides caching options. @@ -1690,3 +1710,4 @@ def _override_caching_options( """ for _, task_spec in pipeline_spec.root.dag.tasks.items(): task_spec.caching_options.enable_cache = enable_caching + task_spec.caching_options.cache_key = cache_key diff --git a/sdk/python/kfp/client/client_test.py b/sdk/python/kfp/client/client_test.py index 301ec6d119b..839510d55ff 100644 --- a/sdk/python/kfp/client/client_test.py +++ b/sdk/python/kfp/client/client_test.py @@ -88,12 +88,19 @@ def pipeline_with_two_component(text: str = 'hi there'): pipeline_obj = yaml.safe_load(f) pipeline_spec = json_format.ParseDict( pipeline_obj, pipeline_spec_pb2.PipelineSpec()) - client._override_caching_options(pipeline_spec, True) + client._override_caching_options( + pipeline_spec, True, cache_key='OVERRIDE_KEY') pipeline_obj = json_format.MessageToDict(pipeline_spec) self.assertTrue(pipeline_obj['root']['dag']['tasks'] ['hello-word']['cachingOptions']['enableCache']) + self.assertEqual( + pipeline_obj['root']['dag']['tasks']['hello-word'] + ['cachingOptions']['cacheKey'], 'OVERRIDE_KEY') self.assertTrue(pipeline_obj['root']['dag']['tasks']['to-lower'] ['cachingOptions']['enableCache']) + self.assertEqual( + pipeline_obj['root']['dag']['tasks']['to-lower'] + ['cachingOptions']['cacheKey'], 'OVERRIDE_KEY') class TestExtractPipelineYAML(parameterized.TestCase): diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 898187e36c3..cde6372279f 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -1001,6 +1001,32 @@ def my_pipeline(): self.assertTrue(caching_options['enableCache']) + def test_compile_pipeline_with_cache_key(self): + """Test pipeline compilation with cache key.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name='tiny-pipeline') + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(True, cache_key='MY_KEY') + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec['cachingOptions'] + + self.assertTrue(caching_options['enableCache']) + self.assertEqual(caching_options['cacheKey'], 'MY_KEY') + def test_compile_pipeline_with_caching_disabled(self): """Test pipeline compilation with caching disabled.""" diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index 3061faab5e4..28dfd049d77 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -124,6 +124,7 @@ def build_task_spec_for_task( utils.sanitize_component_name(task.name)) pipeline_task_spec.caching_options.enable_cache = ( task._task_spec.enable_caching) + pipeline_task_spec.caching_options.cache_key = (task._task_spec.cache_key) if task._task_spec.retry_policy is not None: pipeline_task_spec.retry_policy.CopyFrom( diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index b41a14ef82d..fc43d5cf321 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -99,6 +99,7 @@ def __init__( args: Dict[str, Any], execute_locally: bool = False, execution_caching_default: bool = True, + execution_cache_key: str = '', ) -> None: """Initilizes a PipelineTask instance.""" # import within __init__ to avoid circular import @@ -131,7 +132,8 @@ def __init__( inputs=dict(args.items()), dependent_tasks=[], component_ref=component_spec.name, - enable_caching=execution_caching_default) + enable_caching=execution_caching_default, + cache_key=execution_cache_key) self._run_after: List[str] = [] self.importer_spec = None @@ -301,16 +303,20 @@ def _extract_container_spec_and_convert_placeholders( return container_spec @block_if_final() - def set_caching_options(self, enable_caching: bool) -> 'PipelineTask': + def set_caching_options(self, + enable_caching: bool, + cache_key: str = '') -> 'PipelineTask': """Sets caching options for the task. Args: enable_caching: Whether to enable caching. + cache_key: Customized cache key for this task. Returns: Self return to allow chained setting calls. """ self._task_spec.enable_caching = enable_caching + self._task_spec.cache_key = cache_key return self def _ensure_container_spec_exists(self) -> None: diff --git a/sdk/python/kfp/dsl/structures.py b/sdk/python/kfp/dsl/structures.py index 5a73d93b35c..007016ac82f 100644 --- a/sdk/python/kfp/dsl/structures.py +++ b/sdk/python/kfp/dsl/structures.py @@ -409,6 +409,8 @@ class TaskSpec: from the [items][] collection. enable_caching (optional): whether or not to enable caching for the task. Default is True. + cache_key (optional): Customized cache key for this task. + Default is empty string. display_name (optional): the display name of the task. If not specified, the task name will be used as the display name. """ @@ -421,6 +423,7 @@ class TaskSpec: iterator_items: Optional[Any] = None iterator_item_input: Optional[str] = None enable_caching: bool = True + cache_key: str = '' display_name: Optional[str] = None retry_policy: Optional[RetryPolicy] = None