Skip to content

Commit

Permalink
Support checking service ports with ssl connection
Browse files Browse the repository at this point in the history
By default netcat is used to check if a service is
listening on a port. This is generally ok except
for services expecting SSL connections which need
to be properly closed and netcat can't do that. So
here we add support for optionally using the python
ssl library to create an ssl connection to the port
and close it properly once finished.

Related-Bug: #1920770
  • Loading branch information
dosaboy committed Apr 12, 2024
1 parent d3a8682 commit 78d3e5d
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 13 deletions.
40 changes: 35 additions & 5 deletions charmhelpers/contrib/network/ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import subprocess
import socket
import ssl

from functools import partial

Expand Down Expand Up @@ -527,19 +528,48 @@ def get_hostname(address, fqdn=True):
return result.split('.')[0]


def port_has_listener(address, port):
class SSLPortCheckInfo(object):

def __init__(self, key, cert, ca_cert):
self.key = key
self.cert = cert
self.ca_cert = ca_cert

@property
def ssl_context(self):
context = ssl.create_default_context()
context.check_hostname = False
context.load_cert_chain(self.cert, self.key)
context.load_verify_locations(self.ca_cert)
return context


def port_has_listener(address, port, sslinfo=None):
"""
Returns True if the address:port is open and being listened to,
else False.
else False. By default uses netcat to check ports but if sslinfo is
provided will use an SSL connection instead.
@param address: an IP address or hostname
@param port: integer port
@param sslinfo: optional SSLPortCheckInfo object.
If provided, the check is performed using an ssl
connection.
Note calls 'zc' via a subprocess shell
"""
cmd = ['nc', '-z', address, str(port)]
result = subprocess.call(cmd)
return not (bool(result))
if not sslinfo:
cmd = ['nc', '-z', address, str(port)]
result = subprocess.call(cmd)
return not (bool(result))

try:
with socket.create_connection((address, port)) as sock:
with sslinfo.ssl_context.wrap_socket(sock,
server_hostname=address):
return True
except ConnectionRefusedError:
return False


def assert_charm_supports_ipv6():
Expand Down
29 changes: 22 additions & 7 deletions charmhelpers/contrib/openstack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from charmhelpers.contrib.network.ip import (
get_ipv6_addr,
is_ipv6,
SSLPortCheckInfo,
port_has_listener,
)

Expand Down Expand Up @@ -1207,16 +1208,22 @@ def _ows_check_services_running(services, ports):
return ows_check_services_running(services, ports)


def ows_check_services_running(services, ports):
def ows_check_services_running(services, ports, ssl_check_info=None):
"""Check that the services that should be running are actually running
and that any ports specified are being listened to.
@param services: list of strings OR dictionary specifying services/ports
@param ports: list of ports
@param ssl_check_info: Optional dict of {key: <path to key>, cert: <path
to cert>}. If provided, port checks will be done
using an SSL connection.
@returns state, message: strings or None, None
"""
messages = []
state = None
if ssl_check_info:
ssl_check_info = SSLPortCheckInfo(**ssl_check_info)

if services is not None:
services = _extract_services_list_helper(services)
services_running, running = _check_running_services(services)
Expand All @@ -1228,7 +1235,7 @@ def ows_check_services_running(services, ports):
# also verify that the ports that should be open are open
# NB, that ServiceManager objects only OPTIONALLY have ports
map_not_open, ports_open = (
_check_listening_on_services_ports(services))
_check_listening_on_services_ports(services, ssl_check_info))
if not all(ports_open):
# find which service has missing ports. They are in service
# order which makes it a bit easier.
Expand All @@ -1243,7 +1250,8 @@ def ows_check_services_running(services, ports):

