import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as pat
import os
import scipy as sp
import scipy.stats as stats
import math
import dask
import dask.array as da
import dask.dataframe as dd
from PIL import Image

from crpropa import *
from HadronicInteraction import *

Tp = np.logspace(-2, 10, 12*25 + 1) # 25 bins per logE
eps = np.logspace(-2, 10, 12*15 + 1) # 15 bins per logE


def Data_loader(file):
    if len(file.split("_")) == 4:
        Model, Particles, PriEnergy, CountString = file.split("_")
    else:
        Model, Cut, Particles, PriEnergy, CountString = file.split("_")

    CountPower = int(list(CountString)[2])
    if PriEnergy == "FS":
        s = 1
    elif PriEnergy.startswith("FS") and PriEnergy != "FS":
        s = 2 - int(list(PriEnergy)[2])

    # columns = ["D", "SN", "ID", "E", "SN0", "E0", "SN1", "W"]
    columns = ["D", "SN", "ID", "E", "X", "Y", "Z", "Px", "Py", "Pz", "SN0", "ID0", "E0", "X0", "Y0", "Z0", "SN1", "W", "tag"]
    Data = dd.read_csv(f"/home/scratch/aka/AO_AGNC/1e{CountPower}_Candidates/{file}.txt", comment='#', delimiter='\t', names=columns)
    
    return Data, s, CountPower, Particles, PriEnergy


def cal_J_dask(data, s, CountPower):
    bins = np.logspace(-4, 6, 150 + 1) *GeV/EeV # 15 bins per logE

    E = data["E"].to_dask_array(lengths=True)
    weights = (data["W"] * data["E0"]**(-s)).to_dask_array(lengths=True)

    n, bin_lim = da.histogram(E, bins=bins, weights=weights)
    n, bin_lim = dask.compute(n, bin_lim)

    if s == 0:
        N0 = 10**CountPower / (1e5 - 10)
    elif s == 1:
        N0 = 10**CountPower / np.log(1e5/10)
    elif s == 2:
        N0 = 10**CountPower / (1/(10) - 1/(1e5))

    J = n / np.diff(bin_lim) / N0
    return J


def cal_J_error_dask(data, s, CountPower):
    bins = np.logspace(1, 6, 75 + 1) # 15 bins per logE

    E = data["E"].to_dask_array(lengths=True)
    weights = (data["W"] * data["E0"]**(-s)).to_dask_array(lengths=True)

    n_err, bin_lim = da.histogram(E, bins=bins, weights=weights**2)
    n_err, bin_lim = dask.compute(n_err, bin_lim)

    if s == 0:
        N0 = 10**CountPower / (1e5 - 10)
    elif s == 1:
        N0 = 10**CountPower / np.log(1e5/10)
    elif s == 2:
        N0 = 10**CountPower / (1/(10) - 1/(1e5))

    J_err = n_err / np.diff(bin_lim) / N0
    return J_err


def Save_J(file, Particles, PriEnergy, J1, J2=None, J3=None):
    folder = f"/home/home1/aka/SOWAS/Analysis/Histograms/AGNC/{PriEnergy}/"
    os.makedirs(folder, exist_ok=True)

    if Particles == "elph":
        np.savetxt(f"{folder}Histogram_of_el_{file}.txt", J1)
        np.savetxt(f"{folder}Histogram_of_ph_{file}.txt", J2)
    elif Particles == "all":
        np.savetxt(f"{folder}Histogram_of_el_{file}.txt", J1)
        np.savetxt(f"{folder}Histogram_of_ph_{file}.txt", J2)
        np.savetxt(f"{folder}Histogram_of_nu_{file}.txt", J3)
    else:
        np.savetxt(f"{folder}Histogram_of_{file}.txt", J1)


def Save_J_error(file, Particles, PriEnergy, J1, J2=None, J3=None):
    folder = f"/home/home1/aka/SOWAS/Analysis/Histograms/AGNC/{PriEnergy}_err/"
    os.makedirs(folder, exist_ok=True)

    if Particles == "elph":
        np.savetxt(f"{folder}Histogram_of_el_{file}_err.txt", J1)
        np.savetxt(f"{folder}Histogram_of_ph_{file}_err.txt", J2)
    elif Particles == "all":
        np.savetxt(f"{folder}Histogram_of_el_{file}_err.txt", J1)
        np.savetxt(f"{folder}Histogram_of_ph_{file}_err.txt", J2)
        np.savetxt(f"{folder}Histogram_of_nu_{file}_err.txt", J3)
    else:
        np.savetxt(f"{folder}Histogram_of_{file}_err.txt", J1)


