chart3 = {
const width = 800;
const height = 500;
const margin = {top: 20, right: 30, bottom: 40, left: 50};
const ground_truth = (x) => Math.sin(1.5 * Math.PI * x);
const raw_x = [0.02, 0.05, 0.1, 0.15, 0.22, 0.28, 0.35, 0.40, 0.45, 0.52, 0.58, 0.65, 0.70, 0.78, 0.85, 0.90, 0.95, 0.98];
const train_data = raw_x.map(x => ({
x: x,
y: ground_truth(x) + (Math.sin(x * 1234) * 0.25)
}));
const test_x = [0.01, 0.08, 0.12, 0.2, 0.25, 0.3, 0.38, 0.42, 0.48, 0.55, 0.6, 0.68, 0.75, 0.8, 0.88, 0.92, 0.96, 0.99];
const test_data = test_x.map(x => ({
x: x,
y: ground_truth(x) + (Math.sin(x * 4321) * 0.25)
}));
const svg = d3.create("svg")
.attr("viewBox", [0, 0, width, height])
.style("background", "none");
const leftWidth = width * 0.55;
const rightWidth = width * 0.45;
const xLeft = d3.scaleLinear().domain([0, 1]).range([margin.left, leftWidth - margin.right/2]);
const yLeft = d3.scaleLinear().domain([-1.8, 1.8]).range([height - margin.bottom, margin.top]);
svg.append("g")
.attr("transform", `translate(0,${yLeft(0)})`)
.call(d3.axisBottom(xLeft).ticks(5))
.attr("color", "#7f8c8d");
svg.append("g")
.attr("transform", `translate(${margin.left},0)`)
.call(d3.axisLeft(yLeft).ticks(5))
.attr("color", "#7f8c8d");
const truth_pts = [];
for(let x=0; x<=1; x+=0.01) truth_pts.push([x, ground_truth(x)]);
svg.append("path")
.datum(truth_pts)
.attr("fill", "none")
.attr("stroke", "#7f8c8d")
.attr("stroke-width", 2)
.attr("stroke-dasharray", "5,5")
.attr("d", d3.line().x(d => xLeft(d[0])).y(d => yLeft(d[1])));
const fit_poly = (data, degree) => {
const X = data.map(d => {
let row = [];
for(let j=0; j<=degree; j++) row.push(Math.pow(d.x, j));
return row;
});
const Y = data.map(d => d.y);
const XT = [];
for(let j=0; j<=degree; j++) {
let row = [];
for(let i=0; i<data.length; i++) row.push(X[i][j]);
XT.push(row);
}
const XTX = [];
for(let i=0; i<=degree; i++) {
let row = [];
for(let j=0; j<=degree; j++) {
let sum = 0;
for(let k=0; k<data.length; k++) sum += XT[i][k] * X[k][j];
row.push(sum);
}
XTX.push(row);
}
const XTY = [];
for(let i=0; i<=degree; i++) {
let sum = 0;
for(let k=0; k<data.length; k++) sum += XT[i][k] * Y[k];
XTY.push(sum);
}
for(let i=0; i<=degree; i++) {
let max_el = Math.abs(XTX[i][i]);
let max_row = i;
for(let k=i+1; k<=degree; k++) {
if (Math.abs(XTX[k][i]) > max_el) {
max_el = Math.abs(XTX[k][i]);
max_row = k;
}
}
let tmp = XTX[i]; XTX[i] = XTX[max_row]; XTX[max_row] = tmp;
let tmpY = XTY[i]; XTY[i] = XTY[max_row]; XTY[max_row] = tmpY;
for (let k=i+1; k<=degree; k++) {
let c = -XTX[k][i] / XTX[i][i];
for(let j=i; j<=degree; j++) {
if(i===j) XTX[k][j] = 0;
else XTX[k][j] += c * XTX[i][j];
}
XTY[k] += c * XTY[i];
}
}
let w = new Array(degree+1).fill(0);
for(let i=degree; i>=0; i--) {
if(Math.abs(XTX[i][i]) < 1e-12) { w[i] = 0; continue; }
w[i] = XTY[i] / XTX[i][i];
for(let k=i-1; k>=0; k--) {
XTY[k] -= XTX[k][i] * w[i];
}
}
return w;
};
const predict = (x, weights) => {
let sum = 0;
for(let j=0; j<weights.length; j++) sum += weights[j] * Math.pow(x, j);
return sum;
};
const calc_mse = (data, weights) => {
let sum = 0;
for(let d of data) {
let err = d.y - predict(d.x, weights);
sum += err * err;
}
return sum / data.length;
};
const weights = fit_poly(train_data, poly_degree);
const fit_pts = [];
for(let x=0; x<=1; x+=0.01) {
let y = predict(x, weights);
// clip wild oscillations to bounding box visually
if(y > 2) y = 2; if(y < -2) y = -2;
fit_pts.push([x, y]);
}
svg.append("path")
.datum(fit_pts)
.attr("fill", "none")
.attr("stroke", "#e74c3c")
.attr("stroke-width", 3)
.attr("d", d3.line().x(d => xLeft(d[0])).y(d => yLeft(d[1])));
svg.append("g")
.selectAll("circle")
.data(train_data)
.join("circle")
.attr("cx", d => xLeft(d.x))
.attr("cy", d => yLeft(d.y))
.attr("r", 5)
.attr("fill", "#3498db")
.attr("stroke", "#2980b9");
const xRight = d3.scaleLinear().domain([1, 10]).range([leftWidth + margin.left, width - margin.right]);
const yRight = d3.scaleLog().domain([0.01, 100]).range([height - margin.bottom, margin.top]);
svg.append("g")
.attr("transform", `translate(0,${height - margin.bottom})`)
.call(d3.axisBottom(xRight).ticks(5))
.attr("color", "#7f8c8d");
svg.append("g")
.attr("transform", `translate(${leftWidth + margin.left},0)`)
.call(d3.axisLeft(yRight).ticks(4).tickFormat(d => d))
.attr("color", "#7f8c8d");
const train_errs = [];
const test_errs = [];
for(let d=1; d<=poly_degree; d++) {
let w = fit_poly(train_data, d);
train_errs.push([d, Math.max(0.01, Math.min(100, calc_mse(train_data, w)))]);
test_errs.push([d, Math.max(0.01, Math.min(100, calc_mse(test_data, w)))]);
}
svg.append("path")
.datum(train_errs)
.attr("fill", "none")
.attr("stroke", "#3498db")
.attr("stroke-width", 2)
.attr("d", d3.line().x(d => xRight(d[0])).y(d => yRight(d[1])));
svg.append("g")
.selectAll("circle.train")
.data(train_errs)
.join("circle")
.attr("class", "train")
.attr("cx", d => xRight(d[0]))
.attr("cy", d => yRight(d[1]))
.attr("r", 4)
.attr("fill", "#3498db");
svg.append("path")
.datum(test_errs)
.attr("fill", "none")
.attr("stroke", "#e67e22")
.attr("stroke-width", 2)
.attr("d", d3.line().x(d => xRight(d[0])).y(d => yRight(d[1])));
svg.append("g")
.selectAll("circle.test")
.data(test_errs)
.join("circle")
.attr("class", "test")
.attr("cx", d => xRight(d[0]))
.attr("cy", d => yRight(d[1]))
.attr("r", 4)
.attr("fill", "#e67e22");
svg.append("text").attr("x", margin.left + 10).attr("y", margin.top).text("Training Data").attr("fill", "#3498db").attr("font-size", 14);
svg.append("text").attr("x", margin.left + 10).attr("y", margin.top + 20).text("True Function").attr("fill", "#7f8c8d").attr("font-size", 14);
svg.append("text").attr("x", margin.left + 10).attr("y", margin.top + 40).text("Model Fit").attr("fill", "#e74c3c").attr("font-size", 14);
svg.append("text").attr("x", leftWidth + margin.left + 10).attr("y", margin.top).text("Train Error").attr("fill", "#3498db").attr("font-size", 14);
svg.append("text").attr("x", leftWidth + margin.left + 10).attr("y", margin.top + 20).text("Test Error").attr("fill", "#e67e22").attr("font-size", 14);
return svg.node();
}