Skip to content

Commit 85f16f4

Browse files
authored
Estimate bias parameter in LTS (#766)
1 parent 93760ee commit 85f16f4

File tree

3 files changed

+73
-8
lines changed

3 files changed

+73
-8
lines changed

lib/model/lts.js

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@ export default class LeastTrimmedSquaresRegression {
1414
}
1515

1616
_ls(x, y) {
17+
const m = x.mean(0)
18+
x = Matrix.sub(x, m)
1719
const xtx = x.tDot(x)
18-
return xtx.solve(x.tDot(y))
20+
const w = xtx.solve(x.tDot(y))
21+
y = Matrix.sub(y, x.dot(w))
22+
const b = y.mean(0)
23+
return [w, m, b]
1924
}
2025

2126
/**
@@ -27,10 +32,10 @@ export default class LeastTrimmedSquaresRegression {
2732
fit(x, y) {
2833
x = Matrix.fromArray(x)
2934
y = Matrix.fromArray(y)
30-
const xh = Matrix.resize(x, x.rows, x.cols + 1, 1)
3135

32-
const wls = this._ls(xh, y)
33-
const yls = xh.dot(wls)
36+
const [wls, mls, bls] = this._ls(x, y)
37+
const yls = Matrix.sub(x, mls).dot(wls)
38+
yls.add(bls)
3439
yls.sub(y)
3540
yls.mult(yls)
3641

@@ -39,10 +44,10 @@ export default class LeastTrimmedSquaresRegression {
3944

4045
const h = Math.max(1, Math.floor(r.length * this._h))
4146

42-
const xlts = xh.row(r.slice(0, h).map(v => v[1]))
47+
const xlts = x.row(r.slice(0, h).map(v => v[1]))
4348
const ylts = y.row(r.slice(0, h).map(v => v[1]))
4449

45-
this._w = this._ls(xlts, ylts)
50+
;[this._w, this._m, this._b] = this._ls(xlts, ylts)
4651
}
4752

4853
/**
@@ -53,7 +58,9 @@ export default class LeastTrimmedSquaresRegression {
5358
*/
5459
predict(x) {
5560
x = Matrix.fromArray(x)
56-
const xh = Matrix.resize(x, x.rows, x.cols + 1, 1)
57-
return xh.dot(this._w).toArray()
61+
x.sub(this._m)
62+
const y = x.dot(this._w)
63+
y.add(this._b)
64+
return y.toArray()
5865
}
5966
}

tests/lib/model/least_square.test.js

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,28 @@ test('fit', () => {
1515
const err = rmse(y, t)[0]
1616
expect(err).toBeLessThan(0.5)
1717
})
18+
19+
test('same bias column', () => {
20+
const model = new LeastSquare()
21+
const x = Matrix.randn(50, 2, 0, 5).toArray()
22+
const t = []
23+
for (let i = 0; i < x.length; i++) {
24+
t[i] = [x[i][0] + x[i][1] + (Math.random() - 0.5) / 10 + 5]
25+
}
26+
model.fit(x, t)
27+
const y = model.predict(x)
28+
29+
const x2 = Matrix.resize(Matrix.fromArray(x), x.length, x[0].length + 1, 1)
30+
const t2 = Matrix.fromArray(t)
31+
const w = x2.tDot(x2).solve(x2.tDot(t2))
32+
const y2 = x2.dot(w)
33+
34+
for (let i = 0; i < y.length; i++) {
35+
for (let j = 0; j < y[i].length; j++) {
36+
expect(y[i][j]).toBeCloseTo(y2.at(i, j))
37+
}
38+
}
39+
for (let i = 0; i < 2; i++) {
40+
expect(model._w.at(i, 0)).toBeCloseTo(w.at(i, 0))
41+
}
42+
})

tests/lib/model/lts.test.js

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,36 @@ test('fit', () => {
1515
const err = rmse(y, t)[0]
1616
expect(err).toBeLessThan(0.5)
1717
})
18+
19+
test('same bias column', () => {
20+
const model = new LeastTrimmedSquaresRegression()
21+
const x = Matrix.randn(50, 2, 0, 5).toArray()
22+
const t = []
23+
for (let i = 0; i < x.length; i++) {
24+
t[i] = [x[i][0] + x[i][1] + (Math.random() - 0.5) / 10]
25+
}
26+
model.fit(x, t)
27+
const y = model.predict(x)
28+
29+
const x2 = Matrix.resize(Matrix.fromArray(x), x.length, x[0].length + 1, 1)
30+
const t2 = Matrix.fromArray(t)
31+
const d = Matrix.sub(t2, x2.dot(x2.tDot(x2).solve(x2.tDot(t2))))
32+
d.map(v => v ** 2)
33+
const r = d.sum(1).value.map((v, i) => [v, i])
34+
r.sort((a, b) => a[0] - b[0])
35+
36+
const h = 45
37+
const xlts = x2.row(r.slice(0, h).map(v => v[1]))
38+
const ylts = t2.row(r.slice(0, h).map(v => v[1]))
39+
const w = xlts.tDot(xlts).solve(xlts.tDot(ylts))
40+
const y2 = x2.dot(w)
41+
42+
for (let i = 0; i < y.length; i++) {
43+
for (let j = 0; j < y[i].length; j++) {
44+
expect(y[i][j]).toBeCloseTo(y2.at(i, j))
45+
}
46+
}
47+
for (let i = 0; i < 2; i++) {
48+
expect(model._w.at(i, 0)).toBeCloseTo(w.at(i, 0))
49+
}
50+
})

0 commit comments

Comments
 (0)