#!/usr/bin/python3
# copyright 2021 Peter Dordal 
# licensed under the Apache 2.0 license

# Helper functions for Mininet BGP experiments
# type hinting passes mypy, except for the imported modules

import os
import socket
from pathlib import Path
from mininet.net import Mininet
from mininet.node import Node
from mininet.link import Link, Intf
from typing import List, Dict, Tuple, Union

DIRPREFIX='.'

OVERWRITE=True

# neighbordict returns a dictionary d where d[ (node1,intf) ] = node2 if the link from
# node1 using intf leads directly to node2, AND both node1 and node2 are BGP routers.
# contains (node,intf) entries only for those interfaces that lead to another BGP router. Does NOT include links to ordinary hosts!
# nodes and intfs are string names, not Mininet objects; that is, the type is Dict[Tuple[str,str], str]
def neighbordict(BGPnodelist : List[ Node ]) -> Dict[ Tuple[str,str] , str ] :
    d = {}
    for n in BGPnodelist:
        interfaces = n.intfList()
        for intf in interfaces:
            link = intf.link
            n1 = link.intf1.node
            n2 = link.intf2.node
            if n1 == n: remoten = n2
            else: remoten = n1
            if not (remoten in BGPnodelist): continue
            d[n.__str__(), intf.__str__()] = remoten.__str__()
    return d

# addressdict takes the same BGPnodelist, and returns a dictionary mapping (nodename, ifname) -> (ip4addr, addrlen). 
# Again, nodes and interfaces are represented as string names; type is Dict[ Tuple[str,str] , Tuple[str,int] ]

def addressdict(BGPnodelist : List [ Node ]) -> Dict[ Tuple[str,str] , Tuple[Union[str,None],int] ]:
    d = {}
    for n in BGPnodelist:
        interfaces = n.intfList()
        for intf in interfaces:
            (addr,plen) = ipv4addr(n, intf)
            if addr == None: continue
            d[n.__str__(), intf.__str__()] = (addr, plen)
    return d


# for diagnostics
def dumpdict(d):
    for pair in d:
        (n,i) = pair
        print('node {} interface {} goes to {}'.format(n,i,d[pair]))
        
PASSWORD='zpassword'
EPASSWORD='epassword'

# creates the zebra.conf file. r is a mininet object; keys (rname, intfname) in nd are pairs of strings
def create_zebra_conf(r : Node, nd):
    rname = r.__str__()
    print('working on {}/zebra.conf'.format(rname))
    global DIRPREFIX
    DIR = DIRPREFIX + '/' + rname
    basename =DIR + '/zebra'
    conffilename = basename + '.conf'
    logfilename  = basename + '.log'
    zstr=('! zebra configuration file\n'
          '! This file is automatically generated\n'
          '\n'
          'hostname {routername}\n'
          'password {password}\n'
          'enable password {epassword}\n'
          '\n'
          'log file {logfilename}\n'
          '\n').format(routername=rname, password=PASSWORD, epassword=EPASSWORD, logfilename=logfilename)

    for (r2, intf) in nd:
        if r2 == rname:
            zstr += ('interface {}\n' 'multicast\n\n').format(intf)

    #zstr += 'log file {logfilename}\n'.format(logfilename=logfilename)
    
    print('trying to make directory {}'.format(DIR))
    
    mkdir_if_necessary(DIR)
    
    mode='x'
    if OVERWRITE: mode='w'
    
    #print('trying to write to {}/zebra.conf'.format(rname))
    try:
        f = open(conffilename, mode)		# exclusive open
        f.write(zstr)
        f.close()
    except OSError:
        print('creation of zebra.conf file {} failed; maybe it already exists?'.format(conffilename))
        return None
        
    return conffilename	# if the log file doesn't get deleted, it's not a problem.

