#!/usr/bin/env python3
'''
This is used to set and unset the DNS.

On ubuntu, and fedora systemd uses its own resolver, systemd-resovled.

On Debian and RHEL, network-manager uses resolv.conf.

In this case, an instance of dnsmasq is created.

It requires systemd and dbus to work.

It supports systems that use systemd, dbus and network-manager
and also works if network-manager is not used.

If a different network manager daemon is in use (wicd), it probably
doesn't work.
'''
from getopt import getopt
import os
import sys
import socket
from itertools import chain, product
from syslog import syslog, openlog, LOG_INFO, LOG_ERR, LOG_WARNING

import fcntl
import struct
import hashlib
import subprocess

import nm

RUN_FILE = '/run/cz_resolver'
SEARCH_FILE = '/run/appgatesearch'
RESOLVE_FILE = '/run/appgateresolve'
DNSMASQ_FILE = '/etc/NetworkManager/dnsmasq.d/appgate.conf'
RESOLVED_NAME = "org.freedesktop.resolve1"
RESOLVED_DBUS_PATH = "/org/freedesktop/resolve1"

PRINT_TO_LOG = os.stat(1) == os.stat('/dev/null')


def lprint(string, level=LOG_INFO):
    syslog(level, string)
    kwargs = {}
    if level <= LOG_ERR:
        kwargs['file'] = sys.stderr
    print(string, **kwargs)


def systemcheck(args):
    if subprocess.call(args) != 0:
        lprint(f'command failed: {args}', LOG_ERR)


def create_search_file(search):
    '''
    Creates a file with a search directive to append to our
    resolv.conf. Should only add plain domains, not the
    ones added with dns.server... syntax
    '''
    # Add the current search domains too
    prevsearch = []
    with open('/etc/resolv.conf') as f:
        for line in f:
            if line.startswith('search '):
                prevsearch = line.strip().split(' ')[1:]

    search = ' '.join({i[1] for i in search if not i[0]}.union(set(prevsearch)))
    if len(search) > 0:
        with open(SEARCH_FILE, 'w') as f:
            lprint(f'Adding search domains {search}')
            f.write(f'search {search}\n')

def set_default_resolved(tundev, search, servers):
    """
    Set default link or systemd-resolved
    :param tundev: The name of the TUN device
    :param search: List of domain tuples
    :param server: the DNS server to use as default
    """
    lprint('Setting DNS for systemd-resolved')
    ROUTE_ONLY_DOMAIN = 'true'
    search = [('.', ROUTE_ONLY_DOMAIN)] + [(i, ROUTE_ONLY_DOMAIN) for _, i in set(search)]
    tun_index = socket.if_nametoindex(tundev)
    nm.set_resolved(tun_index, servers, search)


def set_dns_resolved(tundev, search, servers):
    """
    Set DNS for systemd-resolved
    :param tundev: The name of the TUN device
    :param search: List of domain tuples
    :param servers: List of DNS servers
    """
    lprint('Setting DNS for systemd-resolved')
    ROUTE_ONLY_DOMAIN = 'true'
    search = [(i, ROUTE_ONLY_DOMAIN) for _, i in set(search)]
    tun_index = socket.if_nametoindex(tundev)
    nm.set_resolved(tun_index, servers, search)


def set_dns_tun(tundev, search, servers):
    '''
    Uses network-manager to setup DNS for the tun
    connection.

    In this way the DNS setting go away when the tun is
    disconnected.
    '''
    lprint('Use nm to set dns for the tun')
    search = ' '.join(i[1] for i in search if not i[0])
    if not nm.has_nm():
        return
    serversv4 = [x for x in servers if not ':' in x]
    serversv6 = [x for x in servers if ':' in x]
    if len(serversv4) > 0:
        serversv4 = ' '.join(serversv4)
        systemcheck(['nmcli', 'c', 'modify', tundev, 'ipv4.dns', serversv4])
        systemcheck(['nmcli', 'c', 'modify', tundev, 'ipv4.dns-search', search])
    if len(serversv6) > 0:
        serversv6 = ' '.join(serversv6)
        systemcheck(['nmcli', 'c', 'modify', tundev, 'ipv6.dns', serversv6])
        systemcheck(['nmcli', 'c', 'modify', tundev, 'ipv6.dns-search', search])
    systemcheck(['nmcli', 'c', 'up', tundev])


def add_dns(new_dns):
    '''
    new_dns is an iterable.

    for every element,
    i[0] MUST exist and is the DNS address expressed as a string
    i[i] MAY exist, and is the domain name to use for that DNS

    For example:
    (('10.0.0.1', 'mycompany.whatever'), ('8.8.8.8'))
    '''
    lprint('Using appgate-resolver')
    dns = []
    path = 'com.appgate.resolver'
    servers = [list(i) for i in chain(dns, new_dns)]
    nm.set_dnsmasq(path, servers)


def create_run_file():
    with open(RUN_FILE, 'w') as f:
        f.write(' ')


