From ee1f159e4d26dc0fa4cc2fded61821640b223ec3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Erik=20Hedenstr=C3=B6m?= <erik@hedenstroem.com>
Date: Thu, 20 Jun 2024 09:46:29 +0000
Subject: [PATCH] Reverted file structure

---
 openai_stt/__init__.py | 134 ----------------------------------------
 openai_stt/stt.py      | 137 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 137 insertions(+), 134 deletions(-)
 create mode 100644 openai_stt/stt.py

diff --git a/openai_stt/__init__.py b/openai_stt/__init__.py
index a62a8fb..6a10f59 100644
--- a/openai_stt/__init__.py
+++ b/openai_stt/__init__.py
@@ -1,135 +1 @@
 """Custom integration for OpenAI Whisper API STT."""
-
-from typing import AsyncIterable
-import aiohttp
-import logging
-import io
-import wave
-import voluptuous as vol
-from homeassistant.components.stt import (
-    AudioBitRates,
-    AudioChannels,
-    AudioCodecs,
-    AudioFormats,
-    AudioSampleRates,
-    Provider,
-    SpeechMetadata,
-    SpeechResult,
-    SpeechResultState,
-)
-import homeassistant.helpers.config_validation as cv
-
-_LOGGER = logging.getLogger(__name__)
-
-CONF_API_KEY = "api_key"
-CONF_LANG = "language"
-
-PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
-    {
-        vol.Required(CONF_API_KEY): cv.string,
-        vol.Optional(CONF_LANG, default="en-US"): cv.string,
-    }
-)
-
-
-async def async_get_engine(hass, config, discovery_info=None):
-    """Set up Whisper API STT speech component."""
-    api_key = config.get(CONF_API_KEY)
-    lang = config.get(CONF_LANG)
-    return OpenAISTTProvider(hass, api_key, lang)
-
-
-class OpenAISTTProvider(Provider):
-    """The Whisper API STT provider."""
-
-    def __init__(self, hass, api_key, lang):
-        """Initialize Whisper API STT provider."""
-        self.hass = hass
-        self._api_key = api_key
-        self._language = lang
-
-    @property
-    def default_language(self) -> str:
-        """Return the default language."""
-        return self._language.split(",")[0]
-
-    @property
-    def supported_languages(self) -> list[str]:
-        """Return the list of supported languages."""
-        return self._language.split(",")
-
-    @property
-    def supported_formats(self) -> list[AudioFormats]:
-        """Return a list of supported formats."""
-        return [AudioFormats.WAV]
-
-    @property
-    def supported_codecs(self) -> list[AudioCodecs]:
-        """Return a list of supported codecs."""
-        return [AudioCodecs.PCM]
-
-    @property
-    def supported_bit_rates(self) -> list[AudioBitRates]:
-        """Return a list of supported bitrates."""
-        return [AudioBitRates.BITRATE_16]
-
-    @property
-    def supported_sample_rates(self) -> list[AudioSampleRates]:
-        """Return a list of supported samplerates."""
-        return [AudioSampleRates.SAMPLERATE_16000]
-
-    @property
-    def supported_channels(self) -> list[AudioChannels]:
-        """Return a list of supported channels."""
-        return [AudioChannels.CHANNEL_MONO]
-
-    async def async_process_audio_stream(
-        self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
-    ) -> SpeechResult:
-        data = b""
-        async for chunk in stream:
-            data += chunk
-
-        if not data:
-            return SpeechResult("", SpeechResultState.ERROR)
-
-        try:
-            byte_io = io.BytesIO()
-            with wave.open(byte_io, "wb") as wav_file:
-                wav_file.setnchannels(metadata.channel)
-                wav_file.setsampwidth(2)  # 2 bytes per sample
-                wav_file.setframerate(metadata.sample_rate)
-                wav_file.writeframes(data)
-
-            headers = {
-                "Authorization": f"Bearer {self._api_key}",
-            }
-
-            form = aiohttp.FormData()
-            form.add_field(
-                "file",
-                byte_io.getvalue(),
-                filename="audio.wav",
-                content_type="audio/wav",
-            )
-            form.add_field("language", metadata.language)
-            form.add_field("model", "whisper-1")
-
-            async with aiohttp.ClientSession() as session:
-                async with session.post(
-                    "https://api.openai.com/v1/audio/transcriptions",
-                    data=form,
-                    headers=headers,
-                ) as response:
-                    if response.status == 200:
-                        json_response = await response.json()
-                        return SpeechResult(
-                            json_response["text"], SpeechResultState.SUCCESS
-                        )
-                    else:
-                        text = await response.text()
-                        _LOGGER.warning("{}:{}".format(response.status, text))
-                        return SpeechResult("", SpeechResultState.ERROR)
-        except Exception as e:
-            _LOGGER.warning(e)
-            return SpeechResult("", SpeechResultState.ERROR)
diff --git a/openai_stt/stt.py b/openai_stt/stt.py
new file mode 100644
index 0000000..081c048
--- /dev/null
+++ b/openai_stt/stt.py
@@ -0,0 +1,137 @@
+"""
+Support for Whisper API STT.
+"""
+
+from typing import AsyncIterable
+import aiohttp
+import logging
+import io
+import wave
+import voluptuous as vol
+from homeassistant.components.stt import (
+    AudioBitRates,
+    AudioChannels,
+    AudioCodecs,
+    AudioFormats,
+    AudioSampleRates,
+    Provider,
+    SpeechMetadata,
+    SpeechResult,
+    SpeechResultState,
+)
+import homeassistant.helpers.config_validation as cv
+
+_LOGGER = logging.getLogger(__name__)
+
+CONF_API_KEY = "api_key"
+CONF_LANG = "language"
+
+PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
+    {
+        vol.Required(CONF_API_KEY): cv.string,
+        vol.Optional(CONF_LANG, default="en-US"): cv.string,
+    }
+)
+
+
+async def async_get_engine(hass, config, discovery_info=None):
+    """Set up Whisper API STT speech component."""
+    api_key = config.get(CONF_API_KEY)
+    lang = config.get(CONF_LANG)
+    return OpenAISTTProvider(hass, api_key, lang)
+
+
+class OpenAISTTProvider(Provider):
+    """The Whisper API STT provider."""
+
+    def __init__(self, hass, api_key, lang):
+        """Initialize Whisper API STT provider."""
+        self.hass = hass
+        self._api_key = api_key
+        self._language = lang
+
+    @property
+    def default_language(self) -> str:
+        """Return the default language."""
+        return self._language.split(",")[0]
+
+    @property
+    def supported_languages(self) -> list[str]:
+        """Return the list of supported languages."""
+        return self._language.split(",")
+
+    @property
+    def supported_formats(self) -> list[AudioFormats]:
+        """Return a list of supported formats."""
+        return [AudioFormats.WAV]
+
+    @property
+    def supported_codecs(self) -> list[AudioCodecs]:
+        """Return a list of supported codecs."""
+        return [AudioCodecs.PCM]
+
+    @property
+    def supported_bit_rates(self) -> list[AudioBitRates]:
+        """Return a list of supported bitrates."""
+        return [AudioBitRates.BITRATE_16]
+
+    @property
+    def supported_sample_rates(self) -> list[AudioSampleRates]:
+        """Return a list of supported samplerates."""
+        return [AudioSampleRates.SAMPLERATE_16000]
+
+    @property
+    def supported_channels(self) -> list[AudioChannels]:
+        """Return a list of supported channels."""
+        return [AudioChannels.CHANNEL_MONO]
+
+    async def async_process_audio_stream(
+        self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
+    ) -> SpeechResult:
+        data = b""
+        async for chunk in stream:
+            data += chunk
+
+        if not data:
+            return SpeechResult("", SpeechResultState.ERROR)
+
+        try:
+            byte_io = io.BytesIO()
+            with wave.open(byte_io, "wb") as wav_file:
+                wav_file.setnchannels(metadata.channel)
+                wav_file.setsampwidth(2)  # 2 bytes per sample
+                wav_file.setframerate(metadata.sample_rate)
+                wav_file.writeframes(data)
+
+            headers = {
+                "Authorization": f"Bearer {self._api_key}",
+            }
+
+            form = aiohttp.FormData()
+            form.add_field(
+                "file",
+                byte_io.getvalue(),
+                filename="audio.wav",
+                content_type="audio/wav",
+            )
+            form.add_field("language", metadata.language)
+            form.add_field("model", "whisper-1")
+
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    "https://api.openai.com/v1/audio/transcriptions",
+                    data=form,
+                    headers=headers,
+                ) as response:
+                    if response.status == 200:
+                        json_response = await response.json()
+                        return SpeechResult(
+                            json_response["text"], SpeechResultState.SUCCESS
+                        )
+                    else:
+                        text = await response.text()
+                        _LOGGER.warning("{}:{}".format(response.status, text))
+                        return SpeechResult("", SpeechResultState.ERROR)
+        except Exception as e:
+            _LOGGER.warning(e)
+            return SpeechResult("", SpeechResultState.ERROR)
-- 
GitLab