diff --git a/example/index.ts b/example/index.ts
index 69f7928..bf8b576 100644
--- a/example/index.ts
+++ b/example/index.ts
@@ -1,11 +1,13 @@
import * as tts from '../src';
import Worker from './worker.ts?worker';
+import Worker2 from './worker2.ts?worker';
// required for e2e
Object.assign(window, { tts });
document.querySelector('#app')!.innerHTML = `
+
`
document.getElementById('btn')?.addEventListener('click', async () => {
@@ -26,3 +28,21 @@ document.getElementById('btn')?.addEventListener('click', async () => {
worker.terminate();
});
});
+
+const mainWorker = new Worker2();
+
+document.getElementById('btn2')?.addEventListener('click', async () => {
+ mainWorker.postMessage({
+ type: 'init',
+ text: "As the waves crashed against the shore, they carried tales of distant lands and adventures untold.",
+ voiceId: 'en_US-hfc_female-medium',
+ });
+
+ mainWorker.addEventListener('message', (event: MessageEvent<{ type: 'result', audio: Blob }>) => {
+ if (event.data.type != 'result') return;
+
+ const audio = new Audio();
+ audio.src = URL.createObjectURL(event.data.audio);
+ audio.play();
+ }, { once: true });
+});
diff --git a/example/worker.ts b/example/worker.ts
index 5775553..b95f187 100644
--- a/example/worker.ts
+++ b/example/worker.ts
@@ -3,10 +3,12 @@ import * as tts from '../src/index';
async function main(event: MessageEvent) {
if (event.data?.type != 'init') return;
+ const start = performance.now();
const blob = await tts.predict({
text: event.data.text,
voiceId: event.data.voiceId,
});
+ console.log('Time taken:', performance.now() - start);
self.postMessage({ type: 'result', audio: blob })
}
diff --git a/example/worker2.ts b/example/worker2.ts
new file mode 100644
index 0000000..407ef4c
--- /dev/null
+++ b/example/worker2.ts
@@ -0,0 +1,19 @@
+import * as tts from '../src/index';
+
+const start = performance.now();
+const TtsSession = await tts.TtsSession.create({
+ voiceId: 'en_US-hfc_female-medium',
+});
+console.log('Time taken to init session:', performance.now() - start);
+
+async function main(event: MessageEvent) {
+ if (event.data?.type != 'init') return;
+
+ const start = performance.now();
+ const blob = await TtsSession.predict(event.data.text);
+ console.log('Time taken:', performance.now() - start);
+
+ self.postMessage({ type: 'result', audio: blob })
+}
+
+self.addEventListener('message', main);
diff --git a/src/inference.ts b/src/inference.ts
index bc39141..e8fd7fd 100644
--- a/src/inference.ts
+++ b/src/inference.ts
@@ -1,65 +1,133 @@
-import { InferenceConfg, ProgressCallback } from "./types";
+import { InferenceConfg, ProgressCallback, VoiceId } from "./types";
import { HF_BASE, ONNX_BASE, PATH_MAP, WASM_BASE } from './fixtures';
import { readBlob, writeBlob } from './opfs';
import { fetchBlob } from './http.js';
import { pcm2wav } from './audio';
-/**
- * Run text to speech inference in new worker thread. Fetches the model
- * first, if it has not yet been saved to opfs yet.
- */
-export async function predict(config: InferenceConfg, callback?: ProgressCallback): Promise {
- const { createPiperPhonemize } = await import('./piper.js');
- const ort = await import('onnxruntime-web');
-
- const path = PATH_MAP[config.voiceId];
- const input = JSON.stringify([{ text: config.text.trim() }]);
-
- ort.env.allowLocalModels = false;
- ort.env.wasm.numThreads = navigator.hardwareConcurrency;
- ort.env.wasm.wasmPaths = ONNX_BASE;
-
- const modelConfigBlob = await getBlob(`${HF_BASE}/${path}.json`);
- const modelConfig = JSON.parse(await modelConfigBlob.text());
-
- const phonemeIds: string[] = await new Promise(async resolve => {
- const module = await createPiperPhonemize({
- print: (data: any) => {
- resolve(JSON.parse(data).phoneme_ids);
- },
- printErr: (message: any) => {
- throw new Error(message);
- },
- locateFile: (url: string) => {
- if (url.endsWith(".wasm")) return `${WASM_BASE}.wasm`;
- if (url.endsWith(".data")) return `${WASM_BASE}.data`;
- return url;
- }
- });
+interface TtsSessionOptions {
+ voiceId: VoiceId;
+ progress?: ProgressCallback;
+}
- module.callMain(["-l", modelConfig.espeak.voice, "--input", input, "--espeak_data", "/espeak-ng-data"]);
- });
+export class TtsSession {
+ ready = false;
+ voiceId: VoiceId;
+ waitReady: Promise;
+ #createPiperPhonemize?: (moduleArg?: {}) => any;
+ #modelConfig?: any;
+ #ort?: typeof import("onnxruntime-web");
+ #ortSession?: import("onnxruntime-web").InferenceSession
+ #progressCallback?: ProgressCallback;
- const speakerId = 0;
- const sampleRate = modelConfig.audio.sample_rate;
- const noiseScale = modelConfig.inference.noise_scale;
- const lengthScale = modelConfig.inference.length_scale;
- const noiseW = modelConfig.inference.noise_w;
-
- const modelBlob = await getBlob(`${HF_BASE}/${path}`, callback);
- const session = await ort.InferenceSession.create(await modelBlob.arrayBuffer());
- const feeds = {
- input: new ort.Tensor("int64", phonemeIds, [1, phonemeIds.length]),
- input_lengths: new ort.Tensor("int64", [phonemeIds.length]),
- scales: new ort.Tensor("float32", [noiseScale, lengthScale, noiseW])
+ constructor({ voiceId, progress }: TtsSessionOptions) {
+ this.voiceId = voiceId;
+ this.#progressCallback = progress;
+ this.waitReady = this.init();
}
- if (Object.keys(modelConfig.speaker_id_map).length) {
- Object.assign(feeds, { sid: new ort.Tensor("int64", [speakerId]) })
+
+ static async create(options: TtsSessionOptions) {
+ const session = new TtsSession(options);
+ await session.waitReady;
+ return session;
}
- const { output: { data: pcm } } = await session.run(feeds);
+ async init() {
+ const { createPiperPhonemize } = await import("./piper.js");
+ this.#createPiperPhonemize = createPiperPhonemize;
+ this.#ort = await import("onnxruntime-web");
+
+ this.#ort.env.allowLocalModels = false;
+ this.#ort.env.wasm.numThreads = navigator.hardwareConcurrency;
+ this.#ort.env.wasm.wasmPaths = ONNX_BASE;
+
+ const path = PATH_MAP[this.voiceId];
+ const modelConfigBlob = await getBlob(`${HF_BASE}/${path}.json`);
+ this.#modelConfig = JSON.parse(await modelConfigBlob.text());
+
+ const modelBlob = await getBlob(
+ `${HF_BASE}/${path}`,
+ this.#progressCallback
+ );
+ this.#ortSession = await this.#ort.InferenceSession.create(
+ await modelBlob.arrayBuffer()
+ );
+ }
+
+ async predict(text: string): Promise {
+ await this.waitReady; // wait for the session to be ready
+
+ const input = JSON.stringify([{ text: text.trim() }]);
+
+ const phonemeIds: string[] = await new Promise(async (resolve) => {
+ const module = await this.#createPiperPhonemize!({
+ print: (data: any) => {
+ resolve(JSON.parse(data).phoneme_ids);
+ },
+ printErr: (message: any) => {
+ throw new Error(message);
+ },
+ locateFile: (url: string) => {
+ if (url.endsWith(".wasm")) return `${WASM_BASE}.wasm`;
+ if (url.endsWith(".data")) return `${WASM_BASE}.data`;
+ return url;
+ },
+ });
- return new Blob([pcm2wav(pcm as Float32Array, 1, sampleRate)], { type: "audio/x-wav" });
+ module.callMain([
+ "-l",
+ this.#modelConfig.espeak.voice,
+ "--input",
+ input,
+ "--espeak_data",
+ "/espeak-ng-data",
+ ]);
+ });
+
+ const speakerId = 0;
+ const sampleRate = this.#modelConfig.audio.sample_rate;
+ const noiseScale = this.#modelConfig.inference.noise_scale;
+ const lengthScale = this.#modelConfig.inference.length_scale;
+ const noiseW = this.#modelConfig.inference.noise_w;
+
+ const session = this.#ortSession!;
+ const feeds = {
+ input: new this.#ort!.Tensor("int64", phonemeIds, [1, phonemeIds.length]),
+ input_lengths: new this.#ort!.Tensor("int64", [phonemeIds.length]),
+ scales: new this.#ort!.Tensor("float32", [
+ noiseScale,
+ lengthScale,
+ noiseW,
+ ]),
+ };
+ if (Object.keys(this.#modelConfig.speaker_id_map).length) {
+ Object.assign(feeds, {
+ sid: new this.#ort!.Tensor("int64", [speakerId]),
+ });
+ }
+
+ const {
+ output: { data: pcm },
+ } = await session.run(feeds);
+
+ return new Blob([pcm2wav(pcm as Float32Array, 1, sampleRate)], {
+ type: "audio/x-wav",
+ });
+ }
+}
+
+/**
+ * Run text to speech inference in new worker thread. Fetches the model
+ * first, if it has not yet been saved to opfs yet.
+ */
+export async function predict(
+ config: InferenceConfg,
+ callback?: ProgressCallback
+): Promise {
+ const session = new TtsSession({
+ voiceId: config.voiceId,
+ progress: callback,
+ });
+ return session.predict(config.text);
}
/**
diff --git a/tsconfig.json b/tsconfig.json
index 2057dee..1058710 100644
--- a/tsconfig.json
+++ b/tsconfig.json
@@ -3,7 +3,7 @@
"target": "ES2020",
"useDefineForClassFields": true,
"module": "ESNext",
- "lib": ["ES2020", "DOM", "DOM.Iterable","WebWorker"],
+ "lib": ["ES2022", "DOM", "DOM.Iterable","WebWorker"],
"skipLibCheck": true,
"allowJs": true,