#!/usr/bin/python3
###############################################################################
#                                                                             #
# IPFire.org - A linux based firewall                                         #
# Copyright (C) 2022  Michael Tremer                                          #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
###############################################################################

import argparse
import base64
import csv
import daemon
import logging
import logging.handlers
import signal
import socket
import subprocess
import sys

OPENVPN_CONFIG = "/var/ipfire/ovpn/ovpnconfig"

CHALLENGETEXT = "One Time Token: "

log = logging.getLogger()
log.setLevel(logging.DEBUG)

def setup_logging(daemon=True, loglevel=logging.INFO):
	log.setLevel(loglevel)

	# Log to syslog by default
	handler = logging.handlers.SysLogHandler(address="/dev/log", facility="daemon")
	log.addHandler(handler)

	# Format everything
	formatter = logging.Formatter("%(name)s[%(process)d]: %(message)s")
	handler.setFormatter(formatter)

	handler.setLevel(loglevel)

	# If we are running in foreground, we should write everything to the console, too
	if not daemon:
		handler = logging.StreamHandler()
		log.addHandler(handler)

		handler.setLevel(loglevel)

	return log

class OpenVPNAuthenticator(object):
	def __init__(self, socket_path):
		self.socket_path = socket_path

	def _read_line(self):
		buf = []

		while True:
			char = self.sock.recv(1)

			# Break if we could not read from the socket
			if not char:
				raise EOFError("Could not read from socket")

			# Append to buffer
			buf.append(char)

			# Reached end of line
			if char == b"\n":
				break

		line = b"".join(buf).decode()
		line = line.rstrip()

		log.debug("< %s" % line)

		return line

	def _write_line(self, line):
		log.debug("> %s" % line)

		if not line.endswith("\n"):
			line = "%s\n" % line

		# Convert into bytes
		buf = line.encode()

		# Send to socket
		self.sock.send(buf)

	def _send_command(self, command):
		# Send the command
		self._write_line(command)

	def run(self):
		# Connect to socket
		self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
		self.sock.connect(self.socket_path)

		log.info("OpenVPN Authenticator started")

		try:
			while True:
				line = self._read_line()

				if line.startswith(">CLIENT"):
					self._client_event(line)

		# Terminate the daemon when it loses its connection to the OpenVPN daemon
		except (ConnectionResetError, EOFError) as e:
			log.error("Connection to OpenVPN has been lost: %s" % e)

		log.info("OpenVPN Authenticator terminated")

	def terminate(self, *args):
		# XXX TODO
		raise SystemExit

	def _client_event(self, line):
		# Strip away "CLIENT:"
		client, delim, line = line.partition(":")

		# Extract the event & split any arguments
		event, delim, arguments = line.partition(",")
		arguments = arguments.split(",")

		environ = {}

		if event == "CONNECT":
			environ = self._read_env(environ)
			self._client_connect(*arguments, environ=environ)
		elif event == "DISCONNECT":
			environ = self._read_env(environ)
			self._client_disconnect(*arguments, environ=environ)
		elif event == "REAUTH":
			environ = self._read_env(environ)
			self._client_reauth(*arguments, environ=environ)
		elif event == "ESTABLISHED":
			environ = self._read_env(environ)
		else:
			log.debug("Unhandled event: %s" % event)

	def _read_env(self, environ):
		# Read environment
		while True:
			line = self._read_line()

			if not line.startswith(">CLIENT:ENV,"):
				raise RuntimeError("Unexpected environment line: %s" % line)

			# Strip >CLIENT:ENV,
			line = line[12:]

			# Done
			if line == "END":
				break

			# Parse environment
			key, delim, value = line.partition("=")
			environ[key] = value

		return environ

	def _client_connect(self, cid, kid, environ={}):
		log.debug("Received client connect (cid=%s, kid=%s)" % (cid, kid))
		for key in sorted(environ):
			log.debug("  %s : %s" % (key, environ[key]))

		# Fetch common name
		common_name = environ.get("common_name")

		# Find connection details
		conn = self._find_connection(common_name)
		if not conn:
			log.warning("Could not find connection '%s'" % common_name)
			# XXX deny auth?

		log.debug("Found connection:")
		for key in conn:
			log.debug("  %s : %s" % (key, conn[key]))

		# Perform no further checks if TOTP is disabled for this client
		if not conn.get("totp_status") == "on":
			return self._client_auth_successful(cid, kid)

		# Fetch username & password
		username = environ.get("username")
		password = environ.get("password")

		# Client sent the special password TOTP to start challenge authentication
		if password == "TOTP":
			return self._client_auth_challenge(cid, kid,
				username=common_name, password="TOTP")

		elif password.startswith("CRV1:"):
			log.debug("Received dynamic challenge response %s" % password)

			# Decode the string
			(command, flags, username, password, token) = password.split(":", 5)

			# Decode username
			username = self._b64decode(username)

			# Check if username matches common name
			if username == common_name:
				# Check if TOTP token matches
				if self._check_totp_token(token, conn.get("totp_secret")):
					return self._client_auth_successful(cid, kid)

			# Restart authentication
			self._client_auth_challenge(cid, kid,
				username=common_name, password="TOTP")

	def _client_disconnect(self, cid, environ={}):
		"""
			Handles CLIENT:DISCONNECT events
		"""
		pass

	def _client_reauth(self, cid, kid, environ={}):
		"""
			Handles CLIENT:REAUTH events
		"""
		# Perform no checks
		self._client_auth_successful(cid, kid)

	def _client_auth_challenge(self, cid, kid, username, password):
		"""
			Initiates a dynamic challenge authentication with the client
		"""
		log.debug("Sending request for dynamic challenge...")

		self._send_command(
			"client-deny %s %s \"CRV1\" \"CRV1:R,E:%s:%s:%s\"" % (
				cid,
				kid,
				self._b64encode(username),
				self._b64encode(password),
				self._escape(CHALLENGETEXT),
			),
		)

	def _client_auth_successful(self, cid, kid):
		"""
			Sends a positive authentication response
		"""
		log.debug("Client Authentication Successful (cid=%s, kid=%s)" % (cid, kid))

		self._send_command(
			"client-auth-nt %s %s" % (cid, kid),
		)

	@staticmethod
	def _b64encode(s):
		return base64.b64encode(s.encode()).decode()

	@staticmethod
	def _b64decode(s):
		return base64.b64decode(s.encode()).decode()

	@staticmethod
	def _escape(s):
		return s.replace(" ", "\ ")

	def _find_connection(self, common_name):
		with open(OPENVPN_CONFIG, "r") as f:
			for row in csv.reader(f, dialect="unix"):
				# Skip empty rows or rows that are too short
				if not row or len(row) < 5:
					continue

				# Skip disabled connections
				if not row[1] == "on":
					continue

				# Skip any net-2-net connections
				if not row[4] == "host":
					continue

				# Skip if common name does not match
				if not row[3] == common_name:
					continue

				# Return match!
				conn = {
					"name"        : row[2],
					"common_name" : row[3],
				}

				# TOTP options
				try:
					conn |= {
						"totp_protocol" : row[43],
						"totp_status"   : row[44],
						"totp_secret"   : row[45],
					}
				except IndexError:
					pass

				return conn


	def _check_totp_token(self, token, secret):
		p = subprocess.run(
			["oathtool", "--totp", "-w", "3", "%s" % secret],
			capture_output=True,
		)

		# Catch any errors if we could not run the command
		if p.returncode:
			log.error("Could not run oathtool: %s" % p.stderr)

			return False

		# Reading returned tokens looking for a match
		for line in p.stdout.split(b"\n"):
			# Skip empty/last line(s)
			if not line:
				continue

			# Decode bytes into string
			line = line.decode()

			# Return True if a token matches
			if line == token:
				return True

		# No match
		return False


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="OpenVPN Authenticator")

	# Daemon Stuff
	parser.add_argument("--daemon", "-d", action="store_true",
		help="Launch as daemon in background")
	parser.add_argument("--verbose", "-v", action="count", help="Be more verbose")

	# Paths
	parser.add_argument("--socket", default="/var/run/openvpn.sock",
		metavar="PATH", help="Path to OpenVPN Management Socket")

	# Parse command line arguments
	args = parser.parse_args()

	# Setup logging
	loglevel = logging.WARN

	if args.verbose:
		if args.verbose == 1:
			loglevel = logging.INFO
		elif args.verbose >= 2:
			loglevel = logging.DEBUG

	# Create an authenticator
	authenticator = OpenVPNAuthenticator(args.socket)

	with daemon.DaemonContext(
		detach_process=args.daemon,
		stderr=None if args.daemon else sys.stderr,
		signal_map = {
			signal.SIGINT  : authenticator.terminate,
			signal.SIGTERM : authenticator.terminate,
		},
	) as daemon:
		setup_logging(daemon=args.daemon, loglevel=loglevel)

		authenticator.run()
