###############################################################################
#  msca.py: Module for experiments with Multi-subset Crossproduct Algorithm
#  Copyright (C) 2009 Brno University of Technology, ANT @ FIT
#  Author(s): Martin Spinler <xspinl00@stud.fit.vutbr.cz>, Viktor Pus <ipus@fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id$

"""
Module for experiments with Multi-subset Crossproduct Algorithm
"""

import sys
import math
from netbench.classification.rule import *
from netbench.classification.ruleset import *
from netbench.classification.prefix import *
from netbench.classification.prefixset import *
from netbench.classification.common import *

from netbench.common.packetheader import *
from struct import *
import bclassification

class BloomFilter:
    def __init__(self, capacity, error_rate = 0.0005):
        self._slices = int(math.ceil(math.log(1 / error_rate, 2)))
        self._bits = int(math.ceil((2 * capacity * abs(math.log(error_rate))) / (self._slices * (math.log(2) ** 2))))
        self._arraysize = self._slices * self._bits
        self._bitarray = [ False for n in range(self._slices * self._bits) ]
    
    def get_array_size(self):
        return self._arraysize

    def add(self, item):
        hashes = self.hashItem(item)
        for i in hashes:
            self._bitarray[i] = True
        
    def __contains__(self, item):
        hashes = self.hashItem(item)
        for i in hashes:
            if not self._bitarray[i]:
                return False
        return True

    def hashItem(self, item):
        hashes = []
        for i in range(self._slices):
            p = item
            p.append(str(i))
            hashes.append(hash(tuple(item)) % self._arraysize)
        return hashes

def NLT_compare(a, b):
    if len(a[0]) > len(b[0]):
        return 1
    elif len(a[0]) < len(b[0]):
        return -1
    else:
        return 0

def subset_NLTSS(ruleset, subset_count):
    NLTree = {}
    prefixes = ruleset.get_prefixes()
    for c in ruleset.get_field_names():
        # Create root of Nested Level Tree as tuple (list of childs, prefix, nested level)
        NLTree[c] = ([], Prefix(0,0,0), 0) 
        # We need prefix in reverse order, from general to specific 
        prefixes[c].get_prefixes().reverse()
        for p in prefixes[c].get_prefixes():
            next_node = NLTree[c]
            nl = 1
            while 1:
                # Go through tree
                node = next_node
                for i in node[0]:
                    # Select next node
                    if(i[1].covers(p)):
                        nl += 1 
                        next_node = i
                        break
                else:
                    # If none meets, append to end
                    node[0].append( ([], p, nl) )
                    break
    # Create list of Nested Level Tuple (NLT).
    # NLT = tuple (rules list, {nl[1], nl[2] ... nl[n]})
    NLTs = [ ]
    for r in ruleset.get_rules():
        nl = {}
        for c in ruleset.get_field_names():
            next_node = NLTree[c]
            while 1:
                # Find leaf for rule in Nested Level Tree
                node = next_node
                for i in node[0]:
                    if(i[1].covers(r.get_condition(c)[0])):
                        next_node = i
                        break
                else:
                    nl[c] = node[2]
                    break
        # Assign rule to NLT 
        for n in NLTs:
            # If NLT already exist, append rule to NLT rule set
            if(n[1] == nl):
                n[0].append(r)
                break
        else:
            # Otherwise create new NLT with one rule
            NLTs.append( ([r], nl) )

    # Sort NLT set by size of its rule set
    NLTs.sort(NLT_compare, reverse=True)

    subset = [RuleSet() for i in range(subset_count)]
    if(subset_count > len(NLTs)):
        subset_count = len(NLTs)
    # First assign the first (subset_count) largest sets.
    for i in range(subset_count):
        for r in NLTs[i][0]:
            subset[i].add_rule(r)
    # Now assign the rest of the rules
    for i in range(subset_count, len(NLTs)):
        distance = []
        min_distance = 0
        # Compute "distance" of the NLT
        for s in range(subset_count):
            d = 0
            for c in ruleset.get_field_names():
                d += abs(NLTs[s][1][c] - NLTs[i][1][c])
            distance.append(d)
            if(d < distance[min_distance]):
                min_distance = len(distance) - 1

        # If two or more distances are equal, find the smallest set
        min_size = 0
        for s in range(subset_count):
            if(distance[min_distance] == distance[s] and len(NLTs[s][0]) < subset[min_size].count_rules()):
                min_size = s

        # Fill with rules
        for r in NLTs[i][0]:
            subset[min_size].add_rule(r)

    return subset
    
def subset_simple(ruleset, subset_count):
    subset = []
    j = (ruleset.count_rules() + subset_count - 1) / subset_count
    for i in range(subset_count):
        subset.append(RuleSet())
        rules = ruleset.get_rules()[i*j:(i+1)*j]
        for rule in rules:
            subset[i].add_rule(rule)
    
    return subset

class MSCA(bclassification.BClassification):
    """
    Multi-subset Crossproduct Algorithm.
    """
    def __init__(self, subset_count = 3):
        self._subset_count = subset_count
        self._cond_names = []
        self._prefixes = []
        self._subset = []

    def load_ruleset(self, ruleset):
        self._filters = []
        self._table = []
        self._table_size = []

        # Divide one ruleset to n independent subsets
