'''
Clément de la Salle
Préparation à l'agrégation de physique (ENS de Lyon)
2020

Version finale avec possibilité de modifier tous les paramètres directement pendant l'animation. Pour plus de confort, agrandir la fenêtre en plein écran, ou bien si vous utilisez le backend 'TkAgg', décommenter les deux lignes situées juste avant plt.show().

En cas de problème avec l'affichage, il existe une "Version simplifiée" à décommenter pour un affichage plus sobre. Pour modifier les paramètres (ajouter une onde, changer un angle...) il faudra alors quitter l'animation et modifier les listes dans la section "Paramètres d'utilisateur", normalement tout y est bien détaillé.

Détail des sliders
- V :               Vitesse de l'animation
- Lx, Ly :          Longeur de la matrice : pour un meilleur résultat, changer simultanément les deux et les laisser proches l'un de l'autre
- Nx, Ny :          [Mode "Positions"] Changer le nombre de points suivant chacun des axes
                    [Mode "Vitesses"] Changer le nombre de vecteurs suivant chacun des axes
- On/Off :          Pour afficher ou cacher les points ou les vecteurs (suivant le mode coché)
- 4 sliders gris :  Change les paramètres (angle, amplitude, fréquence ou phase) de l'onde sélectionnée
- - et + :          Pour retirer l'onde sélectionnée (bouton -) ou bien en ajouter une autre (bouton +)
- chiffres :        Permet de sélectionner une onde pour changer ses paramètres ou bien la supprmier
- Solo / Tutti :    Sert à chacher toutes les ondes sauf celle sélectionnée
- Start / Pause :   Ordonne au PC d'enclencher l'autodestruction immédiate... À utiliser en dernier recours !
'''


## Import des modules

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.widgets import Button, RadioButtons, Slider

## Définition des fonctions

def index(txt, elt) :
    if elt in txt : return txt.index(elt)
    else : return np.inf


def creer_matrice(Lx, Ly, Nx, Ny) :

    '''Creer la matrice des coordonnées de 0 à Lx avec Nx colonnes (et pareil avec les lignes en y)'''

    ligne = np.linspace(0, Lx, Nx)
    colonne = np.linspace(0, Ly, Ny)
    [a, b] = np.meshgrid(ligne, colonne)
    return np.concatenate((a.reshape(*a.shape, 1), b.reshape(*a.shape, 1)), axis = 2)

def start(event) :

    '''Active le bouton Start / Pause'''

    global play

    if play :
        ani.event_source.stop()
        ax.cla()
        ax.plot(*(Ecarts_ligne(t) + Matrice_ligne), 'bo', ms = 10 * min(Lx, Ly) / max(Nx, Ny))
    else : ani.event_source.start()
    play ^= True

def it_anim(k) :

    '''Renvoie le objets à tracer en fonction des choix de l'utilisateur (uniquement les points, uniquement les vecteurs ou les deux)'''

    global t

    t += dt

    if tracer_points and tracer_vecteurs :
        line.set_data(*(Ecarts_ligne(t) + Matrice_ligne))
        Q.set_UVC(*Vitesses_ligne(t))
        return line, Q,
    elif tracer_points and not tracer_vecteurs :
        line.set_data(*(Ecarts_ligne(t) + Matrice_ligne))
        return line,
    else :
        Q.set_UVC(*Vitesses_ligne(t))
        return Q,

def change_angle(val) :

    if solo_bool :
        angles[0] = sangle.val
        angles_tutti[onde] = sangle.val
    else :
        angles[onde] = sangle.val

def change_amp(val) :

    if solo_bool :
        amplitudes[0] = samp.val
        amplitudes_tutti[onde] = samp.val
    else :
        amplitudes[onde] = samp.val

def change_freq(val) :

    if solo_bool :
        frequences[0] = sfreq.val
        frequences_tutti[onde] = sfreq.val
    else :
        frequences[onde] = sfreq.val

def change_phase(val) :

    if solo_bool :
        phases[0] = sphase.val
        phases_tutti[onde] = sphase.val
    else :
        phases[onde] = sphase.val

def change_cisail(label) :

    if solo_bool :
        if label == 'Longitudinale' :
            cisail[0] = 0
            cisail_tutti[onde] = 0
        else :
            cisail[0] = 1
            cisail_tutti[onde] = 1
    else :
        if label == 'Longitudinale' : cisail[onde] = 0
        else : cisail[onde] = 1

def change_onde(label) :

    global onde, angles, frequences, amplitudes, phases, cisail

    onde = int(label) - 1
    if solo_bool :
        angles = [angles_tutti[onde]]
        amplitudes = [amplitudes_tutti[onde]]
        frequences = [frequences_tutti[onde]]
        phases = [phases_tutti[onde]]
        cisail = [cisail_tutti[onde]]
        sangle.set_val(angles[0])
        samp.set_val(amplitudes[0])
        sfreq.set_val(frequences[0])
        sphase.set_val(phases[0])
        boutton_cisail.set_active(cisail[0])
    else :
        sangle.set_val(angles[onde])
        samp.set_val(amplitudes[onde])
        sfreq.set_val(frequences[onde])
        sphase.set_val(phases[onde])
        boutton_cisail.set_active(cisail[onde])

def ajout_onde(event) :

    global boutton_change_onde,  boutton_ajout_onde, boutton_retrait_onde, angles, onde, angles, frequences, amplitudes, phases, cisail

    if solo_bool :
        angles_tutti.append(0)
        amplitudes_tutti.append(.1)
        frequences_tutti.append(1)
        phases_tutti.append(0)
        cisail_tutti.append(0)
        onde = len(angles_tutti) - 1
        angles = [angles_tutti[onde]]
        amplitudes = [amplitudes_tutti[onde]]
        frequences = [frequences_tutti[onde]]
        phases = [phases_tutti[onde]]
        cisail = [cisail_tutti[onde]]
    else :
        angles.append(0)
        amplitudes.append(.1)
        frequences.append(1)
        phases.append(0)
        cisail.append(0)
        onde = len(angles) - 1

    axonde.cla()
    positions = list(axonde.get_position().bounds)
    positions[3] += taille_onde
    axonde.set_position(positions)
    if solo_bool : boutton_change_onde = RadioButtons(axonde, [str(k+1) for k in range(len(angles_tutti))], active = onde)
    else : boutton_change_onde = RadioButtons(axonde, [str(k+1) for k in range(len(angles))], active = onde)
    boutton_change_onde.on_clicked(change_onde)

    positions = list(case_ajout_onde.get_position().bounds)
    positions[3] += taille_onde
    case_ajout_onde.set_position(positions)
    boutton_ajout_onde = Button(case_ajout_onde, '+')
    boutton_ajout_onde.on_clicked(ajout_onde)

    positions = list(case_retrait_onde.get_position().bounds)
    positions[3] += taille_onde
    case_retrait_onde.set_position(positions)
    boutton_retrait_onde = Button(case_retrait_onde, '-')
    boutton_retrait_onde.on_clicked(retrait_onde)

    change_onde(str(onde + 1))

def retrait_onde(event) :

    global boutton_change_onde, boutton_ajout_onde, boutton_retrait_onde, angles, amplitudes, frequences, phases, cisail, onde, angles_tutti, amplitudes_tutti, frequences_tutti, phases_tutti, cisail_tutti

    if solo_bool :
        angles_tutti = angles_tutti[:onde] + angles_tutti[onde+1:]
        amplitudes_tutti = amplitudes_tutti[:onde] + amplitudes_tutti[onde+1:]
        frequences_tutti = frequences_tutti[:onde] + frequences_tutti[onde+1:]
        phases_tutti = phases_tutti[:onde] + phases_tutti[onde+1:]
        cisail_tutti = cisail_tutti[:onde] + cisail_tutti[onde+1:]
        if onde == len(angles_tutti) : onde -= 1
        angles = [angles_tutti[onde]]
        amplitudes = [amplitudes_tutti[onde]]
        frequences = [frequences_tutti[onde]]
        phases = [phases_tutti[onde]]
        cisail = [cisail_tutti[onde]]

    else :
        angles = angles[:onde] + angles[onde+1:]
        amplitudes = amplitudes[:onde] + amplitudes[onde+1:]
        frequences = frequences[:onde] + frequences[onde+1:]
        phases = phases[:onde] + phases[onde+1:]
        cisail = cisail[:onde] + cisail[onde+1:]
        if onde == len(angles) : onde -= 1

    axonde.cla()
    positions = list(axonde.get_position().bounds)
    positions[3] -= taille_onde
    axonde.set_position(positions)
    if solo_bool : boutton_change_onde = RadioButtons(axonde, [str(k+1) for k in range(len(angles_tutti))], active = onde)
    else : boutton_change_onde = RadioButtons(axonde, [str(k+1) for k in range(len(angles))], active = onde)
    boutton_change_onde.on_clicked(change_onde)

    positions = list(case_ajout_onde.get_position().bounds)
    positions[3] -= taille_onde
    case_ajout_onde.set_position(positions)
    boutton_ajout_onde = Button(case_ajout_onde, '+')
    boutton_ajout_onde.on_clicked(ajout_onde)

    positions = list(case_retrait_onde.get_position().bounds)
    positions[3] -= taille_onde
    case_retrait_onde.set_position(positions)
    boutton_retrait_onde = Button(case_retrait_onde, '-')
    boutton_retrait_onde.on_clicked(retrait_onde)

    change_onde(str(onde + 1))

