# -*- coding: utf-8 -*-
#
#       Copyright 2020
#       Maximiliano Isi <max.isi@ligo.org>
#
#       This program is free software; you can redistribute it and/or modify
#       it under the terms of the GNU General Public License as published by
#       the Free Software Foundation; either version 2 of the License, or
#       (at your option) any later version.
#
#       This program is distributed in the hope that it will be useful,
#       but WITHOUT ANY WARRANTY; without even the implied warranty of
#       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#       GNU General Public License for more details.
#
#       You should have received a copy of the GNU General Public License
#       along with this program; if not, write to the Free Software
#       Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
#       MA 02110-1301, USA.


from pylab import *
import seaborn as sns
import utils
from matplotlib import gridspec
from matplotlib.patches import Patch
import matplotlib.ticker as ticker
from collections import defaultdict

sns.set(style='ticks', context='notebook', font='serif',
        font_scale=1.5, palette='colorblind')
plt.rcParams.update(utils.rcParams)

# further tune some rcParams
plt.rcParams.update({
    'ytick.labelsize': 0.6*utils.fs_label,
    'legend.fontsize': 0.8*utils.fs_label,
})

params_pre = ['dphi-2']+ ['dphi%i' % n for n in range(5)] +\
             ['dphi5l', 'dphi6', 'dphi6l', 'dphi7']

params_post = ['db%i' % n for n in [1,2,3,4]] +\
              ['dc%i' % n for n in [1,2,4]] +\
              ['dcl'] 
params = params_pre + params_post

def get_label(param):
    if param == 'dphi-2':
        l = r'${\varphi}_{-2}${\fontsize{16pt}{3em}{ $(\times 20)$}'
        l = r'${\varphi}_{-2}$'
    elif 'phi' in param:
        l =  param.replace('dphi', r'${\varphi}_{') + '}$'
    elif 'db' in param:
        l = param.replace('db', r'$db_{') + '}$'
    else:
        l = param.replace('dc', r'$dc_{') + '}$'
    return l
                                       
def get_ci(key, df, ci=0.9):
    # compute 90% CL
    samples = df[df['param']==key]['value']
    med = np.percentile(samples, 50)
    lo =  np.percentile(samples, 50*(1-ci))
    hi =  np.percentile(samples, 100-50*(1-ci))
    return med, lo, hi

# monkey patch seaborn function to display 90%-credible intervals in violin
QCOLOR = '0.9'
class _ViolinPlotter(sns.categorical._ViolinPlotter):
    def draw_quartiles(self, ax, data, support, density, center, split=False):
        """Draw the quartiles as lines at width of density."""
        # patch here to replace quartiles with symmetric 90%-credible interval
        q25, q50, q75 = np.percentile(data, [5, 50, 95])
        self.gray = QCOLOR
        lw = 1
        # also change ticklength so that it's better to see marks towards edges
        self.draw_to_density(ax, center, q25, support, density, split,
                             linewidth=lw,
                             dashes=[self.linewidth * 0.5] * 2)
        self.draw_to_density(ax, center, q50, support, density, split,
                             linewidth=lw,
                             dashes=[self.linewidth * 2] * 2)
        self.draw_to_density(ax, center, q75, support, density, split,
                             linewidth=lw,
                             dashes=[self.linewidth * 0.5] * 2)

sns.categorical._ViolinPlotter = _ViolinPlotter

# break parameters into subplots (default)
DEF_BREAK = [params[0:1], params[1:5], params[5:10]]

pn_labels = ['$-1$ PN', '$0$ PN', '$0.5$ PN', '$1$ PN', '$1.5$ PN', '$2$ PN',
             '$2.5$ PN$^{(l)}$', '$3$ PN', '$3$ PN$^{(l)}$', '$3.5$ PN']
pn_labels_dict = dict(zip(params, pn_labels))

