package ZscoresCryoemManager;

use strict;
use warnings;

use DBI;

use lib qw(Core);

use Database;

#use PDBUtils;

use Configuration;
use LocalConfiguration;
#use TargetsManager;
#use GroupsManager;
#use Statistics;

my $zscores_manager = undef;
my $OUTLIER_CUTOFF = -2.0;
my @SCORES = qw/ccc cc_mask cc_peak smoc/;

sub new {
    my ($class) = @_;
    return $zscores_manager if(defined($zscores_manager));

    my $self = {
        _database => Database->new($CONFIG->{HOSTNAME}, $CONFIG->{PORT}, $CONFIG->{DATABASE}, $CONFIG->{USERNAME}, $CONFIG->{PASSWORD}),
    };

    $zscores_manager = bless $self, $class;
    return $zscores_manager;
}

#######################################
#
# Methods to upload data to database
#
########################################

sub get_new_record{
    my ($self) = @_;
    my %record = (
	id => '',
	target => '',
	gr_code => '',
	ccc => '',
	cc_mask => '',
	cc_peak => '',
	smoc => ''
    );
    return %record;
}

sub is_valid_columname {
    my ($self, $colname) = @_;
    if ('id' eq $colname ||
        'gr_code' eq $colname ||
	'target' eq $colname ||
        'ccc' eq $colname ||
        'cc_mask' eq $colname ||
	'cc_peak' eq $colname ||
	'smoc' eq $colname
    ){
	return 1;
    } else {
	return 0;
    }
}

sub getRawScores{
    my ($self, $param ) = @_;
    my $query = "
        SELECT  
        re.code, re.target, 
        max(re.ccc) as ccc, max(re.cc_mask) as cc_mask, max(re.cc_peak) as cc_peak,
	max(smoc) as smoc

        FROM casp13.multimer_results_vs_emm re 
" ;

    my $subquery = " WHERE 1=1 AND re.code >= 0 "; 
    if (defined($param->{target})){
	$subquery .= sprintf (" AND target=\'%s\' ", $param->{target});
    } else {
	return undef;
    }
    if ($param->{model_type} eq "first"){
        $subquery .= " AND re.model_name SIMILAR TO '%_1%' ";
    }
    $subquery .= "  GROUP BY re.code, re.target";

    $query .= $subquery;

    my $sth = $self->{_database}->query($query);

    if (defined($sth) && ($sth->rows() > 0)){
        my @data; 
        while(my ($gr_code, $target,
		$ccc, $cc_mask, $cc_peak,
		$smoc
        ) = $sth->fetchrow_array()){
                push @data, {
			id => '',
                        gr_code => $gr_code,
                        target => $target,
                        ccc => $ccc,
                        cc_mask => $cc_mask,
			cc_peak => $cc_peak,
			smoc => $smoc
                };
	}
        return \@data;
   }

   return undef;
}

sub calcZScore{
  my ($self, $refData) = @_;
  foreach my $score (@SCORES){
    my @arr ; # array to store raw scores
    foreach my $el (@{$refData}){
	if (defined($el->{$score})){  # && $el->{$score} >= 0.0 ){
		push @arr, $el->{$score};
	}
    }
    # calculate mean and standard deviation
    my ($m, $std) = $self->calcMeanStdWithoutOutliers(\@arr);
    # calculate z-score
    foreach my $el (@{$refData}) {
	my $zscore;
        if (defined($el->{$score}) && $el->{$score} >= 0.0 ){
            if (defined($std)){
                if ($std != 0){
                   $zscore = ($el->{$score} - $m)/$std;
                } else {
                   $zscore = 0.0;
                }
            } 
        }
	if (defined($zscore) && $zscore < $OUTLIER_CUTOFF){
		$zscore = $OUTLIER_CUTOFF;
	}
	if (defined($zscore)) {
		$el->{$score} = $zscore;
	} else {
		$el->{$score} = $OUTLIER_CUTOFF;
	}
    }
  }
  return $refData;
}


