traj_multi = {
btn_opt;
let t = [];
let x_init = 3, y_init = -3;
const grad = (x,y) => [x + y, 20*y + x];
// 1. Vanilla GD
let x1=x_init, y1=y_init;
for(let i=0; i<=steps_opt; i++){
t.push({x:x1, y:y1, algo:"GD", iter:i});
let [gx, gy] = grad(x1, y1);
x1 -= lr_opt * gx; y1 -= lr_opt * gy;
}
// 2. Momentum
let x2=x_init, y2=y_init, vx2=0, vy2=0;
for(let i=0; i<=steps_opt; i++){
t.push({x:x2, y:y2, algo:"Momentum", iter:i});
let [gx, gy] = grad(x2, y2);
vx2 = 0.9*vx2 - lr_opt * gx;
vy2 = 0.9*vy2 - lr_opt * gy;
x2 += vx2; y2 += vy2;
}
// 3. RMSProp
let x3=x_init, y3=y_init, egx=0, egy=0;
for(let i=0; i<=steps_opt; i++){
t.push({x:x3, y:y3, algo:"RMSProp", iter:i});
let [gx, gy] = grad(x3, y3);
egx = 0.9*egx + 0.1*(gx*gx);
egy = 0.9*egy + 0.1*(gy*gy);
x3 -= (lr_opt / (Math.sqrt(egx) + 1e-8)) * gx;
y3 -= (lr_opt / (Math.sqrt(egy) + 1e-8)) * gy;
}
// 4. ADAM
let x4=x_init, y4=y_init, m1x=0, m1y=0, v2x=0, v2y=0;
for(let i=0; i<=steps_opt; i++){
t.push({x:x4, y:y4, algo:"ADAM", iter:i});
let [gx, gy] = grad(x4, y4);
m1x = 0.9*m1x + 0.1*gx; m1y = 0.9*m1y + 0.1*gy;
v2x = 0.999*v2x + 0.001*(gx*gx); v2y = 0.999*v2y + 0.001*(gy*gy);
let m1xh = m1x / (1 - Math.pow(0.9, i+1));
let m1yh = m1y / (1 - Math.pow(0.9, i+1));
let v2xh = v2x / (1 - Math.pow(0.999, i+1));
let v2yh = v2y / (1 - Math.pow(0.999, i+1));
x4 -= (lr_adam / (Math.sqrt(v2xh) + 1e-8)) * m1xh;
y4 -= (lr_adam / (Math.sqrt(v2yh) + 1e-8)) * m1yh;
}
return t;
}
grid_multi = {
let pts = [];
for(let i=-4; i<=4; i+=0.2) {
for(let j=-4; j<=4; j+=0.2) {
pts.push({x: i, y: j, z: (i*i)/2 + 10*(j*j) + i*j});
}
}
return pts;
}
plot_multi = {
const plot = Plot.plot({
width: 1000, height: 600,
x: {domain: [-4, 4]},
y: {domain: [-4, 4]},
color: {scheme: "viridis"},
marks: [
Plot.contour(grid_multi, {x: "x", y: "y", fill: "z", thresholds: 20, stroke: "white", strokeOpacity: 0.2}),
Plot.line(traj_multi.filter(d=>d.algo==="GD"), {x: "x", y: "y", stroke: "gray", strokeWidth: 2}),
Plot.line(traj_multi.filter(d=>d.algo==="Momentum"), {x: "x", y: "y", stroke: "red", strokeWidth: 2}),
Plot.line(traj_multi.filter(d=>d.algo==="RMSProp"), {x: "x", y: "y", stroke: "blue", strokeWidth: 2}),
Plot.line(traj_multi.filter(d=>d.algo==="ADAM"), {x: "x", y: "y", stroke: "orange", strokeWidth: 2}),
Plot.dot(traj_multi.filter(d=>d.algo==="GD"), {x: "x", y: "y", fill: "gray", r: 3}),
Plot.dot(traj_multi.filter(d=>d.algo==="Momentum"), {x: "x", y: "y", fill: "red", r: 3}),
Plot.dot(traj_multi.filter(d=>d.algo==="RMSProp"), {x: "x", y: "y", fill: "blue", r: 3}),
Plot.dot(traj_multi.filter(d=>d.algo==="ADAM"), {x: "x", y: "y", fill: "orange", r: 3})
]
});
const legend = Plot.legend({color: {domain: ["GD", "Momentum", "RMSProp", "ADAM"], range: ["gray", "red", "blue", "orange"]}});
return html`<div>${legend}${plot}</div>`;
}