Integration tests can test insecure communications with broker

master
B.J. Dweck 2020-10-16 10:57:38 +02:00
parent e061b64ec1
commit 985d373b74
9 changed files with 115 additions and 52 deletions

View File

@ -0,0 +1,4 @@
listener 8883
connection_messages true
log_type all
websockets_log_level 9

View File

@ -1 +1,3 @@
docker run --rm --name mosquitto -p 8883:8883 -v %cd%\test\broker-config:/mosquitto/config eclipse-mosquitto @echo off
if %1%==secure (set config=broker-tls-config) else (set config=broker-no-tls-config)
docker run --rm -d --name mosquitto -p 8883:8883 -v "%cd%\test\%config%:/mosquitto/config" eclipse-mosquitto

View File

@ -1,8 +1,14 @@
#!/bin/bash #!/bin/bash
docker run -it --rm \ CONFIG=broker-no-tls-config
if [[ $1 == "secure" ]]; then
CONFIG=broker-tls-config
fi
docker run --rm -d \
--user "$UID" \ --user "$UID" \
-p 8883:8883 \ -p 8883:8883 \
-v "$(pwd)/test/broker-config:/mosquitto/config" \ -v "$(pwd)/test/$CONFIG:/mosquitto/config" \
--name mosquitto \ --name mosquitto \
eclipse-mosquitto eclipse-mosquitto

View File

