# Auteur : Clément de la Salle
# Agrégation de physique, ENS de Lyon, 2019-2020

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from time import time
from matplotlib.widgets import Slider, Button, RadioButtons
from sympy import *
from tkinter import *
from scipy.optimize import brentq
from scipy.integrate import odeint
from scipy.signal import argrelextrema

def index(txt, elt) :
    if elt in txt : return txt.index(elt)
    else : return np.inf

def super_split(chaines, *args) :
    '''Permet de couper une (ou plusieurs !) chaines de caractère suivant plusieurs éléments en même temps.'''
    if type(chaines) == str :
        chaines = [chaines]
    if len(args) == 1 :
        return une_liste([[*txt.split(args[0])] for txt in chaines])
    else :
        chaines = une_liste([[*txt.split(args[0])] for txt in chaines])
        return super_split(chaines, *args[1:])

def une_liste(liste) :
    '''Permet de mettre tous les éléments d'une liste de listes à la suite'''
    new = []
    for list in liste :
        for elt in list :
            new.append(elt)
    return new

def remplacer(txt, avant, apres) :
    '''
    On indique dans les arguments quel élément on veut remplacer par quel autre.
    C'est mieux que la fonction déjà présente 'replace' puisque toutes les transformations se font en même temps.
    '''
    
    i = 0
    nouveau = ''
    ind = 0
    while i < len(txt) :
        truc = np.array([[i + index(txt[i:], elt), j, len(elt)] for j, elt in enumerate(avant)])
        if truc[:, 0].min() < np.inf :
            idx = np.lexsort(truc.transpose())
            truc = np.array([truc[m, :] for m in idx])[::-1]
            ind = int(truc[:, 0].min())
            argind = np.abs(truc[:, 0] - ind).argmin()
            nouveau = nouveau + txt[i : ind] + apres[int(truc[argind, 1])]
            i = ind + int(truc[argind, 2])
        else :
            nouveau = nouveau + txt[i:]
            i = len(txt)
    return nouveau

