Skip to content

Commit

Permalink
Remove one call to get_ssh_client()
Browse files Browse the repository at this point in the history
Since in ssh_reachable() we already get a SSH client connection,
let's save it in the (unused so far) _ssh_client var.
Then reuse it, in _scp() command.
  • Loading branch information
Yaniv Kaul committed Apr 24, 2018
1 parent 306d9aa commit 652856e
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions lago/plugins/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def extract_paths(self, paths, ignore_nopath):
path was found on the VM, and ``ignore_nopath`` is True.
:exc:`~lago.plugins.vm.ExtractPathError`: on all other failures.
"""
if self.vm.alive() and self.vm.ssh_reachable(
tries=5, propagate_fail=False
):
if self.vm.ssh_reachable(tries=5, propagate_fail=False):
self._extract_paths_scp(paths=paths, ignore_nopath=ignore_nopath)
else:
raise ExtractPathError(
Expand Down Expand Up @@ -558,7 +556,7 @@ def ssh_reachable(self, tries=None, propagate_fail=True):
"""

try:
ssh.get_ssh_client(
self._ssh_client = ssh.get_ssh_client(
ip_addr=self.ip(),
host_name=self.name(),
ssh_tries=tries,
Expand Down Expand Up @@ -686,19 +684,23 @@ def _normalize_spec(cls, spec):

@contextlib.contextmanager
def _scp(self, propagate_fail=True):
client = ssh.get_ssh_client(
propagate_fail=propagate_fail,
ip_addr=self.ip(),
host_name=self.name(),
ssh_key=self.virt_env.prefix.paths.ssh_id_rsa(),
username=self._spec.get('ssh-user'),
password=self._spec.get('ssh-password'),
)
if self._ssh_client is not None:
client = self._ssh_client
else:
client = ssh.get_ssh_client(
propagate_fail=propagate_fail,
ip_addr=self.ip(),
host_name=self.name(),
ssh_key=self.virt_env.prefix.paths.ssh_id_rsa(),
username=self._spec.get('ssh-user'),
password=self._spec.get('ssh-password'),
)
scp = SCPClient(client.get_transport())
try:
yield scp
finally:
client.close()
self._ssh_client = None

def _detect_service_provider(self):
LOGGER.debug('Detecting service provider for %s', self.name())
Expand Down

0 comments on commit 652856e

Please sign in to comment.