Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions safetytooling/apis/inference/openai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def convert_top_logprobs(data):
async def _make_api_call(self, prompt: Prompt, model_id, start_time, **params) -> list[LLMResponse]:
LOGGER.debug(f"Making {model_id} call")

if prompt.contains_image():
assert model_id in VISION_MODELS, f"Model {model_id} does not support images"
# TODO(task): using prompt.contains_image() check that the model supports images by using VISION_MODELS

# convert completion logprobs api to chat logprobs api
if "logprobs" in params:
Expand Down
91 changes: 13 additions & 78 deletions safetytooling/data_models/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,8 @@ def openai_image_format(self):
# for images the format involves including images and user text in the same message
if self.role == MessageRole.image:
base64_image = image_to_base64(self.content)
return {
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
},
}
# TODO(task): return the format for the openai image message https://platform.openai.com/docs/guides/vision
raise NotImplementedError("Not implemented")
elif self.role == MessageRole.user:
return {"type": "text", "text": self.content}
else:
Expand Down Expand Up @@ -164,15 +160,8 @@ def is_last_message_assistant(self) -> bool:
return self.messages[-1].role == MessageRole.assistant

def contains_image(self) -> bool:
"""Enhanced validation for image-user message pairs"""
for i, msg in enumerate(self.messages):
if msg.role == MessageRole.image:
# Ensure image is followed by user message
if i + 1 >= len(self.messages):
raise ValueError("Each image must be followed by a user message")
if self.messages[i + 1].role != MessageRole.user:
raise ValueError(f"Image must be followed by user message, got {self.messages[i + 1].role}")
return any(msg.role == MessageRole.image for msg in self.messages)
# TODO(task): return true if any message is an image
raise NotImplementedError("Not implemented")

def add_assistant_message(self, message: str) -> "Prompt":
return self + Prompt(messages=[ChatMessage(role=MessageRole.assistant, content=message)])
Expand Down Expand Up @@ -247,8 +236,10 @@ def openai_format(
)
if self.is_none_in_messages():
raise ValueError(f"OpenAI chat prompts cannot have a None role. Got {self.messages}")

if self.contains_image():
return self.openai_image_format()
# TODO(task): return the format for the openai image messages
raise NotImplementedError("Not implemented")
return [msg.openai_format() for msg in self.messages]

def gemini_format(self, use_vertexai: bool = False) -> List[str]:
Expand Down Expand Up @@ -298,36 +289,12 @@ def openai_s2s_format(self) -> List[Any]:
return messages

def openai_image_format(self) -> List[Any]:
# TODO(task): return messages in the format for the openai api
# you should use the format methods of each ChatMessage
# also ensure that an image is followed by a user message
# remember to deal with system messages first
messages = []

# Handle system message if it exists first
if self.messages[0].role == MessageRole.system:
messages.append(self.messages[0].openai_format())
messages_to_process = self.messages[1:]
else:
messages_to_process = self.messages

i = 0
while i < len(messages_to_process):
msg = messages_to_process[i]

if msg.role == MessageRole.image:
# Ensure image is followed by a user message
assert i + 1 < len(messages_to_process), "Image must be followed by a user message as caption"
next_msg = messages_to_process[i + 1]
assert next_msg.role == MessageRole.user, f"Image must be followed by user message, got {next_msg.role}"

# Add image and user message as a combined message
content = [msg.openai_image_format(), {"type": "text", "text": next_msg.content}]
messages.append({"role": "user", "content": content})
i += 2 # Skip both image and user message

elif msg.role in (MessageRole.user, MessageRole.assistant):
messages.append(msg.openai_format())
i += 1

else:
raise ValueError(f"Invalid role {msg.role} in prompt")
raise NotImplementedError("Not implemented")

return messages

Expand All @@ -352,39 +319,7 @@ def anthropic_format(

def anthropic_image_format(self) -> Tuple[str | None, List[anthropic.types.MessageParam]]:
"""Returns (system_message (optional), chat_messages)"""
system_message = None
messages = []

# Handle system message if present
if self.messages[0].role == MessageRole.system:
system_message = self.messages[0].content
messages_to_process = self.messages[1:]
else:
messages_to_process = self.messages

i = 0
while i < len(messages_to_process):
msg = messages_to_process[i]

if msg.role == MessageRole.image:
# Ensure image is followed by a user message
assert i + 1 < len(messages_to_process), "Image must be followed by a user message as caption"
next_msg = messages_to_process[i + 1]
assert next_msg.role == MessageRole.user, f"Image must be followed by user message, got {next_msg.role}"

# Add image and user message as a combined message
content = [msg.anthropic_image_format(), {"type": "text", "text": next_msg.content}]
messages.append(anthropic.types.MessageParam(role="user", content=content))
i += 2 # Skip both image and user message

elif msg.role in (MessageRole.user, MessageRole.assistant):
messages.append(msg.anthropic_format())
i += 1

else:
raise ValueError(f"Invalid role {msg.role} in prompt")

return system_message, messages
raise NotImplementedError("Not implemented")

def pretty_print(self, responses: list[LLMResponse], print_fn: Callable | None = None) -> None:
if print_fn is None:
Expand Down