Skip to content

Commit

Permalink
Merge branch 'main' into sys_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen authored Jul 12, 2024
2 parents 86e572e + e53cc17 commit e695127
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 69 deletions.
10 changes: 6 additions & 4 deletions nvflare/app_common/executors/in_process_client_api_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLMetaKey, ReturnCode
from nvflare.apis.fl_constant import FLContextKey, FLMetaKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.analytix_utils import create_analytic_dxo
from nvflare.apis.workspace import Workspace
from nvflare.app_common.abstract.params_converter import ParamsConverter
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.task_script_runner import TaskScriptRunner
Expand Down Expand Up @@ -107,10 +108,11 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
self._fl_ctx = fl_ctx
self._init_converter(fl_ctx)

workspace: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID)
custom_dir = workspace.get_app_custom_dir(job_id)
self._task_fn_wrapper = TaskScriptRunner(
site_name=fl_ctx.get_identity_name(),
script_path=self._task_script_path,
script_args=self._task_script_args,
custom_dir=custom_dir, script_path=self._task_script_path, script_args=self._task_script_args
)

self._task_fn_thread = threading.Thread(target=self._task_fn_wrapper.run)
Expand Down
31 changes: 16 additions & 15 deletions nvflare/app_common/executors/task_script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@
class TaskScriptRunner:
logger = logging.getLogger(__name__)

def __init__(self, site_name: str, script_path: str, script_args: str = None, redirect_print_to_log=True):
def __init__(self, custom_dir: str, script_path: str, script_args: str = None, redirect_print_to_log=True):
"""Wrapper for function given function path and args
Args:
site_name (str): site name
custom_dir (str): site name
script_path (str): script file name, such as train.py
script_args (str, Optional): script arguments to pass in.
"""

self.redirect_print_to_log = redirect_print_to_log
self.event_manager = EventManager(DataBus())
self.script_args = script_args
self.site_name = site_name
self.custom_dir = custom_dir
self.logger = logging.getLogger(self.__class__.__name__)
self.script_path = script_path
self.script_full_path = self.get_script_full_path(self.site_name, self.script_path)
self.script_full_path = self.get_script_full_path(self.custom_dir, self.script_path)

def run(self):
"""Call the task_fn with any required arguments."""
Expand Down Expand Up @@ -77,7 +77,12 @@ def get_sys_argv(self):
args_list = [] if not self.script_args else self.script_args.split()
return [self.script_full_path] + args_list

def get_script_full_path(self, site_name, script_path) -> str:
def get_script_full_path(self, custom_dir, script_path) -> str:
if not custom_dir:
raise ValueError("custom_dir must be not empty")
if not script_path:
raise ValueError("script_path must be not empty")

target_file = None
script_filename = os.path.basename(script_path)
script_dirs = os.path.dirname(script_path)
Expand All @@ -87,18 +92,14 @@ def get_script_full_path(self, site_name, script_path) -> str:
raise ValueError(f"script_path='{script_path}' not found")
return script_path

for r, dirs, files in os.walk(os.getcwd()):
for r, dirs, files in os.walk(custom_dir):
for f in files:
absolute_path = os.path.join(r, f)
if absolute_path.endswith(script_path):
parent_dir = absolute_path[: absolute_path.find(script_path)].rstrip(os.sep)
if os.path.isdir(parent_dir):
path_components = parent_dir.split(os.path.sep)
if site_name in path_components:
target_file = absolute_path
break

if not site_name and not script_dirs and f == script_filename:
if absolute_path.endswith(os.sep + script_path):
target_file = absolute_path
break

if not custom_dir and not script_dirs and f == script_filename:
target_file = absolute_path
break

Expand Down
Loading

0 comments on commit e695127

Please sign in to comment.