package DomainsPlotsManager;

use strict;
use warnings;
use Statistics::R;
use DBI;
#use Digest::MD5 qw(md5 md5_hex md5_base64);
#use DateTime;

use lib qw(Core);

use Database;
use TargetsManager;
#use PDBUtils;
use DomainsManager;

use Configuration;

my $casp_version = 'casp13';

my $domains_plots_manager = undef;

sub new {
    my ($class) = @_;

    return $domains_plots_manager if(defined($domains_plots_manager));

    my $self = {
        _database => Database->new($CONFIG->{HOSTNAME}, $CONFIG->{PORT}, $CONFIG->{DATABASE}, $CONFIG->{USERNAME}, $CONFIG->{PASSWORD}),
	_target_id => undef,
	_target_name => undef
    };

    $domains_plots_manager = bless $self, $class;
    return $domains_plots_manager;
}

sub set_target_name {
   my ($self, $target_name) = @_;
   if(defined($self->{_target_name})){
	return;
   }
   $self->{_target_name} = $target_name;
   my $target_manager = TargetsManager->new();
   $self->{_target_id} = $target_manager->get_id_by_name($target_name);
}

sub set_target_id {
   my ($self, $target_id) = @_;
   if(defined($self->{_target_id})){
        return;
   }
   $self->{_target_id} = $target_id;
   my $target_manager = TargetsManager->new();
   $self->{_target_name} = $target_manager->name($target_id);
}


# retrieve data from database for target
# arguments: (d1_index, d2_index, d12_index)
# d1_index - index of domain1, d2_index - index of domain2
# d12_index - index of domain, which is a union of domain1 and domain2
# flag_overall - if 1 then pick up data from top groups regardless their type either human or server
#   default value - 0 - select server predictions only
sub get_data {
    my ($self, $d1_index, $d2_index, $d12_index, $flag_overall) = @_; 
    my @xs;
    my @ys;
    my %hash_gdt;
    my %hash_len;
    my $query = sprintf(
	"select pr.id, r.domain, r.n2_4, r.gdt_ts_4 from $casp_version.results r join $casp_version.predictions pr on r.predictions_id=pr.id join $casp_version.groups g on g.id=pr.groups_id where pr.target = \'%s\' and pr.pfrmat = 'TS' and g.type>0 and g.type<3 and pr.model=1  order by (pr.id, r.domain)",$self->{_target_name}
);
   if (defined ($flag_overall) && $flag_overall == 1){
	$query = sprintf(
        "SELECT pr.id, r.domain, r.n2_4, r.gdt_ts_4 FROM $casp_version.results r 
	   JOIN $casp_version.predictions pr ON r.predictions_id=pr.id 
	   JOIN $casp_version.groups g ON g.id=pr.groups_id 
	   WHERE pr.target = \'%s\' AND pr.pfrmat = 'TS' 
           ORDER BY (pr.id, r.domain)", $self->{_target_name}
);
   }
   my $sth = $self->{_database}->query($query);
   if(defined($sth) && ($sth->rows() > 0)) { 
     while(my($pr_id, $r_domain_index, $r_len, $r_gdt_ts) = $sth->fetchrow_array()) {
	if($r_domain_index == $d1_index || $r_domain_index == $d2_index || $r_domain_index == $d12_index){
		$hash_gdt{$pr_id}{$r_domain_index} = $r_gdt_ts;
		$hash_len{$r_domain_index} = $r_len if(! exists $hash_len{$r_domain_index});
	}else{
		next;
	}
     }
   }

   foreach my $pr_id (sort keys %hash_gdt){
	my $x = 0; my $y = 0;  # $x - gdt_ts on combined domain; $y - sum of weighted gdt_ts's on separate domains 
	my $sum_len = 0;
	foreach my $domain_index (sort keys %hash_len){
	   if($domain_index == $d12_index){
		if(exists $hash_gdt{$pr_id}{$domain_index}){
		   $x = $hash_gdt{$pr_id}{$domain_index};
		}else{
		   $x = $self->random_gdt_ts($hash_len{$domain_index});
		}
	   }else{
		if(exists $hash_gdt{$pr_id}{$domain_index}){
                   $y += ($hash_gdt{$pr_id}{$domain_index} * $hash_len{$domain_index});
                }else{
                   $y += ($self->random_gdt_ts($hash_len{$domain_index}) * $hash_len{$domain_index});
                }
		$sum_len += $hash_len{$domain_index};
	   }
	}
	$y = $y/$sum_len;
	push(@xs, $x);
	push(@ys, $y);
   }

   return \@xs, \@ys;
}

