diff --git a/openai_stt/stt.py b/openai_stt/stt.py index f760e7913552577c9dc177a2509efe94513fd53c..081c04800c8302a3939966ecd548435248b7e4c3 100644 --- a/openai_stt/stt.py +++ b/openai_stt/stt.py @@ -1,10 +1,12 @@ """ Support for Whisper API STT. """ + from typing import AsyncIterable import aiohttp import logging import io +import wave import voluptuous as vol from homeassistant.components.stt import ( AudioBitRates, @@ -18,17 +20,18 @@ from homeassistant.components.stt import ( SpeechResultState, ) import homeassistant.helpers.config_validation as cv -import wave _LOGGER = logging.getLogger(__name__) -CONF_API_KEY = 'api_key' -CONF_LANG = 'language' +CONF_API_KEY = "api_key" +CONF_LANG = "language" -PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend({ - vol.Required(CONF_API_KEY): cv.string, - vol.Optional(CONF_LANG, default='en-US'): cv.string, -}) +PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( + { + vol.Required(CONF_API_KEY): cv.string, + vol.Optional(CONF_LANG, default="en-US"): cv.string, + } +) async def async_get_engine(hass, config, discovery_info=None): @@ -50,12 +53,12 @@ class OpenAISTTProvider(Provider): @property def default_language(self) -> str: """Return the default language.""" - return self._language.split(',')[0] + return self._language.split(",")[0] @property def supported_languages(self) -> list[str]: """Return the list of supported languages.""" - return self._language.split(',') + return self._language.split(",") @property def supported_formats(self) -> list[AudioFormats]: @@ -82,8 +85,10 @@ class OpenAISTTProvider(Provider): """Return a list of supported channels.""" return [AudioChannels.CHANNEL_MONO] - async def async_process_audio_stream(self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]) -> SpeechResult: - data = b'' + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + data = b"" async for chunk in stream: data += chunk @@ -92,26 +97,37 @@ class OpenAISTTProvider(Provider): try: byte_io = io.BytesIO() - with wave.open(byte_io, 'wb') as wav_file: + with wave.open(byte_io, "wb") as wav_file: wav_file.setnchannels(metadata.channel) wav_file.setsampwidth(2) # 2 bytes per sample wav_file.setframerate(metadata.sample_rate) wav_file.writeframes(data) headers = { - 'Authorization': f'Bearer {self._api_key}', + "Authorization": f"Bearer {self._api_key}", } form = aiohttp.FormData() - form.add_field('file', byte_io.getvalue(), filename='audio.wav', content_type='audio/wav') - form.add_field('language', metadata.language) - form.add_field('model', 'whisper-1') + form.add_field( + "file", + byte_io.getvalue(), + filename="audio.wav", + content_type="audio/wav", + ) + form.add_field("language", metadata.language) + form.add_field("model", "whisper-1") async with aiohttp.ClientSession() as session: - async with session.post('https://api.openai.com/v1/audio/transcriptions', data=form, headers=headers) as response: + async with session.post( + "https://api.openai.com/v1/audio/transcriptions", + data=form, + headers=headers, + ) as response: if response.status == 200: json_response = await response.json() - return SpeechResult(json_response["text"], SpeechResultState.SUCCESS) + return SpeechResult( + json_response["text"], SpeechResultState.SUCCESS + ) else: text = await response.text() _LOGGER.warning("{}:{}".format(response.status, text))