#        self._subset = subset_simple(ruleset, self._subset_count)

        ruleset.add_universal_prefixes()
        print "Number of rules:",ruleset.count_rules()
        ruleset.expand_prefixsets()
        print "After expanding prefixsets:",ruleset.count_rules()
        if ruleset.count_rules() > 8:
            print "Removing 8 rules, that generates most pseudorules..."
            s = ruleset.remove_spoilers(8)
            print "Removed rules priorities:", [r.get_priority() for r in s]
            ruleset.remove_covered()


        self._subset = subset_NLTSS(ruleset, self._subset_count)
#        self._subset = subset_simple(ruleset, self._subset_count)
        self._cond_names = ruleset.get_field_names()

        self._headerSize = {}
        for c in self._cond_names:
            fp = cond2field(c)
            if (fp == None):
                raise NotImplemented("Condition \"" + c + "\" is not specified in cond2field function.")
            else:
                fp = fp[0]
                if (fp[2] == "ipv4str" or fp[2] == "int"):
                    self._headerSize[c] = 32
                elif (fp[2] == "int16"):
                    self._headerSize[c] = 16
                elif (fp[2] == "int8"):
                    self._headerSize[c] = 8
                elif (fp[2] == "macstr"):
                    self._headerSize[c] = 48
                else:
                    raise NotImplemented("Size of header type \"" + c + "\" is not specified.")

        for s in range(self._subset_count):
            # Expand pseudorules
            print "Before: ", self._subset[s].count_rules()
            self._subset[s].expand_pseudorules()
            print "After: " , self._subset[s].count_rules()

            # Create hash table with rules
            # FIXME: Change table size and Bloom filter size
            self._table_size.append(self._subset[s].count_rules())
            self._table.append([ [] for i in range(self._table_size[s]) ])

            # Create Bloom filter
            self._filters.append(BloomFilter(self._table_size[s]))
            self._prefixes.append(self._subset[s].get_prefixes())
            
            # Store each rule into Bloom filter
            for r in self._subset[s].get_rules():
                # Create tuple from all used fields in rule
                t = []
                for k in self._cond_names:
                    if k == 'protocol':
                        rr = ( r.get_condition(k)[0].get_value(), r.get_condition(k)[0].get_value())
                    else:
                        rr = r.get_condition(k)[0].get_range()
                    t.append(rr[0])
                    t.append(rr[1])
                # Store tuple in Bloom filter
                self._filters[s].add(t)
                # Compute hash for tuple
                hashes = self._filters[s].hashItem(t)
                hs = sum(hashes) % self._table_size[s]
                # Add rule into hash table
                self._table[s][hs].append(r)

    def classify(self, packetheader):
        rule = None
        prefix = [ {} for i in range(self._subset_count)]
        for c in self._cond_names:
            try:
                # Get value from packet and convert it to integer
                field = extract_field(packetheader, c)
            except FieldNotFoundError:
                raise BadPacketError("Field '%s' not found."%c)
            
            # Do a LPM for each field
            for s in range(self._subset_count):
                if(self._prefixes[s].has_key(c)):
                    prefix[s][c] = self._prefixes[s][c].match_longest(field)
                else:
                    prefix[s][c] = None

        priority = None
        for s in range(self._subset_count):
            # Empty tuple
            t = []
            skip = False
            for c in self._cond_names:
                p = prefix[s][c]
                if(p is None):
                    skip = True
                    break
                # Add LPM value into tuple
                if c == 'protocol':
                    t.append(p.get_value())
                    t.append(p.get_value())
                else:
                    t.append(p.get_range()[0])
                    t.append(p.get_range()[1])
            # If an field doesn't have a LPM, skip whole subset
            if(skip == True):
                continue
            # Is item in Bloom filter?
            if t in self._filters[s]:
                # Compute hash = index to hash table
                hashes = self._filters[s].hashItem(t)
                hs = sum(hashes) % self._table_size[s]
                for r in self._table[s][hs]:
                    # Which rule matches the packet header?
                    if(r.match(packetheader) and (priority is None or priority > r.get_priority())):
                        priority = r.get_priority()
                        rule = r
        if rule:
            return [rule]
        else:
            return []

    def report_memory(self):
        """
        Print detailed info about algorithm memory requirements.
        """
        print "========== MSCA memory report =========="
        rulesCount = sum([s.count_rules() for s in self._subset])
        bloomFilter = sum([f.get_array_size() for f in self._filters])
        headerSize = 127 #sum(self._headerSize[c] for c in self._cond_names) +...
        rulesInMemory = sum([p.count_rules() for p in self._subset]) * headerSize

        prefixSize = 0
        for c in self._cond_names:
            ps = 0
            for s in range(self._subset_count):
                if(self._prefixes[s].has_key(c)):
                    ps = ps + len(self._prefixes[s][c]) * math.ceil(self._headerSize[c]+1)
            prefixSize = prefixSize + ps

        print "Number of all rules:                ", rulesCount
        print "Header size (bits):                 ", headerSize
        print "Size of Bloom filter table (bits):  ", bloomFilter
        print "Size in memory for all rules (bits):", rulesInMemory
        print "------------------------------------"
        print "Size in memory (bloom+rules):       ", rulesInMemory + bloomFilter
        print "Added bits to LPM:                  ", prefixSize


