Commits (5)
import os
import socket
class AsyncSystemdNotifier:
""" Boilerplate for proper implementation. This one, however,
also will work. """
def __init__(self):
env_var = os.getenv('NOTIFY_SOCKET')
self._addr = ('\0' + env_var[1:]
if env_var is not None and env_var.startswith('@')
else env_var)
self._sock = None
self._started = False
async def start(self):
if self._addr is None:
return False
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
self._started = True
except socket.error:
return False
return True
async def notify(self, status):
if self._started:
self._sock.sendto(status, socket.MSG_NOSIGNAL, self._addr)
except socket.error:
async def stop(self):
if self._started:
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, exc_type, exc, traceback):
await self.stop()
......@@ -7,7 +7,7 @@ import logging
import signal
from functools import partial
from sdnotify import SystemdNotifier
from .asdnotify import AsyncSystemdNotifier
from . import utils
from . import defaults
from .responder import STSSocketmapResponder
......@@ -72,11 +72,11 @@ async def amain(cfg, loop): # pragma: no cover
sig_handler = partial(exit_handler, exit_event)
signal.signal(signal.SIGTERM, sig_handler)
signal.signal(signal.SIGINT, sig_handler)
notifier = await loop.run_in_executor(None, SystemdNotifier)
await loop.run_in_executor(None, notifier.notify, "READY=1")
await exit_event.wait()
logger.debug("Eventloop interrupted. Shutting down server...")
await loop.run_in_executor(None, notifier.notify, "STOPPING=1")
async with AsyncSystemdNotifier() as notifier:
await notifier.notify(b"READY=1")
await exit_event.wait()
logger.debug("Eventloop interrupted. Shutting down server...")
await notifier.notify(b"STOPPING=1")
await responder.stop()
......@@ -49,7 +49,7 @@ class SingleNetstringFetcher:
def done(self):
return self._done
def read(self):
def read(self, nbytes=65536):
# pylint: disable=too-many-branches
if not self._len_known:
# reading length
......@@ -70,7 +70,7 @@ class SingleNetstringFetcher:
raise TooLong("Netstring length is over limit.")
# reading data
if self._len:
buf = self._incoming.read(self._len)
buf = self._incoming.read(min(nbytes, self._len))
if not buf:
raise WantRead()
self._len -= len(buf)
......@@ -7,7 +7,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read() # pylint: disable=invalid-name
description='Daemon which provides TLS client policy for Postfix '
'via socketmap, according to domain MTA-STS policy',
......@@ -23,7 +23,6 @@ setup(name='postfix_mta_sts_resolver',
'sqlite': 'aiosqlite>=0.10.0',