diff --git a/torch_sub/test/hosts_test b/torch_sub/test/hosts_test deleted file mode 100644 index 9a95b8e..0000000 --- a/torch_sub/test/hosts_test +++ /dev/null @@ -1 +0,0 @@ -127.0.0.1 mqtt.example.com diff --git a/torch_sub/test/integration_pub_sub.py b/torch_sub/test/integration_pub_sub.py index 981f2f4..718a80f 100644 --- a/torch_sub/test/integration_pub_sub.py +++ b/torch_sub/test/integration_pub_sub.py @@ -1,4 +1,5 @@ import json +import os import ssl import threading import time @@ -9,8 +10,8 @@ import paho.mqtt.client as mqtt from torch_sub import torch_sub -host = "mqtt.example.com" -port = 8883 +broker_hostname = "mqtt.example.com" +broker_port = 8883 agent_config_path = "agent-config/" mqtt_ca_file = agent_config_path + "ca.crt" mqtt_cert_file = agent_config_path + "vagrant.crt" @@ -29,11 +30,12 @@ def agent_connect(): certfile=mqtt_cert_file, keyfile=mqtt_key_file, cert_reqs=ssl.CERT_REQUIRED) - client.connect(host, port, 60) + client.connect(broker_hostname, broker_port, 60) return client def agent_publish(client_id, onion_hostname): + payload = { 'clientId': client_id, 'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"), @@ -41,54 +43,50 @@ def agent_publish(client_id, onion_hostname): 'sshPort': 22 } + time.sleep(0.2) client = agent_connect() - - topic = "torch/" + client_id + "/wake" - - client.publish(topic, json.dumps(payload)) - print("Debug: Connected to MQTT Broker at %s://%s:%s/%s" % ("mqtts", host, port, topic)) - print("Debug: Published payload: " + json.dumps(payload)) - + client.publish("torch/" + client_id + "/wake", json.dumps(payload)) client.disconnect() - print("Debug: Disconnected from MQTT Broker") - - pass - - -def subscriber_thread(host, port, filename, topic, cafile, certfile, keyfile): - torch_sub.attach(host, - port, - filename, - topic=topic, - mqtt_ca_file=cafile, - mqtt_cert_file=certfile, - mqtt_key_file=keyfile) + time.sleep(0.2) class GivenBrokerAndTorchAgent(unittest.TestCase): - def test_when_agent_publishes_should_get_hostname_from_subscriber(self): - outfile = "clients.json" - threading.Thread(target=subscriber_thread, - args=(host, - port, - outfile, + def setUp(self) -> None: + if os.path.exists(torch_sub.database_file): + os.remove(torch_sub.database_file) + + def tearDown(self) -> None: + if os.path.exists(torch_sub.database_file): + os.remove(torch_sub.database_file) + + def test_when_agent_publishes_should_get_hostname_from_subscriber(self): + self.run_subscriber() + agent_publish("client1", "crazy_onion.onion") + database = self.loadDatabase() + self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion") + + def test_when_agent_publishes_should_get_hostname_from_subscriber2(self): + self.run_subscriber() + agent_publish("client2", "crazy_onion2.onion") + database = self.loadDatabase() + self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion") + + @staticmethod + def run_subscriber(): + threading.Thread(target=torch_sub.subscribe, + args=(broker_hostname, + broker_port, "torch/+/wake", subscriber_ca_file, subscriber_cert_file, subscriber_key_file), daemon=True).start() - time.sleep(0.5) - - agent_publish("client1", "crazyonion.onion") - - time.sleep(0.5) - - file = open(outfile, "r") - response = json.load(file) - - self.assertEqual(response['client1']['onionAddress'], "crazyonion.onion") + @staticmethod + def loadDatabase(): + with open(torch_sub.database_file, "r") as database_file: + return json.load(database_file) if __name__ == '__main__': diff --git a/torch_sub/torch_sub.py b/torch_sub/torch_sub.py index 553bd28..b1bbce4 100644 --- a/torch_sub/torch_sub.py +++ b/torch_sub/torch_sub.py @@ -1,44 +1,37 @@ import json +import os -import paho.mqtt.client as cl import paho.mqtt.subscribe as mqtt -datafile = "clients.json" +database_file = "clients.json" -def updateOnion(client, userdata, message): - with open(datafile, 'r') as infile: +# noinspection PyUnusedLocal +def update_client_record(client, userdata, message): + + if not os.path.exists(database_file): + with open(database_file, 'w') as database_blank: + database_blank.write("{}") + + with open(database_file, 'r') as infile: database = json.load(infile) payload = message.payload.decode('utf-8') - - print("Payload: %s" % (payload)) - response = json.loads(payload) - print(response) - print("Response: %s" % (response)) - - print("Database: %s" % (database)) database[response['clientId']] = response - print("got one! %s %s %s" % (client, userdata, payload)) - - with open(datafile, 'w') as outfile: + with open(database_file, 'w') as outfile: json.dump(database, outfile) - -def attach(host, port, datafileIn, mqtt_ca_file=None, mqtt_cert_file=None, mqtt_key_file=None, topic=None): - - datafile = datafileIn - - mqtt.callback(updateOnion, +def subscribe(broker_hostname, broker_port, topic="torch", ca_file=None, cert_file=None, key_file=None): + mqtt.callback(update_client_record, topic, - hostname=host, - port=port, + hostname=broker_hostname, + port=broker_port, tls={ - 'ca_certs': mqtt_ca_file, - 'certfile': mqtt_cert_file, - 'keyfile': mqtt_key_file - }) \ No newline at end of file + 'ca_certs': ca_file, + 'certfile': cert_file, + 'keyfile': key_file + })