diff --git a/openai_stt/stt.py b/openai_stt/stt.py index 081c04800c8302a3939966ecd548435248b7e4c3..146055ae4d586d0f6eba398f1585711470266560 100644 --- a/openai_stt/stt.py +++ b/openai_stt/stt.py @@ -24,11 +24,13 @@ import homeassistant.helpers.config_validation as cv _LOGGER = logging.getLogger(__name__) CONF_API_KEY = "api_key" +CONF_BASE_URL = "base_url" CONF_LANG = "language" PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( { vol.Required(CONF_API_KEY): cv.string, + vol.Optional(CONF_BASE_URL, default="https://api.openai.com/v1/audio/transcriptions"): cv.string, vol.Optional(CONF_LANG, default="en-US"): cv.string, } ) @@ -37,17 +39,19 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( async def async_get_engine(hass, config, discovery_info=None): """Set up Whisper API STT speech component.""" api_key = config.get(CONF_API_KEY) + base_url = config.get(CONF_BASE_URL) lang = config.get(CONF_LANG) - return OpenAISTTProvider(hass, api_key, lang) + return OpenAISTTProvider(hass, api_key, base_url, lang) class OpenAISTTProvider(Provider): """The Whisper API STT provider.""" - def __init__(self, hass, api_key, lang): + def __init__(self, hass, api_key, base_url, lang): """Initialize Whisper API STT provider.""" self.hass = hass self._api_key = api_key + self._base_url = base_url self._language = lang @property @@ -119,7 +123,7 @@ class OpenAISTTProvider(Provider): async with aiohttp.ClientSession() as session: async with session.post( - "https://api.openai.com/v1/audio/transcriptions", + self._base_url, data=form, headers=headers, ) as response: