#!/usr/bin/python3
# Copyright 2017 Facebook.
# Licensed under the same terms as memcached itself.

import argparse
import socket
import sys
import re
import traceback
from time import sleep

parser = argparse.ArgumentParser(description="daemon for rebalancing slabs")
parser.add_argument("--host", help="host to connect to",
        default="localhost:11211", metavar="HOST:PORT")
parser.add_argument("-s", "--sleep", help="seconds between runs",
                    type=int, default="1")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-a", "--automove", action="store_true", default=False,
                    help="enable automatic page rebalancing")
parser.add_argument("-w", "--window", type=int, default="30",
                    help="rolling window size for decision history")
parser.add_argument("-r", "--ratio", type=float, default=0.8,
                    help="ratio limiting distance between low/high class ages")

# TODO:
# - age adjuster function
#   - by ratio of get_hits
#   - by ratio of chunk size

args = parser.parse_args()

host, port = args.host.split(':')

def window_check(history, sid, key):
    total = 0
    for window in history['w']:
        s = window.get(sid)
        if s and s.get(key):
            total += s.get(key)
    return total

def determine_move(history, diffs, totals):
    """ Figure out of a page move is in order.

    - if > 2.5 pages of free space without free chunks reducing for N trials,
      and no evictions for N trials, free to global.
    - use ratio of how far apart age can be between slab classes
    - if get_hits exists, allowable distance in age from the *youngest* slab
      class is based on the percentage of get_hits the class gets, against the
      factored max distance, ie:
      1% max delta. youngest is 900, oldest allowable is 900+90
      if class has 30% of get_hits, it can be 930
    - youngest evicting slab class gets page moved to it, if outside ratio max
    """
    w = history['w'][-1]
    oldest = (-1, 0)
    youngest = (-1, sys.maxsize)
    decision = (-1, -1)
    for sid, slab in diffs.items():
        w[sid] = {}
        if 'evicted_d' not in slab or 'total_pages_d' not in slab:
            continue
        # mark this window as dirty if total pages increases or evictions
        # happened
        if slab['total_pages_d'] > 0:
            w[sid]['dirty'] = 1
        if slab['evicted_d'] > 0:
            w[sid]['dirty'] = 1
            w[sid]['ev'] = 1

        # if > 2.5 pages free, and not dirty, reassign to global page pool and
        # break.
        if slab['free_chunks'] > slab['chunks_per_page'] * 2.5:
            if window_check(history, sid, 'dirty') == 0:
                decision = (sid, 0)
                break

        # are we the oldest slab class? (and a valid target)
        if slab['age'] > oldest[1] and slab['total_pages'] > 5:
            oldest = (sid, slab['age'])

        # are we the youngest evicting slab class?
        if slab['age'] < youngest[1] and window_check(history, sid, 'ev') > args.window / 2:
            youngest = (sid, slab['age'])

    # is the youngest slab class too young?
    if args.verbose:
        print("old, young:", oldest, youngest)
    if youngest[0] != -1 and oldest[0] != -1:
        if youngest[1] < oldest[1] * args.ratio:
            decision = (oldest[0], youngest[0])

    # rotate windows
    history['w'].append({})
    # TODO: Configurable window size
    if (len(history['w']) > args.window):
        history['w'].pop(0)
    return decision


def run_move(s, decision):
    s.write("slabs reassign " + str(decision[0]) + " " + str(decision[1]) + "\r\n")
    line = s.readline().rstrip()
    if args.verbose:
        print("move result:", line)


def diff_stats(before, after):
    diffs = {}
    totals = {}
    for slabid in after.keys():
        sb = before.get(slabid)
        sa = after.get(slabid)
        if not (sb and sa):
            continue
        slab = sa.copy()
        for k in sa.keys():
            if k not in sb:
                continue
            if k not in totals:
                totals[k] = 0
                totals[k + '_d'] = 0
            if k + '_d' not in slab:
                slab[k + '_d'] = 0
            if re.search(r"^\d+$", sa[k]):
                totals[k] += int(sa[k])
                slab[k] = int(sa[k])
                slab[k + '_d'] = int(sa[k]) - int(sb[k])
                totals[k + '_d'] += int(sa[k]) - int(sb[k])
        slab['slab'] = slabid
        diffs[slabid] = slab
    return (diffs, totals)


def read_stats(s):
    slabs = {}
    for statcmd in ['items', 'slabs']:
        #print("stat cmd: " + statcmd)
        # FIXME: Formatting
        s.write("stats " + statcmd + "\r\n")
        while True:
            line = s.readline().rstrip()
            if line.startswith("END"):
                break

            m = re.match(r"^STAT (?:items:)?(\d+):(\S+) (\S+)", line)
            if m:
                (slab, var, val) = m.groups()
                if slab not in slabs:
                    slabs[slab] = {}
                slabs[slab][var] = val
            #print("line: " + line)
    return slabs


def pct(num, divisor):
    if not divisor:
        return 0
    return (num / divisor)


def show_detail(diffs, totals):
    print("  {:2s}: {:8s} (pct  ) {:10s} (pct    ) {:6s} (pct)".format('sb',
                'evicted', 'items', 'pages'))

    for sid, slab in diffs.items():
        if 'evicted_d' not in slab:
            continue
        print("  {:2d}: {:8d} ({:.2f}%) {:10d} ({:.4f}%) {:6d} ({:.2f}%)".format(
              int(sid), slab['evicted_d'], pct(slab['evicted_d'], totals['evicted_d']),
              slab['number'], pct(slab['number'], totals['number']),
              slab['total_pages'], pct(slab['total_pages'], totals['total_pages'])))


stats_pre = {}
history = { 'w': [{}] }
while True:
    try:
        with socket.create_connection((host, port), 5) as c:
            s = c.makefile(mode="rw", buffering=1)
            s.write("slabs automove 0\r\n")
            print(s.readline().rstrip())
            while True:
                stats_post = read_stats(s)
                (diffs, totals) = diff_stats(stats_pre, stats_post)
                if args.verbose:
                    show_detail(diffs, totals)
                decision = determine_move(history, diffs, totals)
                if decision[0] != -1 and decision[1] != -1:
                    print("moving page from, to:", decision)
                    if args.automove:
                        run_move(s, decision)

                # Minimize sleeping if we just moved a page to global pool.
                # Improves responsiveness during flushes/quick changes.
                if decision[1] == 0:
                    sleep(0.05)
                else:
                    sleep(args.sleep)
                stats_pre = stats_post
    except:
        err = sys.exc_info()
        print("disconnected:", err[0], err[1])
        traceback.print_exc()
        stats_pre = {}
        history = { 'w': [{}] }
        sleep(args.sleep)

