diff --git a/agent-governance-python/agent-os/src/agent_os/integrations/__init__.py b/agent-governance-python/agent-os/src/agent_os/integrations/__init__.py index 92b4b1926..ad4eca6c1 100644 --- a/agent-governance-python/agent-os/src/agent_os/integrations/__init__.py +++ b/agent-governance-python/agent-os/src/agent_os/integrations/__init__.py @@ -71,7 +71,10 @@ PolicyConfig as ADKPolicyConfig, ) from agent_os.integrations.guardrails_adapter import GuardrailsKernel -from agent_os.integrations.langchain_adapter import LangChainKernel +from agent_os.integrations.langchain_adapter import ( + GovernanceMiddleware as LangChainGovernanceMiddleware, + LangChainKernel, +) try: from agent_os.integrations.maf_adapter import ( AuditTrailMiddleware as MAFAuditTrailMiddleware, diff --git a/agent-governance-python/agent-os/src/agent_os/integrations/langchain_adapter.py b/agent-governance-python/agent-os/src/agent_os/integrations/langchain_adapter.py index 3d5a07f34..1957faeff 100644 --- a/agent-governance-python/agent-os/src/agent_os/integrations/langchain_adapter.py +++ b/agent-governance-python/agent-os/src/agent_os/integrations/langchain_adapter.py @@ -1,18 +1,53 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """ -LangChain Integration +LangChain Integration for Agent-OS +========================================== -Wraps LangChain agents/chains with Agent OS governance. +Provides kernel-level governance for LangChain agents and chains using +the LangChain ``AgentMiddleware`` system. -Usage: - from agent_os.integrations import LangChainKernel - - kernel = LangChainKernel() - governed_chain = kernel.wrap(my_langchain_chain) +**Preferred (native middleware)**:: - # Now all invocations go through Agent OS - result = governed_chain.invoke({"input": "..."}) + from agent_os.integrations import LangChainKernel + from langchain.agents import create_agent + + kernel = LangChainKernel( + policy=GovernancePolicy(blocked_patterns=["DROP TABLE"]), + ) + + agent = create_agent( + model="gpt-4o", + tools=[...], + middleware=[kernel.as_middleware()], + ) + result = agent.invoke({"messages": [...]}) + +**With Cedar/OPA policy evaluation**:: + + kernel = LangChainKernel.from_cedar("policies/governance.cedar") + agent = create_agent( + model="gpt-4o", + tools=[...], + middleware=[kernel.as_middleware()], + ) + +**Legacy (deprecated — kept for backward compatibility)**:: + + governed = kernel.wrap(my_chain) + result = governed.invoke({"input": "..."}) + +Features +-------- +- Native ``AgentMiddleware`` integration (``wrap_tool_call``, ``wrap_model_call``) +- Inherits ``BaseIntegration`` — Cedar/OPA, ``pre_execute``, ``post_execute`` +- Tool allowlist/blocklist enforcement via ``wrap_tool_call`` +- Content filtering on model input/output and tool arguments +- PII / secrets detection in memory writes +- Full audit trail with event recording +- Health check endpoint +- Backward-compatible ``wrap()`` / deep hooks + (deprecated, will be removed in a future release) """ import asyncio @@ -20,6 +55,7 @@ import logging import re import time +import warnings from datetime import datetime from typing import Any, Optional @@ -27,6 +63,16 @@ logger = logging.getLogger("agent_os.langchain") + +# ── Graceful import of LangChain AgentMiddleware ────────────────────── +try: + from langchain.agents.middleware import AgentMiddleware as _SDKMiddleware # type: ignore[import-untyped] + + _HAS_MIDDLEWARE = True +except ImportError: + _SDKMiddleware = None + _HAS_MIDDLEWARE = False + # Patterns used to detect potential PII / secrets in memory writes _PII_PATTERNS = [ re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), # SSN @@ -36,15 +82,34 @@ class LangChainKernel(BaseIntegration): - """ - LangChain adapter for Agent OS. + """Governance kernel for LangChain agents and chains. + + Extends :class:`BaseIntegration` so that it inherits Cedar/OPA policy + evaluation, ``pre_execute`` / ``post_execute`` governance, and + ``from_cedar`` factory support. + + The primary integration path is via :meth:`as_middleware`, which returns + a :class:`GovernanceMiddleware` instance that can be passed directly to + ``create_agent(middleware=[...])``. Supports: - - Chains (invoke, ainvoke) - - Agents (run, arun) - - Runnables (invoke, batch, stream) + + - **Native middleware** (recommended): ``wrap_tool_call``, ``wrap_model_call`` + - **Legacy proxy** (deprecated): chains (``invoke``, ``ainvoke``), agents + (``run``), runnables (``batch``, ``stream``) - Deep hooks: tool registry interception, memory write validation, and sub-agent spawn detection (when ``deep_hooks_enabled`` is True). + + Example:: + + kernel = LangChainKernel( + policy=GovernancePolicy(blocked_patterns=["DROP TABLE"]), + ) + agent = create_agent( + model="gpt-4o", + tools=[...], + middleware=[kernel.as_middleware()], + ) """ def __init__( @@ -320,28 +385,51 @@ def governed_invoke(input_data: Any, **kwargs: Any) -> Any: agent.invoke = governed_invoke agent._spawn_governed = True - # ── wrap / unwrap ───────────────────────────────────────────── + # ================================================================ + # Native AgentMiddleware Integration (PRIMARY API) + # ================================================================ - def wrap(self, agent: Any) -> Any: - """Wrap a LangChain chain, agent, or runnable with governance. + def as_middleware(self, name: str = "governance") -> "GovernanceMiddleware": + """Return a :class:`GovernanceMiddleware` backed by this kernel. - Creates a proxy object that intercepts all execution methods - (``invoke``, ``ainvoke``, ``run``, ``batch``, ``stream``) and - applies pre-/post-execution policy checks. + Pass the returned object to ``create_agent(middleware=[...])`` + or to any LangChain component that accepts middleware:: - When :attr:`deep_hooks_enabled` is ``True`` (the default) the - following additional hooks are applied: + kernel = LangChainKernel( + policy=GovernancePolicy(blocked_patterns=["DROP TABLE"]), + ) + agent = create_agent( + model="gpt-4o", + tools=[...], + middleware=[kernel.as_middleware()], + ) - * **Tool registry interception** — each tool's ``_run`` / ``_arun`` - is wrapped with governance checks. - * **Memory write interception** — ``memory.save_context`` is - validated for PII and blocked patterns. - * **Sub-agent spawn detection** — ``invoke`` calls are monitored - for delegation depth. + Args: + name: Optional label for logging/identification. + + Returns: + A :class:`GovernanceMiddleware` instance. + """ + return GovernanceMiddleware(kernel=self, name=name) - The wrapping strategy uses a dynamically created inner class so that - attribute access for non-execution methods (e.g. ``name``, - ``verbose``) is transparently forwarded to the original object. + # ================================================================ + # Deprecated wrap()-Based API (BACKWARD COMPAT) + # ================================================================ + + def wrap(self, agent: Any) -> Any: + """Wrap a LangChain chain, agent, or runnable with governance. + + .. deprecated:: + Use :meth:`as_middleware` instead. The ``wrap()`` approach + creates a fragile proxy object that cannot intercept model-level + events and mutates tool/memory objects. Prefer the native + ``AgentMiddleware`` path:: + + agent = create_agent( + model="gpt-4o", + tools=[...], + middleware=[kernel.as_middleware()], + ) Args: agent: Any LangChain-compatible object that exposes ``invoke``, @@ -354,14 +442,13 @@ def wrap(self, agent: Any) -> Any: Raises: PolicyViolationError: Raised at execution time if input or output violates the active policy. - - Example: - >>> kernel = LangChainKernel(policy=GovernancePolicy( - ... blocked_patterns=["DROP TABLE"] - ... )) - >>> governed = kernel.wrap(my_chain) - >>> result = governed.invoke({"input": "safe query"}) """ + warnings.warn( + "LangChainKernel.wrap() is deprecated. " + "Use kernel.as_middleware() with create_agent(middleware=[...]) instead.", + DeprecationWarning, + stacklevel=2, + ) # Get agent ID from the object agent_id = getattr(agent, 'name', None) or f"langchain-{id(agent)}" ctx = self.create_context(agent_id) @@ -634,7 +721,302 @@ class PolicyViolationError(Exception): pass -# Convenience function +# ===================================================================== +# GovernanceMiddleware (native AgentMiddleware) +# ===================================================================== + +# Build the base class dynamically so the module stays importable even +# when ``langchain`` is not installed. +_MiddlewareBase: type = _SDKMiddleware if _HAS_MIDDLEWARE else object + + +class GovernanceMiddleware(_MiddlewareBase): + """Native LangChain ``AgentMiddleware`` for Agent-OS governance. + + Acts as the **primary** integration surface between Agent-OS and the + LangChain framework. By implementing ``wrap_tool_call`` and + ``wrap_model_call``, the middleware can intercept every tool invocation + and every model request *without* proxy-wrapping or monkey-patching. + + Compared to the deprecated :meth:`LangChainKernel.wrap` approach: + + * No proxy objects — no ``__getattr__`` fragility. + * No object mutation — tools, memory, and agents are left untouched. + * **New capability**: model-level governance via ``wrap_model_call`` + (prompt injection guards, system prompt integrity, dynamic tool + filtering) which the proxy-based approach *cannot* achieve. + * Composable — multiple middleware instances stack naturally. + + Usage:: + + kernel = LangChainKernel( + policy=GovernancePolicy(blocked_patterns=["DROP TABLE"]), + ) + agent = create_agent( + model="gpt-4o", + tools=[...], + middleware=[kernel.as_middleware()], + ) + result = agent.invoke({"messages": [...]}) + """ + + def __init__( + self, + kernel: "LangChainKernel", + name: str = "governance", + ): + """Initialise the governance middleware. + + Args: + kernel: The :class:`LangChainKernel` that supplies the active + governance policy and Cedar/OPA evaluator. + name: Label used in log messages and audit records. + """ + self._kernel = kernel + self._name = name + self._ctx = kernel.create_context(f"langchain-middleware-{name}") + logger.info( + "GovernanceMiddleware '%s' initialised with policy=%r", + name, + kernel.policy, + ) + + # ── wrap_tool_call ──────────────────────────────────────────── + # + # Intercepts every tool execution. Has full access to the tool + # name and arguments before execution, and the result after. + # Can BLOCK by raising PolicyViolationError. + + def wrap_tool_call(self, request: Any, handler: Any) -> Any: + """Governance gate around each tool execution. + + Performs the following checks **before** the tool runs: + + 1. Tool allowlist / blocklist enforcement. + 2. Blocked-pattern scan on tool arguments. + 3. Cedar/OPA ``pre_execute`` gate (when an evaluator is + configured on the kernel). + + After the tool completes, a ``post_execute`` check validates + the output against the governance policy. + + Args: + request: LangChain ``ToolCallRequest`` with ``tool_call`` + dict containing ``name``, ``args``, and ``id``. + handler: Callable that executes the actual tool. + + Returns: + The tool's result (``ToolMessage`` or ``Command``). + + Raises: + PolicyViolationError: If the tool call violates the + governance policy. + """ + tool_call = getattr(request, "tool_call", {}) + tool_name = tool_call.get("name", "") if isinstance(tool_call, dict) else str(tool_call) + tool_args = tool_call.get("args", {}) if isinstance(tool_call, dict) else {} + + logger.debug( + "[%s] wrap_tool_call: tool=%s args=%s", + self._name, + tool_name, + tool_args, + ) + + # ─── 1. Tool allowlist / blocklist ──────────────────────── + self._kernel._check_tool_policy( + tool_name, (tool_args,), {}, self._ctx + ) + + # ─── 2. Cedar/OPA pre_execute gate ──────────────────────── + input_data = { + "tool_name": tool_name, + "tool_args": tool_args, + } + allowed, reason = self._kernel.pre_execute(self._ctx, input_data) + if not allowed: + logger.info( + "[%s] Policy DENY (pre_execute) on tool '%s': %s", + self._name, + tool_name, + reason, + ) + raise PolicyViolationError(reason) + logger.info("[%s] Policy ALLOW on tool '%s'", self._name, tool_name) + + # ─── 3. Record invocation ───────────────────────────────── + self._kernel._record_tool_invocation(tool_name, (tool_args,), {}) + + # ─── 4. Execute the tool ────────────────────────────────── + try: + result = handler(request) + except Exception as exc: + logger.error( + "[%s] Tool '%s' raised: %s", self._name, tool_name, exc + ) + self._kernel._last_error = str(exc) + raise + + # ─── 5. Post-execution validation ───────────────────────── + result_str = str(getattr(result, "content", result)) + + # Blocked-pattern check on output + matched = self._kernel.policy.matches_pattern(result_str) + if matched: + logger.info( + "[%s] Policy DENY: blocked pattern '%s' in tool '%s' output", + self._name, + matched[0], + tool_name, + ) + raise PolicyViolationError( + f"Blocked pattern '{matched[0]}' detected in tool '{tool_name}' output" + ) + + # Drift detection / checkpointing via base post_execute + valid, reason = self._kernel.post_execute(self._ctx, result_str) + if not valid: + logger.info( + "[%s] Policy DENY (post_execute) on tool '%s' result: %s", + self._name, + tool_name, + reason, + ) + raise PolicyViolationError(reason) + + return result + + # ── wrap_model_call ─────────────────────────────────────────── + # + # Intercepts every model (LLM/chat) call. Can modify the request + # (tools, system prompt) or block entirely. This is a *new* + # capability not possible via the proxy-based wrap() approach. + + def wrap_model_call(self, request: Any, handler: Any) -> Any: + """Governance gate around each model invocation. + + Performs the following checks **before** the model call: + + 1. Content-filter scan on input messages for blocked patterns. + 2. Cedar/OPA ``pre_execute`` gate on the model input. + + After the model responds, a ``post_execute`` check validates + the output for blocked patterns. + + This hook also enables **model-level governance** — a capability + that the proxy-based ``wrap()`` approach cannot achieve: + + * System prompt integrity validation. + * Dynamic tool filtering (remove dangerous tools before the + model sees them). + * Prompt injection detection on input messages. + + Args: + request: LangChain ``ModelRequest`` with ``messages``, + ``tools``, and ``system_message`` attributes. + handler: Callable that executes the model call. + + Returns: + The model's response. + + Raises: + PolicyViolationError: If the model input or output + violates the governance policy. + """ + # Extract input text for content filtering + messages = getattr(request, "messages", None) or [] + input_text = "" + for msg in messages: + content = getattr(msg, "content", str(msg)) + if isinstance(content, str): + input_text += " " + content + elif isinstance(content, list): + for block in content: + if isinstance(block, dict): + input_text += " " + block.get("text", "") + + logger.debug( + "[%s] wrap_model_call: input_len=%d", + self._name, + len(input_text), + ) + + # ─── 1. Content filter on input ─────────────────────────── + if input_text.strip(): + allowed, reason = self._kernel.pre_execute( + self._ctx, input_text.strip() + ) + if not allowed: + logger.info( + "[%s] Policy DENY (pre_execute) on model input: %s", + self._name, + reason, + ) + raise PolicyViolationError(reason) + logger.info("[%s] Policy ALLOW on model input", self._name) + + # ─── 2. Execute the model call ──────────────────────────── + try: + response = handler(request) + except Exception as exc: + logger.error("[%s] Model call failed: %s", self._name, exc) + self._kernel._last_error = str(exc) + raise + + # ─── 3. Content filter on output ────────────────────────── + response_msg = getattr(response, "message", response) + output_text = getattr(response_msg, "content", str(response_msg)) + if isinstance(output_text, str) and output_text.strip(): + # Blocked-pattern check on model output + matched = self._kernel.policy.matches_pattern(output_text) + if matched: + logger.info( + "[%s] Policy DENY: blocked pattern '%s' in model output", + self._name, + matched[0], + ) + raise PolicyViolationError( + f"Blocked pattern '{matched[0]}' detected in model output" + ) + + # Drift detection / checkpointing via base post_execute + valid, reason = self._kernel.post_execute( + self._ctx, output_text.strip() + ) + if not valid: + logger.info( + "[%s] Policy DENY (post_execute) on model output: %s", + self._name, + reason, + ) + raise PolicyViolationError(reason) + + return response + + # ── Convenience properties ──────────────────────────────────── + + @property + def kernel(self) -> "LangChainKernel": + """Return the governing kernel.""" + return self._kernel + + @property + def context(self) -> Any: + """Return the execution context.""" + return self._ctx + + def __repr__(self) -> str: + return ( + f"GovernanceMiddleware(name={self._name!r}, " + f"policy={self._kernel.policy!r})" + ) + + +# ===================================================================== +# Convenience function (deprecated) +# ===================================================================== + + def wrap( agent: Any, policy: Optional[GovernancePolicy] = None, @@ -642,6 +1024,16 @@ def wrap( ) -> Any: """Convenience wrapper for LangChain agents and chains. + .. deprecated:: + Use :meth:`LangChainKernel.as_middleware` instead:: + + kernel = LangChainKernel(policy=GovernancePolicy(...)) + agent = create_agent( + model="gpt-4o", + tools=[...], + middleware=[kernel.as_middleware()], + ) + Args: agent: Any LangChain-compatible object. policy: Optional governance policy (uses defaults when ``None``). @@ -649,10 +1041,11 @@ def wrap( Returns: A governed proxy around *agent*. - - Example: - >>> from agent_os.integrations.langchain_adapter import wrap - >>> governed = wrap(my_chain, policy=GovernancePolicy(max_tokens=5000)) - >>> result = governed.invoke({"input": "hello"}) """ + warnings.warn( + "langchain_adapter.wrap() is deprecated. " + "Use LangChainKernel.as_middleware() with create_agent(middleware=[...]) instead.", + DeprecationWarning, + stacklevel=2, + ) return LangChainKernel(policy, timeout_seconds=timeout_seconds).wrap(agent) diff --git a/agent-governance-python/agent-os/tests/test_langchain_middleware.py b/agent-governance-python/agent-os/tests/test_langchain_middleware.py new file mode 100644 index 000000000..9c645e0b0 --- /dev/null +++ b/agent-governance-python/agent-os/tests/test_langchain_middleware.py @@ -0,0 +1,561 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Tests for LangChain GovernanceMiddleware (native AgentMiddleware). + +Covers: +- GovernanceMiddleware.wrap_tool_call (tool-level governance) +- GovernanceMiddleware.wrap_model_call (model-level governance) +- LangChainKernel.as_middleware() (factory method) +- Deprecation warnings on wrap() and module-level wrap() +- Backward compatibility of existing wrap() API + +Run with: python -m pytest tests/test_langchain_middleware.py -v --tb=short +""" + +import warnings +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest + +from agent_os.integrations.langchain_adapter import ( + GovernanceMiddleware, + LangChainKernel, + PolicyViolationError, + wrap as module_wrap, +) +from agent_os.integrations.base import GovernancePolicy + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _make_kernel(**policy_kw) -> LangChainKernel: + """Create a LangChainKernel with the given policy overrides.""" + return LangChainKernel(policy=GovernancePolicy(**policy_kw)) + + +def _make_tool_request(name="get_weather", args=None): + """Create a mock LangChain ToolCallRequest.""" + req = MagicMock() + req.tool_call = { + "name": name, + "args": args or {"city": "NY"}, + "id": "call_001", + } + return req + + +def _make_model_request(messages=None, tools=None): + """Create a mock LangChain ModelRequest.""" + req = MagicMock() + if messages is None: + msg = MagicMock() + msg.content = "Hello, what is the weather?" + messages = [msg] + req.messages = messages + req.tools = tools or [] + req.system_message = MagicMock() + req.system_message.content_blocks = [] + return req + + +def _make_tool_result(content="sunny, 72F"): + """Create a mock tool result (ToolMessage).""" + result = MagicMock() + result.content = content + return result + + +def _make_model_response(content="The weather is sunny."): + """Create a mock model response.""" + resp = MagicMock() + resp.message = MagicMock() + resp.message.content = content + return resp + + +# ============================================================================= +# GovernanceMiddleware — Construction / Properties +# ============================================================================= + + +class TestGovernanceMiddlewareInit: + """Tests for GovernanceMiddleware construction and properties.""" + + def test_as_middleware_returns_middleware_instance(self): + kernel = _make_kernel() + mw = kernel.as_middleware() + assert isinstance(mw, GovernanceMiddleware) + + def test_as_middleware_custom_name(self): + kernel = _make_kernel() + mw = kernel.as_middleware(name="custom") + assert mw._name == "custom" + + def test_kernel_property(self): + kernel = _make_kernel() + mw = kernel.as_middleware() + assert mw.kernel is kernel + + def test_context_property(self): + kernel = _make_kernel() + mw = kernel.as_middleware() + assert mw.context is not None + assert mw.context.agent_id == "langchain-middleware-governance" + + def test_repr(self): + kernel = _make_kernel() + mw = kernel.as_middleware() + r = repr(mw) + assert "GovernanceMiddleware" in r + assert "governance" in r + + def test_context_has_correct_policy(self): + kernel = _make_kernel(max_tokens=2048) + mw = kernel.as_middleware() + assert mw.context.policy.max_tokens == 2048 + + def test_multiple_middleware_are_independent(self): + kernel = _make_kernel() + mw1 = kernel.as_middleware(name="mw1") + mw2 = kernel.as_middleware(name="mw2") + assert mw1._name != mw2._name + assert mw1.context.agent_id != mw2.context.agent_id + + +# ============================================================================= +# GovernanceMiddleware.wrap_tool_call +# ============================================================================= + + +class TestWrapToolCall: + """Tests for tool-level governance via wrap_tool_call.""" + + def test_allowed_tool_passes(self): + """Tool execution succeeds when policy allows it.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = _make_tool_request("get_weather", {"city": "Seattle"}) + handler = MagicMock(return_value=_make_tool_result("rainy, 55F")) + + result = mw.wrap_tool_call(request, handler) + + handler.assert_called_once_with(request) + assert result.content == "rainy, 55F" + + def test_blocked_pattern_in_args_raises(self): + """Tool args containing a blocked pattern trigger denial.""" + kernel = _make_kernel(blocked_patterns=["DROP TABLE"]) + mw = kernel.as_middleware() + request = _make_tool_request("sql_query", {"query": "DROP TABLE users"}) + handler = MagicMock() + + with pytest.raises(PolicyViolationError, match="Blocked pattern"): + mw.wrap_tool_call(request, handler) + + handler.assert_not_called() + + def test_blocked_tool_name_raises(self): + """Tool whose name matches a blocked pattern is denied.""" + kernel = _make_kernel(blocked_patterns=["delete_all"]) + mw = kernel.as_middleware() + request = _make_tool_request("delete_all_records") + handler = MagicMock() + + with pytest.raises(PolicyViolationError, match="delete_all"): + mw.wrap_tool_call(request, handler) + + def test_allowed_tools_enforcement(self): + """Tool not in the allowlist is denied.""" + kernel = _make_kernel(allowed_tools=["get_weather", "search"]) + mw = kernel.as_middleware() + request = _make_tool_request("execute_code") + handler = MagicMock() + + with pytest.raises(PolicyViolationError, match="not in allowed list"): + mw.wrap_tool_call(request, handler) + + def test_allowed_tools_pass_when_listed(self): + """Tool in the allowlist is permitted.""" + kernel = _make_kernel(allowed_tools=["get_weather", "search"]) + mw = kernel.as_middleware() + request = _make_tool_request("get_weather") + handler = MagicMock(return_value=_make_tool_result()) + + result = mw.wrap_tool_call(request, handler) + handler.assert_called_once() + assert result is not None + + def test_tool_invocation_recorded(self): + """Tool invocations are logged to the audit trail.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = _make_tool_request("calculator", {"expr": "2+2"}) + handler = MagicMock(return_value=_make_tool_result("4")) + + mw.wrap_tool_call(request, handler) + + assert len(kernel._tool_invocations) == 1 + assert kernel._tool_invocations[0]["tool_name"] == "calculator" + + def test_post_execute_blocks_on_output_violation(self): + """Post-execution check catches blocked patterns in tool output.""" + kernel = _make_kernel(blocked_patterns=["secret_key"]) + mw = kernel.as_middleware() + request = _make_tool_request("read_config") + handler = MagicMock( + return_value=_make_tool_result("api_secret_key=abc123") + ) + + with pytest.raises(PolicyViolationError, match="secret_key"): + mw.wrap_tool_call(request, handler) + + def test_tool_exception_records_error(self): + """Tool exceptions are recorded in the kernel's last_error.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = _make_tool_request("failing_tool") + handler = MagicMock(side_effect=RuntimeError("tool crashed")) + + with pytest.raises(RuntimeError, match="tool crashed"): + mw.wrap_tool_call(request, handler) + + assert kernel._last_error == "tool crashed" + + def test_call_count_incremented(self): + """Each tool call increments the execution context call count.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + handler = MagicMock(return_value=_make_tool_result()) + + mw.wrap_tool_call(_make_tool_request(), handler) + mw.wrap_tool_call(_make_tool_request(), handler) + + # post_execute increments call_count + assert mw.context.call_count == 2 + + def test_max_tool_calls_blocks_after_limit(self): + """Tool calls are blocked after max_tool_calls is reached.""" + kernel = _make_kernel(max_tool_calls=1) + mw = kernel.as_middleware() + handler = MagicMock(return_value=_make_tool_result()) + + # First call succeeds — post_execute increments to 1 + mw.wrap_tool_call(_make_tool_request(), handler) + + # Second call blocked by pre_execute check + with pytest.raises(PolicyViolationError, match="Max tool calls"): + mw.wrap_tool_call(_make_tool_request(), handler) + + def test_non_dict_tool_call_handled(self): + """Gracefully handles non-dict tool_call attribute.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = MagicMock() + request.tool_call = "plain_string_tool" + handler = MagicMock(return_value=_make_tool_result()) + + result = mw.wrap_tool_call(request, handler) + handler.assert_called_once() + + def test_missing_tool_call_attr_handled(self): + """Gracefully handles request without tool_call attribute.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = MagicMock(spec=[]) # no attributes + handler = MagicMock(return_value=_make_tool_result()) + + result = mw.wrap_tool_call(request, handler) + handler.assert_called_once() + + +# ============================================================================= +# GovernanceMiddleware.wrap_model_call +# ============================================================================= + + +class TestWrapModelCall: + """Tests for model-level governance via wrap_model_call.""" + + def test_allowed_model_call_passes(self): + """Model call succeeds when policy allows the input.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = _make_model_request() + handler = MagicMock(return_value=_make_model_response()) + + result = mw.wrap_model_call(request, handler) + + handler.assert_called_once_with(request) + assert result.message.content == "The weather is sunny." + + def test_blocked_pattern_in_model_input_raises(self): + """Blocked pattern in input messages triggers denial.""" + kernel = _make_kernel(blocked_patterns=["password"]) + mw = kernel.as_middleware() + msg = MagicMock() + msg.content = "My password is hunter2" + request = _make_model_request(messages=[msg]) + handler = MagicMock() + + with pytest.raises(PolicyViolationError, match="password"): + mw.wrap_model_call(request, handler) + + handler.assert_not_called() + + def test_blocked_pattern_in_model_output_raises(self): + """Blocked pattern in model output triggers post-execution denial.""" + kernel = _make_kernel(blocked_patterns=["SSN"]) + mw = kernel.as_middleware() + request = _make_model_request() + handler = MagicMock( + return_value=_make_model_response("Your SSN is 123-45-6789") + ) + + with pytest.raises(PolicyViolationError, match="SSN"): + mw.wrap_model_call(request, handler) + + def test_model_call_with_list_content(self): + """Model call with list-type content blocks works correctly.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + msg = MagicMock() + msg.content = [ + {"type": "text", "text": "What is the weather?"}, + {"type": "image", "url": "http://example.com/img.png"}, + ] + request = _make_model_request(messages=[msg]) + handler = MagicMock(return_value=_make_model_response()) + + result = mw.wrap_model_call(request, handler) + handler.assert_called_once() + + def test_model_call_blocked_in_list_content(self): + """Blocked pattern in list content blocks triggers denial.""" + kernel = _make_kernel(blocked_patterns=["secret"]) + mw = kernel.as_middleware() + msg = MagicMock() + msg.content = [{"type": "text", "text": "Reveal the secret code"}] + request = _make_model_request(messages=[msg]) + handler = MagicMock() + + with pytest.raises(PolicyViolationError, match="secret"): + mw.wrap_model_call(request, handler) + + def test_model_exception_records_error(self): + """Model call exceptions are recorded in the kernel.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = _make_model_request() + handler = MagicMock(side_effect=RuntimeError("API error")) + + with pytest.raises(RuntimeError, match="API error"): + mw.wrap_model_call(request, handler) + + assert kernel._last_error == "API error" + + def test_empty_messages_pass(self): + """Model call with no messages passes without content check.""" + kernel = _make_kernel(blocked_patterns=["secret"]) + mw = kernel.as_middleware() + request = _make_model_request(messages=[]) + handler = MagicMock(return_value=_make_model_response("safe output")) + + result = mw.wrap_model_call(request, handler) + handler.assert_called_once() + + def test_none_messages_pass(self): + """Model call with None messages attribute passes.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = MagicMock() + request.messages = None + handler = MagicMock(return_value=_make_model_response("result")) + + result = mw.wrap_model_call(request, handler) + handler.assert_called_once() + + def test_model_response_non_string_content(self): + """Model response with non-string content is handled gracefully.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + request = _make_model_request() + resp = MagicMock() + resp.message = MagicMock() + resp.message.content = [{"type": "tool_use", "name": "search"}] + handler = MagicMock(return_value=resp) + + result = mw.wrap_model_call(request, handler) + handler.assert_called_once() + # Should not raise — non-string content is skipped + + +# ============================================================================= +# LangChainKernel.as_middleware() integration +# ============================================================================= + + +class TestAsMiddlewareIntegration: + """Tests for the as_middleware() factory and combined flows.""" + + def test_tool_then_model_flow(self): + """End-to-end: tool call followed by model call both pass.""" + kernel = _make_kernel() + mw = kernel.as_middleware() + + # Tool call + tool_req = _make_tool_request("search", {"query": "AI safety"}) + tool_handler = MagicMock( + return_value=_make_tool_result("AI safety research") + ) + mw.wrap_tool_call(tool_req, tool_handler) + + # Model call + model_req = _make_model_request() + model_handler = MagicMock( + return_value=_make_model_response("Here are the results.") + ) + mw.wrap_model_call(model_req, model_handler) + + assert mw.context.call_count >= 2 + + def test_cedar_evaluator_passed_through(self): + """Cedar evaluator on the kernel is accessible via the middleware.""" + mock_evaluator = MagicMock() + kernel = LangChainKernel(evaluator=mock_evaluator) + mw = kernel.as_middleware() + assert mw.kernel._evaluator is mock_evaluator + + def test_as_middleware_returns_new_instance_each_call(self): + """Each call to as_middleware() returns a fresh middleware.""" + kernel = _make_kernel() + mw1 = kernel.as_middleware() + mw2 = kernel.as_middleware() + assert mw1 is not mw2 + + def test_shared_kernel_state(self): + """Multiple middleware instances share the kernel's audit state.""" + kernel = _make_kernel() + mw1 = kernel.as_middleware(name="mw1") + mw2 = kernel.as_middleware(name="mw2") + + tool_handler = MagicMock(return_value=_make_tool_result()) + mw1.wrap_tool_call(_make_tool_request("tool_a"), tool_handler) + mw2.wrap_tool_call(_make_tool_request("tool_b"), tool_handler) + + # Both recorded on the shared kernel + assert len(kernel._tool_invocations) == 2 + + +# ============================================================================= +# Deprecation warnings +# ============================================================================= + + +class TestDeprecationWarnings: + """Verify that deprecated APIs emit DeprecationWarnings.""" + + def test_wrap_emits_deprecation_warning(self): + """LangChainKernel.wrap() emits a DeprecationWarning.""" + kernel = _make_kernel() + chain = MagicMock() + chain.invoke = MagicMock(return_value="result") + chain.name = "test-chain" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + kernel.wrap(chain) + + dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(dep_warnings) >= 1 + assert "as_middleware" in str(dep_warnings[0].message) + + def test_module_wrap_emits_deprecation_warning(self): + """Module-level wrap() emits a DeprecationWarning.""" + chain = MagicMock() + chain.invoke = MagicMock(return_value="result") + chain.name = "test-chain" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + module_wrap(chain) + + dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + # Should get at least 1 from module_wrap, and 1 from the inner .wrap() + assert len(dep_warnings) >= 1 + + +# ============================================================================= +# Backward compatibility — existing wrap() still works +# ============================================================================= + + +class TestBackwardCompatibility: + """Ensure the deprecated wrap() API still functions correctly.""" + + def test_wrap_invoke_still_works(self): + """Deprecated wrap().invoke() still executes and returns results.""" + kernel = _make_kernel() + chain = MagicMock() + chain.invoke = MagicMock(return_value="invoke-result") + chain.name = "legacy-chain" + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + governed = kernel.wrap(chain) + + result = governed.invoke("hello") + assert result == "invoke-result" + chain.invoke.assert_called_once_with("hello") + + def test_wrap_blocks_on_policy_violation(self): + """Deprecated wrap() still enforces policy.""" + kernel = _make_kernel(blocked_patterns=["DROP TABLE"]) + chain = MagicMock() + chain.invoke = MagicMock(return_value="ok") + chain.name = "legacy-chain" + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + governed = kernel.wrap(chain) + + with pytest.raises(PolicyViolationError, match="Blocked pattern"): + governed.invoke("please DROP TABLE users") + + def test_unwrap_still_works(self): + """Deprecated unwrap() returns the original object.""" + kernel = _make_kernel() + chain = MagicMock() + chain.name = "chain" + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + governed = kernel.wrap(chain) + + assert kernel.unwrap(governed) is chain + + +# ============================================================================= +# Health check +# ============================================================================= + + +class TestHealthCheck: + """Health check is unaffected by middleware changes.""" + + def test_health_check_returns_healthy(self): + kernel = _make_kernel() + result = kernel.health_check() + assert result["status"] == "healthy" + assert result["backend"] == "langchain" + assert result["backend_connected"] is True + + def test_health_check_degraded_after_error(self): + kernel = _make_kernel() + kernel._last_error = "something failed" + result = kernel.health_check() + assert result["status"] == "degraded"