diff --git a/prowler/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/__init__.py b/prowler/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/prowler/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled.metadata.json b/prowler/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled.metadata.json new file mode 100644 index 0000000000..fbdd78f289 --- /dev/null +++ b/prowler/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled.metadata.json @@ -0,0 +1,40 @@ +{ + "Provider": "aws", + "CheckID": "sagemaker_models_monitor_enabled", + "CheckTitle": "Amazon SageMaker has a monitoring schedule scheduled", + "CheckType": [ + "Software and Configuration Checks/AWS Security Best Practices", + "Software and Configuration Checks/Industry and Regulatory Standards/AWS Foundational Security Best Practices" + ], + "ServiceName": "sagemaker", + "SubServiceName": "", + "ResourceIdTemplate": "", + "Severity": "low", + "ResourceType": "Other", + "ResourceGroup": "ai_ml", + "Description": "**SageMaker Models Monitor** detects data drift, model quality issues, and bias drift in production.", + "Risk": "Without active monitoring, model degradation goes undetected and downstream decisions (fraud, access, pricing) silently degrade.", + "RelatedUrl": "", + "AdditionalURLs": [ + "https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor.html", + "https://github.com/aws-samples/sample-aiml-security-assessment" + ], + "Remediation": { + "Code": { + "CLI": "", + "NativeIaC": "", + "Other": "", + "Terraform": "" + }, + "Recommendation": { + "Text": "Ensure that model monitoring is active and you have at least one monitor scheduled to check.", + "Url": "https://hub.prowler.com/check/sagemaker_models_monitor_enabled" + } + }, + "Categories": [ + "gen-ai" + ], + "DependsOn": [], + "RelatedTo": [], + "Notes": "" +} diff --git a/prowler/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled.py b/prowler/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled.py new file mode 100644 index 0000000000..9e92df7f60 --- /dev/null +++ b/prowler/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled.py @@ -0,0 +1,35 @@ +from prowler.lib.check.models import Check, Check_Report_AWS +from prowler.providers.aws.services.sagemaker.sagemaker_client import sagemaker_client + + +class sagemaker_models_monitor_enabled(Check): + def execute(self): + findings = [] + monitoring_schedule_exists = True + monitoring_schedule_is_scheduled = False + for monitoring_schedule in sagemaker_client.sagemaker_monitoring_schedules: + report = Check_Report_AWS( + metadata=self.metadata(), resource=monitoring_schedule + ) + if monitoring_schedule.schedule_status == "Scheduled": + monitoring_schedule_is_scheduled = True + break + + else: + if monitoring_schedule.schedule_status == "NOT_AVAILABLE": + monitoring_schedule_exists = False + + if not monitoring_schedule_exists: + report.status = "FAIL" + report.status_extended = f"SageMaker monitoring schedules in account {sagemaker_client.audited_account} do not exist." + findings.append(report) + else: + if monitoring_schedule_is_scheduled: + report.status = "PASS" + report.status_extended = f"SageMaker monitoring schedule {monitoring_schedule.name} is enabled." + findings.append(report) + else: + report.status = "FAIL" + report.status_extended = f"SageMaker monitoring schedule {monitoring_schedule.name} is not active." + findings.append(report) + return findings diff --git a/prowler/providers/aws/services/sagemaker/sagemaker_service.py b/prowler/providers/aws/services/sagemaker/sagemaker_service.py index 0f73062452..12b2e8b522 100644 --- a/prowler/providers/aws/services/sagemaker/sagemaker_service.py +++ b/prowler/providers/aws/services/sagemaker/sagemaker_service.py @@ -18,6 +18,7 @@ def __init__(self, provider): self.sagemaker_domains = [] self.endpoint_configs = {} self.sagemaker_model_registries = [] + self.sagemaker_monitoring_schedules = [] # Retrieve resources concurrently self.__threading_call__(self._list_notebook_instances) @@ -26,6 +27,7 @@ def __init__(self, provider): self.__threading_call__(self._list_endpoint_configs) self.__threading_call__(self._list_domains) self.__threading_call__(self._list_model_package_groups) + self.__threading_call__(self._list_monitoring_schedules) # Describe resources concurrently self.__threading_call__(self._describe_model, self.sagemaker_models) @@ -377,6 +379,46 @@ def _describe_endpoint_config(self, endpoint_config): f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) + def _list_monitoring_schedules(self, regional_client): + logger.info("SageMaker - listing monitoring schedules...") + try: + list_monitoring_schedules_paginator = regional_client.get_paginator( + "list_monitoring_schedules" + ) + schedule_counter = 0 + for page in list_monitoring_schedules_paginator.paginate(): + for schedule in page["MonitoringScheduleSummaries"]: + if not self.audit_resources or ( + is_resource_filtered( + schedule["MonitoringScheduleArn"], self.audit_resources + ) + ): + schedule_counter += 1 + self.sagemaker_monitoring_schedules.append( + MonitoringSchedule( + name=schedule["MonitoringScheduleName"], + region=regional_client.region, + arn=schedule["MonitoringScheduleArn"], + schedule_status=schedule["MonitoringScheduleStatus"], + ) + ) + if schedule_counter == 0: + self.sagemaker_monitoring_schedules.append( + MonitoringSchedule( + name="monitoring_schedule/unknown", + region=regional_client.region, + arn=self.get_unknown_arn( + region=regional_client.region, + resource_type="monitoring_schedule", + ), + schedule_status="NOT_AVAILABLE", + ) + ) + except Exception as error: + logger.error( + f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + class NotebookInstance(BaseModel): name: str @@ -432,7 +474,6 @@ class EndpointConfig(BaseModel): production_variants: list[ProductionVariant] = [] tags: Optional[list] = [] - class ModelRegistry(BaseModel): """Represents the SageMaker Model Registry state for a specific region.""" @@ -441,3 +482,9 @@ class ModelRegistry(BaseModel): region: str has_groups: bool = False has_approved_packages: bool = False + +class MonitoringSchedule(BaseModel): + name: str + region: str + arn: str + schedule_status: str diff --git a/tests/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled_test.py b/tests/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled_test.py new file mode 100644 index 0000000000..4f82a673bb --- /dev/null +++ b/tests/providers/aws/services/sagemaker/sagemaker_models_monitor_enabled/sagemaker_models_monitor_enabled_test.py @@ -0,0 +1,164 @@ +from unittest import mock + +from prowler.providers.aws.services.sagemaker.sagemaker_service import ( + MonitoringSchedule, +) +from tests.providers.aws.utils import ( + AWS_ACCOUNT_NUMBER, + AWS_REGION_EU_WEST_1, + set_mocked_aws_provider, +) + +test_monitoring_schedule = "test-monitoring-schedule" +monitoring_schedule_arn = f"arn:aws:sagemaker:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:monitoring-schedule/{test_monitoring_schedule}" + + +class Test_sagemaker_models_monitor_enabled: + def test_no_models_monitoring_schedules_exist(self): + sagemaker_client = mock.MagicMock + sagemaker_client.audited_account = AWS_ACCOUNT_NUMBER + sagemaker_client.sagemaker_monitoring_schedules = [] + sagemaker_client.sagemaker_monitoring_schedules.append( + MonitoringSchedule( + name=test_monitoring_schedule, + region=AWS_REGION_EU_WEST_1, + arn=monitoring_schedule_arn, + schedule_status="NOT_AVAILABLE", + ) + ) + + aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1]) + + with ( + mock.patch( + "prowler.providers.common.provider.Provider.get_global_provider", + return_value=aws_provider, + ), + mock.patch( + "prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled.sagemaker_client", + sagemaker_client, + ), + ): + + from prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled import ( + sagemaker_models_monitor_enabled, + ) + + check = sagemaker_models_monitor_enabled() + result = check.execute() + assert len(result) == 1 + assert result[0].status == "FAIL" + assert ( + result[0].status_extended + == f"SageMaker monitoring schedules in account {sagemaker_client.audited_account} do not exist." + ) + assert result[0].resource_id == test_monitoring_schedule + assert result[0].resource_arn == monitoring_schedule_arn + + def test_no_scheduled_models_monitoring_schedule(self): + sagemaker_client = mock.MagicMock + sagemaker_client.sagemaker_monitoring_schedules = [] + sagemaker_client.sagemaker_monitoring_schedules.extend( + [ + MonitoringSchedule( + name=test_monitoring_schedule, + region=AWS_REGION_EU_WEST_1, + arn=monitoring_schedule_arn, + schedule_status="Pending", + ), + MonitoringSchedule( + name=test_monitoring_schedule, + region=AWS_REGION_EU_WEST_1, + arn=monitoring_schedule_arn, + schedule_status="Stopped", + ), + MonitoringSchedule( + name=test_monitoring_schedule, + region=AWS_REGION_EU_WEST_1, + arn=monitoring_schedule_arn, + schedule_status="Failed", + ), + ] + ) + + aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1]) + + with ( + mock.patch( + "prowler.providers.common.provider.Provider.get_global_provider", + return_value=aws_provider, + ), + mock.patch( + "prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled.sagemaker_client", + sagemaker_client, + ), + ): + + from prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled import ( + sagemaker_models_monitor_enabled, + ) + + check = sagemaker_models_monitor_enabled() + result = check.execute() + assert len(result) == 1 + assert result[0].status == "FAIL" + assert ( + result[0].status_extended + == f"SageMaker monitoring schedule {test_monitoring_schedule} is not active." + ) + assert result[0].resource_id == test_monitoring_schedule + assert result[0].resource_arn == monitoring_schedule_arn + + def test_models_monitor_scheduled(self): + sagemaker_client = mock.MagicMock + sagemaker_client.sagemaker_monitoring_schedules = [] + sagemaker_client.sagemaker_monitoring_schedules.extend( + [ + MonitoringSchedule( + name=test_monitoring_schedule, + region=AWS_REGION_EU_WEST_1, + arn=monitoring_schedule_arn, + schedule_status="Pending", + ), + MonitoringSchedule( + name=test_monitoring_schedule, + region=AWS_REGION_EU_WEST_1, + arn=monitoring_schedule_arn, + schedule_status="Scheduled", + ), + MonitoringSchedule( + name=test_monitoring_schedule, + region=AWS_REGION_EU_WEST_1, + arn=monitoring_schedule_arn, + schedule_status="Failed", + ), + ] + ) + + aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1]) + + with ( + mock.patch( + "prowler.providers.common.provider.Provider.get_global_provider", + return_value=aws_provider, + ), + mock.patch( + "prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled.sagemaker_client", + sagemaker_client, + ), + ): + + from prowler.providers.aws.services.sagemaker.sagemaker_models_monitor_enabled.sagemaker_models_monitor_enabled import ( + sagemaker_models_monitor_enabled, + ) + + check = sagemaker_models_monitor_enabled() + result = check.execute() + assert len(result) == 1 + assert result[0].status == "PASS" + assert ( + result[0].status_extended + == f"SageMaker monitoring schedule {test_monitoring_schedule} is enabled." + ) + assert result[0].resource_id == test_monitoring_schedule + assert result[0].resource_arn == monitoring_schedule_arn