#!/usr/bin/python3
#
# Search apache logs for high-bandwith users and create a list of suspicious IPs.
# There are three states: bulk, block, ban. The first are bulk requesters
# that need throtteling, the second bulk requesters that have overdone it
# and the last manually banned IPs.
#

import re
import os
import sys
import subprocess
from datetime import datetime, timedelta
from collections import defaultdict

#
# DEFAULT SETTINGS
#
# Copy into settings/ip_blcoks.conf and adapt as required.
#
BASEDIR = os.path.normpath(os.path.join(os.path.realpath(__file__), '../..'))
BLOCKEDFILE= BASEDIR + '/settings/ip_blocks.map'
LOGFILE= BASEDIR + '/log/restricted_ip.log'

# space-separated list of IPs that are never banned
WHITELIST = ''
# space-separated list of IPs manually blocked
BLACKLIST = ''
# user-agents that should be blocked from bulk mode
# (matched with startswith)
UA_BLOCKLIST = ()

# time before a automatically blocked IP is allowed back
BLOCKCOOLOFF_DELTA=timedelta(hours=1)
# quiet time before an IP is released from the bulk pool
BULKCOOLOFF_DELTA=timedelta(minutes=15)
# time to check if new accesses appear despite being blocked
BLOCKCHECK_DELTA=timedelta(minutes=1)

BULKLONG_LIMIT=8000
BULKSHORT_LIMIT=2000
BLOCK_UPPER=19000
BLOCK_LOWER=4000
BLOCK_LOADFAC=380
BULK_LOADFAC=160
BULK_LOWER=1500
MAX_BULK_IPS=85

#
# END OF DEFAULT SETTINGS
#

try:
    with open(BASEDIR + "/settings/ip_blocks.conf") as f:
        code = compile(f.read(), BASEDIR + "/settings/ip_blocks.conf", 'exec')
        exec(code)
except IOError:
    pass

BLOCK_LIMIT = BLOCK_LOWER

time_regex = r'(?P<t_day>\d\d)/(?P<t_month>[A-Za-z]+)/(?P<t_year>\d\d\d\d):(?P<t_hour>\d\d):(?P<t_min>\d\d):(?P<t_sec>\d\d) [+-]\d\d\d\d'

format_pat= re.compile(r'(?P<ip>[a-f\d\.:]+) - - \['+ time_regex + r'] "(?P<query>.*?)" (?P<return>\d+) (?P<bytes>\d+) "(?P<referer>.*?)" "(?P<ua>.*?)"')
time_pat= re.compile(r'[a-f\d:\.]+ - - \[' + time_regex + '\] ')

logtime_pat = "%d/%b/%Y:%H:%M:%S %z"

MONTHS = { 'Jan' : 1, 'Feb' : 2, 'Mar' : 3, 'Apr' : 4, 'May' : 5, 'Jun' : 6,
           'Jul' : 7, 'Aug' : 8, 'Sep' : 9, 'Oct' : 10, 'Nov' : 11, 'Dec' : 12 }

class LogEntry:
    def __init__(self, logline):
        e = format_pat.match(logline)
        if e is None:
            raise ValueError("Invalid log line:", logline)
        e = e.groupdict()
        self.ip = e['ip']
        self.date = datetime(int(e['t_year']), MONTHS[e['t_month']], int(e['t_day']),
                             int(e['t_hour']), int(e['t_min']), int(e['t_sec']))
        qp = e['query'].split(' ', 2) 
        if len(qp) < 2:
            self.request = None
            self.query = None
        else:
            self.query = qp[1]
            if qp[0] == 'OPTIONS':
                self.request = None
            else:
                if '/?' in qp[1]:
                    self.request = 'S'
                elif '/search' in qp[1]:
                    self.request = 'S'
                elif '/reverse' in qp[1]:
                    self.request = 'R'
                elif '/details' in qp[1]:
                    self.request = 'D'
                elif '/lookup' in qp[1]:
                    self.request = 'L'
                else:
                    self.request = None
        self.query = e['query']
        self.retcode = int(e['return'])
        self.referer = e['referer'] if e['referer'] != '-' else None
        self.ua = e['ua'] if e['ua'] != '-' else None

    def get_log_time(logline):
        e = format_pat.match(logline)
        if e is None:
            return None
        e = e.groupdict()
        #return datetime.strptime(e['time'], logtime_pat).replace(tzinfo=None)
        return datetime(int(e['t_year']), MONTHS[e['t_month']], int(e['t_day']),
                             int(e['t_hour']), int(e['t_min']), int(e['t_sec']))