def points_vecteurs(label) :

    global ptsvec, axNx, sNx, axNy, sNy, boutton_onoff

    if label == 'Positions' :
        ptsvec = 'pts'
        axNx.cla()
        sNx = Slider(axNx, r'$N_x$', 0, 500, valinit = Nx, facecolor = 'blue')
        sNx.on_changed(change_Nx)
        axNy.cla()
        sNy = Slider(axNy, r'$N_y$', 0, 500, valinit = Ny, facecolor = 'blue')
        sNy.on_changed(change_Ny)
        col = 'blue'
    else :
        ptsvec = 'vec'
        axNx.cla()
        sNx = Slider(axNx, r'$N_x$', 0, Nx, valinit = min(nx, Nx), facecolor = 'red')
        sNx.on_changed(change_Nx)
        axNy.cla()
        sNy = Slider(axNy, r'$N_x$', 0, Ny, valinit = min(ny, Ny), facecolor = 'red')
        sNy.on_changed(change_Ny)
        col = 'red'

    boutton_onoff = Button(case_onoff, 'On / Off', color = col)
    boutton_onoff.on_clicked(onoff)

def change_Nx(val) :

    global Nx, nx, Q

    if ptsvec == 'pts' :
        Nx = int(sNx.val)
        matrice_positions()
        line.set_ms(10 * min(Lx, Ly) / max(Nx, Ny))

    else :
        nx = int(sNx.val)
        matrice_vitesses()
        truc = Vitesses_ligne(0)
        print('coucou')
        Q = ax.quiver(*matrice_ligne, *truc, color = 'red')


def change_Ny(val) :

    global Ny, ny, Q

    if ptsvec == 'pts' :
        Ny = int(sNy.val)
        matrice_positions()
        line.set_ms(10 * min(Lx, Ly) / max(Nx, Ny))

    else :
        ny = int(sNy.val)
        matrice_vitesses()
        Q = ax.quiver(*matrice_ligne, *Vitesses_ligne(0), color = 'red')

def change_Lx(val) :

    global Lx, Q

    Lx = sLx.val
    matrice_positions()
    matrice_vitesses()
    Q = ax.quiver(*matrice_ligne, *Vitesses_ligne(0), color = 'red')
    ax.set_xlim(-max(amplitudes), Lx + max(amplitudes))


def change_Ly(val) :

    global Ly, Q

    Ly = sLy.val
    matrice_positions()
    matrice_vitesses()
    Q = ax.quiver(*matrice_ligne, *Vitesses_ligne(0), color = 'red')
    ax.set_ylim(-max(amplitudes), Ly + max(amplitudes))

def onoff(event) :

    global tracer_points, tracer_vecteurs

    if ptsvec == 'pts' : tracer_points ^= True
    else : tracer_vecteurs ^= True

def change_V(val) :

    global dt
    dt = sV.val

def solo(event) :

    global solo_bool, angles_tutti, amplitudes_tutti, frequences_tutti, phases_tutti, angles, amplitudes, frequences, phases, cisail, cisail_tutti

    solo_bool ^= True
    if solo_bool :
        angles_tutti = angles.copy()
        amplitudes_tutti = amplitudes.copy()
        frequences_tutti = frequences.copy()
        phases_tutti = phases.copy()
        cisail_tutti = cisail.copy()
        angles = [angles[onde]]
        amplitudes = [amplitudes[onde]]
        frequences = [frequences[onde]]
        phases = [phases[onde]]
        cisail = [cisail[onde]]

    else :
        angles = angles_tutti
        amplitudes = amplitudes_tutti
        frequences = frequences_tutti
        phases = phases_tutti
        cisail = cisail_tutti


## Paramètres d'utilisateur

