import numpy as np
import pylab as plt
from scipy.optimize import curve_fit

# Création des listes de données
X_tab =    [0,1,2] # ordre de diffraction m
errX_tab = [0,0,0] # erreur nulle
Y_tab =    [124.00,116.10,107.55] # angle theta lu [degré]
#Y_tab =    [124.00,115.50,107.28] # angle theta lu [degré]
#Y_tab =    [124.00,115.20,106.28] # angle theta lu [degré]
#Y_tab =    [124.00,113.00,101.31] # angle theta lu [degré]
errY_tab = [0.04,0.04,0.04] # erreur sur angle lu [degré]

# Conversion de ces listes en tableaux (pour traiter les données)
X_tab = np.array(X_tab)
Y_tab = np.array(Y_tab)
errX_tab = np.array(errX_tab)
errY_tab = np.array(errY_tab)

# Création des bonnes variables à tracer
X_tab = X_tab
errX_tab = errX_tab
Y_origine = 124.00 # [degré]
Y_tab = Y_origine - Y_tab # calcul de l'angle theta
Y_tab = Y_tab*np.pi/180. # conversion degré en radian
errY_tab = errY_tab*2 # erreur doublée due à Y_origine
errY_tab = errY_tab*np.pi/180.
errY_tab = np.cos(Y_tab)*errY_tab
Y_tab = np.sin(Y_tab)

# Définition de la fonction pour le calcul du chi^2 et du chi_réduit^2
# Bally, Berroir 2008, Incertitudes expérimentales
# chi_reduit^2 ~ 1 : bon accord avec le modèle
# chi_reduit^2 >> 1 : modèle non validé
# chi_reduit^2 << 1 : incertitudes surestimées ?
def chi2(sigma_exp,y_th,y_exp,nb_param):
    chi2 = sum(((y_th - y_exp)/sigma_exp)**2)
    nDOF = len(y_exp) - nb_param
    chi_red_2 = chi2/(nDOF)
    return chi2,chi_red_2

# Régression linéaire sur Y=f(X)
def flin(x,a,b):
    return a*x+b
p, covm = curve_fit(flin,X_tab,Y_tab)
a,b = p

# Compute chi2
chi2_lin,chi_red_2_lin = chi2(errY_tab,flin(X_tab,a,b),Y_tab,2)

# Correction des barres d'erreur (kesako? à voir...)
erra, errb = np.sqrt(np.diag(covm)/chi_red_2_lin)

# Figure
fig, ax = plt.subplots(1)
ax.errorbar(X_tab,Y_tab,errY_tab,errX_tab,fmt='.',label="données")
Y_fit = flin(X_tab,a,b)
ax.plot(X_tab,Y_fit,label="ajustement")
textstr = "y(x) = a*x + b\n\
a = %.1e +/- %.1e \n\
b = %.1e +/- %.1e \n\
chi_red_2 = %.2f"  %(a,erra,b,errb,chi_red_2_lin)
ax.text(0.45, 0.95, textstr,transform=ax.transAxes, fontsize=12, verticalalignment='top')
ax.legend(loc=8)
plt.ylabel("$sinus\ angle \, ()$", fontsize=16)
plt.xlabel("$ordre\ m \, ()$", fontsize=16)
#plt.axis([-1,3,-0.05,0.35])
plt.show()

# Print de quelques informations
# pas = 3.379e-6
pas = 1e-3/300
delta_pas = 8e-9
delta_lambd = np.sqrt(a**2*delta_pas**2+pas**2*erra**2)
print("lambda = ",a*pas)
print("delta_lamdba = ",delta_lambd)