#!/usr/bin/env python
from __future__ import division

import os, sys

__author__ = 'Vinicius Wilian D Cruzeiro'
__email__ = 'vwcruzeiro@ufl.edu'
__version__ = '1.0'

try:
    from argparse import ArgumentParser
except ImportError:
    if sys.version_info < (2, 7):
        raise ImportError('%s requires Python 2.7 or later' %
                          (os.path.split(sys.argv[0])[1]))

# We start by setting up the text to be printed when the user types the --help

parser = ArgumentParser(epilog='''This program will generate a TCL script to be used in VMD for
                        better visualization of C(pH,E)MD trajectories.''', usage='%(prog)s [Options]')
parser.add_argument('-v', '--version', action='version', version='%s: %s' %
                    (parser.prog, __version__), help='''show the program's version and exit''')
parser.add_argument('--author', action='version', version='%s author: %s (E-mail: %s )' %
                    (parser.prog, __author__,__email__), help='''show the program's author name and exit''')
parser.add_argument('-O', '--overwrite', dest='overwrite', action='store_const',
                   help='''Allow existing outputs to be overwritten. Default: False''',
                   const=True, default=False)
group = parser.add_argument_group('Required Arguments')
group.add_argument('-cpin', dest='cpin', metavar='FILE',
                   help='''CPIN file used as input in the C(pH,E)MD simulation. Required if CEIN file not
                   stated.''',
                   default=None)
group.add_argument('-cpout', dest='cpouts', metavar='FILE', nargs='*',
                   help='''CPOUT file generated as output in the C(pH,E)MD simulation. If the simulation
                   was restarted, please provide all CPOUT files in the same order as the trajectory will be
                   entered into VMD. Required if CEOUT file(s) not stated.''',
                   default=None)
group.add_argument('-cein', dest='cein', metavar='FILE',
                   help='''CEIN file used as input in the C(pH,E)MD simulation. Required if CPIN file not
                   stated.''',
                   default=None)
group.add_argument('-ceout', dest='ceouts', metavar='FILE', nargs='*',
                   help='''CEOUT file generated as output in the C(pH,E)MD simulation. If the simulation
                   was restarted, please provide all CEOUT files in the same order as the trajectory will be
                   entered into VMD. Required if CPOUT file(s) not stated.''',
                   default=None)
group = parser.add_argument_group('Non-required Arguments')
group.add_argument('-file-out', dest='fileout', metavar='FILE',
                   help='''Name of the TCL script output file. Default: cphemd_script.tcl''',
                   type=str, default="cphemd_script.tcl")

def ErrorExit(text):
    """
    This function prints an error and kills the execution of the program
    """
    
    print ('\nERROR: %s' % text)
    print ('       The execution of vmd_cphemd.py stopped')
    sys.exit(0)

def GetAtoms(resname, startatom, state):

   if resname == 'AS4':
      if state == 0:
         return [startatom + 9, startatom + 12, startatom + 13, startatom + 14]
      elif state == 1:
         return [startatom + 12, startatom + 13, startatom + 14]
      elif state == 2:
         return [startatom + 9, startatom + 13, startatom + 14]
      elif state == 3:
         return [startatom + 9, startatom + 12, startatom + 14]
      elif state == 4:
         return [startatom + 9, startatom + 12, startatom + 13]
   elif resname == 'GL4':
      if state == 0:
         return [startatom + 12, startatom + 15, startatom + 16, startatom + 17]
      elif state == 1:
         return [startatom + 15, startatom + 16, startatom + 17]
      elif state == 2:
         return [startatom + 12, startatom + 16, startatom + 17]
      elif state == 3:
         return [startatom + 12, startatom + 15, startatom + 17]
      elif state == 4:
         return [startatom + 12, startatom + 15, startatom + 16]
   elif resname == 'PRN':
      if state == 0:
         return [startatom + 9, startatom + 8, startatom + 10, startatom + 11]
      elif state == 1:
         return [startatom + 9, startatom + 10, startatom + 11]
      elif state == 2:
         return [startatom + 8, startatom + 10, startatom + 11]
      elif state == 3:
         return [startatom + 9, startatom + 8, startatom + 11]
      elif state == 4:
         return [startatom + 9, startatom + 8, startatom + 10]
   elif resname == 'HIP':
      if state == 0:
         return []
      elif state == 1:
         return [startatom + 12]
      elif state == 2:
         return [startatom + 8]
   elif resname == 'CYS':
      if state == 0:
         return []
      elif state == 1:
         return [startatom + 7]
   elif resname == 'LYS':
      if state == 0:
         return []
      elif state == 1:
         return [startatom + 16]
   elif resname == 'TYR':
      if state == 0:
         return []
      else:
         return [startatom + 13]
    