if ports is not None:
# and we can also check ports which we don't know the service for
ports_open, ports_open_bools = _check_listening_on_ports_list(ports)
ports_open, ports_open_bools = \
_check_listening_on_ports_list(ports, ssl_check_info)
if not all(ports_open_bools):
messages.append(
"Ports which should be open, but are not: {}"
Expand Down Expand Up @@ -1302,7 +1310,8 @@ def _check_running_services(services):
return list(zip(services, services_running)), services_running


def _check_listening_on_services_ports(services, test=False):
def _check_listening_on_services_ports(services, test=False,
ssl_check_info=None):
"""Check that the unit is actually listening (has the port open) on the
ports that the service specifies are open. If test is True then the
function returns the services with ports that are open rather than
Expand All @@ -1312,11 +1321,14 @@ def _check_listening_on_services_ports(services, test=False):
@param services: OrderedDict(service: [port, ...], ...)
@param test: default=False, if False, test for closed, otherwise open.
@param ssl_check_info: SSLPortCheckInfo object. If provided, port checks
will be done using an SSL connection.
@returns OrderedDict(service: [port-not-open, ...]...), [boolean]
"""
test = not (not (test)) # ensure test is True or False
all_ports = list(itertools.chain(*services.values()))
ports_states = [port_has_listener('0.0.0.0', p) for p in all_ports]
ports_states = [port_has_listener('0.0.0.0', p, ssl_check_info)
for p in all_ports]
map_ports = OrderedDict()
matched_ports = [p for p, opened in zip(all_ports, ports_states)
if opened == test] # essentially opened xor test
Expand All @@ -1327,16 +1339,19 @@ def _check_listening_on_services_ports(services, test=False):
return map_ports, ports_states


def _check_listening_on_ports_list(ports):
def _check_listening_on_ports_list(ports, ssl_check_info=None):
"""Check that the ports list given are being listened to
Returns a list of ports being listened to and a list of the
booleans.
@param ssl_check_info: SSLPortCheckInfo object. If provided, port checks
will be done using an SSL connection.
@param ports: LIST of port numbers.
@returns [(port_num, boolean), ...], [boolean]
"""
ports_open = [port_has_listener('0.0.0.0', p) for p in ports]
ports_open = [port_has_listener('0.0.0.0', p, ssl_check_info)
for p in ports]
return zip(ports, ports_open), ports_open


Expand Down
46 changes: 46 additions & 0 deletions tests/contrib/network/test_ip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import subprocess
import unittest
from contextlib import contextmanager

import mock
import netifaces
Expand Down Expand Up @@ -784,6 +785,51 @@ def test_port_has_listener(self, subprocess_call):
self.assertEqual(net_ip.port_has_listener('ip-address', 70), True)
subprocess_call.assert_called_with(['nc', '-z', 'ip-address', '70'])

@patch('charmhelpers.contrib.network.ip.socket')
@patch('charmhelpers.contrib.network.ip.ssl')
def test_port_has_listener_ssl(self, mock_ssl, mock_socket):
ctxt = mock.MagicMock()
mock_ssl.create_default_context.return_value = ctxt

@contextmanager
def mock_create_connection(*args, **kwargs):
for x in [1]:
yield x

@contextmanager
def mock_wrap_socket(*args, **kwargs):
for x in [1]:
yield x

ctxt.wrap_socket = mock_wrap_socket
mock_socket.create_connection = mock_create_connection
sslinfo = net_ip.SSLPortCheckInfo('/etc/ssl/key', '/etc/ssl/cert',
'/etc/ssl/ca_cert')
self.assertEqual(net_ip.port_has_listener('10.0.0.1', 50, sslinfo),
True)

@patch('charmhelpers.contrib.network.ip.socket')
@patch('charmhelpers.contrib.network.ip.ssl')
def test_port_has_listener_ssl_false(self, mock_ssl, mock_socket):
ctxt = mock.MagicMock()
mock_ssl.create_default_context.return_value = ctxt

@contextmanager
def mock_create_connection(*args, **kwargs):
raise ConnectionRefusedError

@contextmanager
def mock_wrap_socket(*args, **kwargs):
for x in [1]:
yield x

ctxt.wrap_socket = mock_wrap_socket
mock_socket.create_connection = mock_create_connection
sslinfo = net_ip.SSLPortCheckInfo('/etc/ssl/key', '/etc/ssl/cert',
'/etc/ssl/ca_cert')
self.assertEqual(net_ip.port_has_listener('10.0.0.1', 50, sslinfo),
False)

@patch.object(net_ip, 'log', lambda *args, **kwargs: None)
@patch.object(net_ip, 'config')
@patch.object(net_ip, 'network_get_primary_address')
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/openstack/test_openstack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,7 @@ def test_pause_unit_retry_port_check_retries(
port_has_listener.side_effect = [True, False]
wait_for_ports_func = openstack.make_wait_for_ports_barrier([77])
openstack.pause_unit(None, services=['service1'], ports=[77], charm_func=wait_for_ports_func)
port_has_listener.assert_has_calls([call('0.0.0.0', 77), call('0.0.0.0', 77)])
port_has_listener.assert_has_calls([call('0.0.0.0', 77, None), call('0.0.0.0', 77, None)])

@patch('charmhelpers.contrib.openstack.utils.service_resume')
@patch('charmhelpers.contrib.openstack.utils.clear_unit_paused')
Expand Down

0 comments on commit 78d3e5d

Please sign in to comment.