sub upload_all{
    my ($self, $param, $refData) = @_;
    my @data = @$refData;
    my $table;
    if ($param->{model_type} eq 'first'){
            $table = 'zscores_cryoem_1m';
    } else {
            $table = 'zscores_cryoem_bm';
    }
    foreach my $record (@data) {
	$self->upload($record, $table);
    }
}

sub upload{
    my ($self, $record, $table) = @_;
    my $id = $self->get_id_record($record, $table);
    if ($id) {
	$record->{id} = $id;
	$self->update($record, $table);
    } else {
	$self->add($record, $table);
    }
}

# check if the record exists in the table 
# if it exists return id else 0
sub get_id_record{
    my ($self, $record, $table) = @_;
    my $query = sprintf("SELECT id FROM casp13.%s WHERE gr_code=%d AND target=\'%s\'", $table, $record->{gr_code}, $record->{target});
    my $sth = $self->{_database}->query($query);
    if (defined($sth) && ($sth->rows() > 0)){
        while(my ($id) = $sth->fetchrow_array()) {
		if ($id eq '' || $id == 0){
			return 0;
		} else {
			return $id;
		}
	}
    }
}

# add one record to table
sub add{
    my ($self, $record, $table) = @_;
    my %hash = %$record;
    my $result = 0;
    my $column_names = '';
    my $column_values = '';
    #my $values_count = 0;
    while(my ($key, $value) = each(%hash)){
	if ($key eq 'id'){
		next;
	}
        if ($self->is_valid_columname($key) && defined($value) && $value ne ''){
                $column_names .= sprintf(" %s,", $key);
                $column_values .= sprintf(" \'%s\',", $value);
        }
    }
    $column_names =~ s/,$//;
    $column_values =~ s/,$//;
    my $query = sprintf("INSERT INTO casp13.$table ( %s ) VALUES ( %s ) RETURNING id", $column_names, $column_values);
   # print $query."\n";
   # return $result;

    my $sth = $self->{_database}->query($query);
    if(defined($sth)) {
        # add logger
        ($result) = $sth->fetchrow_array();
    }
    return $result;
}

# update one record to table
sub update{
    my ($self, $record, $table) = @_;
    my %hash = %$record;
    my $result = 0;
    my $set_query = '';
    while ( my ($key, $value) = each(%hash) ) {
	if ($key eq 'id'){
		next;
	}
	if ($self->is_valid_columname($key) && defined($value) && $value ne ''){
		$set_query .= sprintf(" %s = \'%s\',", $key, $value);
	}
    }
    $set_query =~ s/,$//; 

    my $query = sprintf("UPDATE casp13.$table SET %s WHERE id=%d ", $set_query, $hash{id});

   # print $query."\n";
   # return $result;

    my $sth = $self->{_database}->query($query);
    if(defined($sth)) {
        $result = 1;
    }

    return $result;
}

sub mean{
    my ($self, $refArr) = @_;
    my $res = 0.0;
    my @arr = @{$refArr};
    if (scalar(@arr) == 0){
	return undef;
    }
    foreach my $el (@arr){
	$res += $el;
    }
    return $res/scalar(@arr);
}

sub std{
    my ($self, $refArr) = @_;
    my $res = 0.0;
    my @arr = @{$refArr};
    if (scalar(@arr) == 0 ){
	return undef;
    }
    if (scalar(@arr) == 1){
        return $res;
    }
    my $m = $self->mean($refArr);
    foreach my $el (@arr) {
	$res += ($el-$m)*($el-$m);
    }
    $res = sqrt($res/(scalar(@arr) - 1));
    return $res;
}  

sub calcMeanStdWithoutOutliers{
    my ($self, $refArr) = @_;
    my @arr = @{$refArr};
    my @nArr ;
    my $m = $self->mean($refArr);
    my $std = $self->std($refArr);
    foreach my $el (@arr){
	if ($el - $m >= $OUTLIER_CUTOFF * $std){
		push @nArr, $el;
	}
    }
    $m = $self->mean(\@nArr);
    $std = $self->std(\@nArr);
    return ($m, $std);

}

##################################################
#
# methods for retrieving data from database
#
###################################################

