Skip to content

Commit 6a5a73d

Browse files
Add example with multiple viapoints for ProMP
1 parent 5028976 commit 6a5a73d

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""
2+
===================================
3+
Plot ProMP with Multiple Via Points
4+
===================================
5+
6+
This example shows how to use ProMP with multiple via points.
7+
"""
8+
print(__doc__)
9+
10+
import numpy as np
11+
from movement_primitives.promp import ProMP, via_points
12+
from movement_primitives.data import generate_1d_trajectory_distribution
13+
import matplotlib.pyplot as plt
14+
15+
16+
n_demos = 100
17+
n_steps = 101
18+
T, Y = generate_1d_trajectory_distribution(n_demos, n_steps)
19+
promp = ProMP(n_dims=1, n_weights_per_dim=50)
20+
promp.imitate([T] * n_demos, Y)
21+
Y_mean = promp.mean_trajectory(T)
22+
Y_conf = 1.96 * np.sqrt(promp.var_trajectory(T))
23+
24+
y_cond = np.array([0.5, -0.5, 0.0, 1.0])
25+
y_conditional_cov = np.zeros(4)
26+
ts = np.array([0.2, 0.5, 0.7, 1.0])
27+
cpromp = via_points(
28+
promp=promp,
29+
y_cond=y_cond,
30+
y_conditional_cov=y_conditional_cov,
31+
ts=ts,
32+
)
33+
Y_cmean = cpromp.mean_trajectory(T)
34+
Y_cconf = 1.96 * np.sqrt(cpromp.var_trajectory(T))
35+
36+
plt.figure(figsize=(10, 5))
37+
38+
ax1 = plt.subplot(121)
39+
ax1.set_title("Training set and ProMP")
40+
ax1.fill_between(T, (Y_mean - Y_conf).ravel(), (Y_mean + Y_conf).ravel(), color="r", alpha=0.3)
41+
ax1.plot(T, Y_mean, c="r", lw=2, label="ProMP")
42+
ax1.plot(T, Y[:, :, 0].T, c="k", alpha=0.1)
43+
ax1.set_xlim((-0.05, 1.05))
44+
ax1.set_ylim((-2.5, 3))
45+
ax1.legend(loc="best")
46+
47+
ax2 = plt.subplot(122)
48+
ax2.set_title("Conditioned ProMP")
49+
ax2.scatter(ts, y_cond, marker="*", s=100, c="b", label="Viapoints")
50+
ax2.fill_between(T, (Y_cmean - Y_cconf).ravel(), (Y_cmean + Y_cconf).ravel(), color="b", alpha=0.3)
51+
ax2.plot(T, Y_cmean, c="b", lw=2, label="Conditioned ProMP")
52+
ax2.set_xlim((-0.05, 1.05))
53+
ax2.set_ylim((-2.5, 3))
54+
ax2.legend(loc="best")
55+
56+
ax1.set_xlabel("Time $t$ [s]")
57+
ax1.set_ylabel("Position $y$ [m]")
58+
ax2.set_xlabel("Time $t$ [s]")
59+
plt.tight_layout()
60+
plt.show()

0 commit comments

Comments
 (0)