From a6f65e0b8680172d4bb38014d120154418166ea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Hedenstro=CC=88m?= <erik@hedenstroem.com> Date: Wed, 26 Jun 2024 13:51:53 +0200 Subject: [PATCH] Added model to STT config --- openai_stt/stt.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/openai_stt/stt.py b/openai_stt/stt.py index 146055a..2c8e44e 100644 --- a/openai_stt/stt.py +++ b/openai_stt/stt.py @@ -25,13 +25,15 @@ _LOGGER = logging.getLogger(__name__) CONF_API_KEY = "api_key" CONF_BASE_URL = "base_url" +CONF_MODEL = "model" CONF_LANG = "language" PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( { vol.Required(CONF_API_KEY): cv.string, vol.Optional(CONF_BASE_URL, default="https://api.openai.com/v1/audio/transcriptions"): cv.string, - vol.Optional(CONF_LANG, default="en-US"): cv.string, + vol.Optional(CONF_MODEL, default="whisper-1"): cv.string, + vol.Optional(CONF_LANG, default="en"): cv.string, } ) @@ -40,18 +42,20 @@ async def async_get_engine(hass, config, discovery_info=None): """Set up Whisper API STT speech component.""" api_key = config.get(CONF_API_KEY) base_url = config.get(CONF_BASE_URL) + model = config.get(CONF_MODEL) lang = config.get(CONF_LANG) - return OpenAISTTProvider(hass, api_key, base_url, lang) + return OpenAISTTProvider(hass, api_key, base_url, model, lang) class OpenAISTTProvider(Provider): """The Whisper API STT provider.""" - def __init__(self, hass, api_key, base_url, lang): + def __init__(self, hass, api_key, base_url, model, lang): """Initialize Whisper API STT provider.""" self.hass = hass self._api_key = api_key self._base_url = base_url + self._model = model self._language = lang @property @@ -119,7 +123,7 @@ class OpenAISTTProvider(Provider): content_type="audio/wav", ) form.add_field("language", metadata.language) - form.add_field("model", "whisper-1") + form.add_field("model", self._model) async with aiohttp.ClientSession() as session: async with session.post( -- GitLab