#!/usr/bin/env python3

import matplotlib.pyplot as plt
import numpy as np
import bisect
import ast
import math
import argparse
import sys

#TODO: use YAML/ruamel.yaml for configuration file.
def read_definition(filename):
    ddict = {}
    with open(filename, "r") as f:
        for line in f:
            items = line.split(': ', 1)
            if len(items) == 2:
                ddict[items[0]] = ast.literal_eval(items[1])
    return ddict

def conveyance(numH, n_co, xregion, zregion, zmin, zmax):
    p_i = [] # wetted perimeter
    A_i = [] # area
    r_h = [] # hydraulic radius
    h_i = [] # list of heights
    K_i = [] # conveyance
    Q_i = [] # discharge
    x_sub = [[] for i in range(numH)] # list of x values in subregion
    z_sub = [[] for i in range(numH)] # list of z values in subregion
    for i in range(numH):
        h_i.append(zmin + (i+1)*(zmax-zmin)/numH)
        #print(zregion[zregion < h_i[i]])
        booleanArray = zregion < h_i[i]
        #print(booleanArray[i])
        x_sub[i] += list(xregion[booleanArray])
        z_sub[i] += list(zregion[booleanArray])
        for interval in range(len(xregion)-1):
            if booleanArray[interval+1] != booleanArray[interval]:
                x_extra = xregion[interval] \
                    + (h_i[i] - zregion[interval])\
                    *(xregion[interval+1] - xregion[interval])\
                    /(zregion[interval+1] - zregion[interval])
                bisect.insort(x_sub[i], x_extra) # add intercept value
                ind_x = x_sub[i].index(x_extra)
                z_sub[i].insert(ind_x, h_i[i])   # add height value
        #print(z_sub[i])

        dp = 0
        dA = 0
        eps = 1e-06
        for j in range(len(x_sub[i])-1):
            if (abs(z_sub[i][j+1] - h_i[i]) > eps
                or abs(z_sub[i][j] - h_i[i]) > eps):
                dp += np.hypot(x_sub[i][j+1] - x_sub[i][j],
                               abs(z_sub[i][j+1] - z_sub[i][j]))
            #print(dp)
            # calculate area using trapezium rule
            dA += (h_i[i]
                   - (z_sub[i][j+1] + z_sub[i][j])/2)\
                   *(x_sub[i][j+1] - x_sub[i][j])
            #print('Area =', dA)

        p_i.append(dp)
        A_i.append(dA)

        r_h.append(A_i[i]/p_i[i])  # ratio of area and wetted perimeter
        #print('hydraulic radius =', r_h[i])
        K_i.append(A_i[i]*(1/n_co)*r_h[i]**(2/3)) # conveyance
        Q_i.append(K_i[i]*slope**0.5)             # discharge

    return p_i, A_i, r_h, h_i, K_i, Q_i

def plot_region(ydata, labely,
                xdata1, labelx1,
                xdata2, labelx2,
                xdata3, labelx3, titlep):
    plt.xlabel(labelx1)
    plt.ylabel(labely)
    plt.title(titlep)
    plt.plot(xdata1, ydata)
    plt.show()

    plt.xlabel(labelx2)
    plt.ylabel(labely)
    plt.title(titlep)
    plt.plot(xdata2, ydata)
    plt.show()

    plt.xlabel(labelx3)
    plt.ylabel(labely)
    plt.title(titlep)
    plt.plot(xdata3, ydata)
    plt.show()

def save_bc(outputfile):
    with open(outputfile, 'w') as f:
        f.write('{:6} {:>2} {:>2} {:>10}\n'.format('#x', 'c', 'q', 'h'))
        for ind_z, (xitem, zitem) in enumerate(zip(xin, zin)):
            panel_x = bisect.bisect(markers, xitem)
            if csa[ind_z] == 0:
                # wall boundary condition
                f.write('{:6.2f} {:>2}\n'.format(xitem, 2))
            elif panel_x == panel[ind_p]:
                # imposed discharge within part-filled panel
                f.write('{:6.2f} {:>2} {:10.6f} {:9.6f}\n'.format(
                    xitem, btype,
                    -csa[ind_z]*panel_target_flow/csa_p[panel_x],
                    h_extra-zitem))
            else:
                # imposed discharge within filled panels
                f.write('{:6.2f} {:>2} {:10.6f} {:9.6f}\n'.format(
                    xitem, btype,
                    -csa[ind_z]*Q_i[panel_x][-1]/csa_p[panel_x],
                    zmax-zitem))

def interp(extra2, max1, min1, max2, min2):
    # use similar triangles to perform linear interpolation
    extra1 = min1 + (max1 - min1)*(extra2 - min2)/(max2 - min2)
    
    return extra1

# read command line argument:
parser = argparse.ArgumentParser(
    description="generate FullSWOF boundary files")
parser.add_argument("location", help="boundary location")
args = parser.parse_args()

if args.location == 'top':
    inputFilename = "boundaryTop.txt"
    outputFilename = "BCTop.txt"
elif args.location == 'bottom':
    inputFilename = "boundaryBottom.txt"
    outputFilename = "BCBottom.txt"
