#!/opt/miniconda3/bin/python3
from os import error
from Bio.PDB.MMCIF2Dict import MMCIF2Dict
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio import SeqIO
from Bio.PDB.Polypeptide import protein_letters_3to1
#from mmcif.io.PdbxReader import PdbxReader
#from mmcif.io.IoAdapterCore import IoAdapterCore

#Definition and verification of mandatory entries
mandatory_entries=("_casp.author","_casp.target","_casp.pfrmat")
ok_mesg="OK: all mandatory _casp_* records are present."
def check_mandatory_entries(mmcif_dict: dict, mandatory_entries=mandatory_entries, ok_mesg=ok_mesg):
  """This method checks if all CASP mandatory records are present in
  extended mmCIF file parsed to mmcif_dict obtained using Bio.PDB.MMCIF2Dict."""
  missing_entries = []
  for entry in mandatory_entries:
    if entry not in mmcif_dict:
      missing_entries.append(entry)
  if missing_entries:
    return f"# ERROR! The following mandatory records are missing: {', '.join(missing_entries)}"
  return ok_mesg

#Reading fasta to dict from :
def read_fasta_dict(fasta):
  """Reads data from fasta, where fasta=SeqIO.parse(...,"fasta")
  and returns fasta_dict"""
  fasta_dict={}
  for record in fasta:
     fasta_dict[record.id]=record.seq
  return fasta_dict

#Code checking if sequence from ATOMs in structure matches the fasta file
def check_sequence_correspondence(structure, fasta_dict):
  """This function checks if sequence for structure matches the fasta file
  represented as fast_dict"""
  chains = [chain for i in range(len(structure)) for chain in structure[i].get_chains() ]
  seq_dict={}
  matches={}
  for chain in chains:
    chain_id=chain.get_id()
    model=chain.get_parent().serial_num #Model number from the file
    #print(chain_id)
    for x in chain:
        resnum = x.get_id()[1]
        seq_dict[(model,chain_id,resnum)]=protein_letters_3to1.get(x.get_resname())
        for k,v in fasta_dict.items():
          #print(f"{v} {resnum} {v[resnum-1]}")
          if (resnum<=len(v) and seq_dict[(model,chain_id,resnum)]==v[resnum-1]):
            #print(f"{model} {chain_id} {v} {resnum} {v[resnum-1]}")
            matches[(model,chain_id,k)]=True
          else:
            #print(f"{v} {resnum} {resnum-1} NONE")
            matches[(model,chain_id,k)]=False
  #print({(k[0],k[1]):k[2] for k,v in matches.items() if v})
  #print(seq_dict)
  return matches
def check_if_all_matches(matches, error_list=None):
  """It checks if there is any (model,chain) that does not have any
  corresponding sequence in fasta"""
  error=0
  if (error_list==None):
        error_list=[]
  if_matches={}
  for k,v in matches.items():
        if_matches[(k[0],k[1])]=False
  for k,v in matches.items():
        if v==True:
          if_matches[(k[0],k[1])]=True
  for k,v in if_matches.items():
        if v==False:
          error=1
          error_list.append(k)
  if (error):
        #print(f"Error: the folowing model {error_list} and chain doesn't correspond to "+
        #      f"any fasta sequence for the target.")
        return False
  return True
#Main
def main():
      import sys
      if ((len(sys.argv)>2)and(len(sys.argv)<4)):
        fasta_file=sys.argv[1]
        #print(fasta_file)
        mmcif_file=sys.argv[2]
        #print(mmcif_file)
        #mmcif_dict = MMCIF2Dict(mmcif_file)
      else:
        print("# ERROR! Missing arguments. Please provide two arguments: fasta file and mmCIF file.")       
        return -1 
        
      #Read and parse mmCIF file:
      #1. as dict
      mmcif_dict = MMCIF2Dict(mmcif_file)
      #print(mmcif_dict)
      #2. as structure
      parser = MMCIFParser()
      structure = parser.get_structure("models", mmcif_file)
      # Load sequences from fasta file
      fasta=SeqIO.parse(fasta_file, "fasta")
      fasta_dict=read_fasta_dict(fasta) #Read sequences to dict
      #Verify the prediction file
      #1.Check obligatory fields
      result = check_mandatory_entries(mmcif_dict)
      if (result != ok_mesg):
        print(result)
        return -1
      #2.Check if sequence from ATOM record corresponds to fasta
      matches=check_sequence_correspondence(structure, fasta_dict)
      #print(matches)
      error_list=[]
      all_match=check_if_all_matches(matches, error_list)
      if (not (all_match)):
         print(f"# ERROR! The folowing model and chain {error_list} doesn't correspond to "+
              f"any fasta sequence for the target. Check if the numbering of residues corresponds to their position in the sequence template.")
         return -2
      print("OK: verification suceeded")
      return 0
main()
#check_mmcif.py target_fasta.seq.txt your_casp_prediction_file.cif

