# depth_test.py
#
# Author: Vojtěch Dvořák
#
# Script for depth testing - testing whole structure of syntax tree (in contr-
# ast with the cli tree-sitter test, that tests just the named node tree) 
#
# Created using tree-sitter python bindings (see https://github.com/tree-sitter/py-tree-sitter)
#
# Compares the results of the parsing with syntax tree described in correspond
# -ing .out file by the S-expression (see test/depth/trivial.out for simple 
# example)
#
# Requirements: tree_sitter (run pip install tree_sitter)


from tree_sitter import Language, Parser, Node
from abc import ABC, abstractmethod
import os

TEST_DIRECTORY = 'test/depth/' # Directory with the test cases
TEST_CASE_EXT = '.yar' # Test case file extension
EXP_RES_EXT = '.out' # Extension of file with the expected result


# Build .so library for YARA parser
Language.build_library(
    'build/yara-language.so',
    ['.']
)

YARA_LANGUAGE = Language('build/yara-language.so', 'yara')


class InOrderCursor(ABC):
    """
    Just super class for concrete cursors, that provides unified interface
    for walking through differently represented trees 
    """

    def __init__(self) -> None:
        self.nesting_lvl = 0

    @abstractmethod
    def restart(self) -> None:
        pass

    @abstractmethod
    def next(self) -> None:
        pass

    @abstractmethod
    def is_finished(self) -> None:
        pass

    def get_nesting_lvl(self) -> int:
        """
        Returns current nesting level (nesting level of the node returned by
        last call of the 'next' function), if there was no call of 'next' yet,
        the -1 is returned 
        """

        return self.nesting_lvl - 1


class TreeInOrderCursor(InOrderCursor):
    """
    Class that provides simple interface for inorder traversal of the syntax
    tree returned by the tree-sitter
    """

    def __init__(self, root : Node) -> None:
        super().__init__()
        self.stack_start_state = [root]
        self.stack = [root]
        self.children_cnt_stack = []


    def restart(self):
        """
        Restarts cursor, finished cursor can be executed again
        """

        self.stack = self.stack_start_state


    def next(self) -> Node|None:
        """
        Returns next Node in syntax tree or return None (if all nodes were 
        visited)
        """

        if len(self.stack) > 0:

            # Updating nesting lvl
            self.nesting_lvl += 1 # Always presume, the node is the child node of previous node (get_nesting_lvl makes correction)
            while len(self.children_cnt_stack) > 0:
                if self.children_cnt_stack[-1] == 0: # If there are children counts on the top of the stack, that are 0 (all nodes were visited) go to upper level 
                    self.children_cnt_stack.pop()
                    self.nesting_lvl -= 1
                else:
                    self.children_cnt_stack[-1] -= 1 # If there are nodes that are not visited yet, just decrement their counter
                    break

            current = self.stack.pop() # Popping the node on the top of the stack

            reversed_children = current.children[::-1] # Objects stored in the stack must have the children property (but otherwise it is universal cursor class)

            self.stack.extend(reversed_children) # Appending all cihldren to stack
            self.children_cnt_stack.append(len(reversed_children)) # Storing number of current node children 

            return current

        else:
            return None


    def is_finished(self) -> bool:
        """
        Returns true if there are no nodes to be visited in the tree
        """

        return len(self.stack) == 0



class PseudoNode:
    """
    Fake of the Node class from the tree_sitter module, attributes of Node
    are not permitted to be changed, so this class was made to hold the same
    attributes, that can be written (at least the important attributes of the
    Node class), in respect to duck-typing it should be OK use this class as 
    the node class in special cases 
    """

    def __init__(self) -> None:
        self.type = None
        self.is_named = None

    def __str__(self) -> str:
        if self.is_named:
            return f'<Node type={self.type}>'
        else:
            return f'<Node type="{self.type}">'


