import sys, os
import copy as cp
from pmx import *
from pmx.ndx import *
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import MCS
from rdkit.Chem import FragmentMatcher, Crippen, rdmolops

def alignOnSubset(mol1,mol2,constrMap):
####### use subsets of atoms for o3a alignment ########
    # only heavy atoms can be constraints
    rem = []
    for c in constrMap:
	a1 = mol1.GetAtomWithIdx(c[0])
	a2 = mol2.GetAtomWithIdx(c[1])
        id1 = a1.GetAtomicNum()
        id2 = a2.GetAtomicNum()
        if((id1==1) or (id2==1)):
	    rem.append(c)
    # remove
    for i in rem:
        constrMap.remove(i)

    # align
    mol1_crippen = crippen(mol1)
    mol2_crippen = crippen(mol2)
    pyO3A = AllChem.GetCrippenO3A(mol1,mol2,mol1_crippen,mol2_crippen,constraintMap=constrMap,options=0) # mol1=probe, mol2=ref
    pyO3A.Align()
    return(pyO3A.Score())

def getAttr(n1,n2,iStart,iEnd):
    jStart = None
    jEnd = None
    for foo,bar in zip(n1,n2):
	if(foo==iStart):
	    jStart = bar
	if(foo==iEnd):
	    jEnd = bar
    return(jStart,jEnd)

def checkChiral(mol1,mol2,n1,n2):
####### MCS may map atoms around chiral centers improprietly #
####### if atom1 is chiral and atom2 is chiral ###########
####### if their chiralities do not match ################
####### if 2 or more neighbours of atom1 and #############
####### if 2 or more neighbours of atom2 #################
####### are not in a ring structure ######################
####### then remove those non-ring neighbours from pairs #
    del1 = []
    del2 = []
    for i1,i2 in zip(n1,n2):
	a1 = mol1.GetAtomWithIdx(i1)
	a2 = mol2.GetAtomWithIdx(i2)
	try:
	    cip1 = a1.GetProp("_CIPCode")
	    cip2 = a2.GetProp("_CIPCode")
	    # if stereoisomers match, then let them be
	    if( cip1==cip2 ):
		continue
	    # investigate neighbours
	    nb1 = a1.GetNeighbors()
	    nb2 = a2.GetNeighbors()
	    ####### go over the neighbours for the 1st time ######
	    # atom1
	    counter = 0
	    for anb in nb1:
		if(anb.IsInRing()==False):
		    counter = counter+1
	    if( counter<2 ):
		continue        
	    # atom2 
	    counter = 0
            for anb in nb2:
                if(anb.IsInRing()==False):
                    counter = counter+1
            if( counter<2 ):
                continue
	    ####### go over the neighbours for the 2nd time ######
            # atom1
            for anb in nb1:
                if( (anb.IsInRing()==False) and (anb.GetIdx() in n1) ):
		    del1.append(anb.GetIdx())
            # atom2 
            for anb in nb2:
                if( (anb.IsInRing()==False) and (anb.GetIdx() in n2) ):
                    del2.append(anb.GetIdx())
#	    print i1,i2
#	    print a1.GetChiralTag(),a2.GetChiralTag()
#	    print cip1,cip2,"\n"
	except:
	    continue

    ####### remove #######
    n1_out = []
    n2_out = []
    for i1,i2 in zip(n1,n2):
	if( (i1 in del1) or (i2 in del2) ):
	    continue
	n1_out.append(i1)
	n2_out.append(i2) 

    return(n1_out,n2_out)

def checkTopRemove(mol1,mol2,n1,n2,startList,endList):
    rem1 = []
    rem2 = []
    for iStart,iEnd in zip(startList,endList):
	jStart,jEnd = getAttr(n1,n2,iStart,iEnd)
	# count iStart mapped neighbours
	startNeighb = 0
	for b1 in mol1.GetBonds():
            foo = b1.GetBeginAtomIdx()
            bar = b1.GetEndAtomIdx()
	    if( iStart==foo ): # atom of interest
	        a,b = getAttr(n1,n2,iStart,bar)
		if( (a!=None) and (b!=None) ):
		    startNeighb = startNeighb+1
	    elif( iStart==bar ): # atom of interest
                a,b = getAttr(n1,n2,iStart,foo)
                if( (a!=None) and (b!=None) ):
                    startNeighb = startNeighb+1
	# count iEnd mapped neighbour
	endNeighb = 0
        for b1 in mol1.GetBonds():
            foo = b1.GetBeginAtomIdx()
            bar = b1.GetEndAtomIdx()
            if( iEnd==foo ): # atom of interest
                a,b = getAttr(n1,n2,iEnd,bar)
                if( (a!=None) and (b!=None) ):
                    endNeighb = endNeighb+1
            elif( iEnd==bar ): # atom of interest 
                a,b = getAttr(n1,n2,iEnd,foo)
                if( (a!=None) and (b!=None) ):
                    endNeighb = endNeighb+1
	# add to remove list
	if( startNeighb < endNeighb ):
	    rem1.append(iStart)
	    rem2.append(jStart)
	else:
	    rem1.append(iEnd)
	    rem2.append(jEnd)
    # remove
    for i,j in zip(rem1,rem2):
	n1.remove(i)
	n2.remove(j)
    return(n1,n2)

def getList12(mol,n):
    dict12 = {}
    for a1 in mol.GetAtoms():
	iStart1 = a1.GetIdx()
	if iStart1 not in n:
	    continue
	neighbours1 = a1.GetNeighbors()
	for a2 in neighbours1: # 1-2
	    iEnd2 = a2.GetIdx()
	    if iEnd2 not in n:
		continue
	    if iEnd2 == iStart1:
		continue
 	    if iStart1 in dict12.keys():
	        dict12[iStart1].append(iEnd2)
  	    else:
		dict12[iStart1] = [iEnd2]
    return(dict12)

def getList13(mol,n):
    dict13 = {}
    for a1 in mol.GetAtoms():
	iStart1 = a1.GetIdx()
        if iStart1 not in n:
            continue
	neighbours1 = a1.GetNeighbors()
	for a2 in neighbours1: # 1-2
	    i2 = a2.GetIdx()
#            if i2 not in n:
#                continue
	    if i2 == iStart1:
		continue
	    neighbours2 = a2.GetNeighbors()
	    for a3 in neighbours2: # 1-3
		iEnd3 = a3.GetIdx()
        	if iEnd3 not in n:
	            continue
		if (iEnd3==iStart1) or (iEnd3==i2):
		    continue
 	        if iStart1 in dict13.keys():
		    dict13[iStart1].append(iEnd3)
    		else:
		    dict13[iStart1] = [iEnd3]
    return(dict13)

def getList14(mol,n):
    dict14 = {}
    for a1 in mol.GetAtoms():
	iStart1 = a1.GetIdx()
        if iStart1 not in n:
            continue
	neighbours1 = a1.GetNeighbors()
	for a2 in neighbours1: # 1-2
	    i2 = a2.GetIdx()
#            if i2 not in n:
#                continue
	    if i2 == iStart1:
		continue
	    neighbours2 = a2.GetNeighbors()
	    for a3 in neighbours2: # 1-3
		i3 = a3.GetIdx()
