import prody as pd
from pathlib import Path
from rdkit import Chem


def extract_ligands(pdb_folder, output_folder, target_smiles_file):
    target_smiles_dict = {}
    with open(target_smiles_file, 'r') as f:
        for i, line in enumerate(f):
            if i == 0:
                continue
            filename, smiles = line.strip().split("\t")
            target_smiles_dict[filename] = smiles.strip()

    pdb_smiles_dict = {}
    for filename in Path(pdb_folder).glob("*.pdb"):
        ligands = []
        pdb = pd.parsePDB(filename)
        # get indices of ligand residues
        ligand_resindices = list(set(pdb.select("hetero and not water").getResindices()))
        for resindex in ligand_resindices:
            # chemical component identifier
            cci = pdb.select(f"resindex {resindex}").getResnames()[0]
            if cci == "LIG":
                # not a PDB ligand, get SMILES from target SMILES file
                smiles = Chem.CanonSmiles(target_smiles_dict[str(filename.name)])
            else:
                if cci not in pdb_smiles_dict:
                    # get SMILES from PDB
                    pdb_smiles_dict[cci] = Chem.CanonSmiles(pd.parsePDBLigand(cci).getCanonicalSMILES())
                smiles = pdb_smiles_dict[cci]
            ligands.append((cci, smiles))

        with open(Path(output_folder) / f"{filename.stem}_ligands.txt", "w") as f:
            f.write("ID\tName\tSMILES\n")
            for i, (resname, smiles) in enumerate(sorted(ligands)):
                f.write(f"{i + 1:03}\t{resname}\t{smiles}\n")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Extract ligands from PDB files.")
    parser.add_argument("pdb_folder", type=str, help="Folder with PDB files.")
    parser.add_argument("output_folder", type=str, help="Folder to save extracted ligands.")
    parser.add_argument("target_smiles_file", type=str, help="File with SMILES for non-PDB ligands.")
    args = parser.parse_args()
    extract_ligands(args.pdb_folder, args.output_folder, args.target_smiles_file)