class LogFile:
    """ An apache log file, unpacked. """

    def __init__(self, filename):
        self.fd = open(filename)
        self.len = os.path.getsize(filename)

    def __del__(self):
        self.fd.close()

    def seek_next(self, abstime):
        self.fd.seek(abstime)
        self.fd.readline()
        l = self.fd.readline()
        return LogEntry.get_log_time(l) if l is not None else None

    def seek_to_date(self, target):
        # start position for binary search
        fromseek = 0
        fromdate = self.seek_next(0)
        if fromdate > target:
            return True
        # end position for binary search
        toseek = -100
        while -toseek < self.len:
            todate = self.seek_next(self.len + toseek)
            if todate is not None:
                break
            toseek -= 100
        if todate is None or todate < target:
            return False
        toseek = self.len + toseek


        while True:
            bps = (toseek - fromseek) / (todate - fromdate).total_seconds()
            newseek = fromseek + int((target - fromdate).total_seconds() * bps)
            newdate = self.seek_next(newseek)
            if newdate is None:
                return False;
            error = abs((target - newdate).total_seconds())
            if error < 1:
                return True
            if newdate > target:
                toseek = newseek
                todate = newdate
                oldfromseek = fromseek
                fromseek = toseek - error * bps
                while True:
                    if fromseek <= oldfromseek:
                        fromseek = oldfromseek
                        fromdate = self.seek_next(fromseek)
                        break
                    fromdate = self.seek_next(fromseek)
                    if fromdate < target:
                        break;
                    bps *=2
                    fromseek -= error * bps
            else:
                fromseek = newseek
                fromdate = newdate
                oldtoseek = toseek
                toseek = fromseek + error * bps
                while True:
                    if toseek > oldtoseek:
                        toseek = oldtoseek
                        todate = self.seek_next(toseek)
                        break
                    todate = self.seek_next(toseek)
                    if todate > target:
                        break
                    bps *=2
                    toseek += error * bps
            if toseek - fromseek < 500:
                return True


    def loglines(self):
        for l in self.fd:
            try:
                yield LogEntry(l)
            except ValueError:
                pass # ignore invalid lines

class BlockList:

    def __init__(self):
        self.whitelist = set(WHITELIST.split()) if WHITELIST else set()
        self.blacklist = set(BLACKLIST.split()) if BLACKLIST else set()
        self.prevblocks = set()
        self.prevbulks = set()

        try:
            fd = open(BLOCKEDFILE)
            for line in fd:
                ip, typ = line.strip().split(' ')
                if ip not in self.blacklist:
                    if typ == 'block':
                        self.prevblocks.add(ip)
                    elif typ == 'bulk':
                        self.prevbulks.add(ip)
            fd.close()
        except IOError:
            pass #ignore non-existing file


class IPstats:

    def __init__(self):
        self.redirected = 0
        self.short_total = 0
        self.short_api = 0
        self.long_total = 0
        self.long_api = 0
        self.block_total = 0
        self.bad_ua = False

    def add_long(self, logentry):
        self.long_total += 1
        if logentry.retcode == 301:
            return
        if logentry.request is not None:
            self.long_api += 1
        if not self.bad_ua:
            if logentry.ua is None:
                self.bad_ua = True

    def add_short(self, logentry):
        self.short_total += 1
        if logentry.retcode == 301:
            self.redirected += 1
            return
        if logentry.request is not None:
            self.short_api += 1
        self.add_long(logentry)

    def add_block(self, logentry):
        self.block_total += 1

    def ignores_warnings(self, wasblocked):
        return self.block_total > 5 or (wasblocked and self.redirected > 5)

    def new_state(self, was_blocked, was_bulked):
        if was_blocked:
            # deblock only if the IP has been really quiet
            # (properly catches the ones that simply ignore the HTTP error)
            return None if self.long_total < 20 else 'block'
        if self.long_api > BLOCK_UPPER \
            or self.short_api > BLOCK_UPPER / 3 \
            or (self.redirected > 100 and self.short_total == self.redirected):
                # client totally overdoing it
                return 'block'
        if was_bulked:
            if self.short_total < 20:
                # client has stopped, debulk
                return None
            if self.long_api > BLOCK_LIMIT or self.short_api > BLOCK_LIMIT / 3:
                # client is still hammering us, block
                return 'emblock'
            return 'bulk'

        if self.long_api > BULKLONG_LIMIT or self.short_api > BULKSHORT_LIMIT:
            #if self.bad_ua:
            #    return 'uablock' # bad useragent
            return 'bulk'

        return None



