Skip to content

Commit 30116ef

Browse files
committed
feat: add ElasticBands
1 parent 1564830 commit 30116ef

File tree

2 files changed

+262
-0
lines changed

2 files changed

+262
-0
lines changed
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
"""
2+
Elastic Bands
3+
4+
author: Wang Zheng (@Aglargil)
5+
6+
Ref:
7+
8+
- [Elastic Bands: Connecting Path Planning and Control]
9+
(http://www8.cs.umu.se/research/ifor/dl/Control/elastic%20bands.pdf)
10+
"""
11+
12+
import numpy as np
13+
import sys
14+
import pathlib
15+
import matplotlib.pyplot as plt
16+
from matplotlib.patches import Circle
17+
18+
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
19+
20+
from Mapping.DistanceMap.distance_map import compute_sdf
21+
22+
# Elastic Bands Params
23+
MAX_BUBBLE_RADIUS = 100
24+
MIN_BUBBLE_RADIUS = 10
25+
RHO0 = 20.0 # Maximum distance for applying repulsive force
26+
KC = 0.05 # Contraction force gain
27+
KR = -0.1 # Repulsive force gain
28+
LAMBDA = 0.7 # Overlap constraint factor
29+
STEP_SIZE = 3.0 # Step size for calculating gradient
30+
31+
# Visualization Params
32+
ENABLE_PLOT = True
33+
ENABLE_INTERACTIVE = False
34+
MAX_ITER = 50
35+
36+
37+
class Bubble:
38+
def __init__(self, position, radius):
39+
self.pos = np.array(position) # Bubble center coordinates [x, y]
40+
self.radius = radius # Safety distance radius ρ(b)
41+
if self.radius > MAX_BUBBLE_RADIUS:
42+
self.radius = MAX_BUBBLE_RADIUS
43+
if self.radius < MIN_BUBBLE_RADIUS:
44+
self.radius = MIN_BUBBLE_RADIUS
45+
46+
47+
class ElasticBands:
48+
def __init__(self, initial_path, obstacles, rho0=RHO0, kc=0.05, kr=-0.1):
49+
self.distance_map = compute_sdf(obstacles)
50+
self.bubbles = [
51+
Bubble(p, self.compute_rho(p)) for p in initial_path
52+
] # Initialize bubble chain
53+
self.kc = kc # Contraction force gain
54+
self.kr = kr # Repulsive force gain
55+
self.rho0 = rho0 # Maximum distance for applying repulsive force
56+
57+
def compute_rho(self, position):
58+
"""Compute the distance field value at the position"""
59+
return self.distance_map[int(position[0]), int(position[1])]
60+
61+
def contraction_force(self, i):
62+
"""Calculate internal contraction force for the i-th bubble"""
63+
if i == 0 or i == len(self.bubbles) - 1:
64+
return np.zeros(2)
65+
66+
prev = self.bubbles[i - 1].pos
67+
next_ = self.bubbles[i + 1].pos
68+
current = self.bubbles[i].pos
69+
70+
# f_c = kc * ( (prev-current)/|prev-current| + (next-current)/|next-current| )
71+
dir_prev = (prev - current) / (np.linalg.norm(prev - current) + 1e-6)
72+
dir_next = (next_ - current) / (np.linalg.norm(next_ - current) + 1e-6)
73+
return self.kc * (dir_prev + dir_next)
74+
75+
def external_force(self, i):
76+
"""Calculate external repulsive force for the i-th bubble"""
77+
h = STEP_SIZE # Step size
78+
b = self.bubbles[i].pos
79+
rho = self.bubbles[i].radius
80+
81+
if rho >= self.rho0:
82+
return np.zeros(2)
83+
84+
# Finite difference approximation of the gradient ∂ρ/∂b
85+
dx = np.array([h, 0])
86+
dy = np.array([0, h])
87+
grad_x = (self.compute_rho(b - dx) - self.compute_rho(b + dx)) / (2 * h)
88+
grad_y = (self.compute_rho(b - dy) - self.compute_rho(b + dy)) / (2 * h)
89+
grad = np.array([grad_x, grad_y])
90+
91+
return self.kr * (self.rho0 - rho) * grad
92+
93+
def update_bubbles(self):
94+
"""Update bubble positions"""
95+
new_bubbles = []
96+
for i in range(len(self.bubbles)):
97+
if i == 0 or i == len(self.bubbles) - 1:
98+
new_bubbles.append(self.bubbles[i]) # Fixed start and end points
99+
continue
100+
101+
f_total = self.contraction_force(i) + self.external_force(i)
102+
alpha = self.bubbles[i].radius # Adaptive step size
103+
new_pos = self.bubbles[i].pos + alpha * f_total
104+
new_pos = np.clip(new_pos, 0, 499)
105+
new_radius = self.compute_rho(new_pos)
106+
107+
# Update bubble and maintain overlap constraint
108+
new_bubble = Bubble(new_pos, new_radius)
109+
new_bubbles.append(new_bubble)
110+
111+
self.bubbles = new_bubbles
112+
self._maintain_overlap()
113+
114+
def _maintain_overlap(self):
115+
"""Maintain bubble chain continuity (simplified insertion/deletion mechanism)"""
116+
# Insert bubbles
117+
i = 0
118+
while i < len(self.bubbles) - 1:
119+
bi, bj = self.bubbles[i], self.bubbles[i + 1]
120+
dist = np.linalg.norm(bi.pos - bj.pos)
121+
if dist > LAMBDA * (bi.radius + bj.radius):
122+
new_pos = (bi.pos + bj.pos) / 2
123+
rho = self.compute_rho(
124+
new_pos
125+
) # Calculate new radius using environment model
126+
self.bubbles.insert(i + 1, Bubble(new_pos, rho))
127+
i += 2 # Skip the processed region
128+
else:
129+
i += 1
130+
131+
# Delete redundant bubbles
132+
i = 1
133+
while i < len(self.bubbles) - 1:
134+
prev = self.bubbles[i - 1]
135+
next_ = self.bubbles[i + 1]
136+
dist = np.linalg.norm(prev.pos - next_.pos)
137+
if dist <= LAMBDA * (prev.radius + next_.radius):
138+
del self.bubbles[i] # Delete if redundant
139+
else:
140+
i += 1
141+
142+
143+
class ElasticBandsVisualizer:
144+
def __init__(self):
145+
self.obstacles = np.zeros((500, 500))
146+
self.start_point = None
147+
self.end_point = None
148+
self.elastic_band = None
149+
150+
if ENABLE_PLOT:
151+
self.fig, self.ax = plt.subplots(figsize=(8, 8))
152+
# Set the display range of the graph
153+
self.ax.set_xlim(0, 500)
154+
self.ax.set_ylim(0, 500)
155+
156+
if ENABLE_INTERACTIVE:
157+
self.path_points = [] # Add a list to store path points
158+
# Connect mouse events
159+
self.fig.canvas.mpl_connect("button_press_event", self.on_click)
160+
else:
161+
self.path_points = [
162+
[30, 136],
163+
[61, 214],
164+
[77, 256],
165+
[77, 309],
166+
[53, 366],
167+
[41, 422],
168+
[51, 453],
169+
[110, 471],
170+
[184, 437],
171+
[257, 388],
172+
[343, 353],
173+
[402, 331],
174+
[476, 273],
175+
[456, 206],
176+
[430, 160],
177+
[402, 107],
178+
]
179+
self.obstacles = np.load(pathlib.Path(__file__).parent / "obstacles.npy")
180+
self.plan_path()
181+
182+
self.plot_background()
183+
184+
def plot_background(self):
185+
"""Plot the background grid"""
186+
if not ENABLE_PLOT:
187+
return
188+
189+
self.ax.cla()
190+
self.ax.set_xlim(0, 500)
191+
self.ax.set_ylim(0, 500)
192+
self.ax.grid(True)
193+
if self.path_points:
194+
self.ax.plot(
195+
[p[0] for p in self.path_points],
196+
[p[1] for p in self.path_points],
197+
"yo",
198+
markersize=8,
199+
)
200+
201+
self.ax.imshow(self.obstacles.T, origin="lower", cmap="binary", alpha=0.3)
202+
if self.elastic_band is not None:
203+
path = [b.pos.tolist() for b in self.elastic_band.bubbles]
204+
path = np.array(path)
205+
self.ax.plot(path[:, 0], path[:, 1], "b-", linewidth=2, label="path")
206+
207+
for bubble in self.elastic_band.bubbles:
208+
circle = Circle(
209+
bubble.pos, bubble.radius, fill=False, color="g", alpha=0.3
210+
)
211+
self.ax.add_patch(circle)
212+
self.ax.plot(bubble.pos[0], bubble.pos[1], "bo", markersize=10)
213+
214+
self.ax.legend()
215+
plt.draw()
216+
plt.pause(0.01)
217+
218+
def on_click(self, event):
219+
"""Handle mouse click events"""
220+
if event.inaxes != self.ax:
221+
return
222+
223+
x, y = int(event.xdata), int(event.ydata)
224+
225+
if event.button == 1: # Left click to add obstacles
226+
size = 30 # Side length of the square
227+
half_size = size // 2
228+
229+
# Ensure not out of the map boundary
230+
x_start = max(0, x - half_size)
231+
x_end = min(self.obstacles.shape[0], x + half_size)
232+
y_start = max(0, y - half_size)
233+
y_end = min(self.obstacles.shape[1], y + half_size)
234+
235+
# Set the square area as obstacles (value set to 1)
236+
self.obstacles[x_start:x_end, y_start:y_end] = 1
237+
238+
elif event.button == 3: # Right click to add path points
239+
self.path_points.append([x, y])
240+
241+
elif event.button == 2: # Middle click to end path input and start planning
242+
if len(self.path_points) >= 2:
243+
self.plan_path()
244+
245+
self.plot_background()
246+
247+
def plan_path(self):
248+
"""Plan the path"""
249+
250+
initial_path = self.path_points
251+
# Create an elastic band object and optimize
252+
self.elastic_band = ElasticBands(initial_path, self.obstacles)
253+
for _ in range(MAX_ITER):
254+
self.elastic_band.update_bubbles()
255+
self.path_points = [b.pos for b in self.elastic_band.bubbles]
256+
self.plot_background()
257+
258+
259+
if __name__ == "__main__":
260+
ElasticBandsVisualizer()
261+
if ENABLE_PLOT:
262+
plt.show()
1.91 MB
Binary file not shown.

0 commit comments

Comments
 (0)