Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"Provider": "aws",
"CheckID": "sagemaker_models_monitor_enabled",
"CheckTitle": "Amazon SageMaker has a monitoring schedule scheduled",
"CheckType": [
"Software and Configuration Checks/AWS Security Best Practices",
"Software and Configuration Checks/Industry and Regulatory Standards/AWS Foundational Security Best Practices"
],
"ServiceName": "sagemaker",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": "low",
"ResourceType": "Other",
"ResourceGroup": "ai_ml",
"Description": "**SageMaker Models Monitor** detects data drift, model quality issues, and bias drift in production.",
"Risk": "Without active monitoring, model degradation goes undetected and downstream decisions (fraud, access, pricing) silently degrade.",
"RelatedUrl": "",
"AdditionalURLs": [
"https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor.html",
"https://github.com/aws-samples/sample-aiml-security-assessment"
],
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "",
"Terraform": ""
},
"Recommendation": {
"Text": "Ensure that model monitoring is active and you have at least one monitor scheduled to check.",
"Url": "https://hub.prowler.com/check/sagemaker_models_monitor_enabled"
}
},
"Categories": [
"gen-ai"
],
"DependsOn": [],
"RelatedTo": [],
"Notes": ""
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from prowler.lib.check.models import Check, Check_Report_AWS
from prowler.providers.aws.services.sagemaker.sagemaker_client import sagemaker_client


class sagemaker_models_monitor_enabled(Check):
def execute(self):
findings = []
monitoring_schedule_exists = True
monitoring_schedule_is_scheduled = False
for monitoring_schedule in sagemaker_client.sagemaker_monitoring_schedules:
report = Check_Report_AWS(
metadata=self.metadata(), resource=monitoring_schedule
)
if monitoring_schedule.schedule_status == "Scheduled":
monitoring_schedule_is_scheduled = True
break

else:
if monitoring_schedule.schedule_status == "NOT_AVAILABLE":
monitoring_schedule_exists = False

if not monitoring_schedule_exists:
report.status = "FAIL"
report.status_extended = f"SageMaker monitoring schedules in account {sagemaker_client.audited_account} do not exist."
findings.append(report)
else:
if monitoring_schedule_is_scheduled:
report.status = "PASS"
report.status_extended = f"SageMaker monitoring schedule {monitoring_schedule.name} is enabled."
findings.append(report)
else:
report.status = "FAIL"
report.status_extended = f"SageMaker monitoring schedule {monitoring_schedule.name} is not active."
findings.append(report)
return findings
49 changes: 48 additions & 1 deletion prowler/providers/aws/services/sagemaker/sagemaker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, provider):
self.sagemaker_domains = []
self.endpoint_configs = {}
self.sagemaker_model_registries = []
self.sagemaker_monitoring_schedules = []

# Retrieve resources concurrently
self.__threading_call__(self._list_notebook_instances)
Expand All @@ -26,6 +27,7 @@ def __init__(self, provider):
self.__threading_call__(self._list_endpoint_configs)
self.__threading_call__(self._list_domains)
self.__threading_call__(self._list_model_package_groups)
self.__threading_call__(self._list_monitoring_schedules)

# Describe resources concurrently
self.__threading_call__(self._describe_model, self.sagemaker_models)
Expand Down Expand Up @@ -377,6 +379,46 @@ def _describe_endpoint_config(self, endpoint_config):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

def _list_monitoring_schedules(self, regional_client):
logger.info("SageMaker - listing monitoring schedules...")
try:
list_monitoring_schedules_paginator = regional_client.get_paginator(
"list_monitoring_schedules"
)
schedule_counter = 0
for page in list_monitoring_schedules_paginator.paginate():
for schedule in page["MonitoringScheduleSummaries"]:
if not self.audit_resources or (
is_resource_filtered(
schedule["MonitoringScheduleArn"], self.audit_resources
)
):
schedule_counter += 1
self.sagemaker_monitoring_schedules.append(
MonitoringSchedule(
name=schedule["MonitoringScheduleName"],
region=regional_client.region,
arn=schedule["MonitoringScheduleArn"],
schedule_status=schedule["MonitoringScheduleStatus"],
)
)
if schedule_counter == 0:
self.sagemaker_monitoring_schedules.append(
MonitoringSchedule(
name="monitoring_schedule/unknown",
region=regional_client.region,
arn=self.get_unknown_arn(
region=regional_client.region,
resource_type="monitoring_schedule",
),
schedule_status="NOT_AVAILABLE",
)
)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)


