torch-subscriber-simple/test/test_integration.py

168 lines
5.7 KiB
Python

import json
import os
import ssl
import sys
import threading
import time
import unittest
from datetime import datetime
from unittest.case import TestCase
import paho.mqtt.client as mqtt
from torchsub import torch_sub
broker_hostname = "mqtt.example.com"
broker_port = 8883
agent_config_path = "test/agent-config/"
mqtt_ca_file = agent_config_path + "ca.crt"
mqtt_cert_file = agent_config_path + "vagrant.crt"
mqtt_key_file = agent_config_path + "vagrant.key"
subscriber_config_path = "test/subscriber-config/"
subscriber_ca_file = subscriber_config_path + "ca.crt"
subscriber_cert_file = subscriber_config_path + "subscriber.crt"
subscriber_key_file = subscriber_config_path + "subscriber.key"
def agent_connect(use_tls=True):
client = mqtt.Client()
if use_tls:
client.tls_set(
ca_certs=mqtt_ca_file,
certfile=mqtt_cert_file,
keyfile=mqtt_key_file,
cert_reqs=ssl.CERT_REQUIRED)
client.connect(broker_hostname, broker_port, 60)
return client
def publish(client_id, onion_hostname, use_tls=True):
payload = {
'clientId': client_id,
'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"),
'onionAddress': onion_hostname,
'sshPort': 22
}
time.sleep(1)
client = agent_connect(use_tls=use_tls)
client.publish("torch/" + client_id + "/wake", json.dumps(payload))
client.disconnect()
time.sleep(1)
class GivenBroker(TestCase):
def setUp(self) -> None:
if os.path.exists(torch_sub.database_filename):
os.remove(torch_sub.database_filename)
def tearDown(self) -> None:
os.system("docker container stop mosquitto")
if os.path.exists(torch_sub.database_filename):
os.remove(torch_sub.database_filename)
@staticmethod
def run_broker(tls):
cli = "test/run-broker.sh " + tls
if sys.platform.startswith('win32'):
cli = "test\\run-broker.bat " + tls
os.system(cli)
time.sleep(2)
@staticmethod
def run_subscriber():
threading.Thread(target=torch_sub.subscribe,
args=(broker_hostname,
broker_port,
"torch/+/wake",
{
'ca_certs': subscriber_ca_file,
'certfile': subscriber_cert_file,
'keyfile': subscriber_key_file
}),
daemon=True).start()
@staticmethod
def run_insecure_subscriber():
threading.Thread(target=torch_sub.subscribe,
args=(broker_hostname,
broker_port,
"torch/+/wake",
None),
daemon=True).start()
@staticmethod
def loadDatabase():
with open(torch_sub.database_filename, "r") as database_file:
return json.load(database_file)
class GivenTlsBroker(GivenBroker):
def setUp(self) -> None:
self.run_broker("secure")
self.run_subscriber()
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
publish("client1", "crazy_onion.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion")
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self):
publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
def test_when_agent_publishes_multiple_hosts_should_provide_latest(self):
publish("client2", "crazy_onion2-34.onion")
publish("client3", "crazy_onion3.onion")
publish("client1", "crazy_onion1.onion")
publish("client2", "crazy_onion2-56.onion")
publish("client3", "crazy_onion3.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion")
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion")
self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion")
class GivenNonTlsBroker(GivenBroker):
def setUp(self) -> None:
self.run_broker("insecure")
self.run_insecure_subscriber()
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
self.insecure_publish("client1", "crazy_onion.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion")
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self):
self.insecure_publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
def test_when_agent_publishes_multiple_hosts_should_provide_latest(self):
self.insecure_publish("client2", "crazy_onion2-34.onion")
self.insecure_publish("client3", "crazy_onion3.onion")
self.insecure_publish("client1", "crazy_onion1.onion")
self.insecure_publish("client2", "crazy_onion2-56.onion")
self.insecure_publish("client3", "crazy_onion3.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion")
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion")
self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion")
@staticmethod
def insecure_publish(client_id, onion_address):
publish(client_id, onion_address, use_tls=False)
if __name__ == '__main__':
unittest.main()