def main(opt):
    """
    This is the main function to execute the program
    """
    
    # Checking required arguments
    if (not opt.cpin and not opt.cein):
        ErrorExit('Please provide the CPIN and/or CEIN files (using the flags -cpin and -cein). For help type: vmd_cphemd.py --help')
    if (opt.cpin):
        if (not os.path.exists(opt.cpin)): ErrorExit('CPIN file %s does not exist!'%(opt.cpin))
        if (not opt.cpouts):
            ErrorExit('Please provide the CPOUT file(s) (using the flag -cpouts). For help type: vmd_cphemd.py --help')
        else:
            for i in range(len(opt.cpouts)):
                if (not os.path.exists(opt.cpouts[i])): ErrorExit('CPOUT file %s does not exist!'%(opt.cpouts[i]))
    if (opt.cein):
        if (not os.path.exists(opt.cein)): ErrorExit('CEIN file %s does not exist!'%(opt.cein))
        if (not opt.ceouts):
            ErrorExit('Please provide the CEOUT file(s) (using the flag -ceouts). For help type: vmd_cphemd.py --help')
        else:
            for i in range(len(opt.ceouts)):
                if (not os.path.exists(opt.ceouts[i])): ErrorExit('CEOUT file %s does not exist!'%(opt.ceouts[i]))
    if (opt.cpin and opt.cein):
        if (len(opt.cpouts) != len(opt.ceouts)): ErrorExit('The number of CPOUT files provided do not match the number of CEOUT files provided.')
        
    # Opening the TCL script to be written
    if (os.path.exists(opt.fileout) and not opt.overwrite): ErrorExit('The TCL script output file %s already exists. Use the flag -O if you want it to be overwritten'%(opt.fileout))
    outputfile = open(opt.fileout, 'w')
    
    # Writing the portion for pH-active residues of the script
    if (opt.cpin):
        # Initializing variables to be used
        residuelist = []
        residuenamelist = []
        firstatom = []
        statelist = []
        indexedarray = []
        tempholder = []
        holder = []
        
        # Open the CPIN file
        cpin = open(opt.cpin,'r')
        
        # Parse the CPIN file
        for line in cpin:
            for line_index in range(len(line)):
                if line[line_index:line_index+7] == 'Residue':
                    residuenamelist.append(line[line_index+9:line_index+12])
                    num_toadd = ''
                    line_index = line_index + 12
                    while ( line_index < len(line) and line[line_index] != '\'' ):
                        if line[line_index] != ' ' and line[line_index] != ':' and line[line_index] != '\n':
                            num_toadd = num_toadd + line[line_index]
                        line_index = line_index + 1
                    residuelist.append(int(num_toadd))
                
                if line[line_index:line_index+10] == 'FIRST_ATOM': 
                    line_index = line_index + 10
                    num_toadd = ''
                    while ( line_index < len(line) and line[line_index] != ',' ):
                        if line[line_index] != ' ' and line[line_index] != '=':
                            num_toadd = num_toadd + line[line_index]
                        line_index = line_index + 1
                    if num_toadd != '0':
                        firstatom.append(int(num_toadd))            
        cpin.close()
        if (len(residuenamelist) == 0): ErrorExit('Invalid CPIN file! No residues found.')
        
        linenum = 0
        numframes = 0
        foundrst = 0
        # Loop through all of the CPOUT files
        for x in range(len(opt.cpouts)):
            cpout = open(opt.cpouts[x],'r')
            isfirst = 1
            for line in cpout: # process each full record, since those accompany snapshots
                if line[0:10] == 'Solvent pH':
                    numframes = numframes + 1 # this is a new frame
                    foundrst = 1 # we have now found a full record
                    continue   # done for this line, go to next step in loop
                if foundrst == 0: #if we have not yet found a full record
                    continue # skip to next line
                elif foundrst < 4: # the first four lines have no residue data
                    foundrst = foundrst + 1 # so move on to the next line
                elif foundrst < 4 + len(residuelist): # while we're still looking at residues
                    if isfirst == 0: # The first full record corresponds to initial state, not to any frame\
                        statelist.append(int(line[21]))  # only add to this if it's not initial state
                    foundrst = foundrst + 1  # add one to this counter so we know when we've scanned everything
                else:
                    foundrst = 0 # This is only hit if it's the first full record, so reset these values
                    isfirst = 0
        numberframes = int(round(len(statelist)/len(residuenamelist)))
        for x in range(numberframes + 1):
            indexedarray.append(0) # initialize array for pairlist, but only the pointer part
        
        indexpointer = numberframes + 1
        
        for framenum in range(numberframes): # loop through each frame to find protons to remove
            indexedarray[framenum] = indexpointer # put the pointer into its location for each frame
            for resnum in range(len(residuelist)): # scan through all titratable residues to find hydrogens to disappear
                statelistindex = framenum * len(residuelist) + resnum 
                tempholder = tempholder + GetAtoms(residuenamelist[resnum], firstatom[resnum], statelist[statelistindex])
                holder = holder + tempholder
                indexpointer = indexpointer + len(tempholder)
                tempholder = []
           
        indexedarray[numberframes] = indexpointer # set the last index pointer
        tempholder = [] # empty out tempholder
        indexedarray = indexedarray + holder # add the holder array to the indexed array
        holder = []    # empty out holder
        
        for x in range(numberframes):
            if indexedarray[x+1] - indexedarray[x] == 0: continue
            line = "set asel [atomselect top \"index "
            for atom in range(indexedarray[x+1]-indexedarray[x]):
                line = line + str(indexedarray[indexedarray[x] + atom])
                line = line + " "
            line = line + "\" frame "
            line = line + str(x) + "]\n"
            outputfile.write(line)
            line = "$asel moveby {99 99 99}\n" # move them far away to disappear
            outputfile.write(line)
            
    # Writing the portion for redox-active residues of the script
    if (opt.cein):
        # Initializing variables to be used
        residuelist = []
        residuenamelist = []
        statelist = []
        
        # Open the CEIN file
        cein = open(opt.cein,'r')
        
        # Parse the CEIN file
        for line in cein:
            for line_index in range(len(line)):
                if line[line_index:line_index+7] == 'Residue':
                    residuenamelist.append(line[line_index+9:line_index+12])
                    num_toadd = ''
                    line_index = line_index + 12
                    while ( line_index < len(line) and line[line_index] != '\'' ):
                        if line[line_index] != ' ' and line[line_index] != ':' and line[line_index] != '\n':
                            num_toadd = num_toadd + line[line_index]
                        line_index = line_index + 1
                    residuelist.append(int(num_toadd))
        cein.close()
        if (len(residuenamelist) == 0): ErrorExit('Invalid CEIN file! No residues found.')
        
        linenum = 0
        numframes = 0
        foundrst = 0
        # Loop through all of the CEOUT files
        for x in range(len(opt.ceouts)):
            ceout = open(opt.ceouts[x],'r')
            isfirst = 1
            for line in ceout: # process each full record, since those accompany snapshots
                if line[0:15] == 'Redox potential':
                    numframes = numframes + 1 # this is a new frame
                    foundrst = 1 # we have now found a full record
                    continue   # done for this line, go to next step in loop
                if foundrst == 0: #if we have not yet found a full record
                    continue # skip to next line
                elif foundrst < 4: # the first four lines have no residue data
                    foundrst = foundrst + 1 # so move on to the next line
                elif foundrst < 4 + len(residuelist): # while we're still looking at residues
                    if isfirst == 0: # The first full record corresponds to initial state, not to any frame\
                        statelist.append(int(line[21]))  # only add to this if it's not initial state
                    foundrst = foundrst + 1  # add one to this counter so we know when we've scanned everything
                else:
                    foundrst = 0 # This is only hit if it's the first full record, so reset these values
                    isfirst = 0
        numberframes = int(round(len(statelist)/len(residuenamelist)))
        
        for framenum in range(numberframes): # loop through each frame to find protons to remove
            for resnum in range(len(residuelist)): # scan through all redox-active titratable residues
                statelistindex = framenum * len(residuelist) + resnum
                line = "set rsel [atomselect top \"resid " + str(residuelist[resnum]) + "\" frame " + str(framenum) + "]\n"
                outputfile.write(line)
                line = "$rsel set user "
                if (statelist[statelistindex] == 0):
                    line = line + "0.0"
                else:
                    line = line + "100.0"
                line = line + "\n"
                outputfile.write(line)

if __name__ == '__main__':
    opt = parser.parse_args()

    # Go ahead and execute the program.
    main(opt)
    sys.exit(0)