#	        if i3 not in n:
#        	    continue
		if (i3==iStart1) or (i3==i2):
		    continue
                neighbours3 = a3.GetNeighbors()
                for a4 in neighbours3: # 1-4
                    iEnd4 = a4.GetIdx()
	            if iEnd4 not in n:
	                continue
                    if (iEnd4==iStart1) or (iEnd4==i2) or (iEnd4==i3):
                        continue
 	            if iStart1 in dict14.keys():
		        dict14[iStart1].append(iEnd4)
    		    else:
		        dict14[iStart1] = [iEnd4]
    return(dict14)

def findProblemsExclusions(n1,n2,dict_mol1,dict_mol2):
    rem_start = []
    rem_end = []
    for iStart in dict_mol1.keys():
	for iEnd in dict_mol1[iStart]:
	    jStart,jEnd = getAttr(n1,n2,iStart,iEnd)
	    if( (jStart==None) or (jEnd==None) ): # mapped to a dummy, thus no worries
		continue
	    if jStart in dict_mol2.keys():
		if jEnd not in dict_mol2[jStart]:
		    # maybe entry already exists
		    if ((jStart in rem_start) or (jStart in rem_end)) and ((jEnd in rem_start) or (jEnd in rem_end)):
			continue
		    rem_start.append(jStart)
		    rem_end.append(jEnd)
	    elif jEnd not in dict_mol2.keys():
		# a weird situation that shouldn't happen
		print "Warning: something wrong in the 1-2, 1-3 or 1-4 lists. Trying to proceed with the warning..."
                rem_start.append(jStart)
                rem_end.append(jEnd)
    return(rem_start,rem_end)

def fixProblemsExclusions(mol1,mol2,n1,n2,startList,endList):
    rem1 = []
    rem2 = []
    for iStart,iEnd in zip(startList,endList):
        jStart,jEnd = getAttr(n1,n2,iStart,iEnd)
        # count iStart mapped neighbours
        startNeighb = 0
        for b1 in mol1.GetBonds():
            foo = b1.GetBeginAtomIdx()
            bar = b1.GetEndAtomIdx()
            if( iStart==foo ): # atom of interest
                a,b = getAttr(n1,n2,iStart,bar)
                if( (a!=None) and (b!=None) ):
                    startNeighb = startNeighb+1
            elif( iStart==bar ): # atom of interest
                a,b = getAttr(n1,n2,iStart,foo)
                if( (a!=None) and (b!=None) ):
                    startNeighb = startNeighb+1
        # count iEnd mapped neighbour
        endNeighb = 0
        for b1 in mol1.GetBonds():
            foo = b1.GetBeginAtomIdx()
            bar = b1.GetEndAtomIdx()
            if( iEnd==foo ): # atom of interest
                a,b = getAttr(n1,n2,iEnd,bar)
                if( (a!=None) and (b!=None) ):
                    endNeighb = endNeighb+1
            elif( iEnd==bar ): # atom of interest 
                a,b = getAttr(n1,n2,iEnd,foo)
                if( (a!=None) and (b!=None) ):
                    endNeighb = endNeighb+1
        # add to remove list
        if( startNeighb < endNeighb ):
            rem1.append(iStart)
            rem2.append(jStart)
        else:
            rem1.append(iEnd)
            rem2.append(jEnd)
    # remove
    for i,j in zip(rem1,rem2):
        n1.remove(i)
        n2.remove(j)
    return(n1,n2)

def checkTop(mol1,mol2,n1,n2):
    # 1) generate 1-2, 1-3 and 1-4 lists
    # 2) identify problematic mappings
    # 3) fixt the problems: discard the atom with fewer mapped neighbours

    ####### 1-2 #########    
    # 1a) 1-2 lists
    dict12_mol1 = getList12(mol1,n1)
    dict12_mol2 = getList12(mol2,n2)
    # 2a) identify problems 1-2; and 
    # 3a) fix 1-2
    rem12_mol2_start,rem12_mol2_end = findProblemsExclusions(n1,n2,dict12_mol1,dict12_mol2) # output: indeces of mol2
    n2,n1 = fixProblemsExclusions(mol2,mol1,n2,n1,rem12_mol2_start,rem12_mol2_end)
    rem12_mol1_start,rem12_mol1_end = findProblemsExclusions(n2,n1,dict12_mol2,dict12_mol1) # output: indeces of mol1
    n1,n2 = fixProblemsExclusions(mol1,mol2,n1,n2,rem12_mol1_start,rem12_mol1_end)

    ####### 1-3 #########    
    # 1b) 1-3 lists
    dict13_mol1 = getList13(mol1,n1)
    dict13_mol2 = getList13(mol2,n2)
    # 2b) identify problems 1-3 and
    # 3b) fix 1-3
    rem13_mol2_start,rem13_mol2_end = findProblemsExclusions(n1,n2,dict13_mol1,dict13_mol2) # output: indeces of mol2
    n2,n1 = fixProblemsExclusions(mol2,mol1,n2,n1,rem13_mol2_start,rem13_mol2_end)
    rem13_mol1_start,rem13_mol1_end = findProblemsExclusions(n2,n1,dict13_mol2,dict13_mol1) # output: indeces of mol1
    n1,n2 = fixProblemsExclusions(mol1,mol2,n1,n2,rem13_mol1_start,rem13_mol1_end)

    ####### 1-4 #########    
    # 1b) 1-4 lists
    dict14_mol1 = getList14(mol1,n1)
    dict14_mol2 = getList14(mol2,n2)
    # 2b) identify problems 1-4 and 
    # 3b) fix 1-4
    rem14_mol2_start,rem14_mol2_end = findProblemsExclusions(n1,n2,dict14_mol1,dict14_mol2) # output: indeces of mol2
    n2,n1 = fixProblemsExclusions(mol2,mol1,n2,n1,rem14_mol2_start,rem14_mol2_end)
    rem14_mol1_start,rem14_mol1_end = findProblemsExclusions(n2,n1,dict14_mol2,dict14_mol1) # output: indeces of mol1
    n1,n2 = fixProblemsExclusions(mol1,mol2,n1,n2,rem14_mol1_start,rem14_mol1_end)

    # treat disconnected
    n1,n2 = disconnectedMCS(mol1,mol2,n1,n2)
    return(n1,n2)

def disconnectedFragments(mol1,mol2,n1,n2):
    subMol1 = subMolByIndex(mol1,n1)
    subMol2 = subMolByIndex(mol2,n2)

    print n1
    print n2
    # extract fragments
    fragId1 = Chem.rdmolops.GetMolFrags(subMol1)
    fragId2 = Chem.rdmolops.GetMolFrags(subMol2)
    fragMol1 = Chem.rdmolops.GetMolFrags(subMol1,asMols=True)
    fragMol2 = Chem.rdmolops.GetMolFrags(subMol2,asMols=True)
    print fragId1
    print fragId2
    print "***********"

    # check if molecules were fragmented
    if( len(fragId1)==1 or len(fragId2)==1 ):
	return(n1,n2)

    # find the largest fragment
    largestFragIdList1 = []
    largestFragIdList2 = []
    largestFragMolList1 = []
    largestFragMolList2 = []
    largestFragSize = 0
    # identify largest fragment size
    for frId in fragId1:
	if len(frId) > largestFragSize:
	    largestFragSize = len(frId)
    # extract all largest fragments: Mol1
    for (frId,frMol) in zip(fragId1,fragMol1):
	if len(frId)==largestFragSize:
	    largestFragIdList1.append(frId)
	    largestFragMolList1.append(frMol)
    # extract all largest fragments from Mol2 according to the largest fragments of Mol1
