Mathematical Foundations of AI & ML
Unit 6: Loss Landscapes and Optimization Behavior
FAU Erlangen-Nürnberg
By the end of this lecture, students can:
\[ H_{ij} = \frac{\partial^2 J}{\partial \theta_i \partial \theta_j} \]
//| echo: false
viewof lr_saddle = Inputs.range([0.01, 0.5], {value: 0.1, step: 0.01, label: "Learning Rate (η)"});
viewof noise_saddle = Inputs.range([0, 0.5], {value: 0.0, step: 0.01, label: "Gradient Noise StdDev"});
viewof steps_saddle = Inputs.range([1, 100], {value: 20, step: 1, label: "Steps"});
viewof btn_saddle = Inputs.button("Reset / Re-run");//| echo: false
// Saddle function f(x,y) = x^2 - y^2
gd_traj_saddle = {
btn_saddle;
let x = 0; // Exactly zero gradient in x direction initially!
let y = 0.1;
let traj = [{x: x, y: y, iter: 0}];
const randn = () => Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
for(let i=1; i<=steps_saddle; i++) {
let gx = 2 * x + noise_saddle * randn();
let gy = -2 * y + noise_saddle * randn();
x = x - lr_saddle * gx;
y = y - lr_saddle * gy;
traj.push({x: x, y: y, iter: i});
}
return traj;
}
grid_saddle = {
let pts = [];
for(let i=-2; i<=2; i+=0.1) {
for(let j=-2; j<=2; j+=0.1) {
pts.push({x: i, y: j, z: (i*i - j*j)});
}
}
return pts;
}
plot_saddle = Plot.plot({
width: 800, height: 400,
x: {domain: [-2, 2], label: "Parameter x (Positive Curvature)"},
y: {domain: [-2, 2], label: "Parameter y (Negative Curvature)"},
color: {scheme: "RdBu", legend: false},
marks: [
Plot.contour(grid_saddle, {x: "x", y: "y", fill: "z", thresholds: 20, stroke: "black", strokeOpacity: 0.2}),
Plot.line(gd_traj_saddle, {x: "x", y: "y", stroke: "lime", strokeWidth: 2}),
Plot.dot(gd_traj_saddle, {x: "x", y: "y", fill: "black", r: 3})
]
})//| echo: false
viewof a_cond = Inputs.range([1, 20], {value: 10, step: 1, label: "Condition Number (a)"});
viewof lr_cond = Inputs.range([0.01, 0.5], {value: 0.1, step: 0.01, label: "Learning Rate (η)"});
viewof steps_cond = Inputs.range([1, 50], {value: 20, step: 1, label: "Steps"});
viewof btn_cond = Inputs.button("Reset / Re-run");//| echo: false
gd_traj_cond = {
btn_cond;
let x = -8;
let y = -4;
let traj = [{x: x, y: y, iter: 0}];
for(let i=1; i<=steps_cond; i++) {
x = x - lr_cond * x;
y = y - lr_cond * a_cond * y;
traj.push({x: x, y: y, iter: i});
}
return traj;
}
grid_cond = {
let pts = [];
for(let i=-10; i<=10; i+=0.5) {
for(let j=-5; j<=5; j+=0.25) {
pts.push({x: i, y: j, z: (i*i)/2 + a_cond * (j*j)/2});
}
}
return pts;
}
plot_cond = Plot.plot({
width: 800, height: 400,
x: {domain: [-10, 10], label: "Parameter x_1 (Flat)"},
y: {domain: [-5, 5], label: "Parameter x_2 (Steep)"},
color: {scheme: "viridis", legend: false},
marks: [
Plot.contour(grid_cond, {x: "x", y: "y", fill: "z", thresholds: 15, stroke: "white", strokeOpacity: 0.3}),
Plot.line(gd_traj_cond, {x: "x", y: "y", stroke: "red", strokeWidth: 2}),
Plot.dot(gd_traj_cond, {x: "x", y: "y", fill: "white", r: 3, title: d => "Step " + d.iter})
]
})//| echo: false
viewof lr_mom = Inputs.range([0.01, 0.2], {value: 0.05, step: 0.01, label: "Learning Rate (η)"});
viewof alpha_mom = Inputs.range([0.0, 0.99], {value: 0.9, step: 0.01, label: "Momentum (α)"});
viewof steps_mom = Inputs.range([1, 100], {value: 40, step: 1, label: "Steps"});
viewof btn_mom = Inputs.button("Reset / Re-run");//| echo: false
grid_mom_fixed = {
let pts = [];
for(let i=-10; i<=10; i+=0.5) {
for(let j=-5; j<=5; j+=0.25) {
pts.push({x: i, y: j, z: (i*i)/2 + 15 * (j*j)/2});
}
}
return pts;
}
traj_comparison = {
btn_mom;
const a = 15; // fixed ill-conditioning
let x1 = -8, y1 = -4; // GD
let x2 = -8, y2 = -4; // Momentum
let vx2 = 0, vy2 = 0;
let traj = [];
traj.push({x: x1, y: y1, algo: "Vanilla GD", iter: 0});
traj.push({x: x2, y: y2, algo: "Momentum", iter: 0});
for(let i=1; i<=steps_mom; i++) {
// GD step
x1 = x1 - lr_mom * x1;
y1 = y1 - lr_mom * a * y1;
traj.push({x: x1, y: y1, algo: "Vanilla GD", iter: i});
// Momentum step
let gx2 = x2;
let gy2 = a * y2;
vx2 = alpha_mom * vx2 - lr_mom * gx2;
vy2 = alpha_mom * vy2 - lr_mom * gy2;
x2 = x2 + vx2;
y2 = y2 + vy2;
traj.push({x: x2, y: y2, algo: "Momentum", iter: i});
}
return traj;
}
plot_mom = {
const plot = Plot.plot({
width: 800, height: 400,
x: {domain: [-10, 5], label: "Parameter x_1"},
y: {domain: [-5, 5], label: "Parameter x_2"},
color: {scheme: "viridis"},
marks: [
Plot.contour(grid_mom_fixed, {x: "x", y: "y", fill: "z", thresholds: 15, stroke: "white", strokeOpacity: 0.2}),
Plot.line(traj_comparison.filter(d => d.algo === "Vanilla GD"), {x: "x", y: "y", stroke: "gray", strokeWidth: 2}),
Plot.line(traj_comparison.filter(d => d.algo === "Momentum"), {x: "x", y: "y", stroke: "red", strokeWidth: 2}),
Plot.dot(traj_comparison.filter(d => d.algo === "Vanilla GD"), {x: "x", y: "y", fill: "gray", r: 3}),
Plot.dot(traj_comparison.filter(d => d.algo === "Momentum"), {x: "x", y: "y", fill: "red", r: 3})
]
});
const legend = Plot.legend({color: {domain: ["Vanilla GD", "Momentum"], range: ["gray", "red"]}});
return html`<div>${legend}${plot}</div>`;
}//| echo: false
gd_trajectory = {
reset_btn_1d; // react to button
let x = 0; // starting point
let traj = [{x: x, iter: 0}];
for(let i=1; i<=epochs_1d; i++) {
let grad = 4 * Math.pow(x, 3) - 4 * x + 0.5;
x = x - lr_1d * grad;
if(x > 2 || x < -2 || isNaN(x)) { x = (x>0? 2: -2); traj.push({x: x, iter: i}); break; }
traj.push({x: x, iter: i});
}
return traj;
}
x_range = d3.range(-1.5, 1.6, 0.05)
plot_func_1d = Plot.plot({
width: 800, height: 400,
x: {domain: [-1.5, 1.5], label: "Parameter x"},
y: {domain: [-1.5, 3], label: "Loss f(x)"},
marks: [
Plot.line(x_range, {x: d => d, y: d => Math.pow(d, 4) - 2 * Math.pow(d, 2) + 0.5*d, stroke: "steelblue", strokeWidth: 3}),
Plot.line(gd_trajectory, {x: "x", y: d => Math.pow(d.x, 4) - 2 * Math.pow(d.x, 2) + 0.5*d.x, stroke: "orange", strokeWidth: 2}),
Plot.dot(gd_trajectory, {x: "x", y: d => Math.pow(d.x, 4) - 2 * Math.pow(d.x, 2) + 0.5*d.x, fill: "red", r: d => d.iter == 0 ? 6 : 4, title: d => "Step " + d.iter})
]
})\[ \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \frac{\eta}{\sqrt{\hat{\mathbf{v}}_t} + \epsilon}\,\hat{\mathbf{m}}_t \]
//| echo: false
viewof lr_opt = Inputs.range([0.01, 1.0], {value: 0.1, step: 0.01, label: "GD/Mom/RMS LR"});
viewof lr_adam = Inputs.range([0.01, 1.0], {value: 0.5, step: 0.01, label: "ADAM LR"});
viewof steps_opt = Inputs.range([1, 200], {value: 50, step: 10, label: "Steps"});
viewof btn_opt = Inputs.button("Reset / Re-run");//| echo: false
// Skewed bowl: f(x,y) = x^2/2 + 10y^2 + xy
traj_multi = {
btn_opt;
let t = [];
let x_init = 3, y_init = -3;
const grad = (x,y) => [x + y, 20*y + x];
// 1. Vanilla GD
let x1=x_init, y1=y_init;
for(let i=0; i<=steps_opt; i++){
t.push({x:x1, y:y1, algo:"GD", iter:i});
let [gx, gy] = grad(x1, y1);
x1 -= lr_opt * gx; y1 -= lr_opt * gy;
}
// 2. Momentum
let x2=x_init, y2=y_init, vx2=0, vy2=0;
for(let i=0; i<=steps_opt; i++){
t.push({x:x2, y:y2, algo:"Momentum", iter:i});
let [gx, gy] = grad(x2, y2);
vx2 = 0.9*vx2 - lr_opt * gx;
vy2 = 0.9*vy2 - lr_opt * gy;
x2 += vx2; y2 += vy2;
}
// 3. RMSProp
let x3=x_init, y3=y_init, egx=0, egy=0;
for(let i=0; i<=steps_opt; i++){
t.push({x:x3, y:y3, algo:"RMSProp", iter:i});
let [gx, gy] = grad(x3, y3);
egx = 0.9*egx + 0.1*(gx*gx);
egy = 0.9*egy + 0.1*(gy*gy);
x3 -= (lr_opt / (Math.sqrt(egx) + 1e-8)) * gx;
y3 -= (lr_opt / (Math.sqrt(egy) + 1e-8)) * gy;
}
// 4. ADAM
let x4=x_init, y4=y_init, m1x=0, m1y=0, v2x=0, v2y=0;
for(let i=0; i<=steps_opt; i++){
t.push({x:x4, y:y4, algo:"ADAM", iter:i});
let [gx, gy] = grad(x4, y4);
m1x = 0.9*m1x + 0.1*gx; m1y = 0.9*m1y + 0.1*gy;
v2x = 0.999*v2x + 0.001*(gx*gx); v2y = 0.999*v2y + 0.001*(gy*gy);
let m1xh = m1x / (1 - Math.pow(0.9, i+1));
let m1yh = m1y / (1 - Math.pow(0.9, i+1));
let v2xh = v2x / (1 - Math.pow(0.999, i+1));
let v2yh = v2y / (1 - Math.pow(0.999, i+1));
x4 -= (lr_adam / (Math.sqrt(v2xh) + 1e-8)) * m1xh;
y4 -= (lr_adam / (Math.sqrt(v2yh) + 1e-8)) * m1yh;
}
return t;
}
grid_multi = {
let pts = [];
for(let i=-4; i<=4; i+=0.2) {
for(let j=-4; j<=4; j+=0.2) {
pts.push({x: i, y: j, z: (i*i)/2 + 10*(j*j) + i*j});
}
}
return pts;
}
plot_multi = {
const plot = Plot.plot({
width: 800, height: 400,
x: {domain: [-4, 4]},
y: {domain: [-4, 4]},
color: {scheme: "viridis"},
marks: [
Plot.contour(grid_multi, {x: "x", y: "y", fill: "z", thresholds: 20, stroke: "white", strokeOpacity: 0.2}),
Plot.line(traj_multi.filter(d=>d.algo==="GD"), {x: "x", y: "y", stroke: "gray", strokeWidth: 2}),
Plot.line(traj_multi.filter(d=>d.algo==="Momentum"), {x: "x", y: "y", stroke: "red", strokeWidth: 2}),
Plot.line(traj_multi.filter(d=>d.algo==="RMSProp"), {x: "x", y: "y", stroke: "blue", strokeWidth: 2}),
Plot.line(traj_multi.filter(d=>d.algo==="ADAM"), {x: "x", y: "y", stroke: "orange", strokeWidth: 2}),
Plot.dot(traj_multi.filter(d=>d.algo==="GD"), {x: "x", y: "y", fill: "gray", r: 3}),
Plot.dot(traj_multi.filter(d=>d.algo==="Momentum"), {x: "x", y: "y", fill: "red", r: 3}),
Plot.dot(traj_multi.filter(d=>d.algo==="RMSProp"), {x: "x", y: "y", fill: "blue", r: 3}),
Plot.dot(traj_multi.filter(d=>d.algo==="ADAM"), {x: "x", y: "y", fill: "orange", r: 3})
]
});
const legend = Plot.legend({color: {domain: ["GD", "Momentum", "RMSProp", "ADAM"], range: ["gray", "red", "blue", "orange"]}});
return html`<div>${legend}${plot}</div>`;
}
[INSERT: Selection flowchart or table tailored to materials science models (GNNs vs. simple MLPs).]
Generalization Error Analysis
//| echo: false
// Flat vs Sharp
x_grid_fs = d3.range(-1, 1, 0.01);
plot_fs = Plot.plot({
width: 800, height: 400,
x: {domain: [-1, 1], label: "Parameter Deviation"},
y: {domain: [0, 5], label: "Loss"},
color: {domain: ["Sharp Minimum", "Flat Minimum"], range: ["red", "blue"], legend: true},
marks: [
// True test losses (shifted)
Plot.line(x_grid_fs, {x: d => d, y: d => Math.pow(d - dshift, 2) * 50, stroke: "red", strokeWidth: 2, strokeDasharray: dshift === 0 ? "none" : "5,5"}),
Plot.line(x_grid_fs, {x: d => d, y: d => Math.pow(d - dshift, 2) * 2, stroke: "blue", strokeWidth: 2, strokeDasharray: dshift === 0 ? "none" : "5,5"}),
// Evaluation points (trained model doesn't change parameter, it evaluates at 0)
Plot.dot([{x: 0}], {x: "x", y: d => Math.pow(0 - dshift, 2) * 50, fill: "red", r: 8}),
Plot.dot([{x: 0}], {x: "x", y: d => Math.pow(0 - dshift, 2) * 2, fill: "blue", r: 8})
]
})