"""
Modified from https://git.ligo.org/publications/O3/o3a-cbc-tgr/-/blob/master/release/imr/imr_utils.py
"""

from pylab import *
import scipy
from scipy.interpolate import interp1d, interp2d
import scipy.ndimage.filters as filter
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import utils

uniwiengrey = '#666666'

#list of O3b events which satisfy IMRCT criteria
o3b_imrct_events = ['S200129m','S200208q','S200224ca','S200225q','S200311bg']

# Module for confidence calculations
class confidence(object):
    def __init__(self, counts):

        # Sort in descending order in frequency
        self.counts_sorted = np.sort(counts.flatten())[::-1]

        # Get a normalized cumulative distribution from the mode
        self.norm_cumsum_counts_sorted = np.cumsum(self.counts_sorted) / np.sum(counts)

        # Set interpolations between heights, bins and levels
        self._set_interp()

    def _set_interp(self):

        self._length = len(self.counts_sorted)

        # height from index
        self._height_from_idx = interp1d(np.arange(self._length), self.counts_sorted, bounds_error=False, fill_value=0.)

        # index from height
        self._idx_from_height = interp1d(self.counts_sorted[::-1], np.arange(self._length)[::-1], bounds_error=False, fill_value=self._length)

        # level from index
        self._level_from_idx  = interp1d(np.arange(self._length), self.norm_cumsum_counts_sorted, bounds_error=False, fill_value=1.)

        # index from level
        self._idx_from_level  = interp1d(self.norm_cumsum_counts_sorted, np.arange(self._length), bounds_error=False, fill_value=self._length)

    def level_from_height(self, height):
        return self._level_from_idx(self._idx_from_height(height))

    def height_from_level(self, level):
        return self._height_from_idx(self._idx_from_level(level))

# gaussian filter with 1-sigma softening
def gf(P):
    return filter.gaussian_filter(P, sigma=1.0)

def load_event(path, event='', wf='PhenomPv2'):
    # Read the likelihood data for this event, and corresponding grid
    data = []
    for k in ['dMfbyMf_dchifbychif', 'dMfbyMf_vec', 'dchifbychif_vec']:
        p = path.format(param=k, event=event, wf=wf)
        data.append(np.loadtxt(p))
    return data

def get_marginalized_posteriors(L_dMfbyMf_dchifbychif,dMfbyMf_vec,dchifbychif_vec):
    # Normalize the 2D posteriors
    dx = np.mean(np.diff(dMfbyMf_vec))
    dy = np.mean(np.diff(dchifbychif_vec))
    L_dMfbyMf_dchifbychif /= np.sum(L_dMfbyMf_dchifbychif) * dx * dy

    # Compute marginalized 1D posterior
    L_dMfbyMf      = np.sum(L_dMfbyMf_dchifbychif, axis=0) * dy
    L_dchifbychif  = np.sum(L_dMfbyMf_dchifbychif, axis=1) * dx
    L_dMfbyMf     /= np.sum(L_dMfbyMf) * dx
    L_dchifbychif /= np.sum(L_dchifbychif) * dy
    
    return L_dMfbyMf_dchifbychif, L_dMfbyMf, L_dchifbychif, dx, dy