#    for frId1 in largestFragIdList1:
#	print frId1
#	for ind in frId1:
#	    print frId1,n1.


#    for (frId,frMol) in zip(fragId2,fragMol2):
#	if len(frId)==largestFragSize:
#	    largestFragIdList2.append(frId)
#	    largestFragMolList2.append(frMol)
    # match largest fragments by minRMSD
  
    # reconstruct IDs of the identified largest fragment
    
    print largestFragSize
#    print largestFragList
    sys.exit(0)

    return(n1,n2)


def checkMCSTop(mol1,mol2,n1,n2):
    # if a bond exists in only one substructure,
    # modify the mapping by discarding one atom participating in the unmatched bond
    # try to discard the atom with fewer mapped neighbours
    
    # mol1 bonds
    startList = []
    endList = []
    for b1 in mol1.GetBonds():
        iStart = b1.GetBeginAtomIdx()
        iEnd = b1.GetEndAtomIdx()
	jStart,jEnd = getAttr(n1,n2,iStart,iEnd)
	if( (jStart!=None) and (jEnd!=None) ): # not dummies
	    bOk = False
	    for b2 in mol2.GetBonds():
		foo = b2.GetBeginAtomIdx()
		bar = b2.GetEndAtomIdx()
		if( foo==jStart and bar==jEnd ):
		    bOk = True
		    break
		elif( foo==jEnd and bar==jStart ):
		    bOk = True
		    break
	    if(bOk == False):
		startList.append(iStart)
		endList.append(iEnd)
#		return(bOk)
    n1,n2 = checkTopRemove(mol1,mol2,n1,n2,startList,endList)

    # mol2 bonds
    startList = []
    endList = []
    for b1 in mol2.GetBonds():
        iStart = b1.GetBeginAtomIdx()
        iEnd = b1.GetEndAtomIdx()
        jStart,jEnd = getAttr(n2,n1,iStart,iEnd)
        if( (jStart!=None) and (jEnd!=None) ): # not dummies
	    bOk = False
            for b2 in mol1.GetBonds():
                foo = b2.GetBeginAtomIdx()
                bar = b2.GetEndAtomIdx()
                if( foo==jStart and bar==jEnd ):
                    bOk = True
                    break
                elif( foo==jEnd and bar==jStart ):
                    bOk = True
                    break
	    if(bOk == False):
                startList.append(iStart)
                endList.append(iEnd)
#		return(bOk)
    n1,n2 = checkTopRemove(mol2,mol1,n2,n1,startList,endList)

    return(True)

def writeFormatPDB(fname,m,title="",nr=1):
    fp = open(fname,'w')
    for atom in m.atoms:
	foo = cp.deepcopy(atom)
	# chlorine
	if( 'CL' in atom.name or 'Cl' in atom.name or 'cl' in atom.name ):
	    foo.name = "CL"+"  "
	    print >>fp, foo
	# bromine
        elif( 'BR' in atom.name or 'Br' in atom.name or 'br' in atom.name ):
            foo.name = "BR"+"  "
            print >>fp, foo
        elif( len(atom.name) > 4): # too long atom name
            foo = cp.deepcopy(atom)
            foo.name = foo.name[:4]
            print >>fp, foo
        else:
            print >>fp, atom
    print >>fp, 'ENDMDL'
    fp.close()

def reformatPDB(filename,num):
    newname = "tempFormat"+str(num)+".pdb"
    m = Model().read(filename)

    # adjust atom names and remember the changes
    atomNameDict = {}
    sigmaHoleCounter = 1
    for a in m.atoms:
        newAtomName = a.name
        if 'EP' in a.name:
            newAtomName = 'HSH'+str(sigmaHoleCounter)
            sigmaHoleCounter+=1
        atomNameDict[newAtomName] = a.name
        a.name = newAtomName

    writeFormatPDB(newname,m)
    return(newname,atomNameDict)

def restoreAtomNames(mol,atomNameDict):
    for atom in mol.GetAtoms():
        newname = atom.GetMonomerInfo().GetName()
        if newname in atomNameDict.keys():
            oldname = atomNameDict[newname]
            atom.GetMonomerInfo().SetName(oldname)

def write_pairs(n1,n2,pairsFilename):
    fp = open(pairsFilename,"w")
    for i1,i2 in zip(n1,n2):
	foo = i1 + 1
	bar = i2 + 1
	fp.write("%s	%s\n" % (foo,bar) )
    fp.close()    

def calcScore(mol1,mol2,n1,n2,bH2H,bH2heavy):
    res = 0.0
    nn1 = len(n1)
    nn2 = len(n2)
    res = (nn1+nn2)/2.0
    if( bH2H==True or bH2heavy==True): # consider hydrogens
	na1 = mol1.GetNumAtoms()
	na2 = mol2.GetNumAtoms()
    else: # no hydrogens
	na1 = mol1.GetNumHeavyAtoms()
	na2 = mol2.GetNumHeavyAtoms()
    res = 1.0 - res/(na1+na2-res)
    return(res)

def distance_based(mol1, mol2, d, id1=None, id2=None, calcOnly=False):
    pairs1 = []
    pairs2 = []

    # to choose one MCS out of many
    if(calcOnly==True):
	dist = 0.0
        c1 = mol1.GetConformer()
        for ind1,ind2 in zip(id1,id2):
            pos1 = c1.GetAtomPosition(ind1)
            pos2 = c1.GetAtomPosition(ind2)
	    dist = dist + 0.1*pos1.Distance(pos2) # Angstroms in pdb files
        return(dist)

    # o3a
    if(id1==None or id2==None):
        c1 = mol1.GetConformer()
        c2 = mol2.GetConformer()
	for a1 in mol1.GetAtoms():
            pos1 = c1.GetAtomPosition(a1.GetIdx())
	    dd = d*10.0 # Angstroms in pdb files
            keep1 = None
            keep2 = None
	    for a2 in mol2.GetAtoms():
                pos2 = c2.GetAtomPosition(a2.GetIdx())
                dist = pos1.Distance(pos2)
                if(dist < dd):
                    dd = dist
                    keep1 = a1.GetIdx()
                    keep2 = a2.GetIdx()
            if( (keep1 is not None) and (keep2 is not None) ):
                pairs1.append(keep1)
                pairs2.append(keep2)		
	return(pairs1,pairs2)

    # mcs
    for ind1 in id1:
	c1 = mol1.GetConformer()
	pos1 = c1.GetAtomPosition(ind1)
	dd = d*10.0 # Angstroms in pdb files
	keep1 = None
	keep2 = None
	for ind2 in id2:
	    c2 = mol2.GetConformer()
	    pos2 = c2.GetAtomPosition(ind2)
	    dist = pos1.Distance(pos2)
	    if(dist < dd):
		dd = dist
		keep1 = ind1
		keep2 = ind2
	if( (keep1 is not None) and (keep2 is not None) ):
	    pairs1.append(keep1)
	    pairs2.append(keep2)	
    return(pairs1,pairs2)

