Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion safetytooling/apis/inference/openai/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import os
import time
from pathlib import Path
from traceback import format_exc
Expand Down Expand Up @@ -69,7 +70,14 @@ def __init__(
if openai_api_key:
self.aclient = openai.AsyncClient(api_key=openai_api_key, base_url=self.base_url)
else:
self.aclient = openai.AsyncClient(base_url=self.base_url)
# openai v1+ requires api_key at construction time — it will raise OpenAIError
# if OPENAI_API_KEY is not set, even when only non-OpenAI models are used.
# Fall back to a sentinel so the client initialises cleanly; actual OpenAI
# calls will fail with a 401 if a real key is not set in the environment.
self.aclient = openai.AsyncClient(
api_key=os.environ.get("OPENAI_API_KEY", "not-configured"),
base_url=self.base_url,
)
self.openai_api_key = openai_api_key

self.token_capacity = dict()
Expand Down
5 changes: 4 additions & 1 deletion safetytooling/apis/inference/openai/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import os
import time
import traceback

Expand All @@ -20,7 +21,9 @@ def __init__(self, batch_size: int = 2048):
self.num_threads = 1
self.batch_size = batch_size # Max batch size for embedding endpoint

self.aclient = openai.AsyncClient()
# openai v1+ requires api_key at construction; fall back to sentinel so
# the client initialises even when OPENAI_API_KEY is not set in the environment.
self.aclient = openai.AsyncClient(api_key=os.environ.get("OPENAI_API_KEY", "not-configured"))
self.available_requests = asyncio.BoundedSemaphore(self.num_threads)

async def embed(
Expand Down
5 changes: 4 additions & 1 deletion safetytooling/apis/inference/openai/moderation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import os
import time
from traceback import format_exc
from typing import Awaitable
Expand Down Expand Up @@ -27,7 +28,9 @@ def __init__(
self.num_threads = num_threads
self._batch_size = 32 # Max batch size for moderation endpoint

self.aclient = openai.AsyncClient()
# openai v1+ requires api_key at construction; fall back to sentinel so
# the client initialises even when OPENAI_API_KEY is not set in the environment.
self.aclient = openai.AsyncClient(api_key=os.environ.get("OPENAI_API_KEY", "not-configured"))
self.available_requests = asyncio.BoundedSemaphore(self.num_threads)

async def _single_moderation_request(
Expand Down
5 changes: 4 additions & 1 deletion safetytooling/apis/inference/openai/s2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ async def acquire(self):

class OpenAIS2SModel(InferenceAPIModel):
def __init__(self):
self.api_key = os.environ["OPENAI_API_KEY"]
# Use .get() with a sentinel so the model initialises even when OPENAI_API_KEY
# is not set; actual S2S WebSocket connections will fail at call time if no
# real key is provided.
self.api_key = os.environ.get("OPENAI_API_KEY", "not-configured")
self.base_url = "wss://api.openai.com/v1/realtime"
self.model = "gpt-4o-realtime-preview-2024-10-01"
self.max_size = 10 * 1024 * 1024 # 10MB
Expand Down