diff --git a/flocker/testtools.py b/flocker/testtools.py index f303a3a3d5..a39362c5ca 100644 --- a/flocker/testtools.py +++ b/flocker/testtools.py @@ -8,19 +8,29 @@ import io import socket import sys +import os +import pwd from collections import namedtuple from contextlib import contextmanager from random import random +from subprocess import check_call from functools import wraps from zope.interface import implementer from zope.interface.verify import verifyClass +from ipaddr import IPAddress + from twisted.internet.interfaces import IProcessTransport, IReactorProcess from twisted.python.filepath import FilePath, Permissions from twisted.internet.task import Clock, deferLater from twisted.internet.defer import maybeDeferred from twisted.internet import reactor +from twisted.cred.portal import IRealm, Portal +from twisted.conch.ssh.keys import Key +from twisted.conch.checkers import SSHPublicKeyDatabase +from twisted.conch.openssh_compat.factory import OpenSSHFactory +from twisted.conch.unix import UnixConchUser from twisted.trial.unittest import SynchronousTestCase, SkipTest from . import __version__ @@ -318,6 +328,151 @@ def test_verbosity_multiple(self): self.assertEqual(2, options['verbosity']) +class _InMemoryPublicKeyChecker(SSHPublicKeyDatabase): + """ + Check SSH public keys in-memory. + """ + + def __init__(self, public_key): + """ + :param Key public_key: The public key we will accept. + """ + self._key = public_key + + def checkKey(self, credentials): + """ + Validate some SSH key credentials. + + Access is granted to the name of the user running the current process + for the key this checker was initialized with. + """ + # It would probably be better for the username to be another `__init__` + # argument. https://github.com/ClusterHQ/flocker/issues/189 + return (self._key.blob() == credentials.blob and + pwd.getpwuid(os.getuid()).pw_name == credentials.username) + + +class _FixedHomeConchUser(UnixConchUser): + """ + An SSH user with a fixed, configurable home directory. + + This is like a normal UNIX SSH user except the user's home directory is not + determined by the ``pwd`` database. + """ + def __init__(self, username, home): + """ + :param FilePath home: The path of the directory to use as this user's + home directory. + """ + UnixConchUser.__init__(self, username) + self._home = home + + def getHomeDir(self): + """ + Give back the pre-determined home directory. + """ + return self._home.path + + +@implementer(IRealm) +class UnixSSHRealm(object): + """ + An ``IRealm`` for a Conch server which gives out ``_FixedHomeConchUser`` + users. + """ + def __init__(self, home): + self.home = home + + def requestAvatar(self, username, mind, *interfaces): + user = _FixedHomeConchUser(username, self.home) + return interfaces[0], user, user.logout + + +class _ConchServer(object): + """ + A helper for a test fixture to run an SSH server using Twisted Conch. + + :ivar IPv4Address ip: The address the server is listening on. + :ivar int port: The port number the server is listening on. + :ivar _port: An object which provides ``IListeningPort`` and represents the + listening Conch server. + + :ivar FilePath home_path: The path of the home directory of the user which + is allowed to authenticate against this server. + + :ivar FilePath key_path: The path of an SSH private key which can be used + to authenticate against the server. + + :ivar FilePath host_key_path: The path of the server's private host key. + """ + def __init__(self, base_path): + """ + :param FilePath base_path: The path beneath which all of the temporary + SSH server-related files will be created. An ``ssh`` directory + will be created as a child of this directory to hold the key pair + that is generated. An ``sshd`` directory will also be created here + to hold the generated host key. A ``home`` directory is also + created here and used as the home directory for shell logins to the + server. + """ + self.home = base_path.child(b"home") + self.home.makedirs() + + ssh_path = base_path.child(b"ssh") + ssh_path.makedirs() + self.key_path = ssh_path.child(b"key") + check_call( + [b"ssh-keygen", + # Specify the path where the generated key is written. + b"-f", self.key_path.path, + # Specify an empty passphrase. + b"-N", b"", + # Generate as little output as possible. + b"-q"]) + key = Key.fromFile(self.key_path.path) + + sshd_path = base_path.child(b"sshd") + sshd_path.makedirs() + self.host_key_path = sshd_path.child(b"ssh_host_key") + check_call( + [b"ssh-keygen", + # See above for option explanations. + b"-f", self.host_key_path.path, + b"-N", b"", + b"-q"]) + + factory = OpenSSHFactory() + realm = UnixSSHRealm(self.home) + checker = _InMemoryPublicKeyChecker(public_key=key.public()) + factory.portal = Portal(realm, [checker]) + factory.dataRoot = sshd_path.path + factory.moduliRoot = b"/etc/ssh" + + self._port = reactor.listenTCP(0, factory, interface=b"127.0.0.1") + self.ip = IPAddress(self._port.getHost().host) + self.port = self._port.getHost().port + + def restore(self): + """ + Shut down the SSH server. + + :return: A ``Deferred`` that fires when this has been done. + """ + return self._port.stopListening() + + +def create_ssh_server(base_path): + """ + :py:func:`create_ssh_server` is a fixture which creates and runs a new SSH + server and stops it later. Use the :py:meth:`restore` method of the + returned object to stop the server. + + :param FilePath base_path: The path to a directory in which key material + will be generated. + """ + return _ConchServer(base_path) + + def make_with_init_tests(record_type, kwargs): """ Return a ``TestCase`` which tests that ``record_type.__init__`` accepts the diff --git a/flocker/volume/functional/test_ipc.py b/flocker/volume/functional/test_ipc.py index 0dfc16da42..aabec41b07 100644 --- a/flocker/volume/functional/test_ipc.py +++ b/flocker/volume/functional/test_ipc.py @@ -2,24 +2,17 @@ """Functional tests for IPC.""" -import subprocess import os -import pwd +from getpass import getuser from unittest import skipIf from twisted.trial.unittest import TestCase from twisted.python.filepath import FilePath -from twisted.internet import reactor -from twisted.cred.portal import Portal -from twisted.conch.ssh.keys import Key -from twisted.conch.unix import UnixSSHRealm -from twisted.conch.checkers import SSHPublicKeyDatabase -from twisted.conch.openssh_compat.factory import OpenSSHFactory from twisted.internet.threads import deferToThread from .._ipc import ProcessNode from ..test.test_ipc import make_inode_tests - +from ...testtools import create_ssh_server _if_root = skipIf(os.getuid() != 0, "Must run as root.") @@ -72,53 +65,21 @@ def test_bad_exit(self): self.fail("No IOError") -class InMemoryPublicKeyChecker(SSHPublicKeyDatabase): - """Check SSH public keys in-memory.""" - - def __init__(self, public_key): - """ - :param bytes public_key: The public key we will accept. - """ - self._key = Key.fromString(data=public_key) - - def checkKey(self, credentials): - return (self._key.blob() == credentials.blob and - pwd.getpwuid(os.getuid()).pw_name == credentials.username) - - @_if_root def make_sshnode(test_case): - """Create a ``ProcessNode`` that can SSH into the local machine. + """ + Create a ``ProcessNode`` that can SSH into the local machine. :param TestCase test_case: The test case to use. :return: A ``ProcessNode`` instance. """ - sshd_path = FilePath(test_case.mktemp()) - sshd_path.makedirs() - subprocess.check_call( - [b"ssh-keygen", b"-f", sshd_path.child(b"ssh_host_key").path, - b"-N", b"", b"-q"]) - - ssh_path = FilePath(test_case.mktemp()) - ssh_path.makedirs() - subprocess.check_call( - [b"ssh-keygen", b"-f", ssh_path.child(b"key").path, - b"-N", b"", b"-q"]) - - factory = OpenSSHFactory() - realm = UnixSSHRealm() - checker = InMemoryPublicKeyChecker(ssh_path.child(b"key.pub").getContent()) - factory.portal = Portal(realm, [checker]) - factory.dataRoot = sshd_path.path - factory.moduliRoot = b"/etc/ssh" - - port = reactor.listenTCP(0, factory, interface=b"127.0.0.1") - test_case.addCleanup(port.stopListening) - - return ProcessNode.using_ssh(b"127.0.0.1", port.getHost().port, - pwd.getpwuid(os.getuid()).pw_name, - ssh_path.child(b"key")) + server = create_ssh_server(FilePath(test_case.mktemp())) + test_case.addCleanup(server.restore) + + return ProcessNode.using_ssh( + host=unicode(server.ip).encode("ascii"), port=server.port, + username=getuser(), private_key=server.key_path) class SSHProcessNodeTests(TestCase):