Skip to content

Commit

Permalink
simplify the code
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffnvidia committed Aug 8, 2024
1 parent c0400b6 commit 6d680f7
Showing 1 changed file with 33 additions and 64 deletions.
97 changes: 33 additions & 64 deletions src/cloudai/systems/slurm/slurm_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,37 +335,20 @@ def get_group_node_names(self, partition_name: str, group_name: str) -> List[str
"""
return [node.name for node in self.get_group_nodes(partition_name, group_name)]

def get_available_nodes_from_group_with_reservation(
self, partition_name: str, group_name: str, number_of_nodes: int
) -> List[SlurmNode]:
def _get_reserved_nodes(self, partition_name: str, group_name: str) -> Dict[SlurmNodeState, List[SlurmNode]]:
"""
Retrieve a specific number of potentially available nodes from a group within a partition.
Prioritizes nodes by their current state, preferring idle nodes first, then completing nodes, and finally
allocated nodes, while excluding nodes that are down and allocated nodes to the current user.
If a reservation was queried, then cloudAI will take from the reserved nodes according to the reservation name.
Return the reserved nodes corresponding to the given reservation name.
Args:
partition_name (str): The name of the partition.
group_name (str): The name of the group.
number_of_nodes (int): The number of nodes to retrieve.
Returns:
List[SlurmNode]: Objects that are potentially available for use.
Raises:
ValueError: If the partition or group is not found, or if the requested number of nodes exceeds the
available nodes.
Dict[str, str]: Names of nodes within the specified group and partition and reservation.
"""
if partition_name not in self.groups:
raise ValueError(f"Partition '{partition_name}' not found.")
if group_name not in self.groups[partition_name]:
raise ValueError(f"Group '{group_name}' not found in partition '{partition_name}'.")

self.update_node_states()

# Group nodes by their states
reservation_key = "--reservation "
if not self.extra_srun_args:
raise ValueError("extra_srun_args shouldn't be empty")
reservation_name = self.extra_srun_args.split(reservation_key, 1)[1].split(" ", 1)[0]
reservation_output = self.get_reservation(reservation_name)
reserved_nodes = self.parse_reservation_output(reservation_output, reservation_name)
Expand All @@ -375,30 +358,30 @@ def get_available_nodes_from_group_with_reservation(
for node in self.groups[partition_name][group_name]:
if node.state in grouped_nodes and node.name in reserved_nodes:
grouped_nodes[node.state].append(node)

# Allocate nodes based on priority: idle, then completing, then allocated
allocated_nodes = []
for state in grouped_nodes:
while grouped_nodes[state] and len(allocated_nodes) < number_of_nodes:
allocated_nodes.append(grouped_nodes[state].pop(0))

if len(allocated_nodes) < number_of_nodes:
raise ValueError(
"Requested number of nodes ({}) exceeds the number of " "available nodes in group '{}'.".format(
number_of_nodes, group_name
)
)
return grouped_nodes

# Log allocation details
logging.info(
"Allocated nodes from group '{}' in partition '{}': {}".format(
group_name,
partition_name,
[node.name for node in allocated_nodes],
)
)
def _get_available_nodes(self, partition_name: str, group_name: str):
"""
Return the available nodes sorted into idle and completing.
return allocated_nodes
Args:
partition_name (str): The name of the partition.
group_name (str): The name of the group.
Returns:
Dict[str, str]: Names of nodes within the specified group and partition and reservation.
"""
grouped_nodes = {
SlurmNodeState.IDLE: [],
SlurmNodeState.COMPLETING: [],
}

for node in self.groups[partition_name][group_name]:
if node.state in grouped_nodes:
grouped_nodes[node.state].append(node)

return grouped_nodes

def get_available_nodes_from_group(
self, partition_name: str, group_name: str, number_of_nodes: int
Expand Down Expand Up @@ -429,19 +412,10 @@ def get_available_nodes_from_group(

self.update_node_states()

grouped_nodes = {
SlurmNodeState.IDLE: [],
SlurmNodeState.COMPLETING: [],
SlurmNodeState.ALLOCATED: [],
}

for node in self.groups[partition_name][group_name]:
if node.state in grouped_nodes:
# Exclude nodes allocated to the current user
if node.state == SlurmNodeState.ALLOCATED and node.user == current_user:
continue
if node.state in grouped_nodes:
grouped_nodes[node.state].append(node)
if self.extra_srun_args and "reservation" in self.extra_srun_args:
grouped_nodes = self._get_reserved_nodes(partition_name, group_name)
else:
grouped_nodes = self._get_available_nodes(partition_name, group_name)

# Allocate nodes based on priority: idle, then completing, then allocated
allocated_nodes = []
Expand Down Expand Up @@ -714,8 +688,8 @@ def parse_reservation_output(self, reservation_output: str, reservation_name: st
if reservation_name in reservation:
nodes = reservation.split("Nodes=")[1].split(" ")[0]
node_list = self.parse_node_list(nodes)

return node_list
return node_list
raise ValueError("wrong reservation specified \n. Reservation should be in the form \"--reservation reservation_name\"")

def convert_state_to_enum(self, state_str: str) -> SlurmNodeState:
"""
Expand Down Expand Up @@ -803,12 +777,7 @@ def parse_nodes(self, nodes: List[str]) -> List[str]:
raise ValueError("Format should be partition:group:num_nodes")
partition_name, group_name, num_nodes_str = parts
num_nodes = int(num_nodes_str)
if self.extra_srun_args and "reservation" in self.extra_srun_args:
group_nodes = self.get_available_nodes_from_group_with_reservation(
partition_name, group_name, num_nodes
)
else:
group_nodes = self.get_available_nodes_from_group(partition_name, group_name, num_nodes)
group_nodes = self.get_available_nodes_from_group(partition_name, group_name, num_nodes)
parsed_nodes += [node.name for node in group_nodes]
else:
# Handle both individual node names and ranges
Expand Down

0 comments on commit 6d680f7

Please sign in to comment.