# Auteur : Clément de la Salle
# Agrégation de physique, ENS de Lyon, 2019-2020

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons

def tracer(event) :
    
    temps = np.linspace(0, sTmax.val, 1000)
    ax1.set_xlim(0, sTmax.val)
    modele.set_data(temps, np.cos(2*np.pi*sf0.val * temps))
    
    if ech_bool :
        
        x = np.arange(0, sTmax.val + 1/sfe.val, 1/sfe.val)
        N = len(x)
        a = np.concatenate((x.reshape(N, 1), x.reshape(N, 1)), axis = 1).reshape(2 * N)
        if N%2 == 0 : b = np.array([[-2, 2, 2, -2]]).repeat(int(N/2), axis = 0).reshape(2 * N)
        else : b = np.array([[-2, 2, 2, -2]]).repeat(int(N/2) + 1, axis = 0).reshape(2 * (N + 1))[:-2]
    
    else :

        x = temps
        a, b = [], []
        
    if quant_bool :
        
        Q = int(2 ** sQ.val)
        inc = 2 / (Q - 1)
        echelle = np.linspace(-1, 1, Q)
        y = ((np.cos(2*np.pi*sf0.val * x) + 1 + inc / 2) // inc).astype(np.int)
        y = echelle[y]
        d = np.concatenate((echelle.reshape(Q, 1), echelle.reshape(Q, 1)), axis = 1).reshape(2 * Q)
        if Q%2 == 0 : c = np.array([[-1, sTmax.val + 1, sTmax.val + 1, -1]]).repeat(int(Q/2), axis = 0).reshape(2 * Q)
        else : c = np.array([[-1, sTmax.val + 1, sTmax.val + 1, -1]]).repeat(int(Q/2) + 1, axis = 0).reshape(2 * (Q + 1))[:-2]
        
        
        
    else :
        
        y = np.cos(2*np.pi*sf0.val * x)
        c, d = [], []
        

    echantillonnage.set_data(a, b)    
    quantification.set_data(c, d)    
    signal.set_data(x, y)
    
    tracer_TF(1)

def onoff_ech(label) :
    
    global ech_bool
    
    if label == ' On' :
        ech_bool = True
        ax_fe.set_visible(True)
    else :
        ech_bool = False
        ax_fe.set_visible(False)
    
    tracer(1)

def onoff_quant(label) :
    
    global quant_bool
    
    if label == ' On' :
        quant_bool = True
        ax_Q.set_visible(True)
    else :
        quant_bool = False
        ax_Q.set_visible(False)
    
    tracer(1)

def tracer_TF(event) :
    
    a = np.fft.ifftshift(signal.get_ydata())
    A = np.fft.fft(a)
    B = np.abs(np.real(1/sfe.val * np.fft.fftshift(A)))
    n = len(signal.get_xdata())
    freq = np.fft.fftfreq(n, d = 1/sfe.val)
    f = np.fft.fftshift(freq)

    B = np.concatenate((np.zeros((len(B), 1)), B.reshape(len(B), 1), np.zeros((len(B), 1))), axis = 1).reshape(3 * len(B))
    f = np.concatenate((f.reshape(len(f), 1), f.reshape(len(f), 1), f.reshape(len(f), 1)), axis = 1).reshape(3 * len(f))
    
    ax2.set_xlim(0, f.max())
    ax2.set_ylim(-.1 * B.max(), 1.1 * B.max())
    
    TF.set_data(f, B)
    plt.draw()
    
    

fig = plt.figure()
ax1 = fig.add_subplot(121, ylim = (-1.1, 1.1))
ax1.set_title('Signal temporel', size = 22, pad = 10)
ax2 = fig.add_subplot(222)
ax2.set_title('FFT', size = 22, pad = 10)
TF, = ax2.plot([], [])

modele, = ax1.plot([], [])
echantillonnage, = ax1.plot([], [], ':', color = (.8, .8, .8))
quantification, = ax1.plot([], [], ':', color = (.8, .8, .8))
signal, = ax1.plot([], [], marker = '.', color = 'red')

ax_f0 = fig.add_axes([.6, .1, .25, .04])
sf0 = Slider(ax_f0, r'$f_0$', 0.01, 5, valinit = 1)
sf0.on_changed(tracer)
sf0.label.set_size(20)
sf0.valtext.set_size(20)

ax_Tmax = fig.add_axes([.6, .15, .25, .04])
sTmax = Slider(ax_Tmax, r'$T_{max}$', 0.01, 30, valinit = 5)
sTmax.on_changed(tracer)
sTmax.label.set_size(20)
sTmax.valtext.set_size(20)

ax_ech = fig.add_axes([.58, .3, .05, .1])
ronoff_ech = RadioButtons(ax_ech, [' On', ' Off'], active = 0)
ronoff_ech.on_clicked(onoff_ech)
for lab in ronoff_ech.labels :
    lab.set(size = 16)
ax_ech.set_title('Échantillonnage', size = '20', pad = 10)

ax_fe = fig.add_axes([.52, .23, .15, .04])
sfe = Slider(ax_fe, r'$f_e$', 0.01, 50, valinit = 10)
sfe.on_changed(tracer)
sfe.label.set_size(20)
sfe.valtext.set_size(20)

ax_quant = fig.add_axes([.82, .3, .05, .1])
ronoff_quant = RadioButtons(ax_quant, [' On', ' Off'], active = 1)
ronoff_quant.on_clicked(onoff_quant)
for lab in ronoff_quant.labels :
    lab.set(size = 16)
ax_quant.set_title('Quantification', size = '20', pad = 10)

ax_Q = fig.add_axes([.78, .23, .15, .04])
sQ = Slider(ax_Q, r'$Bits$', 1, 10, valinit = 4, valstep = 1)
sQ.on_changed(tracer)
sQ.label.set_size(20)
sQ.valtext.set_size(20)

# ax_voir_TF = fig.add_axes([.82, .5, .1, .08])
# bouton_voir_TF = Button(ax_voir_TF, 'Voir TF')
# bouton_voir_TF.label.set_size(18)
# bouton_voir_TF.on_clicked(tracer_TF)

ech_bool = True
quant_bool = False
ax_Q.set_visible(False)


tracer(1)
mng = plt.get_current_fig_manager()       # Vous pouvez décommenter ça si vous utilisez 'TkAgg'
mng.window.state('zoomed')
plt.show()