elif args.location == 'left':
    inputFilename = "boundaryLeft.txt"
    outputFilename = "BCLeft.txt"
elif args.location == 'right':
    inputFilename = "boundaryRight.txt"
    outputFilename = "BCRight.txt"

# read boundary definition file:
definition_dict = read_definition(inputFilename)
#for dd in definition_dict:
#    print(definition_dict[dd])
btype       = definition_dict["type"]        # boundary type (1--5)
slope       = abs(definition_dict["slope"])  # slope at top boundary
target_flow = definition_dict["target_flow"] # imposed discharge
plotting    = definition_dict["plotting"]    # enable or disable plotting
printing    = definition_dict["printing"]    # enable or disable printing
n_co        = definition_dict["n_co"]        # Manning's 'n' coefficients
# TODO: use weighted mean 'n' values.  See
# http://help.floodmodeller.com/isis/ISIS/River_Section.htm (Eq. 4)
# Note: weighted mean calculation requires roughness map.
markers     = definition_dict["markers"]     # distances from corner point
panel       = definition_dict["panel"]       # panel fill order
ztol        = definition_dict["ztol"]        # tolerance in overtopping height
numH        = definition_dict["numH"]        # number of height intervals


# print(len(markers))

# with open('./1D_top.txt', "r") as data:
#     xch, ych, zch = np.loadtxt(data, delimiter=' ', unpack=True)

# Fit with polyfit
# m, c = np.polyfit(ych, zch, 1)
# print('gradient =', m, 'intercept =', c)

# read topography:
with open("./topography.txt", "r") as topo:
    xtp, ytp, ztp = np.loadtxt(topo, delimiter=' ', unpack=True)


xmax  = (xtp[0]+xtp[-1])                   # domain extent in x-direction
ymax  = (ytp[0]+ytp[-1])                   # domain extent in y-direction
ncols = int(math.sqrt(len(xtp)*xmax/ymax)) # number of cells in x-direction
nrows = int(len(xtp)/ncols)                # number of cells in y-direction
dX    = xmax/ncols                         # cell size
print('dX =', dX)

#print(ncols, nrows)

# extract slices from height data array.  Note: xyz format uses ncols
# blocks, with nrows lines per block.
if args.location == 'top':
    xin = xtp[nrows-1:len(xtp):nrows]
    yin = 2*ytp[nrows-1:len(xtp):nrows] - ytp[nrows-2:len(xtp):nrows]
    zin = 2*ztp[nrows-1:len(xtp):nrows] - ztp[nrows-2:len(xtp):nrows]
elif args.location == 'bottom':
    xin = xtp[0:len(xtp):nrows]
    yin = 2*ytp[0:len(xtp):nrows] - ytp[1:len(xtp):nrows]
    zin = 2*ztp[0:len(xtp):nrows] - ztp[1:len(xtp):nrows]
elif args.location == 'left':
    xin = 2*xtp[:nrows] - xtp[nrows:2*nrows]
    yin = ytp[:nrows]
    zin = 2*ztp[:nrows] - ztp[nrows:2*nrows]
elif args.location == 'right':
    xin = 2*xtp[nrows*(ncols-1):] - xtp[nrows*(ncols-2):nrows*(ncols-1)]
    yin = ytp[nrows*(ncols-1):]
    zin = 2*ztp[nrows*(ncols-1):] - ztp[nrows*(ncols-2):nrows*(ncols-1)]
    

# print(xin)

num_panels = len(panel) # number of panels across boundary

# convert marker co-ordinates to array indices:
marker_ind = [0]
for i in range(len(markers)):
    marker_ind.append(int(markers[i]/dX))
if args.location == 'left' or args.location == 'right':
    marker_ind.append(nrows)
elif args.location == 'top' or args.location == 'bottom':
    marker_ind.append(ncols)


# print(marker_ind)

xregion = []
zregion = []
zmin = []
for p in range(num_panels):
    # identify regions:
    xregion.append(xin[marker_ind[p]:marker_ind[p+1]])
    zregion.append(zin[marker_ind[p]:marker_ind[p+1]])
    # identify minimum heights within each panel:
    zmin.append(zregion[p].min())

# xregion_west = xin[100:281]
# zregion_west = zin[100:281]

# xregion_east = xin[300:408]
# zregion_east = zin[300:408]

# print(zregion)
# print(xin[12:20])

print('zmin =', zmin)

# channel overtopping height (minimum of left bank and right bank heights):
zmax = min(zregion[panel[0]][0], zregion[panel[0]][-1]) - ztol

print('zmax =', zmax)

#print(h_i)



