#!/usr/bin/python3
###############################################################################
#                                                                             #
# IPFire.org - A linux based firewall                                         #
# Copyright (C) 2016  Michael Tremer                                          #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
###############################################################################

import argparse
import datetime
import daemon
import filecmp
import functools
import ipaddress
import logging
import logging.handlers
import os
import queue
import re
import signal
import socket
import stat
import subprocess
import sys
import tempfile
import threading

LOCAL_TTL = 60

log = logging.getLogger("dhcp")
log.setLevel(logging.DEBUG)

def setup_logging(daemon=True, loglevel=logging.INFO):
	log.setLevel(loglevel)

	# Log to syslog by default
	handler = logging.handlers.SysLogHandler(address="/dev/log", facility="daemon")
	log.addHandler(handler)

	# Format everything
	formatter = logging.Formatter("%(name)s[%(process)d]: %(message)s")
	handler.setFormatter(formatter)

	handler.setLevel(loglevel)

	# If we are running in foreground, we should write everything to the console, too
	if not daemon:
		handler = logging.StreamHandler()
		log.addHandler(handler)

		handler.setLevel(loglevel)

	return log

class UnboundDHCPLeasesBridge(object):
	def __init__(self, dhcp_leases_file, fix_leases_file, unbound_leases_file, hosts_file, socket_path):
		self.leases_file = dhcp_leases_file
		self.fix_leases_file = fix_leases_file
		self.hosts_file = hosts_file
		self.socket_path = socket_path

		self.socket = None

		# Store all known leases
		self.leases = set()

		# Create a queue for all received events
		self.queue = queue.Queue()

		# Initialize the worker
		self.worker = Worker(self.queue, callback=self._handle_message)

		# Initialize the watcher
		self.watcher = Watcher(reload=self.reload)

		self.unbound = UnboundConfigWriter(unbound_leases_file)

	def run(self):
		log.info("Unbound DHCP Leases Bridge started on %s" % self.leases_file)

		# Launch the worker
		self.worker.start()

		# Launch the watcher
		self.watcher.start()

		# Open the server socket
		self.socket = self._open_socket(self.socket_path)

		while True:
			# Accept any incoming connections
			try:
				conn, peer = self.socket.accept()
			except OSError as e:
				break

			try:
				# Receive what the client is sending
				data, ancillary_data, flags, address = conn.recvmsg(4096)

				# Log that we have received some data
				log.debug("Received message of %s byte(s)" % len(data))

				# Decode the data
				message = self._decode_message(data)

				# Add the message to the queue
				self.queue.put(message)

				conn.send(b"OK\n")

			# Send ERROR to the client if something went wrong
			except Exception as e:
				log.error("Could not handle message: %s" % e)

				conn.send(b"ERROR\n")
				continue

			# Close the connection
			finally:
				conn.close()

		# Terminate the worker
		self.queue.put(None)

		# Terminate the watcher
		self.watcher.terminate()

		# Wait for the worker and watcher to finish
		self.worker.join()
		self.watcher.join()

		log.info("Unbound DHCP Leases Bridge terminated")

	def _open_socket(self, path):
		# Allocate a new socket
		s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)

		# Unlink any old sockets
		try:
			os.unlink(path)
		except FileNotFoundError as e:
			pass

		# Bind the socket
		try:
			s.bind(self.socket_path)
		except OSError as e:
			log.error("Could not open socket at %s: %s" % (path, e))

			raise SystemExit(1) from e

		# Listen
		s.listen(128)

		return s

	def _decode_message(self, data):
		message = {}

		for line in data.splitlines():
			# Skip empty lines
			if not line:
				continue

			# Try to decode the line
			try:
				line = line.decode()
			except UnicodeError as e:
				log.error("Could not decode %r: %s" % (line, e))

				raise e

			# Split the line
			key, _, value = line.partition("=")

			# Skip the line if it does not have a value
			if not _:
				raise ValueError("No value given")

			# Store the attributes
			message[key] = value

		return message

	def _handle_message(self, message):
		log.debug("Handling message:")
		for key in message:
			log.debug("  %-20s = %s" % (key, message[key]))

		# Extract the event type
		event = message.get("EVENT")

		# Check if event is set
		if not event:
			raise ValueError("The message does not have EVENT set")

		# COMMIT
		elif event == "commit":
			address = message.get("ADDRESS")
			name    = message.get("NAME")

			# Find the old lease
			old_lease = self._find_lease(address)

			# Don't update fixed leases as they might clear the hostname
			if old_lease and old_lease.fixed:
				log.debug("Won't update fixed lease %s" % old_lease)
				return

			# Create a new lease
			lease = Lease(address, {
				"client-hostname" : name,
			})
			self._add_lease(lease)

			# Can we skip the update?
			if old_lease:
				if lease.rrset == old_lease.rrset:
					log.debug("Won't update %s as nothing has changed" % lease)
					return

				# Remove the old lease first
				self.unbound.remove_lease(old_lease)
				self._remove_lease(old_lease)

			# Apply the lease
			self.unbound.apply_lease(lease)

		# RELEASE/EXPIRY
		elif event in ("release", "expiry"):
			address = message.get("ADDRESS")

			# Find the lease
			lease = self._find_lease(address)

			if not lease:
				log.warning("Could not find lease for %s" % address)
				return

			# Remove the lease
			self.unbound.remove_lease(lease)
			self._remove_lease(lease)

		# Raise an error if the event is not supported
		else:
			raise ValueError("Unsupported event: %s" % event)

	def update_dhcp_leases(self):
		# Drop all known leases
		self.leases.clear()

		# Add all dynamic leases
		for lease in DHCPLeases(self.leases_file):
			self._add_lease(lease)

		# Add all static leases
		for lease in FixLeases(self.fix_leases_file):
			self._add_lease(lease)

		# Dump leases
		if self.leases:
			log.debug("DHCP Leases:")
			for lease in self.leases:
				log.debug("  %s:" % lease.fqdn)
				log.debug("    Start: %s" % lease.time_starts)
				log.debug("    End  : %s" % lease.time_ends)
				if lease.has_expired():
					log.debug("    Expired")

		self.unbound.update_dhcp_leases([l for l in self.leases if not l.has_expired()])

	def _add_lease(self, lease):
		# Skip leases without a FQDN
		if not lease.fqdn:
			log.debug("Skipping lease without a FQDN: %s" % lease)
			return

		# Skip any leases that also are a static host
		elif lease.fqdn in self.hosts:
			log.debug("Skipping lease for which a static host exists: %s" % lease)
			return

		# Don't add expired leases
		elif lease.has_expired():
			log.debug("Skipping expired lease: %s" % lease)
			return

		# Remove any previous leases
		self._remove_lease(lease)

		# Store the lease
		self.leases.add(lease)

	def _find_lease(self, ipaddr):
		"""
			Returns the lease with the specified IP address
		"""
		if not isinstance(ipaddr, ipaddress.IPv4Address):
			ipaddr = ipaddress.IPv4Address(ipaddr)

		for lease in self.leases:
			if lease.ipaddr == ipaddr:
				return lease

	def _remove_lease(self, lease):
		try:
			self.leases.remove(lease)
		except KeyError:
			pass

	def read_static_hosts(self):
		log.info("Reading static hosts from %s" % self.hosts_file)

		hosts = {}
		with open(self.hosts_file) as f:
			for line in f.readlines():
				line = line.rstrip()

				try:
					enabled, ipaddr, hostname, domainname, generateptr = line.split(",")
				except:
					log.warning("Could not parse line: %s" % line)
					continue

				# Skip any disabled entries
				if not enabled == "on":
					continue

				if hostname and domainname:
					fqdn = "%s.%s" % (hostname, domainname)
				elif hostname:
					fqdn = hostname
				elif domainname:
					fqdn = domainname

				try:
					hosts[fqdn].append(ipaddr)
					hosts[fqdn].sort()
				except KeyError:
					hosts[fqdn] = [ipaddr,]

		# Dump everything in the logs
		log.debug("Static hosts:")
		for name in hosts:
			log.debug("  %-20s : %s" % (name, ", ".join(hosts[name])))

		return hosts

	def reload(self, *args, **kwargs):
		# Read all static hosts
		self.hosts = self.read_static_hosts()

		# Unconditionally update all leases and reload Unbound
		self.update_dhcp_leases()

	def terminate(self, *args, **kwargs):
		# Close the socket
		if self.socket:
			self.socket.close()