def multiply_likelihoods(events, **kws):
    # Initialize joint likelihood: *_joint = events below Mcut threshold
    L_dMfbyMf_dchifbychif_joint     = np.ones((kws['N_bins'], kws['N_bins']))

    for event in events:

        # set waveform
        if kws['waveform'] == 'Phenom' and event in ['S190814bv']:
            wf = 'PhenomPv3HM'
        elif kws['waveform'] == 'Phenom' and event in o3b_imrct_events:
            wf = 'PhenomXPHM'
        else:
            wf = 'PhenomPv2'

        # Read the likelihood data for this event, and corresponding grid
        try:
            data = load_event(kws['data_path'], event=event, wf=wf)
            L_dMfbyMf_dchifbychif, dMfbyMf_vec, dchifbychif_vec = data
        except:
            print('Data not found for event: %s' % event)
            continue

        # Construct likelihood interpolation object
        L_dMfbyMf_dchifbychif_interp_obj = interp2d(dMfbyMf_vec,
                                                    dchifbychif_vec,
                                                    L_dMfbyMf_dchifbychif,
                                                    fill_value=0.0,
                                                    bounds_error=False)
        # Interpolate all events to a common grid
        dMfbyMf_vec           = np.linspace(-2.0, 2.0, kws['N_bins'])
        dchifbychif_vec       = np.linspace(-2.0, 2.0, kws['N_bins'])
        L_dMfbyMf_dchifbychif = L_dMfbyMf_dchifbychif_interp_obj(dMfbyMf_vec,
                                                                 dchifbychif_vec)

        # Get marginalized 1D posteriors and normalized 2D posteriors
        L_dMfbyMf_dchifbychif, _, _, dx, dy = get_marginalized_posteriors(L_dMfbyMf_dchifbychif,dMfbyMf_vec,dchifbychif_vec)

        # Joint likelihood for *all* events
        L_dMfbyMf_dchifbychif_joint *= L_dMfbyMf_dchifbychif

        # Removing nans and inf, normalizing the joint posterior
        L_dMfbyMf_dchifbychif_joint[np.isnan(L_dMfbyMf_dchifbychif_joint)] = 0.
        L_dMfbyMf_dchifbychif_joint[np.isinf(L_dMfbyMf_dchifbychif_joint)] = 0.
        L_dMfbyMf_dchifbychif_joint /= np.sum(L_dMfbyMf_dchifbychif_joint) * dx * dy

    # ~~~~~~~~~~~~ Marginalization to one-dimensional joint_posteriors
    L_dMfbyMf_joint      = np.sum(L_dMfbyMf_dchifbychif_joint, axis=0) * dy
    L_dchifbychif_joint  = np.sum(L_dMfbyMf_dchifbychif_joint, axis=1) * dx

    # ~~~~~~~~~~~~ Normalisation of marginalized posteriors
    L_dMfbyMf_joint     /= np.sum(L_dMfbyMf_joint) * dx
    L_dchifbychif_joint /= np.sum(L_dchifbychif_joint) * dy

    return L_dMfbyMf_dchifbychif_joint, L_dMfbyMf_joint, L_dchifbychif_joint, \
           dMfbyMf_vec, dchifbychif_vec