def printhelp():
    print("Usage: %s [OPTIONS] [SERVERS]")
    print()
    print('  -h, --help           Show this help and exit')
    print('      --reset          Resets the DNS to their previous value and exit')
    print('      --get            List the current DNS servers and exit')
    print('      --hash           Print a hash of the current DNS settings')
    print('      --servers        List of DNS servers, comma separated')
    print('      --domains        List of DNS domains, comma separated')
    print('      --tundev         The name of the tun device')
    sys.exit(0)


def reset(tundev):
    '''Reset the DNS status'''
    if not nm.has_dnsmasq():
        systemcheck(['systemctl', 'stop', 'appgate-dumb-resolver'])
    elif nm.has_dnsmasq():
        systemcheck(['systemctl', 'stop', 'appgate-resolver'])

    try:
        os.unlink(RUN_FILE)
    except FileNotFoundError:
        lprint('DNS has not been changed')
        sys.exit(0)

    if os.path.exists(SEARCH_FILE):
        os.unlink(SEARCH_FILE)
    if os.path.exists(RESOLVE_FILE):
        os.unlink(RESOLVE_FILE)

    try:
        os.unlink(DNSMASQ_FILE)
    except FileNotFoundError:
        pass

    if nm.has_nm() and tundev is not None:
        systemcheck(['nmcli', 'c', 'delete', tundev])

    sys.exit(0)


def print_hash():
    '''Hash of current DNS servers'''
    hash = ''
    if nm.is_service_active('systemd_2dresolved') and os.path.islink('/etc/resolv.conf'):
        hash = hashlib.sha256(''.join(map(str, nm.get_resolved_dns())).encode('utf-8')).hexdigest()
    else:
        h = hashlib.sha256()

        with open('/etc/resolv.conf', 'rb') as file:
            while True:
                chunk = file.read(h.block_size)
                if not chunk:
                    break
                h.update(chunk)

        hash = h.hexdigest()
    lprint(f'CONFIG: {hash}')
    sys.exit(0)


def get():
    '''List the current DNS servers'''
    if nm.is_service_active('systemd_2dresolved') and os.path.islink('/etc/resolv.conf'):
        for ip in nm.get_resolved_dns():
            lprint(f'DNS: {ip}')
    else:
        if os.path.exists('/etc/resolv.conf'):
            with open("/etc/resolv.conf", "r") as f:
                for line in f:
                    v = list(map(str.strip, line.split(' ', 1)))
                    if len(v) < 2:
                        continue
                    if v[0] == 'nameserver' and not v[1].startswith("127."):
                        lprint(f'DNS: {v[1]}')
    sys.exit(0)


def get_ip_address(ifname):
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
        return socket.inet_ntoa(fcntl.ioctl(
            s,
            0x8915,  # SIOCGIFADDR
            struct.pack('256s', ifname.encode())
        )[20:24])


def chk_preconditions():
    '''
    Various checks
    Returns if systemd-resolved is being used or not
    '''
    use_resolved = False
    if os.path.exists(RUN_FILE):
        raise Exception('set_dns already ran. Run with --reset first')
    if os.getuid() != 0:
        raise Exception('Needs root')
    if not nm.has_systemd():
        raise Exception('systemd as pid 1 is required')
    if nm.has_system_dnsmasq():
        if nm.is_service_active('dnsmasq'):
            raise Exception('Local DNS configuration not supported. dnsmasq in use.')
        else:
            lprint('dnsmasq systemd unit installed but not running, continuing', level=LOG_WARNING)
    if nm.nm_has_dnsmasq():
        raise Exception('Local DNS configuration not supported. nm has private instance of dnsmasq.')
    if nm.has_resolver('127.0.0.1'):
        raise Exception('Local DNS resolver already running.')
    if nm.is_service_active('systemd_2dresolved'):
        with open('/etc/resolv.conf', 'rt') as f:
            if 'nameserver 127.0.0.53' in f.read():
                lprint('systemd-resolved running')
                use_resolved = True
            else:
                lprint('systemd-resolved running, but resolv.conf is not using it for dns resolution')
    if nm.is_service_active('docker'):
        try:
            docker_ip = get_ip_address('docker0')
            lprint(f'docker running at {docker_ip}')
        except:
            pass
    try:
        subprocess.check_call(['busctl', '--version'])
    except:
        raise Exception('DBus system bus not available')

    # No need for other checks if we're using systemd-resolved
    if use_resolved:
        return use_resolved

    if nm.has_nm() and nm.version_minimum(nm.nmcli_version(), [1,2,2]) is False:
        raise Exception('nmcli version is too old')

    if nm.has_nm() and not nm.has_dnsmasq():
        raise Exception('nm and no dnsmasq combination is not supported')

    if nm.has_dnsmasq():
        if nm.version_minimum(nm.dnsmasq_version(), [2,75]) is False:
            raise Exception('dnsmasq version too old')
        if not nm.dnsmasq_has_capability('DBus'):
            raise Exception('dnsmasq compiled without DBus support')
    else:
        lprint('No dnsmasq found', level=LOG_WARNING)

    return use_resolved


