Skip to content

Commit

Permalink
Add test to verify setting of SemaphoreKey and MutexName fields in DSL
Browse files Browse the repository at this point in the history
Signed-off-by: ddalvi <[email protected]>
  • Loading branch information
DharmitD committed Dec 16, 2024
1 parent 6af6032 commit b659f37
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
60 changes: 60 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3847,6 +3847,66 @@ def outer():
foo_platform_set_bar_feature(task, 12)


class TestPipelineSemaphoreMutex(unittest.TestCase):

def test_pipeline_with_semaphore(self):
"""Test that pipeline config correctly sets the semaphore key."""
config = PipelineConfig()
config.set_semaphore_key('semaphore')

@dsl.pipeline(pipeline_config=config)
def my_pipeline():
task = comp()

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml
)

with open(output_yaml, 'r') as f:
pipeline_docs = list(yaml.safe_load_all(f))

pipeline_spec = None
for doc in pipeline_docs:
if 'platforms' in doc:
pipeline_spec = doc
break

if pipeline_spec:
kubernetes_spec = pipeline_spec['platforms']['kubernetes']['pipelineConfig']
assert kubernetes_spec['semaphoreKey'] == 'semaphore'
self.assertNotIn('mutexName', kubernetes_spec)

def test_pipeline_with_mutex(self):
"""Test that pipeline config correctly sets the mutex name."""
config = PipelineConfig()
config.set_mutex_name('mutex')

@dsl.pipeline(pipeline_config=config)
def my_pipeline():
task = comp()

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml
)

with open(output_yaml, 'r') as f:
pipeline_docs = list(yaml.safe_load_all(f))

pipeline_spec = None
for doc in pipeline_docs:
if 'platforms' in doc:
pipeline_spec = doc
break

if pipeline_spec:
kubernetes_spec = pipeline_spec['platforms']['kubernetes']['pipelineConfig']
assert kubernetes_spec['mutexName'] == 'mutex'


class ExtractInputOutputDescription(unittest.TestCase):

def test_no_descriptions(self):
Expand Down
10 changes: 9 additions & 1 deletion sdk/python/kfp/dsl/pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@ def __init__(self):
def set_semaphore_key(self, semaphore_key: str):
"""Set the name of the semaphore to control pipeline concurrency.
The semaphore is configured via a ConfigMap. By default, the ConfigMap is
named "semaphore-config", but this name can be specified through the APIServer
deployment manifests using an environment variable named SEMAPHORE_CONFIGMAP_NAME.
If the environment variable is not specified, the default name "semaphore-config"
is used. The semaphore key is provided through the pipeline configuration.
If a pipeline has a semaphore, the backend maps the semaphore to the ConfigMap
using the key provided by the user.
Args:
semaphore_key (str): Name of the semaphore.
semaphore_key (str): The key used to map to the ConfigMap.
"""
self.semaphore_key = semaphore_key.strip()

Expand Down

0 comments on commit b659f37

Please sign in to comment.