def Pipeline(file):
    Data, s, CountPower, Particles, PriEnergy = Data_loader(file)

    if Particles == "elph":
        Data_el = Data[Data["ID"].isin([11, -11])]
        Data_ph = Data[Data["ID"].isin([22])]

        J_el = cal_J_dask(Data_el, s, CountPower)
        J_ph = cal_J_dask(Data_ph, s, CountPower)
        Save_J(file, Particles, PriEnergy, J_el, J_ph)
    
    elif Particles == "all":
        Data_el = Data[Data["ID"].isin([11, -11])]
        Data_ph = Data[Data["ID"].isin([22])]
        Data_nu = Data[Data["ID"].isin([12, -12, 14, -14])]

        J_el = cal_J_dask(Data_el, s, CountPower)
        J_ph = cal_J_dask(Data_ph, s, CountPower)
        J_nu = cal_J_dask(Data_nu, s, CountPower)
        Save_J(file, Particles, PriEnergy, J_el, J_ph, J_nu)
    
    else:
        J = cal_J_dask(Data, s, CountPower)
        Save_J(file, Particles, PriEnergy, J)


def Error_Pipeline(file):
    Data, s, CountPower, Particles, PriEnergy = Data_loader(file)

    if Particles == "elph":
        Data_el = Data[Data["ID"].isin([11, -11])]
        Data_ph = Data[Data["ID"].isin([22])]

        J_el = cal_J_error_dask(Data_el, s, CountPower)
        J_ph = cal_J_error_dask(Data_ph, s, CountPower)
        Save_J_error(file, Particles, PriEnergy, J_el, J_ph)
    
    elif Particles == "all":
        Data_el = Data[Data["ID"].isin([11, -11])]
        Data_ph = Data[Data["ID"].isin([22])]
        Data_nu = Data[Data["ID"].isin([12, -12, 14, -14])]

        J_el = cal_J_error_dask(Data_el, s, CountPower)
        J_ph = cal_J_error_dask(Data_ph, s, CountPower)
        J_nu = cal_J_error_dask(Data_nu, s, CountPower)
        Save_J_error(file, Particles, PriEnergy, J_el, J_ph, J_nu)
    
    else:
        J = cal_J_error_dask(Data, s, CountPower)
        Save_J_error(file, Particles, PriEnergy, J)


# Files0 = ["AO_Cut=50GeV_neutrinos_FS0_1e6C", "AO_Cut=100GeV_neutrinos_FS0_1e6C",
#          "AO_Cut=200GeV_neutrinos_FS0_1e6C", "AO_Cut=500GeV_neutrinos_FS0_1e6C",
#          "AO_Cut=1000GeV_neutrinos_FS0_1e6C", "AO_Cut=10000GeV_neutrinos_FS0_1e6C"]

# Files = ["AO_Cut=50GeV_neutrinos_FS2_1e6C", "AO_Cut=100GeV_neutrinos_FS2_1e6C",
#          "AO_Cut=200GeV_neutrinos_FS2_1e6C", "AO_Cut=500GeV_neutrinos_FS2_1e6C",
#          "AO_Cut=1000GeV_neutrinos_FS2_1e6C", "AO_Cut=10000GeV_neutrinos_FS2_1e6C"]

# Files = []
# FSs = ["FS1", "FS2"]
# for FS in FSs:
#     Files.append([f"AO_Cut=50GeV_all_{FS}_1e7C", f"AO_Cut=100GeV_all_{FS}_1e7C", 
#                   f"AO_Cut=200GeV_all_{FS}_1e7C", f"AO_Cut=500GeV_all_{FS}_1e7C", 
#                   f"AO_Cut=1000GeV_all_{FS}_1e7C", f"AO_Cut=2000GeV_all_{FS}_1e7C", 
#                   f"AO_Cut=5000GeV_all_{FS}_1e7C", f"AO_Cut=10000GeV_all_{FS}_1e7C", 
#                   f"AAfrag_all_{FS}_1e7C", f"ODDK_elph_{FS}_1e7C"])

# Files = np.array(Files).flatten()

# for F in Files:
#     Pipeline(F)
#     print(f"{F} done")

Pipeline("ODDK_all_FS1_1e5")
print("ODDK_all_FS1_1e5 done")
Pipeline("AAfrag_all_FS1_1e5")
print("AAfrag_all_FS1_1e5 done")
