#!/usr/bin/env python
# -*- coding: utf-8 -*-
# run within virtual environment (uses Python 3 syntax)

import argparse
import os
import re
from array import array
from collections import Counter
from subprocess import check_output

def meshFile(param):
    global base
    base, ext = os.path.splitext(param)
    if ext.lower() != '.pyfrm':
        raise argparse.ArgumentTypeError('Mesh file must have a .pyfrm extension')
    return param

def writeHeader(xdmfFile):
    "write XDMF header"
    xdmfFile.write('<?xml version="1.0" ?>\n')
    xdmfFile.write('<!DOCTYPE Xdmf SYSTEM "Xdmf.dtd" []>\n')
    xdmfFile.write('<Xdmf xmlns:xi="http://www.w3.org/2003/XInclude" Version="2.2">\n')
    xdmfFile.write(' <Domain>\n')
    return

def writeTopology(xdmfFile, nCells, connFile):
    "write Topology element"
    xdmfFile.write('   <Topology TopologyType="Quadrilateral" NumberOfElements="{}">\n'.format(nCells))
    xdmfFile.write('    <xi:include href="{}"/>\n'.format(connFile))
    xdmfFile.write('   </Topology>\n')
    return

def writeGeometry(xdmfFile, nDims, nCells, nVerts, pyfrm, dataset):
    "write Geometry element"
    if nDims == 2:
        xdmfFile.write('   <Geometry GeometryType="X_Y">\n')   # co-ordinates in separate arrays
    else:
        xdmfFile.write('   <Geometry GeometryType="X_Y_Z">\n') # co-ordinates in separate arrays
    for coord in range(nDims):
        xdmfFile.write('    <DataItem ItemType="Function" Dimensions="{}"\n'.format(nVerts*nCells)) # 1D-array
        xdmfFile.write('      Function="JOIN({})">\n'.format(' ; '.join("$" + str(k) for k in range(nVerts))))
        writeHyperSlab(xdmfFile, coord, nDims, nCells, nVerts, pyfrm, dataset)
        xdmfFile.write('    </DataItem>\n')
    xdmfFile.write('   </Geometry>\n')
    return

def writeHyperSlab(xdmfFile, coord, nDims, nCells, nVerts, pyfrm, dataset):
    "write HyperSlab element"
    for vert in range(nVerts):
        xdmfFile.write('      <DataItem ItemType="HyperSlab"\n')
        xdmfFile.write('        Dimensions="{} 1 1"\n'.format(nCells))
        xdmfFile.write('        Type="HyperSlab">\n')
        xdmfFile.write('        <DataItem\n') # start, stride and count of hyperslab region
        xdmfFile.write('         Dimensions="3 3"\n')
        xdmfFile.write('         Format="XML">\n')
        xdmfFile.write('         {:<3} 0   {}\n'.format(vert, coord)) # select vertex and co-ordinate (format is vertex, cell, co-ordinate)
        xdmfFile.write('         1   1   1\n')             # select every cell, for this vertex and co-ordinate
        xdmfFile.write('         1   {} 1\n'.format(nCells)) # loop over cells
        xdmfFile.write('         </DataItem>\n')
        xdmfFile.write('         <DataItem\n')
        xdmfFile.write('         Name="Points" \n')
        xdmfFile.write('         Dimensions="{} {} {}"\n'.format(nVerts, nCells, nDims))
        xdmfFile.write('         Format="HDF">\n')
        xdmfFile.write('         {}:/{}\n'.format(pyfrm, dataset))
        xdmfFile.write('        </DataItem>\n')
        xdmfFile.write('      </DataItem>\n')
    return

def writeAttribute(xdmfFile, tag):
    "write Attribute element"
    xdmfFile.write('   <Attribute Name="Partition" Center="Grid">\n')
    xdmfFile.write('    <DataItem\n')
    xdmfFile.write('     Dimensions="1"\n')
    xdmfFile.write('     Format="XML">\n')
    xdmfFile.write('     {}\n'.format(tag)) # tag with partition number
    xdmfFile.write('    </DataItem>\n')
    xdmfFile.write('   </Attribute>\n')
    return