def plot2d(events, combined=None, cmap=None, norm=None, color_ticks=[],
           fig=None, **kws):
    '''
        Plot the 2D posteriors (\delta Mf / M_f, \delta chi_f / chi_f) for all input events and the 1D marginalised posteriors for 
        \delta M_f / M_f and \delta \chi_f / \chi_f.
    
        
        Input:
        ----------
        events      : list of events to include in plot. (Required)
        combined    : a list of arrays with two sets of combined data. The structure of each array is as follows: [L_dMfbyMf_dchifbychif, L_dMfbydMf, L_dchifbydchif, dMfbyMf_vec, dchifbhchif_vec]. 
                      L_dMfbyMf_dchifbychif denotes the 2D posteriors, L_dMfbyMf and L_dchifbydchif the 1D marginalized posteriors, and dMfbyMf_vec and dchifbhchif_vec 
                      the grid on which the posteriors are evaluated. (Default = None)
        cmap        : Specify a color map. (Default = None).
        norm        : Specify the norm for the color map. (Default = None).
        color_ticks : Specify colorbar ticks. (Default = [])
        fig         : Pass an existing figure. If None, will create a new figure. (Default = None)
        **kws       : Keyword arguments: { 'N_bins' : The number of bins used to create grid for posteriors, 
                      'waveform' : Waveform approximant for posteriors (Default = PhenomPv2 for all events except for GW190814 which default 
                      to PhenomPv3HM and PhenomXPHM for O3b events when passing 'waveform' = 'PhenomPv2'), 'lw' : Linewidths (Default = 2), 
                      'ls' : Linestyles (Default = '-' for O3 events and '-.' for GWTC-1 events), 
                      'alpha' : Opacity for contours (Default = 1), 'c' : contour colors (Default = 'gray'), 'data_path' : path to likelihood data (Default = None), 
                      'l_kws' : Dictionary of {'ls', 'lw', 'c', 'alpha'}, overwrites individual keyword arguments,
                      'catalog_name' : name of catalog (gwtc2) }
        
        Output:
        ----------
        fig         : figure showing the 2D and 1D marginalized posteriors
        cbaxes      : colorbar axes
        
    '''

    if fig is None:
        fig = plt.figure(figsize=(12,12))
        ax1 = plt.subplot2grid((3,3), (0,0), colspan=2)
        ax2 = plt.subplot2grid((3,3), (1,2), rowspan=2)
        ax3 = plt.subplot2grid((3,3), (1,0), colspan=2, rowspan=2)
    else:
        ax1, ax2, ax3 = fig.axes[:3]

    # define parameter grid for 2D plot
    dMfbyMf_vec           = np.linspace(-2.0, 2.0, kws['N_bins'])
    dchifbychif_vec       = np.linspace(-2.0, 2.0, kws['N_bins'])

    using_color = not (cmap is None or norm is None)
    for event in sorted(events, reverse=True):
        # get event properties
        e = utils.Event(event)
        m = e.get_param('M')

        # set waveform
        if kws['waveform'] == 'Phenom' and event in ['S190814bv']:
            wf = 'PhenomPv3HM'
        elif kws['waveform'] == 'Phenom' and event in o3b_imrct_events:
            wf = 'PhenomXPHM'
        else:
            wf = 'PhenomPv2'

        # set formatting
        def_l_kws = {
            #'ls': kws.get('ls', '-' if e.run == 'O3a' else ('--' if e.run == 'O3b' else '-.')),
            'ls': kws.get('ls', '-' if e.run == 'O3' else '-.'),
            'lw': kws.get('lw', 2),
            'c': cmap(norm(m)) if using_color else kws.get('c', 'gray'),
            'alpha': kws.get('alpha', 1),
        }
        l_kws = kws.get('l_kws', def_l_kws)

        # Read the likelihood data for this event, and corresponding grid
        try:
            d = load_event(kws['data_path'], event=event, wf=wf)
            L_dMfbyMf_dchifbychif, dMfbyMf_vec, dchifbychif_vec = d
        except:
            print('Data not found for event: %s\n----------\n' % event)
            continue

        # Construct likelihood interpolation object
        L_dMfbyMf_dchifbychif_interp_obj = interp2d(dMfbyMf_vec,
                                                    dchifbychif_vec,
                                                    L_dMfbyMf_dchifbychif,
                                                    fill_value=0.0,
                                                    bounds_error=False)
        # Interpolate to a common grid
        L_dMfbyMf_dchifbychif = L_dMfbyMf_dchifbychif_interp_obj(dMfbyMf_vec, dchifbychif_vec)
        
        # Get marginalized 1D posteriors and normalized 2D posteriors
        L_dMfbyMf_dchifbychif, L_dMfbyMf, L_dchifbychif, dx, dy = get_marginalized_posteriors(L_dMfbyMf_dchifbychif,dMfbyMf_vec,dchifbychif_vec)
        
        # Compute the credible regions of this posterior
        conf_v1v2 = confidence(L_dMfbyMf_dchifbychif)
        s2_v1v2   = conf_v1v2.height_from_level(0.9)

        # Plot the likelihood of event
        ax1.plot(dMfbyMf_vec, L_dMfbyMf, **l_kws)
        ax2.plot(L_dchifbychif, dchifbychif_vec, **l_kws)

        CS = ax3.contour(dMfbyMf_vec,dchifbychif_vec,
                         gf(L_dMfbyMf_dchifbychif),
                         levels     = (s2_v1v2,),
                         linewidths = (l_kws['lw'],),
                         colors     = [l_kws['c']],
                         linestyles = l_kws['ls']
                        )

    if combined is not None:
        # combined should be a list of arrays with two sets of combined data
        L_dMfbyMf_dchifbychif_joint, L_dMfbyMf_joint, L_dchifbychif_joint, dMfbyMf_vec, dchifbychif_vec = combined[0]
        L_dMfbyMf_dchifbychif_joint_all, L_dMfbyMf_joint_all, L_dchifbychif_joint_all, dMfbyMf_vec_all, dchifbychif_vec_all = combined[1]


        # ~~~~~~ Make the main plot for joint likelihood for all events
        conf_v1v2_joint      = confidence(L_dMfbyMf_dchifbychif_joint_all)
        s2_v1v2_joint        = conf_v1v2_joint.height_from_level(0.9)

        # Make 1D dchi/chi plot joint for all events
        ax2.plot(L_dchifbychif_joint_all, dchifbychif_vec_all, color='k', alpha=0.7)

        # Make 1D dM/M plot
        ax1.plot(dMfbyMf_vec_all, L_dMfbyMf_joint_all, color='k', alpha=0.7)

        # Plot 2D joint likelihood
        ax3.contour(dMfbyMf_vec_all, dchifbychif_vec_all,
                    gf(L_dMfbyMf_dchifbychif_joint_all),
                    levels=(s2_v1v2_joint, np.inf), linewidths=(1.15,1.15),
                    colors=['k'], linestyles=('-'), alpha=0.7)


        # ~~~~~~ Make the main plot for joint likelihood for all events with M_det below threshold
        conf_v1v2_joint      = confidence(L_dMfbyMf_dchifbychif_joint)
        s2_v1v2_joint        = conf_v1v2_joint.height_from_level(0.9)

        # Make 1D dM/M plot
        ax1.fill(dMfbyMf_vec, L_dMfbyMf_joint, uniwiengrey, alpha=0.15)
        ax1.plot(dMfbyMf_vec, L_dMfbyMf_joint, color=uniwiengrey, alpha=0.5, lw=3)

        # Make 1D dchi/chi plot
        ax2.fill(L_dchifbychif_joint, dchifbychif_vec, uniwiengrey, alpha=0.15)
        ax2.plot(L_dchifbychif_joint, dchifbychif_vec, color=uniwiengrey, alpha=0.5, lw=3)

        # Plot 2D joint likelihood
        ax3.contourf(dMfbyMf_vec, dchifbychif_vec,
                     gf(L_dMfbyMf_dchifbychif_joint),
                     levels=(s2_v1v2_joint, np.inf), colors=[uniwiengrey],
                     linestyles=('--'), alpha=0.15)
        ax3.contour(dMfbyMf_vec, dchifbychif_vec,
                    gf(L_dMfbyMf_dchifbychif_joint),
                    levels=(s2_v1v2_joint, np.inf), linewidths=(3,3),
                    colors=[uniwiengrey], linestyles=('-'), alpha=0.5)

    ax3.plot(0, 0, 'k+', mew=1)

    ax1.xaxis.tick_top()
    ax2.yaxis.tick_right()

    if(kws['catalog_name'] == 'gwtc2'):
        # Tick labels
        ax1.set_xticks(np.arange(-2.0, 2.01, 0.5))
        ax1.set_yticks(np.arange(2, 11, 2.))
        ax2.set_xticks(np.arange(2, 10, 2.))
        ax2.set_yticks(np.arange(-2.0, 2.01, 0.5))
        ax3.set_xticks(np.arange(-2.0, 2.01, 0.5))
        ax3.set_yticks(np.arange(-2.0, 2.01, 0.5))

        # axis limits
        ax1.set_xlim(-2.0,2.0)
        ax1.set_ylim(0,10)
        ax2.set_ylim(-2.0,2.0)
        ax2.set_xlim(0,9)
        ax3.set_xlim(-2.0,2.0)
        ax3.set_ylim(-2.0,2.0)

    else:
        # Tick labels
        ax1.set_xticks(np.arange(-1.5, 1.51, 0.5))
        ax1.set_yticks(np.arange(2, 11, 2.))
        ax2.set_xticks(np.arange(2, 10, 2.))
        ax2.set_yticks(np.arange(-1.5, 1.1, 0.5))
        ax3.set_xticks(np.arange(-1.5, 1.51, 0.5))
        ax3.set_yticks(np.arange(-1.5, 1.1, 0.5))

        # axis limits
        ax1.set_xlim(-1.5,1.5)
        ax1.set_ylim(0,10)
        ax2.set_ylim(-1.5,1)
        ax2.set_xlim(0,9)
        ax3.set_xlim(-1.5,1.5)
        ax3.set_ylim(-1.5,1)

        for ax in [ax3, ax1]:
            ax.set_xlim(-0.75, 1.5)
            ax.axvline(0, ls='--', lw=2, alpha=0.5, c='k')

        for ax in [ax3, ax2]:
            ax.set_ylim(-1.2, 1)
            ax.axhline(0, ls='--', lw=2, alpha=0.5, c='k')

    ax1.tick_params(axis="x", labelsize=20)
    ax1.tick_params(axis="y", labelsize=20)

    ax2.tick_params(axis="x", labelsize=20)
    ax2.tick_params(axis="y", labelsize=20)

    ax3.tick_params(axis="x", labelsize=28)
    ax3.tick_params(axis="y", labelsize=28)

    # Add axis labels and legends
    ax2.set_xlabel(r'$P(\Delta \chi_{\mathrm{f}}/ \bar{\chi}_{\mathrm{f}})$', labelpad =  8,  fontsize=30)
    ax1.set_ylabel(r'$P(\Delta M_{\mathrm{f}}/ \bar{M}_{\mathrm{f}})$',       labelpad =  10, fontsize=30)
    ax3.set_xlabel(r'$\Delta M_{\mathrm{f}}/ \bar{M}_{\mathrm{f}}$',          labelpad =  8,  fontsize=34)
    ax3.set_ylabel(r'$\Delta \chi_{\mathrm{f}}/ \bar{\chi}_{\mathrm{f}}$',    labelpad = -3,  fontsize=34)
    plt.subplots_adjust(wspace=0., hspace=0.)

    if using_color:
        # create mass colorbar
        cm = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
        cm.set_array([])

        ax4 = plt.subplot2grid((3,3), (0,2))
        ax4.axis('off')

        cbaxes = inset_axes(ax4, width="70%", height="15%", loc="lower center")
        cb = plt.colorbar(cm, orientation='horizontal', cax=cbaxes, ticks=color_ticks)
        cb.ax.set_xticklabels([str(v) for v in color_ticks], fontsize=20)
        cbaxes.set_xlabel(r'$%s$' % utils.Parameter('m').latex, fontdict={'size': 25}, labelpad=15)
        cbaxes.xaxis.set_ticks_position('top')
        cbaxes.xaxis.set_label_position('top')
    else:
        cbaxes = None
    return fig, cbaxes

