'''
    SASSIE: Copyright (C) 2011 Joseph E. Curtis, Ph.D. 

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
'''
import sys
import os
import string
import numpy
import time     
import math
#import cPickle as pickle
import sassie.sasmol.sasmol as sasmol

from construct import construct_lmn

from collections import deque as Queue
import pprint,copy

#    DOCKING
#
#    1/5/14    --   fft implementation                :    hz
#
#LC     1         2         3         4         5         6         7
#LC4567890123456789012345678901234567890123456789012345678901234567890123456789
#                                                                      *      **
'''

        DOCKING is the program to perform the FFT docking simulation between 2 molecules and find the the docking configurations
    
        INPUT:
    
         filename1:    Is the PDB file for molecule 1 
    
         filename2:    Is the PDB file for molecule 2 (the molecule to be moved)

         outfile:    Name of file with the new coordinates for molecule 2
    
         txtOutput:    Object to send textual information back to the GUI


        OUTPUT:
    
         outfile:    A file with the aligned coordinates

         txtOutput: Object text sent to GUI
    
'''

def print_failure(message,txtOutput):

    txtOutput.put("\n\n>>>> RUN FAILURE <<<<\n")
    txtOutput.put(">>>> RUN FAILURE <<<<\n")
    txtOutput.put(">>>> RUN FAILURE <<<<\n\n")
    txtOutput.put(message)
    
    return

def unpack_variables(variables):

    runname        =    variables['runname'][0]
    path             =    variables['path'][0]
    pdbmol1        =    variables['pdbmol1'][0]    
    pdbmol2        =    variables['pdbmol2'][0]
    r            =    variables['r'][0]
    d            =    variables['d'][0]
    rou            =    variables['rou'][0]
    delta            =    variables['delta'][0]
    eta            =    variables['eta'][0]
    eta_fine            =    variables['eta_fine'][0]
    N            =    variables['N'][0]
    N_fine            =    variables['N_fine'][0]
    angle_step            =    variables['angle_step'][0]
    nMax            =    variables['nMax'][0]

    return runname,path,pdbmol1,pdbmol2,r,d,rou,delta,eta,eta_fine,N,N_fine,angle_step,nMax



def merge_Q(Q, Qtmp, mMax):
    Q_new = numpy.append(Q,Qtmp)
    Q_new = numpy.sort(Q_new, order=['value'], kind='quicksort')
    if len(Q_new)>=mMax:
        Q_new = Q_new[-mMax:]
    return Q_new