def chargesTypesMMFF(mol):
    mmff = AllChem.MMFFGetMoleculeProperties(mol,mmffVariant='MMFF94')
    return mmff

def crippen(mol):
    crippen = Crippen._GetAtomContribs(mol)
    return crippen

def o3a_alignment(mol1, mol2, bH2H, bH2heavy, bRingsOnly, bCrippen=True, d=0.05):
    n1 = []
    n2 = []
####################################
# prepare molecules and parameters #
####################################
    if( bRingsOnly==True ):
	submol1 = subMolRing(mol1)
	submol2 = subMolRing(mol2)
###################
#### now align ####
###################
    if( bCrippen==True ): # always Crippen
        print "Trying Wildman-Crippen based O3A alignment"
        print "S. A. Wildman and G. M. Crippen JCICS _39_ 868-873 (1999)\n"
        mol1_crippen = crippen(mol1)
        mol2_crippen = crippen(mol2)
        if( bRingsOnly==True ):
            submol1_crippen = crippen(submol1)
            submol2_crippen = crippen(submol2)
            pyO3A = AllChem.GetCrippenO3A(submol1,submol2,submol1_crippen,submol2_crippen,options=0)
            rmsd = pyO3A.Align()
            # now align the full molecule
            pyO3A = AllChem.GetCrippenO3A(mol1,submol1,mol1_crippen,submol1_crippen,options=0)
            pyO3A.Align()
        else:
            pyO3A = AllChem.GetCrippenO3A(mol1,mol2,mol1_crippen,mol2_crippen,options=0) # mol1=probe, mol2=ref
            pyO3A.Align()
        # distances
        n1,n2 = distance_based(mol1,mol2,d)
# hydrogen rule
    n1,n2 = removeH(mol1,mol2,n1,n2,bH2H,bH2heavy)
# triple bond rule: for simulation stability do not allow morphing an atom that is involved in a triple bond
# into an atom that is involved in a non-triple bond (and vice versa)
    n1,n2 = tripleBond(mol1,mol2,n1,n2)
# rings
    n1,n2 = matchRings(mol1,mol2,n1,n2)
# treat disconnected
    n1,n2 = disconnectedMCS(mol1,mol2,n1,n2)
# n1 and n2 are not sorted by pairs at this point
    n1,n2 = sortInd(mol1,mol2,n1,n2)
# checking possible issues with the 1-2, 1-3 and 1-4 interactions
    n1,n2 = checkTop(mol1,mol2,n1,n2) 

    return(n1,n2,pyO3A)

# checking that a non-ring atom would not be morphed into a ring atom
def matchRings(mol1,mol2,nfoo,nbar):
    newn1 = []
    newn2 = []
    for n1,n2 in zip(nfoo,nbar):
        a1 = mol1.GetAtomWithIdx(n1)
        a2 = mol2.GetAtomWithIdx(n2)
#        arom1 = a1.GetIsAromatic()
#        arom2 = a2.GetIsAromatic()
	ring1 = a1.IsInRing()
	ring2 = a2.IsInRing()
	if(ring1==True and ring2==False):
	    continue
	if(ring1==False and ring2==True):
	    continue
        newn1.append(n1)
        newn2.append(n2)

    # only one atom morphed in a ring will not work, check for that
    n1,n2 = oneAtomInRing(mol1,mol2,newn1,newn2)
    return(n1,n2)

def oneAtomInRing(mol1,mol2,n1,n2):
    newn1 = []
    newn2 = []
    for i,j in zip(n1,n2):
	a1 = mol1.GetAtomWithIdx(i)
	a2 = mol2.GetAtomWithIdx(j)
        ring1 = a1.IsInRing()
        ring2 = a2.IsInRing()
	if( ring1==True and ring2==True ):
	    bonds1 = a1.GetBonds()	
	    bonds2 = a2.GetBonds()
	    found = 0
	    for b1 in bonds1:
		id1 = b1.GetEndAtomIdx()
		at1 = b1.GetEndAtom()
		if( b1.GetEndAtomIdx()==i ):
		    id1 = b1.GetBeginAtomIdx()
		    at1 = b1.GetBeginAtom()
		for b2 in bonds2:
                    id2 = b2.GetEndAtomIdx()
                    at2 = b2.GetEndAtom()
                    if( b2.GetEndAtomIdx()==j ):
                        id2 = b2.GetBeginAtomIdx()
                        at2 = b2.GetBeginAtom()
		    if(at1.IsInRing()==True and at2.IsInRing()==True):
			if( (id1 in n1) and (id2 in n2) ):
			    found = 1
			    break
		if(found==1):
		    break
	    if(found==1):
		newn1.append(i)
		newn2.append(j)
	else:
	    newn1.append(i)
	    newn2.append(j)
    return(newn1,newn2)

def carbonize(mol, bH2heavy=False, bRingsOnly=None):
    for atom in mol.GetAtoms():
	if(atom.GetAtomicNum() != 1):
  	    atom.SetAtomicNum(6)
	elif(bH2heavy == True):
	    atom.SetAtomicNum(6)
	if( (bRingsOnly!=None) and (atom.IsInRing()==False) ):
	    atom.SetAtomicNum(bRingsOnly)

def carbonizeCrippen(mol, bH2heavy=False, bRingsOnly=None):
    crippen = []
    for atom in mol.GetAtoms():
        if(atom.GetAtomicNum() != 1):
	    foo = (0.1441,2.503)
            crippen.append(foo)
        elif(bH2heavy == True):
            foo = (0.1441,2.503)
            crippen.append(foo)
        if( (bRingsOnly!=None) and (atom.IsInRing()==False) ):
            foo = (bRingsOnly*(-1),bRingsOnly)
	    crippen.append(foo)
    return(crippen)

def getBondLength(mol,id1,id2):
    conf = mol.GetConformer()
    pos1 = conf.GetAtomPosition(id1)
    pos2 = conf.GetAtomPosition(id2)
    return(pos1.Distance(pos2))

def isTriple(mol,a):
    bTriple = False
    neighb = a.GetNeighbors()
    # analyze C,N (S not considered, because somewhat exotic) atoms for triple bonds
    # C
    if( a.GetAtomicNum()==6 ):
	if( len(a.GetNeighbors())==2 ):
	    for neighb in a.GetNeighbors():
	        if( neighb.GetAtomicNum()==7 ):
	            if( len(neighb.GetNeighbors())==1 ):
            		bTriple=True
		if( neighb.GetAtomicNum()==6 ):
		    if( len(neighb.GetNeighbors())==2 ):
			if( getBondLength(mol,a.GetIdx(),neighb.GetIdx())<1.25 ): # need to check bond length (in Angstroms)
			    bTriple=True
    # N
    elif( a.GetAtomicNum()==7 ):
	if( len(a.GetNeighbors())==1 ):
	    bTriple=True

   #    Chem.MolToMolFile(mol1,"foomol.mol")
