Skip to content

Commit

Permalink
Add NannyPlugins (#5118)
Browse files Browse the repository at this point in the history
This is like WorkerPlugin, but allows for code to run before the Worker
starts up.

Unfortunately this requires the Nanny to check in with the scheduler
before starting the Worker.  In principle this should be fast, but it
does delay the common case for the uncommon case.

This PR includes an Environ nanny-plugin.
If we go with this I think that we should move over PipInstall.
We might also move over UploadFile and make a new UploadDirectory
  • Loading branch information
mrocklin authored Jul 27, 2021
1 parent cf1e412 commit 50fd3ff
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 23 deletions.
10 changes: 9 additions & 1 deletion distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
)
from .core import Status, connect, rpc
from .deploy import Adaptive, LocalCluster, SpecCluster, SSHCluster
from .diagnostics.plugin import PipInstall, SchedulerPlugin, WorkerPlugin
from .diagnostics.plugin import (
Environ,
NannyPlugin,
PipInstall,
SchedulerPlugin,
UploadDirectory,
UploadFile,
WorkerPlugin,
)
from .diagnostics.progressbar import progress
from .event import Event
from .lock import Lock
Expand Down
41 changes: 28 additions & 13 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@
connect,
rpc,
)
from .diagnostics.plugin import UploadFile, WorkerPlugin, _get_worker_plugin_name
from .diagnostics.plugin import (
NannyPlugin,
UploadFile,
WorkerPlugin,
_get_worker_plugin_name,
)
from .metrics import time
from .objects import HasWhat, SchedulerInfo, WhoHas
from .protocol import to_serialize
Expand Down Expand Up @@ -4089,18 +4094,21 @@ def register_worker_callbacks(self, setup=None):
"""
return self.register_worker_plugin(_WorkerSetupPlugin(setup))

async def _register_worker_plugin(self, plugin=None, name=None):
responses = await self.scheduler.register_worker_plugin(
plugin=dumps(plugin, protocol=4), name=name
)
async def _register_worker_plugin(self, plugin=None, name=None, nanny=None):
if nanny or nanny is None and isinstance(plugin, NannyPlugin):
method = self.scheduler.register_nanny_plugin
else:
method = self.scheduler.register_worker_plugin

responses = await method(plugin=dumps(plugin, protocol=4), name=name)
for response in responses.values():
if response["status"] == "error":
exc = response["exception"]
tb = response["traceback"]
raise exc.with_traceback(tb)
return responses

def register_worker_plugin(self, plugin=None, name=None, **kwargs):
def register_worker_plugin(self, plugin=None, name=None, nanny=None, **kwargs):
"""
Registers a lifecycle worker plugin for all current and future workers.
Expand All @@ -4124,12 +4132,14 @@ def register_worker_plugin(self, plugin=None, name=None, **kwargs):
Parameters
----------
plugin : WorkerPlugin
The plugin object to pass to the workers
plugin : WorkerPlugin or NannyPlugin
The plugin object to register.
name : str, optional
A name for the plugin.
Registering a plugin with the same name will have no effect.
If plugin has no name attribute a random name is used.
nanny : bool, optional
Whether to register the plugin with workers or nannies.
**kwargs : optional
If you pass a class as the plugin, instead of a class instance, then the
class will be instantiated with any extra keyword arguments.
Expand Down Expand Up @@ -4174,10 +4184,15 @@ class will be instantiated with any extra keyword arguments.

assert name

return self.sync(self._register_worker_plugin, plugin=plugin, name=name)
return self.sync(
self._register_worker_plugin, plugin=plugin, name=name, nanny=nanny
)

async def _unregister_worker_plugin(self, name):
responses = await self.scheduler.unregister_worker_plugin(name=name)
async def _unregister_worker_plugin(self, name, nanny=None):
if nanny:
responses = await self.scheduler.unregister_nanny_plugin(name=name)
else:
responses = await self.scheduler.unregister_worker_plugin(name=name)

for response in responses.values():
if response["status"] == "error":
Expand All @@ -4186,7 +4201,7 @@ async def _unregister_worker_plugin(self, name):
raise exc.with_traceback(tb)
return responses

