## Solution provided by Ivan Markin as part of the Numerical Relativity class at the University of Potsdam

import numpy as np
import scipy

import matplotlib.pyplot as plt
from cycler import cycler
import scipy
import os

#matplotlib inline

# better pictures and legends
plt.rc('figure', dpi=150)
plt.rc('text', usetex=True)
plt.rc('font', family='serif', size=10)

berlin = {
    "U7": "#009AD9",
    "U3": "#00A192",
    "U1": "#62AD2D",
    "U2": "#E94D10",
    "U4": "#FFD401",
    "U5": "#815237",
    "U6": "#846DAA",
    "U8": "#005A99",
    "U9": "#F18800",
}

zoo_berlin = [berlin["U1"], berlin["U2"], berlin["U9"], ]
palette = [berlin[key] for key in berlin]
# palette = [berlin["U7"], berlin["U1"]]

plt.rc('axes', prop_cycle=cycler(color=palette))

import scipy.interpolate
import scipy.optimize


def just_like(x):
    return np.full_like(x,1)

def has_nan(x):
    return np.isnan(np.sum(x))


# Right-hand side for the evolution of lapse
def slicing_rhs(name, U):
    if name == "geodesic":
        return 0 * just_like(U["alpha"])
    elif name == "1+log":
        # Calculate the trace of the extrinsic curvature
        K = U["K_A"] + 2 * U["K_B"]
        return -2*U["alpha"]*K
    else:
        raise Exception("unknown slicing")


## Set the parameters
M = 1
t_final = 1.2*np.pi*M
slicing = "1+log" #"geodesic" #"geodesic" # "1+log"

plot_every_nth_iteration = 10 #10
plot_variables = ["alpha", "K_A", "K_B", "K", "A_bar", "B_bar", "D_A_bar"]


# Create radial grid
Nr = 1000
r_max = 10
r, dr = np.linspace(0, r_max, Nr, endpoint=False, retstep=True)
# Stagger the grid
r += 0.5*dr

CFL = 0.4
dt=CFL*dr
print(f'grid: dr={dr}, dt={dt}, r0={r[0]}')

r_sane = r_max - 0.7*t_final

psi = 1 + M/(2*r)
E = - 4*M / (2*r**2 + M*r)
dE = 4*M*(M+4*r)/( r**2 * (M+2*r)**2 )

# Go back to non-barred variables
def deregularize(u, derivative=False):
    if not derivative:
        return u*psi**4
    else:
        return u+E

# Symmetry conditions around origin
origin_sign = {
    "alpha":   +1,
    "D_alpha": -1,
    "A_bar":   +1,
    "B_bar":   +1,
    "D_A":  -1,
    "D_B":  -1,
    "D_A_bar":  -1,
    "D_B_bar":  -1,
    "K_A":     +1,
    "K_B":     +1,
}


# Asymptotic value at the infinity
u_inf = {
    "alpha":   1,
    "D_alpha": 0,
    "A_bar":   1,
    "B_bar":   1,
    "D_A": 0,
    "D_B": 0,
    "D_A_bar": 0,
    "D_B_bar": 0,
    "K_A":     0,
    "K_B":     0,
}


# Spatial derivative with embedded boundary conditions
def deriv(u, varname):
    N = Nr
    h = dr

    # Create the array
    u_prime = np.zeros_like(u)
    # Calculate everywhere except the boundaries
    # u_prime[1:-1] = (u[2:]-u[:-2]) / (2*h)
    for i in np.arange(1, N-1):
        u_prime[i] = (u[i+1]-u[i-1]) / (2*h)
    # For the origin
    u_prime[0] = (u[1]-origin_sign[varname]*u[0]) / (2*h)
    # For the infinity
    u_prime[-1] = (1/(r[-1])) * (u_inf[varname]-u[-1]) # Robin
    # if varname.startswith("D"):
    #     u_prime[-1] = (u[-3]-4*u[-2]+3*u[-1])/(2*h)
    return u_prime

