#Standard python imports
import numpy as np, os, traceback, warnings
from scipy.special import logsumexp

#LVC imports
from pesummary.gw.file.read import read as GWread

def extract_logB(event, approx, time, runtag, chain_str, directory, PESummary_format=True):
    
    if(PESummary_format):
        f = GWread(os.path.join(directory, '{S_event}'.format(S_event = Events_dict[event]), 'rin_{S_event}_pyring_{wf_approx}_{start_time}M.h5'.format(S_event = Events_dict[event], wf_approx=approx, start_time=time)))
        try:
            logB   = float(f.extra_kwargs[0]['other']['lnZ_signal']) - float(f.extra_kwargs[0]['other']['lnZ_noise'])
        except(KeyError):
            logB   = float(f.extra_kwargs[0]['other']['logZ_signal']) - float(f.extra_kwargs[0]['other']['logZ_noise'])
    else:
        try:
            logB = np.genfromtxt(os.path.join(directory, event, '{}_{}_{}_{:d}M{}/Nested_sampler/Evidence.txt'.format(event, runtag, approx, time, chain_str)), names=True)['lnB']
        except:
            logB = np.genfromtxt(os.path.join(directory, event, '{}_{}_{}_{:d}M{}/Nested_sampler/Evidence.txt'.format(event, runtag, approx, time, chain_str)), names=True)['logB']
    return logB

PESummary_flag = True
ringdown_posteriors_dir = '/home/tgr.o3/o3b-tgr/results/rin/'

runtag    = 'GWTC3_rerun_PROD1'
chain_str = '_weighted_posterior'

Events_dict_GWTC3 = {
                'GW150914' : 'S150914'  ,
                'GW170104' : 'S170104'  ,
                'GW170814' : 'S170814'  ,
                'GW170823' : 'S170823'  ,
                'GW190408A': 'S190408an',
                'GW190512A': 'S190512at',
                'GW190513A': 'S190513bm',
                'GW190519A': 'S190519bj',
                'GW190521A': 'S190521g' ,
                'GW190521B': 'S190521r' ,
                'GW190602A': 'S190602aq',
                'GW190706A': 'S190706ai',
                'GW190708A': 'S190708ap',
                'GW190727A': 'S190727h' ,
                'GW190828A': 'S190828j' ,
                'GW190910A': 'S190910s' ,
                'GW190915A': 'S190915ak',
                # 'GW191109A': 'S191109d' , # this event is removed from the list since it is not included in combined results
                'GW191222A': 'S191222n' ,
                'GW200129A': 'S200129m' ,
                'GW200224B': 'S200224ca',
                'GW200311B': 'S200311bg',
              }

Events_dict = Events_dict_GWTC3
events = list(Events_dict.keys())

if __name__=='__main__':

    approxs_TIGER_overtones = ['Kerr_221_domega_221'     ,
                               'Kerr_221_dtau_221'       ,
                               'Kerr_221_domega_dtau_221']

    # In both overtones and HMs cases we have two varying coefficients (either domega_220, domega_221 or domega_220, domega_330)
    N_TIGER = 2
    N_TIGER_NORM = np.log(2**N_TIGER - 1)
    # We use natural logs for all the computations and convert to log10 at the end.
    # prior odds on O_nGR_GR = 1
    # TIGER logB per single event (overtones)
    try:
        GR_approx = 'Kerr_221'
        TIGER_string_ovs  = ''
        time = 0
        print('\nOvertones single event TIGER\n')
        print('Event \t logB_GR_noise \t logB_nGR_GR\n')
        for event in events:
            logB_GR  = extract_logB(event = event, approx = GR_approx, time = time, runtag = runtag, chain_str = chain_str, directory = ringdown_posteriors_dir, PESummary_format = PESummary_flag)
            logBs = []
            for approx in approxs_TIGER_overtones:
                logB_nGR_x = extract_logB(event = event, approx = approx, time = time, runtag = runtag, chain_str = chain_str, directory = ringdown_posteriors_dir, PESummary_format = PESummary_flag)
                logBs.append(logB_nGR_x)
            logB_TIGER_ovs = (logsumexp(logBs) - logB_GR - N_TIGER_NORM)/np.log(10)
            print('{} : {:.2f}\t{:.2f} '.format(event, logB_GR/np.log(10), logB_TIGER_ovs))
            TIGER_string_ovs += '{%s}{%.1f}'%(event, logB_TIGER_ovs)
    except:
        warnings.warn("Skipping logB_nGR_GR overtones single events computation due to the following exception: \n{}".format(traceback.print_exc()))

    # TIGER logB overtones combined events (eq. 47 of Li et al. arxiv:1110.0530)
    try:
        GR_approx = 'Kerr_221'
        time = 0
        logBs = []
        median_logBs_single_hyp_single_event_dict = {}
        for approx in approxs_TIGER_overtones:
            logBs_single_hyp = 0.0
            median_logBs_single_hyp_single_event_dict[approx] = []
            for event in events:
                logB_GR  = extract_logB(event = event, approx = GR_approx, time = time, runtag = runtag, chain_str = chain_str, directory = ringdown_posteriors_dir, PESummary_format = PESummary_flag)
                logB_nGR = extract_logB(event = event, approx = approx, time = time, runtag = runtag, chain_str = chain_str, directory = ringdown_posteriors_dir, PESummary_format = PESummary_flag)
                median_logBs_single_hyp_single_event_dict[approx].append(logB_nGR - logB_GR)
                logBs_single_hyp += logB_nGR - logB_GR
            logBs.append(logBs_single_hyp)
        logB_TIGER_ovs_overall = (logsumexp(logBs) - N_TIGER_NORM)/np.log(10)
        print('\nlogB_nGR_GR combined events (overtones): {:.2f}\n'.format(logB_TIGER_ovs_overall))
    except:
        warnings.warn("Skipping logB_nGR_GR overtones overall computation due to the following exception: \n{}".format(traceback.print_exc()))

    # Fix the seed to allow for reproducibility of the logB error
    np.random.seed(2)

    # To compute the uncertainty on the combined logB, we draw samples for this quantity. We assume gaussians (or uniform distributions, it doesn't affect the result much) on the single logBs, using the uncertainty on the single run, which we already have.
    # The 0.14 is the conservative sampler statistical uncertainty on each run, the sqrt(2) comes from the fact that we are summing the logB of s/n to obtain the nGR/GR ones.
    sigma = 0.14*np.sqrt(2)
    combined_logBF_list = []
    logB_draw = []
    for i in range(100):
        for approx in approxs_TIGER_overtones:
            logB_single_hyp_draw = 0.0
            for j in range(len(events)):
                # It is VERY important that we extract samples for the SINGLE-EVENT SINGLE-HYPOTHESIS logB, not from its sum or product or any other combination (otherwise the error estimate that we have is unusable).
                logB_single_hyp_draw += np.random.normal(loc = median_logBs_single_hyp_single_event_dict[approx][j], scale = sigma)
            
                #Doing the same with a uniform distribution (to gauge a bit on the true logB distribution uncertainty). Output: gives a comparable result.
                #logB_single_hyp_draw += np.random.uniform(low = median_logBs_single_hyp_single_event_dict[approx][j]-sigma, high = median_logBs_single_hyp_single_event_dict[approx][j]+sigma)
                
            logB_draw.append(logB_single_hyp_draw)
        combined_logBF_list.append((logsumexp(logB_draw) - N_TIGER_NORM)/np.log(10))
    print('Error on combined logBF: {:.2f}'.format(np.std(combined_logBF_list, ddof=1)))