def plot_violin(df, breakdown=DEF_BREAK, palette='Set2', aspect_ratio=4.57,
                nolegend=False, singles=params_post, scale_factors=None,
                fig=None, mark_ci=True, tick_length=14, hue='approx',
                top_ticks=False, labels=None, inspiral_only=False,**kwargs):
    """Produce violin plots for parameterized tests.

    Arguments
    ---------
    df: pd.DataFrame
        data frame containing samples for all coefficients, events and approxs
    breakdown: list
        list of list with breakdown of parameters across panels (opt.)
    palette: str,list
        color palette, palette name, or list of colors (def. Set2)
    aspect_ratio: float
        figure aspect ratio (def. 4.57)
    nolegend: bool
        do not add legend (def. False)
    singles: list
        list of parameters for which to plot unsplit vilins (def. postinspiral)
    scale_factors: dict
        scaling factors for individual parameters (def. None)
    fig: figure
        matplotlib figure object (opt.)
    mark_ci: bool
        draw 90% CLs (def. True)
    tick_length: float
        length of CL ticks.
    hue: str
        name of df column based on which to split violins (def. 'approx')
    top_ticks: bool
        add PN labels up top (def. False)
    labels: list
        label names (def. values in df['approx'])
    kwargs:
        additional arguments are passed to kde within violin.
    """
    if fig is None:
        # create figure of right proportions
        width = utils.fig_width_page
        height = width/aspect_ratio
        fig = plt.figure(figsize=(width, height))

        # great axis grid
        number_subplots = len(breakdown)
        number_per_subplot = [len(keys) for keys in breakdown]
        gs = gridspec.GridSpec(1, number_subplots, width_ratios=number_per_subplot)
        axs = [plt.subplot(g) for g in gs]
    else:
        axs = fig.axes

    scale_factors = scale_factors or {}
    
    # seaborn kwargs for violin style
    kws = defaultdict(None, {
        'inner': 'quartile',
        'gridsize': 500,
        'linewidth': 1.0,
        'scale': 'width',
    })
    kws.update(kwargs)
    
    # rescale coefficients if needed
    if scale_factors is None:
        scale_factors = {}
    elif scale_factors:
        df = df.copy()
    for k, scale in scale_factors.items():
        df.loc[df['param'] == k, 'value'] *= scale
    
    # plot coefficients 
    for i, (keys, ax) in enumerate(zip(breakdown, axs)):
        if any([k in singles for k in keys]):
            ax = sns.violinplot(x="param", y="value",
                                data=df[[k in keys for k in df['param']]],
                                ax=ax, color=sns.color_palette(palette)[0],
                                **kws)
            if mark_ci:
                for j,k in enumerate(keys):
                    med, lo, hi = get_ci(k, df)
                    for y in [lo,hi]:
                        ax.plot([j],[y], marker='_', markersize=tick_length,
                                c=sns.color_palette(palette, desat=0.5)[0])
        else:
            ax = sns.violinplot(x="param", y="value", hue=hue,
                                data=df[[k in keys for k in df['param']]],
                                split=True, ax=ax, palette=palette, **kws)
            if mark_ci:
                for j,k in enumerate(keys):
                    for c, (a, m) in enumerate(zip(df[hue].unique(), [0,1])):
                        med, lo, hi = get_ci(k, df[df[hue]==a])
                        for y in [lo, hi]:
                            ax.plot([j],[y], marker=m, markersize=tick_length/2,
                                    c=sns.color_palette(palette, desat=1)[c])
            # remove legend except for center panel
            if i != 3 or nolegend:
                ax.legend_.remove()
            else:
                l = ax.get_legend()
                labels = labels or df[hue].unique()
                ax.legend(l.legendHandles, labels,
                          loc='lower left', edgecolor='w')
        if any([k in scale_factors for k in keys]):
            # WARNING: wrong if more than one key!
            ax.text(0.01, 0.99, r'$\times %i$' % scale_factors[keys[0]],
                    horizontalalignment='left', verticalalignment='top', 
                    transform=ax.transAxes, fontsize=0.8*utils.fs_label)
        ax.tick_params(axis='y', length=2, pad=1)
        ax.set_xlabel('')
        ax.set_ylabel('')
        if i == 0 and inspiral_only==False:
            ax.set_ylabel(r'$\delta \hat{p}_i$', fontsize=utils.fs_label, labelpad=0)
        if i == 0 and inspiral_only:
            ax.set_ylabel(r'$\delta \hat{\varphi}_i$', fontsize=utils.fs_label, labelpad=0)
        ax.set_xticklabels([get_label(k) for k in keys])
        plt.tick_params(axis='y', which='major')
        ax.axhline(0, ls='--', lw=3, alpha=0.1, c='k')

    if top_ticks:
        # add twin abscissa on top
        xticks_top = [[pn_labels_dict[k] for k in keys] for keys in breakdown
                      if all([k in pn_labels_dict for k in keys])]
        for ax, xt in zip(fig.axes, xticks_top):
            twin_ax = ax.twiny()
            twin_ax.set_xticks(ax.get_xticks())
            twin_ax.set_xticklabels(xt, fontsize=10)
            twin_ax.set_xlim(ax.get_xlim())
    return fig