def writeConnectivities(connFile, nCells, nVerts, orderDict):
    "write connectivities to xml file"
    cf = open(connFile, 'w')
    cf.write('<DataItem DataType="Int"\n')
    cf.write('  Dimensions="{} {}"\n'.format(nCells, nVerts))
    cf.write('  Format="XML">\n')
    
    for i in range (0, nCells):
        cf.write(' ')
        for j in range (1, nVerts+1):
            cf.write(' ' + repr(orderDict[j]*nCells+i).ljust(1))
        cf.write('\n')
    
    cf.write('</DataItem>\n')
    cf.close()
    print('connectivities written to ' + connFile)
    return

def writeFooter(xdmfFile):
    "write XDMF footer"
    xdmfFile.write(' </Domain>\n')
    xdmfFile.write('</Xdmf>\n')
    return

parser = argparse.ArgumentParser(description="extract connectivities from mesh file")
parser.add_argument("mesh", help="mesh file (.pyfrm)", type=meshFile)
args = parser.parse_args()

# use 'h5ls' command to provide array dimensions
h5ls_output = check_output(["h5ls", args.mesh])

nquads = {}
ntris = {}
for line in h5ls_output.splitlines():
    spt = re.search('spt', line.decode()) # restrict to 'spt' arrays
    if spt:
        chunk = line.decode().split()
        npart = int(re.search('\d+', chunk[0]).group())
        ncells = int(re.search(' (\d+),', line.decode()).group(1))
        if re.search('quad', line.decode()): # check whether cell is quadrilateral
            nquads[npart] = ncells
        elif re.search('tri', line.decode()): # check whether cell is triangular
            ntris[npart] = ncells
        else:
            print("unknown cell type")
            break

# cell types
cellTypes = ['quad', 'tri']
numCellTypes = len(cellTypes)
nverts = {cellTypes[0]: 4, cellTypes[1]: 3}
ndims  = {cellTypes[0]: 2, cellTypes[1]: 2}

# XDMF:PyFR vertex numbering
order = []
order.append({1:0, 2:1, 3:3, 4:2}) # quad order
order.append({1:0, 2:1, 3:2})      # tri order

# sort datasets
quadKeys = list(nquads.keys()) # keys are partition numbers
triKeys = list(ntris.keys())   # keys are partition numbers
allKeys = quadKeys + triKeys   # concatenate keys
numTypes = Counter(allKeys)    # number of types present in each partition
partKeys = list(numTypes.keys()) # partition keys
partitions = [nquads, ntris]   # list of dictionaries

# write files
g = open(os.path.join(base + '.xdmf'), 'w')
writeHeader(g)

for part in partKeys:
    if numTypes[part] > 1: # check whether partition contain multiple cell types
        g.write('  <Grid Name="Partition{}" GridType="Collection">\n'.format(part))
    else:
        g.write('  <Grid Name="Partition{}" GridType="Uniform">\n'.format(part))

    for cellType in range(numCellTypes):
        if part in partitions[cellType]: # check whether these cells exist in this partition
            xfname = os.path.join('con_' + cellTypes[cellType] + '_p' + str(part) + '.xml')
            dname  = os.path.join('spt_' + cellTypes[cellType] + '_p' + str(part))
            if numTypes[part] > 1:
                g.write('  <Grid Name="cellType{}" GridType="Uniform">\n'.format(cellType))
            writeTopology(g, partitions[cellType][part], xfname)
            writeGeometry(g, ndims[cellTypes[cellType]], 
                           partitions[cellType][part], 
                           nverts[cellTypes[cellType]], args.mesh, dname)
            if numTypes[part] > 1:
                g.write('  </Grid>\n')
            # connectivities file
            writeConnectivities(xfname, partitions[cellType][part], 
                                nverts[cellTypes[cellType]], order[cellType])

    writeAttribute(g, part)
    g.write('  </Grid>\n')

writeFooter(g)
g.close()