#!/usr/bin/python3
###############################################################################
#                                                                             #
# IPFire.org - A linux based firewall                                         #
# Copyright (C) 2007-2022  IPFire Team  <info@ipfire.org>                     #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
###############################################################################

import argparse
import logging
import logging.handlers
import os
import sqlite3
import sys

_ = lambda x: x

DEFAULT_DATABASE_PATH = "/var/ipfire/ovpn/clients.db"

def setup_logging(level=logging.INFO):
	l = logging.getLogger("openvpn-metrics")
	l.setLevel(level)

	# Log to console
	h = logging.StreamHandler()
	h.setLevel(logging.DEBUG)
	l.addHandler(h)

	# Log to syslog
	h = logging.handlers.SysLogHandler(address="/dev/log",
		facility=logging.handlers.SysLogHandler.LOG_DAEMON)
	h.setLevel(logging.INFO)
	l.addHandler(h)

	# Format syslog messages
	formatter = logging.Formatter("openvpn-metrics[%(process)d]: %(message)s")
	h.setFormatter(formatter)

	return l

# Initialise logging
log = setup_logging()

class OpenVPNMetrics(object):
	def __init__(self):
		self.db = self._open_database()

	def parse_cli(self):
		parser = argparse.ArgumentParser(
			description=_("Tool that collects metrics of OpenVPN Clients"),
		)
		subparsers = parser.add_subparsers()

		# client-connect
		client_connect = subparsers.add_parser("client-connect",
			help=_("Called when a client connects"),
		)
		client_connect.add_argument("file", nargs="?",
			help=_("Configuration file")
		)
		client_connect.set_defaults(func=self.client_connect)

		# client-disconnect
		client_disconnect = subparsers.add_parser("client-disconnect",
			help=_("Called when a client disconnects"),
		)
		client_disconnect.add_argument("file", nargs="?",
			help=_("Configuration file")
		)
		client_disconnect.set_defaults(func=self.client_disconnect)

		# Parse CLI
		args = parser.parse_args()

		# Print usage if no action was given
		if not "func" in args:
			parser.print_usage()
			sys.exit(2)

		return args

	def __call__(self):
		# Parse command line arguments
		args = self.parse_cli()

		# Call function
		try:
			ret = args.func(args)
		except Exception as e:
			log.critical(e)

		# Return with exit code
		sys.exit(ret or 0)

	def _open_database(self, path=DEFAULT_DATABASE_PATH):
		db = sqlite3.connect(path)

		# Create schema if it doesn't exist already
		db.executescript("""
			CREATE TABLE IF NOT EXISTS sessions(
				common_name TEXT NOT NULL,
				connected_at TEXT NOT NULL,
				disconnected_at TEXT,
				bytes_received INTEGER,
				bytes_sent INTEGER
			);

			-- Create index for speeding up searches
			CREATE INDEX IF NOT EXISTS sessions_common_name ON sessions(common_name);
		""")

		return db

	def _get_environ(self, key):
		if not key in os.environ:
			sys.stderr.write("%s missing from environment\n" % key)
			raise SystemExit(1)

		return os.environ.get(key)

	def client_connect(self, args):
		common_name = self._get_environ("common_name")

		# Time
		time_ascii = self._get_environ("time_ascii")
		time_unix  = self._get_environ("time_unix")

		log.info("Opening session for %s at %s" % (common_name, time_ascii))

		c = self.db.cursor()
		c.execute("INSERT INTO sessions(common_name, connected_at) \
			VALUES(?, DATETIME(?, 'unixepoch'))", (common_name, time_unix))
		self.db.commit()

	def client_disconnect(self, args):
		common_name = self._get_environ("common_name")
		duration    = self._get_environ("time_duration")

		# Collect some usage statistics
		bytes_received = self._get_environ("bytes_received")
		bytes_sent     = self._get_environ("bytes_sent")

		log.info("Closing session for %s after %ss and receiving/sending %s/%s bytes" \
			% (common_name, duration, bytes_received, bytes_sent))

		c = self.db.cursor()
		c.execute("UPDATE sessions SET disconnected_at = DATETIME(connected_at, '+'  || ? || ' seconds'), \
                        bytes_received = ?, bytes_sent = ? \
                        WHERE common_name = ? AND disconnected_at IS NULL",
			(duration, bytes_received, bytes_sent, common_name))
		self.db.commit()

def main():
	m = OpenVPNMetrics()
	m()

main()
