From d5061d9722bcab2cab11695455907253b53fd854 Mon Sep 17 00:00:00 2001 From: Nishar Date: Wed, 29 Apr 2026 14:14:39 -0700 Subject: [PATCH] feat(adapters): add native hooks for Anthropic, SK, smolagents, PydanticAI Complete the native-hooks migration for all four remaining adapters: Anthropic: - Add GovernanceMessageHook + as_message_hook() factory - Pre-execution: content scanning, tool allowlist, token limits - Post-execution: tool_use validation, token tracking, audit - Deprecate wrap() and wrap_client() with migration guidance Semantic Kernel: - Add GovernanceFunctionFilter + as_filter() factory - Uses SK's native add_filter('auto_function_invocation', ...) system - Validates function names, blocked patterns, call counts - Deprecate wrap() and wrap_kernel() with migration guidance Smolagents: - Add GovernanceStepCallback + as_step_callback() factory - Implements step_callbacks protocol: __call__(step, agent) - Validates tool names, blocked patterns, observations - Deprecate wrap() with migration guidance PydanticAI: - Add GovernanceCapability + as_capability() factory - Lifecycle hooks: before/after_run, before/after_tool_execute - Pre-execution policy gating, post-execution drift detection - Deprecate wrap() with migration guidance Package exports: - Export AnthropicGovernanceHook, SKGovernanceFilter, SmolagentsGovernanceCallback, PydanticAIGovernanceCapability from integrations __init__.py Tests: - test_anthropic_hooks.py: 12 tests - test_semantic_kernel_hooks.py: 10 tests - test_smolagents_hooks.py: 14 tests - test_pydantic_ai_hooks.py: 16 tests Part of: microsoft/agent-governance-toolkit#1571 --- .../src/agent_os/integrations/__init__.py | 23 +- .../integrations/anthropic_adapter.py | 206 ++++++++++++++- .../integrations/pydantic_ai_adapter.py | 228 +++++++++++++++- .../integrations/semantic_kernel_adapter.py | 196 +++++++++++++- .../integrations/smolagents_adapter.py | 200 +++++++++++++- .../agent-os/tests/test_anthropic_hooks.py | 246 ++++++++++++++++++ .../agent-os/tests/test_pydantic_ai_hooks.py | 218 ++++++++++++++++ .../tests/test_semantic_kernel_hooks.py | 218 ++++++++++++++++ .../agent-os/tests/test_smolagents_hooks.py | 226 ++++++++++++++++ 9 files changed, 1744 insertions(+), 17 deletions(-) create mode 100644 agent-governance-python/agent-os/tests/test_anthropic_hooks.py create mode 100644 agent-governance-python/agent-os/tests/test_pydantic_ai_hooks.py create mode 100644 agent-governance-python/agent-os/tests/test_semantic_kernel_hooks.py create mode 100644 agent-governance-python/agent-os/tests/test_smolagents_hooks.py diff --git a/agent-governance-python/agent-os/src/agent_os/integrations/__init__.py b/agent-governance-python/agent-os/src/agent_os/integrations/__init__.py index 92b4b1926..dd5d9b725 100644 --- a/agent-governance-python/agent-os/src/agent_os/integrations/__init__.py +++ b/agent-governance-python/agent-os/src/agent_os/integrations/__init__.py @@ -59,7 +59,11 @@ RateLimitError, ) from agent_os.integrations.a2a_adapter import A2AEvaluation, A2AGovernanceAdapter, A2APolicy -from agent_os.integrations.anthropic_adapter import AnthropicKernel, GovernedAnthropicClient +from agent_os.integrations.anthropic_adapter import ( + AnthropicKernel, + GovernedAnthropicClient, + GovernanceMessageHook as AnthropicGovernanceHook, +) from agent_os.integrations.autogen_adapter import AutoGenKernel from agent_os.integrations.crewai_adapter import CrewAIKernel from agent_os.integrations.gemini_adapter import GeminiKernel, GovernedGeminiModel @@ -91,11 +95,19 @@ from agent_os.integrations.llamaindex_adapter import LlamaIndexKernel from agent_os.integrations.mistral_adapter import GovernedMistralClient, MistralKernel from agent_os.integrations.openai_adapter import GovernedAssistant, OpenAIKernel -from agent_os.integrations.pydantic_ai_adapter import PydanticAIKernel +from agent_os.integrations.pydantic_ai_adapter import ( + GovernanceCapability as PydanticAIGovernanceCapability, + PydanticAIKernel, +) from agent_os.integrations.semantic_kernel_adapter import ( + GovernanceFunctionFilter as SKGovernanceFilter, GovernedSemanticKernel, SemanticKernelWrapper, ) +from agent_os.integrations.smolagents_adapter import ( + GovernanceStepCallback as SmolagentsGovernanceCallback, + SmolagentsKernel, +) from .base import ( AsyncGovernedWrapper, @@ -172,7 +184,7 @@ # Anthropic Claude "AnthropicKernel", "GovernedAnthropicClient", - # Google Gemini + "AnthropicGovernanceHook", "GeminiKernel", "GovernedGeminiModel", # Mistral AI @@ -181,6 +193,7 @@ # Semantic Kernel "SemanticKernelWrapper", "GovernedSemanticKernel", + "SKGovernanceFilter", # Guardrails "GuardrailsKernel", # Google ADK @@ -204,6 +217,10 @@ "OffensiveIntentDetector", # PydanticAI "PydanticAIKernel", + "PydanticAIGovernanceCapability", + # Smolagents + "SmolagentsKernel", + "SmolagentsGovernanceCallback", # Microsoft Agent Framework (MAF) "MAFGovernancePolicyMiddleware", "MAFCapabilityGuardMiddleware", diff --git a/agent-governance-python/agent-os/src/agent_os/integrations/anthropic_adapter.py b/agent-governance-python/agent-os/src/agent_os/integrations/anthropic_adapter.py index 1529c4397..5b61a4932 100644 --- a/agent-governance-python/agent-os/src/agent_os/integrations/anthropic_adapter.py +++ b/agent-governance-python/agent-os/src/agent_os/integrations/anthropic_adapter.py @@ -132,9 +132,34 @@ def __init__( self._start_time = time.monotonic() self._last_error: str | None = None - def wrap(self, client: Any) -> GovernedAnthropicClient: + def as_message_hook(self, *, name: str = "anthropic-governance") -> "GovernanceMessageHook": + """Create a ``GovernanceMessageHook`` for non-invasive integration. + + The hook governs ``messages.create()`` calls without wrapping or + proxying the Anthropic client. This is the **recommended** + integration pattern. + + Args: + name: Human-readable identifier for audit logging. + + Returns: + A ``GovernanceMessageHook`` instance. + + Example:: + + kernel = AnthropicKernel(policy=policy) + hook = kernel.as_message_hook() + response = hook.create(client, model="claude-sonnet-4-20250514", ...) + """ + return GovernanceMessageHook(self, name=name) + + def wrap(self, client: Any) -> "GovernedAnthropicClient": """Wrap an Anthropic client with governance. + .. deprecated:: + Use :meth:`as_message_hook` instead for a non-invasive + integration that does not proxy the client object. + Args: client: An ``anthropic.Anthropic`` client instance. @@ -142,6 +167,13 @@ def wrap(self, client: Any) -> GovernedAnthropicClient: A ``GovernedAnthropicClient`` that enforces policy on all ``messages.create()`` calls. """ + import warnings + warnings.warn( + "AnthropicKernel.wrap() is deprecated. Use as_message_hook() " + "for a non-invasive governance pattern that doesn't proxy the client.", + DeprecationWarning, + stacklevel=2, + ) _check_anthropic_available() client_id = id(client) ctx = AnthropicContext( @@ -402,12 +434,176 @@ def __getattr__(self, name: str) -> Any: return getattr(self._client, name) +# ═══════════════════════════════════════════════════════════════════ +# Native Hook: GovernanceMessageHook +# ═══════════════════════════════════════════════════════════════════ +# +# Anthropic's Python SDK does not expose a formal middleware/plugin +# system. However, the recommended integration pattern is a +# composable "message hook" that wraps messages.create() calls +# with governance checks — without creating a proxy client object. +# +# Usage: +# kernel = AnthropicKernel(policy=policy) +# hook = kernel.as_message_hook() +# +# # Use the hook to govern individual calls +# response = hook.create(client, model="claude-sonnet-4-20250514", ...) +# ═══════════════════════════════════════════════════════════════════ + + +class GovernanceMessageHook: + """Stateless governance hook for Anthropic ``messages.create()`` calls. + + Unlike ``GovernedAnthropicClient``, this does **not** wrap or proxy the + client object. Instead, it provides a ``create()`` method that governs + a single ``messages.create()`` invocation on any client you pass in. + + This is the recommended integration pattern for Anthropic because the + SDK does not expose a native plugin/middleware system. + + Example:: + + kernel = AnthropicKernel(policy=GovernancePolicy( + blocked_patterns=["password"], + allowed_tools=["web_search"], + )) + hook = kernel.as_message_hook() + + response = hook.create(client, model="claude-sonnet-4-20250514", + max_tokens=1024, messages=[...]) + """ + + def __init__(self, kernel: AnthropicKernel, *, name: str = "anthropic-governance") -> None: + self._kernel = kernel + self._name = name + self._ctx = AnthropicContext( + agent_id=name, + session_id=f"ant-hook-{int(time.time())}", + policy=kernel.policy, + ) + kernel.contexts[name] = self._ctx + + @property + def kernel(self) -> AnthropicKernel: + """Return the governing kernel.""" + return self._kernel + + @property + def context(self) -> AnthropicContext: + """Return the execution context.""" + return self._ctx + + def create(self, client: Any, **kwargs: Any) -> Any: + """Govern a single ``messages.create()`` call. + + Validates message content against blocked patterns, enforces + tool-call allowlists, checks token limits after completion, + and records an audit trail — all without mutating the client. + + Args: + client: An ``anthropic.Anthropic`` client instance. + **kwargs: Forwarded to ``client.messages.create()``. + + Returns: + The Anthropic message response. + + Raises: + PolicyViolationError: If a governance policy is violated. + """ + # --- pre-execution checks --- + messages = kwargs.get("messages", []) + for msg in messages: + content = msg.get("content", "") if isinstance(msg, dict) else str(msg) + allowed, reason = self._kernel.pre_execute(self._ctx, content) + if not allowed: + raise PolicyViolationError(f"Message blocked: {reason}") + + # Validate requested tools against policy + tools = kwargs.get("tools") + if tools and self._kernel.policy.allowed_tools: + for tool in tools: + name = tool.get("name") if isinstance(tool, dict) else getattr(tool, "name", None) + if name and name not in self._kernel.policy.allowed_tools: + raise PolicyViolationError(f"Tool not allowed: {name}") + + # Enforce max_tokens cap from policy + requested_max = kwargs.get("max_tokens", 0) + if requested_max > self._kernel.policy.max_tokens: + raise PolicyViolationError( + f"Requested max_tokens ({requested_max}) exceeds policy limit " + f"({self._kernel.policy.max_tokens})" + ) + + # Audit log + logger.info( + "Anthropic hook.create | agent=%s model=%s", + self._name, + kwargs.get("model", "unknown"), + ) + + # --- execute --- + response = client.messages.create(**kwargs) + + # --- post-execution checks --- + response_id = getattr(response, "id", f"msg-{int(time.time())}") + self._ctx.message_ids.append(response_id) + + # Track tokens + usage = getattr(response, "usage", None) + if usage: + self._ctx.prompt_tokens += getattr(usage, "input_tokens", 0) + self._ctx.completion_tokens += getattr(usage, "output_tokens", 0) + + total = self._ctx.prompt_tokens + self._ctx.completion_tokens + if total > self._kernel.policy.max_tokens: + raise PolicyViolationError( + f"Token limit exceeded: {total} > {self._kernel.policy.max_tokens}" + ) + + # Validate tool_use blocks in response + content_blocks = getattr(response, "content", []) + for block in content_blocks: + if getattr(block, "type", None) == "tool_use": + tool_name = getattr(block, "name", "") + self._ctx.tool_use_calls.append({ + "id": getattr(block, "id", ""), + "name": tool_name, + "input": getattr(block, "input", {}), + "timestamp": datetime.now().isoformat(), + }) + self._ctx.tool_calls.append({"name": tool_name}) + + if len(self._ctx.tool_calls) > self._kernel.policy.max_tool_calls: + raise PolicyViolationError( + f"Tool call limit exceeded: " + f"{len(self._ctx.tool_calls)} > " + f"{self._kernel.policy.max_tool_calls}" + ) + + if self._kernel.policy.allowed_tools: + if tool_name not in self._kernel.policy.allowed_tools: + raise PolicyViolationError(f"Tool not allowed: {tool_name}") + + # Post-execute bookkeeping + self._kernel.post_execute(self._ctx, response) + + return response + + def __repr__(self) -> str: + return f"GovernanceMessageHook(name={self._name!r})" + + def wrap_client( client: Any, policy: GovernancePolicy | None = None, ) -> GovernedAnthropicClient: """Quick wrapper for Anthropic clients. + .. deprecated:: + Use ``AnthropicKernel.as_message_hook()`` instead for a + non-invasive integration that does not proxy the client. + Args: client: An ``anthropic.Anthropic`` client instance. policy: Optional governance policy. @@ -420,4 +616,12 @@ def wrap_client( >>> governed = wrap_client(my_client) >>> response = governed.messages.create(model="claude-sonnet-4-20250514", ...) """ + import warnings + warnings.warn( + "wrap_client() is deprecated. Use AnthropicKernel(policy=...).as_message_hook() " + "for a non-invasive governance pattern that doesn't proxy the client.", + DeprecationWarning, + stacklevel=2, + ) return AnthropicKernel(policy=policy).wrap(client) + diff --git a/agent-governance-python/agent-os/src/agent_os/integrations/pydantic_ai_adapter.py b/agent-governance-python/agent-os/src/agent_os/integrations/pydantic_ai_adapter.py index b8a9fa47f..3faab879f 100644 --- a/agent-governance-python/agent-os/src/agent_os/integrations/pydantic_ai_adapter.py +++ b/agent-governance-python/agent-os/src/agent_os/integrations/pydantic_ai_adapter.py @@ -126,9 +126,35 @@ def _record_audit( self._audit_log.append(entry) return entry - def wrap(self, agent: Any) -> Any: + def as_capability(self) -> "GovernanceCapability": + """Create a ``GovernanceCapability`` for PydanticAI's native hook system. + + Returns a capability that can be passed to the ``Agent`` constructor's + ``capabilities=`` parameter:: + + kernel = PydanticAIKernel(policy=policy) + capability = kernel.as_capability() + + agent = Agent( + "openai:gpt-4o", + capabilities=[capability], + ) + + This is the **recommended** integration pattern for PydanticAI + because it uses the framework's native ``Hooks``/``Capability`` + system instead of monkey-patching tool functions. + + Returns: + A ``GovernanceCapability`` instance. """ - Wrap a PydanticAI Agent with governance. + return GovernanceCapability(self) + + def wrap(self, agent: Any) -> Any: + """Wrap a PydanticAI Agent with governance. + + .. deprecated:: + Use :meth:`as_capability` with ``capabilities=`` instead + for a non-invasive integration. Intercepts: - agent.run() / agent.run_sync() @@ -141,6 +167,14 @@ def wrap(self, agent: Any) -> Any: Returns: A governed wrapper around the agent. """ + import warnings + warnings.warn( + "PydanticAIKernel.wrap() is deprecated. Use as_capability() " + "with Agent(capabilities=[kernel.as_capability()]) " + "for a non-invasive integration.", + DeprecationWarning, + stacklevel=2, + ) agent_id = getattr(agent, "name", None) or f"agent-{id(agent)}" ctx = self.create_context(agent_id) self._wrapped_agents[id(agent)] = agent @@ -409,13 +443,199 @@ def governed_fn(*args: Any, **kwargs: Any) -> Any: # Convenience function def wrap(agent: Any, policy: GovernancePolicy | None = None, **kwargs) -> Any: - """Quick wrapper for PydanticAI agents.""" - return PydanticAIKernel(policy, **kwargs).wrap(agent) + """Quick wrapper for PydanticAI agents. + + .. deprecated:: + Use ``PydanticAIKernel.as_capability()`` with + ``Agent(capabilities=[...])`` instead. + """ + import warnings + warnings.warn( + "wrap() is deprecated. Use PydanticAIKernel(policy=...).as_capability() " + "with Agent(capabilities=[...]) instead.", + DeprecationWarning, + stacklevel=2, + ) + kernel = PydanticAIKernel(policy, **kwargs) + # Suppress nested deprecation from kernel.wrap() + import contextlib + with contextlib.suppress(Exception), warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return kernel.wrap(agent) + + +# ═══════════════════════════════════════════════════════════════════ +# Native Hook: GovernanceCapability +# ═══════════════════════════════════════════════════════════════════ +# +# PydanticAI provides a Hooks/Capability system for composable, +# non-invasive lifecycle hooks. GovernanceCapability implements +# the key hooks: +# - before_tool_execute: tool allowlist / blocklist / pattern check +# - after_tool_execute: post-execution audit +# - before_run: pre-execution content scanning +# - after_run: post-execution drift detection +# +# Usage: +# kernel = PydanticAIKernel(policy=policy) +# agent = Agent("openai:gpt-4o", capabilities=[kernel.as_capability()]) +# ═══════════════════════════════════════════════════════════════════ + + +class GovernanceCapability: + """Governance capability for PydanticAI's native hook system. + + Implements the PydanticAI capability/hooks protocol, providing + governance checks at key lifecycle points: + + - ``before_tool_execute``: Validates tool name against + ``allowed_tools``, scans arguments for ``blocked_patterns``, + enforces ``max_tool_calls``. + - ``after_tool_execute``: Records audit entries. + - ``before_run``: Scans prompt for blocked patterns. + - ``after_run``: Runs post-execute drift detection. + + Example:: + + kernel = PydanticAIKernel(policy=GovernancePolicy( + allowed_tools=["search", "read_file"], + blocked_patterns=["rm -rf"], + max_tool_calls=10, + )) + capability = kernel.as_capability() + + agent = Agent( + "openai:gpt-4o", + capabilities=[capability], + ) + """ + + def __init__(self, kernel: PydanticAIKernel) -> None: + self._kernel = kernel + self._ctx = kernel.create_context("pydantic-ai-hooks") + self._tool_call_count: int = 0 + self._audit: list[dict[str, Any]] = [] + + @property + def kernel(self) -> PydanticAIKernel: + """Return the governing kernel.""" + return self._kernel + + @property + def context(self) -> ExecutionContext: + """Return the execution context.""" + return self._ctx + + @property + def audit_log(self) -> list[dict[str, Any]]: + """Return the audit log.""" + return list(self._audit) + + def before_run(self, prompt: str, **kwargs: Any) -> str: + """Pre-run hook: scan prompt for governance violations. + + Args: + prompt: The user prompt to validate. + **kwargs: Additional run context. + + Returns: + The prompt (unmodified). + + Raises: + PolicyViolationError: If the prompt violates policy. + """ + allowed, reason = self._kernel.pre_execute(self._ctx, prompt) + if not allowed: + self._audit.append({ + "event": "run_blocked", + "reason": reason, + }) + raise PolicyViolationError(reason or "Pre-execution check failed") + self._audit.append({"event": "run_start", "prompt_length": len(prompt)}) + return prompt + + def after_run(self, result: Any, **kwargs: Any) -> Any: + """Post-run hook: drift detection on result. + + Args: + result: The agent run result. + **kwargs: Additional run context. + + Returns: + The result (unmodified). + """ + self._audit.append({"event": "run_complete"}) + return result + + def before_tool_execute( + self, + tool_name: str, + arguments: dict[str, Any], + **kwargs: Any, + ) -> dict[str, Any]: + """Pre-tool hook: validate tool call against governance policy. + + Args: + tool_name: Name of the tool being called. + arguments: Tool call arguments. + **kwargs: Additional context. + + Returns: + The arguments (unmodified). + + Raises: + PolicyViolationError: If the tool call violates policy. + """ + result = self._kernel.intercept_tool_call(self._ctx, tool_name, arguments) + if not result.allowed: + self._audit.append({ + "event": "tool_blocked", + "tool": tool_name, + "reason": result.reason, + }) + raise PolicyViolationError( + result.reason or f"Tool '{tool_name}' blocked by policy" + ) + + self._tool_call_count += 1 + self._ctx.call_count += 1 + self._audit.append({ + "event": "tool_allowed", + "tool": tool_name, + "call_number": self._tool_call_count, + }) + return arguments + + def after_tool_execute( + self, + tool_name: str, + result: Any, + **kwargs: Any, + ) -> Any: + """Post-tool hook: audit the tool execution result. + + Args: + tool_name: Name of the tool that was called. + result: The tool's return value. + **kwargs: Additional context. + + Returns: + The result (unmodified). + """ + self._audit.append({ + "event": "tool_executed", + "tool": tool_name, + }) + return result + + def __repr__(self) -> str: + return f"GovernanceCapability(calls={self._tool_call_count})" __all__ = [ "PydanticAIKernel", "HumanApprovalRequired", + "GovernanceCapability", "HAS_PYDANTIC_AI", "wrap", ] diff --git a/agent-governance-python/agent-os/src/agent_os/integrations/semantic_kernel_adapter.py b/agent-governance-python/agent-os/src/agent_os/integrations/semantic_kernel_adapter.py index 5b0c7f8b3..02acbb9b5 100644 --- a/agent-governance-python/agent-os/src/agent_os/integrations/semantic_kernel_adapter.py +++ b/agent-governance-python/agent-os/src/agent_os/integrations/semantic_kernel_adapter.py @@ -113,9 +113,29 @@ def __init__( self._start_time = time.monotonic() self._last_error: Optional[str] = None - def wrap(self, kernel: Any) -> "GovernedSemanticKernel": + def as_filter(self) -> "GovernanceFunctionFilter": + """Create a governance filter for Semantic Kernel's native filter system. + + Returns a ``GovernanceFunctionFilter`` that can be registered with:: + + kernel.add_filter("auto_function_invocation", wrapper.as_filter()) + kernel.add_filter("function_invocation", wrapper.as_filter()) + + This is the **recommended** integration pattern for Semantic Kernel + as it uses the framework's native ``add_filter()`` API instead of + proxying the kernel object. + + Returns: + A ``GovernanceFunctionFilter`` instance. """ - Wrap a Semantic Kernel with governance. + return GovernanceFunctionFilter(self) + + def wrap(self, kernel: Any) -> "GovernedSemanticKernel": + """Wrap a Semantic Kernel with governance. + + .. deprecated:: + Use :meth:`as_filter` with ``kernel.add_filter()`` instead + for a non-invasive integration. Args: kernel: Semantic Kernel instance @@ -123,6 +143,14 @@ def wrap(self, kernel: Any) -> "GovernedSemanticKernel": Returns: GovernedSemanticKernel with full governance """ + import warnings + warnings.warn( + "SemanticKernelWrapper.wrap() is deprecated. Use as_filter() with " + "kernel.add_filter('auto_function_invocation', wrapper.as_filter()) " + "for a non-invasive integration.", + DeprecationWarning, + stacklevel=2, + ) kernel_id = f"sk-{id(kernel)}" ctx = SKContext( agent_id=kernel_id, @@ -758,8 +786,11 @@ def wrap_kernel( policy: Optional[GovernancePolicy] = None, timeout_seconds: float = 300.0, ) -> GovernedSemanticKernel: - """ - Quick wrapper for Semantic Kernel. + """Quick wrapper for Semantic Kernel. + + .. deprecated:: + Use ``SemanticKernelWrapper.as_filter()`` with + ``kernel.add_filter()`` instead. Example: from agent_os.integrations.semantic_kernel_adapter import wrap_kernel @@ -767,6 +798,157 @@ def wrap_kernel( governed = wrap_kernel(my_kernel) result = await governed.invoke("plugin", "function") """ - return SemanticKernelWrapper( - policy=policy, timeout_seconds=timeout_seconds - ).wrap(kernel) + import warnings + warnings.warn( + "wrap_kernel() is deprecated. Use SemanticKernelWrapper(policy=...).as_filter() " + "with kernel.add_filter('auto_function_invocation', ...) instead.", + DeprecationWarning, + stacklevel=2, + ) + wrapper = SemanticKernelWrapper(policy=policy, timeout_seconds=timeout_seconds) + # Suppress the deprecation from wrap() since we already emitted one + import contextlib + with contextlib.suppress(Exception), warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return wrapper.wrap(kernel) + + +# ═══════════════════════════════════════════════════════════════════ +# Native Hook: GovernanceFunctionFilter +# ═══════════════════════════════════════════════════════════════════ +# +# Semantic Kernel provides kernel.add_filter() for registering +# function invocation and auto-function-invocation filters. +# GovernanceFunctionFilter implements the filter protocol: +# +# async def __call__(self, context, next): +# ... +# await next(context) +# ... +# +# Usage: +# wrapper = SemanticKernelWrapper(policy=policy) +# sk_kernel.add_filter("auto_function_invocation", wrapper.as_filter()) +# sk_kernel.add_filter("function_invocation", wrapper.as_filter()) +# ═══════════════════════════════════════════════════════════════════ + + +class GovernanceFunctionFilter: + """Governance filter for Semantic Kernel's native ``add_filter()`` system. + + Implements the SK filter protocol (``async __call__(context, next)``) + and intercepts function invocations for policy enforcement. + + The filter: + - Validates function names against ``allowed_tools`` + - Scans function arguments for ``blocked_patterns`` + - Enforces ``max_tool_calls`` limits + - Runs Cedar/OPA ``pre_execute`` checks + - Runs ``post_execute`` drift detection on results + + Example:: + + wrapper = SemanticKernelWrapper(policy=GovernancePolicy( + allowed_tools=["MyPlugin.safe_func"], + blocked_patterns=["DROP TABLE"], + )) + governance_filter = wrapper.as_filter() + + sk_kernel.add_filter("auto_function_invocation", governance_filter) + sk_kernel.add_filter("function_invocation", governance_filter) + """ + + def __init__(self, wrapper: SemanticKernelWrapper) -> None: + self._wrapper = wrapper + self._ctx = SKContext( + agent_id="sk-filter", + session_id=f"sk-filter-{int(datetime.now().timestamp())}", + policy=wrapper.policy, + kernel_id="sk-filter", + ) + wrapper._contexts["sk-filter"] = self._ctx + + @property + def wrapper(self) -> SemanticKernelWrapper: + """Return the parent ``SemanticKernelWrapper``.""" + return self._wrapper + + @property + def context(self) -> SKContext: + """Return the execution context.""" + return self._ctx + + async def __call__(self, context: Any, next: Any) -> None: + """Filter protocol implementation for Semantic Kernel. + + Called by the SK runtime before/after each function invocation. + Validates the function against governance policy, calls ``next()`` + to proceed, then runs post-execution checks. + + Args: + context: SK's ``FunctionInvocationContext`` or + ``AutoFunctionInvocationContext``. + next: Async callable to continue the filter chain or execute + the function. + + Raises: + PolicyViolationError: If the function violates governance policy. + """ + # Extract function identity + func = getattr(context, "function", None) + func_name = getattr(func, "name", None) or "unknown" + plugin_name = getattr(func, "plugin_name", None) or "" + full_name = f"{plugin_name}.{func_name}" if plugin_name else func_name + + # Record invocation + self._ctx.functions_invoked.append({ + "function": full_name, + "timestamp": datetime.now().isoformat(), + }) + + # Check allowed_tools + if self._wrapper.policy.allowed_tools: + if full_name not in self._wrapper.policy.allowed_tools: + wildcard = f"{plugin_name}.*" if plugin_name else None + if not wildcard or wildcard not in self._wrapper.policy.allowed_tools: + raise PolicyViolationError(f"Function not allowed: {full_name}") + + # Check blocked patterns in arguments + args = getattr(context, "arguments", None) + if args: + args_str = str(args) + for pattern in self._wrapper.policy.blocked_patterns: + pat = pattern if isinstance(pattern, str) else pattern[0] + if pat.lower() in args_str.lower(): + raise PolicyViolationError( + f"Blocked pattern '{pat}' in arguments for {full_name}" + ) + + # Check call count + self._ctx.call_count += 1 + if self._ctx.call_count > self._wrapper.policy.max_tool_calls: + raise PolicyViolationError( + f"Tool call limit exceeded: " + f"{self._ctx.call_count} > {self._wrapper.policy.max_tool_calls}" + ) + + # Pre-execute (Cedar/OPA) + allowed, reason = self._wrapper.pre_execute(self._ctx, { + "function": full_name, + "arguments": str(args) if args else "", + }) + if not allowed: + raise PolicyViolationError(f"Invocation blocked: {reason}") + + # Proceed with execution + await next(context) + + # Post-execute drift detection + result = getattr(context, "result", None) + valid, reason = self._wrapper.post_execute(self._ctx, result) + if not valid: + raise PolicyViolationError(f"Result blocked: {reason}") + + def __repr__(self) -> str: + return "GovernanceFunctionFilter(wrapper=SemanticKernelWrapper)" + diff --git a/agent-governance-python/agent-os/src/agent_os/integrations/smolagents_adapter.py b/agent-governance-python/agent-os/src/agent_os/integrations/smolagents_adapter.py index 9125117fe..e93b34bf2 100644 --- a/agent-governance-python/agent-os/src/agent_os/integrations/smolagents_adapter.py +++ b/agent-governance-python/agent-os/src/agent_os/integrations/smolagents_adapter.py @@ -186,9 +186,36 @@ def __init__( # BaseIntegration abstract methods # ------------------------------------------------------------------ - def wrap(self, agent: Any) -> Any: + def as_step_callback(self) -> "GovernanceStepCallback": + """Create a governance callback for smolagents' native ``step_callbacks``. + + Returns a ``GovernanceStepCallback`` that can be passed directly to + a smolagents agent's ``step_callbacks`` list:: + + kernel = SmolagentsKernel(policy=config) + callback = kernel.as_step_callback() + + agent = CodeAgent( + tools=[...], + model=model, + step_callbacks=[callback], + ) + + This is the **recommended** integration pattern for smolagents, + as it uses the framework's native callback system instead of + monkey-patching tool ``forward`` methods. + + Returns: + A ``GovernanceStepCallback`` instance. """ - Wrap a smolagents agent with governance. + return GovernanceStepCallback(self) + + def wrap(self, agent: Any) -> Any: + """Wrap a smolagents agent with governance. + + .. deprecated:: + Use :meth:`as_step_callback` with ``step_callbacks=`` instead + for a non-invasive integration. Intercepts each tool's ``forward`` method so that every tool call passes through policy checks before execution. The agent's @@ -197,6 +224,14 @@ def wrap(self, agent: Any) -> Any: Works without smolagents installed (for testing with mocks). """ + import warnings + warnings.warn( + "SmolagentsKernel.wrap() is deprecated. Use as_step_callback() " + "with agent = Agent(step_callbacks=[kernel.as_step_callback()]) " + "for a non-invasive integration.", + DeprecationWarning, + stacklevel=2, + ) agent_name = getattr(agent, "name", None) or str(id(agent)) # smolagents stores tools in agent.toolbox (dict-like or has .tools) @@ -620,11 +655,172 @@ def health_check(self) -> dict[str, Any]: } +# ═══════════════════════════════════════════════════════════════════ +# Native Hook: GovernanceStepCallback +# ═══════════════════════════════════════════════════════════════════ +# +# smolagents provides ``step_callbacks`` — a list of callables +# invoked after each agent step with (step, agent) signature. +# GovernanceStepCallback implements this protocol. +# +# Usage: +# kernel = SmolagentsKernel(policy=config) +# agent = CodeAgent( +# tools=[...], model=model, +# step_callbacks=[kernel.as_step_callback()], +# ) +# ═══════════════════════════════════════════════════════════════════ + + +class GovernanceStepCallback: + """Governance callback for smolagents' native ``step_callbacks`` system. + + Implements the smolagents step-callback protocol + (``__call__(step, agent)``) and inspects each completed step for + governance violations. + + The callback: + - Validates tool names in ``step.tool_calls`` against ``allowed_tools`` + and ``blocked_tools`` + - Scans tool arguments and observations for ``blocked_patterns`` + - Enforces ``max_tool_calls`` limits + - Records an audit trail for every step + + Example:: + + kernel = SmolagentsKernel( + allowed_tools=["web_search"], + blocked_patterns=["DROP TABLE"], + ) + callback = kernel.as_step_callback() + + agent = CodeAgent( + tools=[web_search_tool], + model=model, + step_callbacks=[callback], + ) + """ + + def __init__(self, kernel: SmolagentsKernel) -> None: + self._kernel = kernel + self._step_count: int = 0 + + @property + def kernel(self) -> SmolagentsKernel: + """Return the governing kernel.""" + return self._kernel + + @property + def step_count(self) -> int: + """Return the number of steps processed.""" + return self._step_count + + def __call__(self, step: Any, agent: Any) -> None: + """Step-callback protocol implementation for smolagents. + + Called by the smolagents runtime after each agent step completes. + Inspects the step for tool calls and validates them against the + governance policy. + + Args: + step: A ``smolagents.MemoryStep`` (or similar) containing + step details such as ``tool_calls`` or ``action``. + agent: The smolagents agent instance. + + Raises: + PolicyViolationError: If the step violates governance policy. + """ + self._step_count += 1 + agent_name = getattr(agent, "name", None) or str(id(agent)) + config = self._kernel._sm_config + + # Extract tool calls from the step + tool_calls = getattr(step, "tool_calls", None) or [] + action = getattr(step, "action", None) + observation = getattr(step, "observation", None) + + # If the step has an action with a tool call + if action and hasattr(action, "tool_name"): + tool_calls = [action] + + for tc in tool_calls: + tool_name = getattr(tc, "tool_name", None) or getattr(tc, "name", str(tc)) + tool_args = getattr(tc, "tool_arguments", None) or getattr(tc, "arguments", {}) + + # Blocked tools + if tool_name in config.blocked_tools: + self._kernel._record( + "tool_blocked", agent_name, + {"tool": tool_name, "reason": "blocked_tool"}, + ) + raise PolicyViolationError( + "blocked_tool", + f"Tool '{tool_name}' is explicitly blocked", + ) + + # Allowed tools check + if config.allowed_tools and tool_name not in config.allowed_tools: + self._kernel._record( + "tool_blocked", agent_name, + {"tool": tool_name, "reason": "not_in_allowlist"}, + ) + raise PolicyViolationError( + "tool_not_allowed", + f"Tool '{tool_name}' is not in the allowed list", + ) + + # Blocked patterns in arguments + args_str = str(tool_args) + for pattern in config.blocked_patterns: + if pattern.lower() in args_str.lower(): + self._kernel._record( + "pattern_blocked", agent_name, + {"tool": tool_name, "pattern": pattern}, + ) + raise PolicyViolationError( + "blocked_pattern", + f"Blocked pattern '{pattern}' in arguments for '{tool_name}'", + ) + + # Increment and check call count + self._kernel._tool_call_count += 1 + if self._kernel._tool_call_count > config.max_tool_calls: + raise PolicyViolationError( + "max_tool_calls_exceeded", + f"Tool call limit exceeded: " + f"{self._kernel._tool_call_count} > {config.max_tool_calls}", + ) + + # Audit + self._kernel._record( + "tool_executed", agent_name, + {"tool": tool_name, "step": self._step_count}, + ) + + # Scan observation for blocked patterns + if observation: + obs_str = str(observation) + for pattern in config.blocked_patterns: + if pattern.lower() in obs_str.lower(): + self._kernel._record( + "observation_blocked", agent_name, + {"pattern": pattern, "step": self._step_count}, + ) + raise PolicyViolationError( + "blocked_pattern_in_observation", + f"Blocked pattern '{pattern}' in step observation", + ) + + def __repr__(self) -> str: + return f"GovernanceStepCallback(steps={self._step_count})" + + __all__ = [ "SmolagentsKernel", "PolicyConfig", "PolicyViolationError", "AuditEvent", + "GovernanceStepCallback", "_HAS_SMOLAGENTS", "_check_smolagents_available", ] diff --git a/agent-governance-python/agent-os/tests/test_anthropic_hooks.py b/agent-governance-python/agent-os/tests/test_anthropic_hooks.py new file mode 100644 index 000000000..b9081554e --- /dev/null +++ b/agent-governance-python/agent-os/tests/test_anthropic_hooks.py @@ -0,0 +1,246 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Tests for Anthropic native governance hooks (GovernanceMessageHook). + +Validates: +- GovernanceMessageHook creation via as_message_hook() +- Message content scanning against blocked_patterns +- Tool allowlist enforcement (pre-call and response) +- Token limit enforcement +- Tool call count limits +- Audit trail recording +- Deprecation warnings on wrap() and wrap_client() +""" + +import warnings +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from agent_os.integrations.anthropic_adapter import ( + AnthropicKernel, + GovernanceMessageHook, + wrap_client, +) +from agent_os.integrations.base import GovernancePolicy + + +# ── Fixtures ────────────────────────────────────────────────────── + + +@pytest.fixture +def policy(): + """Create a governance policy for testing.""" + return GovernancePolicy( + max_tool_calls=5, + max_tokens=1000, + allowed_tools=["web_search", "read_file"], + blocked_patterns=["password", "secret_key"], + ) + + +@pytest.fixture +def kernel(policy): + """Create an AnthropicKernel with test policy.""" + return AnthropicKernel(policy=policy) + + +@pytest.fixture +def hook(kernel): + """Create a GovernanceMessageHook from the kernel.""" + return kernel.as_message_hook() + + +@pytest.fixture +def mock_client(): + """Create a mock Anthropic client.""" + client = MagicMock() + response = SimpleNamespace( + id="msg-test-123", + content=[], + usage=SimpleNamespace(input_tokens=50, output_tokens=100), + ) + client.messages.create.return_value = response + return client + + +# ── as_message_hook() factory ───────────────────────────────────── + + +class TestAsMessageHook: + """Tests for the as_message_hook() factory method.""" + + def test_returns_governance_message_hook(self, kernel): + hook = kernel.as_message_hook() + assert isinstance(hook, GovernanceMessageHook) + + def test_custom_name(self, kernel): + hook = kernel.as_message_hook(name="my-hook") + assert hook._name == "my-hook" + assert "my-hook" in repr(hook) + + def test_context_registered(self, kernel): + hook = kernel.as_message_hook(name="test-ctx") + assert "test-ctx" in kernel.contexts + + def test_hook_has_kernel_reference(self, kernel): + hook = kernel.as_message_hook() + assert hook.kernel is kernel + + +# ── Pre-execution checks ───────────────────────────────────────── + + +class TestPreExecutionChecks: + """Tests for message content and tool validation before execution.""" + + def test_blocks_blocked_pattern_in_messages(self, hook, mock_client): + with pytest.raises(Exception, match="Message blocked"): + hook.create( + mock_client, + model="claude-sonnet-4-20250514", + max_tokens=100, + messages=[{"role": "user", "content": "Tell me the password"}], + ) + + def test_blocks_disallowed_tool(self, hook, mock_client): + with pytest.raises(Exception, match="Tool not allowed"): + hook.create( + mock_client, + model="claude-sonnet-4-20250514", + max_tokens=100, + messages=[{"role": "user", "content": "Hello"}], + tools=[{"name": "dangerous_exec", "description": "..."}], + ) + + def test_allows_approved_tools(self, hook, mock_client): + result = hook.create( + mock_client, + model="claude-sonnet-4-20250514", + max_tokens=100, + messages=[{"role": "user", "content": "Hello"}], + tools=[{"name": "web_search", "description": "Search the web"}], + ) + assert result.id == "msg-test-123" + + def test_blocks_max_tokens_exceeding_policy(self, hook, mock_client): + with pytest.raises(Exception, match="max_tokens.*exceeds policy"): + hook.create( + mock_client, + model="claude-sonnet-4-20250514", + max_tokens=5000, + messages=[{"role": "user", "content": "Hello"}], + ) + + +# ── Post-execution checks ──────────────────────────────────────── + + +class TestPostExecutionChecks: + """Tests for token tracking and tool_use block validation.""" + + def test_tracks_tokens(self, hook, mock_client): + hook.create( + mock_client, + model="claude-sonnet-4-20250514", + max_tokens=100, + messages=[{"role": "user", "content": "Hello"}], + ) + ctx = hook.context + assert ctx.prompt_tokens == 50 + assert ctx.completion_tokens == 100 + + def test_records_message_id(self, hook, mock_client): + hook.create( + mock_client, + model="claude-sonnet-4-20250514", + max_tokens=100, + messages=[{"role": "user", "content": "Hello"}], + ) + assert "msg-test-123" in hook.context.message_ids + + def test_blocks_disallowed_tool_in_response(self, hook, mock_client): + """Tool_use blocks in the response are validated against allowed_tools.""" + response = SimpleNamespace( + id="msg-test-456", + content=[ + SimpleNamespace( + type="tool_use", + id="call-1", + name="dangerous_exec", + input={"cmd": "rm -rf /"}, + ), + ], + usage=SimpleNamespace(input_tokens=10, output_tokens=20), + ) + mock_client.messages.create.return_value = response + + with pytest.raises(Exception, match="Tool not allowed.*dangerous_exec"): + hook.create( + mock_client, + model="claude-sonnet-4-20250514", + max_tokens=100, + messages=[{"role": "user", "content": "Run command"}], + ) + + def test_enforces_token_limit_after_response(self, kernel): + """Cumulative token usage is checked after each response.""" + low_policy = GovernancePolicy(max_tokens=100) + k = AnthropicKernel(policy=low_policy) + hook = k.as_message_hook() + + client = MagicMock() + client.messages.create.return_value = SimpleNamespace( + id="msg-over", + content=[], + usage=SimpleNamespace(input_tokens=60, output_tokens=50), + ) + + with pytest.raises(Exception, match="Token limit exceeded"): + hook.create( + client, + model="claude-sonnet-4-20250514", + max_tokens=90, + messages=[{"role": "user", "content": "Hello"}], + ) + + +# ── Deprecation warnings ───────────────────────────────────────── + + +class TestDeprecationWarnings: + """Tests that legacy methods emit DeprecationWarning.""" + + def test_wrap_emits_deprecation(self, kernel, mock_client): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + kernel.wrap(mock_client) + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) >= 1 + assert "as_message_hook" in str(deprecations[0].message) + + def test_wrap_client_emits_deprecation(self, mock_client): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + wrap_client(mock_client) + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) >= 1 + assert "as_message_hook" in str(deprecations[0].message) + + +# ── Clean messages pass through ─────────────────────────────────── + + +class TestCleanPassthrough: + """Tests that clean, valid messages pass through governance.""" + + def test_clean_message_succeeds(self, hook, mock_client): + result = hook.create( + mock_client, + model="claude-sonnet-4-20250514", + max_tokens=100, + messages=[{"role": "user", "content": "Hello, how are you?"}], + ) + assert result.id == "msg-test-123" + mock_client.messages.create.assert_called_once() diff --git a/agent-governance-python/agent-os/tests/test_pydantic_ai_hooks.py b/agent-governance-python/agent-os/tests/test_pydantic_ai_hooks.py new file mode 100644 index 000000000..0d7b9eb54 --- /dev/null +++ b/agent-governance-python/agent-os/tests/test_pydantic_ai_hooks.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Tests for PydanticAI native governance capability (GovernanceCapability). + +Validates: +- GovernanceCapability creation via as_capability() +- before_run: prompt content scanning +- before_tool_execute: tool allowlist, blocked patterns, call limits +- after_tool_execute: audit recording +- after_run: completion recording +- Deprecation warnings on wrap() +""" + +import warnings +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from agent_os.integrations.pydantic_ai_adapter import ( + GovernanceCapability, + PydanticAIKernel, + wrap, +) +from agent_os.integrations.base import GovernancePolicy, PolicyViolationError + + +# ── Fixtures ────────────────────────────────────────────────────── + + +@pytest.fixture +def policy(): + """Create a governance policy for testing.""" + return GovernancePolicy( + max_tool_calls=5, + allowed_tools=["search", "read_file"], + blocked_patterns=["DROP TABLE", "rm -rf"], + ) + + +@pytest.fixture +def kernel(policy): + """Create a PydanticAIKernel with test policy.""" + return PydanticAIKernel(policy=policy) + + +@pytest.fixture +def capability(kernel): + """Create a GovernanceCapability from the kernel.""" + return kernel.as_capability() + + +# ── as_capability() factory ────────────────────────────────────── + + +class TestAsCapability: + """Tests for the as_capability() factory method.""" + + def test_returns_governance_capability(self, kernel): + cap = kernel.as_capability() + assert isinstance(cap, GovernanceCapability) + + def test_capability_has_kernel_reference(self, kernel): + cap = kernel.as_capability() + assert cap.kernel is kernel + + def test_context_created(self, kernel): + cap = kernel.as_capability() + assert cap.context is not None + assert cap.context.agent_id == "pydantic-ai-hooks" + + +# ── before_run ──────────────────────────────────────────────────── + + +class TestBeforeRun: + """Tests for the before_run hook.""" + + def test_passes_clean_prompt(self, capability): + result = capability.before_run("Hello, how are you?") + assert result == "Hello, how are you?" + + def test_blocks_blocked_pattern_in_prompt(self, capability): + with pytest.raises(PolicyViolationError): + capability.before_run("Please DROP TABLE users") + + def test_records_audit_on_block(self, capability): + try: + capability.before_run("Please DROP TABLE users") + except PolicyViolationError: + pass + assert any(e["event"] == "run_blocked" for e in capability.audit_log) + + def test_records_audit_on_start(self, capability): + capability.before_run("Hello") + assert any(e["event"] == "run_start" for e in capability.audit_log) + + +# ── before_tool_execute ────────────────────────────────────────── + + +class TestBeforeToolExecute: + """Tests for the before_tool_execute hook.""" + + def test_allows_approved_tool(self, capability): + result = capability.before_tool_execute("search", {"query": "Python"}) + assert result == {"query": "Python"} + + def test_blocks_disallowed_tool(self, capability): + with pytest.raises(PolicyViolationError): + capability.before_tool_execute("dangerous_exec", {"cmd": "ls"}) + + def test_blocks_pattern_in_args(self, capability): + with pytest.raises(PolicyViolationError): + capability.before_tool_execute("search", {"query": "rm -rf /"}) + + def test_increments_call_count(self, capability): + capability.before_tool_execute("search", {"query": "test"}) + assert capability.context.call_count == 1 + + def test_enforces_max_tool_calls(self, capability): + for _ in range(5): + capability.before_tool_execute("search", {"query": "test"}) + + with pytest.raises(PolicyViolationError, match="exceeded"): + capability.before_tool_execute("search", {"query": "test"}) + + def test_records_audit_on_block(self, capability): + try: + capability.before_tool_execute("dangerous_exec", {}) + except PolicyViolationError: + pass + assert any( + e["event"] == "tool_blocked" and e["tool"] == "dangerous_exec" + for e in capability.audit_log + ) + + def test_records_audit_on_allow(self, capability): + capability.before_tool_execute("search", {"query": "test"}) + assert any( + e["event"] == "tool_allowed" and e["tool"] == "search" + for e in capability.audit_log + ) + + +# ── after_tool_execute ─────────────────────────────────────────── + + +class TestAfterToolExecute: + """Tests for the after_tool_execute hook.""" + + def test_returns_result_unchanged(self, capability): + result = capability.after_tool_execute("search", {"data": [1, 2, 3]}) + assert result == {"data": [1, 2, 3]} + + def test_records_audit(self, capability): + capability.after_tool_execute("search", "result") + assert any( + e["event"] == "tool_executed" and e["tool"] == "search" + for e in capability.audit_log + ) + + +# ── after_run ───────────────────────────────────────────────────── + + +class TestAfterRun: + """Tests for the after_run hook.""" + + def test_returns_result(self, capability): + result = capability.after_run("final result") + assert result == "final result" + + def test_records_audit(self, capability): + capability.after_run("result") + assert any(e["event"] == "run_complete" for e in capability.audit_log) + + +# ── Deprecation warnings ───────────────────────────────────────── + + +class TestDeprecationWarnings: + """Tests that legacy methods emit DeprecationWarning.""" + + def test_wrap_method_emits_deprecation(self, kernel): + mock_agent = MagicMock() + mock_agent.name = "test" + mock_agent._function_tools = [] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + kernel.wrap(mock_agent) + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) >= 1 + assert "as_capability" in str(deprecations[0].message) + + def test_wrap_function_emits_deprecation(self): + mock_agent = MagicMock() + mock_agent.name = "test" + mock_agent._function_tools = [] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + wrap(mock_agent) + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) >= 1 + assert "as_capability" in str(deprecations[0].message) + + +# ── Repr ────────────────────────────────────────────────────────── + + +class TestRepr: + """Tests for GovernanceCapability repr.""" + + def test_repr(self, capability): + assert "GovernanceCapability" in repr(capability) + assert "calls=0" in repr(capability) diff --git a/agent-governance-python/agent-os/tests/test_semantic_kernel_hooks.py b/agent-governance-python/agent-os/tests/test_semantic_kernel_hooks.py new file mode 100644 index 000000000..868ec0541 --- /dev/null +++ b/agent-governance-python/agent-os/tests/test_semantic_kernel_hooks.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Tests for Semantic Kernel native governance filter (GovernanceFunctionFilter). + +Validates: +- GovernanceFunctionFilter creation via as_filter() +- Function allowlist enforcement +- Blocked pattern detection in arguments +- max_tool_calls enforcement +- Pre-execute (Cedar/OPA) gating +- Post-execute drift detection +- Deprecation warnings on wrap() and wrap_kernel() +""" + +import asyncio +import warnings +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agent_os.integrations.semantic_kernel_adapter import ( + GovernanceFunctionFilter, + GovernedSemanticKernel, + SemanticKernelWrapper, + wrap_kernel, +) +from agent_os.integrations.base import GovernancePolicy + + +# ── Fixtures ────────────────────────────────────────────────────── + + +@pytest.fixture +def policy(): + """Create a governance policy for testing.""" + return GovernancePolicy( + max_tool_calls=5, + allowed_tools=["MyPlugin.safe_func", "MyPlugin.*"], + blocked_patterns=["DROP TABLE", "rm -rf"], + ) + + +@pytest.fixture +def wrapper(policy): + """Create a SemanticKernelWrapper with test policy.""" + return SemanticKernelWrapper(policy=policy) + + +@pytest.fixture +def governance_filter(wrapper): + """Create a GovernanceFunctionFilter.""" + return wrapper.as_filter() + + +def _make_context(func_name="safe_func", plugin_name="MyPlugin", args=None): + """Create a mock SK function invocation context.""" + func = SimpleNamespace(name=func_name, plugin_name=plugin_name) + ctx = SimpleNamespace( + function=func, + arguments=args or {}, + result=None, + ) + return ctx + + +# ── as_filter() factory ────────────────────────────────────────── + + +class TestAsFilter: + """Tests for the as_filter() factory method.""" + + def test_returns_governance_filter(self, wrapper): + f = wrapper.as_filter() + assert isinstance(f, GovernanceFunctionFilter) + + def test_filter_registered_in_contexts(self, wrapper): + wrapper.as_filter() + assert "sk-filter" in wrapper._contexts + + def test_filter_has_wrapper_reference(self, wrapper): + f = wrapper.as_filter() + assert f.wrapper is wrapper + + +# ── Function allowlist ──────────────────────────────────────────── + + +class TestFunctionAllowlist: + """Tests for function name validation.""" + + def test_allows_exact_match(self, governance_filter): + ctx = _make_context("safe_func", "MyPlugin") + next_fn = AsyncMock() + + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + next_fn.assert_awaited_once_with(ctx) + + def test_allows_wildcard_match(self, governance_filter): + ctx = _make_context("any_func", "MyPlugin") + next_fn = AsyncMock() + + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + next_fn.assert_awaited_once() + + def test_blocks_disallowed_function(self, governance_filter): + ctx = _make_context("dangerous_func", "OtherPlugin") + next_fn = AsyncMock() + + with pytest.raises(Exception, match="Function not allowed"): + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + next_fn.assert_not_awaited() + + +# ── Blocked patterns ───────────────────────────────────────────── + + +class TestBlockedPatterns: + """Tests for blocked pattern detection in arguments.""" + + def test_blocks_pattern_in_args(self, governance_filter): + ctx = _make_context("safe_func", "MyPlugin", args={"query": "DROP TABLE users"}) + next_fn = AsyncMock() + + with pytest.raises(Exception, match="Blocked pattern"): + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + + def test_clean_args_pass(self, governance_filter): + ctx = _make_context("safe_func", "MyPlugin", args={"query": "SELECT * FROM users"}) + next_fn = AsyncMock() + + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + next_fn.assert_awaited_once() + + +# ── Call count enforcement ──────────────────────────────────────── + + +class TestCallCount: + """Tests for max_tool_calls enforcement.""" + + def test_enforces_max_tool_calls(self, governance_filter): + next_fn = AsyncMock() + + # Exhaust the call limit + for i in range(5): + ctx = _make_context("safe_func", "MyPlugin") + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + + # 6th call should be blocked + ctx = _make_context("safe_func", "MyPlugin") + with pytest.raises(Exception, match="Tool call limit exceeded"): + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + + def test_tracks_call_count(self, governance_filter): + next_fn = AsyncMock() + ctx = _make_context("safe_func", "MyPlugin") + + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + assert governance_filter.context.call_count == 1 + + +# ── Audit trail ─────────────────────────────────────────────────── + + +class TestAuditTrail: + """Tests for function invocation recording.""" + + def test_records_invocation(self, governance_filter): + next_fn = AsyncMock() + ctx = _make_context("safe_func", "MyPlugin") + + asyncio.get_event_loop().run_until_complete( + governance_filter(ctx, next_fn) + ) + assert len(governance_filter.context.functions_invoked) == 1 + assert governance_filter.context.functions_invoked[0]["function"] == "MyPlugin.safe_func" + + +# ── Deprecation warnings ───────────────────────────────────────── + + +class TestDeprecationWarnings: + """Tests that legacy methods emit DeprecationWarning.""" + + def test_wrap_emits_deprecation(self, wrapper): + mock_kernel = MagicMock() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + wrapper.wrap(mock_kernel) + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) >= 1 + assert "as_filter" in str(deprecations[0].message) + + def test_wrap_kernel_emits_deprecation(self): + mock_kernel = MagicMock() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + wrap_kernel(mock_kernel) + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) >= 1 + assert "as_filter" in str(deprecations[0].message) diff --git a/agent-governance-python/agent-os/tests/test_smolagents_hooks.py b/agent-governance-python/agent-os/tests/test_smolagents_hooks.py new file mode 100644 index 000000000..ece1e1a1f --- /dev/null +++ b/agent-governance-python/agent-os/tests/test_smolagents_hooks.py @@ -0,0 +1,226 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Tests for smolagents native governance callback (GovernanceStepCallback). + +Validates: +- GovernanceStepCallback creation via as_step_callback() +- Tool blocklist enforcement +- Tool allowlist enforcement +- Blocked pattern detection in arguments +- Blocked pattern detection in observations +- max_tool_calls enforcement +- Audit trail recording +- Deprecation warnings on wrap() +""" + +import warnings +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from agent_os.integrations.smolagents_adapter import ( + GovernanceStepCallback, + SmolagentsKernel, +) + + +# ── Fixtures ────────────────────────────────────────────────────── + + +@pytest.fixture +def kernel(): + """Create a SmolagentsKernel with test configuration.""" + return SmolagentsKernel( + max_tool_calls=5, + allowed_tools=["web_search", "read_file"], + blocked_tools=["exec_code", "shell"], + blocked_patterns=["DROP TABLE", "rm -rf"], + ) + + +@pytest.fixture +def callback(kernel): + """Create a GovernanceStepCallback from the kernel.""" + return kernel.as_step_callback() + + +@pytest.fixture +def mock_agent(): + """Create a mock smolagents agent.""" + return SimpleNamespace(name="test-agent") + + +def _make_step(tool_name=None, tool_args=None, observation=None): + """Create a mock smolagents step.""" + tool_calls = [] + action = None + if tool_name: + action = SimpleNamespace( + tool_name=tool_name, + tool_arguments=tool_args or {}, + ) + return SimpleNamespace( + tool_calls=tool_calls, + action=action, + observation=observation, + ) + + +# ── as_step_callback() factory ─────────────────────────────────── + + +class TestAsStepCallback: + """Tests for the as_step_callback() factory method.""" + + def test_returns_governance_callback(self, kernel): + cb = kernel.as_step_callback() + assert isinstance(cb, GovernanceStepCallback) + + def test_callback_has_kernel_reference(self, kernel): + cb = kernel.as_step_callback() + assert cb.kernel is kernel + + def test_initial_step_count_is_zero(self, kernel): + cb = kernel.as_step_callback() + assert cb.step_count == 0 + + +# ── Blocked tools ───────────────────────────────────────────────── + + +class TestBlockedTools: + """Tests for tool blocklist enforcement.""" + + def test_blocks_blocked_tool(self, callback, mock_agent): + step = _make_step(tool_name="exec_code") + with pytest.raises(Exception, match="explicitly blocked"): + callback(step, mock_agent) + + def test_blocks_shell_tool(self, callback, mock_agent): + step = _make_step(tool_name="shell") + with pytest.raises(Exception, match="explicitly blocked"): + callback(step, mock_agent) + + +# ── Allowed tools ───────────────────────────────────────────────── + + +class TestAllowedTools: + """Tests for tool allowlist enforcement.""" + + def test_allows_approved_tool(self, callback, mock_agent): + step = _make_step(tool_name="web_search") + callback(step, mock_agent) # Should not raise + assert callback.step_count == 1 + + def test_blocks_unapproved_tool(self, callback, mock_agent): + step = _make_step(tool_name="dangerous_tool") + with pytest.raises(Exception, match="not in the allowed list"): + callback(step, mock_agent) + + +# ── Blocked patterns ───────────────────────────────────────────── + + +class TestBlockedPatterns: + """Tests for blocked pattern detection.""" + + def test_blocks_pattern_in_args(self, callback, mock_agent): + step = _make_step( + tool_name="web_search", + tool_args={"query": "DROP TABLE users"}, + ) + with pytest.raises(Exception, match="Blocked pattern"): + callback(step, mock_agent) + + def test_blocks_pattern_in_observation(self, callback, mock_agent): + step = _make_step(observation="Result: rm -rf / completed") + callback(step, mock_agent) # Step without tool calls passes... + + # But a step with tool calls + blocked observation: + kernel2 = SmolagentsKernel( + allowed_tools=["web_search"], + blocked_patterns=["rm -rf"], + ) + cb2 = kernel2.as_step_callback() + step2 = SimpleNamespace( + tool_calls=[], + action=None, + observation="Dangerous output: rm -rf /", + ) + with pytest.raises(Exception, match="Blocked pattern.*observation"): + cb2(step2, mock_agent) + + def test_clean_args_pass(self, callback, mock_agent): + step = _make_step( + tool_name="web_search", + tool_args={"query": "SELECT * FROM users"}, + ) + callback(step, mock_agent) # Should not raise + assert callback.step_count == 1 + + +# ── Call count enforcement ──────────────────────────────────────── + + +class TestCallCount: + """Tests for max_tool_calls enforcement.""" + + def test_enforces_max_tool_calls(self, callback, mock_agent): + for _ in range(5): + step = _make_step(tool_name="web_search") + callback(step, mock_agent) + + # 6th call should be blocked + step = _make_step(tool_name="web_search") + with pytest.raises(Exception, match="Tool call limit exceeded"): + callback(step, mock_agent) + + +# ── Step counting ───────────────────────────────────────────────── + + +class TestStepCounting: + """Tests for step counting.""" + + def test_increments_step_count(self, callback, mock_agent): + # Step with no tool calls + step = _make_step() + callback(step, mock_agent) + assert callback.step_count == 1 + + # Step with a tool call + step = _make_step(tool_name="web_search") + callback(step, mock_agent) + assert callback.step_count == 2 + + +# ── Deprecation warnings ───────────────────────────────────────── + + +class TestDeprecationWarnings: + """Tests that legacy wrap() emits DeprecationWarning.""" + + def test_wrap_emits_deprecation(self, kernel): + mock_agent = SimpleNamespace( + name="test", + toolbox={"tool1": SimpleNamespace(forward=lambda *a, **k: None)}, + ) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + kernel.wrap(mock_agent) + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) >= 1 + assert "as_step_callback" in str(deprecations[0].message) + + +# ── Repr ────────────────────────────────────────────────────────── + + +class TestRepr: + """Tests for GovernanceStepCallback repr.""" + + def test_repr(self, callback): + assert "GovernanceStepCallback" in repr(callback) + assert "steps=0" in repr(callback)