import json import os import ssl import sys import threading import time import unittest from datetime import datetime from unittest.case import TestCase import paho.mqtt.client as mqtt from torchsub import torch_sub broker_hostname = "mqtt.example.com" broker_port = 8883 agent_config_path = "test/agent-config/" mqtt_ca_file = agent_config_path + "ca.crt" mqtt_cert_file = agent_config_path + "vagrant.crt" mqtt_key_file = agent_config_path + "vagrant.key" subscriber_config_path = "test/subscriber-config/" subscriber_ca_file = subscriber_config_path + "ca.crt" subscriber_cert_file = subscriber_config_path + "subscriber.crt" subscriber_key_file = subscriber_config_path + "subscriber.key" def agent_connect(use_tls=True): client = mqtt.Client() if use_tls: client.tls_set( ca_certs=mqtt_ca_file, certfile=mqtt_cert_file, keyfile=mqtt_key_file, cert_reqs=ssl.CERT_REQUIRED) client.connect(broker_hostname, broker_port, 60) return client def publish(client_id, onion_hostname, use_tls=True): payload = { 'clientId': client_id, 'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"), 'onionAddress': onion_hostname, 'sshPort': 22 } time.sleep(1) client = agent_connect(use_tls=use_tls) client.publish("torch/" + client_id + "/wake", json.dumps(payload)) client.disconnect() time.sleep(1) class GivenBroker(TestCase): def setUp(self) -> None: if os.path.exists(torch_sub.database_filename): os.remove(torch_sub.database_filename) def tearDown(self) -> None: os.system("docker container stop mosquitto") if os.path.exists(torch_sub.database_filename): os.remove(torch_sub.database_filename) @staticmethod def run_broker(tls): cli = "test/run-broker.sh " + tls if sys.platform.startswith('win32'): cli = "test\\run-broker.bat " + tls os.system(cli) time.sleep(2) @staticmethod def run_subscriber(): threading.Thread(target=torch_sub.subscribe, args=(broker_hostname, broker_port, "torch/+/wake", { 'ca_certs': subscriber_ca_file, 'certfile': subscriber_cert_file, 'keyfile': subscriber_key_file }), daemon=True).start() @staticmethod def run_insecure_subscriber(): threading.Thread(target=torch_sub.subscribe, args=(broker_hostname, broker_port, "torch/+/wake", None), daemon=True).start() @staticmethod def loadDatabase(): with open(torch_sub.database_filename, "r") as database_file: return json.load(database_file) class GivenTlsBroker(GivenBroker): def setUp(self) -> None: self.run_broker("secure") self.run_subscriber() def test_when_agent_publishes_should_get_hostname_from_subscriber(self): publish("client1", "crazy_onion.onion") database = self.loadDatabase() self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion") def test_when_agent_publishes_should_get_hostname_from_subscriber2(self): publish("client2", "crazy_onion2.onion") database = self.loadDatabase() self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion") def test_when_agent_publishes_multiple_hosts_should_provide_latest(self): publish("client2", "crazy_onion2-34.onion") publish("client3", "crazy_onion3.onion") publish("client1", "crazy_onion1.onion") publish("client2", "crazy_onion2-56.onion") publish("client3", "crazy_onion3.onion") database = self.loadDatabase() self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion") self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion") self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion") class GivenNonTlsBroker(GivenBroker): def setUp(self) -> None: self.run_broker("insecure") self.run_insecure_subscriber() def test_when_agent_publishes_should_get_hostname_from_subscriber(self): self.insecure_publish("client1", "crazy_onion.onion") database = self.loadDatabase() self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion") def test_when_agent_publishes_should_get_hostname_from_subscriber2(self): self.insecure_publish("client2", "crazy_onion2.onion") database = self.loadDatabase() self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion") def test_when_agent_publishes_multiple_hosts_should_provide_latest(self): self.insecure_publish("client2", "crazy_onion2-34.onion") self.insecure_publish("client3", "crazy_onion3.onion") self.insecure_publish("client1", "crazy_onion1.onion") self.insecure_publish("client2", "crazy_onion2-56.onion") self.insecure_publish("client3", "crazy_onion3.onion") database = self.loadDatabase() self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion") self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion") self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion") @staticmethod def insecure_publish(client_id, onion_address): publish(client_id, onion_address, use_tls=False) if __name__ == '__main__': unittest.main()