#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 28 15:29:23 2022

@author: Aaron Sweeney
"""

from rdkit import Chem
from io import BytesIO
from rdkit.Chem import rdFMCS
import sys
from collections import Counter
    
class LG():
    
    def __init__(self, file):
        
        with open(file, 'r') as f:
            data = f.read().splitlines()

        #get PFRMAT
        if data[0].startswith('PFRMAT'):
            self.PFRMAT = data[0].split()[1]
        else:
            print '# ERROR! PFRMAT record not found !!'
            sys.exit()

        
        #get TARGET
        if data[1].startswith('TARGET'):
            self.TARGET = data[1].split()[1]
            if len(self.TARGET) == 0:
                print '# ERROR! TARGET not specified !!'
                sys.exit()
        else:
            print '# ERROR! TARGET record not found !!'
            sys.exit()

        #get Author
        if data[2].startswith('AUTHOR'):
            self.AUTHOR = data[2].split()[1]
            if len(self.AUTHOR)  == 0:
                print '# ERROR! Author record not found !!'
                sys.exit()

        else:
            print '# ERROR! Author record not found !!'
            sys.exit()

        self.METHOD = ''
        self.MODELS =  []

        mdl_idx = []
        
        for index,line in enumerate(data):
            if line.startswith('METHOD'):
                self.METHOD += (line.replace('METHOD','').strip() + '\n')
            elif line.startswith('MODEL') or line.startswith('END'):
                mdl_idx.append(index)

        if not len(mdl_idx) % 2 == 0:
            print '# ERROR! Missing MODEL/END records !!'
            sys.exit()

        else:
            for i in range(0, len(mdl_idx),2):
                self.MODELS.append(Model(data[mdl_idx[i]:mdl_idx[i+1] + 1]))
                
        self.VALIDATION_REPORT = None
#AK2022
        self._path = '/local/Projects/Perl/casp15/src/logs/LigLog/'
        #self._path = ''
        self.OUT_FILE =self._path + '{}_{}.log'.format(self.AUTHOR, self.TARGET)
    
    def write_report(self):
        if self.VALIDATION_REPORT is None:
            print '# ERROR! Validation not run !!'
            
            
        else:
            with open(self.OUT_FILE, 'w') as f:
                f.writelines(self.VALIDATION_REPORT)

class Model():

    def __init__(self, mdl_data):
        #protein, ligands and conformations, ligand ids
        ligand_idx = [index for index, line in enumerate(mdl_data) if line.startswith('LIGAND')] + [len(mdl_data)]
#THIS IS NEW
        if len(ligand_idx) <= 1:
            print '# ERROR! No LIGAND record found !!'
            sys.exit()
#UNTIL HERE
        self.LIGANDS = []
        for i in range(len(ligand_idx) -1) :
            self.LIGANDS.append(Ligand(mdl_data[ligand_idx[i]:ligand_idx[i+1]]))

class Ligand():
    def __init__(self, ligand_data):
        self.CONFORMATIONS = []
        self.NUM_CONFORMATIONS = 0
        self.LIG_ID_NUM = None
        self.LIG_ID_CODE = None
        
        
        
        if ligand_data[0].startswith('LIGAND'):
#THIS IS NEW !!!!!!!
            try:
                self.LIG_ID_NUM = ligand_data[0].split(' ')[1]
                self.LIG_ID_CODE = ligand_data[0].split(' ')[2]
            except Exception as e:
                print('# ERROR! Incorrect LIGAND record format !!')
                sys.exit()
#UNTILL  HERE
            #self.LIG_ID_NUM = ligand_data[0].split(' ')[2]
            #self.LIG_ID_CODE = ligand_data[0].split(' ')[1]
        else:
            print '# ERROR! Missing LIGAND records !!'
            sys.exit()
            
        

        pose_idx = [index for index, line in enumerate(ligand_data) if line.startswith('POSE') or line.replace(' ','') == 'MEND']

        if not len(pose_idx) % 2 == 0:
            print '# ERROR! Missing POSE/M END records !!' 
            sys.exit()

        else:
            for i in range(0, len(pose_idx),2):
                mol_block = ligand_data[pose_idx[i] + 1: pose_idx[i+1] + 1]
                mol = self.get_mol_from_block(mol_block)
                self.CONFORMATIONS.append(mol)
        
        self.NUM_CONFORMATIONS = len(self.CONFORMATIONS)
        self.VALIDATION = [list() for i in range( self.NUM_CONFORMATIONS)]
        

        
                
    def get_mol_from_block(self, block):
        
        block = [i for i in block if not i.startswith('REMARK')]
        block.append('$$$$')
        block = '\n'.join(block)
        mol = None
        try:
            with BytesIO(block.encode("UTF-8")) as hnd:
                for mol in Chem.ForwardSDMolSupplier(hnd):
                    pass
        except:
            print '# ERROR! reading ligand coords: {} {} !!'.format(self.LIG_ID_NUM, self.LIG_ID_CODE) 
            sys.exit()
        if mol is None:
            print '# ERROR! reading ligand coords: {} {} !!'.format(self.LIG_ID_NUM, self.LIG_ID_CODE ) 
            sys.exit()
        return mol
        

def validate_ligands(LG_obj):
    
    smiles = find_smiles(LG_obj.TARGET)
    print smiles
    
    for mdl in LG_obj.MODELS:
        for lig in mdl.LIGANDS:
            lig_id = lig.LIG_ID_NUM
            validation_mol = Chem.MolFromSmiles(smiles[lig_id])
            if validation_mol is None:
                print '# ERROR! Validation mol is None !!' 
                sys.exit()
            
            
            for num, pose in enumerate(lig.CONFORMATIONS):
                
                
                
                valid_atoms = compare_atoms(validation_mol, pose)
                lig.VALIDATION[num].append(valid_atoms)
                
                valid_bonds = compare_bonds(validation_mol, pose)
                lig.VALIDATION[num].append(valid_bonds)
                
                valid_connection = maximum_common_substructure(validation_mol, pose)
                lig.VALIDATION[num].append(valid_connection)
                     

def maximum_common_substructure(ground_truth_mol, comparison_mol) :
    
    '''
    finds the maximum common substructure and compares the number of atoms to a given mol
    '''
    
    res = rdFMCS.FindMCS([comparison_mol ,ground_truth_mol])
    
    if ground_truth_mol.GetNumAtoms() == 1 and comparison_mol.GetNumAtoms() == 1:
        return True
    
    if res.numAtoms == ground_truth_mol.GetNumAtoms() and res.numAtoms == comparison_mol.GetNumAtoms():
        return True
    else:
        print '# ERROR! Topology not valid !!'
        return False
            
            
def get_bonds_dic(mol)  :
    
    '''
    makes a Counter object of the number of bonds in a mol
    '''
    
    all_bonds = []
    for i in mol.GetBonds():
        #may be two restrictive
        bond_type = sorted([i.GetBeginAtom().GetSymbol() , i.GetEndAtom().GetSymbol()])
        bond_type = bond_type[0] + bond_type[1] + str(i.GetBondType())
        all_bonds.append(bond_type)
    return Counter(all_bonds)
        

def compare_bonds(ground_truth_mol ,comparison_mol ) :
    '''
    compares the number of atoms between two or more molecules
    '''
    
    gt_bonds = get_bonds_dic(ground_truth_mol)
    mol_bonds  = get_bonds_dic(comparison_mol)
    if sorted(gt_bonds.items(), key = lambda x : x[0]) != sorted(mol_bonds.items(), key = lambda x : x[0]):
        print '# ERROR! Bonds not valid !!'
        return False
    return True

def compare_atoms(validation_mol, mol):
    
    gt_atoms =  Counter([i.GetSymbol() for i in validation_mol.GetAtoms()])
    mol_atoms = Counter([i.GetSymbol() for i in mol.GetAtoms()])
    if sorted(gt_atoms.items(), key = lambda x : x[0]) != sorted(mol_atoms.items(), key = lambda x : x[0]):
        print '# ERROR! Atoms not valid !!'
        return False
    return True


        

def find_smiles(target):
    smiles_fn = './TARGETS/'+target+'.smiles.txt'
    #smiles_fn = 'smi.smiles.txt'
    with open(smiles_fn, 'r') as f:
        smiles = f.read().splitlines()
#AK -tried to change that 1st field is ligand ID number
    smiles = {i.split()[0]: i.split()[2] for i in smiles}
    print smiles
 
    return smiles
        

def make_validation_report(LG_obj):
    
    report = ['----VALIDATION REPORT----\n',
              'AUTHOR : {}\n'.format(LG_obj.AUTHOR), 
              'TARGET : {}\n'.format(LG_obj.TARGET),
              'METHOD : {}\n'.format(LG_obj.METHOD),
              '-------------------------\n','NUM_MODELS : {}\n'.format(len(LG_obj.MODELS))]
    if len(LG_obj.MODELS) > 5:
        report.append('# ERROR! NUM_MODELS greater than allowed !!\n')
    
    for num,mdl in enumerate(LG_obj.MODELS):
        report.append('MODEL :  {}\n'.format(num + 1))
        report.append('NUM_LIGANDS :  {}\n'.format(len(mdl.LIGANDS)))
        
        for lig in mdl.LIGANDS:
            report.append('\n\tLIGAND : {} {}\n'.format(lig.LIG_ID_CODE, lig.LIG_ID_NUM))
            report.append('\tNUM_POSES : {}\n'.format(len(lig.CONFORMATIONS)))
            if len(lig.CONFORMATIONS) > 5:
                report.append('\t# ERROR! NUM_POSES greater than allowed !!\n')
            validation_report = lig.VALIDATION
            for num,pose in enumerate(lig.CONFORMATIONS):
                report.append('\n\t\tPOSE : {}\n'.format(num + 1))
                
                
                if validation_report[num][0] == True:
                    report.append('\t\tATOMS VALID\n')
                elif validation_report[num][0] == False:
                    report.append('\t\t# ERROR! Invalid number of atoms !!\n')
                    
                if validation_report[num][1] == True:
                    report.append('\t\tBONDS VALID\n')
                elif validation_report[num][1] == False:
                    report.append('\t\t# ERROR! Bonds invalid !!\n')
                
                if validation_report[num][2] == True:
                    report.append('\t\tTOPOLOGY VALID\n')
                elif validation_report[num][2] == False:
                    report.append('\t\t# ERROR! Invalid topology !!\n')
    
    LG_obj.VALIDATION_REPORT = report
                


if __name__ == '__main__':
    
    if len(sys.argv) <= 1:
        print 'Usage python LG_validation.py <LG file>'
    test_txt = sys.argv[1]
    a = LG(test_txt)
    validate_ligands(a)
    make_validation_report(a)
    a.write_report()
    
    
    
            
        

        


            
        
        

