import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import HTML

import numpy as np import matplotlib.pyplot as plt from matplotlib import animation from mpl_toolkits.mplot3d import Axes3D from IPython.display import HTML

#SETUP POSTION AND TIME GRID

axis_size = 100
side_length = 1
\dx , dy = side_length/axis_size , side_length/axis_size
axis_points = np.linspace (0,side_length,axis_size)
c = 1/np.sqrt(2)

#setup time grid

T = 20
dt = 0.5 *(1/c)*(1/np.sqrt(dx**(-2)+(dy**(-2))))
n= int(T/dt)

#initial condition


def initial_cond(x,y):
    return np.sin(2*np.pi*x + 2*np.pi*y)

#create meshgrid

X , Y = np.meshgrid(axis_points,axis_points)
U = initial_cond(X,Y)
U.shape
(100, 100)
#assigning the first boundary condition boundary condition

B1 = U[:,0]
B2 = U[:,-1]
B3 = U[0,:]
B4 = U[-1,:]

#assigning the second boundary condition for iteration with time( the first derivative with central difference scheme)

U1 = np.zeros((axis_size,axis_size))

U1[1:-1,1:-1] =  (U[1:-1,1:-1]+(c*2/2)*(dt**2/dx**2)*(U[1:-1,0:-2]-2*U[1:-1,1:-1]+U[1:-1,2:])+(c**2/2)*(dt**2/dy**2 \
                            )*(U[0:-2,1:-1]-2*U[1:-1,1:-1]+U[2:,1:-1]))
#Reinforce the boundary condition
U[:,0] = B1
U[:,-1] =B2
U[0,:] =B3
U[-1,:] = B4


#give these boundary conditions their own variable
B5 =U[:,0] 
B6 = U[:,-1] 
B7= U[0,:] 
B8= U[-1,:]



#STEP 4.  Solve the PDE for a result of all spatial positions after
#time T has elapsed.
#Create a leading array to update the wave at every time step.  Initialize it with zeros.

U2 = np.zeros((axis_size,axis_size))

#Create an initialized array to store all the wave amplitude map images for each time point.
map_array = np.zeros((axis_size,axis_size,n))


#Initialize the first two slices of the array with the two initial wave maps.
map_array[:,:,0] = U
map_array[:,:,1] = U1

#Numerically solve the PDE by iteration over the specified total time.
for i in range(2,n):

    U2[1:-1,1:-1] = (2*U1[1:-1,1:-1] - U[1:-1,1:-1] + (c**2)*((dt/dx)**2)*(U1[1:-1,0:-2] - 2*U1[1:-1,1:-1] +
                    U1[1:-1, 2:]) + (c**2)*((dt/dy)**2)*(U1[0:-2,1:-1] - 2*U1[1:-1,1:-1] +
                    U1[2:, 1:-1]))
    #Direchlet boundary conditions for the wave.
    U2[:,0] = B5
    U2[:,-1] = B6
    U2[0,:] = B7
    U2[-1,:] = B8

    U1[:,0] = B5
    U1[:,-1] = B6
    U1[0,:] = B7
    U1[-1,:] = B8

    U[:,0] = B1
    U[:,-1] = B2
    U[0,:] = B3
    U[-1,:] = B4
    
    #MAP THE REMAINING TIME FRAMES
    map_array[:,:,i] = U2
    
    #UPDATE
    U = U1
    U1 = U2




#3D PLOTTING

#STEP 5.  Animate the wave amplitudes as a movie to see how the
#wave amplitude changes over time.  For a movie animation showing the changes,
#we don't need every wave map.  The following array picks out a few of the wave maps
#to make the movie loop, using Numpy slicing.  There are many wave maps,
#so this takes every 20th map from the original array of waves.
movie_frames = map_array[:,:,0::20]

#Set up the plot template to animate the wave amplitude changes
#over time.  An initial figure needs to be generated along with the colormap plot
#and the associated labels and axes.
fig = plt.figure()

#Create a 3D projection view for the surface plot.
ax = fig.gca(projection = '3d')

#Generate an initial surface plot using the initial wave condition.  Set the grid size for plotting,
#colormap, range and mesh linewidth.
surf = (ax.plot_surface(X,Y,movie_frames[:,:,3], rstride=2, cstride=2,
                        cmap ='coolwarm', vmax = 1, vmin = -1, linewidth=1))

#Title of plot.
ax.set_title('2D Sine Wave')

#Add a colorbar to the plot.
fig.colorbar(surf)                         #Add a colorbar to the plot
ax.view_init(elev=45,azim=60)              #Elevation & angle initial view
ax.dist=9                           #Viewing distance

#Axis limits and labels.
ax.set_xlim3d([0.0, 1.0])
ax.set_xlabel('X')

ax.set_ylim3d([0.0, 1.0])
ax.set_ylabel('Y')

ax.set_zlim3d([-1.0, 1.0])
ax.set_zlabel('Z')


Text(0.5, 0, 'Z')

png

def animate(i):
    ax.clear()
    surf = (ax.plot_surface(X,Y,movie_frames[:,:,i], rstride=2, cstride=2,
                        cmap ='coolwarm', vmax = 1, vmin = -1, linewidth=1))
    ax.view_init(elev=30,azim=70)              #Elevation & angle initial view
    ax.dist=8                                  #Viewing distance

    #Axis limits and labels.
    ax.set_xlim3d([0.0, 1.0])
    ax.set_xlabel('X')

    ax.set_ylim3d([0.0, 1.0])
    ax.set_ylabel('Y')

    ax.set_zlim3d([-1.0, 1.0])
    ax.set_zlabel('Z')
    ax.set_title('2D Sine Wave')

    return surf

#Call the full animation function.  The number of frames is given by the last element of the shape tuple of
#of the movie frames array.
anim = animation.FuncAnimation(fig, animate, frames = movie_frames.shape[2])
HTML(anim.to_html5_video())