#-----------------------------------------------------------
# Lecture 22 -- Example 1 -- Scipy's solve_ivp function
#-----------------------------------------------------------

import numpy as np
import matplotlib.pyplot as plt

# link to more info about scipy.integrate module
#   https://docs.scipy.org/doc/scipy/reference/integrate.html

#%%
#-----------------------------------------------------------
# A. Solving single rate equations
#-----------------------------------------------------------
#
# The ODE solver in scipy is the scipy.integrate.solve_ivp function. 
#
# Steps to solve an ODE with solve_ivp:
#
#       dy/dt = f(t, y)
#       y(0) = y_0
#
#   (1) Import solve_ivp from the scipy.integrate module
#   (2) Define the function, f, for the right-hand side of the ODE
#       - Make sure ODE is in standard form
#       - function arguments: f(t, y, [parameters])
#   (3) Create an array of the initial and final time: t_span = [<t_beg>, <t_end>]
#       - The solver will choose it's own internal time points!
#   (4) Define an array with the value of the initial condition: y_0 = [<val>]
#       - Must be an array!
#   (5) Find the solution using solve_ivp
#         <(soln)> = solve_ivp(<fn name>, <t_span>, <y_0>)
#
#   Returns a dictionary object with:
#       t: time points
#       y: array of the values of the solution at t
#       other values: (see documentation)
#
#  link to more info about odeint: 
#   https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
#
# ** Example **
#    
# Solve:
#   
#   dC/dt = -C/tau
# 
# with tau=1, C(0)=1, and t_end=5.

# (1) Import odeint from scipy.integrate module
from scipy.integrate import solve_ivp

# (2) define the function
def fA(t, C):
    tau = 1.0
    return -C/tau

# (3) create array of initial and final times
t_span = np.array([0, 5])

# (4) set initial condition
C0 = np.array([1.0])

# (5) solve the problem
soln = solve_ivp(fA, t_span, C0)
t = soln.t
C = soln.y[0] # returns an array! We'll see why in a minute.

# The exact solution for comparison
t_exact = np.linspace(t_span[0], t_span[1], 101)
C_exact = C0*np.exp(-t_exact/1.0)

# plot the answer
plt.rc("font", size=16)
plt.figure(figsize=(8,6))
plt.plot(t_exact, C_exact, '-', label='exact')
plt.plot(t, C, 'o', label='solve_ivp')
plt.xlabel('time')
plt.ylabel('C');
plt.legend()
plt.show()

#%% 
#-----------------------------------------------------------
# B. odeint with explicit time points
#-----------------------------------------------------------
#
# Solve the previous example with 50 internal time-points
#
#   - To do so, we need to add an extra argument: "t_eval = <array of times>"
#     to the solve_ivp call.
#

# (2) define the function, let tau be an extra parameter
def fB(t, C):
    tau = 1
    return -C/tau

# (3) same array of times 
t_span = np.array([0, 5])
times = np.linspace(t_span[0], t_span[1], 50) # here is where I define 50

# (4) same initial condition
C0 = np.array([1.0])

# (5) solve the problem
soln = solve_ivp(fB, t_span, C0, t_eval = times)
t = soln.t
C = soln.y[0]

# The exact solution for comparison
t_exact = np.linspace(t_span[0], t_span[1], 101)
C_exact = C0*np.exp(-t_exact/1.0)

# plot the answer
plt.rc("font", size=16)
plt.figure(figsize=(8,6))
plt.plot(t_exact, C_exact, '-', label='exact')
plt.plot(t, C, 'o', label='solve_ivp (50 pts)')
plt.xlabel('time')
plt.ylabel('C');
plt.legend()
plt.show()

#%%
#-----------------------------------------------------------
# Practice (Key Below)
#-----------------------------------------------------------
# 
# Solve the first order IVP:
#
#   dy/dt = cos(pi*t**2) - y**2 * sin(pi*t), y(0) = 10
# 
# in the range t = [0, 5] and plot the solution as a function of t.


















































#%% 
# Solution to Practice Problem

# from scipy.import solve_ivp

def rhs(t, y):
    return np.cos(np.pi*t**2) - y**2*np.sin(np.pi*t)

t_span = np.array([0, 5])
y0 = np.array([10])

soln = solve_ivp(rhs, t_span, y0,
                 t_eval=np.linspace(t_span[0], t_span[1], 101)) # this line isn't necessary, but it makes the plot look smoother
t = soln.t
y = soln.y[0]

plt.plot(t, y, '.-')
plt.show

