volterra.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """The Lotka-Volterra model where:
  2. x is the number of preys
  3. y is the number of predators"""
  4. #Credits:
  5. #http://visual.icse.us.edu.pl/NPB/notebooks/Lotka_Volterra_with_SAGE.html
  6. #as implemented in K3D_Animations/Lotka-Volterra.ipynb
  7. #https://en.wikipedia.org/wiki/Lotka%E2%80%93Volterra_equations
  8. import numpy as np
  9. from scipy.integrate import odeint
  10. def rhs(y0, t, a):
  11. x, y = y0[0], y0[1]
  12. return [x-x*y, a*(x*y-y)]
  13. a_1 = 1.2
  14. x0_1, x0_2, x0_3 = 2.0, 1.2, 1.0
  15. y0_1, y0_2, y0_3 = 4.2, 3.7, 2.4
  16. T = np.arange(0, 8, 0.02)
  17. sol1 = odeint(rhs, [x0_1, y0_1], T, args=(a_1,))
  18. sol2 = odeint(rhs, [x0_2, y0_2], T, args=(a_1,))
  19. sol3 = odeint(rhs, [x0_3, y0_3], T, args=(a_1,))
  20. limx = np.linspace(np.min(sol1[:,0]), np.max(sol1[:,0]), 20)
  21. limy = np.linspace(np.min(sol1[:,1]), np.max(sol1[:,1]), 20)
  22. vx, vy = np.meshgrid(limx, limy)
  23. vx, vy = np.ravel(vx), np.ravel(vy)
  24. vec = rhs([vx, vy], t=0.01, a=a_1)
  25. origins = np.stack([np.zeros(np.shape(vx)), vx, vy]).T
  26. vectors = np.stack([np.zeros(np.shape(vec[0])), vec[0], vec[1]]).T
  27. vectors /= np.stack([np.linalg.norm(vectors, axis=1)]).T * 5
  28. curve_points1 = np.vstack([np.zeros(sol1[:,0].shape), sol1[:,0], sol1[:,1]]).T
  29. curve_points2 = np.vstack([np.zeros(sol2[:,0].shape), sol2[:,0], sol2[:,1]]).T
  30. curve_points3 = np.vstack([np.zeros(sol3[:,0].shape), sol3[:,0], sol3[:,1]]).T
  31. ########################################################################
  32. from vedo import Plotter, Arrows, Points, Line
  33. plt = Plotter(bg="blackboard")
  34. plt += Arrows(origins, origins+vectors, c='lr')
  35. plt += Points(curve_points1, c='y')
  36. plt += Line(curve_points1, c='y')
  37. plt += Line(np.vstack([T, sol1[:,0], sol1[:,1]]).T, c='y')
  38. plt += Points(curve_points2, c='g')
  39. plt += Line(curve_points2, c='g')
  40. plt += Line(np.vstack([T, sol2[:,0], sol2[:,1]]).T, c='g')
  41. plt += Points(curve_points3, c='lb')
  42. plt += Line(curve_points3, c='lb')
  43. plt += Line(np.vstack([T, sol3[:,0], sol3[:,1]]).T, c='lb')
  44. plt += __doc__
  45. plt.show(axes={'xtitle':'time',
  46. 'ytitle':'x',
  47. 'ztitle':'y',
  48. 'zxgrid':True,
  49. 'yzgrid':False},
  50. viewup='x',
  51. )
  52. plt.close()