...
 
Commits (3)
......@@ -200,7 +200,7 @@ ignore-docstrings=yes
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
min-similarity-lines=10
[MISCELLANEOUS]
......
......@@ -4,10 +4,17 @@ import argparse
import asyncio
from .resolver import STSResolver
from . import utils
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-v", "--verbosity",
help="logging verbosity",
type=utils.check_loglevel,
choices=utils.LogLevel,
default=utils.LogLevel.info)
parser.add_argument("domain",
help="domain to fetch MTA-STS policy from")
parser.add_argument("known_version",
......@@ -20,10 +27,11 @@ def parse_args():
def main(): # pragma: no cover
args = parse_args()
loop = asyncio.get_event_loop()
resolver = STSResolver(loop=loop)
result = loop.run_until_complete(resolver.resolve(args.domain, args.known_version))
with utils.AsyncLoggingHandler(None) as log_handler:
utils.setup_logger('RES', args.verbosity, log_handler)
loop = asyncio.get_event_loop()
resolver = STSResolver(loop=loop)
result = loop.run_until_complete(resolver.resolve(args.domain, args.known_version))
print(result)
......
......@@ -12,21 +12,14 @@ from . import utils
from . import defaults
from .proactive_fetcher import STSProactiveFetcher
from .responder import STSSocketmapResponder
from .utils import create_cache
def parse_args():
def check_loglevel(arg):
try:
return utils.LogLevel[arg]
except (IndexError, KeyError):
raise argparse.ArgumentTypeError("%s is not valid loglevel" % (repr(arg),))
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-v", "--verbosity",
help="logging verbosity",
type=check_loglevel,
type=utils.check_loglevel,
choices=utils.LogLevel,
default=utils.LogLevel.info)
parser.add_argument("-c", "--config",
......@@ -67,8 +60,8 @@ async def amain(cfg, loop): # pragma: no cover
proactive_fetch_enabled = cfg['proactive_policy_fetching']['enabled']
# Create policy cache
cache = create_cache(cfg["cache"]["type"],
cfg["cache"]["options"])
cache = utils.create_cache(cfg["cache"]["type"],
cfg["cache"]["options"])
await cache.setup()
# Construct request handler
......@@ -109,6 +102,7 @@ def main(): # pragma: no cover
logger = utils.setup_logger('MAIN', args.verbosity, log_handler)
utils.setup_logger('STS', args.verbosity, log_handler)
utils.setup_logger('PF', args.verbosity, log_handler)
utils.setup_logger('RES', args.verbosity, log_handler)
logger.info("MTA-STS daemon starting...")
# Read config and populate with defaults
......
import asyncio
import enum
import logging
from io import BytesIO
import aiodns
......@@ -25,6 +26,8 @@ class STSFetchResult(enum.Enum):
_HEADERS = {"User-Agent": defaults.USER_AGENT}
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-statements
class STSResolver:
def __init__(self, *, timeout=defaults.TIMEOUT, loop):
self._loop = loop
......@@ -32,6 +35,7 @@ class STSResolver:
self._resolver = aiodns.DNSResolver(timeout=timeout, loop=loop)
self._http_timeout = aiohttp.ClientTimeout(total=timeout)
self._proxy_info = aiohttp.helpers.proxies_from_env().get('https', None)
self._logger = logging.getLogger("RES")
if self._proxy_info is None:
self._proxy = None
......@@ -49,6 +53,8 @@ class STSResolver:
# Construct name of corresponding MTA-STS DNS record for domain
sts_txt_domain = '_mta-sts.' + domain
self._logger.debug("Got STS resolve request: sts_txt_domain=%s, "
"known_id=%s", sts_txt_domain, last_known_id)
# Try to fetch it
try:
......@@ -89,6 +95,9 @@ class STSResolver:
or 'id' not in mta_sts_record):
return STSFetchResult.NONE, None
self._logger.debug("Parsed STS record for domain %s: %s",
repr(domain), repr(mta_sts_record))
# Obtain policy ID and return NOT_CHANGED if ID is equal to last known
if mta_sts_record['id'] == last_known_id:
return STSFetchResult.NOT_CHANGED, None
......@@ -125,12 +134,16 @@ class STSResolver:
charset = (resp.charset if resp.charset is not None
else 'ascii')
policy_text = policy_file.getvalue().decode(charset)
except Exception:
except Exception as exc:
self._logger.warning("STS policy fetch for domain %s failed with "
"error: %s", repr(domain), str(exc))
return STSFetchResult.FETCH_ERROR, None
# Parse policy
pol = parse_mta_sts_policy(policy_text)
self._logger.debug("Parsed policy for domain %s: %s", domain, repr(pol))
# Validate policy
if pol.get('version', None) != 'STSv1':
return STSFetchResult.FETCH_ERROR, None
......
......@@ -4,6 +4,7 @@ import logging.handlers
import asyncio
import socket
import queue
import argparse
import yaml
......@@ -226,3 +227,10 @@ def create_cache(cache_type, options):
else:
raise NotImplementedError("Unsupported cache type!")
return cache
def check_loglevel(arg):
try:
return LogLevel[arg]
except (IndexError, KeyError):
raise argparse.ArgumentTypeError("%s is not valid loglevel" % (repr(arg),))
......@@ -7,7 +7,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read() # pylint: disable=invalid-name
setup(name='postfix_mta_sts_resolver',
version='0.8.1',
version='0.8.2',
description='Daemon which provides TLS client policy for Postfix '
'via socketmap, according to domain MTA-STS policy',
url='https://github.com/Snawoot/postfix-mta-sts-resolver',
......