def unregister_worker_plugin(self, name):
def unregister_worker_plugin(self, name, nanny=None):
"""Unregisters a lifecycle worker plugin
This unregisters an existing worker plugin. As part of the unregistration process
Expand Down Expand Up @@ -4220,7 +4235,7 @@ def unregister_worker_plugin(self, name):
--------
register_worker_plugin
"""
return self.sync(self._unregister_worker_plugin, name=name)
return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny)


class _WorkerSetupPlugin(WorkerPlugin):
Expand Down
117 changes: 116 additions & 1 deletion distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import subprocess
import sys
import uuid
import zipfile

from dask.utils import funcname
from dask.utils import funcname, tmpfile

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -175,6 +176,41 @@ def release_key(self, key, state, cause, reason, report):
"""


class NannyPlugin:
"""Interface to extend the Nanny
A worker plugin enables custom code to run at different stages of the Workers'
lifecycle. A nanny plugin does the same thing, but benefits from being able
to run code before the worker is started, or to restart the worker if
necessary.
To implement a plugin implement some of the methods of this class and register
the plugin to your client in order to have it attached to every existing and
future nanny by passing ``nanny=True`` to
:meth:`Client.register_worker_plugin<distributed.Client.register_worker_plugin>`.
The ``restart`` attribute is used to control whether or not a running ``Worker``
needs to be restarted when registering the plugin.
See Also
--------
WorkerPlugin
SchedulerPlugin
"""

restart = False

def setup(self, nanny):
"""
Run when the plugin is attached to a nanny. This happens when the plugin is registered
and attached to existing nannies, or when a nanny is created after the plugin has been
registered.
"""

def teardown(self, nanny):
"""Run when the nanny to which the plugin is attached to is closed"""


def _get_worker_plugin_name(plugin) -> str:
"""Returns the worker plugin name. If plugin has no name attribute
a random name is used."""
Expand Down Expand Up @@ -289,3 +325,82 @@ async def setup(self, worker):
comm=None, filename=self.filename, data=self.data, load=True
)
assert len(self.data) == response["nbytes"]


class Environ(NannyPlugin):
restart = True

def __init__(self, environ={}):
self.environ = {k: str(v) for k, v in environ.items()}

async def setup(self, nanny):
nanny.env.update(self.environ)


class UploadDirectory(NannyPlugin):
"""A NannyPlugin to upload a local file to workers.
Parameters
----------
path: str
A path to the directory to upload
Examples
--------
>>> from distributed.diagnostics.plugin import UploadDirectory
>>> client.register_worker_plugin(UploadDirectory("/path/to/directory"), nanny=True) # doctest: +SKIP
"""

def __init__(
self,
path,
restart=False,
update_path=False,
skip_words=(".git", ".github", ".pytest_cache", "tests", "docs"),
skip=(lambda fn: os.path.splitext(fn)[1] == ".pyc",),
):
"""
Initialize the plugin by reading in the data from the given file.
"""
path = os.path.expanduser(path)
self.path = os.path.split(path)[-1]
self.restart = restart
self.update_path = update_path

self.name = "upload-directory-" + os.path.split(path)[-1]

with tmpfile(extension="zip") as fn:
with zipfile.ZipFile(fn, "w", zipfile.ZIP_DEFLATED) as z:
for root, dirs, files in os.walk(path):
for file in files:
filename = os.path.join(root, file)
if any(predicate(filename) for predicate in skip):
continue
dirs = filename.split(os.sep)
if any(word in dirs for word in skip_words):
continue

archive_name = os.path.relpath(
os.path.join(root, file), os.path.join(path, "..")
)
z.write(filename, archive_name)

with open(fn, "rb") as f:
self.data = f.read()

async def setup(self, nanny):
fn = os.path.join(nanny.local_directory, f"tmp-{str(uuid.uuid4())}.zip")
with open(fn, "wb") as f:
f.write(self.data)

import zipfile

with zipfile.ZipFile(fn) as z:
z.extractall(path=nanny.local_directory)

if self.update_path:
path = os.path.join(nanny.local_directory, self.path)
if path not in sys.path:
sys.path.insert(0, path)

os.remove(fn)
67 changes: 64 additions & 3 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
import weakref
from contextlib import suppress
from inspect import isawaitable
from multiprocessing.queues import Empty
from time import sleep as sync_sleep

Expand All @@ -22,15 +23,18 @@
from . import preloading
from .comm import get_address_host, unparse_host_port
from .comm.addressing import address_from_user_args
from .core import CommClosedError, RPCClosed, Status, coerce_to_address
from .core import CommClosedError, RPCClosed, Status, coerce_to_address, error_message
from .diagnostics.plugin import _get_worker_plugin_name
from .node import ServerNode
from .process import AsyncProcess
from .proctitle import enable_proctitle_on_children
from .protocol import pickle
from .security import Security
from .utils import (
TimeoutError,
get_ip,
json_load_robust,
log_errors,
mp_context,
parse_ports,
silence_logging,
Expand Down Expand Up @@ -110,14 +114,14 @@ def __init__(

if local_directory is None:
local_directory = dask.config.get("temporary-directory") or os.getcwd()
if not os.path.exists(local_directory):
os.makedirs(local_directory)
self._original_local_dir = local_directory
local_directory = os.path.join(local_directory, "dask-worker-space")
else:
self._original_local_dir = local_directory

self.local_directory = local_directory
if not os.path.exists(self.local_directory):
os.makedirs(self.local_directory, exist_ok=True)

self.preload = preload
if self.preload is None:
Expand Down Expand Up @@ -205,8 +209,12 @@ def __init__(
"terminate": self.close,
"close_gracefully": self.close_gracefully,
"run": self.run,
"plugin_add": self.plugin_add,
"plugin_remove": self.plugin_remove,
}

self.plugins = {}

super().__init__(
handlers=handlers, io_loop=self.loop, connection_args=self.connection_args
)
Expand Down Expand Up @@ -300,6 +308,10 @@ async def start(self):
for preload in self.preloads:
await preload.start()

msg = await self.scheduler.register_nanny()
for name, plugin in msg["nanny-plugins"].items():
await self.plugin_add(plugin=plugin, name=name)

logger.info(" Start Nanny at: %r", self.address)
response = await self.instantiate()
if response == Status.running:
Expand Down Expand Up @@ -390,6 +402,47 @@ async def instantiate(self, comm=None) -> Status:
raise
return result

async def plugin_add(self, comm=None, plugin=None, name=None):
with log_errors(pdb=False):
if isinstance(plugin, bytes):
plugin = pickle.loads(plugin)

if name is None:
name = _get_worker_plugin_name(plugin)

assert name

self.plugins[name] = plugin

logger.info("Starting Nanny plugin %s" % name)
if hasattr(plugin, "setup"):
try:
result = plugin.setup(nanny=self)
if isawaitable(result):
result = await result
except Exception as e:
msg = error_message(e)
return msg
if getattr(plugin, "restart", False):
await self.restart()

return {"status": "OK"}

async def plugin_remove(self, comm=None, name=None):
with log_errors(pdb=False):
logger.info(f"Removing Nanny plugin {name}")
try:
plugin = self.plugins.pop(name)
if hasattr(plugin, "teardown"):
result = plugin.teardown(nanny=self)
if isawaitable(result):
result = await result
except Exception as e:
msg = error_message(e)
return msg

return {"status": "OK"}

async def restart(self, comm=None, timeout=30, executor_wait=True):
async def _():
if self.process is not None:
Expand Down Expand Up @@ -514,6 +567,14 @@ async def close(self, comm=None, timeout=5, report=None):
for preload in self.preloads:
await preload.teardown()

teardowns = [
plugin.teardown(self)
for plugin in self.plugins.values()
if hasattr(plugin, "teardown")
]

await asyncio.gather(*[td for td in teardowns if isawaitable(td)])

self.stop()
try:
if self.process is not None:
Expand Down
Loading

0 comments on commit 50fd3ff

Please sign in to comment.