p_i = [[] for _ in range(num_panels)]
A_i = [[] for _ in range(num_panels)]
r_h = [[] for _ in range(num_panels)]
h_i = [[] for _ in range(num_panels)]
K_i = [[] for _ in range(num_panels)]
Q_i = [[] for _ in range(num_panels)]
for p in range(num_panels):
    if p == panel[0]-1 and zregion[p][-1] < zmax:
        # ensure end node in region to the left of channel is dry:
        xregion[p] = np.append(xregion[p], xin[marker_ind[p]])
        zregion[p] = np.append(zregion[p], zin[marker_ind[p]])
    if p == panel[0]+1 and zregion[p][0] < zmax:
        # ensure start node in region to the right of channel is dry:
        xregion[p] = np.insert(xregion[p], 0, xin[marker_ind[p]-1])
        zregion[p] = np.insert(zregion[p], 0, zin[marker_ind[p]-1])
    if zmax > zmin[p]:
        p_i[p], A_i[p], r_h[p], h_i[p], K_i[p], Q_i[p] = conveyance(
            numH,
            n_co[p],
            xregion[p],
            zregion[p],
            zmin[p],
            zmax)
        if plotting:
            plot_region(
                h_i[p]-zmin[p], 'water level / m',
                r_h[p], 'hydraulic radius / m',
                K_i[p], r'conveyance / $m^3/s$',
                Q_i[p], r'discharge / $m^3/s$',
                'Panel {}'.format(p))
        if printing:
            ratingCurveFileName = 'panel{}_{}.dat'.format(p,args.location)
            with open(ratingCurveFileName, 'w') as f:
                f.write('{:14} {:18} {:12} {:10}\n'.format(
                    '#water level', 'hydraulic radius',
                    'conveyance', 'discharge'))
                f.write('{:14} {:18} {:12} {:10}\n'.format(
                    '#/ m', '/ m', '/ m^3/s', '/ m^3/s'))
                for h in range(numH):
                    f.write('{:7.6f} {:16.6f} {:19.6f} {:11.6f}\n'.format(
                        h_i[p][h]-zmin[p],r_h[p][h],K_i[p][h],Q_i[p][h]))
    else:
        p_i[p], A_i[p], r_h[p], h_i[p], K_i[p], Q_i[p] = [
            [0] * numH for _ in range(6)]

# sort list of discharge lists according to panel fill order:
sortedQ = [Q_i[i] for i in panel]
# create cumulative discharge list:
total_flow = np.cumsum([item[-1] for item in sortedQ])
print('total_flow = ', total_flow)
# target_flow_west = target_flow - Q_i[-1] - Q_i_east[-1]
# calculate velocity: note dependence on hydraulic radius
velocity_channel = Q_i[panel[0]][-1]/A_i[panel[0]][-1]
# velocity_east    = Q_i_east[-1]/A_i_east[-1]

# print(target_flow_west)
# find part-filled panel:
if total_flow[-1] > target_flow:
    ind_p = bisect.bisect(total_flow, target_flow)
else:
    print('Error: imposed discharge is higher than total capacity of panels.')
    sys.exit()

print('index of part-filled panel:', ind_p)

# calculate target flow in part-filled panel:
if ind_p == 0:
    panel_target_flow = target_flow
else:
    panel_target_flow = target_flow - total_flow[ind_p-1]

# find insertion point for target flow value:
ind_q = bisect.bisect(Q_i[panel[ind_p]], panel_target_flow)

print('insertion point =', ind_q)

# find height at target flow by linear interpolation
h_extra = interp(
    panel_target_flow,
    h_i[panel[ind_p]][ind_q],
    h_i[panel[ind_p]][ind_q-1],
    Q_i[panel[ind_p]][ind_q],
    Q_i[panel[ind_p]][ind_q-1])

print('heights:', h_i[panel[ind_p]][ind_q-1], h_extra, h_i[panel[ind_p]][ind_q])

# find area at target flow by linear interpolation
A_extra = interp(
    h_extra,
    A_i[panel[ind_p]][ind_q],
    A_i[panel[ind_p]][ind_q-1],
    h_i[panel[ind_p]][ind_q],
    h_i[panel[ind_p]][ind_q-1])
    
print('hydraulic radii:', r_h[panel[ind_p]][ind_q-1], r_h[panel[ind_p]][ind_q])

velocity_panel    = panel_target_flow/A_extra
print('velocities:', velocity_channel, velocity_panel)

csa = np.zeros(len(xin))             # cross-sectional area of element
csa_p = np.zeros(num_panels)         # cross-sectional area of panel
for i, p in enumerate(panel):
    if i < ind_p:
        # panels are filled
        area_sum = 0
        for m in range(marker_ind[p], marker_ind[p+1]):
            csa[m] = max(0, (zmax - zin[m])*dX)
            area_sum += csa[m]
        csa_p[p] = area_sum
    elif i == ind_p:
        # panel is part-filled
        area_sum = 0
        for m in range(marker_ind[p], marker_ind[p+1]):
            csa[m] = max(0, (h_extra - zin[m])*dX)
            area_sum += csa[m]
        csa_p[p] = area_sum
    else:
        # panel is empty
        for m in range(marker_ind[p], marker_ind[p+1]):
            csa[m] = 0
        csa_p[p] = 0


#print('csa_p[0] = {} csa_p[1] = {}'.format(csa_p[0], csa_p[1]))
#print('A_i_west = {} A_i = {} A_i_east = {}'.format(A_extra, A_i[-1], A_i_east[-1]))


save_bc(outputFilename)