#!/usr/bin/python3
###############################################################################
#                                                                             #
# IPFire.org - A linux based firewall                                         #
# Copyright (C) 2016  Michael Tremer                                          #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
###############################################################################

import argparse
import datetime
import daemon
import functools
import ipaddress
import logging
import logging.handlers
import os
import re
import signal
import stat
import subprocess
import sys
import tempfile
import time

import inotify.adapters

LOCAL_TTL = 60

log = logging.getLogger("dhcp")
log.setLevel(logging.DEBUG)

def setup_logging(daemon=True, loglevel=logging.INFO):
	log.setLevel(loglevel)

	# Log to syslog by default
	handler = logging.handlers.SysLogHandler(address="/dev/log", facility="daemon")
	log.addHandler(handler)

	# Format everything
	formatter = logging.Formatter("%(name)s[%(process)d]: %(message)s")
	handler.setFormatter(formatter)

	handler.setLevel(loglevel)

	# If we are running in foreground, we should write everything to the console, too
	if not daemon:
		handler = logging.StreamHandler()
		log.addHandler(handler)

		handler.setLevel(loglevel)

	return log

def ip_address_to_reverse_pointer(address):
	parts = address.split(".")
	parts.reverse()

	return "%s.in-addr.arpa" % ".".join(parts)

def reverse_pointer_to_ip_address(rr):
	parts = rr.split(".")

	# Only take IP address part
	parts = reversed(parts[0:4])

	return ".".join(parts)

class UnboundDHCPLeasesBridge(object):
	def __init__(self, dhcp_leases_file, fix_leases_file, unbound_leases_file, hosts_file):
		self.leases_file = dhcp_leases_file
		self.fix_leases_file = fix_leases_file
		self.hosts_file = hosts_file

		self.watches = {
			self.leases_file     : inotify.constants.IN_MODIFY,
			self.fix_leases_file : 0,
			self.hosts_file      : 0,
		}

		self.unbound = UnboundConfigWriter(unbound_leases_file)
		self.running = False

	def run(self):
		log.info("Unbound DHCP Leases Bridge started on %s" % self.leases_file)
		self.running = True

		i = inotify.adapters.Inotify()

		# Add watches for the directories of every relevant file
		for f, mask in self.watches.items():
			i.add_watch(
				os.path.dirname(f),
				mask | inotify.constants.IN_CLOSE_WRITE | inotify.constants.IN_MOVED_TO,
			)

		# Enabled so that we update hosts and leases on startup
		update_hosts = update_leases = True

		while self.running:
			log.debug("Wakeup of main loop")

			# Process the entire inotify queue and identify what we need to do
			for event in i.event_gen():
				# Nothing to do
				if event is None:
					break

				# Decode the event
				header, type_names, path, filename = event

				file = os.path.join(path, filename)

				log.debug("inotify event received for %s: %s", file, " ".join(type_names))

				# Did the hosts file change?
				if self.hosts_file == file:
					update_hosts = True

				# We will need to update the leases on any change
				update_leases = True

			# Update hosts (if needed)
			if update_hosts:
				self.hosts = self.read_static_hosts()

			# Update leases (if needed)
			if update_leases:
				self.update_dhcp_leases()

			# Reset
			update_hosts = update_leases = False

			# Wait a moment before we start the next iteration
			time.sleep(5)

		log.info("Unbound DHCP Leases Bridge terminated")

	def update_dhcp_leases(self):
		leases = []

		for lease in DHCPLeases(self.leases_file):
			# Don't bother with any leases that don't have a hostname
			if not lease.fqdn:
				continue

			leases.append(lease)

		for lease in FixLeases(self.fix_leases_file):
			leases.append(lease)

		# Skip any leases that also are a static host
		leases = [l for l in leases if not l.fqdn in self.hosts]

		# Remove any inactive or expired leases
		leases = [l for l in leases if l.active and not l.expired]

		# Dump leases
		if leases:
			log.debug("DHCP Leases:")
			for lease in leases:
				log.debug("  %s:" % lease.fqdn)
				log.debug("    State: %s" % lease.binding_state)
				log.debug("    Start: %s" % lease.time_starts)
				log.debug("    End  : %s" % lease.time_ends)
				if lease.expired:
					log.debug("    Expired")

		self.unbound.update_dhcp_leases(leases)

	def read_static_hosts(self):
		log.info("Reading static hosts from %s" % self.hosts_file)

		hosts = {}
		with open(self.hosts_file) as f:
			for line in f.readlines():
				line = line.rstrip()

				try:
					enabled, ipaddr, hostname, domainname, generateptr = line.split(",")
				except:
					log.warning("Could not parse line: %s" % line)
					continue

				# Skip any disabled entries
				if not enabled == "on":
					continue

				if hostname and domainname:
					fqdn = "%s.%s" % (hostname, domainname)
				elif hostname:
					fqdn = hostname
				elif domainname:
					fqdn = domainname

				try:
					hosts[fqdn].append(ipaddr)
					hosts[fqdn].sort()
				except KeyError:
					hosts[fqdn] = [ipaddr,]

		# Dump everything in the logs
		log.debug("Static hosts:")
		for name in hosts:
			log.debug("  %-20s : %s" % (name, ", ".join(hosts[name])))

		return hosts

	def terminate(self):
		self.running = False


