###############################################################################
#  sst.py: Module for shape-shifting trie LPM algorithm
#  Copyright (C) 2010 Brno University of Technology, ANT @ FIT
#  Author(s): Martin Skacan <xskaca00@stud.fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id$

"""
Module for shape-shifting trie LPM algorithm
"""

import sys
import blpm
import trie
from netbench.classification import prefixset
from netbench.classification import maskedint

class SST(blpm.BLPM):
    """
    Shape-shifting trie LPM algorithm for lookup prefixes
    """

    def __init__(self, K = 3):
        """Constructor"""
        blpm.BLPM.__init__(self)
        self.binTrie = trie.Trie()
        self.K = K
        self.prefixes = 0
        self.nodes = 0
        self.child_pointers = 0
        self.prefix_pointers = 0
        self.max_depth = 0

    def report_memory(self):
        """
        Return detailed info about algorithm memory requirements.
        """
        
        report = {'prefixes': self.prefixes, 'nodes': self.nodes,
                  'child_pointers': self.child_pointers,
                  'prefix_pointers': self.prefix_pointers,
                  'depth': self.max_depth}
        return report
        
    def load_prefixset(self, prefixset):
        """
        Load prefixes and generate all necessary data structures. 
        """
        
        self.prefixset = prefixset
        # generate the unibit Trie
        self.binTrie.load_prefixset(self.prefixset)
        self.prefixes = self.binTrie.prefixes
        
        # count the number of nodes in subtrees      
        self.count_nodes(self.binTrie.tree.root)
        # order the SSNodes in breadth-first order
        bfs_list = self.prepare_bfs(self.binTrie.tree.root)
        
        # prune the nodes of unibit Trie and create SSTree
        while bfs_list:
            for node in bfs_list:
                # prune this subtree
                if node.subtree_nodes <= self.K:
                    # create new SSNode
                    new_SSNode = self.make_SSNode(node, bfs_list)
                    self.nodes += 1
                    # bind it to parent
                    if node.parent:
                        if node.parent.lchild == node:
                            node.parent.lchild = new_SSNode
                        elif node.parent.rchild == node:
                            node.parent.rchild = new_SSNode
                    current_node = node.parent
                    # decrease the number of subtree nodes in all ancestors
                    while current_node:
                        current_node.subtree_nodes -= node.subtree_nodes
                        current_node = current_node.parent
                    break
                    
        # create new SSTree             
        self.tree = SSTree(self.K)
        # in new_SSNode is still stored the root of the pruned tree
        self.tree.root = new_SSNode
        self.tree.root.depth = 0
        # compute depth of the nodes
        queue = []
        queue.append(self.tree.root)
        while(queue):
            item = queue.pop(0)
            for i in item.children:
                i.depth = item.depth + 1
                if i.depth > self.max_depth:
                    self.max_depth = i.depth
                queue.append(i)

        
    def make_SSNode(self, root, tree_list):
        """
        Create SSNode from a (unibit trie) subtree
        """
        
        new_node = SSNode()
        # order the nodes in breadth-first order
        bfs_list = self.prepare_bfs(root)
        for i in bfs_list:
            # compute SBM
            if i.lchild != None and isinstance(i.lchild, trie.Node):
                new_node.SBM.append(1)
            else:
                new_node.SBM.append(0)
            if i.rchild != None and isinstance(i.rchild, trie.Node):
                new_node.SBM.append(1)
            else:
                new_node.SBM.append(0)
            # compute IBM
            if i.value:
                new_node.IBM.append(1)
                new_node.prefixes.append(i.value)
            else:
                new_node.IBM.append(0)
            # compute EBM
            if i.lchild != None and isinstance(i.lchild, SSNode):
                new_node.EBM.append(1)
                new_node.children.append(i.lchild)
            elif i.lchild == None:
                new_node.EBM.append(0)
            if i.rchild != None and isinstance(i.rchild, SSNode):
                new_node.EBM.append(1)
                new_node.children.append(i.rchild)
            elif i.rchild == None:
                new_node.EBM.append(0)
                
            # remove trie.node from bfs_list of the whole tree
            if tree_list.count(i):
                tree_list.remove(i)
                
        # update count of prefix_pointers
        if new_node.prefixes:
            self.prefix_pointers += 1
        # update count of children_pointers
        if new_node.children:
            self.child_pointers += 1
        return new_node
        
        
        
    def prepare_bfs(self, node = None):
        """
        Order the nodes of the subtree in breadth-first order
        """
        
        queue = []
        bfs_list = []
        if isinstance(node, trie.Node):
            queue.append(node)
        while(queue):
            item = queue.pop(0)
            bfs_list.append(item)
            if item.lchild != None and isinstance(item.lchild, trie.Node):
                queue.append(item.lchild)
            if item.rchild != None and isinstance(item.rchild, trie.Node):
                queue.append(item.rchild)
        
        return bfs_list
        
    def left_most(self, node, stack, stackB):
        """ traverse to the left most node of current branch """
        
        current_node = node
        while current_node:
            stack.append(current_node)
            stackB.append(True)
            current_node = current_node.lchild
            
           
    def count_nodes(self, node = None):
        """
        Count the nodes in every subtree of the tree
        """
        
        stack = []
        stackB = []
        
        self.left_most(node, stack, stackB)
        # post-order traverse
        while stack:
            current_node = stack[-1]
            from_left = stackB.pop()
            if from_left:
                stackB.append(False)
                self.left_most(current_node.rchild, stack, stackB)
            else:
                stack.pop()
                tmpcount = 0
                if current_node.lchild:
                    tmpcount += current_node.lchild.subtree_nodes
                if current_node.rchild:
                    tmpcount += current_node.rchild.subtree_nodes
                current_node.subtree_nodes = 1 + tmpcount
        
        
    def lookup(self, ip):
        """
        Lookup prefixes that match ip.
        Return the list of matched prefixes. 
        If ip matches no prefix, the list will be empty.
         
        ip: value of the ip address according to maskedint.py
        """
        
        # list of prefixes that match ip     
        list = []
        # ip in binary notation
        key = bin(ip)[2:]
        key = (2**maskedint.log2(len(key)) - len(key))*"0" + key
        # we start in the root
        current_node = self.tree.root
        # check if root has a valid prefix
        if current_node.IBM[0]:
            list.append(current_node.prefixes[0])
        # init the variables
        ni = 2
        fi = 0
        ai = int(key[0])
        pi = ai
        # go through the IP (in the binary notation)
        for i in range(0, len(key)):
            # we stay in the current_node
            if current_node.SBM[pi]:
                # look for a valid prefix in the current_node
                ones = 0
                for k in range(0, pi+1):
                    if current_node.SBM[k]:
                        ones += 1
                x = ones
                # there is a valid prefix
                if current_node.IBM[x]:
                    ones = 0
                    for k in range(0, x):
                        if current_node.IBM[k]:
                            ones += 1
                    # add the prefix in the list
                    list.append(current_node.prefixes[ones])
                # store the values of the previous step and compute the values
                # for the next step
                old_fi = fi
                old_ni = ni
                old_pi = pi
                fi = old_fi + old_ni
                ones = 0
                for k in range(old_fi, fi):
                    if current_node.SBM[k]:
                        ones += 1
                ni = 2 * ones
                if len(key) > i+1:
                    ai = int(key[i+1])
                else:
                    break
                ones = 0
                for k in range(old_fi, old_pi):
                    if current_node.SBM[k]:
                        ones += 1
                # the next index in the SBM
                pi = fi + 2*ones + ai
            # there is nowhere to go, in the current_node. Try to jump on next node
            else:
                zeros = 0
                for k in range(0, pi):
                    if current_node.SBM[k] == 0:
                        zeros += 1
                # there is another SSNode in desired way
                if current_node.EBM[zeros]:
                    ones = 0
                    for k in range(0, zeros):
                        if current_node.EBM[k]:
                            ones += 1
                    # jump on the next SSNode
                    current_node = current_node.children[ones]
                    # check if there is a valid prefix in the root of the SSNode
                    if current_node.IBM[0]:
                        list.append(current_node.prefixes[0])
                    # re-init the variables for the next node
                    if len(key) > i+1:
                        ni = 2
                        fi = 0
                        ai = int(key[i+1])
                        pi = ai
                else:
                    break
        
        # reverse list to be sorted in descentant order   
        list.reverse()
        return list


    def display(self, node):
        """Display the structure of the tree"""
        
        # create shortcuts for prefixes in the PrefixSet
        shortcut = {}
        i = 0
        for prefix in self.prefixset.get_prefixes():
            i += 1
            shortcut[prefix] = "P" + str(i)
        # traverse the tree in pre-order and print the stucture of the tree
        if node:
            print node.depth *"  ", "---- SSNODE --------------------"
            print node.depth *"  ", "- SBM ", node.SBM
            print node.depth *"  ", "- IBM ", node.IBM
            print node.depth *"  ", "- EBM ", node.EBM
            print node.depth *"  ", "- prefixes in node:",
            for prefix in node.prefixes:
                print shortcut[prefix],
            print "\n",node.depth *"  ", "------------------------------"
            for i in node.children:
                self.display(i)
                
                       
class SSNode(object):
    """
    Common class necessary for the class SST
    It is one node in SSTree. It contains 3 bitmaps(SBM,IBM,EBM) which
    describe the node.
    """
    def __init__(self):
        self.depth = None
        self.SBM = []
        self.IBM = []
        self.EBM = []
        self.children = []
        self.prefixes = []
        
    

class SSTree(object):
    """
    Common class necessary for the class SST
    It consists of SSNodes and represents the whole tree in which will be
    looking up the prefixes
    """
    def __init__(self, K = 3):
        self.K = K
        self.root = None
        self.nodes = 0
        self.child_pointers = 0
        self.prefix_pointers = 0