#    foo = Chem.MolFromMolFile("foomol.mol")
#    Chem.MolToMolFile(mol2,"barmol.mol")
#    bar = Chem.MolFromMolFile("barmol.mol")
#    os.remove("foomol.mol")
#    os.remove("barmol.mol")

    return(bTriple)

def tripleBond(mol1,mol2,nfoo,nbar):
    newn1 = []
    newn2 = []
    for n1,n2 in zip(nfoo,nbar):
        a1 = mol1.GetAtomWithIdx(n1)
        a2 = mol2.GetAtomWithIdx(n2)
	bTriple1 = False
	bTriple2 = False
	# identify if bTriple is True/False
	bTriple1 = isTriple(mol1,a1)
	bTriple2 = isTriple(mol2,a2)
        if(bTriple1==True and bTriple2==False):
            continue
        elif(bTriple2==True and bTriple1==False):
            continue
        newn1.append(n1)
        newn2.append(n2)
    return(newn1,newn2)

def removeH(mol1,mol2,nfoo,nbar,bH2H,bH2heavy):
    newn1 = []
    newn2 = [] 
    for n1,n2 in zip(nfoo,nbar):
	a1 = mol1.GetAtomWithIdx(n1)
	a2 = mol2.GetAtomWithIdx(n2)
	id1 = a1.GetAtomicNum()
	id2 = a2.GetAtomicNum()
	if(bH2H==False and id1==1 and id2==1):
	    continue
	elif(bH2heavy==False and ( (id1==1) ^ (id2==1) ) ): # ^ := xor
	    continue
	newn1.append(n1)
	newn2.append(n2)
    return(newn1,newn2)

def subMolByIndex(mol,ind):
    copyMol = copy.deepcopy(mol)
    editMol = Chem.EditableMol(copyMol)
    indRm = []
#   create an inverted list of ind, i.e. indRm
    for a in mol.GetAtoms():
        found = 0
        for i in ind:
            if(i == a.GetIdx()):
                found = 1
                break
        if(found == 0):
            indRm.append(a.GetIdx())
#   remove the indRm atoms
    indRm.sort(reverse=True)
    for i in indRm:
        editMol.RemoveAtom(i)
    copyMol = editMol.GetMol()
    return(copyMol)

def subMolRing(mol):
    copyMol = copy.deepcopy(mol)
    editMol = Chem.EditableMol(copyMol)
    indRm = []
#   create an inverted list of ind, i.e. indRm
    for a in mol.GetAtoms():
        if(a.IsInRing()==False):
            indRm.append(a.GetIdx())
#   remove the indRm atoms
    indRm.sort(reverse=True)
    for i in indRm:
        editMol.RemoveAtom(i)
    copyMol = editMol.GetMol()
    return(copyMol)

def calcRMSD(mol1,mol2,ind1,ind2):
    rmsd = 0.0
    c1 = mol1.GetConformer()
    c2 = mol2.GetConformer()
    for id1 in ind1:
	pos1 = c1.GetAtomPosition(id1)
	toAdd = 999.999
	for id2 in ind2:
	    pos2 = c2.GetAtomPosition(id2)
	    if(pos1.Distance(pos2)<toAdd):
		toAdd = pos1.Distance(pos2)
	rmsd = rmsd + pow(toAdd,2)
    rmsd = sqrt(rmsd)
    return(rmsd)

def matchIDbyRMSD(subMol,ind,mol):
    res_ind = []
    subc = subMol.GetConformer()
    c = mol.GetConformer()
    for i in ind:
        pos1 = subc.GetAtomPosition(i)
        rmsd = 999.999
	keep = -1
        for atom in mol.GetAtoms():
            pos2 = c.GetAtomPosition(atom.GetIdx())
            if(pos1.Distance(pos2)<rmsd):
                rmsd = pos1.Distance(pos2)
		keep = atom.GetIdx()
    	res_ind.append(keep)
    return(res_ind)


def disconnectedMCS(mol1,mol2,ind1,ind2):
    subMol1 = subMolByIndex(mol1,ind1)
    subMol2 = subMolByIndex(mol2,ind2)
    res = MCS.FindMCS([subMol1,subMol2],ringMatchesRingOnly=True, completeRingsOnly=True, atomCompare='any', bondCompare='any')
    p = Chem.FragmentMatcher
    pp = p.FragmentMatcher()
    n1_orig = []
    n2_orig = []
    try:
	pp.Init(res.smarts)
    except:
	if len(ind1)>1:
	    print "WARNING: the mapping may (but not necessarily) contain disconnected fragments. Proceed with caution."
	return(ind1,ind2)
#	return(n1_orig,n2_orig)
    n1_list = pp.GetMatches(subMol1)
    n2_list = pp.GetMatches(subMol2)
# find out which of the generated list pairs has the smallest rmsd
    minRMSD = 999999.99
    for nl1 in n1_list:
	for nl2 in n2_list:
	    rmsd = calcRMSD(subMol1,subMol2,nl1,nl2)
	    if(rmsd < minRMSD):
		minRMSD = rmsd
		n1 = nl1
		n2 = nl2
# match indices n1,n2 to the original molecule ind1,ind2
    n1_orig = matchIDbyRMSD(subMol1,n1,mol1)
    n2_orig = matchIDbyRMSD(subMol2,n2,mol2)
    return(n1_orig,n2_orig)

def sortInd(mol1,mol2,ind1,ind2):
    n1out = []
    n2out = []
    c1 = mol1.GetConformer()
    c2 = mol2.GetConformer()
    for id1 in ind1:
	pos1 = c1.GetAtomPosition(id1)
	minDist = 9999.999
	keep = -1
	for id2 in ind2:
	    pos2 = c2.GetAtomPosition(id2)
            if(pos1.Distance(pos2)<minDist):
		minDist = pos1.Distance(pos2)
                keep = id2
	n1out.append(id1)
	n2out.append(keep)
    return(n1out,n2out)

def disconnected(mol,ind):
#   create a molecule of the ind atoms only
    copyMol = copy.deepcopy(mol)
    editMol = Chem.EditableMol(copyMol)
    indRm = []
#   create an inverted list of ind, i.e. indRm
    for a in mol.GetAtoms():
	found = 0
	for i in ind:
	    if(i == a.GetIdx()):
		found = 1
		break
	if(found == 0):
            indRm.append(a.GetIdx())
#   remove the indRm atoms
#   print len(indRm),len(ind),copyMol.GetNumAtoms()
    indRm.sort(reverse=True)
    for i in indRm:
	editMol.RemoveAtom(i)
    copyMol = editMol.GetMol()
#   get the disconnected fragments
    indLists = Chem.GetMolFrags(copyMol,asMols=False,sanitizeFrags=False)
#   only keep the largest lists
#    maxSize = getLargestList(indLists)
#    idLists = []
#    for l in indLists:
#	if(len(l) == maxSize):
#	    idLists.append(l)
#   match the IDs of the idLists to the original ind
    resID = []
    for l in indLists:
        ll = [ind[i] for i in l]
	resID.append(ll)
    return(resID)

