Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ tasks = [
"pyvips==3.1.1.8.18.1",
"uvdat-flood-sim[large-image-writer]==1.0.4",
"xdg-base-dirs==6.0.2",
"huggingface-hub==1.14.0",
]

[dependency-groups]
Expand Down
11 changes: 8 additions & 3 deletions terraform/django.tf
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@ module "django" {
ec2_worker_ssh_public_key = file("${path.module}/ssh-key.pub")

additional_django_vars = {
DJANGO_UVDAT_WEB_URL = "https://www.geodatalytics.kitware.com/"
DJANGO_DATABASE_POOL_MAX_SIZE = "12"
DJANGO_SENTRY_DSN = "https://5302701c88f1fa6ec056e0c269071191@o267860.ingest.us.sentry.io/4510620385804288"
DJANGO_UVDAT_WEB_URL = "https://www.geodatalytics.kitware.com/"
DJANGO_DATABASE_POOL_MAX_SIZE = "12"
DJANGO_SENTRY_DSN = "https://5302701c88f1fa6ec056e0c269071191@o267860.ingest.us.sentry.io/4510620385804288"
DJANGO_UVDAT_HF_NAMESPACE = "Kitware"
DJANGO_UVDAT_HF_ENDPOINT_NAMES = "qwen=qwen3-5-9b-gguf-ulh,"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you end up with any more of these, you can use some Terraform syntax like:

locals {
  huggingface_endpoint_names = {
    qwen = "qwen3-5-9b-gguf-ulh"
  }
}

module "django" {
  ...
  additional_django_vars = {
    ...
    DJANGO_UVDAT_HF_ENDPOINT_NAMES = join(",", [for k, v in local.huggingface_endpoint_names : "${k}=${v}"])
  }
}

Maybe it's overkill for now, but worth remembering in the future.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks. I agree that we can use this in the future when we add more endpoints to the mapping.

}
additional_sensitive_django_vars = {
DJANGO_UVDAT_HF_TOKEN = var.DJANGO_UVDAT_HF_TOKEN
}
django_cors_allowed_origins = [
# Can't make this use "aws_route53_record.www.fqdn" because of a circular dependency
Expand Down
6 changes: 6 additions & 0 deletions terraform/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ variable "SENTRY_AUTH_TOKEN" {
nullable = false
sensitive = true
}

variable "DJANGO_UVDAT_HF_TOKEN" {
type = string
nullable = true
sensitive = true
}
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions uvdat/core/tasks/analytics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .flood_network_failure import FloodNetworkFailure
from .flood_simulation import FloodSimulation
from .geoai_segmentation import GeoAISegmentation
from .imagery_ask_qwen import ImageryAskQwen
from .network_recovery import NetworkRecovery
from .uncertainty_quantification import UncertaintyQuantification

Expand All @@ -15,6 +16,7 @@
analysis_types: list[type[AnalysisType]] = [
FloodSimulation,
FloodNetworkFailure,
ImageryAskQwen,
NetworkRecovery,
UncertaintyQuantification,
GeoAISegmentation,
Expand Down
161 changes: 161 additions & 0 deletions uvdat/core/tasks/analytics/imagery_ask_qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations

import base64

from celery import shared_task
from django.conf import settings
from django_large_image import utilities
import large_image

from uvdat.core.models import RasterData, TaskResult

from .analysis_type import AnalysisInputError, AnalysisTask, AnalysisType

MODEL_CARD_URL = "https://huggingface.co/unsloth/Qwen3.5-9B-GGUF"
SYSTEM_PROMPT = (
"You are an urban planning and geospatial analysis expert specializing in "
"land use patterns, hydrology, transportation networks, and municipal policy. "
"Analyze the provided imagery to answer the user's question. In your answer, "
"assume that the user is also a geospatial analyst with the same expertise."
)
TOKEN_RANGE = {"min": 1000, "max": 10000, "step": 1000}
MAX_PROMPT_LENGTH = 4000
THUMBNAIL_SIZE = 4000
MAX_STARTUP_WAIT = 300


class ImageryAskQwen(AnalysisType):
def __init__(self):
super().__init__()
self.name = "Imagery: Ask Qwen"
self.description = "Select an imagery layer and ask Qwen 3.5 about it."
self.details = (
"Inferencing with unsloth/Qwen3.5-9B-GGUF provided by a "
"Kitware-hosted Huggingface Inference Endpoint. "
f"See the model card at {MODEL_CARD_URL}. "
"Responses may cut off mid-sentence if max_tokens is reached."
)
self.db_value = "imagery_ask_qwen"
self.input_types = {
"imagery": "RasterData",
"text_prompt": "string",
"max_tokens": "number",
}
self.output_types = {
"response": "markdown",
}
self.attribution = "Unsloth AI, Kitware Inc."

@classmethod
def is_enabled(cls) -> bool:
return (
settings.UVDAT_ENABLE_IMAGERY_ASK_QWEN
and settings.UVDAT_HF_TOKEN is not None
and settings.UVDAT_HF_NAMESPACE is not None
and settings.UVDAT_HF_ENDPOINT_NAMES.get("qwen") is not None
)

def get_input_options(self):
return {
"imagery": RasterData.objects.filter(dataset__category="imagery"),
"text_prompt": [],
"max_tokens": [TOKEN_RANGE],
}

def validate_inputs(self, inputs):
super().validate_inputs(inputs)
try:
imagery = RasterData.objects.get(id=inputs.get("imagery"))
except RasterData.DoesNotExist as e:
err_msg = "Imagery raster does not exist."
raise AnalysisInputError(err_msg) from e
if imagery.dataset.category != "imagery":
err_msg = 'Selected raster is not categorized as "imagery".'
raise AnalysisInputError(err_msg)
text_prompt = str(inputs.get("text_prompt"))
if len(text_prompt) > MAX_PROMPT_LENGTH:
err_msg = f"Prompt too long. Provide a prompt with <{MAX_PROMPT_LENGTH} characters."
raise AnalysisInputError(err_msg)
max_tokens = int(inputs.get("max_tokens"))
if max_tokens < TOKEN_RANGE["min"] or max_tokens > TOKEN_RANGE["max"]:
err_msg = f"max_tokens must be between {TOKEN_RANGE['min']} and {TOKEN_RANGE['max']}."
raise AnalysisInputError(err_msg)

def run_task(self, *, project, **inputs):
text_prompt = inputs.get("text_prompt")
result = TaskResult.objects.create(
name=text_prompt[:250],
task_type=self.db_value,
inputs=inputs,
project=project,
status="Initializing Task...",
)
imagery_ask_qwen.delay(result.id)
return result

def finalize(self, result):
pass


@shared_task(base=AnalysisTask)
def imagery_ask_qwen(result_id):
# Only available with [tasks] extra
from huggingface_hub import ( # noqa: PLC0415
InferenceEndpointTimeoutError,
get_inference_endpoint,
)

result = TaskResult.objects.get(id=result_id)
imagery = RasterData.objects.get(id=result.inputs.get("imagery"))
text_prompt = result.inputs.get("text_prompt")
max_tokens = int(result.inputs.get("max_tokens"))
Comment thread
brianhelba marked this conversation as resolved.

result.write_status("Encoding imagery...")
imagery_path = utilities.field_file_to_local_path(imagery.cloud_optimized_geotiff)
src = large_image.open(imagery_path)
thumbnail_bytes, _ = src.getThumbnail(THUMBNAIL_SIZE, THUMBNAIL_SIZE, encoding="PNG")
thumbnail_b64 = base64.b64encode(thumbnail_bytes).decode("utf-8")
thumbnail_uri = f"data:image/jpeg;base64,{thumbnail_b64}"

result.write_status("Starting inference endpoint...")
endpoint = get_inference_endpoint(
name=settings.UVDAT_HF_ENDPOINT_NAMES.get("qwen"),
Comment thread
annehaley marked this conversation as resolved.
Outdated
namespace=settings.UVDAT_HF_NAMESPACE,
token=settings.UVDAT_HF_TOKEN,
)
endpoint.resume()
try:
endpoint.wait(timeout=MAX_STARTUP_WAIT)
except InferenceEndpointTimeoutError:
result.write_error("Endpoint failed to start in 5 minutes. Try again later.")
return

result.write_status("Sending question to Qwen...")
messages = [
{
"role": "system",
"content": SYSTEM_PROMPT,
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": thumbnail_uri}},
{"type": "text", "text": text_prompt},
],
},
]

result.write_status("Awaiting Qwen's response...")
chat = endpoint.client.chat_completion(
model="unsloth/Qwen3.5-9B-GGUF",
messages=messages,
max_tokens=max_tokens,
Comment thread
annehaley marked this conversation as resolved.
)
response = ""
for choice in chat.choices:
if choice.finish_reason == "length":
# max tokens exceeded, use reasoning content
response += choice.message.reasoning_content
else:
response += choice.message.content
result.write_outputs({"response": response})
3 changes: 2 additions & 1 deletion uvdat/core/tests/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def test_rest_list_analysis_types(user, authenticated_api_client, project):
user.is_superuser = True
user.save()

analysis_type_instances = [at() for at in analysis_types]
analysis_type_instances = [at() for at in analysis_types if at.is_enabled()]
resp = authenticated_api_client.get(f"/api/v1/analytics/project/{project.id}/types/")
data = resp.json()

assert len(data) == len(analysis_type_instances)
assert {type_info.get("name") for type_info in data} == {
i.name for i in analysis_type_instances
Expand Down
5 changes: 5 additions & 0 deletions uvdat/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@
}

UVDAT_WEB_URL: str = env.url("DJANGO_UVDAT_WEB_URL").geturl()
UVDAT_HF_TOKEN: str | None = env.str("DJANGO_UVDAT_HF_TOKEN", default=None)
UVDAT_HF_NAMESPACE: str | None = env.str("DJANGO_UVDAT_HF_NAMESPACE", default=None)
UVDAT_HF_ENDPOINT_NAMES: dict = env.dict("DJANGO_UVDAT_HF_ENDPOINT_NAMES", default={})

UVDAT_ENABLE_FLOOD_SIMULATION: bool = env.bool("DJANGO_UVDAT_ENABLE_FLOOD_SIMULATION", default=True)
UVDAT_ENABLE_FLOOD_NETWORK_FAILURE: bool = env.bool(
"DJANGO_UVDAT_ENABLE_FLOOD_NETWORK_FAILURE", default=True
Expand All @@ -172,6 +176,7 @@
UVDAT_ENABLE_UNCERTAINTY_QUANTIFICATION: bool = env.bool(
"DJANGO_UVDAT_ENABLE_UNCERTAINTY_QUANTIFICATION", default=True
)
UVDAT_ENABLE_IMAGERY_ASK_QWEN: bool = env.bool("DJANGO_UVDAT_ENABLE_IMAGERY_ASK_QWEN", default=True)
Comment thread
brianhelba marked this conversation as resolved.

logging.getLogger("pyvips").setLevel(logging.ERROR)
logging.getLogger("rasterio").setLevel(logging.ERROR)
Expand Down
Loading