diff --git a/openai_stt/stt.py b/openai_stt/stt.py index 146055ae4d586d0f6eba398f1585711470266560..2c8e44e4fa5480e1dc9d16c6efc2e2eab4e109c3 100644 --- a/openai_stt/stt.py +++ b/openai_stt/stt.py @@ -25,13 +25,15 @@ _LOGGER = logging.getLogger(__name__) CONF_API_KEY = "api_key" CONF_BASE_URL = "base_url" +CONF_MODEL = "model" CONF_LANG = "language" PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( { vol.Required(CONF_API_KEY): cv.string, vol.Optional(CONF_BASE_URL, default="https://api.openai.com/v1/audio/transcriptions"): cv.string, - vol.Optional(CONF_LANG, default="en-US"): cv.string, + vol.Optional(CONF_MODEL, default="whisper-1"): cv.string, + vol.Optional(CONF_LANG, default="en"): cv.string, } ) @@ -40,18 +42,20 @@ async def async_get_engine(hass, config, discovery_info=None): """Set up Whisper API STT speech component.""" api_key = config.get(CONF_API_KEY) base_url = config.get(CONF_BASE_URL) + model = config.get(CONF_MODEL) lang = config.get(CONF_LANG) - return OpenAISTTProvider(hass, api_key, base_url, lang) + return OpenAISTTProvider(hass, api_key, base_url, model, lang) class OpenAISTTProvider(Provider): """The Whisper API STT provider.""" - def __init__(self, hass, api_key, base_url, lang): + def __init__(self, hass, api_key, base_url, model, lang): """Initialize Whisper API STT provider.""" self.hass = hass self._api_key = api_key self._base_url = base_url + self._model = model self._language = lang @property @@ -119,7 +123,7 @@ class OpenAISTTProvider(Provider): content_type="audio/wav", ) form.add_field("language", metadata.language) - form.add_field("model", "whisper-1") + form.add_field("model", self._model) async with aiohttp.ClientSession() as session: async with session.post(