class DHCPLeases(object):
	regex_leaseblock = re.compile(r"lease (?P<ipaddr>\d+\.\d+\.\d+\.\d+) {(?P<config>[\s\S]+?)\n}")

	def __init__(self, path):
		self.path = path

		self._leases = self._parse()

	def __iter__(self):
		return iter(self._leases)

	def _parse(self):
		log.info("Reading DHCP leases from %s" % self.path)

		leases = []

		with open(self.path) as f:
			# Read entire leases file
			data = f.read()

			for match in self.regex_leaseblock.finditer(data):
				block = match.groupdict()

				ipaddr = block.get("ipaddr")
				config = block.get("config")

				properties = self._parse_block(config)

				# Skip any abandoned leases
				if not "hardware" in properties:
					continue

				lease = Lease(ipaddr, properties)

				# Check if a lease for this Ethernet address already
				# exists in the list of known leases. If so replace
				# if with the most recent lease
				for i, l in enumerate(leases):
					if l.ipaddr == lease.ipaddr:
						leases[i] = max(lease, l)
						break

				else:
					leases.append(lease)

		return leases

	def _parse_block(self, block):
		properties = {}

		for line in block.splitlines():
			if not line:
				continue

			# Remove trailing ; from line
			if line.endswith(";"):
				line = line[:-1]

			# Invalid line if it doesn't end with ;
			else:
				continue

			# Remove any leading whitespace
			line = line.lstrip()

			# We skip all options and sets
			if line.startswith("option") or line.startswith("set"):
				continue

			# Split by first space
			key, val = line.split(" ", 1)
			properties[key] = val

		return properties


class FixLeases(object):
	cache = {}

	def __init__(self, path):
		self.path = path

		self._leases = self.cache[self.path] = self._parse()

	def __iter__(self):
		return iter(self._leases)

	def _parse(self):
		log.info("Reading fix leases from %s" % self.path)

		leases = []
		now = datetime.datetime.utcnow()

		with open(self.path) as f:
			for line in f.readlines():
				line = line.rstrip()

				try:
					hwaddr, ipaddr, enabled, a, b, c, hostname = line.split(",")
				except ValueError:
					log.warning("Could not parse line: %s" % line)
					continue

				# Skip any disabled leases
				if not enabled == "on":
					continue

				l = Lease(ipaddr, {
					"binding"         : "state active",
					"client-hostname" : hostname,
					"hardware"        : "ethernet %s" % hwaddr,
					"starts"          : now.strftime("%w %Y/%m/%d %H:%M:%S"),
					"ends"            : "never",
				})
				leases.append(l)

		# Try finding any deleted leases
		for lease in self.cache.get(self.path, []):
			if lease in leases:
				continue

			# Free the deleted lease
			lease.free()
			leases.append(lease)

		return leases


