From 5d83d0cf08ef2976528010b8ffc80cf80d639135 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Erik=20Hedenstr=C3=B6m?= <erik@hedenstroem.com>
Date: Thu, 20 Jun 2024 09:18:34 +0000
Subject: [PATCH] Added OpenAI patch component

---
 .gitlab-ci.yml             |   9 ++-
 README.md                  |   2 +-
 openai_patch/__init__.py   |  26 +++++++
 openai_patch/manifest.json |  16 +++++
 openai_stt/__init__.py     | 134 ++++++++++++++++++++++++++++++++++++
 openai_stt/stt.py          | 137 -------------------------------------
 requirements.txt           |   7 ++
 7 files changed, 190 insertions(+), 141 deletions(-)
 create mode 100644 openai_patch/__init__.py
 create mode 100644 openai_patch/manifest.json
 delete mode 100644 openai_stt/stt.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 56c2523..3259fb7 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,7 +1,10 @@
 deploy:
   script:
-    - export PATH=$PATH:/opt/host/bin:/opt/host/local/bin
-    - rm -rf /opt/host/homeassistant/custom_components/openai_stt
-    - cp -Rfv openai_stt /opt/host/homeassistant/custom_components
+    - |
+      export PATH=$PATH:/opt/host/bin:/opt/host/local/bin
+      for name in openai_patch openai_tts; do
+        rm -rf /opt/host/homeassistant/custom_components/$name
+        cp -Rfv $name /opt/host/homeassistant/custom_components
+      done
   tags:
     - shell
diff --git a/README.md b/README.md
index 917573f..1e746fb 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ Run the following commands to create a virtual environment to install homeassist
 python3 -m venv .venv
 source .venv/bin/activate
 python3 -m pip install wheel
