self_org_maps2d.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. """Self organizing maps"""
  2. # -----------------------------------------------------------------------------
  3. # Copyright 2019 (C) Nicolas P. Rougier
  4. # Released under a BSD two-clauses license
  5. #
  6. # References: Kohonen, Teuvo. Self-Organization and Associative Memory.
  7. # Springer, Berlin, 1984.
  8. # https://github.com/rougier/ML-Recipes/blob/master/recipes/ANN/som.py
  9. # -----------------------------------------------------------------------------
  10. import numpy as np
  11. import scipy.spatial
  12. from vedo import Sphere, Grid, Plotter, progressbar
  13. class SOM:
  14. def __init__(self, shape, distance):
  15. self.codebook = np.random.uniform(0, 1, shape)
  16. self.distance = distance / distance.max()
  17. self.samples = []
  18. def learn(self, n_epoch=10000, sigma=(0.25,0.01), lrate=(0.5,0.01)):
  19. t = np.linspace(0, 1, n_epoch)
  20. lrate = lrate[0] * (lrate[1] / lrate[0]) ** t
  21. sigma = sigma[0] * (sigma[1] / sigma[0]) ** t
  22. I = np.random.randint(0, len(self.samples), n_epoch)
  23. self.samples = self.samples[I]
  24. for i in progressbar(n_epoch):
  25. # Get random sample
  26. data = self.samples[i]
  27. # Get index of nearest node (minimum distance)
  28. winner = np.argmin(((self.codebook - data)**2).sum(axis=-1))
  29. # Gaussian centered on winner
  30. G = np.exp(-self.distance[winner]**2 / sigma[i]**2)
  31. # Move nodes towards sample according to Gaussian
  32. self.codebook -= lrate[i] * G[..., np.newaxis] * (self.codebook-data)
  33. # Draw network
  34. if i>500 and not i%20 or i==n_epoch-1:
  35. x, y, z = [self.codebook[:,i].reshape(n,n) for i in range(3)]
  36. grd.wireframe(False).lw(0.5).bc('blue9').flat()
  37. grdpts = grd.points
  38. for i in range(n):
  39. for j in range(n):
  40. grdpts[i*n+j] = (x[i,j], y[i,j], z[i,j])
  41. grd.points = grdpts
  42. if plt: plt.azimuth(1.0).render()
  43. if plt: plt.interactive().close()
  44. return [self.codebook[:,i].reshape(n,n) for i in range(3)]
  45. # -------------------------------------------------------------------------------
  46. if __name__ == "__main__":
  47. n = 20
  48. X, Y = np.meshgrid(np.linspace(0, 1, n), np.linspace(0, 1, n))
  49. P = np.c_[X.ravel(), Y.ravel()]
  50. D = scipy.spatial.distance.cdist(P, P)
  51. sphere = Sphere(res=90).cut_with_plane(origin=(0,-.3,0), normal='y')
  52. sphere.subsample(0.01).add_gaussian_noise(0.5).point_size(3)
  53. plt = Plotter(axes=6, interactive=False)
  54. grd = Grid(res=[n-1, n-1]).c('green2')
  55. plt.show(__doc__, sphere, grd)
  56. som = SOM((len(P), 3), D)
  57. som.samples = sphere.points.copy()
  58. som.learn(n_epoch=4000, sigma=(1, 0.01), lrate=(1, 0.01))