def getLargestList(lists):
    res = 0
    for l in lists:
	if(len(l) > res):
	    res = len(l)
    return(res)

def genFilename(filename,counter):
    if(counter == 0):
	name = filename
    else:
	name = os.path.splitext(filename)[0]+"_"+str(counter)+os.path.splitext(filename)[1]
    return(name)

def incrementByOne(foo):
    bar = []
    for l in foo:
	l = l+1
	bar.append(l)
    return(bar)

def mcsHremove(mol1,mol2,n1_list,n2_list,bH2H,bH2heavy):
    n1 = []
    n2 = []
    for nfoo in n1_list:
	for nbar in n2_list:
            foo,bar = removeH(mol1,mol2,nfoo,nbar,bH2H,bH2heavy)
	    n1.append(foo)
	    n2.append(bar)
    return(n1,n2)

def mcsDist(mol1,mol2,n1_list,n2_list,d):
    n1 = []
    n2 = []
    # distances
    maxMCS = 0 # size of the largest MCS fulfilling the distances
    for nfoo,nbar in zip(n1_list,n2_list):
        alignID = zip(nfoo,nbar)
##########################################
###### o3a alignment may work better #####
#        rmsd = Chem.rdMolAlign.AlignMol(mol1,mol2,atomMap=alignID)
	rmsd = alignOnSubset(mol1,mol2,alignID)
#        print "RMSD after alignment: %f Angstroms" %rmsd
        x,y = distance_based(mol1,mol2,d,nfoo,nbar)
	n1.append(x)
	n2.append(y)
	if(len(x)>maxMCS):
	    maxMCS = len(x)
    
    nn1 = []
    nn2 = []
    # match rings and remove disconnected
    maxMCS = 0
    for nfoo,nbar in zip(n1,n2):
	# rings
        x,y = matchRings(mol1,mol2,nfoo,nbar)
	# disconnected
	x,y = disconnectedMCS(mol1,mol2,x,y)
	nn1.append(x)
	nn2.append(y)
        if(len(x)>maxMCS):
            maxMCS = len(x)
    
    print "maxMCS after distance treatment: %d" % maxMCS

    n1 = []
    n2 = []
    # only keep the largest MCSs
    maxSize = getLargestList(nn1+nn2)
    for nfoo,nbar in zip(nn1,nn2):
	if(len(nfoo)==maxSize and len(nbar)==maxSize):
  	    n1.append(nfoo)
	    n2.append(nbar)
#	    print "foo",nfoo,nbar
    return(n1,n2)
 
def selectOneMCS(n1_list,n2_list,mol1,mol2):
    n1 = n1_list[0]
    n2 = n2_list[0]
    rmsdMin = 9999.999
    for nfoo,nbar in zip(n1_list,n2_list):
	# align
        alignID = zip(nfoo,nbar)
	try:
            rmsd = Chem.rdMolAlign.AlignMol(mol1,mol2,atomMap=alignID)
	except:
	    rmsd = rmsdMin*10.0
#            x,y = distance_based(mol1,mol2,d,nfoo,nbar,True)
	    # compare
        if( rmsd < rmsdMin ):
	    rmsdMin = rmsd
	    n1 = nfoo
	    n2 = nbar
    return(n1,n2)

def matchFullRings(mol1,mol2,n1,n2):
    r1 = mol1.GetRingInfo()
    r2 = mol2.GetRingInfo()
    rem1 = []
    rem2 = []
    for i,j in zip(n1,n2):
        a1 = mol1.GetAtomWithIdx(i)
        a2 = mol2.GetAtomWithIdx(j)
        ring1 = a1.IsInRing()
        ring2 = a2.IsInRing()
	if( (ring1==True) and (ring2==False) ):
	    rem1.append(i)
	    rem2.append(j)
        elif( (ring1==False) and (ring2==True) ):
	    rem1.append(i)
	    rem2.append(j)
        elif( (ring1==True) and (ring2==True) ):
	    mapped1 = False
	    mapped2 = False
	    # investigate the rings
	    for ar1 in r1.AtomRings():
		if( i in ar1 ):
	            mapped1 = isMapped(ar1,n1)
		if( mapped1 == True):	
		    break
            for ar2 in r2.AtomRings():
		if( j in ar2 ):
                    mapped2 = isMapped(ar2,n2)
                if( mapped2 == True):
                    break
	    if( (mapped1==False) or (mapped2==False) ):
	        rem1.append(i)
	        rem2.append(j)
    # remove
    for i,j in zip(rem1,rem2):
        n1.remove(i)
	n2.remove(j)

    return(n1,n2)

def isMapped(ring,ind):
    for a in ring:
        if( a in ind):
	    continue
	else:
	    return(False)
    return(True)

def mcs(mol1, mol2, bH2H, bH2heavy, bdMCS, bRingsOnly, d, bChiral, t=None):
    # make all atoms into carbon
    foo = copy.deepcopy(mol1)
    bar = copy.deepcopy(mol2)
    if( bRingsOnly==True ):
        carbonize(foo,bH2heavy,42)
        carbonize(bar,bH2heavy,43)
    else:
	carbonize(foo,bH2heavy)
	carbonize(bar,bH2heavy)
    mols = [foo,bar]
    print "Searching..."
    res = MCS.FindMCS(mols,ringMatchesRingOnly=True, completeRingsOnly=True, atomCompare='elements', bondCompare='any', timeout=t, maximize='bonds')
    p = Chem.FragmentMatcher
    pp = p.FragmentMatcher()
    n1_list = []
    n2_list = []
    try:
        pp.Init(res.smarts)
    except:
        return(n1_list,n2_list)
    n1_list = pp.GetMatches(foo)
    n2_list = pp.GetMatches(bar)
    print 'Found %d MCSs in total (mol1: %d, mol2: %d), each with %d atoms and %d bonds' % (len(n1_list)*len(n2_list),len(n1_list),len(n2_list),res.numAtoms,res.numBonds)

    # if hydrogens to be removed
    if(bH2H==False or bH2heavy==False):
	n1_list,n2_list = mcsHremove(mol1,mol2,n1_list,n2_list,bH2H,bH2heavy)
    # from this point n1_list and n2_list elements must match 1to1, i.e. the number of elements in the lists is the same

# triple bond rule: for simulation stability do not allow morphing an atom that is involved in a triple bond
# into an atom that is involved in a non-triple bond (and vice versa)
# also checking possible issues with the 1-2, 1-3 and 1-4 interactions
    n1_foo = []
    n2_foo = []
    for n1,n2 in zip(n1_list,n2_list):
        n1,n2 = tripleBond(mol1,mol2,n1,n2)
        n1,n2 = checkTop(mol1,mol2,n1,n2) 
	n1_foo.append(n1)
	n2_foo.append(n2)
    n1_list = cp.copy(n1_foo)
    n2_list = cp.copy(n2_foo)

#############################################
######### chirality check ###################
    bChiral = False
    if( bChiral==True ):
        print "Chirality check.\n"
        n1_foo = []
        n2_foo = []
        for n1,n2 in zip(n1_list,n2_list):
            n1,n2 = checkChiral(mol1,mol2,n1,n2)
            n1_foo.append(n1)
            n2_foo.append(n2)
        n1_list = cp.copy(n1_foo)
        n2_list = cp.copy(n2_foo)