def fft_docking(path,pdbmol1,pdbmol2,outpath,outputfile,r,d,rou,delta,eta,eta_fine,N,N_fine,angle_step,nMax,txtOutput):

    outputfile.write('#Input fixed PDB file name: '+pdbmol1+'\n')
    outputfile.write('#Input non-fixed PDB file name: '+pdbmol2+'\n')
    outputfile.write('#Output complex PDB file name: complex_'+pdbmol1[:-4]+'_'+pdbmol2[:-4]+'.pdb\n')
    outputfile.write('#Output complex DCD file name: complex_'+pdbmol1[:-4]+'_'+pdbmol2[:-4]+'.dcd\n\n')
    
    m1=sasmol.SasMol(0)
    m2=sasmol.SasMol(1)

    m1.read_pdb(os.path.join(pdbmol1))
    m2.read_pdb(os.path.join(pdbmol2))

    m1.write_pdb(os.path.join(outpath,pdbmol1),0,"w")
    m1.write_pdb(os.path.join(outpath,pdbmol2),0,"w")
    
    angle_step  *= numpy.pi/180.

    Q_coarse=numpy.array([(0, 0, 0, 0, 0.0)]*nMax,dtype=[('phi',float),('theta',float),('psi',float),('index',int),('value',float)])
    Q_fine=numpy.array([(0, 0, 0, 0, 0.0)]*nMax,dtype=[('phi',float),('theta',float),('psi',float),('index',int),('value',float)])

    natoms = m2.natoms()
    m2tmp = copy.deepcopy(m2)
    coor=numpy.zeros((1,natoms,3),numpy.float)

    # Center and properly orient the acceptor
    m1.calccom(0) ; com1=m1.com()
    m2.calccom(0) ; com2=m2.com()
    m1.center(0)
    m2.center(0)
    m3 = sasmol.SasMol(2)    
    error = m3.merge_two_molecules(m1,m2)    
    dcdoutfile = m3.open_dcd_write(os.path.join(outpath,'complex_'+pdbmol1[:-4]+'_'+pdbmol2[:-4]+'.dcd'))

    # Construct almn,AopqC
    almn = construct_lmn(m1.coor()[0], rou, N, r, d, eta, 0)
    #print 'almn sum: ',numpy.sum(almn)
    AopqC = numpy.fft.fftn(almn,(N,N,N)).conj()
    #print 'AopqC[0,0,0]: ',AopqC[0,0,0]
    almn_fine = construct_lmn(m1.coor()[0], rou, N_fine, r, d, eta_fine, 0)
    AopqC_fine = numpy.fft.fftn(almn_fine,(N_fine,N_fine,N_fine)).conj()


    # Coarse Scan Stage
    import datetime
    txtOutput.put("\n\nCoarse scanning for the best %d docking orientations...\n"%nMax)
    outputfile.write("\n\nCoarse scanning for the best %d docking orientations...\n"%nMax)
    tot_count = math.pow(len(numpy.arange(0,numpy.pi*2.0,angle_step)),3.)/4.+len(Q_coarse)
    count=0
    for phi in numpy.arange(0,numpy.pi*2.0,angle_step):
        print("\nEuler angle Phi: %6.2f\n"%phi)
        start =datetime.datetime.now()
        for theta in numpy.arange(0,numpy.pi*1.0,angle_step):
            #print('Theta %6.2f'%theta)
            for psi in numpy.arange(0,numpy.pi*1.0,angle_step):
                #print('Psi %6.2f'%psi)
                coor = copy.deepcopy(m2.coor())
                m2tmp.setCoor(coor)
                m2tmp.euler_rotate(0,phi,theta,psi)
                blmn = construct_lmn(m2tmp.coor()[0], delta, N, r, d, eta, 0)
                Bopq = numpy.fft.fftn(blmn,(N,N,N))
                Copq = AopqC*Bopq
                #print 'Copq sum:',numpy.sum(Copq)
                clmn = numpy.fft.ifftn(Copq,(N,N,N))
                #print 'clmn[0,0,0]:',clmn[0,0,0]
                indices = clmn.real.ravel().argsort()[-1:]
                Qtmp = numpy.array([(phi, theta, psi, idx, clmn.real.ravel()[idx]) for idx in indices],dtype=[('phi',float),('theta',float),('psi',float),('index',int),('value',float)])
                #print 'clmn best: ',clmn.ravel()[idx]
                Q_coarse=merge_Q(Q_coarse,Qtmp,nMax)
                count += 1
                fraction_done = (float(count)/float(tot_count))
                report_string='STATUS\t'+str(fraction_done)
                txtOutput.put(report_string)
        end =datetime.datetime.now()
        print('time used ',(end-start).seconds)

    # Coarse scanning results
    outputfile.write('#Best docking score, relative orientation of '+pdbmol2+' w.r.t. '+pdbmol1+' (euler rotation angle: phi, theta, psi), and relative position of center of mass of '+pdbmol2+' w.r.t. that of '+pdbmol1+'\n')
    for i in range(len(Q_coarse)-1,-1,-1):
        phi = Q_coarse[i]['phi']
        theta = Q_coarse[i]['theta']
        psi = Q_coarse[i]['psi']
        idx = Q_coarse[i]['index']
        idx3  = numpy.unravel_index(idx,almn.shape)
        idx3 -= N*numpy.around(numpy.array(idx3)/float(N))
        shift = -numpy.array([idx3[0]*eta, idx3[1]*eta, idx3[2]*eta],numpy.float)
        outputfile.write('%.1f, (%.2f, %.2f, %.2f), (%.2f, %.2f, %.2f)\n'%(Q_coarse[i]['value'],phi,theta,psi,shift[0],shift[1],shift[2]))

    # Fine Scanning Stage
    txtOutput.put("\n\nFine scanning for the best %d docking positions from the best docking orientations found by coarse scanning...\n"%nMax)
    outputfile.write("\n\nFine scanning for the best %d docking positions from the best docking orientations found by coarse scanning...\n"%nMax)
    for i in range(len(Q_coarse)):
        phi = Q_coarse[i]['phi']
        theta = Q_coarse[i]['theta']
        psi = Q_coarse[i]['psi']
        coor = copy.deepcopy(m2.coor())
        m2tmp.setCoor(coor)
        m2tmp.euler_rotate(0,phi,theta,psi)
        blmn_fine = construct_lmn(m2tmp.coor()[0], delta, N_fine, r, d, eta_fine, 0)
        Bopq_fine = numpy.fft.fftn(blmn_fine,(N_fine,N_fine,N_fine))
        Copq_fine = AopqC_fine*Bopq_fine
        #print 'Copq_fine sum:',numpy.sum(Copq_fine)
        clmn_fine = numpy.fft.ifftn(Copq_fine,(N_fine,N_fine,N_fine))
        #print 'clmn_fine[0,0,0]:',clmn_fine[0,0,0]
        indices = clmn_fine.real.ravel().argsort()[-nMax:]
        #indices = clmn_fine.real.ravel().argsort()[-1:] #ZHL hack
        Qtmp = numpy.array([(phi, theta, psi, idx, clmn_fine.real.ravel()[idx]) for idx in indices],dtype=[('phi',float),('theta',float),('psi',float),('index',int),('value',float)])
        #print 'clmn_fine best: ',clmn_fine.ravel()[idx]
        Q_fine = merge_Q(Q_fine,Qtmp,nMax)
        #Q_fine[i] = Qtmp[0] #ZHL hack
        #pickle.dump(clmn_fine.real,open(os.path.join(outpath,'correlation_for_orientation_%.2f,%.2f,%.2f.bin'%(phi,theta,psi)),'wb'))
        count += 1
        fraction_done = (float(count)/float(tot_count))
        report_string='STATUS\t'+str(fraction_done)
        txtOutput.put(report_string)


    # Fine scanning results
    outputfile.write('#Best docking score, relative orientation of '+pdbmol2+' w.r.t. '+pdbmol1+' (euler rotation angle: phi, theta, psi), and relative position of center of mass of '+pdbmol2+' w.r.t. that of '+pdbmol1+'\n')
    count_tmp = 0
    for i in range(len(Q_fine)-1,-1,-1):
        phi = Q_fine[i]['phi']
        theta = Q_fine[i]['theta']
        psi = Q_fine[i]['psi']
        idx = Q_fine[i]['index']
        coor = copy.deepcopy(m2.coor())
        m2tmp.setCoor(coor)
        m2tmp.euler_rotate(0,phi,theta,psi)
        idx3  = numpy.unravel_index(idx,almn_fine.shape)
        idx3 -= N_fine*numpy.around(numpy.array(idx3)/float(N_fine))
        shift = -numpy.array([idx3[0]*eta_fine, idx3[1]*eta_fine, idx3[2]*eta_fine],numpy.float)
        m2tmp.translate(0,shift)
        coor_merge = numpy.concatenate((m1.coor(),m2tmp.coor()),axis=1)
        m3.setCoor(coor_merge)
        m3.translate(0,com1)
        if (i==len(Q_fine)-1):
            m3.write_pdb(os.path.join(outpath,'complex_'+pdbmol1[:-4]+'_'+pdbmol2[:-4]+'.pdb'),0,"w")
        m3.write_dcd_step(dcdoutfile,0,count_tmp+1)
        count_tmp += 1
        outputfile.write('%.1f, (%.2f, %.2f, %.2f), (%.2f, %.2f, %.2f)\n'%(Q_fine[i]['value'],phi,theta,psi,shift[0],shift[1],shift[2]))

    m3.close_dcd_write(dcdoutfile)

