springs_fem.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """Solving a system of springs using the finite element method."""
  2. # https://www.youtube.com/watch?v=YqpIEDWJCwc
  3. import numpy as np
  4. from vedo import *
  5. # np.random.seed(0)
  6. num_springs = 7
  7. k = 1.0 # Stiffness of the springs
  8. # Define applied forces at each node
  9. num_nodes = num_springs + 1 # One more node than springs
  10. F = np.random.randn(num_nodes) /5
  11. # Discretize the system
  12. nodes = np.arange(num_nodes)
  13. elements = list(zip(nodes[:-1], nodes[1:]))
  14. # Assemble global stiffness matrix and force vector
  15. K = np.zeros((num_nodes, num_nodes))
  16. for element in elements:
  17. i, j = element
  18. K[i, i] += k
  19. K[j, j] += k
  20. K[i, j] -= k
  21. K[j, i] -= k
  22. # Apply boundary conditions (fixed nodes at both ends)
  23. fixed_nodes = [0, num_nodes - 1]
  24. for node in fixed_nodes:
  25. K[node, :] = 0
  26. K[:, node] = 0
  27. K[node, node] = 1
  28. F[node] = 0
  29. # Solve for displacements
  30. u = np.linalg.solve(K, F)
  31. yvals = np.zeros(num_nodes)
  32. nodes = np.c_[nodes, yvals]
  33. u = np.c_[u, yvals]
  34. F = np.c_[F, yvals]
  35. nodes_displaced = nodes + u
  36. # Visualize the solution
  37. vnodes1 = Points(nodes).color("k", 0.25).ps(20)
  38. vline1 = Line(nodes).color("k", 0.25)
  39. arr_disp = Arrows2D(nodes, nodes_displaced).y(0.4)
  40. arr_force= Arrows2D(nodes, nodes + F).y(-0.25)
  41. arr_disp.c("red4",0.8).legend('Displacements')
  42. arr_force.c("blue4",0.8).legend('Forces')
  43. vnodes2 = Points(nodes_displaced).color("k").ps(20).y(0.1)
  44. vline2 = Lines(vnodes1, vnodes2).color("k", 0.25)
  45. springs = []
  46. for i in range(num_springs):
  47. s = Spring(nodes_displaced[i], nodes_displaced[i+1], r1=0.04).y(0.1)
  48. s.lighting("metallic")
  49. springs.append(s)
  50. lbox = LegendBox([arr_disp, arr_force], width=0.2, height=0.25, markers='s')
  51. lbox.font("Calco")
  52. show(
  53. __doc__, lbox,
  54. vnodes1, vnodes2, vline1, vline2, arr_disp, arr_force, springs,
  55. axes=8, size=(1900, 490), zoom=3.6,
  56. ).close()