# if any triple bonds or chiral atoms have been removed need to treat disconnected fragments
    n1_foo = []
    n2_foo = []
    for n1,n2 in zip(n1_list,n2_list):
        n1,n2 = disconnectedMCS(mol1,mol2,n1,n2)
        n1_foo.append(n1)
        n2_foo.append(n2)
    n1_list = cp.copy(n1_foo)
    n2_list = cp.copy(n2_foo)

    # if distances to be compared
    if(bdMCS==True):
	n1_list,n2_list = mcsDist(mol1,mol2,n1_list,n2_list,d)
	# mcs matches complete rings #
	# however if distance to be compared
	# full rings need to be matched again
	# ... maybe don't match full rings, because there may not be full rings in the mappings ...
    	bMatchFullRings = False#True
	if( bMatchFullRings==True ):
            print "Matching full rings.\n"
            n1_foo = []
            n2_foo = []
            for n1,n2 in zip(n1_list,n2_list):
                n1,n2 = matchFullRings(mol1,mol2,n1,n2)
                n1_foo.append(n1)
                n2_foo.append(n2)
            n1_list = cp.copy(n1_foo)
            n2_list = cp.copy(n2_foo)
	# again check for disconnectedMCS
        n1_foo = []
        n2_foo = []
        for n1,n2 in zip(n1_list,n2_list):
            n1,n2 = disconnectedMCS(mol1,mol2,n1,n2)
            n1_foo.append(n1)
            n2_foo.append(n2)
        n1_list = cp.copy(n1_foo)
        n2_list = cp.copy(n2_foo)


    # if there are several MCSs, select the one yielding the smallest RMSD
    n1,n2 = selectOneMCS(n1_list,n2_list,mol1,mol2)

    # one more final check for possible issues with the 1-2, 1-3 and 1-4 interactions
    n1,n2 = checkTop(mol1,mol2,n1,n2) 

    print 'Final MCS that survived after pruning: %d atoms' % (len(n1))

    return n1,n2

def checkRingsOnlyFlag(mol1,mol2):
    flag = 0
    for atom in mol1.GetAtoms():
	if(atom.IsInRing()==True):
	    flag = flag + 1
	    break
    for atom in mol2.GetAtoms():
        if(atom.IsInRing()==True):
            flag = flag + 1
            break
    if(flag==2):
	return(True)
    else:
	return(False)

def main(argv):

    desc=("Provided two structures find atoms to be morphed.")

# define input/output files

    files= [
        FileOption("-i1", "r",["pdb"],"input1.pdb", "input"),
        FileOption("-i2", "r",["pdb"],"input2.pdb", "input"),
        FileOption("-o", "w",["dat"],"pairs.dat", "output"),
        FileOption("-opdb1", "w/o",["pdb"],"out_pdb1.pdb", "optional output of superimposed structure 1"),
        FileOption("-opdb2", "w/o",["pdb"],"out_pdb2.pdb", "optional output of superimposed structure 2"),
        FileOption("-opdbm1", "w/o",["pdb"],"out_pdb_morphe1.pdb", "optional output of the morphable atoms str1"),
        FileOption("-opdbm2", "w/o",["pdb"],"out_pdb_morphe2.pdb", "optional output of the morphable atoms str2"),
        FileOption("-score", "w/o",["dat"],"out_score.dat", "optional output: score of the morphe"),
        FileOption("-n1", "r/o",["ndx"],"scaffold1" ,"optionally read index of atoms to consider: mol1"),
        FileOption("-n2", "r/o",["ndx"],"scaffold2","optionally read index of atoms to consider: mol2" ),
        ]

# define options

    options=[
        Option( "-alignment", "bool", "False", "method 1: 3D alignment"),
        Option( "-mcs", "bool", "False", "method 2: maximum common substructure"),
        Option( "-H2H", "bool", "False", "should hydrogen be morphed into hydrogen?"),
        Option( "-H2heavy", "bool", "False", "should hydrogen be morphed into a heavy atom (also sets -H2H to true)"),
        Option( "-RingsOnly", "bool", "False", "should rings only be used in the MCS search or alignment"),
        Option( "-dMCS", "bool", "False", "find MCS, superimpose the structures based on the MCS, apply distance in\
		Cartesian coordinate space to define the morphes"),
#        Option( "-chirality", "bool", "True", "perform chirality check for MCS mapping"),
        Option( "-d", "float", "0.05", "distance (nm) between atoms to consider them morphable"),
	Option( "-timeout", "int", "None", "maximum time (s) for an MCS search"),
        ]

    help_text = ()

# pass options, files and the command line to pymacs

    cmdl = Commandline( argv, options = options, fileoptions = files, program_desc = help_text, check_for_existing_files = False, version = "0.0" )
    
    # deal with the flags
    bH2H = False
    bChiral = False
    d = 0.05
    timeout = None
    if(cmdl['-H2H']==True):
	bH2H = True
    bH2heavy = False
    if(cmdl['-H2heavy']==True):
        bH2heavy = True
	bH2H = True
    if(cmdl.opt['-d'].is_set):
	d = cmdl['-d']
    if(cmdl.opt['-timeout'].is_set):
	timeout = cmdl['-timeout']
#    if(cmdl['-chirality']==False):
#        bChiral = False

    # read index
    read_from_idx1 = False
    read_from_idx2 = False
    if cmdl.opt['-n1'].is_set:
        read_from_idx1 = True
        idx1 = IndexFile(cmdl['-n1']).dic['scaffold']
    if cmdl.opt['-n2'].is_set:
        read_from_idx2 = True
        idx2 = IndexFile(cmdl['-n2']).dic['scaffold']

#args=parse_common_args(sys.argv,files,options, desc)

    # reformat PDB to read two letter atoms properly (why does it have to be so painful?)
    pdbName1,atomNameDict1 = reformatPDB(cmdl['-i1'],1)
    pdbName2,atomNameDict2 = reformatPDB(cmdl['-i2'],2)

    mol1 = Chem.MolFromPDBFile(pdbName1,removeHs=False,sanitize=False)
    mol2 = Chem.MolFromPDBFile(pdbName2,removeHs=False,sanitize=False)

    os.remove(pdbName1)
    os.remove(pdbName2)

    try:
	rdmolops.AssignAtomChiralTagsFromStructure(mol1)
	rdmolops.AssignAtomChiralTagsFromStructure(mol2)
	rdmolops.AssignStereochemistry(mol1)
	rdmolops.AssignStereochemistry(mol2)
    except:
	print "Chirality not assigned"

#    mol1 = Chem.SDMolSupplier(cmdl['-i1'],removeHs=False,sanitize=True)
#    mol2 = Chem.SDMolSupplier(cmdl['-i2'],removeHs=False,sanitize=True)
#    mol1 = Chem.SDMolSupplier(cmdl['-i1'])
#    mol2 = Chem.SDMolSupplier(cmdl['-i2'])
#    mol1 = mol1[0]
#    mol2 = mol2[0]

    # deal with the -RingsOnly flag
    bRingsOnly = False
    bYesRings = checkRingsOnlyFlag(mol1,mol2)
    if(cmdl['-RingsOnly']==True):
	if( bYesRings==True ):
	    bRingsOnly=True
	else:
	    print "-RingsOnly flag is unset, because one (or both) molecule has no rings\n"
    
    n1 = []
    n2 = []

