#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 19 18:08:19 2022

@author: aaron
"""
from rdkit import Chem
from rdkit.Chem import rdFMCS
from rdkit import RDLogger
import sys
from collections import Counter
from warnings import warn 






def get_ligand_input(file, sanitize = True, removeHs = True, cleanupSubstructures=False, strictParsing = True):
    
    if file.endswith('.mol2'):
        #assumes one ligand per mol2 file
        mol = [Chem.MolFromMol2File(file, sanitize=sanitize, removeHs=removeHs,cleanupSubstructures= cleanupSubstructures )] #!!!
        
    elif file.endswith('.sdf'):
        suppl = Chem.SDMolSupplier(file, sanitize = sanitize, removeHs=removeHs, strictParsing = strictParsing) # options
        mol = [i for i in suppl]
        for i in mol:
            if i is None:
                print '# ERROR! Mol parsed as None type'
    else:
        print '# ERROR! Unknown file imput' 
        sys.exit()
        
    
    return mol


def compare_atoms(ground_truth_mol ,comparison_mol ):
    '''
    compares the number and type of atoms in two or more mols
    '''
    gt_atoms = Counter([i.GetSymbol() for i in ground_truth_mol.GetAtoms()])
    for num, mol in enumerate(comparison_mol):
        mol_atoms = Counter([i.GetSymbol() for i in mol.GetAtoms()])
        
        if sorted(gt_atoms.items(), key = lambda x : x[0]) != sorted(mol_atoms.items(), key = lambda x : x[0]):
            print '# ERROR! Incorrect number of atoms in ligand {}'.format(num)
            return False
    
    return True
        

def get_bonds_dic(mol ) :
    
    '''
    makes a Counter object of the number of bonds in a mol
    '''
    
    all_bonds = []
    for i in mol.GetBonds():
        #may be two restrictive
        bond_type = sorted([i.GetBeginAtom().GetSymbol() , i.GetEndAtom().GetSymbol()])
        bond_type = bond_type[0] + bond_type[1] + str(i.GetBondType())
        all_bonds.append(bond_type)
    return Counter(all_bonds)
        

def compare_bonds(ground_truth_mol ,comparison_mol ) :
    '''
    compares the number of atoms between two or more molecules
    '''
    
    gt_bonds = get_bonds_dic(ground_truth_mol)
    
    for num, mol in enumerate(comparison_mol):
        mol_bonds  = get_bonds_dic(mol)
        if sorted(gt_bonds.items(), key = lambda x : x[0]) != sorted(mol_bonds.items(), key = lambda x : x[0]):
            print '# ERROR! Incorrect number of bonds in ligand {}'.format(num) 
            return False
    return True
    
def compare_smiles(ground_truth_mol ,comparison_mol ) :
    '''compares two canoical smiles strings as str '''
    gt_smiles = Chem.CanonSmiles(Chem.MolToSmiles(ground_truth_mol))
    
    for num, mol in enumerate(comparison_mol):
       
        mol_smiles = Chem.CanonSmiles(Chem.MolToSmiles(mol))
        
        if gt_smiles != mol_smiles:
           
            print '# ERROR! Incorrect smiles comparison in ligand {}'.format(num)
            return False
    return True

def compare_inchi(ground_truth_mol ,comparison_mol) :
    
    '''compares two Inchi strings as str '''
    
    gt_inchi = Chem.MolToInchi(ground_truth_mol)
    
    for num, mol in enumerate(comparison_mol):
       
        mol_inchi = Chem.MolToInchi(mol)
        if gt_inchi != mol_inchi:
            print '# ERROR! Incorrect inchi comparison in ligand {}'.format(num)
            return False
    return True

def maximum_common_substructure(ground_truth_mol ,comparison_mol ) :
    
    '''
    finds the maximum common substructure and compares the number of atoms to a given mol
    '''
    res = rdFMCS.FindMCS(comparison_mol + [ground_truth_mol])
    
    if res.numAtoms == ground_truth_mol.GetNumAtoms():
        return True
    else:
        return False

def convert(option ) :
    if option.upper() == 'TRUE':
        return True
    elif option.upper() == 'FALSE':
        return False
    
    else:
        print '# ERROR! Unknown option input {}'.format(option)
        sys.exit()

def parse_options(args):
    ''' returns a dict of options to use for assesment'''
    
    options_defaults = { '-d' : '\t',
                        '-col' : 1,
                        '-smiNum' : 0, 
                        '-t': True,
                        '-s': True,
                        '-h': True,
                        '-c': True,
                        '-p': True}
    
    for key in options_defaults:
        if key in args:
            
            if type(options_defaults[key]) is bool:
                options_defaults[key] = convert(sys.argv[sys.argv.index(key) + 1].upper())
            else:
                options_defaults[key] =  type(options_defaults[key])(sys.argv[sys.argv.index(key) + 1])
      
    return options_defaults        
            
            
    
    
if __name__ == '__main__':
    
    RDLogger.DisableLog('rdApp.*')
    
    if len(sys.argv) < 3:
        
        print '\nUsage <smiles string/file> <input file (.mol2/.sdf)> <Options>' 
        print 'Options: \n\t-s : Sanitise molecule on reading, default = True'
        print '\t-h : remove Hs upon reading molecule, default = True'
        print '\t-c : clean substructure upon reading molecule, default = True'
        print '\t-p : stict parsing for SDmolSupplier, default = True'
        print '\t-t : smiles file title line, default = True'
        print '\t-col : smiles colum, default 1'
        print '\t-smiNum : number of the smiles string to use in the file, default 0'
        print '\tsmiles file delimiter is \\t \n'
        
    
    else:
        options = parse_options(sys.argv)
        
        print(options)
        
           
        smi = sys.argv[1]
        
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            
            suppl = Chem.SmilesMolSupplier( sys.argv[1],delimiter=options['-d'],titleLine=options['-t'], smilesColumn= options['-col'])
            mol = suppl[options['-smiNum']]
        
            if mol is None:
                print '# ERROR! Unable to parse SMILES'
                sys.exit()
            
        
        
        file = sys.argv[2]
        
        
        input_mol = get_ligand_input(file, sanitize = options['-s'], removeHs = options['-h'], cleanupSubstructures = options['-c'], strictParsing  =  options['-p'] )
        
        #Make sure the mol has the correct atoms
        correct_atoms = compare_atoms(mol, input_mol)
        if correct_atoms:
            print 'Atoms accepted'
        else:
            print '# ERROR! Atoms not accepted'
        
        #Make sure the mol has the correct bonds
        correct_bonds = compare_bonds(mol, input_mol)
        if correct_bonds:
            print 'Bonds accepted'
        else:
            print '# ERROR! Bonds not accepted'
        
        #correct_smiles = compare_smiles(mol, input_mol)
        #if correct_smiles:
        #    print 'SMILES accepted'
        #else:
        #    print '# ERROR! SMILES not accepted'
        
        #correct_inchi = compare_inchi(mol, input_mol)
        #if correct_inchi:
        #    print 'Inchi accepted'
        #else:
        #    print '# ERROR! Inchi not accepted'
        
        
        correct_common_substructure = maximum_common_substructure(mol, input_mol)
        if correct_common_substructure:
            print'Maximum common substructure accepted'
        else:
            print '# ERROR! Maximum common substructure not accepted'
    
        





