Skip to content

Commit 2e50d68

Browse files
authored
Change dynamic import in library section to static import (#850)
1 parent 754c032 commit 2e50d68

File tree

5 files changed

+121
-21
lines changed

5 files changed

+121
-21
lines changed

create_import_list.js

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,5 +205,18 @@ import Tensor from '../../../util/tensor.js'
205205
await fs.promises.writeFile(layerDir + '/index.js', '// This file is generated automatically.\n' + code + typeCode)
206206
}
207207

208-
await createEntrypoint()
208+
const createONNXOperatorlist = async () => {
209+
const operatorsDir = './lib/model/nns/onnx/operators'
210+
const files = await fs.promises.readdir(operatorsDir)
211+
let code = ''
212+
for (const file of files) {
213+
if (file !== 'index.js' && file.endsWith('.js')) {
214+
code += `export { default as ${file.slice(0, -3)} } from './${file}'\n`
215+
}
216+
}
217+
await fs.promises.writeFile(operatorsDir + '/index.js', '// This file is generated automatically.\n' + code)
218+
}
219+
209220
await createLayerlist()
221+
await createONNXOperatorlist()
222+
await createEntrypoint()

lib/model/nns/graph.js

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ import Matrix from '../../util/matrix.js'
22
import { NeuralnetworkException } from '../neuralnetwork.js'
33
import Layer from './layer/base.js'
44
import { InputLayer, OutputLayer } from './layer/index.js'
5-
6-
let ONNXImporter = null
5+
import ONNXImporter from './onnx/onnx_importer.js'
76

87
/**
98
* @typedef {import("./layer/index").PlainLayerObject & {input?: string | string[], name?: string}} LayerObject
@@ -87,9 +86,6 @@ export default class ComputationalGraph {
8786
* @returns {Promise<ComputationalGraph>} Loaded graph
8887
*/
8988
static async fromONNX(buffer) {
90-
if (!ONNXImporter) {
91-
ONNXImporter = (await import('./onnx/onnx_importer.js')).default
92-
}
9389
return ComputationalGraph.fromObject(await ONNXImporter.load(buffer))
9490
}
9591

lib/model/nns/onnx/onnx_importer.js

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export { default as onnx } from './onnx_pb.js'
44
import input from './operators/input.js'
55
import output from './operators/output.js'
66

7-
const operators = {}
7+
import * as operators from './operators/index.js'
88

