diff --git a/openai_stt/__init__.py b/openai_stt/__init__.py index a62a8fba111cce53d8cb56873bb01a36b32fadec..6a10f59502b8b51974249aadcb04272422709c2b 100644 --- a/openai_stt/__init__.py +++ b/openai_stt/__init__.py @@ -1,135 +1 @@ """Custom integration for OpenAI Whisper API STT.""" - -from typing import AsyncIterable -import aiohttp -import logging -import io -import wave -import voluptuous as vol -from homeassistant.components.stt import ( - AudioBitRates, - AudioChannels, - AudioCodecs, - AudioFormats, - AudioSampleRates, - Provider, - SpeechMetadata, - SpeechResult, - SpeechResultState, -) -import homeassistant.helpers.config_validation as cv - -_LOGGER = logging.getLogger(__name__) - -CONF_API_KEY = "api_key" -CONF_LANG = "language" - -PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( - { - vol.Required(CONF_API_KEY): cv.string, - vol.Optional(CONF_LANG, default="en-US"): cv.string, - } -) - - -async def async_get_engine(hass, config, discovery_info=None): - """Set up Whisper API STT speech component.""" - api_key = config.get(CONF_API_KEY) - lang = config.get(CONF_LANG) - return OpenAISTTProvider(hass, api_key, lang) - - -class OpenAISTTProvider(Provider): - """The Whisper API STT provider.""" - - def __init__(self, hass, api_key, lang): - """Initialize Whisper API STT provider.""" - self.hass = hass - self._api_key = api_key - self._language = lang - - @property - def default_language(self) -> str: - """Return the default language.""" - return self._language.split(",")[0] - - @property - def supported_languages(self) -> list[str]: - """Return the list of supported languages.""" - return self._language.split(",") - - @property - def supported_formats(self) -> list[AudioFormats]: - """Return a list of supported formats.""" - return [AudioFormats.WAV] - - @property - def supported_codecs(self) -> list[AudioCodecs]: - """Return a list of supported codecs.""" - return [AudioCodecs.PCM] - - @property - def supported_bit_rates(self) -> list[AudioBitRates]: - """Return a list of supported bitrates.""" - return [AudioBitRates.BITRATE_16] - - @property - def supported_sample_rates(self) -> list[AudioSampleRates]: - """Return a list of supported samplerates.""" - return [AudioSampleRates.SAMPLERATE_16000] - - @property - def supported_channels(self) -> list[AudioChannels]: - """Return a list of supported channels.""" - return [AudioChannels.CHANNEL_MONO] - - async def async_process_audio_stream( - self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] - ) -> SpeechResult: - data = b"" - async for chunk in stream: - data += chunk - - if not data: - return SpeechResult("", SpeechResultState.ERROR) - - try: - byte_io = io.BytesIO() - with wave.open(byte_io, "wb") as wav_file: - wav_file.setnchannels(metadata.channel) - wav_file.setsampwidth(2) # 2 bytes per sample - wav_file.setframerate(metadata.sample_rate) - wav_file.writeframes(data) - - headers = { - "Authorization": f"Bearer {self._api_key}", - } - - form = aiohttp.FormData() - form.add_field( - "file", - byte_io.getvalue(), - filename="audio.wav", - content_type="audio/wav", - ) - form.add_field("language", metadata.language) - form.add_field("model", "whisper-1") - - async with aiohttp.ClientSession() as session: - async with session.post( - "https://api.openai.com/v1/audio/transcriptions", - data=form, - headers=headers, - ) as response: - if response.status == 200: - json_response = await response.json() - return SpeechResult( - json_response["text"], SpeechResultState.SUCCESS - ) - else: - text = await response.text() - _LOGGER.warning("{}:{}".format(response.status, text)) - return SpeechResult("", SpeechResultState.ERROR) - except Exception as e: - _LOGGER.warning(e) - return SpeechResult("", SpeechResultState.ERROR) diff --git a/openai_stt/stt.py b/openai_stt/stt.py new file mode 100644 index 0000000000000000000000000000000000000000..081c04800c8302a3939966ecd548435248b7e4c3 --- /dev/null +++ b/openai_stt/stt.py @@ -0,0 +1,137 @@ +""" +Support for Whisper API STT. +""" + +from typing import AsyncIterable +import aiohttp +import logging +import io +import wave +import voluptuous as vol +from homeassistant.components.stt import ( + AudioBitRates, + AudioChannels, + AudioCodecs, + AudioFormats, + AudioSampleRates, + Provider, + SpeechMetadata, + SpeechResult, + SpeechResultState, +) +import homeassistant.helpers.config_validation as cv + +_LOGGER = logging.getLogger(__name__) + +CONF_API_KEY = "api_key" +CONF_LANG = "language" + +PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( + { + vol.Required(CONF_API_KEY): cv.string, + vol.Optional(CONF_LANG, default="en-US"): cv.string, + } +) + + +async def async_get_engine(hass, config, discovery_info=None): + """Set up Whisper API STT speech component.""" + api_key = config.get(CONF_API_KEY) + lang = config.get(CONF_LANG) + return OpenAISTTProvider(hass, api_key, lang) + + +class OpenAISTTProvider(Provider): + """The Whisper API STT provider.""" + + def __init__(self, hass, api_key, lang): + """Initialize Whisper API STT provider.""" + self.hass = hass + self._api_key = api_key + self._language = lang + + @property + def default_language(self) -> str: + """Return the default language.""" + return self._language.split(",")[0] + + @property + def supported_languages(self) -> list[str]: + """Return the list of supported languages.""" + return self._language.split(",") + + @property + def supported_formats(self) -> list[AudioFormats]: + """Return a list of supported formats.""" + return [AudioFormats.WAV] + + @property + def supported_codecs(self) -> list[AudioCodecs]: + """Return a list of supported codecs.""" + return [AudioCodecs.PCM] + + @property + def supported_bit_rates(self) -> list[AudioBitRates]: + """Return a list of supported bitrates.""" + return [AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[AudioSampleRates]: + """Return a list of supported samplerates.""" + return [AudioSampleRates.SAMPLERATE_16000] + + @property + def supported_channels(self) -> list[AudioChannels]: + """Return a list of supported channels.""" + return [AudioChannels.CHANNEL_MONO] + + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + data = b"" + async for chunk in stream: + data += chunk + + if not data: + return SpeechResult("", SpeechResultState.ERROR) + + try: + byte_io = io.BytesIO() + with wave.open(byte_io, "wb") as wav_file: + wav_file.setnchannels(metadata.channel) + wav_file.setsampwidth(2) # 2 bytes per sample + wav_file.setframerate(metadata.sample_rate) + wav_file.writeframes(data) + + headers = { + "Authorization": f"Bearer {self._api_key}", + } + + form = aiohttp.FormData() + form.add_field( + "file", + byte_io.getvalue(), + filename="audio.wav", + content_type="audio/wav", + ) + form.add_field("language", metadata.language) + form.add_field("model", "whisper-1") + + async with aiohttp.ClientSession() as session: + async with session.post( + "https://api.openai.com/v1/audio/transcriptions", + data=form, + headers=headers, + ) as response: + if response.status == 200: + json_response = await response.json() + return SpeechResult( + json_response["text"], SpeechResultState.SUCCESS + ) + else: + text = await response.text() + _LOGGER.warning("{}:{}".format(response.status, text)) + return SpeechResult("", SpeechResultState.ERROR) + except Exception as e: + _LOGGER.warning(e) + return SpeechResult("", SpeechResultState.ERROR)