interpolate_field.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """Interpolate a vectorial field using
  2. Thin Plate Spline or Radial Basis Function"""
  3. from scipy.interpolate import Rbf
  4. from vedo import Plotter, Points, Arrows, show
  5. import numpy as np
  6. ls = np.linspace(0, 10, 8)
  7. X, Y, Z = np.meshgrid(ls, ls, ls)
  8. xr, yr, zr = X.ravel(), Y.ravel(), Z.ravel()
  9. positions = np.vstack([xr, yr, zr]).T
  10. sources = [(5, 8, 5), (8, 5, 5), (5, 2, 5)]
  11. deltas = [(1, 1, 0.2), (1, 0, -0.8), (1, -1, 0.2)]
  12. apos = Points(positions, r=2)
  13. # for p in apos.vertices: ####### Uncomment to fix some points.
  14. # if abs(p[2]-5) > 4.999: # differences btw RBF and thinplate
  15. # sources.append(p) # will become much smaller.
  16. # deltas.append(np.zeros(3))
  17. sources = np.array(sources)
  18. deltas = np.array(deltas)
  19. src = Points(sources).color("r").ps(12)
  20. trs = Points(sources + deltas).color("v").ps(12)
  21. arr = Arrows(sources, sources + deltas).color("k8")
  22. ################################################# warp using Thin Plate Splines
  23. warped = apos.clone().warp(sources, sources+deltas)
  24. warped.alpha(0.4).color("lg").point_size(10)
  25. allarr = Arrows(apos.vertices, warped.vertices).color("k8")
  26. set1 = [apos, warped, src, trs, arr, __doc__]
  27. plt1 = Plotter(N=2, bg='bb')
  28. plt1.at(0).show(apos, warped, src, trs, arr, __doc__)
  29. plt1.at(1).show(allarr)
  30. ################################################# RBF
  31. x, y, z = sources[:, 0], sources[:, 1], sources[:, 2]
  32. dx, dy, dz = deltas[:, 0], deltas[:, 1], deltas[:, 2]
  33. itrx = Rbf(x, y, z, dx) # Radial Basis Function interpolator:
  34. itry = Rbf(x, y, z, dy) # interoplate the deltas in each separate
  35. itrz = Rbf(x, y, z, dz) # cartesian dimension
  36. positions_x = itrx(xr, yr, zr) + xr
  37. positions_y = itry(xr, yr, zr) + yr
  38. positions_z = itrz(xr, yr, zr) + zr
  39. positions_rbf = np.vstack([positions_x, positions_y, positions_z]).T
  40. warped_rbf = Points(positions_rbf).color("lg",0.4).point_size(10)
  41. allarr_rbf = Arrows(apos.vertices, warped_rbf.vertices).color("k8")
  42. arr = Arrows(sources, sources + deltas).color("k8")
  43. plt2 = Plotter(N=2, pos=(200, 300), bg='bb')
  44. plt2.at(0).show("Radial Basis Function", apos, warped_rbf, src, trs, arr)
  45. plt2.at(1).show(allarr_rbf)
  46. plt2.interactive()
  47. plt2.close()
  48. plt1.close()