class Watcher(threading.Thread):
	"""
		Watches if Unbound is still running.
	"""
	def __init__(self, reload, *args, **kwargs):
		super().__init__(*args, **kwargs)

		self.reload = reload

		# Set to true if this thread should be terminated
		self._terminated = threading.Event()

	def run(self):
		log.debug("Watcher launched")

		pidfd = None

		while True:
			# One iteration takes 30 seconds unless we don't know the process
			# when we try to find it once a second.
			if self._terminated.wait(30 if pidfd else 1):
				break

			# Fetch a PIDFD for Unbound
			if pidfd is None:
				pidfd = self._get_pidfd()

				# If we could not acquire a PIDFD, we will try again soon...
				if not pidfd:
					log.warning("Cannot find Unbound...")
					continue

				# Since Unbound has been restarted, we need to reload it all...
				self.reload()

			log.debug("Checking if Unbound is still alive...")

			# Send the process a signal
			try:
				signal.pidfd_send_signal(pidfd, signal.SIG_DFL)

			# If the process has died, we land here and will have to wait until Unbound
			# has come back and reload it...
			except ProcessLookupError as e:
				log.error("Unbound has died")

				# Reset the PIDFD
				pidfd = None

			else:
				log.debug("Unbound is alive")

		log.debug("Watcher terminated")

	def terminate(self):
		"""
			Called to signal this thread to terminate
		"""
		self._terminated.set()

	def _get_pidfd(self):
		"""
			Returns a PIDFD for unbound if it is running, otherwise None.
		"""
		# Try to find the PID
		pid = pidof("unbound")

		if pid:
			log.debug("Unbound is running as PID %s" % pid)

			# Open a PIDFD
			pidfd = os.pidfd_open(pid)

			log.debug("Acquired PIDFD %s for PID %s" % (pidfd, pid))

			return pidfd


