morphomatics_riemann.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. """Compute the mean of two Bézier splines on a sphere using the Riemannian mean"""
  2. # https://morphomatics.github.io/tutorials/tutorial_bezierfold/
  3. import jax
  4. import jax.numpy as jnp
  5. import numpy as np
  6. from morphomatics.geom import BezierSpline
  7. from morphomatics.manifold import Bezierfold
  8. from morphomatics.manifold import Sphere
  9. import vedo
  10. M = Sphere()
  11. B = Bezierfold(M, 2, 2)
  12. North = jnp.array([0.0, 0.0, 1.0])
  13. South = jnp.array([0.0, 0.0, -1.0])
  14. p1 = jnp.array([1.0, 0.0, 0.0])
  15. o1 = jnp.array([1 / jnp.sqrt(2), 1 / jnp.sqrt(2), 0.0])
  16. om1 = M.connec.exp(o1, jnp.array([0, 0, -0.25]))
  17. op1 = M.connec.exp(o1, jnp.array([0, 0, 0.25]))
  18. q1 = jnp.array([0, 1, 0.0])
  19. B1 = BezierSpline(M, [jnp.stack((p1, om1, o1)), jnp.stack((o1, op1, q1))])
  20. z = M.connec.geopoint(o1, North, 0.5)
  21. p2 = jnp.array([1.0, 0.0, 0.0])
  22. o2 = M.connec.geopoint(p1, z, 0.5)
  23. om2 = M.connec.geopoint(p1, z, 0.4)
  24. op2 = M.connec.geopoint(p1, z, 0.6)
  25. q2 = z
  26. B2 = BezierSpline(M, [jnp.stack((p2, om2, o2)), jnp.stack((o2, op2, q2))])
  27. data = jnp.array([B.to_coords(B1), B.to_coords(B2)])
  28. mean = Bezierfold.FunctionalBasedStructure.mean(B, data)[0]
  29. mean = B.from_coords(mean)
  30. time = jnp.linspace(0.0, 2.0, num=100)
  31. pts1 = np.asarray(jax.vmap(B1.eval)(time))
  32. pts2 = np.asarray(jax.vmap(B2.eval)(time))
  33. mean_pts = np.asarray(jax.vmap(mean.eval)(time))
  34. sphere = vedo.Sphere().c("yellow9")
  35. line1 = vedo.Line(pts1, lw=3).cmap("Blues", time)
  36. line2 = vedo.Line(pts2, lw=3).cmap("Blues", time).add_scalarbar("Time")
  37. line_mean = vedo.Line(mean_pts, c="red5", lw=4)
  38. vedo.show(sphere, line1, line2, line_mean, __doc__, axes=1).close()