def get_stats(data,verbose=True,return_dict=False):
    L_dMfbyMf_dchifbychif, L_dMfbyMf, L_dchifbychif, dMfbyMf_vec, dchifbychif_vec = data

    # ~~~~~~~~~~~~ Calculate the credible levels and intervals in the marginalized 1d posteriors (joint)
    M90left, M90right = utils.get_sym_interval_from_pdf(L_dMfbyMf, dMfbyMf_vec, p=0.9)
    M50 = utils.get_ul_from_pdf(L_dMfbyMf, dMfbyMf_vec, p=0.5)

    chi90left, chi90right = utils.get_sym_interval_from_pdf(L_dchifbychif, dchifbychif_vec, p=0.9)
    chi50 = utils.get_ul_from_pdf(L_dchifbychif, dchifbychif_vec, p=0.5)

    # ~~~~~~~~~~~~ Calculate the 2D GR quantile from combined posterior
    conf_v1v2       = confidence(L_dMfbyMf_dchifbychif)
    gr_height       = L_dMfbyMf_dchifbychif[np.argmin(abs(dMfbyMf_vec)),
                                                        np.argmin(abs(dchifbychif_vec))]
    gr_credib_level = conf_v1v2.level_from_height(gr_height)

    if(verbose):
        print('    Credible level of the GR value (combined):\t %.1f%% ' % (100.*gr_credib_level))
        print('    Median +/- 90%% sym-CI on dMfbyMf:\t\t %2.3f +%2.3f -%2.3f ' % (M50, M90right-M50, M50-M90left))
        print('    Median +/- 90%% sym-CI on dchifbychif:\t %2.3f +%2.3f -%2.3f '% (chi50, chi90right-chi50, chi50-chi90left))
    
    if(return_dict):
        stats_dict = {
            'GR_credible_level' : 100.*gr_credib_level,
            'dMf'   : {'med' : float(M50),   'hi' : M90right-M50,     'lo' : M50-M90left},
            'dchif' : {'med' : float(chi50), 'hi' : chi90right-chi50, 'lo' : chi50-chi90left}
        }
        return stats_dict