class Worker(threading.Thread):
	"""
		The worker is launched in a separate thread
		which allows us to perform some tasks asynchronously.
	"""
	def __init__(self, queue, callback):
		super().__init__()

		self.queue = queue
		self.callback = callback

	def run(self):
		log.debug("Worker %s launched" % self.native_id)

		while True:
			message = self.queue.get()

			# If the message is None, we have to quit
			if message is None:
				break

			# Call the callback
			try:
				self.callback(message)
			except Exception as e:
				log.error("Callback failed: %s" % e, exc_info=True)

		log.debug("Worker %s terminated" % self.native_id)


class DHCPLeases(object):
	regex_leaseblock = re.compile(r"lease (?P<ipaddr>\d+\.\d+\.\d+\.\d+) {(?P<config>[\s\S]+?)\n}")

	def __init__(self, path):
		self.path = path

		self._leases = self._parse()

	def __iter__(self):
		return iter(self._leases)

	def _parse(self):
		log.info("Reading DHCP leases from %s" % self.path)

		leases = []

		with open(self.path) as f:
			# Read entire leases file
			data = f.read()

			for match in self.regex_leaseblock.finditer(data):
				block = match.groupdict()

				ipaddr = block.get("ipaddr")
				config = block.get("config")

				properties = self._parse_block(config)

				# Skip any abandoned leases
				if not "hardware" in properties:
					continue

				# Skip inactive leases
				elif not properties.get("binding", "state active"):
					continue

				lease = Lease(ipaddr, properties)
				leases.append(lease)

		return leases

	def _parse_block(self, block):
		properties = {}

		for line in block.splitlines():
			if not line:
				continue

			# Remove trailing ; from line
			if line.endswith(";"):
				line = line[:-1]

			# Invalid line if it doesn't end with ;
			else:
				continue

			# Remove any leading whitespace
			line = line.lstrip()

			# We skip all options and sets
			if line.startswith("option") or line.startswith("set"):
				continue

			# Split by first space
			key, val = line.split(" ", 1)
			properties[key] = val

		return properties


