Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of Agent/Environment with runner #332

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d9e854a
Fix ABC usage issues
amaslenn Dec 19, 2024
34cda1f
Add test hooks to USER_GUIDE.md (#322)
TaekyungHeo Jan 6, 2025
0cfd5e6
remove the default condition check
srivatsankrishnan Jan 8, 2025
baa2925
preserves lists in cmd_args as is (for pydantic validation)
srivatsankrishnan Jan 8, 2025
1a26f7f
propate cmd_args type to all places in cloudAI for pyright errors
srivatsankrishnan Jan 8, 2025
aa8a38c
Add ClassVar to remove pydantic annonation error
srivatsankrishnan Jan 8, 2025
c95fefa
fix pytest
srivatsankrishnan Jan 8, 2025
cd8fbbf
Do not check image accessibility using "local" enroot
amaslenn Jan 6, 2025
3f358ea
Do not require enroot binary on head node
amaslenn Jan 6, 2025
41694b7
Pass SlurmSystem into DockerImageCacheManager
amaslenn Jan 6, 2025
1bcb661
Specify account while caching images
amaslenn Jan 6, 2025
2fb9b92
Reduce noise in CLI output
amaslenn Jan 6, 2025
a1ed291
Make ruff happy
amaslenn Jan 6, 2025
f85961b
more unit tests for parser with Grok Test definition + pydantic of on…
srivatsankrishnan Jan 8, 2025
aa2d0d9
ruffing
srivatsankrishnan Jan 8, 2025
3e0864d
Add more test to have ranges for FDL flags.
srivatsankrishnan Jan 8, 2025
e2fe3aa
More test for XLA flags as list other fixed + fixing typing in Grok/J…
srivatsankrishnan Jan 8, 2025
8367cb7
All static values (benchmarking scenarios in CloudAI)
srivatsankrishnan Jan 8, 2025
a7b1633
negative tests with various types in the list
srivatsankrishnan Jan 8, 2025
695a5d0
remove the unit tests
srivatsankrishnan Jan 8, 2025
44a9fda
remove instance check (assuming model_dump() never fails)
srivatsankrishnan Jan 8, 2025
6124236
fix the typing for slurm_args
srivatsankrishnan Jan 8, 2025
f229433
removing the old _parser_cmd method that is not used.
srivatsankrishnan Jan 8, 2025
fd83f65
test and test scenario for environment configuration
srivatsankrishnan Jan 9, 2025
118638a
Add configurable gym environment from test run object
srivatsankrishnan Jan 9, 2025
532c9ad
Configurable cloudaigym environment and tests
srivatsankrishnan Jan 9, 2025
db516e3
reorg the environment under _core directory
srivatsankrishnan Jan 10, 2025
6bfe9e2
fix pyright and pytest issues
srivatsankrishnan Jan 10, 2025
4fc907e
Remove conf/common/test/chakra_replay.toml (#328)
TaekyungHeo Jan 9, 2025
c0f8f6c
checkpoint policy serializer for list/ranges
srivatsankrishnan Jan 10, 2025
dae998f
Add farma gym to requirements
srivatsankrishnan Jan 10, 2025
6487344
vulture check
srivatsankrishnan Jan 10, 2025
cdfb826
fix pyproject.toml
srivatsankrishnan Jan 10, 2025
76f1f2f
Fix the test package errors
srivatsankrishnan Jan 10, 2025
897c4b9
taplo
srivatsankrishnan Jan 10, 2025
bcd45da
port agent interface and grid search
srivatsankrishnan Jan 10, 2025
d31412c
Ignore vulture for grid search
srivatsankrishnan Jan 10, 2025
3ad97cd
vulture and ruff fixes
srivatsankrishnan Jan 10, 2025
bf6c96a
remove comments
srivatsankrishnan Jan 10, 2025
7e10b22
removed the fixed value
srivatsankrishnan Jan 10, 2025
dff5c64
Merge branch 'main' into config-agent
srivatsankrishnan Jan 10, 2025
275bf39
remove the setter
srivatsankrishnan Jan 11, 2025
63efa41
Not introduce range as of now. Stick to static lists
srivatsankrishnan Jan 11, 2025
d945201
agent environment intergation with runner
srivatsankrishnan Jan 11, 2025
1be4398
more fixes
srivatsankrishnan Jan 11, 2025
4df4ab9
Remove Farma gym dependies for more control over types + other fixes …
srivatsankrishnan Jan 11, 2025
e6905f7
vulture fix
srivatsankrishnan Jan 11, 2025
177694f
remove farma gym dependencies + update the pytest for cloudai_gym
srivatsankrishnan Jan 11, 2025
b10dbfb
remove farma gym from pyproject
srivatsankrishnan Jan 11, 2025
15be693
fix the copyright headers checks
srivatsankrishnan Jan 11, 2025
d5d1e14
use iterators to avoid indexing errors.
srivatsankrishnan Jan 11, 2025
96ab055
helper method for manipulating the TestRun object directly
srivatsankrishnan Jan 11, 2025
aa9dcea
Merge branch 'main' into agent-env-integration
srivatsankrishnan Jan 13, 2025
e6f833c
Removing the agent's configuration and instead query from the environ…
srivatsankrishnan Jan 14, 2025
e211b30
Merge branch 'main' into agent-env-integration
srivatsankrishnan Jan 14, 2025
d9b77e5
Merge branch 'agent-env-integration' of https://github.com/srivatsank…
srivatsankrishnan Jan 14, 2025
24fe925
Fix the testing code
srivatsankrishnan Jan 14, 2025
5eb42d0
fix the configurator structure
srivatsankrishnan Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions conf/common/test/dse_jaxtoolbox_grok.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name = "dse_jaxtoolbox_grok"
description = "DSE JaxToolbox Grok"
test_template_name = "JaxToolboxGrok"

[cmd_args]
docker_image_url = "https://docker/url"
[cmd_args.fdl]
num_gpus = [1, 8, 16]
checkpoint_policy = ["save_iteration_input", "save_none"]
num_groups = "16"
use_fp8 = "1"
use_repeated_layer = "True"

[extra_env_vars]
"ENABLE_TE" = "0"
"NVTE_FUSED_ATTN" = "1"
"COMBINE_THRESHOLD" = "301989888"
"XLA_PYTHON_CLIENT_MEM_FRACTION" = "0.9"
22 changes: 22 additions & 0 deletions conf/common/test_scenario/dse_jaxtoolbox.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name = "dse_jaxtoolbox_grok"

[[Tests]]
id = "Tests.1"
test_name = "dse_jaxtoolbox_grok"
num_nodes = "1"
102 changes: 102 additions & 0 deletions src/cloudai/_core/configurator/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict

from cloudai._core.test_scenario import TestRun


class BaseAgent(ABC):
"""
Base class for all agents in the CloudAI framework.

Provides a unified interface and parameter management for action spaces.
Automatically infers parameter types from TestRun's cmd_args.
"""

def __init__(self, test_run: TestRun):
"""
Initialize the agent with the TestRun object.

Args:
test_run (TestRun): The TestRun object containing cmd_args and test state.
"""
self.test_run = test_run
self.action_space = self.extract_action_space(test_run.test.cmd_args)

def extract_action_space(self, cmd_args: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract the action space from cmd_args by inferring parameter types.

Args:
cmd_args (Dict[str, Any]): The command arguments from TestRun.

Returns:
Dict[str, Any]: Action space defined with inferred parameter types.
"""
action_space = {}

for key, value in cmd_args.items():
self._process_value(action_space, key, value)

return action_space

def _process_value(self, action_space: Dict[str, Any], key: str, value: Any) -> None:
if isinstance(value, list):
self._process_list(action_space, key, value)
elif isinstance(value, dict):
for sub_key, sub_value in value.items():
full_key = f"{key}.{sub_key}"
self._process_value(action_space, full_key, sub_value)

def _process_list(self, action_space: Dict[str, Any], key: str, value: list) -> None:
if all(isinstance(v, int) for v in value):
action_space[key] = {"type": "int", "values": value}
elif all(isinstance(v, float) for v in value):
action_space[key] = {"type": "float", "values": value}
else:
action_space[key] = {"type": "categorical", "categories": value}

@abstractmethod
def configure(self, config: Dict[str, Any]) -> None:
"""
Configure the agent with additional settings.

Args:
config (Dict[str, Any]): Configuration settings for the agent.
"""
pass

@abstractmethod
def select_action(self) -> Dict[str, Any]:
"""
Select an action from the action space.

Returns:
Dict[str, Any]: A dictionary mapping action keys to selected values.
"""
pass

@abstractmethod
def update_policy(self, _feedback: Dict[str, Any]) -> None:
"""
Update the agent state based on feedback from the environment.

Args:
feedback (Dict[str, Any]): Feedback information from the environment.
"""
pass
91 changes: 91 additions & 0 deletions src/cloudai/_core/configurator/agents/grid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from typing import Any, Dict, List

from cloudai._core.configurator.agents.base_agent import BaseAgent
from cloudai._core.test_scenario import TestRun


class GridSearchAgent(BaseAgent):
"""
Agent implementing a grid search over the action space.

Iterates through all possible parameter combinations.
"""

def __init__(self, test_run: TestRun):
"""
Initialize the GridSearchAgent with the TestRun object.

Args:
test_run (TestRun): The TestRun object containing cmd_args and test state.
"""
super().__init__(test_run)
self.action_combinations = []
self.index = 0

def configure(self, config: Dict[str, Any]) -> None:
"""
Configure the grid search by precomputing all parameter combinations.

Args:
config (Dict[str, Any]): Additional configuration settings (optional).
"""
parameter_values = []
for _key, param in self.action_space.items():
if param["type"] == "int" or param["type"] == "float":
parameter_values.append(param["values"])
elif param["type"] == "categorical":
parameter_values.append(param["categories"])

self.action_combinations = list(itertools.product(*parameter_values))
self.index = 0

def get_all_combinations(self) -> List[Dict[str, Any]]:
"""
Get all possible combinations of the action space parameters.

Returns:
List[Dict[str, Any]]: A list of dictionaries, each representing a unique combination of parameters.
"""
keys = list(self.action_space.keys())
return [dict(zip(keys, combination)) for combination in self.action_combinations]

def select_action(self) -> Dict[str, Any]:
"""
Select the next action from the grid.

Returns:
Dict[str, Any]: A dictionary mapping action keys to selected values.
"""
if self.index >= len(self.action_combinations):
raise StopIteration("Grid search completed.")

action = dict(zip(self.action_space.keys(), self.action_combinations[self.index]))
self.index += 1
return action

def update_policy(self, _feedback: Dict[str, Any]) -> None:
"""
Update the agent based on feedback (not used in grid search).

Args:
feedback (Dict[str, Any]): Feedback information from the environment.
"""
# Grid search is stateless and does not rely on feedback.
pass
102 changes: 102 additions & 0 deletions src/cloudai/_core/configurator/base_gym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple


class BaseGym(ABC):
"""Base class for CloudAI Gym environments."""

def __init__(self):
"""Initialize the CloudAIGym environment."""
self.action_space = self.define_action_space()
self.observation_space = self.define_observation_space()

@abstractmethod
def define_action_space(self) -> Dict[str, Any]:
"""
Define the action space for the environment.

Returns:
Dict[str, Any]: The action space.
"""
pass

@abstractmethod
def define_observation_space(self) -> list:
"""
Define the observation space for the environment.

Returns:
list: The observation space.
"""
pass

@abstractmethod
def reset(
self, seed: Optional[int] = None, _options: Optional[dict[str, Any]] = None
) -> Tuple[list, dict[str, Any]]:
"""
Reset the environment.

Args:
seed (Optional[int]): Seed for the environment's random number generator.
options (Optional[dict]): Additional options for reset.

Returns:
Tuple: A tuple containing:
- observation (list): Initial observation.
- info (dict): Additional info for debugging.
"""
pass

@abstractmethod
def step(self, action: Any) -> Tuple[list, float, bool, dict]:
"""
Execute one step in the environment.

Args:
action (Any): Action chosen by the agent.

Returns:
Tuple: A tuple containing:
- observation (list): Updated system state.
- reward (float): Reward for the action taken.
- done (bool): Whether the episode is done.
- info (dict): Additional info for debugging.
"""
pass

@abstractmethod
def render(self, mode: str = "human"):
"""
Render the current state of the environment.

Args:
mode (str): The mode to render with. Default is "human".
"""
pass

@abstractmethod
def seed(self, seed: Optional[int] = None):
"""
Set the seed for the environment's random number generator.

Args:
seed (Optional[int]): Seed for the environment's random number generator.
"""
pass
Loading
Loading