def get_event_stats(event, **kws):

    # set waveform
    if kws['waveform'] == 'Phenom' and event in ['S190814bv']:
        wf = 'PhenomPv3HM'
    elif kws['waveform'] == 'Phenom' and event in o3b_imrct_events:
        wf = 'PhenomXPHM'
    else:
        wf = 'PhenomPv2'

    # Read the likelihood data for this event, and corresponding grid
    try:
        data = load_event(kws['data_path'], event=event, wf=wf)
        L_dMfbyMf_dchifbychif, dMfbyMf_vec, dchifbychif_vec = data
    except:
        print('Data not found for event: %s' % event)
        return -1

    # Construct likelihood interpolation object
    L_dMfbyMf_dchifbychif_interp_obj = interp2d(dMfbyMf_vec,
                                                dchifbychif_vec,
                                                L_dMfbyMf_dchifbychif,
                                                fill_value=0.0,
                                                bounds_error=False)
    # Interpolate all events to a common grid
    dMfbyMf_vec           = np.linspace(-2.0, 2.0, kws['N_bins'])
    dchifbychif_vec       = np.linspace(-2.0, 2.0, kws['N_bins'])
    L_dMfbyMf_dchifbychif = L_dMfbyMf_dchifbychif_interp_obj(dMfbyMf_vec,
                                                             dchifbychif_vec)
    
    # Get marginalized 1D posteriors and normalized 2D posteriors
    L_dMfbyMf_dchifbychif, L_dMfbyMf, L_dchifbychif, dx, dy = get_marginalized_posteriors(L_dMfbyMf_dchifbychif,dMfbyMf_vec,dchifbychif_vec)

    # repack the 2D and 1D data to pass to get_stats
    data  = [L_dMfbyMf_dchifbychif, L_dMfbyMf, L_dchifbychif, dMfbyMf_vec, dchifbychif_vec]
    stats = get_stats(data,verbose=False,return_dict=True)

    stats_dict = {
        event : stats
    }

    return stats_dict

