"""Convert a posterior distribution of dchi in the generic FD parameterization to a posterior distribution of dchi in the TIGER parameterization """

#With bugfix by Ethan Payne and Max Isi, 2023

import numpy as np
import argparse
import re
from numpy import random
from scipy.special import lambertw
from scipy import interpolate
from scipy.stats import gaussian_kde

lal_gamma = 0.577215664901532860606512090082402431

#The functions phi${N} return the coefficient of the N/2-PN term in the inspiral (as in Eq. A4 of https://arxiv.org/abs/1005.3306)
def phi0(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return 1.

def phi1(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return 1.

def phi2(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  eta = (m1*m2)/(m1+m2)**2.
  return 5.*(743./84. + 11.*eta)/9.

def phi3(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  m1M = m1/(m1+m2)
  m2M = m2/(m1+m2)
  d = (m1-m2)/(m1+m2)
  SL = m1M * m1M * a1L + m2M * m2M * a2L
  dSigmaL = d * (m2M * a2L - m1M * a1L)
  return -16.* np.pi + 188.*SL/3. + 25.*dSigmaL

def phi4(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  #qm_def are the spin susceptibailities of the objects, which we take as the black hole value of 1.
  qm_def1 = 1
  qm_def2 = 1
  m1M = m1/(m1+m2)
  m2M = m2/(m1+m2)
  eta = (m1*m2)/(m1+m2)**2.
  pnsigma = eta * (721./48. * a1L * a2L - 247./48. * a1dota2) + (720.*(qm_def1) - 1.)/96.0* m1M* m1M * a1L * a1L + (720. *(qm_def2) - 1.)/96.0 * m2M * m2M * a2L * a2L - (240.*(qm_def1) - 7.)/96.0 * m1M * m1M * a1sq - (240.*(qm_def2) - 7.)/96.0 * m2M * m2M * a2sq

  return 5.*(3058.673/7.056 + 5429./7.*eta + 617.*eta*eta)/72. - 10.*pnsigma

def phi5l(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  m1M = m1/(m1+m2)
  m2M = m2/(m1+m2)
  d = (m1-m2)/(m1+m2)
  eta = (m1*m2)/(m1+m2)**2.
  SL = m1M * m1M * a1L + m2M * m2M * a2L
  dSigmaL = d * (m2M * a2L - m1M * a1L)
  pngamma = (554345./1134. + 110.*eta/9.)*SL + (13915./84. - 10.*eta/3.)*dSigmaL
  return 5./3. * (7729./84. - 13. * eta) * np.pi - 3. * pngamma

def phi6(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  #qm_def are the spin susceptibailities of the objects, which we take as the black hole value of 1.
  qm_def1 = 1
  qm_def2 = 1
  m1M = m1/(m1+m2)
  m2M = m2/(m1+m2)
  d = (m1-m2)/(m1+m2)
  eta = (m1*m2)/(m1+m2)**2.
  SL = m1M * m1M * a1L + m2M * m2M * a2L
  dSigmaL = d * (m2M * a2L - m1M * a1L)
  pnss3 = (326.75/1.12 + 557.5/1.8*eta) * eta * a1L * a2L + ((4703.5/8.4 + 2935./6. * m1M - 120. * m1M * m1M)*(qm_def1) + (-4108.25/6.72 - 108.5/1.2*m1M + 125.5/3.6*m1M*m1M))*m1M*m1M* a1sq + ((4703.5/8.4 + 2935./6. * m2M - 120. * m2M * m2M)*(qm_def2) + (-4108.25/6.72 - 108.5/1.2*m2M + 125.5/3.6*m2M*m2M))*m2M*m2M* a2sq
  return (11583.231236531/4.694215680 - 640./3. * np.pi * np.pi - 6848./21.*lal_gamma) + eta*(-15737.765635/3.048192 + 2255./12.*np.pi*np.pi) + eta*eta*76055./1728. - eta*eta*eta*127825./1296. + (-6848./21.)*np.log(4.) + np.pi*(3760.*SL + 1490*dSigmaL)/3. + pnss3

def phi6l(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return -6848./21.

def phi7(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  m1M = m1/(m1+m2)
  m2M = m2/(m1+m2)
  d = (m1-m2)/(m1+m2)
  eta = (m1*m2)/(m1+m2)**2.
  SL = m1M * m1M * a1L + m2M * m2M * a2L
  dSigmaL = d * (m2M * a2L - m1M * a1L)
  return np.pi*(77096675./254016. + 378515./1512.*eta - 74045./756.*eta*eta) + (-8980424995./762048. + 6586595.*eta/756. - 305.*eta*eta/36.)* SL - (170978035./48384. - 2876425.*eta/672. - 4735.*eta*eta/144.)* dSigmaL

def phiMinus2(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return 1.


#The functions phi${N}NS return the spin-independent component of the coefficient of the N/2-PN term in the inspiral

def phi0NS(m1,m2):
  return phi0(m1,m2,0.,0.,0.,0.,0.)

def phi1NS(m1,m2):
  return phi1(m1,m2,0.,0.,0.,0.,0.)

def phi2NS(m1,m2):
  return phi2(m1,m2,0.,0.,0.,0.,0.)

def phi3NS(m1,m2):
  return phi3(m1,m2,0.,0.,0.,0.,0.)

def phi4NS(m1,m2):
  return phi4(m1,m2,0.,0.,0.,0.,0.)

def phi5lNS(m1,m2):
  return phi5l(m1,m2,0.,0.,0.,0.,0.)

def phi6NS(m1,m2):
  return phi6(m1,m2,0.,0.,0.,0.,0.)

def phi6lNS(m1,m2):
  return phi6l(m1,m2,0.,0.,0.,0.,0.)

def phi7NS(m1,m2):
  return phi7(m1,m2,0.,0.,0.,0.,0.)

def phiMinus2NS(m1,m2):
  return phiMinus2(m1,m2,0.,0.,0.,0.,0.)

#The functions phi${N}S return the spin-dependent component of the coefficient of the N/2-PN term in the inspiral

def phi0S(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi0(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi0NS(m1, m2)

def phi1S(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi1(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi1NS(m1, m2)

def phi2S(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi2(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi2NS(m1, m2)

def phi3S(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi3(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi3NS(m1, m2)

def phi4S(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi4(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi4NS(m1, m2)

def phi5lS(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi5l(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi5lNS(m1, m2)

def phi6S(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi6(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi6NS(m1, m2)

def phi6lS(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi6l(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi6lNS(m1, m2)

def phi7S(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phi7(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phi7NS(m1, m2)

def phiMinus2S(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2):
  return phiMinus2(m1, m2, a1L, a2L, a1sq, a2sq, a1dota2) - phiMinus2NS(m1, m2)

#Dictionaries that map the testing-GR parameter of each run to the corresponding function above
phiDict = {'dchi0':phi0, 'dchi1':phi1, 'dchi2':phi2, 'dchi3':phi3, 'dchi4':phi4, 'dchi5l':phi5l, 'dchi6':phi6, 'dchi6l':phi6l, 'dchi7':phi7, 'dchiminus2':phiMinus2, 'dipolecoeff':phiMinus2}
phiNSDict = {'dchi0':phi0NS, 'dchi1':phi1NS, 'dchi2':phi2NS, 'dchi3':phi3NS, 'dchi4':phi4NS, 'dchi5l':phi5lNS, 'dchi6':phi6NS, 'dchi6l':phi6lNS, 'dchi7':phi7NS, 'dchiminus2':phiMinus2NS, 'dipolecoeff':phiMinus2NS}
phiSDict = {'dchi0':phi0S, 'dchi1':phi1S, 'dchi2':phi2S, 'dchi3':phi3S, 'dchi4':phi4S, 'dchi5l':phi5lS, 'dchi6':phi6S, 'dchi6l':phi6lS, 'dchi7':phi7S, 'dchiminus2':phiMinus2S, 'dipolecoeff':phiMinus2S}

def convert_genericFD_to_TIGER(data, param, bins_arg=25, nsamples=1000000, resample_factor=10, intrinsic_names=['mass_1', 'mass_2', 'spin_1z', 'spin_2z']):
  """Given a full set of posterior samples from a generic FD run, return the bins and PDF for dchi for an equivalent TIGER run"""

  if param in ['dchi0', 'dchi1', 'dchi2', 'dchi6l', 'dchiminus2', 'dipolecoeff']:
    return data[param]

  #Convert the posterior distribution of dchi_i (as parameterized with generic FD) into a distribution pf dchi_i (parameterized with TIGER)
  m1 = data[intrinsic_names[0]]
  m2 = data[intrinsic_names[1]]
  a1z = data[intrinsic_names[2]]
  a2z = data[intrinsic_names[3]]
  a1sq = data[intrinsic_names[2]]*data[intrinsic_names[2]]
  a2sq = data[intrinsic_names[3]]*data[intrinsic_names[3]]
  a1dota2 = data[intrinsic_names[2]]*data[intrinsic_names[3]]

  factor = (1. + phiSDict[param](m1,m2,a1z,a2z,a1sq,a2sq,a1dota2)/phiNSDict[param](m1,m2))

  dchidata_TIGER = data[param]*factor
  weights = np.abs(factor)

  dchi_min=min(dchidata_TIGER)
  dchi_max=max(dchidata_TIGER)
  dchi_grid = np.linspace(dchi_min,dchi_max,num=bins_arg+1)

  pdf_dchi = gaussian_kde(dchidata_TIGER, weights=weights)(dchi_grid)
  cdf_dchi = np.cumsum(pdf_dchi)/np.sum(pdf_dchi)

  #Tabulate the discrete CDF, invert it, and interpolate the inverse CDF
  inv_cdf = interpolate.interp1d(cdf_dchi, dchi_grid, bounds_error=False, fill_value=(0,1))

  #Return the bins and values of the reweighted posterior distribution
  r = random.uniform(0.,1.,int(nsamples/resample_factor))

  return inv_cdf(r)