def docking(variables,txtOutput):

    runname,path,pdbmol1,pdbmol2,r,d,rou,delta,eta,eta_fine,N,N_fine,angle_step,nMax = unpack_variables(variables)

    if(runname[-1]=='/'):
        lin=len(runname)
        runname=runname[:lin-1]

    direxist=os.path.exists(runname)
    if(direxist==0):
        os.system('mkdir -p '+runname)

    outpath=os.path.join(runname,'docking')
    direxist=os.path.exists(outpath)
    if(direxist==0):
        os.system('mkdir -p '+outpath)

    outputfile=open(os.path.join(outpath,'output.txt'),'w')

    lineintxtOutput=''.join(['=' for x in xrange(60)])
    ttxt=time.ctime()
    txtOutput.put("\n%s \n" %(lineintxtOutput))
    txtOutput.put("DATA FROM RUN: %s \n\n" %(ttxt))
    outputfile.write("\n%s \n" %(lineintxtOutput))
    outputfile.write("DATA FROM RUN: %s \n\n" %(ttxt))


    fft_docking(path,pdbmol1,pdbmol2,outpath,outputfile,r,d,rou,delta,eta,eta_fine,N,N_fine,angle_step,nMax,txtOutput)

    txtOutput.put("\n%s \n" %(lineintxtOutput))
    outputfile.write("\n%s \n" %(lineintxtOutput))
    time.sleep(1.0)

    print 'DOCKING IS DONE'
    txtOutput.put('DOCKING IS DONE')


            
    return
