#!/usr/bin/env python
# -*- coding: utf-8 -*-
# run within virtual environment (uses Python 3 syntax)
import argparse
import os
import re
from array import array
from collections import Counter
from subprocess import check_output
def meshFile(param):
    global base
    base, ext = os.path.splitext(param)
    if ext.lower() != '.pyfrm':
        raise argparse.ArgumentTypeError('Mesh file must have a .pyfrm '
                                          'extension')
    if not os.path.exists(param):
        raise argparse.ArgumentTypeError("{} not found".format(param))
    return param
def writeHeader(xdmfFile):
    "write XDMF header"
    xdmfFile.write('\n')
    xdmfFile.write('\n')
    xdmfFile.write('\n')
    xdmfFile.write(' \n')
    return
def writeTopology(xdmfFile, tType, nCells, connFile):
    "write Topology element"
    xdmfFile.write('    \n'.format(nCells))
    xdmfFile.write('     \n'.format(connFile))
    xdmfFile.write('    \n')
    return
def writeGeometry(xdmfFile, nDims, nCells, nodeArray, pyfrm, dataset):
    "write Geometry element"
    nVerts = len(nodeArray)
    # co-ordinates in separate arrays
    if nDims == 2:
        xdmfFile.write('    \n')
    else:
        xdmfFile.write('    \n')
    for coord in range(nDims):
        xdmfFile.write('     \n'.format(
                           ' ; '.join( "$" + str(k) for k in range(nVerts))))
        writeHyperSlab(
                xdmfFile,
                coord, nDims, nCells,
                nodeArray, pyfrm, dataset)
        xdmfFile.write('     \n')
    xdmfFile.write('    \n')
    return
def writeHyperSlab(xdmfFile, coord, nDims, nCells, nodeArray, pyfrm, dataset):
    "write HyperSlab element"
    nVerts = len(nodeArray)
    for vert in nodeArray:
        xdmfFile.write('       \n')
        # start, stride and count of hyperslab region
        xdmfFile.write('         \n')
        # select vertex and co-ordinate (format is vertex, cell, co-ordinate)
        xdmfFile.write('          {:<3} 0   {}\n'.format(vert, coord))
        # select every cell, for this vertex and co-ordinate
        xdmfFile.write('          1   1   1\n')
        # loop over cells
        xdmfFile.write('          1   {} 1\n'.format(nCells))
        xdmfFile.write('          \n')
        xdmfFile.write('          \n')
        xdmfFile.write('          {}:/{}\n'.format(pyfrm, dataset))
        xdmfFile.write('         \n')
        xdmfFile.write('       \n')
    return
def writeAttribute(xdmfFile, tag):
    "write Attribute element"
    xdmfFile.write('    \n')
    xdmfFile.write('     \n')
    # tag with partition number
    xdmfFile.write('      {}\n'.format(tag))
    xdmfFile.write('     \n')
    xdmfFile.write('    \n')
    return
def writeConnectivities(connFile, nCells, nVerts):
    "write connectivities to xml file"
    cf = open(connFile, 'w')
    cf.write('\n')
    for i in range (0, nCells):
        cf.write(' ')
        for j in range(nVerts):
            cf.write(' ' + repr(j*nCells+i).ljust(1))
        cf.write('\n')
    cf.write('\n')
    cf.close()
    print('connectivities written to ' + connFile)
    return
def writeFooter(xdmfFile):
    "write XDMF footer"
    xdmfFile.write(' \n')
    xdmfFile.write('\n')
    return
