torch-subscriber-simple/test/test_integration_pub_sub.py

106 lines
3.5 KiB
Python
Raw Normal View History

import json
import os
import ssl
import threading
import time
import unittest
from datetime import datetime
import paho.mqtt.client as mqtt
from torch_sub import torch_sub
broker_hostname = "mqtt.example.com"
broker_port = 8883
agent_config_path = "test/agent-config/"
mqtt_ca_file = agent_config_path + "ca.crt"
mqtt_cert_file = agent_config_path + "vagrant.crt"
mqtt_key_file = agent_config_path + "vagrant.key"
subscriber_config_path = "test/subscriber-config/"
subscriber_ca_file = subscriber_config_path + "ca.crt"
subscriber_cert_file = subscriber_config_path + "subscriber.crt"
subscriber_key_file = subscriber_config_path + "subscriber.key"
def agent_connect():
client = mqtt.Client()
client.tls_set(
ca_certs=mqtt_ca_file,
certfile=mqtt_cert_file,
keyfile=mqtt_key_file,
cert_reqs=ssl.CERT_REQUIRED)
client.connect(broker_hostname, broker_port, 60)
return client
def agent_publish(client_id, onion_hostname):
payload = {
'clientId': client_id,
'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"),
'onionAddress': onion_hostname,
'sshPort': 22
}
time.sleep(0.2)
client = agent_connect()
client.publish("torch/" + client_id + "/wake", json.dumps(payload))
client.disconnect()
time.sleep(0.2)
class GivenBrokerAndTorchAgent(unittest.TestCase):
def setUp(self) -> None:
if os.path.exists(torch_sub.database_file):
os.remove(torch_sub.database_file)
def tearDown(self) -> None:
if os.path.exists(torch_sub.database_file):
os.remove(torch_sub.database_file)
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
self.run_subscriber()
agent_publish("client1", "crazy_onion.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion")
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self):
self.run_subscriber()
agent_publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
def test_when_agent_publishes_multiple_hosts_should_provide_latest(self):
self.run_subscriber()
agent_publish("client2", "crazy_onion2-34.onion")
agent_publish("client3", "crazy_onion3.onion")
agent_publish("client1", "crazy_onion1.onion")
agent_publish("client2", "crazy_onion2-56.onion")
agent_publish("client3", "crazy_onion3.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion")
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion")
self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion")
@staticmethod
def run_subscriber():
threading.Thread(target=torch_sub.subscribe,
args=(broker_hostname,
broker_port,
"torch/+/wake",
subscriber_ca_file,
subscriber_cert_file,
subscriber_key_file),
daemon=True).start()
@staticmethod
def loadDatabase():
with open(torch_sub.database_file, "r") as database_file:
return json.load(database_file)
if __name__ == '__main__':
unittest.main()