class StringInOrderCursor(InOrderCursor):
    """
    Cursor for reading file with syntax tree, that is written in InOrder way
    (see test/depth/*.out to see examples) in S-Expression format
    """

    def __init__(self, string : str) -> None:
        super().__init__()
        self.orig_string = string.strip()
        self.string = self.orig_string
        self.cur_line_index = 0

    
    def restart(self):
        """
        Restarts the cursor and it should be executed again
        """

        self.string = self.orig_string


    def next(self) -> PseudoNode|None:
        """
        Returns the next node in the syntax tree
        """

        if not self.string.replace('(', ''). replace(')', '').strip():
            return None

        node = PseudoNode()
        while self.string:
            if self.string[0] == '\n':
                self.cur_line_index = self.cur_line_index + 1
            elif self.string[0].isspace(): # Skip the whitespaces
                pass
            elif self.string[0] == '(': # Go deeper to the hierarchy
                self.nesting_lvl += 1
            elif self.string[0] == ')': # Go to upper level of the hierarchy
                self.nesting_lvl -= 1
            else:
                if self.string[0] == '"' or self.string[0] == "'":
                    bound_char = self.string[0]
                    node.is_named = False # The node is unnamed
                    if not bound_char in self.string[1:]:
                        print(f'Unmatched \'{bound_char}\'!')

                    end_index = self.string[1:].index(bound_char) + 1
                    node.type = self.string[1:end_index]
                    self.string = self.string[end_index + 1:]

                else:
                    node.is_named = True # The node is named
                    node.type = ''
                    special = ['(', ')', '"', "'"]
                    while not self.string[0].isspace() and not self.string[0] in special:
                        node.type += self.string[0]
                        self.string = self.string[1:]

                        if not self.string:
                            break 
                
                return node

            self.string = self.string[1:]


    def is_finished(self) -> bool:
        """
        Returns True if cursor reaches the end
        """
        return not self.string.replace('(', ''). replace(')', '').strip()

    
    def get_cur_line(self) -> int:
        """
        Returns the line of the node that was returned by the last call of the
        next function
        """

        return self.cur_line_index + 1



def test(sut : Parser, path_to_yar : str, path_to_exp_res : str) -> tuple[bool, str]:
    """
    Tries to parse the file given by path_to_yar with the sut (system under test)
    and compares it with the expected result in the file given by path_to_exp_res
    """

    inp_file = open(path_to_yar, 'r')
    expF = open(path_to_exp_res, 'r')

    tree = sut.parse(inp_file.read().encode('utf8'))

    tree_cursor = TreeInOrderCursor(tree.root_node)
    string_cursor = StringInOrderCursor(expF.read())

    while not tree_cursor.is_finished() and not string_cursor.is_finished():
        node = tree_cursor.next()
        nesting_lvl = tree_cursor.get_nesting_lvl()

        exp_node = string_cursor.next()
        exp_nesting_lvl = string_cursor.get_nesting_lvl()

        # Comparing real and expected node
        if nesting_lvl < exp_nesting_lvl:
            return False, f'Expected node with type \'{exp_node}\' (l. {string_cursor.get_cur_line()} in {EXP_RES_EXT})!'
        
        elif nesting_lvl > exp_nesting_lvl:
            return False, f'Unexpected node with type \'{node}\' (l. {string_cursor.get_cur_line()} in {EXP_RES_EXT})!'

        elif node.type != exp_node.type or node.is_named != exp_node.is_named:
            return False, f'Expected node {exp_node} (l. {string_cursor.get_cur_line()} in {EXP_RES_EXT}) but found {node}!'

    # There are still non visited nodes
    if not tree_cursor.is_finished():
        return False, f'Unexpected node \'{tree_cursor.next().type}\'!'

    if not string_cursor.is_finished():
        return False, f'Expected node with type \'{string_cursor.next().type}\' (l. {string_cursor.get_cur_line()} in {EXP_RES_EXT})!'
    
    return True, 'Success'


def run(sut : Parser) -> tuple[int, int]:
    """
    Recursively walks through TEST_DIRECTORY and executes tests above the files
    with the .yar extension, returns tuple that consists of passed test count
    and total test count 
    """

    total = 0
    passed = 0
    last_dir = None
    for root, _, files in os.walk(TEST_DIRECTORY):
        for file in files:
            
            cur_dir = root
            if cur_dir != last_dir:
                if last_dir != None:
                    print()

                print(f'\033[4m{cur_dir}:\033[0m')
                last_dir = cur_dir

            bname, ext = os.path.splitext(file)

            # File with the expected result of the parsing
            exp_res_file = os.path.join(root, bname + EXP_RES_EXT)
            
            # Input file for parser
            inp_file = os.path.join(root, file)

            if ext == TEST_CASE_EXT and os.path.isfile(exp_res_file):
                test_result, reason = test(sut, inp_file, exp_res_file)
                total += 1

                if test_result:
                    print('\033[92mPASSED\033[0m\t', end='')
                    passed += 1
                else:
                    print('\033[91mFAILED\033[0m\t', end='')

                print(file)

                if not test_result:
                    print(reason)
                    print()

    return passed, total


def dump_tree(cursor : InOrderCursor):
    """
    Prints tree traversed by given cursor to stdout
    """

    while not cursor.is_finished():
        node_type = cursor.next().type
        nesting_lvl = cursor.get_nesting_lvl()
        print(nesting_lvl*'  ' + node_type)


def main():
    parser = Parser()
    parser.set_language(YARA_LANGUAGE)

    passed, total = run(parser)

    print()
    print(f'Summary: {passed}/{total}')

    if passed != total:
        exit(1)
    else:
        exit(0)


if __name__ == '__main__':
    main()