# Veillez à ce que ces quatres listes aient toujours autant d'éléments (autant que d'ondes à superposer)

angles = [0]            # Rentrer dans cette liste autant de valeur que l'on veut : le programme superpose toutes les ondes
frequences = [1]        # Renseigner les caractéristiques des ondes à superposer : fréquence...
amplitudes = [.1]       # ... amplitudes...
phases = [0]            # ... phase (utile à partir de 3 ondes seulement)...
cisail = [0]            # ... et même si c'est une onde P ou S !
cosin = ['np.cos(', '(-1) ** c * np.sin(']

Lx, Ly = 5, 5                       # Dimensions (réelle en distance)
Nx, Ny = 100, 100                   # Nombre de points dans les directions x et y
nx, ny = 5, 50                      # Nombre de vecteurs dans les directions x et y

t = 0       # Temps initial (aucune importance)
dt = 0.01   # Pas de temps pour l'animation

# Création de la matrice de repos et de la fonction donnant les écarts avec le temps

def matrice_positions() :

    global Matrice, Matrice_ligne, Champs

    Matrice = creer_matrice(Lx, Ly, Nx, Ny)
    Matrice_ligne = np.array([Matrice[:, :, 0].reshape(Nx * Ny), Matrice[:, :, 1].reshape(Nx * Ny)])
    Champs = lambda t, a, f, theta, phi, c : a * np.kron(np.cos(2*np.pi*f * (t - Matrice[:, :, 0] * np.cos(theta) - Matrice[:, :, 1] * np.sin(theta)) + phi), [eval(cosin[c] + 'theta)'), eval(cosin[c-1] + 'theta)')]).reshape(Ny, Nx, 2)

def Ecarts_ligne(t) :
    ecarts = sum(Champs(t, a, f, theta, phi, c) for a, f, theta, phi, c in zip(amplitudes, frequences, angles, phases, cisail))
    return np.array([ecarts[:, :, 0].reshape(Nx * Ny), ecarts[:, :, 1].reshape(Nx * Ny)])

# Pareil pour les vitesses qu'on représentera pas des vecteurs

def matrice_vitesses() :

    global matrice, matrice_ligne, Vitesses

    matrice = creer_matrice(Lx, Ly, nx, ny)
    matrice_ligne = np.array([matrice[:, :, 0].reshape(nx * ny), matrice[:, :, 1].reshape(nx * ny)])
    Vitesses = lambda t, a, f, theta, phi, c : min(Lx, Ly) / (2 * max(nx, ny)) * f * a * np.kron(-np.sin(2*np.pi*f * (t - matrice[:, :, 0] * np.cos(theta) - matrice[:, :, 1] * np.sin(theta)) + phi), [eval(cosin[c] + 'theta)'), eval(cosin[c-1] + 'theta)')]).reshape(ny, nx, 2)

def Vitesses_ligne(t) :
    v = sum(Vitesses(t, a, f, theta, phi, c) for a, f, theta, phi, c in zip(amplitudes, frequences, angles, phases, cisail))
    return np.array([v[:, :, 0].reshape(nx * ny), v[:, :, 1].reshape(nx * ny)])

matrice_positions()
matrice_vitesses()


## Version définitive

onde = 0
solo_bool = False

fig = plt.figure()

taille_onde = 0.03

case_start = fig.add_axes([0.82, 0.1, 0.1, 0.075])
boutton_start = Button(case_start, 'Play / Pause')
boutton_start.on_clicked(start)

case_solo = fig.add_axes([0.82, 0.2, 0.1, 0.075])
boutton_solo = Button(case_solo, 'Solo / Tutti')
boutton_solo.on_clicked(solo)

case_ajout_onde = fig.add_axes([0.72, 0.1, 0.04, taille_onde * len(angles)])
boutton_ajout_onde = Button(case_ajout_onde, '+')
boutton_ajout_onde.on_clicked(ajout_onde)

case_retrait_onde = fig.add_axes([0.6, 0.1, 0.04, taille_onde * len(angles)])
boutton_retrait_onde = Button(case_retrait_onde, '-')
boutton_retrait_onde.on_clicked(retrait_onde)

axonde = fig.add_axes([.64, .1, 0.08, taille_onde * len(angles)])
boutton_change_onde = RadioButtons(axonde, [str(k+1) for k in range(len(angles))])
boutton_change_onde.on_clicked(change_onde)

