Skip to content

Commit

Permalink
Fix the simulator worker sys path (#2561)
Browse files Browse the repository at this point in the history
* Fixed the simulator worker sys path.

* fixed the get_new_sys_path() logic, added in unit test.

* fixed isort.

* Changed the _get_new_sys_path() implementation.

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
yhwen and YuanTingHsieh authored May 14, 2024
1 parent 4025b7c commit eca7e12
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
13 changes: 9 additions & 4 deletions nvflare/private/fed/app/simulator/simulator_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,8 @@ def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name
if gpu:
command += " --gpu " + str(gpu)
new_env = os.environ.copy()
if not sys.path[0]:
new_env["PYTHONPATH"] = os.pathsep.join(sys.path[1:])
else:
new_env["PYTHONPATH"] = os.pathsep.join(sys.path)
new_env["PYTHONPATH"] = os.pathsep.join(self._get_new_sys_path())

_ = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env)

conn = self._create_connection(open_port, timeout=timeout)
Expand Down Expand Up @@ -696,6 +694,13 @@ def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name

return stop_run, next_client, end_run_client

def _get_new_sys_path(self):
new_sys_path = []
for i in range(0, len(sys.path) - 1):
if sys.path[i]:
new_sys_path.append(sys.path[i])
return new_sys_path

def _create_connection(self, open_port, timeout=60.0):
conn = None
start = time.time()
Expand Down
7 changes: 3 additions & 4 deletions nvflare/private/fed/app/simulator/simulator_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ def run(self, args, conn):

client = self._create_client(args, build_ctx, deploy_args)

app_root = get_simulator_app_root(args.workspace, client.client_name)
app_root = get_simulator_app_root(args.simulator_root, client.client_name)
app_custom_folder = os.path.join(app_root, "custom")
sys.path.append(app_custom_folder)
if os.path.isdir(app_custom_folder) and app_custom_folder not in sys.path:
sys.path.append(app_custom_folder)

self.create_client_engine(client, deploy_args)

Expand Down Expand Up @@ -235,8 +236,6 @@ def main(args):
log_file = os.path.join(args.workspace, WorkspaceConstants.LOG_FILE_NAME)
add_logfile_handler(log_file)

app_custom_folder = os.path.join(args.workspace, "custom")
sys.path.append(app_custom_folder)
os.chdir(args.workspace)
startup = os.path.join(args.workspace, WorkspaceConstants.STARTUP_FOLDER_NAME)
os.makedirs(startup, exist_ok=True)
Expand Down
40 changes: 38 additions & 2 deletions tests/unit_test/private/fed/app/simulator/simulator_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
import shutil
import sys
import threading
import time
import uuid
from argparse import Namespace
from tempfile import TemporaryDirectory
from unittest.mock import Mock, patch

import pytest

from nvflare.apis.fl_constant import FLContextKey, MachineStatus, WorkspaceConstants
from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner
from nvflare.private.fed.app.simulator.simulator_runner import SimulatorClientRunner, SimulatorRunner
from nvflare.private.fed.utils.fed_utils import split_gpus


Expand Down Expand Up @@ -155,3 +157,37 @@ def test_start_server_app(self, mock_deploy, mock_admin, mock_register, mock_cel

runner.server.logger = Mock()
runner.server.engine.asked_to_stop = True

def test_get_new_sys_path_with_empty(self):
args = Namespace(workspace="/tmp")
args.set = []
runner = SimulatorClientRunner(args, [], None, None, None)
old_sys_path = copy.deepcopy(sys.path)
sys.path.insert(0, "")
sys.path.append("/temp/test")
new_sys_path = runner._get_new_sys_path()
assert old_sys_path == new_sys_path
sys.path = old_sys_path

def test_get_new_sys_path_with_multiple_empty(self):
args = Namespace(workspace="/tmp")
args.set = []
runner = SimulatorClientRunner(args, [], None, None, None)
old_sys_path = copy.deepcopy(sys.path)
sys.path.insert(0, "")
if len(sys.path) > 2:
sys.path.insert(2, "")
sys.path.append("/temp/test")
new_sys_path = runner._get_new_sys_path()
assert old_sys_path == new_sys_path
sys.path = old_sys_path

def test_get_new_sys_path(self):
args = Namespace(workspace="/tmp")
args.set = []
runner = SimulatorClientRunner(args, [], None, None, None)
old_sys_path = copy.deepcopy(sys.path)
sys.path.append("/temp/test")
new_sys_path = runner._get_new_sys_path()
assert old_sys_path == new_sys_path
sys.path = old_sys_path

0 comments on commit eca7e12

Please sign in to comment.