@ -6,6 +6,7 @@ import threading
import time import time
import unittest import unittest
from datetime import datetime from datetime import datetime
from unittest.case import TestCase
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
@ -25,18 +26,19 @@ subscriber_cert_file = subscriber_config_path + "subscriber.crt"
subscriber_key_file = subscriber_config_path + "subscriber.key" subscriber_key_file = subscriber_config_path + "subscriber.key"
def agent_connect(): def agent_connect(use_tls=True):
client = mqtt.Client() client = mqtt.Client()
client.tls_set( if use_tls:
ca_certs=mqtt_ca_file, client.tls_set(
certfile=mqtt_cert_file, ca_certs=mqtt_ca_file,
keyfile=mqtt_key_file, certfile=mqtt_cert_file,
cert_reqs=ssl.CERT_REQUIRED) keyfile=mqtt_key_file,
cert_reqs=ssl.CERT_REQUIRED)
client.connect(broker_hostname, broker_port, 60) client.connect(broker_hostname, broker_port, 60)
return client return client
def agent_publish(client_id, onion_hostname): def publish(client_id, onion_hostname, use_tls=True):
payload = { payload = {
'clientId': client_id, 'clientId': client_id,
'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"), 'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"),
@ -45,21 +47,14 @@ def agent_publish(client_id, onion_hostname):
} }
time.sleep(1) time.sleep(1)
client = agent_connect() client = agent_connect(use_tls=use_tls)
client.publish("torch/" + client_id + "/wake", json.dumps(payload)) client.publish("torch/" + client_id + "/wake", json.dumps(payload))
client.disconnect() client.disconnect()
time.sleep(1) time.sleep(1)
class GivenBrokerAndTorchAgent(unittest.TestCase): class GivenBroker(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
cli = "test/run-broker.sh"
if sys.platform.startswith('win32'):
cli = "test\\run-broker.bat"
threading.Thread(target=os.system, args=(cli,), daemon=True).start()
time.sleep(2)
if os.path.exists(torch_sub.database_file): if os.path.exists(torch_sub.database_file):
os.remove(torch_sub.database_file) os.remove(torch_sub.database_file)
@ -68,29 +63,13 @@ class GivenBrokerAndTorchAgent(unittest.TestCase):
if os.path.exists(torch_sub.database_file): if os.path.exists(torch_sub.database_file):
os.remove(torch_sub.database_file) os.remove(torch_sub.database_file)
def test_when_agent_publishes_should_get_hostname_from_subscriber(self): @staticmethod
self.run_subscriber() def run_broker(tls):
agent_publish("client1", "crazy_onion.onion") cli = "test/run-broker.sh " + tls
database = self.loadDatabase() if sys.platform.startswith('win32'):
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion") cli = "test\\run-broker.bat " + tls
os.system(cli)
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self): time.sleep(2)
self.run_subscriber()
agent_publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
def test_when_agent_publishes_multiple_hosts_should_provide_latest(self):
self.run_subscriber()
agent_publish("client2", "crazy_onion2-34.onion")
agent_publish("client3", "crazy_onion3.onion")
agent_publish("client1", "crazy_onion1.onion")
agent_publish("client2", "crazy_onion2-56.onion")
agent_publish("client3", "crazy_onion3.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion")
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion")
self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion")
@staticmethod @staticmethod
def run_subscriber(): def run_subscriber():
@ -98,9 +77,20 @@ class GivenBrokerAndTorchAgent(unittest.TestCase):
args=(broker_hostname, args=(broker_hostname,
broker_port, broker_port,
"torch/+/wake", "torch/+/wake",
subscriber_ca_file, {
subscriber_cert_file, 'ca_certs': subscriber_ca_file,
subscriber_key_file), 'certfile': subscriber_cert_file,
'keyfile': subscriber_key_file
}),
daemon=True).start()
@staticmethod
def run_insecure_subscriber():
threading.Thread(target=torch_sub.subscribe,
args=(broker_hostname,
broker_port,
"torch/+/wake",
None),
daemon=True).start() daemon=True).start()
@staticmethod @staticmethod
@ -109,5 +99,69 @@ class GivenBrokerAndTorchAgent(unittest.TestCase):
return json.load(database_file) return json.load(database_file)
class GivenTlsBroker(GivenBroker):
def setUp(self) -> None:
self.run_broker("secure")
self.run_subscriber()
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
publish("client1", "crazy_onion.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion")
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self):
publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
def test_when_agent_publishes_multiple_hosts_should_provide_latest(self):
publish("client2", "crazy_onion2-34.onion")
publish("client3", "crazy_onion3.onion")
publish("client1", "crazy_onion1.onion")
publish("client2", "crazy_onion2-56.onion")
publish("client3", "crazy_onion3.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion")
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion")
self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion")
class GivenNonTlsBroker(GivenBroker):
def setUp(self) -> None:
self.run_broker("insecure")
self.run_insecure_subscriber()
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
self.insecure_publish("client1", "crazy_onion.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion")
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self):
self.insecure_publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
def test_when_agent_publishes_multiple_hosts_should_provide_latest(self):
self.insecure_publish("client2", "crazy_onion2-34.onion")
self.insecure_publish("client3", "crazy_onion3.onion")
self.insecure_publish("client1", "crazy_onion1.onion")
self.insecure_publish("client2", "crazy_onion2-56.onion")
self.insecure_publish("client3", "crazy_onion3.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion")
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion")
self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion")
@staticmethod
def insecure_publish(client_id, onion_address):
publish(client_id, onion_address, use_tls=False)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -14,7 +14,7 @@ def update_client_record(client, userdata, message):
database_blank.write("{}") database_blank.write("{}")
with open(database_file, 'r') as infile: with open(database_file, 'r') as infile:
database = json.load(infile) database = json.loads(infile.read())
payload = message.payload.decode('utf-8') payload = message.payload.decode('utf-8')
response = json.loads(payload) response = json.loads(payload)
@ -25,13 +25,10 @@ def update_client_record(client, userdata, message):
json.dump(database, outfile) json.dump(database, outfile)
def subscribe(broker_hostname, broker_port, topic="torch", ca_file=None, cert_file=None, key_file=None): def subscribe(broker_hostname, broker_port, topic="torch", tls=None, auth=None):
mqtt.callback(update_client_record, mqtt.callback(update_client_record,
topic, topic,
hostname=broker_hostname, hostname=broker_hostname,
port=broker_port, port=broker_port,
tls={ tls=tls,
'ca_certs': ca_file, auth=auth)
'certfile': cert_file,
'keyfile': key_file
})