sub get_max_gdt_ts {
    my ($self, $d1_index, $d2_index, $d12_index, $flag_overall) = @_;
    my ($max1, $max2, $max12);

    my $query1 = sprintf(
        "select max(r.gdt_ts_4) from $casp_version.results r join $casp_version.predictions pr on r.predictions_id=pr.id join $casp_version.groups g on g.id=pr.groups_id where pr.target = \'%s\' and pr.pfrmat = 'TS' and g.type>0 and g.type<3 and pr.model=1 and r.domain=%d",$self->{_target_name}, $d1_index
);
    if (defined ($flag_overall) && $flag_overall == 1){
	$query1 = sprintf(
        "SELECT max(r.gdt_ts_4) FROM $casp_version.results r 
	  JOIN $casp_version.predictions pr ON r.predictions_id=pr.id 
	  JOIN $casp_version.groups g ON g.id=pr.groups_id 
	  WHERE pr.target = \'%s\' AND pr.pfrmat = 'TS' AND r.domain=%d", $self->{_target_name}, $d1_index
	);
    }
    my $sth = $self->{_database}->query($query1); 
    if(defined($sth) && ($sth->rows() > 0)) {
	($max1) = $sth->fetchrow_array();
	$max1 = sprintf("%-5.2f", $max1);
    }else{
	$max1 = "no results";
    }

    my $query2 = sprintf(
        "select max(r.gdt_ts_4) from $casp_version.results r join $casp_version.predictions pr on r.predictions_id=pr.id join $casp_version.groups g on g.id=pr.groups_id where pr.target = \'%s\' and pr.pfrmat = 'TS' and g.type>0 and g.type<3 and pr.model=1 and r.domain=%d",$self->{_target_name}, $d2_index
);
    if (defined ($flag_overall) && $flag_overall == 1){
        $query2 = sprintf(
        "SELECT max(r.gdt_ts_4) FROM $casp_version.results r 
          JOIN $casp_version.predictions pr ON r.predictions_id=pr.id 
          JOIN $casp_version.groups g ON g.id=pr.groups_id 
          WHERE pr.target = \'%s\' AND pr.pfrmat = 'TS' AND r.domain=%d", $self->{_target_name}, $d2_index
        );
    }

    $sth = $self->{_database}->query($query2); 
    if(defined($sth) && ($sth->rows() > 0)) {
        ($max2) = $sth->fetchrow_array();
	$max2 = sprintf("%-5.2f", $max2);
    }else{
        $max2 = "no results";
    }   

    my $query12 = sprintf(
        "select max(r.gdt_ts_4) from $casp_version.results r join $casp_version.predictions pr on r.predictions_id=pr.id join $casp_version.groups g on g.id=pr.groups_id where pr.target = \'%s\' and pr.pfrmat = 'TS' and g.type>0 and g.type<3 and pr.model=1 and r.domain=%d",$self->{_target_name}, $d12_index
);
    if (defined ($flag_overall) && $flag_overall == 1){
        $query12 = sprintf(
        "SELECT max(r.gdt_ts_4) FROM $casp_version.results r 
          JOIN $casp_version.predictions pr ON r.predictions_id=pr.id 
          JOIN $casp_version.groups g ON g.id=pr.groups_id 
          WHERE pr.target = \'%s\' AND pr.pfrmat = 'TS' AND r.domain=%d", $self->{_target_name}, $d12_index
        );
    }
    $sth = $self->{_database}->query($query12);   
    if(defined($sth) && ($sth->rows() > 0)) {
        ($max12) = $sth->fetchrow_array();
	$max12 = sprintf("%-5.2f", $max12);
    }else{
        $max12 = "no results";
    }
    my $result = sprintf("domain(max GDT_TS): %d(%s) %d(%s) %d(%s)",$d1_index, $max1, $d2_index, $max2, $d12_index, $max12);
    return $result;
}

