###############################################################################
#  dcfl.py: Module for experiments with DCFL algorithm
#  Copyright (C) 2010 Brno University of Technology, ANT @ FIT
#  Author(s): Vaclav Bartos <xbarto11@stud.fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id: dcfl.py 485 2010-09-21 13:54:44Z xbarto11 $

"""
Module for experiments with Distributed Crossproducting of Field Labels
algorithm.
"""

import sys
import copy
import math
import bclassification
from netbench.classification.rule import *
from netbench.classification.ruleset import *
from netbench.classification.prefix import *
from netbench.classification.prefixset import *
from netbench.classification.common import *
from netbench.common.packetheader import *

class DCFL(bclassification.BClassification):
    """
    Distributed Crossproducting of Field Labels algorithm.
    """

    def __init__(self, mem_width=72, verbose=True):
        """Constructor"""
        
        # Given ruleset
        self._ruleset = None
        
        # List of all condition names (dimensions) used in ruleset
        self._cond_names = []
        
        # List of all prefixes (one PrefixSet for each dimension)
        self._prefixes = {}
        
        # Set memory width
        self._mem_width = mem_width
        
        # Set verbosity
        self._verbose = verbose
    

    def load_ruleset(self, ruleset, best_seq=None):
        """
        Load rules and create optimal aggregation network.
        
        ruleset - Intance of Ruleset class
        """
        
        vrbs = self._verbose
        
        # -v---- Common preprocessing ----v-
        orig_count = ruleset.count_rules()
        
        if (orig_count == 0):
            print "ERROR: Ruleset is empty"
            return
        
        # Copy ruleset
        rs = RuleSet()
        for rule in ruleset.get_rules():
            rs.add_rule(rule)
        
        # Filter out duplicates and rules covered by other rules
        rs.remove_covered()
        
        print "Number of rules:", rs.count_rules(),
        print "(%d loaded, %d filtered out)"%(orig_count,orig_count-rs.count_rules())
        
        
        # Get names of fields used in the ruleset
        #cond_names = rs.get_field_names()
        cond_names = ['protocol','src_port','dst_port','srcipv4','dstipv4']
        
        # Add zero-length prefixes to non-specified fields and expand prefixsets 
        rs.add_universal_prefixes(cond_names)
        rs.expand_prefixsets()
        print "After expanding prefixsets:",rs.count_rules()
        print
        
        # -^------------------------------^-
        
        if len(cond_names) > 7:
            print "WARNING: There is more than 7 dimensions, finding the best configuration may take a very long time and huge amount of memory."
        elif len(cond_names) > 5:
            print "WARNING: There is more than 5 dimensions, finding the best configuration may take a long time."
        
        # Create lists of unique prefixes (dict with one PrefixSet for every dimension)
        prefixes = rs.get_prefixes(cond_names)
        
        dims = len(cond_names)  # Number of dimensions
        node_count = dims-1     # Number of aggregation nodes
        
        # Generate list of all permutations of dimensions
        
        def get_perms(x):
            if len(x) == 1:
                return (x,)
            out = []
            for i in range(len(x)):
                y = x[:i] + x[i+1:]
                out.extend([(x[i],) + a for a in get_perms(y)])
            return out
        
        if best_seq:
            # If best sequence was given, use it
            permutations = [best_seq]
        else:
            permutations = get_perms(tuple(range(len(cond_names))))
        
        # Get sets of possible LPM outputs in each dimension
        lpm_sets = [ [] for i in range(dims)] 
        for i,cn in enumerate(cond_names):
            lpm_sets[i] = map(lambda x:(x,), prefixes[cn])
        
        # Get maximum size of LPM outputs
        lpm_max_output_size = []
        for s in lpm_sets:
            lpm_max_output_size.append(max_num_of_matching_fields(s))
        
        mem_width = self._mem_width
        
        # Initialize best SMA and MEM to very high values
        self._sma = 100000000
        self._mem = 1000000000000
        
        print "Trying all configurations of aggregation network ..."
        if vrbs:
            print "Dimensions:"
            for i,cn in enumerate(cond_names):
                print "   %i:"%i,cn
            print "(Sequence of dims - node types - sequential memory accesses - bits of memory)"
        
        for perm in permutations:
            
            if vrbs: print perm,
            
            meta_sets    = [ [] for i in range(dims)]
            lpm_sets     = [ [] for i in range(dims)]
            meta_labels  = [ [] for i in range(dims)]
            lpm_labels   = [ [] for i in range(dims)]
            valid_combs  = [ [] for i in range(dims)]
            
            for r in rs.get_rules():
                t = tuple([r.get_condition(cond_names[i])[0] for i in perm])
                meta_sets[dims-1].append(t)
            
            for dim in range(dims-2,-1,-1):
                for i in range(len(meta_sets[dim+1])):
                    left_slice  = meta_sets[dim+1][i][0:dim+1]
                    right_slice = meta_sets[dim+1][i][dim+1]
                    try:
                        vc1 = meta_sets[dim].index(left_slice)
                    except:
                        vc1 = len(meta_sets[dim])
                        meta_sets[dim].append(left_slice)
                    try:
                        vc2 = lpm_sets[dim].index(right_slice)
                    except:
                        vc2 = len(lpm_sets[dim])
                        lpm_sets[dim].append(right_slice)
                    valid_combs[dim].append( (vc1,vc2) )
            
            for dim in range(dims):
                meta_labels[dim] = range(len(meta_sets[dim]))
                lpm_labels[dim]  = range(len(lpm_sets[dim]))
            
            # Get maximum size of outputs of aggregation nodes and maximum size 
            # of queries in aggreagation nodes
            node_max_query_size = []
            node_max_output_size = []
            node_max_query_size.append(lpm_max_output_size[perm[0]] * lpm_max_output_size[perm[1]])
            node_max_output_size.append(max_num_of_matching_fields(meta_sets[0]))
            for i in range(1,node_count):
                node_max_query_size.append(node_max_output_size[i-1] * lpm_max_output_size[perm[i+1]])
                node_max_output_size.append(min(max_num_of_matching_fields(meta_sets[i]),node_max_query_size[i]))
                
            # Generate aggregation nodes
            nodes = [ [None,None,None] for i in range(node_count)]
            nodes_sma = []
            nodes_mem = []
            for i in range(0,node_count):
                # Generate all types of nodes and get their memory requirements
                nodes[i][0] = MetaLabelIndexingAggNode(len(meta_labels[i]),len(lpm_labels[i]),valid_combs[i],mem_width)
                nodes[i][1] = FieldLabelIndexingAggNode(len(meta_labels[i]),len(lpm_labels[i]),valid_combs[i],mem_width)
                nodes[i][2] = BloomFilterArrayAggNode(len(meta_labels[i]),len(lpm_labels[i]),valid_combs[i],mem_width)
                if (i == 0):
                    sma1 = nodes[i][0].get_sma(lpm_max_output_size[perm[0]],lpm_max_output_size[perm[i+1]])
                    sma2 = nodes[i][1].get_sma(lpm_max_output_size[perm[0]],lpm_max_output_size[perm[i+1]])
                    sma3 = nodes[i][2].get_sma(lpm_max_output_size[perm[0]],lpm_max_output_size[perm[i+1]])
                else:
                    sma1 = nodes[i][0].get_sma(node_max_output_size[i-1],lpm_max_output_size[perm[i+1]])
                    sma2 = nodes[i][1].get_sma(node_max_output_size[i-1],lpm_max_output_size[perm[i+1]])
                    sma3 = nodes[i][2].get_sma(node_max_output_size[i-1],lpm_max_output_size[perm[i+1]])
                
                m1 = nodes[i][0].get_memory_req()
                m2 = nodes[i][1].get_memory_req()
                m3 = nodes[i][2].get_memory_req()

                #print "-----"
                #print "SMA:", sma1, sma2, sma3
                #print "MEM:", m1, m2, m3
                
                nodes_sma.append((sma1,sma2,sma3))
                nodes_mem.append((m1,m2,m3))
            
            # Last node can't be Bloom Filter Array, so set it's SMA to very high value
            nodes_sma[-1] = (nodes_sma[-1][0],nodes_sma[-1][1],100000000)
            
            sma_limit = max(map(min,nodes_sma))  # Best achievable SMA
            
            total_mem = 0
            nodes_print = ""
            
            for i in range(0,node_count):
                possible_types = [] # Types of nodes, that fit into SMA limit
                if (nodes_sma[i][0] <= sma_limit):
                    possible_types.append(0)
                if (nodes_sma[i][1] <= sma_limit):
                    possible_types.append(1)
                if (nodes_sma[i][2] <= sma_limit):
                    possible_types.append(2)
                
                # Choose the node with minimal memory requirements 
                if (min([nodes_mem[i][x] for x in possible_types]) == nodes_mem[i][0]):
                    nodes[i] = nodes[i][0]
                    total_mem += nodes_mem[i][0]
                    nodes_print += ",M"
                    #print "%i: Meta-Label Indexing node (needs %i x %i = %i bits of memory)"%(i,nodes_mem[i][0]/mem_width,mem_width,nodes_mem[i][0])
                elif (min([nodes_mem[i][x] for x in possible_types]) == nodes_mem[i][1]):
                    nodes[i] = nodes[i][1]
                    total_mem += nodes_mem[i][1]
                    nodes_print += ",F"
                    #print "%i: Field-Label Indexing node (needs %i x %i = %i bits of memory)"%(i,nodes_mem[i][1]/mem_width,mem_width,nodes_mem[i][1])
                elif (min([nodes_mem[i][x] for x in possible_types]) == nodes_mem[i][2]):
                    nodes[i] = nodes[i][2]
                    total_mem += nodes_mem[i][2]
                    nodes_print += ",B"
                    #print "%i: Bloom Filter Array node (needs %i x %i = %i bits of memory)"%(i,nodes_mem[i][2]/mem_width,mem_width,nodes_mem[i][2])
                else:
                    raise RuntimeError("Error (%s)"%possible_types.__repr__()) #this shouldn't happen
            
            if vrbs: print "", nodes_print[1:], " SMA:", sma_limit, " MEM:", total_mem
            
            # If this configuration is better than the best found, update the best
            if (sma_limit < self._sma or (sma_limit == self._sma and total_mem < self._mem)):
                self._perm = perm
                self._sma = sma_limit
                self._mem = total_mem
                self._nodes = nodes
                self._nodes_print = nodes_print
                self._lpm_sets = [ [x[0] for x in meta_sets[0]] ]
                self._lpm_sets.extend(lpm_sets)
        
        print
        print "Best configuration:"
        print self._perm, "", self._nodes_print[1:], " SMA:", self._sma, " MEM:", self._mem 
        
        self._ruleset = rs
        self._cond_names = [cond_names[i] for i in self._perm]
