value_iteration.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """Solve a random maze with
  2. Markovian Decision Process"""
  3. # -----------------------------------------------------------------------------
  4. # Copyright 2019 (C) Nicolas P. Rougier & Anthony Strock
  5. # Released under a BSD two-clauses license
  6. #
  7. # References: Bellman, Richard (1957), A Markovian Decision Process.
  8. # Journal of Mathematics and Mechanics. Vol. 6, No. 5.
  9. # -----------------------------------------------------------------------------
  10. #https://github.com/rougier/ML-Recipes/blob/master/recipes/MDP/value-iteration.py
  11. #https://en.wikipedia.org/wiki/Markov_decision_process
  12. import numpy as np
  13. from scipy.ndimage import generic_filter
  14. def maze(shape=(30, 50), complexity=0.8, density=0.8):
  15. shape = (np.array(shape)//2)*2 + 1
  16. n_complexity = int(complexity*(shape[0]+shape[1]))
  17. n_density = int(density*(shape[0]*shape[1]))
  18. Z = np.ones(shape, dtype=bool)
  19. Z[1:-1, 1:-1] = 0
  20. P = (np.dstack([np.random.randint(0, shape[0]+1, n_density),
  21. np.random.randint(0, shape[1]+1, n_density)])//2)*2
  22. for (y,x) in P.squeeze():
  23. Z[y, x] = 1
  24. for j in range(n_complexity):
  25. neighbours = []
  26. if x > 1: neighbours.append([(y, x-1), (y, x-2)])
  27. if x < shape[1]-2: neighbours.append([(y, x+1), (y, x+2)])
  28. if y > 1: neighbours.append([(y-1, x), (y-2, x)])
  29. if y < shape[0]-2: neighbours.append([(y+1, x), (y+2, x)])
  30. if len(neighbours):
  31. next_1, next_2 = neighbours[np.random.randint(len(neighbours))]
  32. if Z[next_2] == 0:
  33. Z[next_1] = Z[next_2] = 1
  34. y, x = next_2
  35. else:
  36. break
  37. return Z
  38. def solve(Z, start, goal):
  39. Z = 1 - Z
  40. G = np.zeros(Z.shape)
  41. G[start] = 1
  42. # We iterate until value at exit is > 0. This requires the maze
  43. # to have a solution or it will be stuck in the loop.
  44. def diffuse(Z, gamma=0.99):
  45. return max(gamma*Z[0], gamma*Z[1], Z[2], gamma*Z[3], gamma*Z[4])
  46. while G[goal] == 0.0:
  47. G = Z * generic_filter(G, diffuse, footprint=[[0, 1, 0],
  48. [1, 1, 1],
  49. [0, 1, 0]])
  50. # Descent gradient to find shortest path from entrance to exit
  51. y, x = goal
  52. dirs = (0,-1), (0,+1), (-1,0), (+1,0)
  53. P = []
  54. while (x, y) != start:
  55. P.append((y,x))
  56. neighbours = [-1, -1, -1, -1]
  57. if x > 0: neighbours[0] = G[y, x-1]
  58. if x < G.shape[1]-1: neighbours[1] = G[y, x+1]
  59. if y > 0: neighbours[2] = G[y-1, x]
  60. if y < G.shape[0]-1: neighbours[3] = G[y+1, x]
  61. a = np.argmax(neighbours)
  62. x, y = x + dirs[a][1], y + dirs[a][0]
  63. P.append((y,x))
  64. return P, G
  65. def show_solution3d(S, start, goal):
  66. from vedo import Text3D, Cube, Line, Grid, merge, show
  67. pts, cubes, txts = [], [], []
  68. pts = [(x,-y) for y,x in S[0]]
  69. for y,line in enumerate(Z):
  70. for x,c in enumerate(line):
  71. if c: cubes.append(Cube([x,-y,0]))
  72. path = Line(pts).lw(6).c('red5')
  73. walls = merge(cubes).flat().c('orange1')
  74. sy, sx = S[1].shape
  75. gradient = np.flip(S[1], axis=0).ravel()
  76. grd = Grid(pos=((sx-1)/2, -(sy-1)/2, -0.49), s=[sx,sy], res=[sx,sy])
  77. grd.lw(0).wireframe(False).cmap('gist_earth_r', gradient, on='cells')
  78. grd.add_scalarbar('Gradient', horizontal=True, c='k', nlabels=2)
  79. txts.append(__doc__)
  80. txts.append(Text3D('Start', pos=[start[1]-1,-start[0]+1.5,1], c='k'))
  81. txts.append(Text3D('Goal!', pos=[goal[1] -2,-goal[0] -2.7,1], c='k'))
  82. return show(path, walls, grd, txts, axes=0, zoom=1.2)
  83. ##########################################################################
  84. if __name__ == '__main__':
  85. np.random.seed(4)
  86. Z = maze(shape=(50, 70))
  87. start, goal = (1,1), (Z.shape[0]-2, Z.shape[1]-2)
  88. S = solve(Z, start, goal)
  89. show_solution3d(S, start, goal).close()