#! /cvmfs/oasis.opensciencegrid.org/ligo/sw/conda/envs/igwn-py39/bin/python

"""
Modified from https://git.ligo.org/publications/O3/o3a-cbc-tgr/-/blob/master/release/imr/imr_plot_likelihoods.ipynb
and https://git.ligo.org/publications/O3/o3a-cbc-tgr/-/blob/master/release/imr/norelease/generate_macros.py 

"""
import seaborn as sns
import pylab as plt
import matplotlib
import utils
import utils_imrct
import argparse

# Set style for plots using O3a TGR settings
sns.set(style='ticks', context='notebook', font='serif', 
        font_scale=1.5, palette='colorblind')
plt.rcParams.update(utils.rcParams)

# further tune some rcParams
plt.rcParams.update({
    'ytick.labelsize': 0.6*utils.fs_label,
    'legend.fontsize': 0.8*utils.fs_label,
})

parser = argparse.ArgumentParser(
    description="Create IMRCT contour plot and macros"
)


parser.add_argument(
    "--catalog_name",
    type=str,
    help="Name of catalog, gtwc2 or gwtc3 (default: gwtc3)",
    default='gwtc3',
)

options = parser.parse_args()

# Events that meet the IMR criteria
imrct_events    = {
    'GW150914',
    'GW170104', 
    'GW170809',
    'GW170814', 
    'GW170818',
    'GW170823',
    'S190408an',
    'S190503bf',
    'S190513bm',
    'S190521r',
    'S190630ag',
    'S190814bv',
    'S190828j',
    'S200129m',
    'S200208q',
    'S200224ca',
    'S200225q',
    'S200311bg'
}

#o3_events = [e for e in imrct_events if utils.Event(e).run == 'O3']
#o3b_events = [e for e in imrct_events if utils.Event(e).run == 'O3b']
#o1o2_events = imrct_events - set(o3_events) #- set(o3b_events)

o3b_events = {k for k in imrct_events if 'S20' in k}

settings = {
    'waveform'  : 'Phenom',
    'N_bins'    : 401,
    'data_path' : './likelihoods/imrct_likelihood_{param}_{event}_{wf}.dat.gz',
    'catalog_name' : options.catalog_name,
}

if(options.catalog_name == 'gwtc2'):
    imrct_events -= set(o3b_events)

combined_data = utils_imrct.multiply_likelihoods(imrct_events, **settings)

print('Joint likelihood for all events')
utils_imrct.get_stats(combined_data)

cmap   = plt.cm.coolwarm
cticks = [25, 70., 115]
norm   = matplotlib.colors.TwoSlopeNorm(cticks[1], vmin=25, vmax=115)

comb = (combined_data, combined_data)

if(options.catalog_name == 'gwtc2'):
    fig, cbaxes = utils_imrct.plot2d(list(set(imrct_events)),
                               cmap=cmap, norm=norm, color_ticks=cticks,
                               combined=comb, **settings)
    fig.savefig('./fig/imrct_L_dMfbyMf_dchifbychi_joint_Phenom_GWTC-2_reanalysis.pdf')
else:
    fig, cbaxes = utils_imrct.plot2d(list(set(o3b_events)),
                               cmap=cmap, norm=norm, color_ticks=cticks,
                               combined=comb, **settings)
    fig.savefig('./fig/imrct_L_dMfbyMf_dchifbychi_joint_PhenomXPHM_GWTC-3.pdf')
## add GW190814
#e = 'S190814bv'
#c = cmap(norm(utils.Event(e).get_param('M')))
#fig, cbaxes = imr_utils.plot2d([e], fig=fig, ls='dotted', c=c, **settings)

#generate individual event macros
if(options.catalog_name == 'gwtc3'):
    utils_imrct.generate_event_macros(imrct_events,**settings)

#generate combined macros
utils_imrct.generate_joint_macros(combined_data,**settings)