#         self._nodes = best_nodes
#         self._lpm_sets = []
#         self._lpm_sets.append([x[0] for x in meta_sets[0]])
#         self._lpm_sets.extend(lpm_sets)
        

    def classify(self, packetheader):
        """
        Classify packet by the algorithm.
        Return the list of all rules that matches given packet.
        
        packetheader: Instance of the PacketHeader class.
        """
        
        lpm_output = [] # Labels of matched prefixes (a list for each dimension)
        
        # Do LPM on every header field
        for i,c in enumerate(self._cond_names): # For every dimension (condition)
            
            try:
                # Get value from packet and convert it to integer
                field = extract_field(packetheader, c)
            except FieldNotFoundError:
                raise BadPacketError("Field '%s' not found."%c)
            
            # Find all matching prefixes and save theirs labels
            matching = []
            for j,prefix in enumerate(self._lpm_sets[i]):
                if prefix.match(field):
                    matching.append(j)
            lpm_output.append(matching)
        
        
        # Now we have set of labels of matching prefixes for each dimension        
        
        # Put labels into the aggregation network to get matching rules
        last_meta = lpm_output[0]
        for i,node in enumerate(self._nodes):
            last_meta = node.process(last_meta,lpm_output[i+1])
         
        # Outputs of the last node are indexes of rules in ruleset
        output = [self._ruleset.get_rules()[i] for i in last_meta]
        
        # Sort output rules by their priority
        return sorted(output, key=Rule.get_priority)
        

    def report_memory(self):
        """
        Print detailed info about algorithm memory requirements.
        """
        
        print "========== DCFL memory report =========="
        print "Memory width was set to:", self._mem_width
        print "Maximal number of memory accesses per lookup:", self._sma
        print "Total memory requirement:", self._mem, "bits"