alphabet = 'azertyuiopqsdfghjklmwxcvbn'
chiffres = ['.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
symboles = [' ', '=', '+', '-', '*', '/', '**', '(', ')', '{', '}', '^', '.']
operateurs = ['cos', 'sin', 'tan', 'exp', 'ln', 'log', 'cosh', 'sinh', 'tanh', 'acos', 'asin', 'atan']
Caracteres_reconnus = chiffres + symboles + operateurs + ['x']
operateurs_numpy = ['np.' + op for op in operateurs]

class Controler(Frame) :
    
    def __init__(self, fenetre) :
        
        Frame.__init__(self, fenetre)
        self.fenetre = fenetre
        self.autre = Frame(fenetre)
        self.formule = Entry(self)
        self.formule.grid(row = 1, column = 1)
        Label(self, text = 'Forme du potentiel').grid(row = 0, column = 1)
        Label(self, text = 'y(x) = ').grid(row = 1, column = 0)
        Button(self, text = 'OK', command = self.OK).grid(row = 1, column = 2)
        fenetre.bind('<Return>', self.OK_bis)
        self.pack()
    
    def OK(self) :
        
        global Y, Params_label, Exp
        
        Exp = self.formule.get()
        Exp = remplacer(Exp, [' '], [''])
        Exp1 = remplacer(Exp, Caracteres_reconnus, [''.join(['_' for k in range(len(c))]) for c in Caracteres_reconnus])
        Params_label = []
        Y = ''
        i = 0
        j = 0
        for k, elt in enumerate(Exp1) :
            if elt != '_' :
                Params_label.append(elt)
                Y = Y + Exp[i:k] + 'S[' + str(j) + '].val'
                j += 1
                i = k + 1
        Y = Y + Exp[i:]
        Exp = Y
        Y = remplacer(Y, operateurs + ['x'], operateurs_numpy + ['X'])
        self.fenetre.destroy()
    
    def OK_bis(self, event) :
        
        self.OK()

def step(X, t) :
    
    return [X[1], -F(X[0]) - X[1] * stau.val]

def tracer(event) :
    
    global F, xmin, xmax

    F = lambdify(x, eval(Exp).diff(x))
    X = np.linspace(ssx.val - sLx.val, ssx.val + sLx.val, 10000)
    Y_array = eval(Y)
    line1.set_data(X, Y_array)
    ax1.set_xlim(ssx.val - sLx.val, ssx.val + sLx.val)
    D = Y_array.max() - Y_array.min()
    ax1.set_ylim(Y_array.min() - .1 * D, Y_array.max() + .1 * D)
    line_em.set_data([], [])
    line_xmin_xmax.set_data([], [])
    if 'xmin' in globals() :
        del(xmin)
    
    plt.draw()
    
def click(event) :
    
    global em, xmin, xmax
    
    if str(event.inaxes)[:11] == 'AxesSubplot' :
        y = eval(Exp)
        f = lambdify(x, y)
        Y_array = eval(Y)
        
        if event.ydata > f(event.xdata) :
            em = event.ydata

            g = lambda t : f(t) - event.ydata
            XL = np.linspace(-L, L, L * 1000)
            Xav = XL[XL < event.xdata]
            Xap = XL[XL > event.xdata]
            YL = remplacer(Y, ['X'], ['XL'])
            YL_array = eval(YL)
            Yav = YL_array[XL < event.xdata]
            Yap = YL_array[XL > event.xdata]
            xmin, xmax = -np.inf, np.inf
            
            idx = np.argwhere(np.diff(np.sign(YL_array - em)) != 0).reshape(-1)
            
            if False in (XL[idx] > event.xdata) and True in (XL[idx] > event.xdata) :
                idx_max = idx[(XL[idx] > event.xdata).argmax()]
                idx_min = idx[(XL[idx] > event.xdata).argmax() - 1]
                xmin = brentq(g, event.xdata, XL[idx_min - 1])
                xmax = brentq(g, event.xdata, XL[idx_max + 1])
                line_xmin_xmax.set_data([xmin, xmax], [em, em])
                line_em.set_data([xmin, xmax], [em, em])

            elif False in (XL[idx] > event.xdata) and not(True in (XL[idx] > event.xdata)) :
                xmin = brentq(g, event.xdata, XL[idx[-1] - 1])
                xmax = np.inf
                line_em.set_data([xmin, ssx.val + sLx.val], [em, em])
                line_xmin_xmax.set_data([xmin], [em])

            elif not(False in (XL[idx] > event.xdata)) and True in (XL[idx] > event.xdata) :
                idx_max = idx[0]
                xmin = -np.inf
                xmax = brentq(g, event.xdata, XL[idx[0] + 1])
                line_em.set_data([ssx.val - sLx.val, xmax], [em, em])
                line_xmin_xmax.set_data([xmax], [em])
            
            else :
                xmin = 0
                line_em.set_data([ssx.val - sLx.val, ssx.val + sLx.val], [em, em])
                line_xmin_xmax.set_data([], [])
                
    tracer2(1)
    
def tracer2(event) :
    
    N = 1000
    
    if 'xmin' in globals() :

        if xmin == 0 :
            X = 0
            sol = odeint(step, [0, np.sqrt(2 *(em - eval(Y)))], np.linspace(0, sT.val, N))
        elif xmin != -np.inf : sol = odeint(step, [xmin + 0.001, 0], np.linspace(0, sT.val, N))
        else : sol = odeint(step, [xmax - 0.001, 0.0], np.linspace(0, sT.val, N))
        X, V = sol[:, 0], sol[:, 1]
        
        if type == 'Temporel' :
            
            line2.set_data(np.linspace(0, sT.val, N), X)
            D = X.max() - X.min()
            ax2.set_aspect('auto')
            ax2.set_xlim(0, sT.val)
            ax2.set_ylim(X.min() - .1 * D, X.max() + .1 * D)
        
        elif type == 'Phase' :
            
            line2.set_data(X, V)
            Dx = X.max() - X.min()
            Dv = V.max() - V.min()
            ax2.set_aspect('equal')
            ax2.set_xlim(X.min() - .1 * Dx, X.max() + .1 * Dx)
            ax2.set_ylim(V.min() - .1 * Dv, V.max() + .1 * Dv)
        
        elif type == 'Fourier' :
            
            X = X[::10]
        
            a = np.fft.ifftshift(X)
            A = np.fft.fft(a)
            B = np.abs(np.real(10/N * np.fft.fftshift(A)))
            C = B[int(len(B) / 2):]
            
            extrema = argrelextrema(C, np.greater)[0]
            extrema = extrema[C[extrema] > .05]
            last_max = extrema[-1]
            C = C[: last_max + int(.1 * last_max / len(B))]

            n = X.size
            freq = np.fft.fftfreq(n, d = 10/N)
            f = np.fft.fftshift(freq)
            f = f[int(len(f) / 2) : int(len(f) / 2) + last_max + int(.1 * last_max / len(f))]
                     
            line2.set_data(f, C)
            ax2.set_aspect('auto')
            ax2.set_xlim(f.min() - .1 * (f.max() - f.min()), f.max())
            D = C.max() - C.min()
            ax2.set_ylim(C.min() - .1 * D, C.max() + .1 * D)
    
        plt.draw()

def change_ax2(label) :
    
    global type
    
    type = label
    tracer2(1)


x = Symbol('x')
L = 100
if 'xmin' in globals() : del(xmin)

fenetre = Tk()
interface = Controler(fenetre)
interface.mainloop()

fig = plt.figure()
ax1 = fig.add_subplot(121)
position = list(ax1.get_position().bounds)
position[1] += .16
position[3] -= .15
ax1.set_position(position)

A = []
S = []
for k, label in enumerate(Params_label) :
    A.append(fig.add_axes([.15, .18 - .05 * k, .3, .04]))
    S.append(Slider(A[k], label, 0, 10, valinit = 1))
    S[k].on_changed(tracer)
    exec(label + '= S[k].val')

axe_Lx = fig.add_axes([.13, .91, .14, .04])
sLx = Slider(axe_Lx, r'$L_x$', 0, 20, valinit = 5)
sLx.on_changed(tracer)

axe_sx = fig.add_axes([.32, .91, .14, .04])
ssx = Slider(axe_sx, r'$s_x$', -10, 10, valinit = 0)
ssx.on_changed(tracer)

X = np.linspace(ssx.val-sLx.val, ssx.val+sLx.val, 10000)
line1, = ax1.plot(X, eval(Y))
line_em, = ax1.plot([], [], color = (1, .5, 0))
line_xmin_xmax, = ax1.plot([], [], 'ro')
ax1.set_xlim(-sLx.val, sLx.val)
D = eval(Y).max() - eval(Y).min()
ax1.set_ylim(eval(Y).min() - .1 * D, eval(Y).max() + .1 * D)

fig.canvas.mpl_connect('button_press_event', click)

axe_T = fig.add_axes([.58, .91, .14, .04])
sT = Slider(axe_T, r'$T_{max}$', 0, 300, 10)
sT.on_changed(tracer2)

axe_tau = fig.add_axes([.76, .91, .14, .04])
stau = Slider(axe_tau, r'$f$', 0, .2, 0)
stau.on_changed(tracer2)

F = lambdify(x, eval(Exp).diff(x))

axe_radio = fig.add_axes([.78, .1, 0.1, .1])
bouton_radio = RadioButtons(axe_radio, ['Temporel', 'Phase', 'Fourier'], active = 1)
bouton_radio.on_clicked(change_ax2)

type = 'Phase'


ax2 = fig.add_subplot(122, aspect = 'equal')
position = list(ax2.get_position().bounds)
# position[1] += .16
# position[3] -= .07
ax2.set_position(position)

line2, = ax2.plot([], [])

mng = plt.get_current_fig_manager()       # Vous pouvez décommenter ça si vous utilisez 'TkAgg'
mng.window.state('zoomed')
plt.show()
