#-----------------------------------------------------------
# Lecture 22 -- Example 2 -- Systems of ODEs
#-----------------------------------------------------------

import numpy as np
import matplotlib.pyplot as plt

# link to more info about scipy.integrate module
#   https://docs.scipy.org/doc/scipy/reference/integrate.html

#%%
#-----------------------------------------------------------
# C. System of equations
#-----------------------------------------------------------
# 
# If we have a system of rate equations in multiple unknowns, we solve the 
# problem in a similar way, but we use arrays. This is similar to what we 
# did with fsolve when solving nonlinear equations.
#
# Steps to solve an system of ODEs with solve_ivp:
#
#       dy/dt = f(t, y_array)
#       y(0) = y_array_0
#
#   (1) Import solve_ivp from the scipy.integrate module
#   (2) Define the function, f, for the right-hand side of the ODE
#       - Make sure ODE is in standard form
#       - function arguments: f(t, y_array, [parameters])
#   (3) Create an array of the initial and final time: t_span = [<t_beg>, <t_end>]
#       - The solver will choose it's own internal time points!
#   (4) Define an array with the value of the initial condition: 
#               y_array_0 = [<val0>, <val1>, ..., <val2>]
#       - Must be an array!
#   (5) Find the solution using solve_ivp
#         <(soln)> = solve_ivp(<fn name>, <t_span>, <y_array_0>)
#
#   Returns a dictionary object with:
#       t: time points
#       y: array of the values of the solution at t
#       other values: (see documentation)
#
#  link to more info about odeint: 
#   https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
#
# **Example** 
#
#   Solve the system
#
#   dA/dt = -k A B, A(0) = 2 M
#   dB/dt = -k A B, B(0) = 1 M
#   dC/dt = k A B,  C(0) = 0 M
#
#   where k = 0.1 for t = [0, 50]
#   
#   Convert to standard form:
#
#   dy0/dt = - k y0 y1
#   dy1/dt = - k y0 y1
#   dy2/dt = k y0 y1

# (1) Import solve_ivp from scipy.integrate
from scipy.integrate import solve_ivp

# (2) Define function
def f(t, y):
    A = y[0]
    B = y[1]
    C = y[2]
    
    k = 0.1
    
    dA_dt = -k*A*B
    dB_dt = -k*A*B
    dC_dt = k*A*B
    
    return np.array([dA_dt, dB_dt, dC_dt])

# (3) Set times
t_span = np.array([0, 50])
times = np.linspace(t_span[0], t_span[1], 101)

# (4) Set initial conditions
y0 = np.array([2, 1, 0])

# (5) Solve IVP
soln = solve_ivp(f, t_span, y0, t_eval=times)
t = soln.t
A = soln.y[0]
B = soln.y[1]
C = soln.y[2]

# plot the solution
plt.rc("font", size=14)
plt.figure()
plt.plot(t, A, '-', label='A')
plt.plot(t, B, '-', label='B')
plt.plot(t, C, '-', label='C')
plt.xlabel("time")
plt.ylabel("concentration")
plt.legend()
plt.show()

#-----------------------------------------------------------
# D. Higher-Order ODEs
#-----------------------------------------------------------
# 
# The process is the same as systems of ODEs, but you must first turn your
# high-order ODE into a system of ODEs. See the handwritten notes for details
# about how to do this.
#
# **Example** 
#
#   [Question]: Solve for the position of a falling object where
#
#   d^2 x/dt^2 - c (dx/dt)^2 + g = 0
#   
#   x is position 
#   g = 9.81 m/s^2 (gravitational acceleration) 
#   c = 0.01 m^-1 (drag coefficient/mass).
#   t_end = 10 (s)
#   dx/dt (0) = 0 (m/s)
#   x(0) = 250 (m)
#   
#   [Answer]: Re-writing as a system:
#
#   Let v = dx/dt (v = velocity)
# 
#   dv/dt = - g + c/m*v^2
#   dx/dt = v
# 
#   In standard form
#   dy_0/dt = f_0(y_0,y_1,t) \\
#   dy_1/dt = f_1(y_0,y_1,t) 
# 
#   y_0 = v
#   f_0 = -g/m + c/m*y_0^2 
#   y_0(0) = 0
#   y_1 = x
#   f_1 = y_0
#   y_1(0) = 200

# (1) Import solve_ivp from scipy.integrate module
from scipy.integrate import solve_ivp

# (2) define the function
def f(t, y):
    v = y[0]        # set local variables from y-array 
    x = y[1]        # to make it easier to write the rates
    
    g = 9.81        # m/s
    c = 0.01        # 1/m
    
    dvdt = -g+c*v**2
    dxdt = v
    
    return np.array([dvdt, dxdt])

# (3) set times
t_span = np.array([0, 10])
times = np.linspace(t_span[0], t_span[1], 101)

# (4) set initial condition
y0 = np.array([0, 250])   

# (5) Find solution
soln = solve_ivp(f, t_span, y0, t_eval=times)
t = soln.t
v = soln.y[0]
x = soln.y[1]

# plot of answer
fig, ax = plt.subplots(2,1,sharex=True)
ax[0].plot(t, v, 'b-', label='velocity')
ax[0].set_ylabel('v (m/s)');
ax[0].legend()
ax[1].plot(t, x, 'g-', label='position')
ax[1].plot([t[0], t[-1]], [0,0], 'k--')
ax[1].set_xlabel('time (s)')
ax[1].set_ylabel('x (m)');
ax[1].legend()
plt.show()

