Skip to content

Commit 0177ca3

Browse files
authored
Integrate workers (#978)
1 parent 0132991 commit 0177ca3

15 files changed

+214
-298
lines changed

js/view/autoencoder.js

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,35 @@ import { BaseWorker } from '../utils.js'
55

66
class AutoencoderWorker extends BaseWorker {
77
constructor() {
8-
super('js/view/worker/autoencoder_worker.js', { type: 'module' })
8+
super('js/view/worker/model_worker.js', { type: 'module' })
99
}
1010

1111
initialize(input_size, reduce_size, enc_layers, dec_layers, optimizer) {
12-
return this._postMessage({ mode: 'init', input_size, reduce_size, enc_layers, dec_layers, optimizer })
12+
return this._postMessage({
13+
name: 'autoencoder',
14+
method: 'constructor',
15+
arguments: [input_size, reduce_size, enc_layers, dec_layers, optimizer],
16+
})
17+
}
18+
19+
epoch() {
20+
return this._postMessage({ name: 'autoencoder', method: 'epoch' }).then(r => r.data)
1321
}
1422

1523
fit(train_x, iteration, rate, batch, rho) {
16-
return this._postMessage({ mode: 'fit', x: train_x, iteration, rate, batch, rho }).then(r => r.data)
24+
return this._postMessage({
25+
name: 'autoencoder',
26+
method: 'fit',
27+
arguments: [train_x, iteration, rate, batch, rho],
28+
}).then(r => r.data)
1729
}
1830

1931
predict(x) {
20-
return this._postMessage({ mode: 'predict', x: x })
32+
return this._postMessage({ name: 'autoencoder', method: 'predict', arguments: [x] }).then(r => r.data)
2133
}
2234

2335
reduce(x) {
24-
return this._postMessage({ mode: 'reduce', x: x }).then(r => r.data)
36+
return this._postMessage({ name: 'autoencoder', method: 'reduce', arguments: [x] }).then(r => r.data)
2537
}
2638
}
2739

@@ -35,8 +47,8 @@ export default function (platform) {
3547
const fitModel = async () => {
3648
if (mode === 'AD') {
3749
const tx = platform.trainInput
38-
const fite = await model.fit(tx, +iteration.value, rate.value, batch.value, rho.value)
39-
platform.plotLoss(fite.loss)
50+
const loss = await model.fit(tx, +iteration.value, rate.value, batch.value, rho.value)
51+
platform.plotLoss(loss)
4052
const px = platform.testInput(4)
4153
let pd = [].concat(tx, px)
4254
const e = await model.predict(pd)
@@ -63,11 +75,11 @@ export default function (platform) {
6375
platform.trainResult = outliers
6476
platform.testResult(outlier_tiles)
6577

66-
epoch = fite.epoch
78+
epoch = await model.epoch()
6779
} else if (mode === 'CT') {
6880
const step = 8
69-
const fite = await model.fit(platform.trainInput, +iteration.value, rate.value, batch.value, rho.value)
70-
platform.plotLoss(fite.loss)
81+
const loss = await model.fit(platform.trainInput, +iteration.value, rate.value, batch.value, rho.value)
82+
platform.plotLoss(loss)
7183
let p_mat = Matrix.fromArray(await model.reduce(platform.trainInput))
7284

7385
const t_mat = p_mat.argmax(1).value.map(v => v + 1)
@@ -77,12 +89,12 @@ export default function (platform) {
7789
platform.trainResult = t_mat
7890
platform.testResult(categories.value)
7991

80-
epoch = fite.epoch
92+
epoch = await model.epoch()
8193
} else {
82-
const fite = await model.fit(platform.trainInput, +iteration.value, rate.value, batch.value, rho.value)
83-
platform.plotLoss(fite.loss)
94+
const loss = await model.fit(platform.trainInput, +iteration.value, rate.value, batch.value, rho.value)
95+
platform.plotLoss(loss)
8496
platform.trainResult = await model.reduce(platform.trainInput)
85-
epoch = fite.epoch
97+
epoch = await model.epoch()
8698
}
8799
}
88100

js/view/gan.js

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,36 @@ import { BaseWorker, specialCategory } from '../utils.js'
44

55
class GANWorker extends BaseWorker {
66
constructor() {
7-
super('js/view/worker/gan_worker.js', { type: 'module' })
7+
super('js/view/worker/model_worker.js', { type: 'module' })
88
}
99

1010
initialize(noise_dim, g_hidden, d_hidden, g_opt, d_opt, class_size, type) {
1111
this._type = type
12-
return this._postMessage({ mode: 'init', noise_dim, g_hidden, d_hidden, g_opt, d_opt, class_size, type })
12+
return this._postMessage({
13+
name: 'gan',
14+
method: 'constructor',
15+
arguments: [noise_dim, g_hidden, d_hidden, g_opt, d_opt, class_size, type],
16+
})
17+
}
18+
19+
epoch() {
20+
return this._postMessage({ name: 'gan', method: 'epoch' }).then(r => r.data)
1321
}
1422

1523
fit(train_x, train_y, iteration, gen_rate, dis_rate, batch) {
16-
return this._postMessage({ mode: 'fit', x: train_x, y: train_y, iteration, gen_rate, dis_rate, batch }).then(
17-
r => r.data
18-
)
24+
return this._postMessage({
25+
name: 'gan',
26+
method: 'fit',
27+
arguments: [train_x, train_y, iteration, gen_rate, dis_rate, batch],
28+
}).then(r => r.data)
1929
}
2030

2131
prob(x, y) {
22-
return this._postMessage({ mode: 'prob', x: x, y: y }).then(r => r.data)
32+
return this._postMessage({ name: 'gan', method: 'prob', arguments: [x, y] }).then(r => r.data)
2333
}
2434

2535
generate(n, y) {
26-
return this._postMessage({ mode: 'generate', n: n, y: y }).then(r => r.data)
36+
return this._postMessage({ name: 'gan', method: 'generate', arguments: [n, y] }).then(r => r.data)
2737
}
2838
}
2939

@@ -46,9 +56,9 @@ export default function (platform) {
4656

4757
const tx = platform.trainInput
4858
const ty = platform.trainOutput
49-
const fit_data = await model.fit(tx, ty, +iteration.value, gen_rate, dis_rate, batch.value)
50-
epoch = fit_data.epoch
51-
platform.plotLoss({ generator: fit_data.generatorLoss, discriminator: fit_data.discriminatorLoss })
59+
const loss = await model.fit(tx, ty, +iteration.value, gen_rate, dis_rate, batch.value)
60+
epoch = await model.epoch()
61+
platform.plotLoss({ generator: loss.generatorLoss, discriminator: loss.discriminatorLoss })
5262
if (platform.task === 'GR') {
5363
const gen_data = await model.generate(tx.length, ty)
5464
if (model._type === 'conditional') {

js/view/ladder_network.js

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,31 @@ import * as opt from '../../lib/model/nns/optimizer.js'
44

55
class LadderNetworkWorker extends BaseWorker {
66
constructor() {
7-
super('js/view/worker/ladder_network_worker.js', { type: 'module' })
7+
super('js/view/worker/model_worker.js', { type: 'module' })
88
}
99

1010
initialize(hidden_sizes, lambdas, activation, optimizer) {
11-
return this._postMessage({ mode: 'init', hidden_sizes, lambdas, activation, optimizer })
11+
return this._postMessage({
12+
name: 'ladder_network',
13+
method: 'constructor',
14+
arguments: [hidden_sizes, lambdas, activation, optimizer],
15+
})
16+
}
17+
18+
epoch() {
19+
return this._postMessage({ name: 'ladder_network', method: 'epoch' }).then(r => r.data)
1220
}
1321

1422
fit(train_x, train_y, iteration, rate, batch) {
15-
return this._postMessage({ mode: 'fit', x: train_x, y: train_y, iteration, rate, batch })
23+
return this._postMessage({
24+
name: 'ladder_network',
25+
method: 'fit',
26+
arguments: [train_x, train_y, iteration, rate, batch],
27+
}).then(r => r.data)
1628
}
1729

1830
predict(x) {
19-
return this._postMessage({ mode: 'predict', x: x })
31+
return this._postMessage({ name: 'ladder_network', method: 'predict', arguments: [x] }).then(r => r.data)
2032
}
2133
}
2234

@@ -37,11 +49,10 @@ export default function (platform) {
3749
const dim = platform.datas.dimension
3850

3951
const ty = platform.trainOutput.map(v => v[0])
40-
const e = await model.fit(platform.trainInput, ty, +iteration.value, rate.value, batch.value)
41-
epoch = e.data.epoch
42-
platform.plotLoss({ labeled: e.data.labeledLoss, unlabeled: e.data.unlabeledLoss })
43-
const pred_e = await model.predict(platform.testInput(dim === 1 ? 2 : 4))
44-
const data = pred_e.data
52+
const loss = await model.fit(platform.trainInput, ty, +iteration.value, rate.value, batch.value)
53+
epoch = await model.epoch()
54+
platform.plotLoss({ labeled: loss.labeledLoss, unlabeled: loss.unlabeledLoss })
55+
const data = await model.predict(platform.testInput(dim === 1 ? 2 : 4))
4556
platform.testResult(data)
4657
}
4758

js/view/neuralnetwork.js

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,29 @@ import { BaseWorker } from '../utils.js'
55

66
class NNWorker extends BaseWorker {
77
constructor() {
8-
super('js/view/worker/neuralnetwork_worker.js', { type: 'module' })
8+
super('js/view/worker/model_worker.js', { type: 'module' })
99
}
1010

1111
initialize(layers, loss, optimizer) {
12-
return this._postMessage({ mode: 'init', layers, loss, optimizer })
12+
return this._postMessage({
13+
name: 'neuralnetwork',
14+
method: 'fromObject',
15+
static: true,
16+
initialize: true,
17+
arguments: [layers, loss, optimizer],
18+
})
1319
}
1420

1521
fit(train_x, train_y, iteration, rate, batch) {
16-
return this._postMessage({ mode: 'fit', x: train_x, y: train_y, iteration, rate, batch })
22+
return this._postMessage({
23+
name: 'neuralnetwork',
24+
method: 'fit',
25+
arguments: [train_x, train_y, iteration, rate, batch],
26+
})
1727
}
1828

1929
predict(x) {
20-
return this._postMessage({ mode: 'predict', x: x })
30+
return this._postMessage({ name: 'neuralnetwork', method: 'calc', arguments: [x] })
2131
}
2232
}
2333

js/view/rnn.js

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,31 @@ import { BaseWorker } from '../utils.js'
33

44
class RNNWorker extends BaseWorker {
55
constructor() {
6-
super('js/view/worker/rnn_worker.js', { type: 'module' })
6+
super('js/view/worker/model_worker.js', { type: 'module' })
77
}
88

99
initialize(method, window, unit, out_size, optimizer) {
10-
return this._postMessage({ mode: 'init', method, window, unit, out_size, optimizer })
10+
return this._postMessage({
11+
name: 'rnn',
12+
method: 'constructor',
13+
arguments: [method, window, unit, out_size, optimizer],
14+
})
15+
}
16+
17+
epoch() {
18+
return this._postMessage({ name: 'rnn', method: 'epoch' }).then(r => r.data)
1119
}
1220

1321
fit(train_x, train_y, iteration, rate, batch) {
14-
return this._postMessage({ mode: 'fit', x: train_x, y: train_y, iteration, rate, batch })
22+
return this._postMessage({
23+
name: 'rnn',
24+
method: 'fit',
25+
arguments: [train_x, train_y, iteration, rate, batch],
26+
}).then(r => r.data)
1527
}
1628

1729
predict(x, k) {
18-
return this._postMessage({ mode: 'predict', x, k })
30+
return this._postMessage({ name: 'rnn', method: 'predict', arguments: [x, k] }).then(r => r.data)
1931
}
2032
}
2133

@@ -27,11 +39,16 @@ export default function (platform) {
2739
let epoch = 0
2840

2941
const fitModel = async () => {
30-
const e = await model.fit(platform.trainInput, platform.trainInput, +iteration.value, rate.value, batch.value)
31-
epoch = e.data.epoch
32-
platform.plotLoss(e.data.loss)
33-
const pred_e = await model.predict(platform.trainInput, predCount.value)
34-
platform.trainResult = pred_e.data
42+
const loss = await model.fit(
43+
platform.trainInput,
44+
platform.trainInput,
45+
+iteration.value,
46+
rate.value,
47+
batch.value
48+
)
49+
epoch = await model.epoch()
50+
platform.plotLoss(loss)
51+
platform.trainResult = await model.predict(platform.trainInput, predCount.value)
3552
}
3653

3754
const method = controller.select(['rnn', 'LSTM', 'GRU'])

js/view/vae.js

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,34 @@ import { BaseWorker } from '../utils.js'
44

55
class VAEWorker extends BaseWorker {
66
constructor() {
7-
super('js/view/worker/vae_worker.js', { type: 'module' })
7+
super('js/view/worker/model_worker.js', { type: 'module' })
88
}
99

1010
initialize(in_size, noise_dim, enc_layers, dec_layers, optimizer, class_size, type) {
1111
this._type = type
1212
return this._postMessage({
13-
mode: 'init',
14-
in_size,
15-
noise_dim,
16-
enc_layers,
17-
dec_layers,
18-
optimizer,
19-
class_size,
20-
type,
13+
name: 'vae',
14+
method: 'constructor',
15+
arguments: [in_size, noise_dim, enc_layers, dec_layers, optimizer, class_size, type],
2116
})
2217
}
2318

19+
epoch() {
20+
return this._postMessage({ name: 'vae', method: 'epoch' }).then(r => r.data)
21+
}
22+
2423
fit(x, y, iteration, rate, batch) {
25-
return this._postMessage({ mode: 'fit', x, y, iteration, rate, batch })
24+
return this._postMessage({ name: 'vae', method: 'fit', arguments: [x, y, iteration, rate, batch] }).then(
25+
r => r.data
26+
)
2627
}
2728

2829
predict(x, y) {
29-
return this._postMessage({ mode: 'predict', x, y })
30+
return this._postMessage({ name: 'vae', method: 'predict', arguments: [x, y] }).then(r => r.data)
3031
}
3132

3233
reduce(x, y) {
33-
return this._postMessage({ mode: 'reduce', x, y })
34+
return this._postMessage({ name: 'vae', method: 'reduce', arguments: [x, y] }).then(r => r.data)
3435
}
3536
}
3637

@@ -48,15 +49,19 @@ export default function (platform) {
4849
return
4950
}
5051

51-
const e = await model.fit(platform.trainInput, platform.trainOutput, +iteration.value, rate.value, batch.value)
52-
epoch = e.data.epoch
53-
platform.plotLoss(e.data.loss)
52+
const loss = await model.fit(
53+
platform.trainInput,
54+
platform.trainOutput,
55+
+iteration.value,
56+
rate.value,
57+
batch.value
58+
)
59+
epoch = await model.epoch()
60+
platform.plotLoss(loss)
5461
if (mode === 'DR') {
55-
const e = await model.reduce(platform.trainInput, platform.trainOutput)
56-
platform.trainResult = e.data
62+
platform.trainResult = await model.reduce(platform.trainInput, platform.trainOutput)
5763
} else if (mode === 'GR') {
58-
const e = await model.predict(platform.trainInput, platform.trainOutput)
59-
const data = e.data
64+
const data = await model.predict(platform.trainInput, platform.trainOutput)
6065
if (model._type === 'conditional') {
6166
platform.trainResult = [data, platform.trainOutput]
6267
} else {
@@ -66,8 +71,7 @@ export default function (platform) {
6671
}
6772

6873
const genValues = async () => {
69-
const e = await model.predict(platform.trainInput, platform.trainOutput)
70-
const data = e.data
74+
const data = await model.predict(platform.trainInput, platform.trainOutput)
7175
if (type.value === 'conditional') {
7276
platform.trainResult = [data, platform.trainOutput]
7377
} else {

0 commit comments

Comments
 (0)