def readH5lsOutput(os_output):
    "read output from 'h5ls' command"
    # tetrahedral cells (first order to sixth order)
    tet4,  tet10, tet20,  tet35,  tet56,  tet84  = ({} for i in range(6))
    # prism cells
    pri6,  pri18, pri40,  pri75,  pri126, pri196 = ({} for i in range(6))
    # pyramid cells
    pyr5,  pyr14, pyr30,  pyr55,  pyr91,  pyr140 = ({} for i in range(6))
    # hexahedral cells (first order to fifth order)
    hex8,  hex27, hex64,  hex125, hex216  = ({} for i in range(5))
    # triangular cells
    tri3,  tri6,  tri10,  tri15,  tri21   = ({} for i in range(5))
    # quadrilateral cells
    quad4, quad9, quad16, quad25, quad36  = ({} for i in range(5))
    for line in os_output.splitlines():
        # restrict to 'spt' arrays
        spt = re.search('spt', line.decode())
        if spt:
            chunk = line.decode().split()
            partno = int(re.search('\d+', chunk[0]).group())
            nnodes = int(re.search('(\d+),', line.decode()).group(1))
            ncells = int(re.search(' (\d+),', line.decode()).group(1))
            # check whether cells are quadrilaterals
            if re.search('quad', line.decode()):
                if nnodes == 4:
                    quad4[partno] = ncells
                elif nnodes == 9:
                    quad9[partno] = ncells
                elif nnodes == 16:
                    quad16[partno] = ncells
                elif nnodes == 25:
                    quad25[partno] = ncells
                elif nnodes == 36:
                    quad36[partno] = ncells
                else:
                    print("unknown cell order")
            # check whether cells are triangles
            elif re.search('tri', line.decode()):
                if nnodes == 3:
                    tri3[partno] = ncells
                elif nnodes == 6:
                    tri6[partno] = ncells
                elif nnodes == 10:
                    tri10[partno] = ncells
                elif nnodes == 15:
                    tri15[partno] = ncells
                elif nnodes == 21:
                    tri21[partno] = ncells
                else:
                    print("unknown cell order")
            # check whether cells are hexahedrons
            elif re.search('hex', line.decode()):
                if nnodes == 8:
                    hex8[partno] = ncells
                elif nnodes == 27:
                    hex27[partno] = ncells
                elif nnodes == 64:
                    hex64[partno] = ncells
                elif nnodes == 125:
                    hex125[partno] = ncells
                elif nnodes == 216:
                    hex216[partno] = ncells
                else:
                    print("unknown cell order")
            # check whether cells are pyramids
            elif re.search('pyr', line.decode()):
                if nnodes == 5:
                    pyr5[partno] = ncells
                elif nnodes == 14:
                    pyr14[partno] = ncells
                elif nnodes == 30:
                    pyr30[partno] = ncells
                elif nnodes == 55:
                    pyr55[partno] = ncells
                elif nnodes == 91:
                    pyr91[partno] = ncells
                elif nnodes == 140:
                    pyr140[partno] = ncells
                else:
                    print("unknown cell order")
            # check whether cells are prisms
            elif re.search('pri', line.decode()):
                if nnodes == 6:
                    pri6[partno] = ncells
                elif nnodes == 18:
                    pri18[partno] = ncells
                elif nnodes == 40:
                    pri40[partno] = ncells
                elif nnodes == 75:
                    pri75[partno] = ncells
                elif nnodes == 126:
                    pri126[partno] = ncells
                elif nnodes == 196:
                    pri196[partno] = ncells
                else:
                    print("unknown cell order")
            # check whether cells are tetrahedrons
            elif re.search('tet', line.decode()):
                if nnodes == 4:
                    tet4[partno] = ncells
                elif nnodes == 10:
                    tet10[partno] = ncells
                elif nnodes == 20:
                    tet20[partno] = ncells
                elif nnodes == 35:
                    tet35[partno] = ncells
                elif nnodes == 56:
                    tet56[partno] = ncells
                elif nnodes == 84:
                    tet84[partno] = ncells
                else:
                    print("unknown cell order")
            else:
                print("unknown cell type")
                break
    # list of cell dictionaries
    cellDictList = [tet4,  tet10, tet20,  tet35,  tet56,  tet84,
                    pri6,  pri18, pri40,  pri75,  pri126, pri196,
                    pyr5,  pyr14, pyr30,  pyr55,  pyr91,  pyr140,
                    hex8,  hex27, hex64,  hex125, hex216,
                    tri3,  tri6,  tri10,  tri15,  tri21,
                    quad4, quad9, quad16, quad25, quad36]
    return cellDictList
# read command line arguments
parser = argparse.ArgumentParser(
            description="extract connectivities from mesh file"
                        ": write xdmf file")
parser.add_argument("mesh", help="mesh file (.pyfrm)", type=meshFile)
args = parser.parse_args()
# use 'h5ls' command to provide array dimensions
h5lsOutput = check_output(["h5ls", args.mesh])
partitions = readH5lsOutput(h5lsOutput)
# cell types
firstOrderCellType = ['tet', 'tet', 'tet', 'tet', 'tet', 'tet',
                      'pri', 'pri', 'pri', 'pri', 'pri', 'pri',
                      'pyr', 'pyr', 'pyr', 'pyr', 'pyr', 'pyr',
                      'hex', 'hex', 'hex', 'hex', 'hex',
                      'tri', 'tri', 'tri', 'tri', 'tri',
                      'quad', 'quad', 'quad', 'quad', 'quad']
xdmfTopologyType   = {'quad': 'Quadrilateral', 'tri': 'Triangle',
                      'hex' : 'Hexahedron',    'pyr': 'Pyramid',
                      'pri' : 'Wedge',         'tet': 'Tetrahedron'}
