...
 
Commits (2)
import os
import socket
import asyncio
MAX_QLEN = 128
class AsyncSystemdNotifier:
""" Boilerplate for proper implementation. This one, however,
......@@ -12,27 +15,50 @@ class AsyncSystemdNotifier:
else env_var)
self._sock = None
self._started = False
self._loop = None
self._queue = asyncio.Queue(MAX_QLEN)
@property
def started(self):
return self._started
def _drain(self):
try:
while not self._queue.empty():
msg = self._queue.get_nowait()
self._queue.task_done()
self._send(msg)
except BlockingIOError: # pragma: no cover
pass
except OSError:
pass
def _send(self, data):
return self._sock.sendto(data, socket.MSG_NOSIGNAL, self._addr)
async def start(self):
if self._addr is None:
return False
self._loop = asyncio.get_event_loop()
try:
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
self._sock.setblocking(0)
self._loop.add_writer(self._sock.fileno(), self._drain)
self._started = True
except socket.error:
except OSError:
return False
return True
async def notify(self, status):
if self._started:
try:
self._sock.sendto(status, socket.MSG_NOSIGNAL, self._addr)
except socket.error:
pass
await self._queue.put(status)
self._drain()
async def stop(self):
if self._started:
self._started = False
await self._queue.join()
self._loop.remove_writer(self._sock.fileno())
self._sock.close()
async def __aenter__(self):
......
......@@ -49,6 +49,9 @@ class SingleNetstringFetcher:
def done(self):
return self._done
def pending(self):
return self._len is not None
def read(self, nbytes=65536):
# pylint: disable=too-many-branches
if not self._len_known:
......@@ -109,6 +112,9 @@ class StreamReader:
self._incoming = ssl.MemoryBIO()
self._fetcher = None
def pending(self):
return self._fetcher is not None and self._fetcher.pending()
def feed(self, data):
self._incoming.write(data)
......@@ -138,4 +144,5 @@ def decode(data):
res.append(buf)
yield b''.join(res)
except WantRead:
pass
if reader.pending():
raise IncompleteNetstring("Input ends on unfinished string.")
import contextlib
import socket
import asyncio
import os
import sys
import pytest
from postfix_mta_sts_resolver.asdnotify import AsyncSystemdNotifier
@contextlib.contextmanager
def set_env(**environ):
old_environ = dict(os.environ)
os.environ.update(environ)
try:
yield
finally:
os.environ.clear()
os.environ.update(old_environ)
class UnixDatagramReceiver:
def __init__(self, loop):
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
self._sock.setblocking(0)
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._sock.bind('')
self._name = self._sock.getsockname()
self._incoming = asyncio.Queue()
self._loop = loop
loop.add_reader(self._sock.fileno(), self._read_handler)
def _read_handler(self):
try:
while True:
msg = self._sock.recv(4096)
self._incoming.put_nowait(msg)
except BlockingIOError: # pragma: no cover
pass
async def recvmsg(self):
return await self._incoming.get()
@property
def name(self):
return self._name
@property
def asciiname(self):
sockname = self.name
if isinstance(sockname, bytes):
sockname = sockname.decode('ascii')
if sockname.startswith('\x00'):
sockname = '@' + sockname[1:]
return sockname
def close(self):
self._loop.remove_reader(self._sock.fileno())
self._sock.close()
self._sock = None
pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows")
@pytest.fixture(scope="module")
def unix_dgram_receiver(event_loop):
udr = UnixDatagramReceiver(event_loop)
yield udr
udr.close()
@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_message_sent(unix_dgram_receiver):
sockname = unix_dgram_receiver.asciiname
msg = b"READY=1"
with set_env(NOTIFY_SOCKET=sockname):
async with AsyncSystemdNotifier() as notifier:
await notifier.notify(msg)
assert await unix_dgram_receiver.recvmsg() == msg
@pytest.mark.timeout(5)
@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows")
@pytest.mark.asyncio
async def test_message_flow(unix_dgram_receiver):
sockname = unix_dgram_receiver.asciiname
msgs = [b"READY=1", b'STOPPING=1'] * 500
with set_env(NOTIFY_SOCKET=sockname):
async with AsyncSystemdNotifier() as notifier:
for msg in msgs:
await notifier.notify(msg)
assert await unix_dgram_receiver.recvmsg() == msg
@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_not_started():
async with AsyncSystemdNotifier() as notifier:
assert not notifier.started
@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_started(unix_dgram_receiver):
with set_env(NOTIFY_SOCKET=unix_dgram_receiver.asciiname):
async with AsyncSystemdNotifier() as notifier:
assert notifier.started
@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_send_never_fails():
with set_env(NOTIFY_SOCKET='abc'):
async with AsyncSystemdNotifier() as notifier:
await notifier.notify(b'!!!')
@pytest.mark.timeout(5)
@pytest.mark.asyncio
async def test_socket_create_failure(monkeypatch):
class mocksock:
def __init__(self, *args, **kwargs):
raise OSError()
monkeypatch.setattr(socket, "socket", mocksock)
with set_env(NOTIFY_SOCKET='abc'):
async with AsyncSystemdNotifier() as notifier:
await notifier.notify(b'!!!')
......@@ -18,6 +18,7 @@ def test_leading_zeroes(reference, sample):
assert reference == list(netstring.decode(sample))
@pytest.mark.parametrize("reference,sample", [
pytest.param([], b'', id="nodata"),
pytest.param([b''], b'0:,', id="empty"),
pytest.param([b'5:Hello,6:World!,'], b'17:5:Hello,6:World!,,', id="nested"),
])
......@@ -29,6 +30,18 @@ def test_bad_length(encoded):
with pytest.raises(netstring.BadLength):
list(netstring.decode(encoded))
@pytest.mark.parametrize("encoded", [b'3', b'3:', b'3:a', b'3:aa', b'3:aaa'])
def test_decode_incomplete_string(encoded):
with pytest.raises(netstring.IncompleteNetstring):
list(netstring.decode(encoded))
def test_abandoned_string_reader_handles():
stream_reader = netstring.StreamReader()
stream_reader.feed(b'0:,')
string_reader = stream_reader.next_string()
with pytest.raises(netstring.InappropriateParserState):
string_reader = stream_reader.next_string()
@pytest.mark.parametrize("encoded", [b'0:_', b'3:aaa_'])
def test_bad_terminator(encoded):
with pytest.raises(netstring.BadTerminator):
......