warp5.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """
  2. Takes 2 shapes, source and target, and morphs source on target
  3. this is obtained by fitting 18 parameters of a non linear,
  4. quadratic, transformation defined in transform()
  5. The fitting minimizes the distance to the target surface
  6. using algorithms available in the scipy.optimize package.
  7. """
  8. from vedo import dataurl, vector, mag2, mag
  9. from vedo import Plotter, Sphere, Point, Text3D, Arrows, Mesh
  10. import scipy.optimize as opt
  11. print(__doc__)
  12. class Morpher:
  13. def __init__(self):
  14. self.source = None
  15. self.target = None
  16. self.bound = 0.1
  17. self.method = "SLSQP" # 'SLSQP', 'L-BFGS-B', 'TNC' ...
  18. self.tolerance = 0.0001
  19. self.subsample = 200 # pick only subsample pts
  20. self.allow_scaling = False
  21. self.params = []
  22. self.msource = None
  23. self.s_size = ([0, 0, 0], 1) # ave position and ave size
  24. self.fitResult = None
  25. self.chi2 = 1.0e10
  26. self.plt = None
  27. # -------------------------------------------------------- fit function
  28. def transform(self, p):
  29. a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, b6, c1, c2, c3, c4, c5, c6, s = self.params
  30. pos, sz = self.s_size[0], self.s_size[1]
  31. x, y, z = (p - pos) / sz * s # bring to origin, norm and scale
  32. xx, yy, zz, xy, yz, xz = x * x, y * y, z * z, x * y, y * z, x * z
  33. xp = x + 2 * a1 * xy + a4 * xx + 2 * a2 * yz + a5 * yy + 2 * a3 * xz + a6 * zz
  34. yp = +2 * b1 * xy + b4 * xx + y + 2 * b2 * yz + b5 * yy + 2 * b3 * xz + b6 * zz
  35. zp = +2 * c1 * xy + c4 * xx + 2 * c2 * yz + c5 * yy + z + 2 * c3 * xz + c6 * zz
  36. p2 = vector(xp, yp, zp)
  37. p2 = (p2 * sz) + pos # take back to original size and position
  38. return p2
  39. def _func(self, pars):
  40. self.params = pars
  41. #calculate chi2
  42. d2sum, n = 0.0, self.source.npoints
  43. srcpts = self.source.vertices
  44. rng = range(0, n, int(n / self.subsample))
  45. for i in rng:
  46. p1 = srcpts[i]
  47. p2 = self.transform(p1)
  48. tp = self.target.closest_point(p2)
  49. d2sum += mag2(p2 - tp)
  50. d2sum /= len(rng)
  51. if d2sum < self.chi2:
  52. if d2sum < self.chi2 * 0.99:
  53. print("Emin ->", d2sum)
  54. self.chi2 = d2sum
  55. return d2sum
  56. # ------------------------------------------------------- Fit
  57. def morph(self):
  58. def avesize(pts): # helper fnc
  59. s, amean = 0, vector(0, 0, 0)
  60. for p in pts:
  61. amean = amean + p
  62. amean /= len(pts)
  63. for p in pts:
  64. s += mag(p - amean)
  65. return amean, s / len(pts)
  66. print("\n..minimizing with " + self.method)
  67. self.msource = self.source.clone()
  68. self.s_size = avesize(self.source.vertices)
  69. bnds = [(-self.bound, self.bound)] * 18
  70. x0 = [0.0] * 18 # initial guess
  71. x0 += [1.0] # the optional scale
  72. if self.allow_scaling:
  73. bnds += [(1.0 - self.bound, 1.0 + self.bound)]
  74. else:
  75. bnds += [(1.0, 1.0)] # fix scale to 1
  76. res = opt.minimize(self._func, x0,
  77. bounds=bnds, method=self.method, tol=self.tolerance)
  78. # recalc for all pts:
  79. self.subsample = self.source.npoints
  80. self._func(res["x"])
  81. print("\nFinal fit score", res["fun"])
  82. self.fitResult = res
  83. # ------------------------------------------------------- Visualization
  84. def draw_shapes(self):
  85. newpts = []
  86. for p in self.msource.vertices:
  87. newp = self.transform(p)
  88. newpts.append(newp)
  89. self.msource.vertices = newpts
  90. arrs = []
  91. pos, sz = self.s_size[0], self.s_size[1]
  92. sphere0 = Sphere(pos, r=sz, res=10, quads=True).wireframe().c("gray")
  93. for p in sphere0.vertices:
  94. newp = self.transform(p)
  95. arrs.append([p, newp])
  96. hair = Arrows(arrs, s=0.3, c='jet').add_scalarbar()
  97. zero = Point(pos).c("black")
  98. x1, x2, y1, y2, z1, z2 = self.target.bounds()
  99. tpos = [x1, y2, z1]
  100. text1 = Text3D("source vs target", tpos, s=sz/10).color("dg")
  101. text2 = Text3D("morphed vs target", tpos, s=sz/10).color("db")
  102. text3 = Text3D("deformation", tpos, s=sz/10).color("dr")
  103. self.plt = Plotter(shape=[1, 3], axes=1)
  104. self.plt.at(2).show(sphere0, zero, text3, hair)
  105. self.plt.at(1).show(self.msource, self.target, text2)
  106. self.plt.at(0).show(self.source, self.target, text1, zoom=1.2)
  107. self.plt.interactive().close()
  108. #################################
  109. if __name__ == "__main__":
  110. mr = Morpher()
  111. mr.source = Mesh(dataurl+"270.vtk").color("g",0.4)
  112. mr.target = Mesh(dataurl+"290.vtk").color("b",0.3)
  113. mr.target.wireframe()
  114. mr.allow_scaling = True
  115. mr.bound = 0.4 # limits the parameter value
  116. mr.morph()
  117. print("Result of parameter fit:\n", mr.params)
  118. # now mr.msource contains the modified/morphed source.
  119. mr.draw_shapes()