axangle = fig.add_axes([.57, .52, .25, .03])
sangle = Slider(axangle, r'$\theta$', -np.pi, np.pi, valinit = 0, facecolor = 'grey')
sangle.on_changed(change_angle)

axamp = fig.add_axes([.57, .47, .25, .03])
samp = Slider(axamp, r'$a$', 0, .5, valinit = 0.1, facecolor = 'grey')
samp.on_changed(change_amp)

axfreq = fig.add_axes([.57, .42, .25, .03])
sfreq = Slider(axfreq, r'f', 0, 10, valinit = 1, facecolor = 'grey')
sfreq.on_changed(change_freq)

axphase = fig.add_axes([.57, .37, .25, .03])
sphase = Slider(axphase, r'$\varphi$', 0, 2*np.pi, valinit = 0, facecolor = 'grey')
sphase.on_changed(change_phase)

axcisail = fig.add_axes([.87, .47, 0.09, .08])
boutton_cisail = RadioButtons(axcisail, ['Longitudinale', 'Cisaillement'], active = 0)
boutton_cisail.on_clicked(change_cisail)

ptsvec = 'vec'
axtracer = fig.add_axes([.87, .67, 0.07, .08])
boutton_tracer = RadioButtons(axtracer, ['Positions', 'Vitesses'], active = 1)
boutton_tracer.on_clicked(points_vecteurs)

axNx = fig.add_axes([.57, .72, .25, .03])
sNx = Slider(axNx, r'$N_x$', 0, Nx, valinit = nx, facecolor = 'red')
sNx.on_changed(change_Nx)

axNy = fig.add_axes([.57, .67, .25, .03])
sNy = Slider(axNy, r'$N_y$', 0, Ny, valinit = ny, facecolor = 'red')
sNy.on_changed(change_Ny)

axLx = fig.add_axes([.57, .85, .25, .03])
sLx = Slider(axLx, r'$L_x$', 0, 10, valinit = Lx, facecolor = 'black')
sLx.on_changed(change_Lx)

axLy = fig.add_axes([.57, .8, .25, .03])
sLy = Slider(axLy, r'$L_y$', 0, 10, valinit = Ly, facecolor = 'black')
sLy.on_changed(change_Ly)

case_onoff = fig.add_axes([0.68, 0.58, 0.05, 0.06])
boutton_onoff = Button(case_onoff, 'On / Off', color = 'red')
boutton_onoff.on_clicked(onoff)

axV = fig.add_axes([.57, .94, .25, .03])
sV = Slider(axV, r'$V$', 0, 0.05, valinit = dt, facecolor = 'black')
sV.on_changed(change_V)

ax = fig.add_subplot(121, aspect = 'equal')
position = [.05, .05, .5, .9]
ax.set_position(position)
ax.set_xlim(-max(amplitudes), Lx + max(amplitudes))
ax.set_ylim(-max(amplitudes), Ly + max(amplitudes))

play = False
tracer_points = True
tracer_vecteurs = True

if tracer_points : line, = ax.plot([], [], 'bo', ms = 10 * min(Lx, Ly) / max(Nx, Ny))
if tracer_vecteurs : Q = ax.quiver(*matrice_ligne, *Vitesses_ligne(0), color = 'red')

ani = FuncAnimation(fig, it_anim, interval = 0.01, blit = True)
ani.event_source.stop()

# mng = plt.get_current_fig_manager()       # Vous pouvez décommenter ça si vous utilisez 'TkAgg'
# mng.window.state('zoomed')
plt.show()


## Version simplifiée

# fig = plt.figure()
#
# case_start = fig.add_axes([0.81, 0.25, 0.17, 0.075])
# boutton_start = Button(case_start, 'Play / Pause')
# boutton_start.on_clicked(start)
#
# ax = fig.add_subplot(111, aspect = 'equal')
# ax.set_xlim(-max(amplitudes), Lx + max(amplitudes))
# ax.set_ylim(-max(amplitudes), Ly + max(amplitudes))
#
# play = False
# tracer_points = True
# tracer_vecteurs = True
#
# if tracer_points : line, = ax.plot([], [], 'bo', ms = 10 * min(Lx, Ly) / max(Nx, Ny))
# if tracer_vecteurs : Q = ax.quiver(*matrice_ligne, *Vitesses_ligne(0), color = 'red')
#
# ani = FuncAnimation(fig, it_anim, interval = 0.01, blit = True)
# ani.event_source.stop()
#
# plt.show()