Can create database file if doesn't exist; Refactored

This commit is contained in:
B.J. Dweck 2020-10-15 19:48:48 +02:00
parent e580fd343b
commit 53a0b3cbb6
3 changed files with 56 additions and 66 deletions

View File

@ -1 +0,0 @@
127.0.0.1 mqtt.example.com

View File

@ -1,4 +1,5 @@
import json import json
import os
import ssl import ssl
import threading import threading
import time import time
@ -9,8 +10,8 @@ import paho.mqtt.client as mqtt
from torch_sub import torch_sub from torch_sub import torch_sub
host = "mqtt.example.com" broker_hostname = "mqtt.example.com"
port = 8883 broker_port = 8883
agent_config_path = "agent-config/" agent_config_path = "agent-config/"
mqtt_ca_file = agent_config_path + "ca.crt" mqtt_ca_file = agent_config_path + "ca.crt"
mqtt_cert_file = agent_config_path + "vagrant.crt" mqtt_cert_file = agent_config_path + "vagrant.crt"
@ -29,11 +30,12 @@ def agent_connect():
certfile=mqtt_cert_file, certfile=mqtt_cert_file,
keyfile=mqtt_key_file, keyfile=mqtt_key_file,
cert_reqs=ssl.CERT_REQUIRED) cert_reqs=ssl.CERT_REQUIRED)
client.connect(host, port, 60) client.connect(broker_hostname, broker_port, 60)
return client return client
def agent_publish(client_id, onion_hostname): def agent_publish(client_id, onion_hostname):
payload = { payload = {
'clientId': client_id, 'clientId': client_id,
'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"), 'timestamp': datetime.now().strftime("%d-%b-%Y (%H:%M:%S.%f)"),
@ -41,54 +43,50 @@ def agent_publish(client_id, onion_hostname):
'sshPort': 22 'sshPort': 22
} }
time.sleep(0.2)
client = agent_connect() client = agent_connect()
client.publish("torch/" + client_id + "/wake", json.dumps(payload))
topic = "torch/" + client_id + "/wake"
client.publish(topic, json.dumps(payload))
print("Debug: Connected to MQTT Broker at %s://%s:%s/%s" % ("mqtts", host, port, topic))
print("Debug: Published payload: " + json.dumps(payload))
client.disconnect() client.disconnect()
print("Debug: Disconnected from MQTT Broker") time.sleep(0.2)
pass
def subscriber_thread(host, port, filename, topic, cafile, certfile, keyfile):
torch_sub.attach(host,
port,
filename,
topic=topic,
mqtt_ca_file=cafile,
mqtt_cert_file=certfile,
mqtt_key_file=keyfile)
class GivenBrokerAndTorchAgent(unittest.TestCase): class GivenBrokerAndTorchAgent(unittest.TestCase):
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
outfile = "clients.json"
threading.Thread(target=subscriber_thread, def setUp(self) -> None:
args=(host, if os.path.exists(torch_sub.database_file):
port, os.remove(torch_sub.database_file)
outfile,
def tearDown(self) -> None:
if os.path.exists(torch_sub.database_file):
os.remove(torch_sub.database_file)
def test_when_agent_publishes_should_get_hostname_from_subscriber(self):
self.run_subscriber()
agent_publish("client1", "crazy_onion.onion")
database = self.loadDatabase()
self.assertEqual(database['client1']['onionAddress'], "crazy_onion.onion")
def test_when_agent_publishes_should_get_hostname_from_subscriber2(self):
self.run_subscriber()
agent_publish("client2", "crazy_onion2.onion")
database = self.loadDatabase()
self.assertEqual(database['client2']['onionAddress'], "crazy_onion2.onion")
@staticmethod
def run_subscriber():
threading.Thread(target=torch_sub.subscribe,
args=(broker_hostname,
broker_port,
"torch/+/wake", "torch/+/wake",
subscriber_ca_file, subscriber_ca_file,
subscriber_cert_file, subscriber_cert_file,
subscriber_key_file), subscriber_key_file),
daemon=True).start() daemon=True).start()
time.sleep(0.5) @staticmethod
def loadDatabase():
agent_publish("client1", "crazyonion.onion") with open(torch_sub.database_file, "r") as database_file:
return json.load(database_file)
time.sleep(0.5)
file = open(outfile, "r")
response = json.load(file)
self.assertEqual(response['client1']['onionAddress'], "crazyonion.onion")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,44 +1,37 @@
import json import json
import os
import paho.mqtt.client as cl
import paho.mqtt.subscribe as mqtt import paho.mqtt.subscribe as mqtt
datafile = "clients.json" database_file = "clients.json"
def updateOnion(client, userdata, message):
with open(datafile, 'r') as infile: # noinspection PyUnusedLocal
def update_client_record(client, userdata, message):
if not os.path.exists(database_file):
with open(database_file, 'w') as database_blank:
database_blank.write("{}")
with open(database_file, 'r') as infile:
database = json.load(infile) database = json.load(infile)
payload = message.payload.decode('utf-8') payload = message.payload.decode('utf-8')
print("Payload: %s" % (payload))
response = json.loads(payload) response = json.loads(payload)
print(response)
print("Response: %s" % (response))
print("Database: %s" % (database))
database[response['clientId']] = response database[response['clientId']] = response
print("got one! %s %s %s" % (client, userdata, payload)) with open(database_file, 'w') as outfile:
with open(datafile, 'w') as outfile:
json.dump(database, outfile) json.dump(database, outfile)
def subscribe(broker_hostname, broker_port, topic="torch", ca_file=None, cert_file=None, key_file=None):
def attach(host, port, datafileIn, mqtt_ca_file=None, mqtt_cert_file=None, mqtt_key_file=None, topic=None): mqtt.callback(update_client_record,
datafile = datafileIn
mqtt.callback(updateOnion,
topic, topic,
hostname=host, hostname=broker_hostname,
port=port, port=broker_port,
tls={ tls={
'ca_certs': mqtt_ca_file, 'ca_certs': ca_file,
'certfile': mqtt_cert_file, 'certfile': cert_file,
'keyfile': mqtt_key_file 'keyfile': key_file
}) })