# Copyright (C) 2018 Libor Polčák <ipolcak@fit.vutbr.cz>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import os
from IPy import IP

def create_ip_filename(path, ipaddr, create_path=False):
    created_path = ""
    if "." in ipaddr:
        groups = ipaddr.split(".")
        created_path = "%s/%s/%s" % (path, groups[0], groups[1])
    elif ":" in ipaddr:
        groups = ipaddr.split(":")
        fix_short = lambda g: g if g != "" else "0000"
        created_path = "%s/%s/%s" % (path, fix_short(groups[0]), fix_short(groups[1]))
    else:
        created_path = "%s/unknown" % (path,)
    if create_path:
        created_path = os.path.abspath(created_path)
        os.makedirs(created_path, exist_ok = True)
    return "%s/%s" % (created_path, ipaddr)

def __find_addresses(input_net, res_plen, p, prefix, remainder):
    addresses = []
    try:
        for entry in os.scandir(p):
            try:
                ip = IP(str(prefix + entry.name) + "%s/%d" % (remainder, res_plen))
                if ip in input_net:
                    addresses.append(ip)
            except ValueError:
                pass # Not an expected address family range
    except FileNotFoundError:
        pass
    addresses.sort()
    return [str(ip) for ip in addresses]

def get_ranges_addresses_ipv4(path, network_addr, plen):
    input_net = IP("%s/%d" % (network_addr, plen))
    if plen < 8:
        return __find_addresses(input_net, 8, path, "", ".0.0.0")
    elif plen < 16:
        begining = int(input_net.strBin()[0:8], 2)
        return __find_addresses(input_net, 16, "%s/%d" % (path, begining), "%d." % begining, ".0.0")
    else:
        binary = input_net.strBin()
        byte_ints = int(binary[0:8], 2), int(binary[8:16], 2)
        return __find_addresses(input_net, 32, "%s/%d/%d" % (path, *byte_ints), "", "")

def get_ranges_addresses_ipv6(path, network_addr, plen):
    input_net = IP("%s/%d" % (network_addr, plen))
    if plen < 16:
        return __find_addresses(input_net, 16, path, "", "::")
    elif plen < 32:
        begining = hex(int(input_net.strBin()[0:16], 2))[2:]
        return __find_addresses(input_net, 32, "%s/%s" % (path, begining), "%s:" % begining, "::")
    else:
        binary = input_net.strBin()
        groups = (hex(b)[2:] for b in (int(binary[0:16], 2), int(binary[16:32], 2)))
        return __find_addresses(input_net, 128, "%s/%s/%s" % (path, *groups), "", "")

def get_ranges_addresses(path, network_addr, plen):
    if "." in network_addr:
        return get_ranges_addresses_ipv4(path, network_addr, plen)
    else:
        return get_ranges_addresses_ipv6(path, network_addr, plen)
