From 5d83d0cf08ef2976528010b8ffc80cf80d639135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Hedenstr=C3=B6m?= <erik@hedenstroem.com> Date: Thu, 20 Jun 2024 09:18:34 +0000 Subject: [PATCH] Added OpenAI patch component --- .gitlab-ci.yml | 9 ++- README.md | 2 +- openai_patch/__init__.py | 26 +++++++ openai_patch/manifest.json | 16 +++++ openai_stt/__init__.py | 134 ++++++++++++++++++++++++++++++++++++ openai_stt/stt.py | 137 ------------------------------------- requirements.txt | 7 ++ 7 files changed, 190 insertions(+), 141 deletions(-) create mode 100644 openai_patch/__init__.py create mode 100644 openai_patch/manifest.json delete mode 100644 openai_stt/stt.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 56c2523..3259fb7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,7 +1,10 @@ deploy: script: - - export PATH=$PATH:/opt/host/bin:/opt/host/local/bin - - rm -rf /opt/host/homeassistant/custom_components/openai_stt - - cp -Rfv openai_stt /opt/host/homeassistant/custom_components + - | + export PATH=$PATH:/opt/host/bin:/opt/host/local/bin + for name in openai_patch openai_tts; do + rm -rf /opt/host/homeassistant/custom_components/$name + cp -Rfv $name /opt/host/homeassistant/custom_components + done tags: - shell diff --git a/README.md b/README.md index 917573f..1e746fb 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Run the following commands to create a virtual environment to install homeassist python3 -m venv .venv source .venv/bin/activate python3 -m pip install wheel -pip3 install homeassistant +pip3 install homeassistant openai voluptuous-openapi pip3 freeze > requirements.txt ``` diff --git a/openai_patch/__init__.py b/openai_patch/__init__.py new file mode 100644 index 0000000..d88ddcf --- /dev/null +++ b/openai_patch/__init__.py @@ -0,0 +1,26 @@ +import re +import json +import logging + +from homeassistant.helpers import intent +from homeassistant.exceptions import ServiceNotFound +from homeassistant.components import conversation +from homeassistant.components.openai_conversation import OpenAIConversationEntity + +_LOGGER = logging.getLogger(__name__) + +async def async_setup(hass, config): + + original = OpenAIConversationEntity.async_process + + async def async_process(self, user_input: conversation.ConversationInput) -> conversation.ConversationResult: + _LOGGER.debug("OpenAIConversationEntity.async_process") + client = self.entry.runtime_data + _LOGGER.debug(client.base_url) + result = await original(self, user_input) + return result + + OpenAIConversationEntity.async_process = async_process + _LOGGER.info("Patched OpenAIConversationEntity.async_process") + + return True diff --git a/openai_patch/manifest.json b/openai_patch/manifest.json new file mode 100644 index 0000000..5d00e56 --- /dev/null +++ b/openai_patch/manifest.json @@ -0,0 +1,16 @@ +{ + "domain": "openai_patch", + "name": "OpenAI Patch", + "codeowners": [ + "@erik" + ], + "dependencies": [ + "conversation" + ], + "documentation": "https://gitlab.hedenstroem.com/home-assistant/custom-components", + "issue_tracker": "https://gitlab.hedenstroem.com/home-assistant/custom-components/-/issues", + "requirements": [ + "aiohttp>=3.7.4" + ], + "version": "0.1.0" +} diff --git a/openai_stt/__init__.py b/openai_stt/__init__.py index 6a10f59..a62a8fb 100644 --- a/openai_stt/__init__.py +++ b/openai_stt/__init__.py @@ -1 +1,135 @@ """Custom integration for OpenAI Whisper API STT.""" + +from typing import AsyncIterable +import aiohttp +import logging +import io +import wave +import voluptuous as vol +from homeassistant.components.stt import ( + AudioBitRates, + AudioChannels, + AudioCodecs, + AudioFormats, + AudioSampleRates, + Provider, + SpeechMetadata, + SpeechResult, + SpeechResultState, +) +import homeassistant.helpers.config_validation as cv + +_LOGGER = logging.getLogger(__name__) + +CONF_API_KEY = "api_key" +CONF_LANG = "language" + +PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( + { + vol.Required(CONF_API_KEY): cv.string, + vol.Optional(CONF_LANG, default="en-US"): cv.string, + } +) + + +async def async_get_engine(hass, config, discovery_info=None): + """Set up Whisper API STT speech component.""" + api_key = config.get(CONF_API_KEY) + lang = config.get(CONF_LANG) + return OpenAISTTProvider(hass, api_key, lang) + + +class OpenAISTTProvider(Provider): + """The Whisper API STT provider.""" + + def __init__(self, hass, api_key, lang): + """Initialize Whisper API STT provider.""" + self.hass = hass + self._api_key = api_key + self._language = lang + + @property + def default_language(self) -> str: + """Return the default language.""" + return self._language.split(",")[0] + + @property + def supported_languages(self) -> list[str]: + """Return the list of supported languages.""" + return self._language.split(",") + + @property + def supported_formats(self) -> list[AudioFormats]: + """Return a list of supported formats.""" + return [AudioFormats.WAV] + + @property + def supported_codecs(self) -> list[AudioCodecs]: + """Return a list of supported codecs.""" + return [AudioCodecs.PCM] + + @property + def supported_bit_rates(self) -> list[AudioBitRates]: + """Return a list of supported bitrates.""" + return [AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[AudioSampleRates]: + """Return a list of supported samplerates.""" + return [AudioSampleRates.SAMPLERATE_16000] + + @property + def supported_channels(self) -> list[AudioChannels]: + """Return a list of supported channels.""" + return [AudioChannels.CHANNEL_MONO] + + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + data = b"" + async for chunk in stream: + data += chunk + + if not data: + return SpeechResult("", SpeechResultState.ERROR) + + try: + byte_io = io.BytesIO() + with wave.open(byte_io, "wb") as wav_file: + wav_file.setnchannels(metadata.channel) + wav_file.setsampwidth(2) # 2 bytes per sample + wav_file.setframerate(metadata.sample_rate) + wav_file.writeframes(data) + + headers = { + "Authorization": f"Bearer {self._api_key}", + } + + form = aiohttp.FormData() + form.add_field( + "file", + byte_io.getvalue(), + filename="audio.wav", + content_type="audio/wav", + ) + form.add_field("language", metadata.language) + form.add_field("model", "whisper-1") + + async with aiohttp.ClientSession() as session: + async with session.post( + "https://api.openai.com/v1/audio/transcriptions", + data=form, + headers=headers, + ) as response: + if response.status == 200: + json_response = await response.json() + return SpeechResult( + json_response["text"], SpeechResultState.SUCCESS + ) + else: + text = await response.text() + _LOGGER.warning("{}:{}".format(response.status, text)) + return SpeechResult("", SpeechResultState.ERROR) + except Exception as e: + _LOGGER.warning(e) + return SpeechResult("", SpeechResultState.ERROR) diff --git a/openai_stt/stt.py b/openai_stt/stt.py deleted file mode 100644 index 081c048..0000000 --- a/openai_stt/stt.py +++ /dev/null @@ -1,137 +0,0 @@ -""" -Support for Whisper API STT. -""" - -from typing import AsyncIterable -import aiohttp -import logging -import io -import wave -import voluptuous as vol -from homeassistant.components.stt import ( - AudioBitRates, - AudioChannels, - AudioCodecs, - AudioFormats, - AudioSampleRates, - Provider, - SpeechMetadata, - SpeechResult, - SpeechResultState, -) -import homeassistant.helpers.config_validation as cv - -_LOGGER = logging.getLogger(__name__) - -CONF_API_KEY = "api_key" -CONF_LANG = "language" - -PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( - { - vol.Required(CONF_API_KEY): cv.string, - vol.Optional(CONF_LANG, default="en-US"): cv.string, - } -) - - -async def async_get_engine(hass, config, discovery_info=None): - """Set up Whisper API STT speech component.""" - api_key = config.get(CONF_API_KEY) - lang = config.get(CONF_LANG) - return OpenAISTTProvider(hass, api_key, lang) - - -class OpenAISTTProvider(Provider): - """The Whisper API STT provider.""" - - def __init__(self, hass, api_key, lang): - """Initialize Whisper API STT provider.""" - self.hass = hass - self._api_key = api_key - self._language = lang - - @property - def default_language(self) -> str: - """Return the default language.""" - return self._language.split(",")[0] - - @property - def supported_languages(self) -> list[str]: - """Return the list of supported languages.""" - return self._language.split(",") - - @property - def supported_formats(self) -> list[AudioFormats]: - """Return a list of supported formats.""" - return [AudioFormats.WAV] - - @property - def supported_codecs(self) -> list[AudioCodecs]: - """Return a list of supported codecs.""" - return [AudioCodecs.PCM] - - @property - def supported_bit_rates(self) -> list[AudioBitRates]: - """Return a list of supported bitrates.""" - return [AudioBitRates.BITRATE_16] - - @property - def supported_sample_rates(self) -> list[AudioSampleRates]: - """Return a list of supported samplerates.""" - return [AudioSampleRates.SAMPLERATE_16000] - - @property - def supported_channels(self) -> list[AudioChannels]: - """Return a list of supported channels.""" - return [AudioChannels.CHANNEL_MONO] - - async def async_process_audio_stream( - self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] - ) -> SpeechResult: - data = b"" - async for chunk in stream: - data += chunk - - if not data: - return SpeechResult("", SpeechResultState.ERROR) - - try: - byte_io = io.BytesIO() - with wave.open(byte_io, "wb") as wav_file: - wav_file.setnchannels(metadata.channel) - wav_file.setsampwidth(2) # 2 bytes per sample - wav_file.setframerate(metadata.sample_rate) - wav_file.writeframes(data) - - headers = { - "Authorization": f"Bearer {self._api_key}", - } - - form = aiohttp.FormData() - form.add_field( - "file", - byte_io.getvalue(), - filename="audio.wav", - content_type="audio/wav", - ) - form.add_field("language", metadata.language) - form.add_field("model", "whisper-1") - - async with aiohttp.ClientSession() as session: - async with session.post( - "https://api.openai.com/v1/audio/transcriptions", - data=form, - headers=headers, - ) as response: - if response.status == 200: - json_response = await response.json() - return SpeechResult( - json_response["text"], SpeechResultState.SUCCESS - ) - else: - text = await response.text() - _LOGGER.warning("{}:{}".format(response.status, text)) - return SpeechResult("", SpeechResultState.ERROR) - except Exception as e: - _LOGGER.warning(e) - return SpeechResult("", SpeechResultState.ERROR) diff --git a/requirements.txt b/requirements.txt index d2ea61b..a271e84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ aiohttp-fast-zlib==0.1.0 aiooui==0.1.5 aiosignal==1.3.1 aiozoneinfo==0.1.0 +annotated-types==0.7.0 anyio==4.4.0 astral==2.2 async-interrupt==1.1.1 @@ -29,6 +30,7 @@ charset-normalizer==3.3.2 ciso8601==2.3.1 cryptography==42.0.5 dbus-fast==2.21.3 +distro==1.9.0 envs==1.4 fnv-hash-fast==0.5.0 fnvhash==0.1.0 @@ -49,6 +51,7 @@ josepy==1.14.0 lru-dict==1.3.0 MarkupSafe==2.1.5 multidict==6.0.5 +openai==1.35.1 orjson==3.9.15 packaging==24.1 pillow==10.3.0 @@ -57,6 +60,8 @@ psutil-home-assistant==0.0.1 pycares==4.4.0 pycognito==2024.5.1 pycparser==2.22 +pydantic==2.7.4 +pydantic_core==2.18.4 PyJWT==2.8.0 pyOpenSSL==24.1.0 pyRFC3339==1.1 @@ -73,6 +78,7 @@ sniffio==1.3.1 snitun==0.39.1 SQLAlchemy==2.0.30 text-unidecode==1.3 +tqdm==4.66.4 typing_extensions==4.12.2 tzdata==2024.1 uart-devices==0.1.0 @@ -80,6 +86,7 @@ ulid-transform==0.9.0 urllib3==1.26.19 usb-devices==0.4.5 voluptuous==0.13.1 +voluptuous-openapi==0.0.4 voluptuous-serialize==2.6.0 wheel==0.43.0 yarl==1.9.4 -- GitLab