class FixLeases(object):
	def __init__(self, path):
		self.path = path

		self._leases = self._parse()

	def __iter__(self):
		return iter(self._leases)

	def _parse(self):
		log.info("Reading fix leases from %s" % self.path)

		now = datetime.datetime.utcnow()

		leases = []

		with open(self.path) as f:
			for line in f.readlines():
				line = line.rstrip()

				try:
					hwaddr, ipaddr, enabled, a, b, c, hostname = line.split(",")
				except ValueError:
					log.warning("Could not parse line: %s" % line)
					continue

				# Skip any disabled leases
				if not enabled == "on":
					continue

				l = Lease(ipaddr, {
					"binding"         : "state active",
					"client-hostname" : hostname,
					"starts"          : now.strftime("%w %Y/%m/%d %H:%M:%S"),
					"ends"            : "never",
				}, fixed=True)
				leases.append(l)

		return leases


class Lease(object):
	def __init__(self, ipaddr, properties, fixed=False):
		if not isinstance(ipaddr, ipaddress.IPv4Address):
			ipaddr = ipaddress.IPv4Address(ipaddr)

		self.ipaddr = ipaddr
		self._properties = properties
		self.fixed = fixed

	def __repr__(self):
		return "<%s for %s (%s)>" % (self.__class__.__name__, self.ipaddr, self.hostname)

	def __eq__(self, other):
		if isinstance(other, self.__class__):
			return self.ipaddr == other.ipaddr

		return NotImplemented

	def __gt__(self, other):
		if isinstance(other, self.__class__):
			if not self.ipaddr == other.ipaddr:
				return NotImplemented

			return self.time_starts > other.time_starts

		return NotImplemented

	def __hash__(self):
		return hash(self.ipaddr)

	@property
	def hostname(self):
		hostname = self._properties.get("client-hostname")

		if hostname is None:
			return

		# Remove any ""
		hostname = hostname.replace("\"", "")

		# Only return valid hostnames
		m = re.match(r"^[A-Z0-9\-]{1,63}$", hostname, re.I)
		if m:
			return hostname

	@property
	def domain(self):
		# Load ethernet settings
		ethernet_settings = self.read_settings("/var/ipfire/ethernet/settings")

		# Load DHCP settings
		dhcp_settings = self.read_settings("/var/ipfire/dhcp/settings")

		subnets = {}
		for zone in ("GREEN", "BLUE"):
			if not dhcp_settings.get("ENABLE_%s" % zone) == "on":
				continue

			netaddr = ethernet_settings.get("%s_NETADDRESS" % zone)
			submask = ethernet_settings.get("%s_NETMASK" % zone)

			subnet = ipaddress.ip_network("%s/%s" % (netaddr, submask))
			domain = dhcp_settings.get("DOMAIN_NAME_%s" % zone)

			subnets[subnet] = domain

		address = ipaddress.ip_address(self.ipaddr)

		for subnet in subnets:
			if address in subnet:
				return subnets[subnet]

		# Load main settings
		settings = self.read_settings("/var/ipfire/main/settings")

		# Fall back to the host domain if no match could be found
		return settings.get("DOMAINNAME", "localdomain")

	@staticmethod
	@functools.cache
	def read_settings(filename):
		settings = {}

		with open(filename) as f:
			for line in f.readlines():
				# Remove line-breaks
				line = line.rstrip()

				k, v = line.split("=", 1)
				settings[k] = v

		return settings

	@property
	def fqdn(self):
		if self.hostname:
			return "%s.%s" % (self.hostname, self.domain)

	@staticmethod
	def _parse_time(s):
		return datetime.datetime.strptime(s, "%w %Y/%m/%d %H:%M:%S")

	@property
	def time_starts(self):
		starts = self._properties.get("starts")

		if starts:
			return self._parse_time(starts)

	@property
	def time_ends(self):
		ends = self._properties.get("ends")

		if not ends or ends == "never":
			return

		return self._parse_time(ends)

	def has_expired(self):
		if not self.time_starts:
			return

		if not self.time_ends:
			return self.time_starts > datetime.datetime.utcnow()

		return not self.time_starts < datetime.datetime.utcnow() < self.time_ends

	@property
	def rrset(self):
		# If the lease does not have a valid FQDN, we cannot create any RRs
		if self.fqdn is None:
			return []

		return [
			# Forward record
			(self.fqdn, "%s" % LOCAL_TTL, "IN A", "%s" % self.ipaddr),

			# Reverse record
			(self.ipaddr.reverse_pointer, "%s" % LOCAL_TTL,
				"IN PTR", self.fqdn),
		]