# creates the bgpd.conf file
def create_bgpd_conf(r : Node, ndict : Dict[ Tuple[str,str] , str ], addrdict : Dict[ Tuple[str,str] , Tuple[str,int] ], ASdict : Dict[str, int], neighbors=None):
    global DIRPREFIX
    rname = r.__str__()
    DIR = DIRPREFIX + '/' + rname
    basename =DIR + '/bgpd'
    conffilename = basename + '.conf'
    logfilename  = basename + '.log'
    ASnum =ASdict[rname]
    router_id = maxIPaddr(rname, addrdict) 
    
    bstr=('! bgpd configuration file for router {routername}\n'
          '! This file is automatically generated\n'
          '\n'
          'hostname {routername}\n'
          'password {password}\n'
          'enable password {epassword}\n'
          '\n'
          'log file {logfilename}\n'
          '\n'
          'router bgp {ASnum}\n'
          'bgp router-id {router_id}\n'
          '\n').format(routername=rname, password=PASSWORD, epassword=EPASSWORD, logfilename=logfilename, ASnum=ASnum, router_id=router_id)

    # create entries for networks we announce
    bstr += '! These are the networks we announce; configured here to be ALL directly connected networks\n'
    for (r2, intf) in addrdict:
        if r2 != rname: continue
        (ipaddr, iplen) = addrdict[ (rname, intf) ]
        bstr += 'network {}/{}\n'.format(createprefix(ipaddr,iplen), iplen)
        
    # the neighbors with which we establish BGP sessions are those in the neighbors list. 
    # If this is None, then we use all BGP neighbors.
    if neighbors == None: neighbors = allBGPneighbors(rname, ndict)	# the neighbors we have BGP sessions with, by name
    bstr +='\n!These are the neighbors with which we estabish BGP sessions:\n'
    for neighbor in neighbors:
        neighborAS = ASdict[neighbor]
        neighborIP = '0.0.0.0'
        for (r2name,intf) in ndict:
            if r2name != neighbor: continue
            if rname != ndict[ (r2name, intf) ]: continue		# find interface by which neighbor connects to rname
            neighborIP = addrdict[ (r2name,intf) ][0]		# addrdict values are pairs (IPaddr, len), hence the [0]
            break
        bstr += 'neighbor {} remote-as {}\n'.format(neighborIP, neighborAS)

    mkdir_if_necessary(DIR)

    mode='x'
    if OVERWRITE: mode='w'

    try:
        f = open(conffilename, mode)		# exclusive open
        f.write(bstr)
        f.close()
    except OSError:
        print('creation of bgpd.conf file {} failed; maybe it already exists?'.format(conffilename))
        return None
        
    return conffilename	# if the log file doesn't get deleted, it's not a problem.
    
# need to create rname-> ASnum dictionary to be made available globally.
# This function creates BGP entries to *announce* all directly connected networks, 
# but to set up BGP peer connections only with neighbors in nlist
    
def maxIPaddr(rname : str, addrdict : Dict[ Tuple[str,str] , Tuple[str,int] ]):
    addrlist = []
    for (r, intf) in addrdict:
        if r != rname: continue
        (addr, length) = addrdict[ (r, intf) ]
        addrlist.append(addr)
    addrlist.sort()
    return addrlist[ len(addrlist)-1 ]

# rname is a string 
def allBGPneighbors(rname : str, ndict : Dict[ Tuple[str,str] , str ]):
    nlist = []
    for (r, intf) in ndict: 
        if r == rname: nlist.append(ndict[(r,intf)])
    return nlist
             
# in the following, ip(4,2) returns 10.0.4.2. Not used here
def ip(subnet : int ,host : int ,prefix=None):
    return
    addr = '10.0.'+str(subnet)+'.' + str(host)
    if prefix != None: addr = addr + '/' + str(prefix)
    return addr
 
# creates directory if it does not exist. Fails if it is a nondirectory file
def mkdir_if_necessary(dir : str):
    #if path.isdir(dir): return
    Path(dir).mkdir(parents=True, exist_ok=True)
    os.chmod(dir, 0o777)

# workaround for the fact that Python modules do not share global variables
def setdirprefix(dp : str):
    global DIRPREFIX
    DIRPREFIX=dp

# Given 10.0.37.61 and 24, returns '10.0.37.0/24'. Bit-masking is tricky with Python's variable-length integers.
def createprefix(ipaddr : str, length : int) -> str:
    baddr = socket.inet_aton(ipaddr)
    numaddr = (baddr[0]<<24) | (baddr[1]<<16) | (baddr[2]<<8) | baddr[3]
    mask = ((1<<32) - 1) ^ ((1<<(32-length)) -1)	# this ensures a 32-bit quantity
    IPint = numaddr & mask
    ba = bytearray(baddr)
    ba[0] = (IPint >> 24) & 0xff
    ba[1] = (IPint >> 16) & 0xff
    ba[2] = (IPint >>  8) & 0xff
    ba[3] =  IPint        & 0xff
    return socket.inet_ntoa(ba)

# This finds the first (IPv4_address, prefixlen) listed by the "ip addr list" command for the interface.
# If there are none, it returns (None,0)
# It may fail if the output of "ip addr" cannot be parsed.
# If you want *all* IPv4 addresses, modify it to build a list. 
# Each IPv4 address should be on its own 'inet ' line ('inet6 ' for IPv6).
def ipv4addr(node : Node, intf : Intf) -> Tuple[ Union[str,None], int] :
    s = node.cmd('ip addr list dev {}'.format(intf))
    lines = s.split('\n')
    target='inet '		# to find IPv6 addresses, change this to 'inet6 '
    tlen = len(target)
    for line in lines:
        llen = len(line)
        pos = line.find(target)
        if pos < 0: continue
        start = pos + tlen
        pos = start
        while pos < llen and (line[pos] == ' ' or line[pos] == '\t'): pos = pos+1
        start = pos		# first nonblank
        while pos < llen and line[pos] != ' ' and line[pos] != '\t': pos = pos+1
        word = line[start:pos]
        pos = word.find('/')
        assert pos >= 0	# there better *be* a '/'
        ipaddr = word[:pos]
        iplen = int(word[pos+1:])
        return (ipaddr, iplen)
    return(None, 0)
    
     
