#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 19 18:08:19 2022

@author: aaron
"""
from rdkit import Chem
from rdkit.Chem import rdFMCS
import sys
from collections import Counter
from warnings import warn 






def get_ligand_input(file : str, sanitize = True, removeHs = True, cleanupSubstructures=False, strictParsing = True) -> 'RDKit mol object':
    
    if file.endswith('.mol2'):
        #assumes one ligand per mol2 file
        mol = [Chem.MolFromMol2File(file, sanitize=sanitize, removeHs=removeHs,cleanupSubstructures= cleanupSubstructures )] #!!!
        
    elif file.endswith('.sdf'):
        suppl = Chem.SDMolSupplier(file, sanitize = sanitize, removeHs=removeHs, strictParsing = strictParsing) # options
        mol = [i for i in suppl]
    else:
        warn('Unknown file imput')
        sys.exit()
        
    
    return mol


def compare_atoms(ground_truth_mol : 'RDKit mol object',comparison_mol : 'RDKit mol object') -> bool:
    '''
    compares the number and type of atoms in two or more mols
    '''
    gt_atoms = Counter([i.GetSymbol() for i in ground_truth_mol.GetAtoms()])
    for num, mol in enumerate(comparison_mol):
        mol_atoms = Counter([i.GetSymbol() for i in mol.GetAtoms()])
        
        if sorted(gt_atoms.items(), key = lambda x : x[0]) != sorted(mol_atoms.items(), key = lambda x : x[0]):
            warn(f'Incorrect number of atoms in ligand { num }')
            return False
    
    return True
        

def get_bonds_dic(mol : 'RDKit mol object') -> 'python Counter object':
    
    '''
    makes a Counter object of the number of bonds in a mol
    '''
    
    all_bonds = []
    for i in mol.GetBonds():
        #may be two restrictive
        bond_type = sorted([i.GetBeginAtom().GetSymbol() , i.GetEndAtom().GetSymbol()])
        bond_type = bond_type[0] + bond_type[1] + str(i.GetBondType())
        all_bonds.append(bond_type)
    return Counter(all_bonds)
        

def compare_bonds(ground_truth_mol : 'RDKit mol object',comparison_mol : 'RDKit mol object') -> bool:
    '''
    compares the number of atoms between two or more molecules
    '''
    
    gt_bonds = get_bonds_dic(ground_truth_mol)
    
    for num, mol in enumerate(comparison_mol):
        mol_bonds  = get_bonds_dic(mol)
        if sorted(gt_bonds.items(), key = lambda x : x[0]) != sorted(mol_bonds.items(), key = lambda x : x[0]):
            print(f'Incorrect number of bonds in ligand { num }')
            return False
    return True
    
def compare_smiles(ground_truth_mol : 'RDKit mol object',comparison_mol : 'RDKit mol object') -> bool:
    '''compares two canoical smiles strings as str '''
    gt_smiles = Chem.CanonSmiles(Chem.MolToSmiles(ground_truth_mol))
    
    for num, mol in enumerate(comparison_mol):
       
        mol_smiles = Chem.CanonSmiles(Chem.MolToSmiles(mol))
        
        if gt_smiles != mol_smiles:
           
            print(f'Incorrect smiles comparison in ligand { num }')
            return False
    return True

def compare_inchi(ground_truth_mol : 'RDKit mol object',comparison_mol : 'RDKit mol object') -> bool:
    
    '''compares two Inchi strings as str '''
    
    gt_inchi = Chem.MolToInchi(ground_truth_mol)
    
    for num, mol in enumerate(comparison_mol):
       
        mol_inchi = Chem.MolToInchi(mol)
        if gt_inchi != mol_inchi:
            print(f'Incorrect inchi comparison in ligand { num }')
            return False
    return True

def maximum_common_substructure(ground_truth_mol : 'RDKit mol object',comparison_mol : 'RDKit mol object') -> bool:
    
    '''
    finds the maximum common substructure and compares the number of atoms to a given mol
    '''
    res = rdFMCS.FindMCS(comparison_mol + [ground_truth_mol])
    
    if res.numAtoms == ground_truth_mol.GetNumAtoms():
        return True
    else:
        return False

def convert(option : str) -> bool:
    if option.upper() == 'TRUE':
        return True
    elif option.upper() == 'FALSE':
        return False
    
    else:
        print(f'Unknown option input { option }')
        sys.exit()

if __name__ == '__main__':
    
    if len(sys.argv) < 3:
        print('Usage <smiles string> <input file (.mol2/.sdf)> <Options>')
        print('Options: \n\t-s : Sanitise molecule on reading, default = True')
        print('\t-h : remove Hs upon reading molecule, default = True')
        print('\t-c : clean substructure upon reading molecule, default = True')
        print('\t-p : stict parsing for SDmolSupplier, default = True')
    
    else:
        smi = sys.argv[1]
        file = sys.argv[2]
        
        if '-s' in sys.argv:
            sanitize = convert(sys.argv[sys.argv.index('-s') + 1].upper())
        else:
            sanitize = True
        
        if '-h' in sys.argv:
            removeHs = convert(sys.argv[sys.argv.index('-h') + 1].upper())
        else:
            removeHs = True
        
        if '-c' in sys.argv:
            cleanupSubstructures = convert(sys.argv[sys.argv.index('-c') + 1].upper())
            
        else:
            cleanupSubstructures = True
            
        if '-p' in sys.argv:
            strictParsing = convert(sys.argv[sys.argv.index('-p') + 1].upper())
            print( strictParsing )
        else:
            strictParsing = True
    
    
    mol = Chem.MolFromSmiles(smi)
    input_mol = get_ligand_input(file, sanitize = sanitize, removeHs = removeHs, cleanupSubstructures = cleanupSubstructures, strictParsing  =  strictParsing )
    
    #Make sure the mol has the correct atoms
    correct_atoms = compare_atoms(mol, input_mol)
    if correct_atoms:
        print('Atoms accepted')
    else:
        print('Atoms not accepted')
    
    #Make sure the mol has the correct bonds
    correct_bonds = compare_bonds(mol, input_mol)
    if correct_bonds:
        print('Bonds accepted')
    else:
        print('Bonds not accepted')
    
    correct_smiles = compare_smiles(mol, input_mol)
    if correct_smiles:
        print('SMILES accepted')
    else:
        print('SMILES not accepted')
    
    correct_inchi = compare_inchi(mol, input_mol)
    if correct_inchi:
        print('Inchi accepted')
    else:
        print('Inchi not accepted')
    
    
    correct_common_substructure = maximum_common_substructure(mol, input_mol)
    if correct_common_substructure:
        print('Maximum common substructure accepted')
    else:
        print('Maximum common substructure not accepted')
    
        