ndims              = {'quad': 2, 'tri': 2,
                      'hex' : 3, 'pyr': 3, 'pri': 3, 'tet': 3}
# node identification: reduces high-order cells to first order
# (see pyfr/readers/nodemaps.py)
# Example, for second-order pyramids having 14 solution points and 5 vertices:
# >>> from pyfr.readers.nodemaps import GmshNodeMaps
# >>> [GmshNodeMaps.to_pyfr['pyr', 14][i] for i in range(5)]
nodeIDs = {}
nodeIDs[0]  = [0, 1, 2, 3]                       # tet4
nodeIDs[1]  = [0, 2, 5, 9]                       # tet10
nodeIDs[2]  = [0, 3, 9, 19]                      # tet20
nodeIDs[3]  = [0, 4, 14, 34]                     # tet35
nodeIDs[4]  = [0, 5, 20, 55]                     # tet56
nodeIDs[5]  = [0, 6, 27, 83]                     # tet84
nodeIDs[6]  = [0, 1, 2, 3, 4, 5]                 # pri6
nodeIDs[7]  = [0, 2, 5, 12, 14, 17]              # pri18
nodeIDs[8]  = [0, 3, 9, 30, 33, 39]              # pri40
nodeIDs[9]  = [0, 4, 14, 60, 64, 74]             # pri75
nodeIDs[10] = [0, 5, 20, 105, 110, 125]          # pri126
nodeIDs[11] = [0, 6, 27, 168, 174, 195]          # pri196
nodeIDs[12] = [0, 1, 3, 2, 4]                    # pyr5
nodeIDs[13] = [0, 2, 8, 6, 13]                   # pyr14
nodeIDs[14] = [0, 3, 15, 12, 29]                 # pyr30
nodeIDs[15] = [0, 4, 24, 20, 54]                 # pyr55
nodeIDs[16] = [0, 5, 35, 30, 90]                 # pyr91
nodeIDs[17] = [0, 6, 48, 42, 139]                # pyr140
nodeIDs[18] = [0, 1, 3, 2, 4, 5, 7, 6]           # hex8
nodeIDs[19] = [0, 2, 8, 6, 18, 20, 26, 24]       # hex27
nodeIDs[20] = [0, 3, 15, 12, 48, 51, 63, 60]     # hex64
nodeIDs[21] = [0, 4, 24, 20, 100, 104, 124, 120] # hex125
nodeIDs[22] = [0, 5, 35, 30, 180, 185, 215, 210] # hex216
nodeIDs[23] = [0, 1, 2]                          # tri3
nodeIDs[24] = [0, 2, 5]                          # tri6
nodeIDs[25] = [0, 3, 9]                          # tri10
nodeIDs[26] = [0, 4, 14]                         # tri15
nodeIDs[27] = [0, 5, 20]                         # tri21
nodeIDs[28] = [0, 1, 3, 2]                       # quad4
nodeIDs[29] = [0, 2, 8, 6]                       # quad9
nodeIDs[30] = [0, 3, 15, 12]                     # quad16
nodeIDs[31] = [0, 4, 24, 20]                     # quad25
nodeIDs[32] = [0, 5, 35, 30]                     # quad36
# number of supported cell types
numCellTypes = len(partitions)
# keys are partition numbers
allKeys = []
for cd in partitions:
    allKeys += list(cd.keys())
# number of types present in each partition
numTypes  = Counter(allKeys)
# list of partition keys
partKeys  = list(numTypes.keys())
if partKeys:
    # write files
    g = open(os.path.join(base + '.xdmf'), 'w')
    writeHeader(g)
    for part in partKeys:
        g.write(
            '  \n'.format(part))
        # step through all supported cell types
        for cellType in range(numCellTypes):
            # check whether these cells exist in this partition
            if part in partitions[cellType]:
                idname = firstOrderCellType[cellType] + '_p' + str(part)
                xfname = os.path.join('con_' + idname + '.xml')
                dname  = os.path.join('spt_' + idname)
                g.write(
                    '   \n'.format(
                        xdmfTopologyType[firstOrderCellType[cellType]],
                        part))
                writeTopology(
                        g,
                        xdmfTopologyType[firstOrderCellType[cellType]],
                        partitions[cellType][part],
                        xfname)
                writeGeometry(
                        g,
                        ndims[firstOrderCellType[cellType]],
                        partitions[cellType][part],
                        nodeIDs[cellType],
                        args.mesh,
                        dname)
                writeAttribute(g, part)
                g.write('   \n')
                # connectivities file
                writeConnectivities(
                        xfname,
                        partitions[cellType][part],
                        len(nodeIDs[cellType]))
        g.write('  \n')
    writeFooter(g)
    g.close()
else:
    print("no supported cell types")