99
/**
1010
* ONNX importer
@@ -31,16 +31,11 @@ export default class ONNXImporter {
3131
}
3232
for (const node of graph.getNodeList()) {
3333
const opType = node.getOpType()
34-
if (!operators[opType]) {
35-
try {
36-
const module = await import(`./operators/${opType.toLowerCase()}.js`)
37-
operators[opType] = module.default
38-
} catch (e) {
39-
console.error(opType, e.name, e.message)
40-
continue
41-
}
34+
if (!operators[opType.toLowerCase()]) {
35+
console.error(`Unimplemented operator ${opType}.`)
36+
continue
4237
}
43-
const op = operators[opType]
38+
const op = operators[opType.toLowerCase()]
4439
nodes.push(...op.import(model, node))
4540
}
4641
for (const node of graph.getOutputList()) {
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// This file is generated automatically.
2+
export { default as abs } from './abs.js'
3+
export { default as acos } from './acos.js'
4+
export { default as acosh } from './acosh.js'
5+
export { default as add } from './add.js'
6+
export { default as and } from './and.js'
7+
export { default as argmax } from './argmax.js'
8+
export { default as argmin } from './argmin.js'
9+
export { default as asin } from './asin.js'
10+
export { default as asinh } from './asinh.js'
11+
export { default as atan } from './atan.js'
12+
export { default as atanh } from './atanh.js'
13+
export { default as averagepool } from './averagepool.js'
14+
export { default as batchnormalization } from './batchnormalization.js'
15+
export { default as bitshift } from './bitshift.js'
16+
export { default as bitwiseand } from './bitwiseand.js'
17+
export { default as bitwisenot } from './bitwisenot.js'
18+
export { default as bitwiseor } from './bitwiseor.js'
19+
export { default as bitwisexor } from './bitwisexor.js'
20+
export { default as ceil } from './ceil.js'
21+
export { default as celu } from './celu.js'
22+
export { default as clip } from './clip.js'
23+
export { default as concat } from './concat.js'
24+
export { default as constant } from './constant.js'
25+
export { default as conv } from './conv.js'
26+
export { default as cos } from './cos.js'
27+
export { default as cosh } from './cosh.js'
28+
export { default as div } from './div.js'
29+
export { default as dropout } from './dropout.js'
30+
export { default as elu } from './elu.js'
31+
export { default as equal } from './equal.js'
32+
export { default as erf } from './erf.js'
33+
export { default as exp } from './exp.js'
34+
export { default as flatten } from './flatten.js'
35+
export { default as floor } from './floor.js'
36+
export { default as gemm } from './gemm.js'
37+
export { default as globalaveragepool } from './globalaveragepool.js'
38+
export { default as globallppool } from './globallppool.js'
39+
export { default as globalmaxpool } from './globalmaxpool.js'
40+
export { default as greater } from './greater.js'
41+
export { default as greaterorequal } from './greaterorequal.js'
42+
export { default as hardsigmoid } from './hardsigmoid.js'
43+
export { default as hardswish } from './hardswish.js'
44+
export { default as identity } from './identity.js'
45+
export { default as input } from './input.js'
46+
export { default as isinf } from './isinf.js'
47+
export { default as isnan } from './isnan.js'
48+
export { default as layernormalization } from './layernormalization.js'
49+
export { default as leakyrelu } from './leakyrelu.js'
50+
export { default as less } from './less.js'
51+
export { default as lessorequal } from './lessorequal.js'
52+
export { default as log } from './log.js'
53+
export { default as logsoftmax } from './logsoftmax.js'
54+
export { default as lppool } from './lppool.js'
55+
export { default as lrn } from './lrn.js'
56+
export { default as matmul } from './matmul.js'
57+
export { default as max } from './max.js'
58+
export { default as maxpool } from './maxpool.js'
59+
export { default as mean } from './mean.js'
60+
export { default as min } from './min.js'
61+
export { default as mish } from './mish.js'
62+
export { default as mod } from './mod.js'
63+
export { default as mul } from './mul.js'
64+
export { default as neg } from './neg.js'
65+
export { default as not } from './not.js'
66+
export { default as or } from './or.js'
67+
export { default as output } from './output.js'
68+
export { default as pow } from './pow.js'
69+
export { default as prelu } from './prelu.js'
70+
export { default as reciprocal } from './reciprocal.js'
71+
export { default as reducel1 } from './reducel1.js'
72+
export { default as reducel2 } from './reducel2.js'
73+
export { default as reducelogsum } from './reducelogsum.js'
74+
export { default as reducelogsumexp } from './reducelogsumexp.js'
75+
export { default as reducemax } from './reducemax.js'
76+
export { default as reducemean } from './reducemean.js'
77+
export { default as reducemin } from './reducemin.js'
78+
export { default as reduceprod } from './reduceprod.js'
79+
export { default as reducesum } from './reducesum.js'
80+
export { default as reducesumsquare } from './reducesumsquare.js'
81+
export { default as relu } from './relu.js'
82+
export { default as reshape } from './reshape.js'
83+
export { default as round } from './round.js'
84+
export { default as selu } from './selu.js'
85+
export { default as shrink } from './shrink.js'
86+
export { default as sigmoid } from './sigmoid.js'
87+
export { default as sign } from './sign.js'
88+
export { default as sin } from './sin.js'
89+
export { default as sinh } from './sinh.js'
90+
export { default as softmax } from './softmax.js'
91+
export { default as softplus } from './softplus.js'
92+
export { default as softsign } from './softsign.js'
93+
export { default as sqrt } from './sqrt.js'
94+
export { default as sub } from './sub.js'
95+
export { default as sum } from './sum.js'
96+
export { default as tan } from './tan.js'
97+
export { default as tanh } from './tanh.js'
98+
export { default as thresholdedrelu } from './thresholdedrelu.js'
99+
export { default as transpose } from './transpose.js'
100+
export { default as xor } from './xor.js'

tests/lib/model/nns/onnx/onnx_importer.test.js

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,7 @@ describe('import', () => {
8080
expect(nodes).toHaveLength(2)
8181
expect(nodes.map(n => n.type)).toEqual(['input', 'output'])
8282
expect(console.error).toHaveBeenCalledTimes(1)
83-
expect(console.error).toHaveBeenCalledWith(
84-
'Cast',
85-
'Error',
86-
"Cannot find module './operators/cast.js' from 'lib/model/nns/onnx/onnx_importer.js'"
87-
)
83+
expect(console.error).toHaveBeenCalledWith('Unimplemented operator Cast.')
8884
})
8985
})
9086
})

0 commit comments

Comments
 (0)