Skip to content

Commit

Permalink
feat: Introduce cache_key to sdk
Browse files Browse the repository at this point in the history
Signed-off-by: Ze Mao <[email protected]>
  • Loading branch information
Ze Mao authored and Ubuntu committed Dec 13, 2024
1 parent 0eb67e1 commit 6ca07d1
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 4 deletions.
23 changes: 22 additions & 1 deletion sdk/python/kfp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def run_pipeline(
version_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = '',
service_account: Optional[str] = None,
) -> kfp_server_api.V2beta1Run:
"""Runs a specified pipeline.
Expand All @@ -709,6 +710,8 @@ def run_pipeline(
is ``True`` for all tasks by default. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
Expand All @@ -721,6 +724,7 @@ def run_pipeline(
pipeline_id=pipeline_id,
version_id=version_id,
enable_caching=enable_caching,
cache_key=cache_key,
pipeline_root=pipeline_root,
)

Expand Down Expand Up @@ -806,6 +810,7 @@ def create_recurring_run(
enabled: bool = True,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = '',
service_account: Optional[str] = None,
) -> kfp_server_api.V2beta1RecurringRun:
"""Creates a recurring run.
Expand Down Expand Up @@ -850,6 +855,8 @@ def create_recurring_run(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account this recurring run uses.
Returns:
Expand All @@ -862,6 +869,7 @@ def create_recurring_run(
pipeline_id=pipeline_id,
version_id=version_id,
enable_caching=enable_caching,
cache_key=cache_key,
pipeline_root=pipeline_root,
)

Expand Down Expand Up @@ -908,6 +916,7 @@ def _create_job_config(
pipeline_id: Optional[str],
version_id: Optional[str],
enable_caching: Optional[bool],
cache_key: Optional[str],
pipeline_root: Optional[str],
) -> _JobConfig:
"""Creates a JobConfig with spec and resource_references.
Expand All @@ -928,6 +937,8 @@ def _create_job_config(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
pipeline_root: Root path of the pipeline outputs.
Returns:
Expand Down Expand Up @@ -956,7 +967,7 @@ def _create_job_config(
# settings.
if enable_caching is not None:
_override_caching_options(pipeline_doc.pipeline_spec,
enable_caching)
enable_caching, cache_key)
pipeline_spec = pipeline_doc.to_dict()

pipeline_version_reference = None
Expand All @@ -983,6 +994,7 @@ def create_run_from_pipeline_func(
namespace: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = '',
service_account: Optional[str] = None,
experiment_id: Optional[str] = None,
) -> RunPipelineResult:
Expand All @@ -1004,6 +1016,8 @@ def create_run_from_pipeline_func(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
Expand Down Expand Up @@ -1032,6 +1046,7 @@ def create_run_from_pipeline_func(
namespace=namespace,
pipeline_root=pipeline_root,
enable_caching=enable_caching,
cache_key=cache_key,
service_account=service_account,
)

Expand All @@ -1044,6 +1059,7 @@ def create_run_from_pipeline_package(
namespace: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = '',
service_account: Optional[str] = None,
experiment_id: Optional[str] = None,
) -> RunPipelineResult:
Expand All @@ -1065,6 +1081,8 @@ def create_run_from_pipeline_package(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
Expand Down Expand Up @@ -1105,6 +1123,7 @@ def create_run_from_pipeline_package(
params=arguments,
pipeline_root=pipeline_root,
enable_caching=enable_caching,
cache_key=cache_key,
service_account=service_account,
)
return RunPipelineResult(self, run_info)
Expand Down Expand Up @@ -1681,6 +1700,7 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc:
def _override_caching_options(
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
enable_caching: bool,
cache_key: str = '',
) -> None:
"""Overrides caching options.
Expand All @@ -1690,3 +1710,4 @@ def _override_caching_options(
"""
for _, task_spec in pipeline_spec.root.dag.tasks.items():
task_spec.caching_options.enable_cache = enable_caching
task_spec.caching_options.cache_key = cache_key
9 changes: 8 additions & 1 deletion sdk/python/kfp/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,19 @@ def pipeline_with_two_component(text: str = 'hi there'):
pipeline_obj = yaml.safe_load(f)
pipeline_spec = json_format.ParseDict(
pipeline_obj, pipeline_spec_pb2.PipelineSpec())
client._override_caching_options(pipeline_spec, True)
client._override_caching_options(
pipeline_spec, True, cache_key='OVERRIDE_KEY')
pipeline_obj = json_format.MessageToDict(pipeline_spec)
self.assertTrue(pipeline_obj['root']['dag']['tasks']
['hello-word']['cachingOptions']['enableCache'])
self.assertEqual(
pipeline_obj['root']['dag']['tasks']['hello-word']
['cachingOptions']['cacheKey'], 'OVERRIDE_KEY')
self.assertTrue(pipeline_obj['root']['dag']['tasks']['to-lower']
['cachingOptions']['enableCache'])
self.assertEqual(
pipeline_obj['root']['dag']['tasks']['to-lower']
['cachingOptions']['cacheKey'], 'OVERRIDE_KEY')


class TestExtractPipelineYAML(parameterized.TestCase):
Expand Down
26 changes: 26 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,32 @@ def my_pipeline():

self.assertTrue(caching_options['enableCache'])

def test_compile_pipeline_with_cache_key(self):
"""Test pipeline compilation with cache key."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name='tiny-pipeline')
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(True, cache_key='MY_KEY')

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec['cachingOptions']

self.assertTrue(caching_options['enableCache'])
self.assertEqual(caching_options['cacheKey'], 'MY_KEY')

def test_compile_pipeline_with_caching_disabled(self):
"""Test pipeline compilation with caching disabled."""

Expand Down
1 change: 1 addition & 0 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def build_task_spec_for_task(
utils.sanitize_component_name(task.name))
pipeline_task_spec.caching_options.enable_cache = (
task._task_spec.enable_caching)
pipeline_task_spec.caching_options.cache_key = (task._task_spec.cache_key)

if task._task_spec.retry_policy is not None:
pipeline_task_spec.retry_policy.CopyFrom(
Expand Down
10 changes: 8 additions & 2 deletions sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
args: Dict[str, Any],
execute_locally: bool = False,
execution_caching_default: bool = True,
execution_cache_key: str = '',
) -> None:
"""Initilizes a PipelineTask instance."""
# import within __init__ to avoid circular import
Expand Down Expand Up @@ -131,7 +132,8 @@ def __init__(
inputs=dict(args.items()),
dependent_tasks=[],
component_ref=component_spec.name,
enable_caching=execution_caching_default)
enable_caching=execution_caching_default,
cache_key=execution_cache_key)
self._run_after: List[str] = []

self.importer_spec = None
Expand Down Expand Up @@ -301,16 +303,20 @@ def _extract_container_spec_and_convert_placeholders(
return container_spec

@block_if_final()
def set_caching_options(self, enable_caching: bool) -> 'PipelineTask':
def set_caching_options(self,
enable_caching: bool,
cache_key: str = '') -> 'PipelineTask':
"""Sets caching options for the task.
Args:
enable_caching: Whether to enable caching.
cache_key: Customized cache key for this task.
Returns:
Self return to allow chained setting calls.
"""
self._task_spec.enable_caching = enable_caching
self._task_spec.cache_key = cache_key
return self

def _ensure_container_spec_exists(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/kfp/dsl/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ class TaskSpec:
from the [items][] collection.
enable_caching (optional): whether or not to enable caching for the task.
Default is True.
cache_key (optional): Customized cache key for this task.
Default is empty string.
display_name (optional): the display name of the task. If not specified,
the task name will be used as the display name.
"""
Expand All @@ -421,6 +423,7 @@ class TaskSpec:
iterator_items: Optional[Any] = None
iterator_item_input: Optional[str] = None
enable_caching: bool = True
cache_key: str = ''
display_name: Optional[str] = None
retry_policy: Optional[RetryPolicy] = None

Expand Down

0 comments on commit 6ca07d1

Please sign in to comment.