# random gdt_ts is used in cases if we don't have results of lga for particular group for particular domain, provided for other domains the results are present
sub random_gdt_ts {
    my ($self, $len) = @_;
    return (102.8*exp( -0.089 * ($len**0.729)) + 11.3);
}

#draw plot in R
sub draw_plot {
    my ($self, $png_file, $d1_index, $d2_index, $d12_index, $flag_overall) = @_;
    if(-f $png_file){
	return 1;
    }
    my ($refX,$refY) = $self->get_data($d1_index, $d2_index, $d12_index, $flag_overall);
    if(scalar(@{$refX})==0 || scalar(@{$refY})==0){
	return 0;
    }
    my $max_gdt_string = $self->get_max_gdt_ts($d1_index, $d2_index, $d12_index, $flag_overall);
    my $range_string = $self->toStringRange($d1_index, $d2_index);
    # draw plot in R
    my $R = Statistics::R->new();
    $R->set('x',$refX);
    $R->set('y',$refY);

    my $command1 = sprintf("png(filename=\"%s\", width=480, height=480, bg=\"white\",units=\"px\",  pointsize=12)", $png_file);
    my $command2 = sprintf("plot(x, y, pch=20, xlim=c(0,100), ylim=c(0,100), main='%s: (%d and %d vs %d)', xlab='', ylab='separated domains: weighted sum of GDT_TS')",$self->{_target_name}, $d1_index, $d2_index, $d12_index);
    my $command3 = sprintf("par(mar=c(6,4,4,2))\nmtext(\"%s\", side=1, line=3, cex=1, outer=FALSE)\nmtext(\"%s\", side=1, line=4, cex=1, outer=FALSE)\nmtext(\"%s\", side=1, line=5, cex=1, outer=FALSE)","combined domain: GDT_TS",$max_gdt_string,$range_string);
#    my $command4 = sprintf("par(oma=c(1,0,0,0))\nmtext(\"%s\", side=1, line=0, cex=1, outer=TRUE)",$range_string);
    my $command5 = sprintf("abline(fit, col=\"%s\")", ( (defined($flag_overall)&&($flag_overall==1)) ? 'blue' : 'green') );
  $R->run($command1,
	$command2,
	$command3,
  q`abline(v=(seq(0,100,5)), col="lightgray", lty="dotted")`,
  q`abline(h=(seq(0,100,5)), col="lightgray", lty="dotted")`,
  q`lines(c(-5,105),c(-5,105), col='red')`,
  q`fit <- lm(y ~ 0 + x)`,
  $command5
);
  $R->run(q`dev.off()`);
  $R->stop();
  system("chgrp users $png_file");
  system("chmod 664 $png_file");
  return 1;
}

sub toStringRange {
    my ($self, $d1_index, $d2_index) = @_;
    my $result = "domain(range): ";
    my $domainsManager = new DomainsManager();
    my @domains = $domainsManager->get_domains($self->{_target_id}, $d1_index);    
    $result .= sprintf("%s(%s) ",$d1_index,$domains[0]->{RANGE});
    @domains = $domainsManager->get_domains($self->{_target_id}, $d2_index);
    $result .= sprintf("%s(%s) ",$d2_index,$domains[0]->{RANGE});
    return $result;
}


1;