def max_num_of_matching_fields(set):
    maximum = 0
    for item in set:
        num = 0
        for item2 in set:
            if covers(item2,item):
                num += 1
        if (num > maximum):
            maximum = num
    
    return maximum

def covers(covering, covered):
    for i in range(len(covering)):
        if not covering[i].covers(covered[i]):
            return False
    
    return True
    

class MetaLabelIndexingAggNode:
   
    def __init__(self, meta_labels_count, labels_count, valid_combinations, mem_width):
        "valid combinations format: [(meta_label, label),...]"
        
        self.meta_labels_count = meta_labels_count
        self.labels_count = labels_count
        self.out_labels_count = len(valid_combinations)
        self.mem_width = mem_width
        
        # Check indexes in valid_combinations
        for ml, l in valid_combinations:
            if (ml >= meta_labels_count or l >= labels_count):
                raise ValueError("MetaLabelIndexigAggNode: Some index in valid_combinations is too large.") 
        
        # Create an array of lists
        self._array = [ [] for i in range(meta_labels_count)]
        
        for i, (ml, l) in enumerate(valid_combinations):
            self._array[ml].append((l,i))
        
    def display(self):
        for list in self._array:
            print list
   
    def process(self, meta_labels, labels):
        "metalabels and labels are arrays of numbers"
        
        output = []
        for ml in meta_labels:
            for l, out in self._array[ml]:
                if l in labels:
                    output.append(out)
        
        return output
   
    def get_memory_req(self, used_bits_only=False):
        """
        Returns number of bits of memory needed by the node. Some bits may
        be unused because of alignment to memory width. If you don't want to
        count these bits, set used_bits_only to True.
        """
        
        # Size of one list entry
        s = max(1,log2(self.labels_count)) + max(1,log2(self.out_labels_count))
        
        if (used_bits_only == True):
            num_of_items = sum(map(len,self._array))
            mem = s * num_of_items
            #print "Meta: Used bits:",mem
        
        else:
            n = self.mem_width / s   # number of items in one memory word
            if (n == 0):
                raise ValueError("Memory width (%i) too low, not even one list item can fit in it."%self.mem_width)
            mem_words = sum([int(math.ceil(float(len(x))/n)) for x in self._array])
            mem = self.mem_width * mem_words
            #print "Meta: Number of memory words needed (memory width",self.mem_width,"b):",mem_words
            
        return mem

    def get_sma(self, max_meta_labels, max_labels):
        "Returns maximal number of sequential memory accesses per item"
        
        # To get accurate value, we need to know all possible input sets 
        # of meta_labels, compute SMA for each one and get maximum.
        # This would be very difficult and slow to compute, so we use maximum 
        # number of memory words per an array item.
        
        s = max(1,log2(self.labels_count)) + max(1,log2(self.out_labels_count)) # Size of one list entry
        n = self.mem_width / s   # Number of items in one memory word
        max_words_per_meta_label = max([(len(x)+n-1)/n for x in self._array])
        
        return max_meta_labels * max_words_per_meta_label