# Initial state vector
U0 = {
    "alpha":     1 * just_like(r),
    "A_bar":     1 * just_like(r),
    "B_bar":     1 * just_like(r),
    "D_A_bar":   0 * just_like(r),
    "D_B_bar":   0 * just_like(r),
    "K_A":       0 * just_like(r),
    "K_B":       0 * just_like(r),
}

# Function to calculate the RHS for a variable using the state vector
def rhs(U, varname):
    # More usable variables
    alpha = U["alpha"]
    A_bar = U["A_bar"]
    B_bar = U["B_bar"]
    D_A_bar = U["D_A_bar"]
    D_B_bar = U["D_B_bar"]
    K_A = U["K_A"]
    K_B = U["K_B"]

    # Recover non-barred variables
    A = deregularize(A_bar) 
    B = deregularize(B_bar)
    D_A = deregularize(D_A_bar, derivative=True)
    D_B = deregularize(D_B_bar, derivative=True)

    # Necessary calculations
    D_alpha = deriv(alpha, 'alpha')/alpha
    
    RHS = {
            "alpha":     lambda: slicing_rhs(slicing, U),
            "A_bar":     lambda: -2 * alpha * A_bar * K_A,
            "B_bar":     lambda: -2 * alpha * B_bar * K_B,
            "D_A_bar":   lambda: -2 * alpha * (K_A * D_alpha + deriv(K_A, "K_A")),
            "D_B_bar":   lambda: -2 * alpha * (K_B * D_alpha + deriv(K_B, "K_B")),
            "K_A":       lambda: - alpha/A * ( deriv(D_alpha, "D_alpha") + deriv(D_B_bar, "D_B_bar") +dE + D_alpha**2 - 0.5 * D_alpha*D_A + 0.5*D_B**2 - 0.5*D_A*D_B - (1/r) * (D_A-2*D_B)  ) + alpha * K_A * (K_A + 2*K_B),
            "K_B":       lambda: - alpha/(2*A) * (  deriv(D_B_bar, "D_B_bar") + dE + D_alpha*D_B + D_B**2 - 0.5*D_A*D_B - (1/r)*(D_A-2*D_alpha-4*D_B) - 2*(A-B)/(r**2*B)   ) + alpha * K_B * (K_A + 2*K_B),
        }
    return RHS[varname]()

# Hamiltonian constraint
def hamiltonian(U):
    # More usable variables
    A_bar = U["A_bar"]
    B_bar = U["B_bar"]
    D_A_bar = U["D_A_bar"]
    D_B_bar = U["D_B_bar"]
    K_A = U["K_A"]
    K_B = U["K_B"]

    # Recover non-barred variables
    A = deregularize(A_bar) 
    B = deregularize(B_bar)
    D_A = deregularize(D_A_bar, derivative=True)
    D_B = deregularize(D_B_bar, derivative=True)

    H = -deriv(D_B, 'D_B') + (1/(r**2*B))*(A-B) + A*K_B*(2*K_A+K_B) + (1/r)*(D_A-3*D_B) + 0.5*D_A*D_B - (3/4)*D_B**2
    return H


# Apparent horizon finder
def find_AH(U):
    # More usable variables
    A_bar = U["A_bar"]
    B_bar = U["B_bar"]
    K_B = U["K_B"]

    # Recover irregular variables
    A = deregularize(A_bar)
    B = deregularize(B_bar)

    # Expansion parameter
    H = (1/np.sqrt(A)) * (2/r + deriv(B_bar, 'B_bar')/B_bar + E) - 2*K_B

    ## Simple method with finding root of H. Results in quite some grid noise.
    # AH_index = np.argmin(np.abs(H))
    # r_AH = r[AH_index]
    # S_AH = 4*np.pi*B[AH_index]*r_AH**2

    ham = hamiltonian(U)
    sane_mask = r < r_sane
    # Create H interpolator for the root finder
    iH = scipy.interpolate.interp1d(r, H)
    # Find the root
    root = scipy.optimize.root_scalar(iH, bracket=[r[0], r[-1]])
    r_AH = root.root

    # Interpolate B to the r_AH found above
    iB = scipy.interpolate.interp1d(r, B)
    S_AH = 4*np.pi*iB(r_AH)*r_AH**2
    
    ## Plot the expansion parameter

    varname = 'H'
    ax = plt.gca()
    ax.set_xlabel('r')
    ax.set_ylabel(varname)
    ax.plot(r, H)
    ax.set_title(f't={t}')
    ax.axhline(0, linestyle='dashed', color='gray')
    os.makedirs(f'frames/{varname}', exist_ok=True)
    plt.savefig(f'frames/{varname}/{varname}.{n:04d}.png')
    plt.close('all')

    return {
        "r": r_AH,
        "S": S_AH,
        "M": np.sqrt(S_AH/(16*np.pi)),
    }