class NotebookInstance(BaseModel):
name: str
Expand Down Expand Up @@ -432,7 +474,6 @@ class EndpointConfig(BaseModel):
production_variants: list[ProductionVariant] = []
tags: Optional[list] = []


class ModelRegistry(BaseModel):
"""Represents the SageMaker Model Registry state for a specific region."""

Expand All @@ -441,3 +482,9 @@ class ModelRegistry(BaseModel):
region: str
has_groups: bool = False
has_approved_packages: bool = False

class MonitoringSchedule(BaseModel):
name: str
region: str
arn: str
schedule_status: str
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from unittest import mock

from prowler.providers.aws.services.sagemaker.sagemaker_service import (
MonitoringSchedule,
)
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
AWS_REGION_EU_WEST_1,
set_mocked_aws_provider,
)

test_monitoring_schedule = "test-monitoring-schedule"
monitoring_schedule_arn = f"arn:aws:sagemaker:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:monitoring-schedule/{test_monitoring_schedule}"


class Test_sagemaker_models_monitor_enabled:
def test_no_models_monitoring_schedules_exist(self):
sagemaker_client = mock.MagicMock
sagemaker_client.audited_account = AWS_ACCOUNT_NUMBER
sagemaker_client.sagemaker_monitoring_schedules = []
sagemaker_client.sagemaker_monitoring_schedules.append(
MonitoringSchedule(
name=test_monitoring_schedule,
region=AWS_REGION_EU_WEST_1,
arn=monitoring_schedule_arn,
schedule_status="NOT_AVAILABLE",
)
)

aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])

with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled.sagemaker_client",
sagemaker_client,
),
):

from prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled import (
sagemaker_models_monitor_enabled,
)

check = sagemaker_models_monitor_enabled()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"SageMaker monitoring schedules in account {sagemaker_client.audited_account} do not exist."
)
assert result[0].resource_id == test_monitoring_schedule
assert result[0].resource_arn == monitoring_schedule_arn

def test_no_scheduled_models_monitoring_schedule(self):
sagemaker_client = mock.MagicMock
sagemaker_client.sagemaker_monitoring_schedules = []
sagemaker_client.sagemaker_monitoring_schedules.extend(
[
MonitoringSchedule(
name=test_monitoring_schedule,
region=AWS_REGION_EU_WEST_1,
arn=monitoring_schedule_arn,
schedule_status="Pending",
),
MonitoringSchedule(
name=test_monitoring_schedule,
region=AWS_REGION_EU_WEST_1,
arn=monitoring_schedule_arn,
schedule_status="Stopped",
),
MonitoringSchedule(
name=test_monitoring_schedule,
region=AWS_REGION_EU_WEST_1,
arn=monitoring_schedule_arn,
schedule_status="Failed",
),
]
)

aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])

with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled.sagemaker_client",
sagemaker_client,
),
):

from prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled import (
sagemaker_models_monitor_enabled,
)

check = sagemaker_models_monitor_enabled()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"SageMaker monitoring schedule {test_monitoring_schedule} is not active."
)
assert result[0].resource_id == test_monitoring_schedule
assert result[0].resource_arn == monitoring_schedule_arn

def test_models_monitor_scheduled(self):
sagemaker_client = mock.MagicMock
sagemaker_client.sagemaker_monitoring_schedules = []
sagemaker_client.sagemaker_monitoring_schedules.extend(
[
MonitoringSchedule(
name=test_monitoring_schedule,
region=AWS_REGION_EU_WEST_1,
arn=monitoring_schedule_arn,
schedule_status="Pending",
),
MonitoringSchedule(
name=test_monitoring_schedule,
region=AWS_REGION_EU_WEST_1,
arn=monitoring_schedule_arn,
schedule_status="Scheduled",
),
MonitoringSchedule(
name=test_monitoring_schedule,
region=AWS_REGION_EU_WEST_1,
arn=monitoring_schedule_arn,
schedule_status="Failed",
),
]
)

aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])

with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled.sagemaker_client",
sagemaker_client,
),
):

from prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled import (
sagemaker_models_monitor_enabled,
)

check = sagemaker_models_monitor_enabled()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"SageMaker monitoring schedule {test_monitoring_schedule} is enabled."
)
assert result[0].resource_id == test_monitoring_schedule
assert result[0].resource_arn == monitoring_schedule_arn