...
 
Commits (9)
......@@ -5,7 +5,6 @@ postfix-mta-sts-resolver
Daemon which provides TLS client policy for Postfix via socketmap, according to domain MTA-STS policy. Current support of RFC8461 is limited - daemon lacks some minor features:
* Proactive policy fetch
* Fetch error reporting
* Fetch ratelimit (but actual fetch rate partially restricted with `cache_grace` config option).
......
......@@ -43,6 +43,14 @@ The file is in YAML syntax with the following elements:
** Options for _redis_ type:
*** All parameters are passed to `aioredis.create_redis_pool` [0]. Check there for a parameter reference.
*proactive_policy_fetching*::
* *enabled*: (_bool_) enable proactive policy fetching in the background. Default: false
* *interval*: (_int_) if proactive policy fetching is enabled, it is scheduled every this many seconds.
It is unaffected by `cache_grace` and vice versa. Default: 86400
* *concurrency_limit*: (_int_) the maximum number of concurrent domain updates. Default: 100
* *grace_ratio*: (_float_) proactive fetch for a particular domain is skipped if its cached policy age is less than `interval/grace_ratio`. Default: 2.0
*default_zone*::
* *strict_testing*: (_bool_) enforce policy for testing domains
......@@ -65,6 +73,11 @@ domains operate under "testing" mode.
port: 8461
reuse_port: true
shutdown_timeout: 20
proactive_policy_fetching:
enabled: true
interval: 86400
concurrency_limit: 100
grace_ratio: 2
cache:
type: internal
options:
......
import asyncio
import collections
from abc import ABC, abstractmethod
......@@ -19,6 +20,26 @@ class BaseCache(ABC):
async def set(self, key, value):
""" Abstract method """
async def safe_set(self, domain, entry, logger):
try:
await self.set(domain, entry)
except asyncio.CancelledError: # pragma: no cover pylint: disable=try-except-raise
raise
except Exception as exc: # pragma: no cover
logger.exception("Cache set failed: %s", str(exc))
@abstractmethod
async def scan(self, token, amount_hint):
""" Abstract method """
@abstractmethod
async def get_proactive_fetch_ts(self):
""" Abstract method """
@abstractmethod
async def set_proactive_fetch_ts(self, timestamp):
""" Abstract method """
@abstractmethod
async def teardown(self):
""" Abstract method """
......@@ -2,3 +2,5 @@ HARD_RESP_LIMIT = 64 * 1024
CHUNK = 4096
QUEUE_LIMIT = 128
REQUEST_LIMIT = 1024
DOMAIN_QUEUE_LIMIT = 1000
MIN_PROACTIVE_FETCH_INTERVAL = 1
......@@ -10,7 +10,9 @@ from functools import partial
from .asdnotify import AsyncSystemdNotifier
from . import utils
from . import defaults
from .proactive_fetcher import STSProactiveFetcher
from .responder import STSSocketmapResponder
from .utils import create_cache
def parse_args():
......@@ -61,12 +63,28 @@ async def heartbeat():
async def amain(cfg, loop): # pragma: no cover
logger = logging.getLogger("MAIN")
# Construct request handler instance
responder = STSSocketmapResponder(cfg, loop)
proactive_fetch_enabled = cfg['proactive_policy_fetching']['enabled']
# Create policy cache
cache = create_cache(cfg["cache"]["type"],
cfg["cache"]["options"])
await cache.setup()
# Construct request handler
responder = STSSocketmapResponder(cfg, loop, cache)
await responder.start()
logger.info("Server started.")
# Conditionally construct proactive policy fetcher
proactive_fetcher = None
if proactive_fetch_enabled:
proactive_fetcher = STSProactiveFetcher(cfg, loop, cache)
await proactive_fetcher.start()
logger.info("Proactive policy fetcher started.")
else:
logger.info("Proactive policy fetching is disabled.")
exit_event = asyncio.Event()
beat = asyncio.ensure_future(heartbeat())
sig_handler = partial(exit_handler, exit_event)
......@@ -79,6 +97,9 @@ async def amain(cfg, loop): # pragma: no cover
await notifier.notify(b"STOPPING=1")
beat.cancel()
await responder.stop()
if proactive_fetch_enabled:
await proactive_fetcher.stop()
await cache.teardown()
def main(): # pragma: no cover
......@@ -87,6 +108,7 @@ def main(): # pragma: no cover
with utils.AsyncLoggingHandler(args.logfile) as log_handler:
logger = utils.setup_logger('MAIN', args.verbosity, log_handler)
utils.setup_logger('STS', args.verbosity, log_handler)
utils.setup_logger('PF', args.verbosity, log_handler)
logger.info("MTA-STS daemon starting...")
# Read config and populate with defaults
......
......@@ -13,4 +13,8 @@ SQLITE_THREADS = cpu_count()
SQLITE_TIMEOUT = 5
REDIS_TIMEOUT = 5
CACHE_GRACE = 60
PROACTIVE_FETCH_ENABLED = False
PROACTIVE_FETCH_INTERVAL = 86400
PROACTIVE_FETCH_CONCURRENCY_LIMIT = 100
PROACTIVE_FETCH_GRACE_RATIO = 2.0
USER_AGENT = "postfix-mta-sts-resolver"
import collections
from itertools import islice
from .base_cache import BaseCache
......@@ -7,6 +8,7 @@ class InternalLRUCache(BaseCache):
def __init__(self, cache_size=10000):
self._cache_size = cache_size
self._cache = collections.OrderedDict()
self._proactive_fetch_ts = 0
async def setup(self):
pass
......@@ -29,3 +31,25 @@ class InternalLRUCache(BaseCache):
if len(self._cache) >= self._cache_size:
self._cache.popitem(last=False)
self._cache[key] = value
async def scan(self, token, amount_hint):
if token is None:
token = 0
total = len(self._cache)
left = total - token
if left > 0:
amount = min(left, amount_hint)
new_token = token + amount if token + amount < total else None
# Take "amount" of oldest
result = list(islice(self._cache.items(), amount))
for key, _ in result: # for LRU consistency
await self.get(key)
return new_token, result
return None, []
async def get_proactive_fetch_ts(self):
return self._proactive_fetch_ts
async def set_proactive_fetch_ts(self, timestamp):
self._proactive_fetch_ts = timestamp
import asyncio
import logging
import time
from postfix_mta_sts_resolver import constants
from postfix_mta_sts_resolver.base_cache import CacheEntry
from postfix_mta_sts_resolver.resolver import STSResolver, STSFetchResult
# pylint: disable=too-many-instance-attributes
class STSProactiveFetcher:
def __init__(self, cfg, loop, cache):
self._shutdown_timeout = cfg['shutdown_timeout']
self._pf_interval = cfg['proactive_policy_fetching']['interval']
self._pf_concurrency_limit = cfg['proactive_policy_fetching']['concurrency_limit']
self._pf_grace_ratio = cfg['proactive_policy_fetching']['grace_ratio']
self._logger = logging.getLogger("PF")
self._loop = loop
self._cache = cache
self._periodic_fetch_task = None
self._resolver = STSResolver(loop=loop,
timeout=cfg["default_zone"]["timeout"])
async def process_domain(self, domain_queue):
async def update(cached):
status, policy = await self._resolver.resolve(domain, cached.pol_id)
if status is STSFetchResult.VALID:
pol_id, pol_body = policy
updated = CacheEntry(ts, pol_id, pol_body)
await self._cache.safe_set(domain, updated, self._logger)
elif status is STSFetchResult.NOT_CHANGED:
updated = CacheEntry(ts, cached.pol_id, cached.pol_body)
await self._cache.safe_set(domain, updated, self._logger)
else:
self._logger.warning("Domain %s does not have a valid policy.", domain)
while True: # Run until cancelled
cache_item = await domain_queue.get()
ts = time.time() # pylint: disable=invalid-name
try:
domain, cached = cache_item
if ts - cached.ts < self._pf_interval / self._pf_grace_ratio:
self._logger.debug("Domain %s skipped (cache recent enough).", domain)
else:
await update(cached)
except asyncio.CancelledError: # pragma: no cover pylint: disable=try-except-raise
raise
except Exception as exc: # pragma: no cover
self._logger.exception("Unhandled exception: %s", exc)
finally:
domain_queue.task_done()
async def iterate_domains(self):
self._logger.info("Proactive policy fetching "
"for all domains in cache started...")
# Create domain processor tasks
domain_processors = []
domain_queue = asyncio.Queue(maxsize=constants.DOMAIN_QUEUE_LIMIT)
for _ in range(self._pf_concurrency_limit):
domain_processor = self._loop.create_task(self.process_domain(domain_queue))
domain_processors.append(domain_processor)
# Produce work for domain processors
try:
token = None
while True:
token, cache_items = await self._cache.scan(token, constants.DOMAIN_QUEUE_LIMIT)
self._logger.debug("Enqueued %d domains for processing.", len(cache_items))
for cache_item in cache_items:
await domain_queue.put(cache_item)
if token is None:
break
# Wait for queue to clear
await domain_queue.join()
# Clean up the domain processors
finally:
for domain_processor in domain_processors:
domain_processor.cancel()
await asyncio.gather(*domain_processors, return_exceptions=True)
# Update the proactive fetch timestamp
await self._cache.set_proactive_fetch_ts(time.time())
self._logger.info("Proactive policy fetching "
"for all domains in cache finished.")
async def fetch_periodically(self):
while True: # Run until cancelled
next_fetch_ts = await self._cache.get_proactive_fetch_ts() + self._pf_interval
sleep_duration = max(constants.MIN_PROACTIVE_FETCH_INTERVAL,
next_fetch_ts - time.time() + 1)
self._logger.debug("Sleeping for %ds until next fetch.", sleep_duration)
await asyncio.sleep(sleep_duration)
await self.iterate_domains()
async def start(self):
self._periodic_fetch_task = self._loop.create_task(self.fetch_periodically())
async def stop(self):
self._periodic_fetch_task.cancel()
try:
self._logger.warning("Awaiting periodic fetching to finish...")
await self._periodic_fetch_task
except asyncio.CancelledError: # pragma: no cover
pass
......@@ -55,6 +55,32 @@ class RedisCache(BaseCache):
pipe.zremrangebyrank(key, 0, -2)
await pipe.execute()
async def scan(self, token, amount_hint):
assert self._pool is not None
if token is None:
token = b'0'
new_token, keys = await self._pool.scan(cursor=token, count=amount_hint)
if not new_token:
new_token = None
result = []
for key in keys:
key = key.decode('utf-8')
if key != '_metadata':
result.append((key, await self.get(key)))
return new_token, result
async def get_proactive_fetch_ts(self):
assert self._pool is not None
val = await self._pool.hget('_metadata', 'proactive_fetch_ts')
return 0 if not val else float(val.decode('utf-8'))
async def set_proactive_fetch_ts(self, timestamp):
assert self._pool is not None
val = str(timestamp).encode('utf-8')
await self._pool.hset('_metadata', 'proactive_fetch_ts', val)
async def teardown(self):
assert self._pool is not None
self._pool.close()
......
......@@ -9,7 +9,7 @@ from functools import partial
from .resolver import STSResolver, STSFetchResult
from .constants import QUEUE_LIMIT, CHUNK, REQUEST_LIMIT
from .utils import create_custom_socket, create_cache, filter_domain, is_ipaddr
from .utils import create_custom_socket, filter_domain, is_ipaddr
from .base_cache import CacheEntry
from . import netstring
......@@ -19,7 +19,7 @@ ZoneEntry = collections.namedtuple('ZoneEntry', ('strict', 'resolver'))
# pylint: disable=too-many-instance-attributes
class STSSocketmapResponder:
def __init__(self, cfg, loop):
def __init__(self, cfg, loop, cache):
self._logger = logging.getLogger("STS")
self._loop = loop
if cfg.get('path') is not None:
......@@ -44,12 +44,28 @@ class STSSocketmapResponder:
timeout=zone["timeout"])))
for k, zone in cfg["zones"].items())
# Construct cache
self._cache = create_cache(cfg["cache"]["type"],
cfg["cache"]["options"])
self._cache = cache
self._children = set()
self._server = None
# Check if cached record is nonexistent or stale
def is_stale(self, cached):
ts = time.time() # pylint: disable=invalid-name
# Nonexistent ?
if cached is None:
return True
# Expired grace period ?
if ts - cached.ts > self._grace:
return True
# Expired policy ?
if cached.pol_body['max_age'] + cached.ts < ts:
return True
return False
async def start(self):
def _spawn(reader, writer):
def done_cb(task, fut):
......@@ -59,8 +75,6 @@ class STSSocketmapResponder:
self._children.add(task)
self._logger.debug("len(self._children) = %d", len(self._children))
await self._cache.setup()
if self._unix:
self._server = await asyncio.start_unix_server(_spawn, path=self._path)
if self._sockmode is not None:
......@@ -113,7 +127,6 @@ class STSSocketmapResponder:
await asyncio.sleep(1)
if not self._children:
break
await self._cache.teardown()
async def sender(self, queue, writer):
def cleanup_queue():
......@@ -146,15 +159,6 @@ class STSSocketmapResponder:
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
async def process_request(self, raw_req):
# Update local cache
async def cache_set(domain, entry):
try:
await self._cache.set(domain, entry)
except asyncio.CancelledError: # pragma: no cover pylint: disable=try-except-raise
raise
except Exception as exc: # pragma: no cover
self._logger.exception("Cache set failed: %s", str(exc))
have_policy = True
# Parse request and canonicalize domain
......@@ -181,10 +185,9 @@ class STSSocketmapResponder:
self._logger.exception("Cache get failed: %s", str(exc))
cached = None
ts = time.time() # pylint: disable=invalid-name
# Check if cached record exists and recent enough to omit
# DNS lookup and cache update
if cached is None or ts - cached.ts > self._grace:
if self.is_stale(cached):
ts = time.time() # pylint: disable=invalid-name
self._logger.debug("Lookup PERFORMED: domain = %s", domain)
# Check if newer policy exists or
# retrieve policy from scratch if there is no cached one
......@@ -193,11 +196,11 @@ class STSSocketmapResponder:
if status is STSFetchResult.NOT_CHANGED:
cached = CacheEntry(ts, cached.pol_id, cached.pol_body)
await cache_set(domain, cached)
await self._cache.safe_set(domain, cached, self._logger)
elif status is STSFetchResult.VALID:
pol_id, pol_body = policy
cached = CacheEntry(ts, pol_id, pol_body)
await cache_set(domain, cached)
await self._cache.safe_set(domain, cached, self._logger)
else:
if cached is None:
have_policy = False
......@@ -208,7 +211,6 @@ class STSSocketmapResponder:
else:
self._logger.debug("Lookup skipped: domain = %s", domain)
if have_policy:
mode = cached.pol_body['mode']
# pylint: disable=no-else-return
......
......@@ -79,6 +79,7 @@ class SqliteCache(BaseCache):
self._filename = filename
self._threads = threads
self._timeout = timeout
self._last_proactive_fetch_ts_id = 1
sqlitelogger = logging.getLogger("aiosqlite")
if not sqlitelogger.hasHandlers(): # pragma: no cover
sqlitelogger.addHandler(logging.NullHandler())
......@@ -97,6 +98,8 @@ class SqliteCache(BaseCache):
init_queries=conn_init)
await self._pool.prepare()
queries = [
"create table if not exists proactive_fetch_ts "
"(id integer primary key, last_fetch_ts integer)",
"create table if not exists sts_policy_cache "
"(domain text, ts integer, pol_id text, pol_body text)",
"create unique index if not exists sts_policy_domain on sts_policy_cache (domain)",
......@@ -108,6 +111,28 @@ class SqliteCache(BaseCache):
await cur.execute(q)
await conn.commit()
async def get_proactive_fetch_ts(self):
async with self._pool.borrow(self._timeout) as conn:
async with conn.execute('select last_fetch_ts from '
'proactive_fetch_ts where id = ?',
(self._last_proactive_fetch_ts_id,)) as cur:
res = await cur.fetchone()
return int(res[0]) if res is not None else 0
async def set_proactive_fetch_ts(self, timestamp):
async with self._pool.borrow(self._timeout) as conn:
try:
await conn.execute('insert into proactive_fetch_ts (last_fetch_ts, id) '
'values (?, ?)',
(int(timestamp), self._last_proactive_fetch_ts_id))
await conn.commit()
except sqlite3.IntegrityError:
await conn.execute('update proactive_fetch_ts '
'set last_fetch_ts = ? where id = ?',
(int(timestamp), self._last_proactive_fetch_ts_id))
await conn.commit()
async def get(self, key):
async with self._pool.borrow(self._timeout) as conn:
async with conn.execute('select ts, pol_id, pol_body from '
......@@ -138,5 +163,29 @@ class SqliteCache(BaseCache):
(int(ts), pol_id, pol_body, key, int(ts)))
await conn.commit()
async def scan(self, token, amount_hint):
if token is None:
token = 1
async with self._pool.borrow(self._timeout) as conn:
async with conn.execute('select rowid, ts, pol_id, pol_body, domain from '
'sts_policy_cache where rowid between ? and ?',
(token, token + amount_hint - 1)) as cur:
res = await cur.fetchall()
if res:
result = []
new_token = token
for row in res:
rowid, ts, pol_id, pol_body, domain = row
ts = int(ts)
rowid = int(rowid)
new_token = max(new_token, rowid)
pol_body = json.loads(pol_body)
result.append((domain, CacheEntry(ts, pol_id, pol_body)))
new_token += 1
return new_token, result
else:
return None, []
async def teardown(self):
await self._pool.stop()
......@@ -86,6 +86,17 @@ def populate_cfg_defaults(cfg):
defaults.SHUTDOWN_TIMEOUT)
cfg['cache_grace'] = cfg.get('cache_grace', defaults.CACHE_GRACE)
if 'proactive_policy_fetching' not in cfg:
cfg['proactive_policy_fetching'] = {}
cfg['proactive_policy_fetching']['enabled'] = cfg['proactive_policy_fetching'].\
get('enabled', defaults.PROACTIVE_FETCH_ENABLED)
cfg['proactive_policy_fetching']['interval'] = cfg['proactive_policy_fetching'].\
get('interval', defaults.PROACTIVE_FETCH_INTERVAL)
cfg['proactive_policy_fetching']['concurrency_limit'] = cfg['proactive_policy_fetching'].\
get('concurrency_limit', defaults.PROACTIVE_FETCH_CONCURRENCY_LIMIT)
cfg['proactive_policy_fetching']['grace_ratio'] = cfg['proactive_policy_fetching'].\
get('grace_ratio', defaults.PROACTIVE_FETCH_GRACE_RATIO)
if 'cache' not in cfg:
cfg['cache'] = {}
......
......@@ -7,7 +7,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read() # pylint: disable=invalid-name
setup(name='postfix_mta_sts_resolver',
version='0.7.5',
version='0.8.0',
description='Daemon which provides TLS client policy for Postfix '
'via socketmap, according to domain MTA-STS policy',
url='https://github.com/Snawoot/postfix-mta-sts-resolver',
......
......@@ -3,8 +3,8 @@ import os
import pytest
from postfix_mta_sts_resolver.utils import enable_uvloop
from postfix_mta_sts_resolver.utils import enable_uvloop, create_cache, populate_cfg_defaults
from async_generator import yield_, async_generator
@pytest.fixture(scope="session")
def event_loop():
......@@ -14,3 +14,24 @@ def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="function")
@async_generator
async def function_cache_fixture():
cfg = populate_cfg_defaults(None)
cache = create_cache(cfg['cache']['type'],
cfg['cache']['options'])
await cache.setup()
await yield_(cache)
await cache.teardown()
@pytest.fixture(scope="module")
@async_generator
async def module_cache_fixture():
cfg = populate_cfg_defaults(None)
cache = create_cache(cfg['cache']['type'],
cfg['cache']['options'])
await cache.setup()
await yield_(cache)
await cache.teardown()
\ No newline at end of file
......@@ -2,6 +2,46 @@ import tempfile
import pytest
import postfix_mta_sts_resolver.utils as utils
import postfix_mta_sts_resolver.base_cache as base_cache
from postfix_mta_sts_resolver import constants
async def setup_cache(cache_type, cache_opts):
tmpfile = None
if cache_type == 'sqlite':
tmpfile = tempfile.NamedTemporaryFile()
cache_opts["filename"] = tmpfile.name
cache = utils.create_cache(cache_type, cache_opts)
await cache.setup()
if cache_type == 'redis':
cache._pool.flushdb()
return cache, tmpfile
@pytest.mark.parametrize("cache_type,cache_opts,safe_set", [
("internal", {}, True),
("internal", {}, False),
("sqlite", {}, True),
("sqlite", {}, False),
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}, True),
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}, False)
])
@pytest.mark.asyncio
async def test_cache_lifecycle(cache_type, cache_opts, safe_set):
cache, tmpfile = await setup_cache(cache_type, cache_opts)
try:
assert await cache.get("nonexistent") == None
stored = base_cache.CacheEntry(0, "pol_id", "pol_body")
if safe_set:
await cache.safe_set("test", stored, None)
await cache.safe_set("test", stored, None) # second time for testing conflicting insert
else:
await cache.set("test", stored)
await cache.set("test", stored) # second time for testing conflicting insert
assert await cache.get("test") == stored
finally:
await cache.teardown()
if cache_type == 'sqlite':
tmpfile.close()
@pytest.mark.parametrize("cache_type,cache_opts", [
("internal", {}),
......@@ -9,20 +49,74 @@ import postfix_mta_sts_resolver.base_cache as base_cache
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}),
])
@pytest.mark.asyncio
async def test_cache_lifecycle(cache_type, cache_opts):
if cache_type == 'sqlite':
tmpfile = tempfile.NamedTemporaryFile()
cache_opts["filename"] = tmpfile.name
cache = utils.create_cache(cache_type, cache_opts)
await cache.setup()
assert await cache.get("nonexistent") == None
stored = base_cache.CacheEntry(0, "pol_id", "pol_body")
await cache.set("test", stored)
await cache.set("test", stored) # second time for testing conflicting insert
assert await cache.get("test") == stored
await cache.teardown()
if cache_type == 'sqlite':
tmpfile.close()
async def test_proactive_fetch_ts_lifecycle(cache_type, cache_opts):
cache, tmpfile = await setup_cache(cache_type, cache_opts)
try:
assert await cache.get_proactive_fetch_ts() >= 0 # works with empty db
await cache.set_proactive_fetch_ts(123)
await cache.set_proactive_fetch_ts(123) # second time for testing conflicting insert
assert await cache.get_proactive_fetch_ts() == 123
await cache.set_proactive_fetch_ts(321) # updating the db works
assert await cache.get_proactive_fetch_ts() == 321
finally:
await cache.teardown()
if cache_type == 'sqlite':
tmpfile.close()
@pytest.mark.parametrize("cache_type,cache_opts,n_items,batch_size_limit", [
("internal", {}, 3, 1),
("internal", {}, 3, 2),
("internal", {}, 3, 3),
("internal", {}, 3, 4),
("internal", {}, 0, 4),
("internal", {}, constants.DOMAIN_QUEUE_LIMIT*2, constants.DOMAIN_QUEUE_LIMIT),
("sqlite", {}, 3, 1),
("sqlite", {}, 3, 2),
("sqlite", {}, 3, 3),
("sqlite", {}, 3, 4),
("sqlite", {}, 0, 4),
("sqlite", {}, constants.DOMAIN_QUEUE_LIMIT*2, constants.DOMAIN_QUEUE_LIMIT),
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}, 3, 1),
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}, 3, 2),
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}, 3, 3),
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}, 3, 4),
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}, 0, 4),
("redis", {"address": "redis://127.0.0.1/0?timeout=5"}, constants.DOMAIN_QUEUE_LIMIT*2, constants.DOMAIN_QUEUE_LIMIT),
])
@pytest.mark.timeout(10)
@pytest.mark.asyncio
async def test_scanning_in_batches(cache_type, cache_opts, n_items, batch_size_limit):
# Prepare
cache, tmpfile = await setup_cache(cache_type, cache_opts)
data = []
for n in range(n_items):
item = ("test{:04d}".format(n+1), base_cache.CacheEntry(n+1, "pol_id", "pol_body"))
data.append(item)
await cache.set(*item)
# Test (scan)
token = None
scanned = []
while True:
token, cache_items = await cache.scan(token, batch_size_limit)
for cache_item in cache_items:
scanned.append(cache_item)
if token is None:
break
try:
# Verify scanned data is same as inserted (order agnostic)
assert len(scanned) == len(data)
assert sorted(scanned) == sorted(data)
# For internal LRU, verify it's scanned from LRU to MRU record
if cache_type == "internal":
assert scanned == data
finally:
await cache.teardown()
if cache_type == 'sqlite':
tmpfile.close()
@pytest.mark.asyncio
async def test_capped_cache():
......
import asyncio
import time
import pytest
from postfix_mta_sts_resolver import base_cache, utils
from postfix_mta_sts_resolver.proactive_fetcher import STSProactiveFetcher
from async_generator import yield_, async_generator
from postfix_mta_sts_resolver.utils import populate_cfg_defaults, create_cache
@pytest.fixture
@async_generator
async def cache():
cfg = populate_cfg_defaults(None)
cache = create_cache(cfg['cache']['type'],
cfg['cache']['options'])
await cache.setup()
await yield_(cache)
await cache.teardown()
@pytest.mark.parametrize("domain, init_policy_id, expected_policy_id, expected_update",
[("good.loc", "19990907T090909", "20180907T090909", True),
("good.loc", "20180907T090909", "20180907T090909", True),
("valid-none.loc", "19990907T090909", "20180907T090909", True),
("blackhole.loc", "19990907T090909", "19990907T090909", False),
("bad-record1.loc", "19990907T090909", "19990907T090909", False),
("bad-policy1.loc", "19990907T090909", "19990907T090909", False)
])
@pytest.mark.asyncio
@pytest.mark.timeout(10)
async def test_cache_update(event_loop, cache,
domain, init_policy_id, expected_policy_id, expected_update):
cfg = utils.populate_cfg_defaults(None)
cfg['proactive_policy_fetching']['enabled'] = True
cfg['proactive_policy_fetching']['interval'] = 1
cfg['proactive_policy_fetching']['grace_ratio'] = 1000
cfg["default_zone"]["timeout"] = 1
cfg['shutdown_timeout'] = 1
await cache.set(domain, base_cache.CacheEntry(0, init_policy_id, {}))
pf = STSProactiveFetcher(cfg, event_loop, cache)
await pf.start()
# Wait for policy fetcher to do its rounds
await asyncio.sleep(3)
# Verify
assert time.time() - await cache.get_proactive_fetch_ts() < 10
result = await cache.get(domain)
assert result
assert result.pol_id == expected_policy_id
if expected_update:
assert time.time() - result.ts < 10 # update
# Due to an id change, a new body must be fetched
if init_policy_id != expected_policy_id:
assert result.pol_body
# Otherwise we don't fetch a new policy body
else:
assert not result.pol_body
else:
assert result.ts == 0
assert not result.pol_body
await pf.stop()
@pytest.mark.asyncio
@pytest.mark.timeout(10)
async def test_no_cache_update_during_grace_period(event_loop, cache):
cfg = utils.populate_cfg_defaults(None)
cfg['proactive_policy_fetching']['enabled'] = True
cfg['proactive_policy_fetching']['interval'] = 86400
cfg['proactive_policy_fetching']['grace_ratio'] = 2.0
cfg['shutdown_timeout'] = 1
init_record = base_cache.CacheEntry(time.time() - 1, "19990907T090909", {})
await cache.set("good.loc", init_record)
pf = STSProactiveFetcher(cfg, event_loop, cache)
await pf.start()
# Wait for policy fetcher to do its round
await asyncio.sleep(3)
# Verify
assert time.time() - await cache.get_proactive_fetch_ts() < 10
result = await cache.get("good.loc")
assert result == init_record # no update (cached being fresh enough)
await pf.stop()
@pytest.mark.asyncio
@pytest.mark.timeout(10)
async def test_respect_previous_proactive_fetch_ts(event_loop, cache):
cfg = utils.populate_cfg_defaults(None)
cfg['proactive_policy_fetching']['enabled'] = True
cfg['proactive_policy_fetching']['interval'] = 86400
cfg['proactive_policy_fetching']['grace_ratio'] = 2.0
cfg['shutdown_timeout'] = 1
previous_proactive_fetch_ts = time.time() - 1
init_record = base_cache.CacheEntry(0, "19990907T090909", {})
await cache.set("good.loc", init_record)
await cache.set_proactive_fetch_ts(previous_proactive_fetch_ts)
pf = STSProactiveFetcher(cfg, event_loop, cache)
await pf.start()
# Wait for policy fetcher to do its potential work
await asyncio.sleep(3)
# Verify
assert previous_proactive_fetch_ts == await cache.get_proactive_fetch_ts()
result = await cache.get("good.loc")
assert result == init_record # no update
await pf.stop()
......@@ -19,11 +19,15 @@ async def responder(event_loop):
import postfix_mta_sts_resolver.utils as utils
cfg = utils.populate_cfg_defaults(None)
cfg["zones"]["test2"] = cfg["default_zone"]
resp = STSSocketmapResponder(cfg, event_loop)
cache = utils.create_cache(cfg['cache']['type'],
cfg['cache']['options'])
await cache.setup()
resp = STSSocketmapResponder(cfg, event_loop, cache)
await resp.start()
result = resp, cfg['host'], cfg['port']
await yield_(result)
await resp.stop()
await cache.teardown()
@pytest.fixture(scope="module")
@async_generator
......@@ -31,11 +35,15 @@ async def unix_responder(event_loop):
import postfix_mta_sts_resolver.utils as utils
cfg = utils.populate_cfg_defaults({'path': '/tmp/mta-sts.sock', 'mode': 0o666})
cfg["zones"]["test2"] = cfg["default_zone"]
resp = STSSocketmapResponder(cfg, event_loop)
cache = utils.create_cache(cfg['cache']['type'],
cfg['cache']['options'])
await cache.setup()
resp = STSSocketmapResponder(cfg, event_loop, cache)
await resp.start()
result = resp, cfg['path']
await yield_(result)
await resp.stop()
await cache.teardown()
buf_sizes = [4096, 128, 16, 1]
reqresps = list(load_testdata('refdata'))
......
......@@ -66,12 +66,12 @@ async def test_responder_expiration(event_loop):
"max_age": 1,
}
await cache.set("no-record.loc", base_cache.CacheEntry(0, "0", pol_body))
await cache.teardown()
resp = STSSocketmapResponder(cfg, event_loop)
resp = STSSocketmapResponder(cfg, event_loop, cache)
await resp.start()
try:
result = await query(cfg['host'], cfg['port'], 'no-record.loc')
assert result == b'NOTFOUND '
finally:
await resp.stop()
await cache.teardown()
......@@ -19,11 +19,15 @@ async def responder(event_loop):
cfg = utils.populate_cfg_defaults({"default_zone": {"strict_testing": True}})
cfg["zones"]["test2"] = cfg["default_zone"]
cfg["port"] = 28461
resp = STSSocketmapResponder(cfg, event_loop)
cache = utils.create_cache(cfg['cache']['type'],
cfg['cache']['options'])
await cache.setup()
resp = STSSocketmapResponder(cfg, event_loop, cache)
await resp.start()
result = resp, cfg['host'], cfg['port']
await yield_(result)
await resp.stop()
await cache.teardown()
buf_sizes = [4096, 128, 16, 1]
reqresps = list(load_testdata('refdata_strict'))
......
......@@ -19,11 +19,15 @@ async def responder(event_loop):
cfg["shutdown_timeout"] = 1
cfg["cache_grace"] = 0
cfg["zones"]["test2"] = cfg["default_zone"]
resp = STSSocketmapResponder(cfg, event_loop)
cache = utils.create_cache(cfg['cache']['type'],
cfg['cache']['options'])
await cache.setup()
resp = STSSocketmapResponder(cfg, event_loop, cache)
await resp.start()
result = resp, cfg['host'], cfg['port']
await yield_(result)
await resp.stop()
await cache.teardown()
@pytest.mark.asyncio
@pytest.mark.timeout(5)
......@@ -121,4 +125,3 @@ async def test_fast_expire(responder):
assert answer_a == answer_b == b'OK secure match=mail.loc'
finally:
writer.close()
......@@ -22,6 +22,10 @@ def test_populate_cfg_defaults(cfg):
assert isinstance(res['port'], int)
assert 0 < res['port'] < 65536
assert isinstance(res['cache_grace'], (int, float))
assert isinstance(res['proactive_policy_fetching']['enabled'], bool)
assert isinstance(res['proactive_policy_fetching']['interval'], int)
assert isinstance(res['proactive_policy_fetching']['concurrency_limit'], int)
assert isinstance(res['proactive_policy_fetching']['grace_ratio'], (int, float))
assert isinstance(res['cache'], collections.abc.Mapping)
assert res['cache']['type'] in ('redis', 'sqlite', 'internal')
assert isinstance(res['default_zone'], collections.abc.Mapping)
......