@@ -5,23 +5,35 @@ import { BaseWorker } from '../utils.js'
55
66class AutoencoderWorker extends BaseWorker {
77 constructor ( ) {
8- super ( 'js/view/worker/autoencoder_worker .js' , { type : 'module' } )
8+ super ( 'js/view/worker/model_worker .js' , { type : 'module' } )
99 }
1010
1111 initialize ( input_size , reduce_size , enc_layers , dec_layers , optimizer ) {
12- return this . _postMessage ( { mode : 'init' , input_size, reduce_size, enc_layers, dec_layers, optimizer } )
12+ return this . _postMessage ( {
13+ name : 'autoencoder' ,
14+ method : 'constructor' ,
15+ arguments : [ input_size , reduce_size , enc_layers , dec_layers , optimizer ] ,
16+ } )
17+ }
18+
19+ epoch ( ) {
20+ return this . _postMessage ( { name : 'autoencoder' , method : 'epoch' } ) . then ( r => r . data )
1321 }
1422
1523 fit ( train_x , iteration , rate , batch , rho ) {
16- return this . _postMessage ( { mode : 'fit' , x : train_x , iteration, rate, batch, rho } ) . then ( r => r . data )
24+ return this . _postMessage ( {
25+ name : 'autoencoder' ,
26+ method : 'fit' ,
27+ arguments : [ train_x , iteration , rate , batch , rho ] ,
28+ } ) . then ( r => r . data )
1729 }
1830
1931 predict ( x ) {
20- return this . _postMessage ( { mode : 'predict' , x : x } )
32+ return this . _postMessage ( { name : 'autoencoder' , method : ' predict', arguments : [ x ] } ) . then ( r => r . data )
2133 }
2234
2335 reduce ( x ) {
24- return this . _postMessage ( { mode : 'reduce' , x : x } ) . then ( r => r . data )
36+ return this . _postMessage ( { name : 'autoencoder' , method : ' reduce', arguments : [ x ] } ) . then ( r => r . data )
2537 }
2638}
2739
@@ -35,8 +47,8 @@ export default function (platform) {
3547 const fitModel = async ( ) => {
3648 if ( mode === 'AD' ) {
3749 const tx = platform . trainInput
38- const fite = await model . fit ( tx , + iteration . value , rate . value , batch . value , rho . value )
39- platform . plotLoss ( fite . loss )
50+ const loss = await model . fit ( tx , + iteration . value , rate . value , batch . value , rho . value )
51+ platform . plotLoss ( loss )
4052 const px = platform . testInput ( 4 )
4153 let pd = [ ] . concat ( tx , px )
4254 const e = await model . predict ( pd )
@@ -63,11 +75,11 @@ export default function (platform) {
6375 platform . trainResult = outliers
6476 platform . testResult ( outlier_tiles )
6577
66- epoch = fite . epoch
78+ epoch = await model . epoch ( )
6779 } else if ( mode === 'CT' ) {
6880 const step = 8
69- const fite = await model . fit ( platform . trainInput , + iteration . value , rate . value , batch . value , rho . value )
70- platform . plotLoss ( fite . loss )
81+ const loss = await model . fit ( platform . trainInput , + iteration . value , rate . value , batch . value , rho . value )
82+ platform . plotLoss ( loss )
7183 let p_mat = Matrix . fromArray ( await model . reduce ( platform . trainInput ) )
7284
7385 const t_mat = p_mat . argmax ( 1 ) . value . map ( v => v + 1 )
@@ -77,12 +89,12 @@ export default function (platform) {
7789 platform . trainResult = t_mat
7890 platform . testResult ( categories . value )
7991
80- epoch = fite . epoch
92+ epoch = await model . epoch ( )
8193 } else {
82- const fite = await model . fit ( platform . trainInput , + iteration . value , rate . value , batch . value , rho . value )
83- platform . plotLoss ( fite . loss )
94+ const loss = await model . fit ( platform . trainInput , + iteration . value , rate . value , batch . value , rho . value )
95+ platform . plotLoss ( loss )
8496 platform . trainResult = await model . reduce ( platform . trainInput )
85- epoch = fite . epoch
97+ epoch = await model . epoch ( )
8698 }
8799 }
88100
0 commit comments