-pip3 install homeassistant
+pip3 install homeassistant openai voluptuous-openapi
 pip3 freeze > requirements.txt
 ```
 
diff --git a/openai_patch/__init__.py b/openai_patch/__init__.py
new file mode 100644
index 0000000..d88ddcf
--- /dev/null
+++ b/openai_patch/__init__.py
@@ -0,0 +1,26 @@
+import re
+import json
+import logging
+
+from homeassistant.helpers import intent
+from homeassistant.exceptions import ServiceNotFound
+from homeassistant.components import conversation
+from homeassistant.components.openai_conversation import OpenAIConversationEntity
+
+_LOGGER = logging.getLogger(__name__)
+
+async def async_setup(hass, config):
+
+    original = OpenAIConversationEntity.async_process
+
+    async def async_process(self, user_input: conversation.ConversationInput) -> conversation.ConversationResult:
+        _LOGGER.debug("OpenAIConversationEntity.async_process")
+        client = self.entry.runtime_data
+        _LOGGER.debug(client.base_url)
+        result = await original(self, user_input)
+        return result
+
+    OpenAIConversationEntity.async_process = async_process
+    _LOGGER.info("Patched OpenAIConversationEntity.async_process")
+
+    return True
diff --git a/openai_patch/manifest.json b/openai_patch/manifest.json
new file mode 100644
index 0000000..5d00e56
--- /dev/null
+++ b/openai_patch/manifest.json
@@ -0,0 +1,16 @@
+{
+  "domain": "openai_patch",
+  "name": "OpenAI Patch",
+  "codeowners": [
+    "@erik"
+  ],
+  "dependencies": [
+    "conversation"
+  ],
+  "documentation": "https://gitlab.hedenstroem.com/home-assistant/custom-components",
+  "issue_tracker": "https://gitlab.hedenstroem.com/home-assistant/custom-components/-/issues",
+  "requirements": [
+    "aiohttp>=3.7.4"
+  ],
+  "version": "0.1.0"
+}
diff --git a/openai_stt/__init__.py b/openai_stt/__init__.py
index 6a10f59..a62a8fb 100644
--- a/openai_stt/__init__.py
+++ b/openai_stt/__init__.py
@@ -1 +1,135 @@
 """Custom integration for OpenAI Whisper API STT."""
+
+from typing import AsyncIterable
+import aiohttp
+import logging
+import io
+import wave
+import voluptuous as vol
+from homeassistant.components.stt import (
+    AudioBitRates,
+    AudioChannels,
+    AudioCodecs,
+    AudioFormats,
+    AudioSampleRates,
+    Provider,
+    SpeechMetadata,
+    SpeechResult,
+    SpeechResultState,
+)
+import homeassistant.helpers.config_validation as cv
+
+_LOGGER = logging.getLogger(__name__)
+
+CONF_API_KEY = "api_key"
+CONF_LANG = "language"
+
+PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
+    {
+        vol.Required(CONF_API_KEY): cv.string,
+        vol.Optional(CONF_LANG, default="en-US"): cv.string,
+    }
+)
+
+
+async def async_get_engine(hass, config, discovery_info=None):
+    """Set up Whisper API STT speech component."""
+    api_key = config.get(CONF_API_KEY)
+    lang = config.get(CONF_LANG)
+    return OpenAISTTProvider(hass, api_key, lang)
+
+
+class OpenAISTTProvider(Provider):
+    """The Whisper API STT provider."""
+
+    def __init__(self, hass, api_key, lang):
+        """Initialize Whisper API STT provider."""
+        self.hass = hass
+        self._api_key = api_key
+        self._language = lang
+
+    @property
+    def default_language(self) -> str:
+        """Return the default language."""
+        return self._language.split(",")[0]
+
+    @property
+    def supported_languages(self) -> list[str]:
+        """Return the list of supported languages."""
+        return self._language.split(",")
+
+    @property
+    def supported_formats(self) -> list[AudioFormats]:
+        """Return a list of supported formats."""
+        return [AudioFormats.WAV]
+
+    @property
+    def supported_codecs(self) -> list[AudioCodecs]:
+        """Return a list of supported codecs."""
+        return [AudioCodecs.PCM]
+
+    @property
+    def supported_bit_rates(self) -> list[AudioBitRates]:
+        """Return a list of supported bitrates."""
+        return [AudioBitRates.BITRATE_16]
+
+    @property
+    def supported_sample_rates(self) -> list[AudioSampleRates]:
+        """Return a list of supported samplerates."""
+        return [AudioSampleRates.SAMPLERATE_16000]
+
+    @property
+    def supported_channels(self) -> list[AudioChannels]:
+        """Return a list of supported channels."""
+        return [AudioChannels.CHANNEL_MONO]
+
+    async def async_process_audio_stream(
+        self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
+    ) -> SpeechResult:
+        data = b""
+        async for chunk in stream:
+            data += chunk
+
+        if not data:
+            return SpeechResult("", SpeechResultState.ERROR)
+
+        try:
+            byte_io = io.BytesIO()
+            with wave.open(byte_io, "wb") as wav_file:
+                wav_file.setnchannels(metadata.channel)
+                wav_file.setsampwidth(2)  # 2 bytes per sample
+                wav_file.setframerate(metadata.sample_rate)
+                wav_file.writeframes(data)
+
+            headers = {
+                "Authorization": f"Bearer {self._api_key}",
+            }
+
+            form = aiohttp.FormData()
+            form.add_field(
+                "file",
+                byte_io.getvalue(),
+                filename="audio.wav",
+                content_type="audio/wav",
+            )
+            form.add_field("language", metadata.language)
+            form.add_field("model", "whisper-1")
+
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    "https://api.openai.com/v1/audio/transcriptions",
+                    data=form,
+                    headers=headers,
+                ) as response:
+                    if response.status == 200:
+                        json_response = await response.json()
+                        return SpeechResult(
+                            json_response["text"], SpeechResultState.SUCCESS
+                        )
+                    else:
+                        text = await response.text()
+                        _LOGGER.warning("{}:{}".format(response.status, text))
+                        return SpeechResult("", SpeechResultState.ERROR)
+        except Exception as e:
+            _LOGGER.warning(e)
+            return SpeechResult("", SpeechResultState.ERROR)
diff --git a/openai_stt/stt.py b/openai_stt/stt.py
deleted file mode 100644
index 081c048..0000000
--- a/openai_stt/stt.py
+++ /dev/null
@@ -1,137 +0,0 @@
-"""
-Support for Whisper API STT.
-"""
-
-from typing import AsyncIterable
-import aiohttp
-import logging
-import io
-import wave
-import voluptuous as vol
-from homeassistant.components.stt import (
-    AudioBitRates,
-    AudioChannels,
-    AudioCodecs,
-    AudioFormats,
-    AudioSampleRates,
-    Provider,
-    SpeechMetadata,
-    SpeechResult,
-    SpeechResultState,
-)
-import homeassistant.helpers.config_validation as cv
-
-_LOGGER = logging.getLogger(__name__)
-
-CONF_API_KEY = "api_key"
-CONF_LANG = "language"
-
-PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
-    {
-        vol.Required(CONF_API_KEY): cv.string,
-        vol.Optional(CONF_LANG, default="en-US"): cv.string,
-    }
-)
-
-
-async def async_get_engine(hass, config, discovery_info=None):
-    """Set up Whisper API STT speech component."""
-    api_key = config.get(CONF_API_KEY)
-    lang = config.get(CONF_LANG)
-    return OpenAISTTProvider(hass, api_key, lang)
-
-
-class OpenAISTTProvider(Provider):
-    """The Whisper API STT provider."""
-
-    def __init__(self, hass, api_key, lang):
-        """Initialize Whisper API STT provider."""
-        self.hass = hass
-        self._api_key = api_key
-        self._language = lang
-
-    @property
-    def default_language(self) -> str:
-        """Return the default language."""
-        return self._language.split(",")[0]
-
-    @property
-    def supported_languages(self) -> list[str]:
-        """Return the list of supported languages."""
-        return self._language.split(",")
-
-    @property
-    def supported_formats(self) -> list[AudioFormats]:
-        """Return a list of supported formats."""
-        return [AudioFormats.WAV]
-
-    @property
-    def supported_codecs(self) -> list[AudioCodecs]:
-        """Return a list of supported codecs."""
-        return [AudioCodecs.PCM]
-
-    @property
-    def supported_bit_rates(self) -> list[AudioBitRates]:
-        """Return a list of supported bitrates."""
-        return [AudioBitRates.BITRATE_16]
-
-    @property
-    def supported_sample_rates(self) -> list[AudioSampleRates]:
-        """Return a list of supported samplerates."""
-        return [AudioSampleRates.SAMPLERATE_16000]
-
-    @property
-    def supported_channels(self) -> list[AudioChannels]:
-        """Return a list of supported channels."""
-        return [AudioChannels.CHANNEL_MONO]
-
-    async def async_process_audio_stream(
-        self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
-    ) -> SpeechResult:
-        data = b""
-        async for chunk in stream:
-            data += chunk
-
-        if not data:
-            return SpeechResult("", SpeechResultState.ERROR)
-
-        try:
-            byte_io = io.BytesIO()
-            with wave.open(byte_io, "wb") as wav_file:
-                wav_file.setnchannels(metadata.channel)
-                wav_file.setsampwidth(2)  # 2 bytes per sample
-                wav_file.setframerate(metadata.sample_rate)
-                wav_file.writeframes(data)
-
-            headers = {
-                "Authorization": f"Bearer {self._api_key}",
-            }
-
-            form = aiohttp.FormData()
-            form.add_field(
-                "file",
-                byte_io.getvalue(),
-                filename="audio.wav",
-                content_type="audio/wav",
-            )
-            form.add_field("language", metadata.language)
-            form.add_field("model", "whisper-1")
-
-            async with aiohttp.ClientSession() as session:
-                async with session.post(
-                    "https://api.openai.com/v1/audio/transcriptions",
-                    data=form,
-                    headers=headers,
-                ) as response:
-                    if response.status == 200:
-                        json_response = await response.json()
-                        return SpeechResult(
-                            json_response["text"], SpeechResultState.SUCCESS
-                        )
-                    else:
-                        text = await response.text()
-                        _LOGGER.warning("{}:{}".format(response.status, text))
-                        return SpeechResult("", SpeechResultState.ERROR)
-        except Exception as e:
-            _LOGGER.warning(e)
-            return SpeechResult("", SpeechResultState.ERROR)
diff --git a/requirements.txt b/requirements.txt
index d2ea61b..a271e84 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,6 +7,7 @@ aiohttp-fast-zlib==0.1.0
 aiooui==0.1.5
 aiosignal==1.3.1
 aiozoneinfo==0.1.0
+annotated-types==0.7.0
 anyio==4.4.0
 astral==2.2
 async-interrupt==1.1.1
@@ -29,6 +30,7 @@ charset-normalizer==3.3.2
 ciso8601==2.3.1
 cryptography==42.0.5
 dbus-fast==2.21.3
+distro==1.9.0
 envs==1.4
 fnv-hash-fast==0.5.0
 fnvhash==0.1.0
@@ -49,6 +51,7 @@ josepy==1.14.0
 lru-dict==1.3.0
 MarkupSafe==2.1.5
 multidict==6.0.5
+openai==1.35.1
 orjson==3.9.15
 packaging==24.1
 pillow==10.3.0
@@ -57,6 +60,8 @@ psutil-home-assistant==0.0.1
 pycares==4.4.0
 pycognito==2024.5.1
 pycparser==2.22
+pydantic==2.7.4
+pydantic_core==2.18.4
 PyJWT==2.8.0
 pyOpenSSL==24.1.0
 pyRFC3339==1.1
@@ -73,6 +78,7 @@ sniffio==1.3.1
 snitun==0.39.1
 SQLAlchemy==2.0.30
 text-unidecode==1.3
+tqdm==4.66.4
 typing_extensions==4.12.2
 tzdata==2024.1
 uart-devices==0.1.0
@@ -80,6 +86,7 @@ ulid-transform==0.9.0
 urllib3==1.26.19
 usb-devices==0.4.5
 voluptuous==0.13.1
+voluptuous-openapi==0.0.4
 voluptuous-serialize==2.6.0
 wheel==0.43.0
 yarl==1.9.4
-- 
GitLab