class UnboundConfigWriter(object):
	def __init__(self, path):
		self.path = path

	def update_dhcp_leases(self, leases):
		# Write out all leases
		if self.write_dhcp_leases(leases):
			log.debug("Reloading Unbound...")

			# Reload the configuration without dropping the cache
			self._control("reload_keep_cache")

	def write_dhcp_leases(self, leases):
		log.debug("Writing DHCP leases...")

		with tempfile.NamedTemporaryFile(mode="w") as f:
			for l in sorted(leases, key=lambda x: x.ipaddr):
				for rr in l.rrset:
					f.write("local-data: \"%s\"\n" % " ".join(rr))

			# Flush the file
			f.flush()

			# Compare if the new leases file has changed from the previous version
			try:
				if filecmp.cmp(f.name, self.path, shallow=False):
					log.debug("The generated leases file has not changed")

					return False

				# Remove the old file
				os.unlink(self.path)

			# If the previous file did not exist, just keep falling through
			except FileNotFoundError:
				pass

			# Make file readable for everyone
			os.fchmod(f.fileno(), stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH)

			# Move the file to its destination
			os.link(f.name, self.path)

		return True

	def _control(self, *args):
		command = ["unbound-control"]
		command.extend(args)

		# Log what we are doing
		log.debug("Running %s" % " ".join(command))

		try:
			subprocess.check_output(command)

		# Log any errors
		except subprocess.CalledProcessError as e:
			log.critical("Could not run %s, error code: %s: %s" % (
				" ".join(command), e.returncode, e.output))

			raise e

	def apply_lease(self, lease):
		"""
			This method takes a lease and updates Unbound at runtime.
		"""
		log.debug("Applying lease %s" % lease)

		for rr in lease.rrset:
			log.debug("Adding new record %s" % " ".join(rr))

			self._control("local_data", *rr)

	def remove_lease(self, lease):
		"""
			This method takes a lease and removes it from Unbound at runtime.
		"""
		log.debug("Removing lease %s" % lease)

		for name, ttl, type, content in lease.rrset:
			log.debug("Removing records for %s" % name)

			self._control("local_data_remove", name)


def pidof(program):
	"""
		Returns the first PID of the given program.
	"""
	try:
		output = subprocess.check_output(["pidof", program])
	except subprocess.CalledProcessError as e:
		return

	# Convert to string
	output = output.decode()

	# Return the first PID
	for pid in output.split():
		try:
			pid = int(pid)
		except ValueError:
			continue

		return pid


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Bridge for DHCP Leases and Unbound DNS")

	# Daemon Stuff
	parser.add_argument("--daemon", "-d", action="store_true",
		help="Launch as daemon in background")
	parser.add_argument("--verbose", "-v", action="count", help="Be more verbose")

	# Paths
	parser.add_argument("--dhcp-leases", default="/var/state/dhcp/dhcpd.leases",
		metavar="PATH", help="Path to the DHCPd leases file")
	parser.add_argument("--unbound-leases", default="/etc/unbound/dhcp-leases.conf",
		metavar="PATH", help="Path to the unbound configuration file")
	parser.add_argument("--fix-leases", default="/var/ipfire/dhcp/fixleases",
		metavar="PATH", help="Path to the fix leases file")
	parser.add_argument("--hosts", default="/var/ipfire/main/hosts",
		metavar="PATH", help="Path to static hosts file")
	parser.add_argument("--socket-path", default="/var/run/unbound-dhcp-leases-bridge.sock",
		metavar="PATH", help="Socket Path",
	)

	# Parse command line arguments
	args = parser.parse_args()

	# Setup logging
	loglevel = logging.WARN

	if args.verbose:
		if args.verbose == 1:
			loglevel = logging.INFO
		elif args.verbose >= 2:
			loglevel = logging.DEBUG

	bridge = UnboundDHCPLeasesBridge(args.dhcp_leases, args.fix_leases,
		args.unbound_leases, args.hosts, socket_path=args.socket_path)

	with daemon.DaemonContext(
		detach_process=args.daemon,
		stderr=None if args.daemon else sys.stderr,
		signal_map = {
			signal.SIGHUP  : bridge.reload,
			signal.SIGINT  : bridge.terminate,
			signal.SIGTERM : bridge.terminate,
		},
	) as daemon:
		setup_logging(daemon=args.daemon, loglevel=loglevel)

		bridge.run()
