wg-manager/wg_connection_manager/dj_wg_manager_task.py

69 lines
2.7 KiB
Python

from django.db import transaction
from .models import UserConnection
from pywireguard.factory import Peer
import environ
from celery import Celery
from .wg_manager import WGManager
env = environ.Env(WG_INTERFACE=(str, "wg0"), CELERY_BROKER=(str, "redis://localhost"), CELERY_BACKEND=(str, "redis://localhost"))
environ.Env.read_env(".env")
WG_INTERFACE = env("WG_INTERFACE")
CELERY_BROKER = env("CELERY_BROKER")
CELERY_BACKEND = env("CELERY_BACKEND")
class DJWGManager:
wg_manager: WGManager
app: Celery
def __init__(self):
self.app = Celery("wg_manager_tasks", broker=CELERY_BROKER, backend=CELERY_BACKEND)
self.app.conf.event_serializer = (
"pickle" # this event_serializer is optional. somehow i missed this when writing this solution and it still worked without.
)
self.app.conf.task_serializer = "pickle"
self.app.conf.result_serializer = "pickle"
self.app.conf.accept_content = ["application/json", "application/x-python-serialize"]
def sync(self):
with transaction.atomic():
UserConnection.objects.filter(active=True).update(active=False)
res = self.app.send_task("wg_connection_manager_worker.tasks.get_peers")
peers = res.get()
for peer in peers:
pk = peer.public_key.decode("ascii")
psk = peer.preshared_key.decode("ascii")
if not peer.allowed_ips:
continue
connection = UserConnection.objects.filter(public_key=pk)
if len(connection) == 1:
connection = connection[0]
else:
connection = UserConnection()
connection.public_key = pk
connection.preshared_key = psk
connection.active = True
connection.vpn_ip = peer.allowed_ips[0]
connection.save()
def add_peer(self, user_connection: UserConnection):
res = self.app.send_task(
"wg_connection_manager_worker.tasks.add_peer",
[Peer(public_key=user_connection.public_key, preshared_key=user_connection.preshared_key, allowed_ips=[user_connection.vpn_ip])],
)
res.get()
self.sync()
def remove_peer(self, user_connection: UserConnection):
pk = user_connection.public_key
res = self.app.send_task("wg_connection_manager_worker.tasks.get_peers")
peers = res.get()
peer = list(filter(lambda x: x.public_key.decode("ascii") == pk, peers))
if not peer:
# TODO raise exception/ignore?
return
peer = peer[0]
res = self.app.send_task("wg_connection_manager_worker.tasks.remove_peer", [peer])
res.get()
self.sync()