def load_settings():
    r = {}

    try:
        with open('/etc/appgate.conf', 'rt') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                if line[0] in (';', '#'):
                    continue
                elif line.startswith('['):
                    break
                elif '=' not in line:
                    raise Exception(f'Invalid configuration: {line}')
                k, v = line.split('=', 1)
                r[k.strip()] = v.strip()
    except FileNotFoundError:
        lprint('No appgate.conf')
    return r


def parse_domains(par):
    '''
    Example value of par:

    "devops,appgate.com,dns.server.1.2.3.4.two.com,dns.server.4.3.2.1.two.com,dns.server.5.6.7.8.one.com"

    Return a list of tuples like
    ('1.1.1.1', 'domain')
    (None, 'domain2')
    '''
    r = []
    for i in par.split(','):
        if not i.startswith('dns.server.'):
            r.append((None, i))
        elif ':' in i:
            components = i.split('.')
            r.append((
                components[2],
                '.'.join(components[3:])
            ))
        else:
            components = i.split('.')
            r.append((
                '.'.join(components[2:6]),
                '.'.join(components[6:])
            ))
    return r


def get_conf_nameservers():
    '''
    return a list of name servers defined in /etc/resolv.conf
    '''
    with open('/etc/resolv.conf', 'r') as resolvconf:
        nameservers = filter(lambda x: x.startswith('nameserver'), resolvconf.readlines())
        return [i.split()[1] for i in nameservers if len(i) > 1]

def main():
    if PRINT_TO_LOG:
        openlog('appgate-set-dns')
    settings = load_settings()

    # Run custom set_dns script if set
    external_script = settings.get('dns_script')
    if external_script:
        lprint(f'Running external script {external_script}')
        if external_script == sys.argv[0]:
            raise Exception('dns_script pointing to self')
        os.execv(external_script, [external_script] + sys.argv[1:])
        raise Exception(f'Unable to run {external_script}')

    optlist, args = getopt(
        sys.argv[1:],
        'h',
        (
            'help',
            'reset',
            'get',
            'hash',
            'servers=',
            'domains=',
            'tundev=',
        )
    )

    servers = []
    domains = []

    doreset = False
    doget = False
    dohash = False
    tundev = None

    for opt, par in optlist:
        if opt in ('-h', '--help'):
            printhelp()
        elif opt == '--reset':
            doreset = True
        elif opt == '--get':
            doget = True
        elif opt == '--hash':
            dohash = True
        elif opt == '--servers':
            servers = par.split(',') if par else []
        elif opt == '--domains':
            domains = parse_domains(par)
        elif opt == '--tundev':
            tundev = par

    if doreset:
        reset(tundev)

    if doget:
        get()

    if dohash:
        print_hash()

    use_resolved = chk_preconditions()

    # Add servers set using dns.server syntax for resolved
    if use_resolved:
        # ... but only if we didn't get any servers
        if len(servers) == 0:
            servers = servers + [i[0] for i in domains if i[1] != "default" and i[0] is not None and i[0] not in servers]
    # when not using resolved, set global servers set using .default
    else:
        servers = servers + [i[0] for i in domains if i[1] == "default" and i[0] is not None and i[0] not in servers]

    default = [i[0] for i in domains if i[1] == "default"]

    # remove all domains with .default syntax
    domains = list(filter(lambda x: (x[1] != "default"), domains))

    if len(domains) > 0:
        create_search_file(domains)
    create_run_file()

    if use_resolved:
        if default:
            set_default_resolved(tundev, domains, default)
        else:
            set_dns_resolved(tundev, domains, servers)
    elif nm.has_dnsmasq():
        # if we don't have any default servers set, read users nameserver from resolv.conf
        if len(default) == 0:
            default = get_conf_nameservers()
        if not servers and not domains:
            raise Exception('No DNS server provided')
        if servers and domains or domains:
            plaindomains = [i[1] for i in domains if not i[0]]
            add_dns(chain(
                [[i] for i in default],
                product(servers, plaindomains),
                [i for i in domains if i[0]]
            ))
        else:
            add_dns(i for i in servers)

        set_dns_tun(tundev, domains, servers)
    else:
        # Just direct to the Appgate DNS
        if len(servers) > 3:
            lprint('Too many DNS servers', level=LOG_WARNING)
        with open(RESOLVE_FILE, 'wt') as f:
            f.write('\n'.join(f'nameserver {i}' for i in servers))
            f.write('\n')
        systemcheck(['systemctl', 'start', 'appgate-dumb-resolver'])

    # get hash after setting dns
    print_hash()


if __name__ == '__main__':
    try:
         main()
    except Exception as e:
         lprint(str(e), LOG_ERR)
         sys.exit(1)
