torch-subscriber-simple/test/test_integration.py

168 lines
5.7 KiB
Python
Raw Normal View History

import json
import os
import ssl
2020-10-15 23:31:57 +00:00
import sys
import threading
import time
import unittest
from datetime import datetime
from unittest.case import TestCase
import paho.mqtt.client as mqtt
from torch_sub import torch_sub
broker_hostname = "mqtt.example.com"
broker_port = 8883
2020-10-15 23:31:57 +00:00
agent_config_path = "test/agent-config/"
mqtt_ca_file = agent_config_path + "ca.crt"
mqtt_cert_file = agent_config_path + "vagrant.crt"
mqtt_key_file = agent_config_path + "vagrant.key"
subscriber_config_path = "test/subscriber-config/"
subscriber_ca_file = subscriber_config_path + "ca.crt"
subscriber_cert_file = subscriber_config_path + "subscriber.crt"
subscriber_key_file = subscriber_config_path + "subscriber.key"
def agent_connect(use_tls=True):
client = mqtt.Client()
if use_tls:
client.tls_set(
ca_certs=mqtt_ca_file,
certfile=mqtt_cert_file,
keyfile=mqtt_key_file,
cert_reqs=ssl.CERT_REQUIRED)
client.connect(broker_hostname, broker_port, 60)
return client
def publish(client_id, onion_hostname, use_tls=True):
payload = {
'clientId': client_id,
'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"),
'onionAddress': onion_hostname,
'sshPort': 22
}
2020-10-15 23:31:57 +00:00
time.sleep(1)
client = agent_connect(use_tls=use_tls)
client.publish("torch/" + client_id + "/wake", json.dumps(payload))
client.disconnect()
2020-10-15 23:31:57 +00:00
time.sleep(1)
class GivenBroker(TestCase):
def setUp(self) -> None:
if os.path.exists(torch_sub.database_file):
os.remove(torch_sub.database_file)
def tearDown(self) -> None:
2020-10-15 23:31:57 +00:00
os.system("docker container stop mosquitto")
if os.path.exists(torch_sub.database_file):
os.remove(torch_sub.database_file)
@staticmethod
def run_broker(tls):
cli = "test/run-broker.sh " + tls
if sys.platform.startswith('win32'):
cli = "test\\run-broker.bat " + tls
os.system(cli)
time.sleep(2)
@staticmethod
def run_subscriber():
threading.Thread(target=torch_sub.subscribe,
args=(broker_hostname,
broker_port,
"torch/+/wake",
{
'ca_certs': subscriber_ca_file,
'certfile': subscriber_cert_file,
'keyfile': subscriber_key_file
}),
daemon=True).start()
@staticmethod
def run_insecure_subscriber():
threading.Thread(target=torch_sub.subscribe,
args=(broker_hostname,
broker_port,
"torch/+/wake",
None),
daemon=True).start()
@staticmethod
def loadDatabase():
with open(torch_sub.database_file, "r") as database_file:
return json.load(database_file)
class GivenTlsBroker(GivenBroker):
def setUp(self) -> None:
self.run_broker("secure")
self.run_subscriber()
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
publish("client1", "crazy_onion.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion")
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self):
publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
def test_when_agent_publishes_multiple_hosts_should_provide_latest(self):
publish("client2", "crazy_onion2-34.onion")
publish("client3", "crazy_onion3.onion")
publish("client1", "crazy_onion1.onion")
publish("client2", "crazy_onion2-56.onion")
publish("client3", "crazy_onion3.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion")
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion")
self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion")
class GivenNonTlsBroker(GivenBroker):
def setUp(self) -> None:
self.run_broker("insecure")
self.run_insecure_subscriber()
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
self.insecure_publish("client1", "crazy_onion.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion")
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self):
self.insecure_publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
def test_when_agent_publishes_multiple_hosts_should_provide_latest(self):
self.insecure_publish("client2", "crazy_onion2-34.onion")
self.insecure_publish("client3", "crazy_onion3.onion")
self.insecure_publish("client1", "crazy_onion1.onion")
self.insecure_publish("client2", "crazy_onion2-56.onion")
self.insecure_publish("client3", "crazy_onion3.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion1.onion")
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2-56.onion")
self.assertEqual(database['client3']['onionAddress'], "crazy_onion3.onion")
@staticmethod
def insecure_publish(client_id, onion_address):
publish(client_id, onion_address, use_tls=False)
if __name__ == '__main__':
unittest.main()