sub get_rows{
    my ($self, $param) = @_;
    my $table;
    if ($param->{model_type} eq 'first'){
        $table = 'zscores_cryoem_1m';
    } else {
        $table = 'zscores_cryoem_bm';
    }
    my $subque = ' WHERE 1=1 ';
    # groups_only
    # targets classes (xray, nmr, cryoem)

    
    # query to get no_domains
    my $query = "SELECT COUNT(DISTINCT z.target) FROM casp13.$table z 
	JOIN casp13.targets t ON substring(z.target from 1 for 5)::text=t.name::text
	$subque ";
    my $sth = $self->{_database}->query($query);
    my $no_targets;
    if (defined($sth) && ($sth->rows() > 0)){
	($no_targets) = $sth->fetchrow_array();
    }

    # main query
    my %weights = %{$param->{weights}};

    my $sum_minus2_query = '';
    my $sum_0_query = '';
    my $avrg_minus2_query = '';
    my $avrg_0_query = '';

    foreach my $score (@SCORES){
	$sum_minus2_query .= sprintf("%.3f*z.%s\+", $weights{"w_$score"}, $score);
	$sum_0_query .= sprintf("%.3f*(CASE WHEN z.%s > 0 THEN z.%s ELSE 0.0 END)\+", $weights{"w_$score"}, $score, $score);
    }
    $sum_minus2_query =~ s/\+$//;
    $sum_0_query =~ s/\+$//;
    $avrg_minus2_query = $sum_minus2_query;
    $avrg_0_query = $sum_0_query;

    $query = "SELECT gr.name, gr.eval_capri, gr_code, count(z.*) as count_targets, 
	SUM($sum_minus2_query) as sum_minus2,
	SUM($sum_0_query) as sum_0,
	AVG($avrg_minus2_query) as avg_mimus2,
	AVG($avrg_0_query) as avg_0
	FROM casp13.$table z 
	JOIN CASP13.groups gr ON gr.code=z.gr_code
	$subque 
	GROUP BY gr_code, gr.name, gr.eval_capri  ORDER BY sum_minus2 ";
    
    $sth = $self->{_database}->query($query);
    my @data;
    if (defined($sth) && ($sth->rows() > 0)){
        while(my ($gr_name, $eval_capri, $gr_code, $count_targets, $sum_minus2, $sum_0, $avg_minus2, $avg_0) = $sth->fetchrow_array()){
#	     if ($count_targets > 2){ # count of targets should be at least 3
                push @data, {
                        GR_CODE => sprintf("%03d", $gr_code),
			EVAL_CAPRI => $eval_capri,
			GR_NAME => $gr_name,
			COUNT_TARGETS => $count_targets,
			SUM_MINUS2 => $sum_minus2 + ($no_targets - $count_targets) * $OUTLIER_CUTOFF,
			SUM_0 => $sum_0,
			AVG_MINUS2 => $avg_minus2,
			AVG_0 => $avg_0
                };
#	     }
        }
        return $self->calcFinalRanks(\@data);
    }
    return @data;
}


sub calcFinalRanks{
    my ($self, $refData) = @_;
    my @data = @{$refData};
    my @KEYS = qw(SUM_MINUS2 AVG_MINUS2 AVG_0 SUM_0);
    foreach my $key (@KEYS){
    	@data = sort {$b->{$key}<=>$a->{$key}} @data;
	my $cur_rank = 1;
	my $prev_value;
        for (my $i = 0; $i < scalar(@data); $i++){
		if ($i == 0){
			$prev_value = $data[$i]->{$key};
			$data[$i]->{"RANK_$key"} = $cur_rank;
			next;
		}
		if ($prev_value != $data[$i]->{$key}){
			$prev_value = $data[$i]->{$key};
			$cur_rank = $i + 1;
		}
		$data[$i]->{"RANK_$key"} = $cur_rank;
	}
    }
    for(my $index = 0; $index < scalar(@data); $index++){
	$data[$index]->{"INDEX"} = $index + 1;
    }
    return @data;
}

1;