def generate_event_macros(imrct_events, **kws):
    imrct_event_stats = {}

    for ev in imrct_events:
        stats = get_event_stats(ev,**kws)
        imrct_event_stats.update( stats )

    macros    = []
    case_list = []
    
    params    = ['dMf','dchif']
    labels    = {'dMf' : 'DMF', 'dchif' : 'DCHIF'}

    for ev in imrct_events:
        # Compute Phenom median and 90% CI
        phenom_stats = imrct_event_stats[ev]
        
        for par in params:
            label = labels[par] 
            med   = phenom_stats[par]['med']
            lo    = phenom_stats[par]['lo']
            hi    = phenom_stats[par]['hi']
            s     = r'{%s%s}{\ensuremath{%.2f^{+%.2f}_{-%.2f}}}' % (ev, label+'GWTC3PHENOM', med, hi, lo)
            case_list.append(s)
            
        quant = phenom_stats['GR_credible_level']
        s     = r'{%s%s}{\ensuremath{%.1f}}' % (ev, 'GRQUANTGWTC3', quant)
        case_list.append(s)

        cases = ''.join(case_list)

    macros.append(r'\newcommand{\ImrEVENTSTATS}[1]{\IfEqCase{#1}{%s}}' % (cases))

    save_dir = kws.get('save_dir','./macros/')

    with open(save_dir + 'imrct_event_macros.tex', 'w') as f:
        f.write('\n'.join(macros))
    print('Exported macros to : '+save_dir+'imrct_event_macros.tex')

#def generate_joint_macros(combined_data_low,combined_data_low_seob,**kws):
def generate_joint_macros(combined_data,**kws):
    macros    = []
    case_list = []
    params    = ['dMf','dchif']
    labels    = {'dMf' : 'DMF', 'dchif' : 'DCHIF'}
    
    # Compute Phenom median and 90% CI
    phenom_stats = get_stats(combined_data,verbose=False,return_dict=True)
    
#    # Compute seob median and 90% CI
#    seob_stats   = get_stats(combined_data_low_seob,verbose=False,return_dict=True)

    if(kws['catalog_name'] == 'gwtc2'):
        gwtc_name = 'GWTC2'
        command_name = '\ImrGWTCTWO'
        file_name_suffix = '_gwtc2'
    else:
        gwtc_name = 'GWTC3'
        command_name = '\ImrGWTCTHREE'
        file_name_suffix = '_gwtc3'
    
    for par in params:
        label = labels[par]

        med   = phenom_stats[par]['med']
        lo    = phenom_stats[par]['lo']
        hi    = phenom_stats[par]['hi']
        s     = r'{%s}{\ensuremath{%.2f^{+%.2f}_{-%.2f}}}' % (label+gwtc_name+'PHENOM', med, hi, lo)
        case_list.append(s)

#        med   = seob_stats[par]['med']
#        lo    = seob_stats[par]['lo']
#        hi    = seob_stats[par]['hi']
#        s     = r'{%s}{\ensuremath{%.2f^{+%.2f}_{-%.2f}}}' % (label+'GWTC3SEOB', med, hi, lo)
#        case_list.append(s)

    quant = phenom_stats['GR_credible_level']
    s     = r'{%s}{\ensuremath{%.1f}}' % ('GRQUANT'+gwtc_name, quant)
    case_list.append(s)

    cases = ''.join(case_list)
    macros.append(r'\newcommand{%s}[1]{\IfEqCase{#1}{%s}}' % (command_name,cases))
    
    save_dir = kws.get('save_dir', './macros/')

    with open(save_dir + 'imrct_macros'+file_name_suffix+'.tex', 'w') as f:
         f.write('\n'.join(macros))
    print('Exported joint macros to : '+save_dir+'imrct_macros'+file_name_suffix+'.tex')