class FieldLabelIndexingAggNode:
   
    def __init__(self, meta_labels_count, labels_count, valid_combinations, mem_width):
        "valid combinations format: [(meta_label, label),...]"
        
        self.meta_labels_count = meta_labels_count
        self.labels_count = labels_count
        self.out_labels_count = len(valid_combinations)
        self.mem_width = mem_width
        
        # Check indexes in valid_combinations
        for ml, l in valid_combinations:
            if (ml >= meta_labels_count or l >= labels_count):
                raise ValueError("FieldLabelIndexingAggNode: Some index in valid_combinations is too large.") 
        
        # Create an array of lists
        self._array = [ [] for i in range(labels_count)]
        
        for i, (ml, l) in enumerate(valid_combinations):
            self._array[l].append((ml,i))
        
        self.out_labels_count = len(valid_combinations)

    def display(self):
        for list in self._array:
            print list
    
    def process(self, meta_labels, labels):
        "metalabels and labels are arrays of numbers"
        
        output = []
        for l in labels:
            for ml, out in self._array[l]:
                if ml in meta_labels:
                    output.append(out)
        
        return output

    def get_memory_req(self, used_bits_only=False):
        """
        Returns number of bits of memory needed by the node. Some bits may
        be unused because of alignment to memory width. If you don't want to
        count these bits, set used_bits_only to True.
        """
        
        # Size of one list entry
        s = max(1,log2(self.meta_labels_count)) + max(1,log2(self.out_labels_count))
        
        if (used_bits_only == True):
            num_of_items = sum(map(len,self._array))
            mem = s * num_of_items
            #print "Field: Used bits:",mem
        
        else:
            n = self.mem_width / s   # number of items in one memory word
            if (n == 0):
                raise ValueError("Memory width too low, not even one list item can fit in it.")
            mem_words = sum([int(math.ceil(float(len(x))/n)) for x in self._array])
            mem = self.mem_width * mem_words
            #print "Field: Number of memory words needed (memory width",self.mem_width,"b):",mem_words
            
        return mem
    
    def get_sma(self, max_meta_labels, max_labels):
        "Returns maximal number of sequential memory accesses per item"
        
        # To get accurate value, we need to know all possible input sets 
        # of labels, compute SMA for each one and get maximum.
        # This would be very difficult and slow to compute, so we use maximum 
        # number of memory words per an array item.
        
        s = max(1,log2(self.labels_count)) + max(1,log2(self.out_labels_count)) # Size of one list entry
        n = self.mem_width / s   # Number of items in one memory word
        max_words_per_meta_label = max([(len(x)+n-1)/n for x in self._array])
        
        return max_meta_labels * max_words_per_meta_label