if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: %s logfile startdate" % sys.argv[0])
        sys.exit(-1)

    if len(sys.argv) == 2:
        dt = datetime.now() - BLOCKCOOLOFF_DELTA
    else:
        dt = datetime.strptime(sys.argv[2], "%Y-%m-%d %H:%M:%S")

    if os.path.getsize(sys.argv[1]) < 2*1030*1024:
        sys.exit(0) # not enough data

    lf = LogFile(sys.argv[1])
    if not lf.seek_to_date(dt):
        sys.exit(0)

    bl = BlockList()

    shortstart = dt + BLOCKCOOLOFF_DELTA - BULKCOOLOFF_DELTA
    blockstart = dt + BLOCKCOOLOFF_DELTA - BLOCKCHECK_DELTA
    notlogged = bl.whitelist | bl.blacklist

    stats = defaultdict(IPstats)

    for l in lf.loglines():
        if l.ip not in notlogged:
            stats[l.ip].add_long(l)
        if l.date > shortstart:
            break

    total200 = 0
    for l in lf.loglines():
        if l.ip not in notlogged:
            stats[l.ip].add_short(l)
        if l.request is not None and l.retcode == 200:
            total200 += 1
        if l.date > blockstart and l.retcode in (403, 429):
            stats[l.ip].add_block(l)

    # adapt limits according to CPU and DB load
    fd = open("/proc/loadavg")
    cpuload = int(float(fd.readline().split()[2]))
    fd.close()
    # check the number of excess connections to apache
    dbcons = int(subprocess.check_output("netstat -s | grep 'connections established' | sed 's:^\s*::;s: .*::'", shell=True))
    fpms = int(subprocess.check_output('ps -Af | grep php-fpm | wc -l', shell=True))
    dbload = max(0, dbcons - fpms)

    numbulks = len(bl.prevbulks)
    BLOCK_LIMIT = max(BLOCK_LIMIT, BLOCK_UPPER - BLOCK_LOADFAC * dbload)
    BULKLONG_LIMIT = max(BULK_LOWER, BULKLONG_LIMIT - BULK_LOADFAC * cpuload)
    if numbulks > MAX_BULK_IPS:
        BLOCK_LIMIT = max(3600, BLOCK_LOWER - (numbulks - MAX_BULK_IPS)*10)
    # if the bulk pool is still empty, clients will be faster, avoid having
    # them blocked in this case
    if numbulks < 10:
        BLOCK_UPPER *= 2
        BLOCK_LIMIT = BLOCK_UPPER


    # collecting statistics
    unblocked = []
    debulked = []
    bulked = []
    blocked = []
    uablocked = []
    emblocked = []
    # write out new state file
    fd = open(BLOCKEDFILE, 'w')
    for k,v in stats.items():
        wasblocked = k in bl.prevblocks
        wasbulked = k in bl.prevbulks
        state = v.new_state(wasblocked, wasbulked)
        if state is not None:
            if state == 'uablock':
                uablocked.append(k)
                state = 'block'
            elif state == 'emblock':
                emblocked.append(k)
                state = 'block'
            elif state == 'block':
                if not wasblocked:
                    blocked.append(k)
            elif state == 'bulk':
                if not wasbulked:
                    bulked.append(k)
            fd.write("%s %s\n" % (k, state))
        else:
            if wasblocked:
                unblocked.append(k)
            elif wasbulked:
                debulked.append(k)
    for i in bl.blacklist:
        fd.write("%s ban\n" % i)
    fd.close()

    # TODO write logs (need to collect some statistics)
    logstr = datetime.now().strftime('%d/%b/%Y:%H:%M:%S') + ' %s %s\n'
    fd = open(LOGFILE, 'a')
    if unblocked:
        fd.write(logstr % ('unblocked:', ', '.join(unblocked)))
    if debulked:
        fd.write(logstr % (' debulked:', ', '.join(debulked)))
    if bulked:
        fd.write(logstr % ('new bulks:', ', '.join(bulked)))
    if emblocked:
        fd.write(logstr % ('dir.block:', ', '.join(emblocked)))
    if uablocked:
        fd.write(logstr % (' ua block:', ', '.join(uablocked)))
    if blocked:
        fd.write(logstr % ('new block:', ', '.join(blocked)))
    #for k,v in stats.items():
    #    if v.ignores_warnings(k in bl.prevblocks) and k not in notlogged and ':' not in k:
    #        fd.write(logstr % ('Warning ignored:', k))
    fd.close()