#########################
## make molecule copies #
    molcp1 = cp.deepcopy(mol1)
    molcp2 = cp.deepcopy(mol2)

########################################
####### alignment ######################
    if(cmdl['-alignment']==True):
	print "The alignment approach will be used"
	print "Tosco, P., Balle, T. & Shiri, F. Open3DALIGN: an open-source software aimed at unsupervised ligand alignment. J Comput Aided Mol Des 25:777-83 (2011)\n"
	# only use rings
	if(bRingsOnly==True):
            n1,n2,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,bRingsOnly,True,d)
	# else try both options and choose better
	else:
	    print "Trying to align all atoms..."
            n1,n2,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,False,True,d)
	    print "Size of mapping: ",len(n1)
	    if( bYesRings==True ):
                print "Trying to align rings only..."
                mol1 = cp.deepcopy(molcp1)
                mol2 = cp.deepcopy(molcp2)
                n1B,n2B,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,True,True,d)
	    	print "Size of mapping: ",len(n1B)
	        if(len(n1)<=len(n1B)):
		    print "Using ring only alignment result."
		    n1 = n1B
		    n2 = n2B
		else:
		    print "Using all atom alignment result."
#	print n1,n2
###########################################
########### mcs ###########################
    elif(cmdl['-mcs']==True):
	print "The topology matching approach will be used"
	print "fmcs module: Copyright (c) 2012 Andrew Dalke Scientific AB\n"
	if(bRingsOnly==True):
	    n1,n2 = mcs(mol1,mol2,bH2H,bH2heavy,cmdl['-dMCS'],bRingsOnly,d,bChiral,timeout)
	else:
            print "Trying to run an MCS using all atoms..."
            n1A,n2A = mcs(mol1,mol2,bH2H,bH2heavy,cmdl['-dMCS'],False,d,bChiral,timeout)
            n1 = n1A
            n2 = n2A
            if( bYesRings==True ):
                print "Trying to run an MCS using rings only..."
                n1B,n2B = mcs(mol1,mol2,bH2H,bH2heavy,cmdl['-dMCS'],True,d,bChiral,timeout)
                if(len(n1A)<=len(n1B)):
                    print "Using ring only MCS result."
                    n1 = n1B
                    n2 = n2B
                else:
                    print "Using all atom MCS result."

#	print "MCS mapping may not result in a proper hybrid topology.\nPerforming a check..."
############################################
######### topology check ###################
#	if( 1==1 ):
#	    print "Topology check went fine, proceed with the MCS mapping.\n"
#	    n1,n2 = checkTop(mol1,mol2,n1,n2) # checking possible issues with the 1-2, 1-3 and 1-4 interactions
#	else:
#	    print "Problem in bond assignment, running alignment based mapping...\n"
#	    print "Tosco, P., Balle, T. & Shiri, F. Open3DALIGN: an open-source software aimed at unsupervised ligand alignment. J Comput Aided Mol Des 25:777-83 (2011)\n"
#  	    # only use rings
#	    if(bRingsOnly==True):
#	        print "MMFF parameters..."
#	        n1A,n2A,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,bRingsOnly,False,d)
#                print "Crippen parameters..."
#                n1B,n2B,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,bRingsOnly,True,d)
#                if(len(n1A)<=len(n1B)):
#                    n1 = n1B
#                    n2 = n2B
#                else:
#                    n1 = n1A
#		    n2 = n2A
	    # else try both options and choose better
#	    else:
#	        print "Trying to align all atoms..."
#	        print "MMFF parameters..."
#	        n1AA,n2AA,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,False,False,d)
#                print "Crippen parameters..."
#                n1AB,n2AB,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,False,True,d)
#                if(len(n1AA)<=len(n1AB)):
#                    n1 = n1AB
#                    n2 = n2AB
#                else:
#                    n1 = n1AA
#                    n2 = n2AA
#	        if( bYesRings==True ):
#                    print "Trying to align rings only..."
#		    print "MMFF parameters..."
#                    n1BA,n2BA,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,True,False,d)
#		    print "Crippen parameters..."
#                    n1BB,n2BB,pyO3A = o3a_alignment(mol1,mol2,bH2H,bH2heavy,True,True,d)
#                    if(len(n1BA)<=len(n1BB)):
#                        n1B = n1BB
#                        n2B = n2BB
#                    else:
#                        n1B = n1BA
#                        n2B = n2BA
#	            if(len(n1)<=len(n1B)):
#		        print "Using ring only alignment result."
#		        n1 = n1B
#		        n2 = n2B
#		    else:
#		        print "Using all atom alignment result."
    else:
	print "Select -alignment or -mcs\n"
	sys.exit(0)

    # a check
    if( len(n1) != len(n2) ):
	print "Warning: something went wrong."
	print "Number of the morphable atoms in the ligands does not match.\n"
    
    # calculate score
    score = calcScore(mol1,mol2,n1,n2,bH2H,bH2heavy)

    # print some output
    if( bH2H==True or bH2heavy==True ):
        print "Atoms considered in mol1: ",mol1.GetNumAtoms()
        print "Atoms considered in mol2: ",mol2.GetNumAtoms()
    else:
        print "Atoms considered in mol1: ",mol1.GetNumHeavyAtoms()
        print "Atoms considered in mol2: ",mol2.GetNumHeavyAtoms()
    print "Morphable atoms in both molecules: ",len(n1),len(n2)
    print "Dissimilarity (distance) score: %.4f\n" % score
    if(cmdl['-score']):
	fp = open(cmdl['-score'],'w')
	fp.write("Score: %.4f\n" % score)
	fp.close()	

    restoreAtomNames(mol1,atomNameDict1)
    restoreAtomNames(mol2,atomNameDict2)

    # output
    if cmdl.opt['-opdb1'].is_set:
        Chem.MolToPDBFile(mol1,cmdl['-opdb1'])
    if cmdl.opt['-opdb2'].is_set:
#	if( cmdl['-mcs']==True ):
        try:
            Chem.rdMolAlign.AlignMol(mol2,mol1,atomMap=zip(n2,n1))
        except:
  	    print "Cannot superimpose -opdb2 structure. Maybe no morphable atoms have been found\n"
        Chem.MolToPDBFile(mol2,cmdl['-opdb2'])
    if cmdl.opt['-opdbm1'].is_set:
	mol = subMolByIndex(mol1,n1)
        Chem.MolToPDBFile(mol,cmdl['-opdbm1'])
    if cmdl.opt['-opdbm2'].is_set:
        mol = subMolByIndex(mol2,n2)
        Chem.MolToPDBFile(mol,cmdl['-opdbm2'])

    # write out pairs
    pairsFile = cmdl['-o']
    write_pairs(n1,n2,pairsFile)

main( sys.argv )


