diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index c07b0a0c..b23a0ab1 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -376,7 +376,9 @@ def message_received(self, topic: str, timeout: int = 1): return msg - def publish(self, topic, payload=None, qos=None, retain=None): + def publish( + self, topic: str, payload=None, qos=None, retain=None + ) -> paho.MQTTMessageInfo: """publish message using paho library""" self._wait_for_subscriptions() @@ -389,7 +391,14 @@ def publish(self, topic, payload=None, qos=None, retain=None): kwargs["retain"] = retain msg = self._client.publish(topic, payload, **kwargs) - if not msg.is_published: + # Wait for 2*connect timeout which should be plenty to publish the message even with qos 2 + # TODO: configurable + try: + msg.wait_for_publish(self._connect_timeout * 2) + except (RuntimeError, ValueError) as e: + raise exceptions.MQTTError("could not publish message") from e + + if not msg.is_published(): raise exceptions.MQTTError( "err {:s}: {:s}".format( _err_vals.get(msg.rc, "unknown"), paho.error_string(msg.rc) diff --git a/tests/unit/test_mqtt.py b/tests/unit/test_mqtt.py index 1d55b06f..4b83bc2c 100644 --- a/tests/unit/test_mqtt.py +++ b/tests/unit/test_mqtt.py @@ -1,3 +1,4 @@ +import time from typing import Dict from unittest.mock import MagicMock, Mock, patch @@ -71,28 +72,67 @@ def test_context_connection_success(self, fake_client): with fake_client as x: assert fake_client == x - def test_assert_message_published(self, fake_client): - """If it couldn't immediately publish the message, error out""" + def test_assert_message_published_error(self, fake_client): + """Error waiting for it to publish""" + + class FakeMessage(paho.MQTTMessageInfo): + def wait_for_publish(self, timeout=None): + raise RuntimeError + + rc = 1 + + with patch.object(fake_client._client, "subscribe"), patch.object( + fake_client._client, "publish", return_value=FakeMessage(10) + ): + with pytest.raises(exceptions.MQTTError): + fake_client.publish("abc", "123") + + def test_assert_message_published_failure(self, fake_client): + """If it couldn't publish the message, error out""" + + class FakeMessage(paho.MQTTMessageInfo): + def wait_for_publish(self, timeout=None): + return + + def is_published(self): + return False - class FakeMessage: - is_published = False rc = 1 with patch.object(fake_client._client, "subscribe"), patch.object( - fake_client._client, "publish", return_value=FakeMessage() + fake_client._client, "publish", return_value=FakeMessage(10) ): with pytest.raises(exceptions.MQTTError): fake_client.publish("abc", "123") + def test_assert_message_published_delay(self, fake_client): + """Published but only after a small delay""" + + class FakeMessage(paho.MQTTMessageInfo): + def wait_for_publish(self, timeout=None): + time.sleep(0.5) + + def is_published(self): + return True + + rc = 1 + + with patch.object(fake_client._client, "subscribe"), patch.object( + fake_client._client, "publish", return_value=FakeMessage(10) + ): + fake_client.publish("abc", "123") + def test_assert_message_published_unknown_err(self, fake_client): """Same, but with an unknown error code""" - class FakeMessage: - is_published = False + class FakeMessage(paho.MQTTMessageInfo): + def is_published(self): + return False + rc = 2342423 with patch.object(fake_client._client, "subscribe"), patch.object( - fake_client._client, "publish", return_value=FakeMessage() + fake_client._client, "publish", return_value=FakeMessage(10) ): with pytest.raises(exceptions.MQTTError): fake_client.publish("abc", "123")