class BloomFilterArrayAggNode:
   
    def __init__(self, meta_labels_count, labels_count, valid_combinations, mem_width):
        "valid combinations format: [(meta_label, label),...]"
        
        # Check indexes in valid_combinations
        for ml, l in valid_combinations:
            if (ml >= meta_labels_count or l >= labels_count):
                raise ValueError("FieldLabelIndexingAggNode: Some index in valid_combinations is too large.") 
        
        # Number of hash function used to index bits in bloom filter.
        # Determines false positive probability which is (1/2)^k (for k=4 it's 0.06)
        k = 4
        
        self.valid_combs = valid_combinations
        self.num_of_hashes = k
        
        # Compute needed size of array
        array_width = mem_width # Each bloom filter is stored in one memory word
        array_len = int(math.ceil((k * len(valid_combinations)) / (array_width * 0.69314718))) # Equation from original DCFL paper
        
        # Initialize bloom filter array 
        self.array = [ [0 for x in range(array_width)] for y in range(array_len) ]
        
        # Put all valid combinations into bloom filters (hash it and set correspondig bits to one)
        for comb in valid_combinations:
            bfilter = self.array[hash(comb) % array_len]
            for i in range(self.num_of_hashes):
                bfilter[hash((comb,i)) % array_width] = 1
        
    def display(self):
        for bfilter in self.array:
            print bfilter
    
    def process(self, meta_labels, labels):
        "metalabels and labels are arrays of numbers"
        
        array_len = len(self.array)
        array_width = len(self.array[0])
        
        output = []
        # Get set of all possible combinations of meta_labels and labels
        query_set = [(ml,l) for l in labels for ml in meta_labels]
        
        # Ask bloom filter array which combinations are valid
        for query in query_set:
            bfilter = self.array[hash(query) % array_len]
            for i in range(self.num_of_hashes):
                if bfilter[hash((query,i)) % array_width] != 1:
                    break
            else:
                # Convert found combination to output meta_label
                try:
                    output.append(self.valid_combs.index(query))
                except ValueError:
                    # Found combination is not in valid_combs (false positive 
                    # of bloom filter), use some random label
                    #print "** Fault positive of Bloom Filter Array occured **"
                    ""
                    #output.append(hash((query,-1))%len(self.valid_combs))
                """ TODO:
                Bloom filter muze vratit jakoukoliv kombinaci vstupu, tomu musi
                odpovidat vystupni meta_labely. Bude tedy nutne prepsat generovani
                uzlu tak, ze kdyz se pouzije Bloom, dalsi uzel dostane jinou 
                mnozinu moznych vstupu. Nebo vymyslet nejak jinak, ale nejde pri
                false positive vratit nejakou hodnotu patici jinym prefixum.
                """  
        
        return output

    def get_memory_req(self):
        "Returns number of bits of memory needed by the node."
        
        # Return widht * length of the array
        return len(self.array[0]) * len(self.array)

    def get_sma(self, max_meta_labels, max_labels):
        "Returns maximal number of sequential memory accesses per item"
        return max_meta_labels * max_labels

