###############################################################################
#  multimatch.py: Module for multi-match LPM algorithm
#  Copyright (C) 2010 Brno University of Technology, ANT @ FIT
#  Author(s): Martin Skacan <xskaca00@stud.fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id$

"""
Module for multi-match LPM algorithm
"""

import sys
from math import ceil
import itertools
import blpm
from netbench.classification import prefixset
from netbench.classification import maskedint

class MultiMatch(blpm.BLPM):
    """
    Multi-match LPM algorithm for lookup prefixes
    Prefixes from prefixset will be sorted into classes
    Count of classes and sorting method can be set as a parameter
    """

    def __init__(self, class_count = 4, variant = 1, mask = []):
        """
        Constructor
        
        variant 1 - each class has fixed size (domain_size/class_count)
        variant 2 - the best distribution of prefixes into classes will be computed
        variant 3 - explicit set up classes as parameter 'mask'
        """
        blpm.BLPM.__init__(self)
        self.variant = variant
        self.class_count = class_count
        self.hashtable = {}
        self.max_domain = 0
        if variant == 3:
            self.expand_mask = mask
        else:
            self.expand_mask = [0]* self.class_count
        self.prefixes = [0]*self.class_count
        self.tables = [0]*self.class_count

    def report_memory(self):
        """
        Return detailed info about algorithm memory requirements.
        """
        
        report = {'prefixes': self.prefixes, 'tables': self.tables}
        return report
        

    def load_prefixset(self, prefixset):
        """
        Load prefixes and generate all necessary data structures. 
        """
        
        self.prefixset = prefixset
        # find max domain_size of prefixset
        for prefix in self.prefixset.get_prefixes():
            if prefix.get_domain_size() > self.max_domain:
                self.max_domain = prefix.get_domain_size()
        
        # variant 1 - uniform distribution into classes        
        if self.variant == 1:
            class_length = float(self.max_domain) / self.class_count
            class_length = ceil(class_length)
            # set up the expand mask
            for i in range(0, self.class_count):
                if int((i+1)*class_length) > self.max_domain:
                    self.expand_mask[i] = self.max_domain
                else:
                    self.expand_mask[i] = int((i+1)*class_length)    
                     
        # variant 2 - the best possible distribution into classes              
        elif self.variant == 2:
            # get the histogram of the prefixset
            histogram = self.prefixset.get_histogram()
            last = self.class_count - 1
            # find the size of the longest prefix
            for i in range(1, self.max_domain-1):
                if histogram[-i] != 0:
                    last = self.max_domain - i + 1
                    break
            mask = range(self.class_count-1)
            # the last value in the expand mask is the length of the longest prefix
            mask.append(last)
            
            minimum = 0
            next = 0
            self.expand_mask = mask
            # try the fist combination, compute the count of required expansions
            # and set it up as the minimum
            if len(histogram) > last:
                for i in range(last+1):
                    # sum n*2**(expand_mask - length
                    minimum += histogram[i]*(2**(mask[next] - i))
                    if i >= mask[next]:
                        next += 1
            # try all combinations and find the minimum of required expansions
            for i in itertools.combinations(range(last), self.class_count-1):
                tmpmask = []
                tmpmask.extend(i)
                tmpmask.append(last)
                tmpmin = 0
                next = 0
                if len(histogram) > last:
                    for j in range(last+1):
                        # sum n*2**(expand_mask - length)
                        tmpmin += histogram[j]*(2**(tmpmask[next] - j))
                        if j >= tmpmask[next]:
                            next += 1
                if tmpmin < minimum:
                    minimum = tmpmin
                    mask = tmpmask
            self.expand_mask = mask

        # the expand mask is set up, now.
        # go through the whole prefixset
        for prefix in self.prefixset.get_prefixes():
            # find the number of the class, that fit
            for i in range(0, self.class_count):
                if prefix.get_length() <= self.expand_mask[i]:
                    position = i
                    break
            # expand the prefix, if needed
            key = ""
            if prefix.get_length():
                key = bin(prefix.get_value() >> (prefix.get_domain_size() - prefix.get_length()))[2:]
                key = (prefix.get_length() - len(key))*"0" + key
            list = expand(key, self.expand_mask[position])
            # add the expanded prefixes into hash-table
            for i in list:
                if i in self.hashtable:
                    self.hashtable[i].append(prefix)
                else:
                    self.hashtable[i]=[prefix]
                    self.prefixes[position] += 1
                    
        # count the required memory of every table, for statistics
        for i in range(0,self.class_count):
            self.tables[i] = 2**(maskedint.log2(self.prefixes[i]))
                

    def lookup(self, ip):
        """
        Lookup prefixes that match ip.
        Return the list of matched prefixes. 
        If ip matches no prefix, the list will be empty.
         
        ip: value of the ip address according to maskedint.py
        """
        
        # list of prefixes that match ip     
        list = []
        # ip in binary notation
        key = bin(ip)[2:]
        key = (2**maskedint.log2(len(key)) - len(key))*"0" + key
        
        # try to find prefixes in every class
        for i in range(0, self.class_count):
            # cut the first n bits from the IP and use it as the key into hash-table
            hashkey = key[0:self.expand_mask[i]]
            # if there is a prefix, add it into list of prefixes that match
            if hashkey in self.hashtable:
                for prefix in self.hashtable[hashkey]:
                    list.append(prefix)

        # sort the list of prefixes according to the length
        list.sort(key = lambda x: x.get_length(), reverse=True)
        return list


    def display(self):
        """Display the structure of the tree"""
        
        pass
   
    
########################
# module functions
########################
    
def expand(prefix = "", n = 8):
    """ Expand the prefix to n bits """
    
    # list of all expanded prefixes
    list = []
    number = 0
    # count of bits we need to expand
    count = int(n) - len(prefix)
    if count <= 0:
        list.append(prefix)
        return list
    # compute the expanded prefixes
    for i in range(0,2**count):
        suffix = bin(number)[2:]
        suffix = (count-len(suffix))*"0" + suffix
        list.append(prefix+suffix)
        number += 1
    return list
        