diff --git a/README.md b/README.md index e105d12..3520930 100644 --- a/README.md +++ b/README.md @@ -63,8 +63,8 @@ Using `all` as wanted_asr parameter to main.py will attempt to start all ASRs fo | ✅ sp | CMU Sphinx | Open Source | Offline - docker | | ✅ vs | Alphacep Vosk | Open Source | Offline - docker | | ✅ cq | Coqui | Open Source | Offline - docker | +| ✅ nm | Nvidia NeMo | Open Source | Offline - docker | | ❌ sb | Speech Brain | Open Source | Offline - docker | -| ❌ nm | Nvidia NeMo | Open Source | Offline - docker | | ✅ gg | Google | Proprietary | API set env:`GOOGLE_APPLICATION_CREDENTIALS` | | ✅ az | Microsoft Azure | Proprietary | API set env:`AZURE_KEY` | | ✅ aw | Amazon | Proprietary | API set env:`AWS_ACCESS_KEY_ID`
+`AWS_SECRET_ACCESS_KEY` or aws configure| diff --git a/changelog.md b/changelog.md index 66beeec..2c65cf0 100644 --- a/changelog.md +++ b/changelog.md @@ -4,3 +4,4 @@ - 0.0.1 - initial release - 0.0.2 - fixed common corrections pack error with requirements. Fixed issue with wizard where no ASR can be selected - 0.0.3 - refactored ASRs so each is own file. Added arg switches for columns/enable_wer/text_normalization/hashing +- 0.0.4 - added NeMo as alpha diff --git a/models/dev_prune_delete_all.sh b/models/dev_prune_delete_all.sh new file mode 100755 index 0000000..7fbd2c2 --- /dev/null +++ b/models/dev_prune_delete_all.sh @@ -0,0 +1,2 @@ + +docker system prune -a --volumes diff --git a/models/sl-nemo/Dockerfile b/models/sl-nemo/Dockerfile new file mode 100644 index 0000000..1bf605a --- /dev/null +++ b/models/sl-nemo/Dockerfile @@ -0,0 +1,31 @@ +FROM python:3.9-slim + +ENV TZ=Europe/London +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +ENV MODELNAME='stt_en_contextnet_1024.nemo' +ENV MODELTYPE='EncDecRNNTBPEModel' +ENV MODELURL='https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_contextnet_1024/versions/1.9.0/files/stt_en_contextnet_1024.nemo' + +EXPOSE 3500 +COPY app /app +WORKDIR /app + +RUN apt update && apt-get install -y gcc curl python3-dev python3-pip ffmpeg \ + && pip install numpy==1.22.4 fastapi uvicorn Cython torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html \ + && pip install nemo_toolkit[asr] \ + && curl -L -o /app/$MODELNAME $MODELURL \ + && rm -rf /var/lib/apt/lists/* \ + && apt remove -y gcc curl \ + && apt autoremove -y + +HEALTHCHECK --interval=30s --timeout=5s --start-period=15s \ + CMD curl --fail http://localhost:3500/healthcheck || exit 1 + +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "3500"] + +# docker build -f Dockerfile . -t robmsmt/sl-coqui +# docker run -d --restart unless-stopped -p 3200:3200 robmsmt/sl-coqui-en-16k:latest +# docker run -it -p 3200:3200 robmsmt/sl-coqui +#docker run -it -p 3200:3200 robmsmt/sl-coqui-en-16k:latest +#docker commit my-broken-container && docker run -it my-broken-container /bin/bash diff --git a/models/sl-nemo/README.md b/models/sl-nemo/README.md new file mode 100644 index 0000000..ab8c01f --- /dev/null +++ b/models/sl-nemo/README.md @@ -0,0 +1,16 @@ +# Nemo + +## CONFIG +- Shortcode: ` nm ` +- Docker: ` robmsmt/sl-nemo-en-16k:latest ` +- InternalPort: ` 3500 ` +- ExternalPort: ` 3500 ` +- SampleRate: ` 16000 ` +- InterfaceType: ` docker-fastapi ` + +## CHANGES + - tbc + +## Notes +- Contextnet 1024 version +- Not used onnx yet diff --git a/models/sl-nemo/app/__init__.py b/models/sl-nemo/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/sl-nemo/app/main.py b/models/sl-nemo/app/main.py new file mode 100644 index 0000000..33342e5 --- /dev/null +++ b/models/sl-nemo/app/main.py @@ -0,0 +1,83 @@ +import os +from fastapi import FastAPI +from pydantic import BaseModel +import tempfile +from io import BytesIO +from base64 import b64decode +import argparse + +# import soundfile +# import numpy as np +# import onnxruntime as rt +# import nemo +import nemo.collections.asr as nemo_asr + +model = os.environ["MODELNAME"] + +app = FastAPI() +# nm = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name="QuartzNet15x5Base-En") +# print(nemo_asr.models.EncDecRNNTBPEModel.list_available_models()) +# nm = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name=args.model) +nm = nemo_asr.models.EncDecRNNTBPEModel.restore_from(model) +# +# +# enc_dec_ctc_models = [(x.pretrained_model_name, nemo_asr.models.EncDecCTCModel.from_pretrained(model_name=x.pretrained_model_name)) for x in nemo_asr.models.EncDecCTCModel.list_available_models() if "en" in x.pretrained_model_name] +# enc_dec_ctc_bpe_models = [(x.pretrained_model_name, nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name=x.pretrained_model_name)) for x in nemo_asr.models.EncDecCTCModelBPE.list_available_models() if "en" in x.pretrained_model_name] +# enc_dec_rnn_t_bpe_models = [(x.pretrained_model_name, nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name=x.pretrained_model_name)) for x in nemo_asr.models.EncDecRNNTBPEModel.list_available_models() if "en" in x.pretrained_model_name] +# enc_dec_rnn_t_models = [(x.pretrained_model_name, nemo_asr.models.EncDecRNNTModel.from_pretrained(model_name=x.pretrained_model_name)) for x in nemo_asr.models.EncDecRNNTModel.list_available_models() if "en" in x.pretrained_model_name] +# +# all_models = enc_dec_ctc_models + enc_dec_ctc_bpe_models + enc_dec_rnn_t_bpe_models + enc_dec_rnn_t_models +# print(all_models) + + +def disk_in_memory(wav_bytes): + """ + this spooled wav was chosen because it's much more efficient than writing to disk, + it effectively is writing to memory only and can still be read (by some applications) as a file + """ + with tempfile.SpooledTemporaryFile() as spooled_wav: + spooled_wav.write(wav_bytes) + spooled_wav.seek(0) + return BytesIO(spooled_wav.read()) + + +class Audio(BaseModel): + b64_wav: str + sr: int = 16000 + + +@app.get("/healthcheck") +async def healthcheck(): + return {"ok": "true"} + + +# Next, we instantiate all the necessary models directly from NVIDIA NGC +# Speech Recognition model + + +@app.post("/transcribe") +async def transcribe(audio: Audio): + + try: + wav_bytes = b64decode(audio.b64_wav.encode("utf-8")) + + # dm = disk_in_memory(wav_bytes) + # pcm, sample_rate = soundfile.read(dm, dtype="int16") + # todo cannot use disk memory since nemo lib needs file - in future replace with onnx: https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/ASR_with_NeMo.ipynb + + with tempfile.NamedTemporaryFile(mode="wb", delete=True, suffix=".wav") as f: + f.write(wav_bytes) + files_list = [f.name] + print(files_list) + transcript = nm.transcribe(paths2audio_files=files_list) + + return {"transcript": transcript[0][0]} + except: + raise + + +if __name__ == "__main__": + import uvicorn + + print("starting...") + uvicorn.run("main:app", host="0.0.0.0", port=3600) diff --git a/models/sl-nemo/autoinspect.sh b/models/sl-nemo/autoinspect.sh new file mode 100755 index 0000000..38c1b97 --- /dev/null +++ b/models/sl-nemo/autoinspect.sh @@ -0,0 +1,13 @@ + +# try docker log first + +# This should work for a running container +# +IMG=$(cat ./README.md | grep "Docker:" | awk '{print $4}') +ID=$(docker ps | grep $IMG | awk '{ print $1 }') + +echo "INSPECTING: $IMG with ID: $ID" +$(docker stop $(docker ps -a -q --filter ancestor="$IMG" --format="{{.ID}}")) +docker commit "$ID" broken-container1 && docker run -p 3500:3500 -it broken-container1 /bin/bash +# run with: uvicorn main:app --host 0.0.0.0 --port 3200 +# then hit test endpoint diff --git a/models/sl-nemo/build.sh b/models/sl-nemo/build.sh new file mode 100755 index 0000000..cb1ce58 --- /dev/null +++ b/models/sl-nemo/build.sh @@ -0,0 +1,18 @@ +set -e +#CWD=${PWD##*/} +DIR_PATH="$(dirname "${0}")" +IMG_REPO=$(cat ./README.md | grep "Docker:" | awk '{print $4}') +EXTPORT=$(cat ./README.md | grep "ExternalPort:" | awk '{print $4}') + +echo $DIR_PATH $IMG_REPO $EXTPORT +docker build -t $IMG_REPO "$DIR_PATH" + +set +e +# this is empty if the container crashes +echo $(docker ps -q -a --filter ancestor="$IMG_REPO" --format="{{.ID}}") +docker stop $(docker ps -q -a --filter ancestor="$IMG_REPO" --format="{{.ID}}") +set -e +docker run -p "$EXTPORT":"$EXTPORT" -d "$IMG_REPO" + +## to debug - kill container and start with: +#docker run --restart unless-stopped -p "$EXTPORT":"$EXTPORT" "$IMG_REPO" diff --git a/models/sl-nemo/push.sh b/models/sl-nemo/push.sh new file mode 100755 index 0000000..00e9e89 --- /dev/null +++ b/models/sl-nemo/push.sh @@ -0,0 +1,2 @@ +#to upload +docker push robmsmt/sl-nemo-en-16k diff --git a/models/sl-nemo/te_st_endpoint.py b/models/sl-nemo/te_st_endpoint.py new file mode 100755 index 0000000..e42adb4 --- /dev/null +++ b/models/sl-nemo/te_st_endpoint.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 + +import json +import base64 +import requests +import pprint as pp + + +def main(endpoint, wav_location): + + b64audio = base64.b64encode(open(wav_location, "rb").read()).decode("utf-8") + print(f"Length of b64 data is:{len(b64audio)}") + + json_message = {"b64_wav": b64audio, "sr": 16000} + + r = requests.post(endpoint, json=json_message) + print(f"Status code: {r.status_code}") + try: + response = r.json() + pp.pprint(response, indent=2) + except: + print("err") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="This file reads in a wav file and prints a CURL best to be piped to a file") + parser.add_argument("--endpoint", default="/transcribe", type=str) + parser.add_argument("--host", default="http://localhost:3500", type=str) + parser.add_argument("--wav", default="../../speechloop/data/simple_test/wavs/109938_zebra_ch0_16k.wav", type=str) + args = parser.parse_args() + url = args.host + args.endpoint + main(url, args.wav) diff --git a/models/sl-nemo/uninstall.sh b/models/sl-nemo/uninstall.sh new file mode 100755 index 0000000..692edc3 --- /dev/null +++ b/models/sl-nemo/uninstall.sh @@ -0,0 +1,8 @@ + +#rm .INSTALLED +IMG=$(cat ./README.md | grep "Docker:" | awk '{print $4}') +echo "Killing: $IMG" +docker rm $(docker stop $(docker ps -a -q --filter ancestor="$IMG" --format="{{.ID}}")) +echo "Deleting: $IMG" +docker image rm "$IMG" +echo "Finished removing: $IMG" diff --git a/setup.py b/setup.py index 6e4fd77..7d11a43 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ def read_file(fname): # python3 -m pip install --upgrade setuptools wheel setup( name="speechloop", - version="0.0.3", + version="0.0.4", author="robmsmt", author_email="robmsmt@gmail.com", description='A "keep it simple" collection of many speech recognition engines... Designed to help answer - what is the best ASR?', diff --git a/speechloop/asr.py b/speechloop/asr.py new file mode 100644 index 0000000..d4aee4a --- /dev/null +++ b/speechloop/asr.py @@ -0,0 +1,469 @@ +import abc +import asyncio +import base64 +import datetime +import json +import time +import warnings +import atexit +from os import environ +from time import monotonic +from urllib.parse import urlencode + +import docker +import requests +import websockets + +try: + from amazon_transcribe.client import TranscribeStreamingClient + from amazon_transcribe.handlers import TranscriptResultStreamHandler + from amazon_transcribe.model import TranscriptEvent +except ImportError as e: + print(f"Amazon not imported, for reason:{e}") + +try: + from google.cloud import speech_v1 as speech +except ImportError as e: + print(f"Google not imported, for reason:{e}") + +try: + import azure.cognitiveservices.speech as speechsdk +except ImportError as e: + print(f"Azure not imported, for reason:{e}") + +from speechloop.file_utils import valid_readable_file, disk_in_memory + +try: + DOCKER_CLIENT = docker.from_env() +except Exception as e: + warnings.warn("Either docker is not installed OR the docker client cannot be connected to. " "This might be ok if using just APIs") + + +class AsrNotRecognized(Exception): + pass + + +class InvalidConfigPath(Exception): + pass + + +class APIKeyError(Exception): + pass + + +class ASR(metaclass=abc.ABCMeta): + """ + Args: + name: The name of the model + asrtype: The type of the speech rec + """ + + def __init__(self, name, asrtype, sr=16000): + self.init_starttime = datetime.datetime.now() + self.name = name + self.asrtype = asrtype + self.init_finishtime = None + self.total_inf_time = datetime.datetime.now() - datetime.datetime.now() + self.verbose = True + self.kill_containers_on_quit = True + self.sr = sr + + if self.asrtype == "docker-local" and self.kill_containers_on_quit: + atexit.register(self.kill) + + def finish_init(self): + self.init_finishtime = datetime.datetime.now() + + def add_time(self, time_to_add): + self.total_inf_time += time_to_add + + def return_error(self, error_msg=""): + return error_msg + + def kill(self): + kill_container(self.dockerhub_url, verbose=self.verbose) + + @abc.abstractmethod + def execute_with_audio(self, audio): + pass + + def read_audio_file(self, path_to_audio): + if valid_readable_file(path_to_audio): + audio = open(path_to_audio, "rb").read() + return self.execute_with_audio(audio) + + +class Coqui(ASR): + """ + Coqui + """ + + def __init__(self): + super().__init__("cq", "docker-local") + self.uri = "http://localhost:3200/transcribe" + self.dockerhub_url = "robmsmt/sl-coqui-en-16k:latest" + self.shortname = self.dockerhub_url.rsplit("/")[-1].rsplit(":")[0] + self.longname = "coqui" + launch_container(self.dockerhub_url, {"3200/tcp": 3200}, verbose=self.verbose, delay=3) + self.finish_init() + + def execute_with_audio(self, audio): + b64 = base64.b64encode(audio).decode("utf-8") + json_message = {"b64_wav": b64, "sr": 16000} + r = requests.post(self.uri, json=json_message) + if r.status_code == 200: + try: + response = r.json()["transcript"] + return response + except KeyError: + return self.return_error() + else: + return self.return_error() + + +class Nemo(ASR): + """ + Nemo + """ + + def __init__(self): + super().__init__("nm", "docker-local") + self.uri = "http://localhost:3500/transcribe" + self.dockerhub_url = "robmsmt/sl-nemo-en-16k:latest" + self.shortname = self.dockerhub_url.rsplit("/")[-1].rsplit(":")[0] + self.longname = "nemo" + launch_container(self.dockerhub_url, {"3500/tcp": 3500}, verbose=self.verbose, delay=30) + self.finish_init() + + def execute_with_audio(self, audio): + b64 = base64.b64encode(audio).decode("utf-8") + json_message = {"b64_wav": b64, "sr": 16000} + r = requests.post(self.uri, json=json_message) + if r.status_code == 200: + try: + response = r.json()["transcript"] + return response + except KeyError: + return self.return_error() + else: + return self.return_error() + + +class Sphinx(ASR): + """ + Vosk + """ + + def __init__(self): + super().__init__("sp", "docker-local") + self.uri = "http://localhost:3000/transcribe" + self.dockerhub_url = "robmsmt/sl-sphinx-en-16k:latest" + self.shortname = self.dockerhub_url.rsplit("/")[-1].rsplit(":")[0] + self.longname = "sphinx" + launch_container(self.dockerhub_url, {"3000/tcp": 3000}, verbose=self.verbose, delay=2) + self.finish_init() + + def execute_with_audio(self, audio): + b64 = base64.b64encode(audio).decode("utf-8") + json_message = {"b64_wav": b64, "sr": 16000} + r = requests.post(self.uri, json=json_message) + if r.status_code == 200: + try: + response = r.json()["transcript"] + return response + except Exception as e: + warnings.warn(f"Engine did not return transcript: {e}") + return self.return_error() + else: + return self.return_error() + + +class Vosk(ASR): + """ + Vosk + """ + + def __init__(self): + super().__init__("vs", "docker-local") + self.uri = "ws://localhost:2800" + self.dockerhub_url = "robmsmt/sl-vosk-en-16k:latest" + self.shortname = self.dockerhub_url.rsplit("/")[-1].rsplit(":")[0] + self.longname = "vosk" + self.container_found = False + launch_container(self.dockerhub_url, {"2700/tcp": 2800}, verbose=self.verbose, delay=5) + self.finish_init() + + def execute_with_audio(self, audio): + audio_file = disk_in_memory(audio) + return asyncio.get_event_loop().run_until_complete(self.send_websocket(audio_file)) + + async def send_websocket(self, audio_file): + async with websockets.connect(self.uri) as websocket: + all_finals = "" + all_partials = [] + while True: + partial = None + data = audio_file.read(1024 * 16) + if len(data) == 0: + break + await websocket.send(data) + try: + partial_json = json.loads(await websocket.recv()) + partial = partial_json["partial"] + if partial: + all_partials.append(partial) + except KeyError: + all_finals += partial_json["text"] + " " + + await websocket.send('{"eof" : 1}') + final_result = json.loads(await websocket.recv())["text"] + + if len(all_finals) > 0 and len(final_result) == 0: + return all_finals + elif len(all_finals) > 0 and len(final_result) > 0: + return all_finals + f" {final_result}" + elif len(final_result) == 0: + return all_partials[-1] + else: + return final_result + + +class Azure(ASR): + """ + Sign up to Speech service at: https://portal.azure.com + create project and set one of the 2 keys to be passed in through OS ENV var: AZURE_KEY + """ + + def __init__(self, apikey=None): + + super().__init__("az", "cloud-api") + self.longname = "azure" + self.shortname = "az" + self.key = apikey if apikey is not None else environ.get("AZURE_KEY") # APIKEY param takes priority over ENV + self.location = "eastus" + self.language = "en-US" + self.profanity = "False" + self.start_time = monotonic() + self.credential_url = f"https://{self.location}.api.cognitive.microsoft.com/sts/v1.0/issueToken" + settings = urlencode({"language": self.language, "format": "simple", "profanity": self.profanity}) + self.url = f"https://{self.location}.stt.speech.microsoft.com/speech/recognition/conversation/cognitiveservices/v1?{settings}" + + self.renew_token() + + if self.verbose: + print(f"Using {self.longname}") + + def now(self): + return monotonic() + + def renew_token(self): + try: + cred_req = requests.post( + self.credential_url, + data=b"", + headers={ + "Content-type": "application/x-www-form-urlencoded", + "Content-Length": "0", + "Ocp-Apim-Subscription-Key": self.key, + }, + ) + if cred_req.status_code == 200: + self.access_token = cred_req.text + self.azure_cached_access_token = self.access_token + self.start_time = monotonic() + self.azure_cached_access_token_expiry = ( + self.start_time + 600 + ) # according to https://docs.microsoft.com/en-us/azure/cognitive-services/Speech-Service/rest-apis#authentication, the token expires in exactly 10 minutes + else: + raise APIKeyError("Cannot renew token") + except APIKeyError as e: + raise APIKeyError(f"Error renewing token: {e}") + + def execute_with_audio(self, audio): + + if self.now() > self.azure_cached_access_token_expiry: + self.renew_token() + + req = requests.post(self.url, data=audio, headers={"Authorization": f"Bearer {self.access_token}", "Content-type": 'audio/wav; codec="audio/pcm"; samplerate=16000'}) + + if req.status_code == 200: + result = json.loads(req.text) + else: + return self.return_error + + if "RecognitionStatus" not in result or result["RecognitionStatus"] != "Success" or "DisplayText" not in result: + return self.return_error + + res = result["DisplayText"].strip() + final_result = res[:-1] if res.endswith(".") else res + return final_result + + +class Aws(ASR): + def __init__(self): + + super().__init__("aw", "cloud-api") + # credentials will be auto retrieved from ~/.aws/credentials however in future should be overriden by param? + + self.longname = "aws" + self.shortname = "aw" + self.handler = None + self.stream = None + self.client = None + + if self.verbose: + print(f"Using {self.longname}") + + def execute_with_audio(self, audio): + audio_file = disk_in_memory(audio) + return asyncio.get_event_loop().run_until_complete(self.write_chunks(audio_file)) + + async def write_chunks(self, audio_file): + + self.client = TranscribeStreamingClient(region="us-east-1") + self.stream = await self.client.start_stream_transcription( + language_code="en-US", + media_sample_rate_hz=16000, + media_encoding="pcm", + ) + + while True: + data = audio_file.read(1024 * 16) + if len(data) == 0: + await self.stream.input_stream.end_stream() + break + await self.stream.input_stream.send_audio_event(audio_chunk=data) + + async for event in self.stream.output_stream: + if isinstance(event, TranscriptEvent): + result = await self.handle_transcript_event(event) + else: + print(event) + + # todo this is not a very good implementation but quick first attempt + while True: + # wait until generator has been iterated over + await asyncio.sleep(0.1) + if result[-1].is_partial == False: + break + + # await asyncio.sleep(0.1) + transcript = result[-1].alternatives[-1].transcript + + if transcript.endswith("."): + # aws ends always with a period, let's kill it. + transcript = transcript[:-1] + + return transcript + + async def handle_transcript_event(self, transcript_event: TranscriptEvent): + # This handler can be implemented to handle transcriptions as needed. + # Here's an example to get started. + results = transcript_event.transcript.results + return results + + +class Google(ASR): + def __init__(self, apikey=None): + + super().__init__("gg", "cloud-api") + # Check GOOGLE_APPLICATION_CREDENTIALS + + if apikey and valid_readable_file(apikey, quiet=True): + environ["GOOGLE_APPLICATION_CREDENTIALS"] = apikey + else: + if valid_readable_file("../models/google/google.json", quiet=True) and environ.get("GOOGLE_APPLICATION_CREDENTIALS") is None: + environ["GOOGLE_APPLICATION_CREDENTIALS"] = "../models/google/google.json" + if environ.get("GOOGLE_APPLICATION_CREDENTIALS") is None or not valid_readable_file(environ["GOOGLE_APPLICATION_CREDENTIALS"]): + warnings.warn( + "INVALID CONFIG/PATH, please update env variable to where your GoogleCloud speech conf exists: \n " + "export GOOGLE_APPLICATION_CREDENTIALS=path/to/google.json \n" + ) + raise InvalidConfigPath + + self.client = speech.SpeechClient() + self.longname = "google" + self.shortname = "gg" + self.configpath = environ.get("GOOGLE_APPLICATION_CREDENTIALS") + self.recognition_config = speech.RecognitionConfig( + dict( + encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=self.sr, + language_code="en-US", + ) + ) + if self.verbose: + print(f"Using {self.longname} with config: {self.configpath}") + + def execute_with_audio(self, audio): + + rec_audio = speech.RecognitionAudio(content=audio) + response = self.client.recognize(config=self.recognition_config, audio=rec_audio) + + transcript_list = [] + for result in response.results: + transcript_list.append(result.alternatives[0].transcript) + + if len(transcript_list) == 0: + transcript = self.return_error + else: + transcript = " ".join(transcript_list) + + return transcript + + +def kill_container(dockerhub_url, verbose=True): + for container in DOCKER_CLIENT.containers.list(): + if len(container.image.tags) > 0 and container.image.tags[-1] == dockerhub_url: + if verbose: + print(f"Docker container: {dockerhub_url} found. Killing...") + container.stop() + + +def launch_container(dockerhub_url, ports_dict, verbose=True, delay=5): + container_running = False + for container in DOCKER_CLIENT.containers.list(): + if len(container.image.tags) > 0 and container.image.tags[-1] == dockerhub_url: + if verbose: + print(f"Docker container: {dockerhub_url} found running") + container_running = True + + if not container_running: + if verbose: + print(f"Docker container: {dockerhub_url} NOT found... downloading and/or running...") + DOCKER_CLIENT.containers.run( + dockerhub_url, + detach=True, + ports=ports_dict, + restart_policy={"Name": "on-failure", "MaximumRetryCount": 5}, + ) + if verbose: + print(f"{dockerhub_url} Downloaded. Starting container...") + time.sleep(delay) + + +def create_model_objects(wanted_asr: list) -> list: + list_of_asr = [] + + print(wanted_asr) + for asr in wanted_asr: + if asr == "all": + list_of_asr = [Vosk(), Sphinx(), Coqui(), Google(), Aws(), Azure(), Nemo()] + elif asr == "vs": + list_of_asr.append(Vosk()) + elif asr == "sp": + list_of_asr.append(Sphinx()) + elif asr == "cq": + list_of_asr.append(Coqui()) + elif asr == "gg": + list_of_asr.append(Google()) + elif asr == "aw": + list_of_asr.append(Aws()) + elif asr == "az": + list_of_asr.append(Azure()) + elif asr == "nm": + list_of_asr.append(Nemo()) + else: + raise AsrNotRecognized("ASR not recognised") + + return list_of_asr diff --git a/speechloop/asr/nemo.py b/speechloop/asr/nemo.py new file mode 100644 index 0000000..e928759 --- /dev/null +++ b/speechloop/asr/nemo.py @@ -0,0 +1,33 @@ +from speechloop.asr.base_asr import ASR +from speechloop.asr.container_utils import launch_container + +import base64 +import requests + + +class Nemo(ASR): + """ + Nemo + """ + + def __init__(self): + super().__init__("nm", "docker-local") + self.uri = "http://localhost:3500/transcribe" + self.dockerhub_url = "robmsmt/sl-nemo-en-16k:latest" + self.shortname = self.dockerhub_url.rsplit("/")[-1].rsplit(":")[0] + self.longname = "nemo" + launch_container(self.dockerhub_url, {"3500/tcp": 3500}, verbose=self.verbose, delay=8) + self.finish_init() + + def execute_with_audio(self, audio): + b64 = base64.b64encode(audio).decode("utf-8") + json_message = {"b64_wav": b64, "sr": 16000} + r = requests.post(self.uri, json=json_message) + if r.status_code == 200: + try: + response = r.json()["transcript"] + return response + except KeyError: + return self.return_error() + else: + return self.return_error() diff --git a/speechloop/asr/registry.py b/speechloop/asr/registry.py index aebab3e..7ecb131 100644 --- a/speechloop/asr/registry.py +++ b/speechloop/asr/registry.py @@ -5,6 +5,7 @@ from speechloop.asr.google import Google from speechloop.asr.aws import Aws from speechloop.asr.azure import Azure +from speechloop.asr.nemo import Nemo def create_model_objects(wanted_asr: list) -> list: @@ -26,6 +27,8 @@ def create_model_objects(wanted_asr: list) -> list: list_of_asr.append(Aws()) elif asr == "az": list_of_asr.append(Azure()) + elif asr == "nm": + list_of_asr.append(Nemo()) else: raise AsrNotRecognized("ASR not recognised") diff --git a/speechloop/wizard.py b/speechloop/wizard.py index 006e1fe..e0455b0 100644 --- a/speechloop/wizard.py +++ b/speechloop/wizard.py @@ -118,8 +118,11 @@ def wizard_main(): "vs - Alphacep Vosk", "sp - CMU Sphinx", "cq - Coqui stt", + "nm - NeMo stt", Separator("---Cloud ASRs---"), "gg - Google Cloud - (requires api key)", # todo maybe ask for this or grey it out if not provided? + "az - Azure - (requires api key)", # todo maybe ask for this or grey it out if not provided? + "aw - AWS - (requires api key)", # todo maybe ask for this or grey it out if not provided? ], validate=lambda a: (True if len(a) > 0 else "You must select at least one ASR"), ).ask()