## Beginning of the evolution

t=0
n=0
U = U0

# Plotting finction
def plot_variable(U, varname):
    ax = plt.gca()
    ax.set_xlabel('r')
    ax.set_ylabel(varname)
    v = None
    if isinstance(U, dict):
        try:
            v = U[varname]
        except KeyError:
            if varname == "K":
                v = U["K_A"] + 2 * U["K_B"]
    else:
        v = U
    # mask = np.logical_and(r_max-1 < r, r < r_max)
    # # mask = np.logical_and(0 < r, r < 1)
    # ax.plot(r[mask],v[mask])
    ax.plot(r, v)
    ax.set_title(f't={t}')
    if varname == "alpha":
        ax.set_ylim((0, 1.2))
    os.makedirs(f'frames/{varname}', exist_ok=True)
    plt.savefig(f'frames/{varname}/{varname}.{n:04d}.png')
    plt.close('all')

# Evolve
# To store AH parameter time series
AHs = {
    "time": np.array([]),
    "r": np.array([]),
    "S": np.array([]),
    "M": np.array([]),
}

variables = U.keys()

## Do the evolution
while t < t_final:
    if plot_every_nth_iteration is not None:
        if n%plot_every_nth_iteration == 0:
            for varname in plot_variables:
                plot_variable(U, varname)
            ham = np.log10(np.abs(hamiltonian(U)))
            plot_variable(ham, "Hamiltonian Constraint")

    for var in variables:
        # Stop the execution if there is a NaN
        if has_nan(U[var]):
            raise Exception(f'Variable {var} has NaN, terminating at t={t-dt:.3f}')

    # RK4
    k1 = {var: rhs(U, var) for var in variables}

    k2 = {var: rhs({var: U[var] + dt/2*k1[var] for var in variables}, var) for var in variables}
    
    k3 = {var: rhs({var: U[var] + dt/2*k2[var] for var in variables}, var) for var in variables}

    k4 = {var: rhs({var: U[var] + dt*k3[var] for var in variables}, var ) for var in variables}


    U_next = {var: U[var] + (k1[var]/6 + k2[var]/3 + k3[var]/3 + k4[var]/6)*dt for var in variables}
    
    # Look for the apparent horizon
    AH = find_AH(U_next)
    print(f'AH (t={t:.3f}): r={AH["r"]:.3f}, S={AH["S"]:.3f}, M={AH["M"]:.3f}')

    # Save the AH parameters
    AHs["time"] = np.append(AHs["time"], t)
    AHs["r"] = np.append(AHs["r"], AH["r"])
    AHs["S"] = np.append(AHs["S"], AH["S"])
    AHs["M"] = np.append(AHs["M"], AH["M"])
            
    U = U_next

    # Increase time and iteration
    t += dt
    n += 1


for AHparameter in ["r", "S", "M"]:
    ax = plt.gca()
    ax.plot(AHs["time"], AHs[AHparameter])
    if AHparameter == "S":
        ax.set_ylim((0, 1.2*16*np.pi*M**2))
    if AHparameter == "M":
        ax.set_ylim((0, 1.2*M))
    plt.savefig(f'frames/AH_{AHparameter}.png')
    plt.close('all')