class Lease(object):
	def __init__(self, ipaddr, properties):
		self.ipaddr = ipaddr
		self._properties = properties

	def __repr__(self):
		return "<%s %s for %s (%s)>" % (self.__class__.__name__,
			self.ipaddr, self.hwaddr, self.hostname)

	def __eq__(self, other):
		return self.ipaddr == other.ipaddr and self.hwaddr == other.hwaddr

	def __gt__(self, other):
		if not self.ipaddr == other.ipaddr:
			return

		if not self.hwaddr == other.hwaddr:
			return

		return self.time_starts > other.time_starts

	@property
	def binding_state(self):
		state = self._properties.get("binding")

		if state:
			state = state.split(" ", 1)
			return state[1]

	def free(self):
		self._properties.update({
			"binding" : "state free",
		})

	@property
	def active(self):
		return self.binding_state == "active"

	@property
	def hwaddr(self):
		hardware = self._properties.get("hardware")

		if not hardware:
			return

		ethernet, address = hardware.split(" ", 1)

		return address

	@property
	def hostname(self):
		hostname = self._properties.get("client-hostname")

		if hostname is None:
			return

		# Remove any ""
		hostname = hostname.replace("\"", "")

		# Only return valid hostnames
		m = re.match(r"^[A-Z0-9\-]{1,63}$", hostname, re.I)
		if m:
			return hostname

	@property
	def domain(self):
		# Load ethernet settings
		ethernet_settings = self.read_settings("/var/ipfire/ethernet/settings")

		# Load DHCP settings
		dhcp_settings = self.read_settings("/var/ipfire/dhcp/settings")

		subnets = {}
		for zone in ("GREEN", "BLUE"):
			if not dhcp_settings.get("ENABLE_%s" % zone) == "on":
				continue

			netaddr = ethernet_settings.get("%s_NETADDRESS" % zone)
			submask = ethernet_settings.get("%s_NETMASK" % zone)

			subnet = ipaddress.ip_network("%s/%s" % (netaddr, submask))
			domain = dhcp_settings.get("DOMAIN_NAME_%s" % zone)

			subnets[subnet] = domain

		address = ipaddress.ip_address(self.ipaddr)

		for subnet in subnets:
			if address in subnet:
				return subnets[subnet]

		# Load main settings
		settings = self.read_settings("/var/ipfire/main/settings")

		# Fall back to the host domain if no match could be found
		return settings.get("DOMAINNAME", "localdomain")

	@staticmethod
	@functools.cache
	def read_settings(filename):
		settings = {}

		with open(filename) as f:
			for line in f.readlines():
				# Remove line-breaks
				line = line.rstrip()

				k, v = line.split("=", 1)
				settings[k] = v

		return settings

	@property
	def fqdn(self):
		if self.hostname:
			return "%s.%s" % (self.hostname, self.domain)

	@staticmethod
	def _parse_time(s):
		return datetime.datetime.strptime(s, "%w %Y/%m/%d %H:%M:%S")

	@property
	def time_starts(self):
		starts = self._properties.get("starts")

		if starts:
			return self._parse_time(starts)

	@property
	def time_ends(self):
		ends = self._properties.get("ends")

		if not ends or ends == "never":
			return

		return self._parse_time(ends)

	@property
	def expired(self):
		if not self.time_ends:
			return self.time_starts > datetime.datetime.utcnow()

		return self.time_starts > datetime.datetime.utcnow() > self.time_ends

	@property
	def rrset(self):
		# If the lease does not have a valid FQDN, we cannot create any RRs
		if self.fqdn is None:
			return []

		return [
			# Forward record
			(self.fqdn, "%s" % LOCAL_TTL, "IN A", self.ipaddr),

			# Reverse record
			(ip_address_to_reverse_pointer(self.ipaddr), "%s" % LOCAL_TTL,
				"IN PTR", self.fqdn),
		]


class UnboundConfigWriter(object):
	def __init__(self, path):
		self.path = path

	def update_dhcp_leases(self, leases):
		# Write out all leases
		self.write_dhcp_leases(leases)

		log.debug("Reloading Unbound...")

		# Reload the configuration without dropping the cache
		self._control("reload_keep_cache")

	def write_dhcp_leases(self, leases):
		log.debug("Writing DHCP leases...")

		with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
			for l in leases:
				for rr in l.rrset:
					f.write("local-data: \"%s\"\n" % " ".join(rr))

			# Make file readable for everyone
			os.fchmod(f.fileno(), stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH)

			# Move the file to its destination
			os.rename(f.name, self.path)

	def _control(self, *args):
		command = ["unbound-control"]
		command.extend(args)

		try:
			subprocess.check_output(command)

		# Log any errors
		except subprocess.CalledProcessError as e:
			log.critical("Could not run %s, error code: %s: %s" % (
				" ".join(command), e.returncode, e.output))

			raise e


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Bridge for DHCP Leases and Unbound DNS")

	# Daemon Stuff
	parser.add_argument("--daemon", "-d", action="store_true",
		help="Launch as daemon in background")
	parser.add_argument("--verbose", "-v", action="count", help="Be more verbose")

	# Paths
	parser.add_argument("--dhcp-leases", default="/var/state/dhcp/dhcpd.leases",
		metavar="PATH", help="Path to the DHCPd leases file")
	parser.add_argument("--unbound-leases", default="/etc/unbound/dhcp-leases.conf",
		metavar="PATH", help="Path to the unbound configuration file")
	parser.add_argument("--fix-leases", default="/var/ipfire/dhcp/fixleases",
		metavar="PATH", help="Path to the fix leases file")
	parser.add_argument("--hosts", default="/var/ipfire/main/hosts",
		metavar="PATH", help="Path to static hosts file")

	# Parse command line arguments
	args = parser.parse_args()

	# Setup logging
	loglevel = logging.WARN

	if args.verbose:
		if args.verbose == 1:
			loglevel = logging.INFO
		elif args.verbose >= 2:
			loglevel = logging.DEBUG

	bridge = UnboundDHCPLeasesBridge(args.dhcp_leases, args.fix_leases,
		args.unbound_leases, args.hosts)

	with daemon.DaemonContext(
		detach_process=args.daemon,
		stderr=None if args.daemon else sys.stderr,
		signal_map = {
			signal.SIGHUP  : bridge.update_dhcp_leases,
			signal.SIGTERM : bridge.terminate,
		},
	) as daemon:
		setup_logging(daemon=args.daemon, loglevel=loglevel)

		bridge.run()
