diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td index aa073832e1..0b6300726d 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td @@ -44,4 +44,22 @@ def Tracing_TraceMessageOp : Tracing_Op<"trace_message"> { let arguments = (ins OptionalAttr : $msg); } +def Tracing_DebugProbeOp : Tracing_Op<"debug_probe"> { + let summary = "Captures a value into the debug probe buffer."; + + let arguments = (ins + Type.predicate, + MemRefRankOf<[I64], [1]>.predicate, + AnyInteger.predicate + ]>>: $value, + I32Attr: $probe_id, + OptionalAttr: $tag, + OptionalAttr: $nmsb + ); +} + #endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/debug_probes.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/debug_probes.h new file mode 100644 index 0000000000..8dd3065db4 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/debug_probes.h @@ -0,0 +1,58 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_RUNTIME_DEBUG_PROBES_H +#define CONCRETELANG_RUNTIME_DEBUG_PROBES_H + +#include +#include +#include +#include + +namespace mlir { +namespace concretelang { +namespace debug { + +struct ProbeEntry { + uint32_t probe_id; + std::string tag; + int64_t value; + uint32_t nmsb; +}; + +class ProbeBuffer { + std::mutex mutex; + std::vector entries; + +public: + static ProbeBuffer &instance(); + + void reset(); + + void record_plaintext(uint32_t probe_id, int64_t value, const char *tag_ptr, + uint32_t tag_len, uint32_t nmsb); + + size_t size() const; + + const ProbeEntry &get(size_t index) const; + + const std::vector &all() const; +}; + +} // namespace debug +} // namespace concretelang +} // namespace mlir + +extern "C" { +void memref_debug_probe_plaintext(int64_t value, int64_t input_width, + int32_t probe_id, char *tag_ptr, + int32_t tag_len, int32_t nmsb); + +void debug_probe_buffer_reset(); + +uint64_t debug_probe_buffer_size(); +} + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/wrappers.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/wrappers.h index c234674817..2c67ab1c09 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/wrappers.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/wrappers.h @@ -310,6 +310,15 @@ void memref_trace_plaintext(uint64_t input, uint64_t input_width, void memref_trace_message(char *message_ptr, uint32_t message_len); +// Debug probes //////////////////////////////////////////////////////////////// +void memref_debug_probe_plaintext(int64_t value, int64_t input_width, + int32_t probe_id, char *tag_ptr, + int32_t tag_len, int32_t nmsb); + +void debug_probe_buffer_reset(); + +uint64_t debug_probe_buffer_size(); + /// @brief Allocate memory using malloc and check for nullptr /// @param size number of bytes to allocate /// @return pointer to the allocated memory or nullptr diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 660835d173..fba0972c98 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -13,6 +13,7 @@ #include "concretelang/Common/Values.h" #include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Runtime/GPUDFG.hpp" +#include "concretelang/Runtime/debug_probes.h" #include "concretelang/ServerLib/ServerLib.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Error.h" @@ -2347,4 +2348,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( } return result.value(); }); + + // Debug probes + m.def("debug_probe_buffer_reset", &debug_probe_buffer_reset); + m.def("debug_probe_buffer_size", &debug_probe_buffer_size); + m.def("debug_probe_get_entries", []() -> pybind11::list { + auto &buf = mlir::concretelang::debug::ProbeBuffer::instance(); + pybind11::list result; + for (const auto &entry : buf.all()) { + pybind11::dict d; + d["probe_id"] = entry.probe_id; + d["tag"] = entry.tag; + d["value"] = entry.value; + d["nmsb"] = entry.nmsb; + result.append(d); + } + return result; + }); } diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index 2d7f5190ef..52be77c9c5 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -696,6 +696,39 @@ struct ZeroTensorOpPattern }; }; +struct DebugProbeOpPattern + : public mlir::OpConversionPattern { + DebugProbeOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(Tracing::DebugProbeOp debugProbeOp, + Tracing::DebugProbeOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // After simulation type conversion, the value is now i64. + // Re-create the op with the converted operand. + auto newOp = rewriter.replaceOpWithNewOp( + debugProbeOp, mlir::TypeRange{}, adaptor.getValue()); + + newOp.setProbeId(debugProbeOp.getProbeId()); + + if (auto tag = debugProbeOp.getTag()) + newOp.setTag(tag); + + if (auto nmsb = debugProbeOp.getNmsb()) + newOp.setNmsb(nmsb); + + auto inputWidth = + newOp.getValue().getType().cast().getWidth(); + newOp->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth)); + + return ::mlir::success(); + }; +}; + struct TraceCiphertextOpPattern : public mlir::OpConversionPattern { TraceCiphertextOpPattern(mlir::MLIRContext *context, @@ -747,6 +780,10 @@ void SimulateTFHEPass::runOnOperation() { mlir::tensor::CastOp, mlir::LLVM::GlobalOp, mlir::LLVM::AddressOfOp, mlir::LLVM::GEPOp, Tracing::TracePlaintextOp>(); + // DebugProbeOp is dynamically legal after type conversion + target.addDynamicallyLegalOp([&](mlir::Operation *op) { + return converter.isLegal(op->getOperandTypes()); + }); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); @@ -821,7 +858,8 @@ void SimulateTFHEPass::runOnOperation() { patterns.insert(&getContext(), converter); + TraceCiphertextOpPattern, DebugProbeOpPattern>( + &getContext(), converter); patterns.insert(&getContext()); // if overflow detection is enable, then rewrite to CAPI functions that diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp index a87646d079..33d6b8c3c6 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp @@ -22,6 +22,7 @@ namespace memref = mlir::memref; char memref_trace_ciphertext[] = "memref_trace_ciphertext"; char memref_trace_plaintext[] = "memref_trace_plaintext"; char memref_trace_message[] = "memref_trace_message"; +char memref_debug_probe_plaintext[] = "memref_debug_probe_plaintext"; mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, size_t rank) { @@ -76,6 +77,14 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( {mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), rewriter.getI32Type()}, {}); + } else if (funcName == memref_debug_probe_plaintext) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {rewriter.getI64Type(), rewriter.getI64Type(), + rewriter.getI32Type(), + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getI32Type(), rewriter.getI32Type()}, + {}); } else { op->emitError("unknwon external function") << funcName; return mlir::failure(); @@ -190,6 +199,38 @@ void traceMessageAddOperands(Tracing::TraceMessageOp op, op.getLoc(), rewriter.getI32IntegerAttr(msg.size()))); } +void debugProbeAddOperands(Tracing::DebugProbeOp op, + mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) { + auto tag = op.getTag().value_or(""); + auto nmsb = op.getNmsb().value_or(0); + auto probeId = op.getProbeId(); + + // input_width (set as attribute by SimulateTFHE pass) + operands.push_back(rewriter.create( + op.getLoc(), op->getAttr("input_width"))); + + // probe_id + operands.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(probeId))); + + // tag string + std::string tagName; + std::stringstream stream; + stream << rand(); + stream >> tagName; + auto tagVal = mlir::LLVM::createGlobalString( + op.getLoc(), rewriter, tagName, tag, + mlir::LLVM::linkage::Linkage::Linkonce, false); + operands.push_back(tagVal); + operands.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(tag.size()))); + + // nmsb + operands.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(nmsb))); +} + struct TracingToCAPIPass : public TracingToCAPIBase { TracingToCAPIPass() {} @@ -219,6 +260,9 @@ struct TracingToCAPIPass : public TracingToCAPIBase { patterns.add>( &getContext(), traceMessageAddOperands); + patterns.add>( + &getContext(), debugProbeAddOperands); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Runtime/CMakeLists.txt index 2ea777750d..168d863eb7 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Runtime/CMakeLists.txt @@ -9,7 +9,8 @@ if(CONCRETELANG_CUDA_SUPPORT) GPUDFG.cpp simulation.cpp wrappers.cpp - time_util.cpp) + time_util.cpp + debug_probes.cpp) else() add_library( ConcretelangRuntime SHARED @@ -19,7 +20,8 @@ else() GPUDFG.cpp simulation.cpp wrappers.cpp - time_util.cpp) + time_util.cpp + debug_probes.cpp) endif() add_dependencies(ConcretelangRuntime rust_deps_bundle concrete-protocol) @@ -83,7 +85,8 @@ add_library( GPUDFG.cpp simulation.cpp wrappers.cpp - time_util.cpp) + time_util.cpp + debug_probes.cpp) if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") target_link_libraries(ConcretelangRuntimeStatic PUBLIC omp) diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/debug_probes.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/debug_probes.cpp new file mode 100644 index 0000000000..fd1923ff22 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Runtime/debug_probes.cpp @@ -0,0 +1,63 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Runtime/debug_probes.h" +#include + +namespace mlir { +namespace concretelang { +namespace debug { + +ProbeBuffer &ProbeBuffer::instance() { + static ProbeBuffer buf; + return buf; +} + +void ProbeBuffer::reset() { + std::lock_guard lock(mutex); + entries.clear(); +} + +void ProbeBuffer::record_plaintext(uint32_t probe_id, int64_t value, + const char *tag_ptr, uint32_t tag_len, + uint32_t nmsb) { + std::lock_guard lock(mutex); + entries.push_back( + {probe_id, std::string(tag_ptr, tag_len), value, nmsb}); +} + +size_t ProbeBuffer::size() const { return entries.size(); } + +const ProbeEntry &ProbeBuffer::get(size_t index) const { + assert(index < entries.size()); + return entries[index]; +} + +const std::vector &ProbeBuffer::all() const { return entries; } + +} // namespace debug +} // namespace concretelang +} // namespace mlir + +extern "C" { + +void memref_debug_probe_plaintext(int64_t value, int64_t input_width, + int32_t probe_id, char *tag_ptr, + int32_t tag_len, int32_t nmsb) { + mlir::concretelang::debug::ProbeBuffer::instance().record_plaintext( + static_cast(probe_id), value, + tag_ptr, static_cast(tag_len), + static_cast(nmsb)); +} + +void debug_probe_buffer_reset() { + mlir::concretelang::debug::ProbeBuffer::instance().reset(); +} + +uint64_t debug_probe_buffer_size() { + return static_cast( + mlir::concretelang::debug::ProbeBuffer::instance().size()); +} +} diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 0e66b7d9ce..45631f0671 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -36,6 +36,7 @@ * [Simulation](execution-analysis/simulation.md) * [Debugging and artifact](execution-analysis/debug.md) +* [Interactive debugger](interactive-debugger.md) * [Performance](optimization/summary.md) * [GPU acceleration](execution-analysis/gpu_acceleration.md) * [Rust integration](execution-analysis/rust_integration.md) diff --git a/docs/interactive-debugger.md b/docs/interactive-debugger.md new file mode 100644 index 0000000000..67650bbce2 --- /dev/null +++ b/docs/interactive-debugger.md @@ -0,0 +1,329 @@ +# Interactive Debugger + +The Concrete Interactive Debugger lets you inspect what happens inside your FHE circuits. Instead of treating a circuit as a black box, you can see every intermediate value, detect overflows, compare cleartext against simulation, and step through execution in VS Code. + +## Inspecting intermediate values + +`circuit.inspect()` evaluates your circuit in cleartext (no encryption, no noise) and returns a snapshot of every node's computed value. + +```python +from concrete import fhe + +@fhe.compiler({"x": "encrypted", "y": "encrypted"}) +def add(x, y): + return x + y + +circuit = add.compile([(i, j) for i in range(8) for j in range(8)]) + +result = circuit.inspect(2, 6) +print(result.summary()) +``` + +Output: + +``` +%0 = x => 2 +%1 = y => 6 +%2 = add(%0, %1) => 8 +return %2 +``` + +Access the final output directly with `result.output`. + +### Overflow detection + +When values exceed the range the circuit was compiled for, `inspect()` tells you exactly which nodes overflowed: + +```python +bad = circuit.inspect(100, 100) +print(bad.has_overflow) # True + +for snap in bad.overflows: + print(snap) +``` + +### Stopping early + +Halt evaluation before a specific node — like a breakpoint: + +```python +# Stop by predicate +result = circuit.inspect(2, 6, stop_at=lambda node: node.properties.get("name") == "add") + +# Stop by source location +result = circuit.inspect(2, 6, stop_at="/path/to/my_file.py:42") +``` + +When stopped early, `result.output` raises `RuntimeError` since the final output was never computed. + +### Filtering snapshots + +```python +result = circuit.inspect(2, 6) + +result.filter(operation_filter="add") +result.filter(tag_filter="my_tag") +result.filter(is_encrypted_filter=True) +result.filter(overflow_only=True) +result.filter(custom_filter=lambda snap: snap.value > 5) +``` + +### Snapshot properties + +Each `NodeSnapshot` carries: + +| Property | Description | +|----------|-------------| +| `snap.value` | Computed value at this node | +| `snap.index` | Position in evaluation order | +| `snap.operation_name` | `"input"`, `"add"`, `"multiply"`, etc. | +| `snap.location` | Source file and line number | +| `snap.tag` | User-assigned tag (from `fhe.tag(...)`) | +| `snap.is_encrypted` | Whether the output is encrypted | +| `snap.overflow` | Whether the value exceeds the dtype range | + +For multi-function modules, inspect per-function: + +```python +module.my_func.inspect(x) +``` + +--- + +## Simulation probes + +`inspect()` evaluates in pure Python. To verify what happens in the actual compiled MLIR simulation pipeline, use `circuit.run_with_probes()`. This injects debug probe ops into the MLIR, runs simulation, and captures values at each probed node. + +```python +@fhe.compiler({"x": "encrypted"}) +def f(x): + return (x + 1) * 2 + +circuit = f.compile(range(8)) + +probed = circuit.run_with_probes(5) +print(probed.output) # 12 +print(probed.summary()) +``` + +Output: + +``` +Output: 12 +Probes: 2 + + ID Operation Tag Value Overflow +-------------------------------------------------------------------------------- + 1 add 6 + 2 multiply 12 +``` + +### Comparing cleartext vs. simulation + +```python +inspection = circuit.inspect(5) +probed = circuit.run_with_probes(5) + +print(probed.compare_with(inspection)) +``` + +``` +Node Inspect (cleartext) Probe (simulation) Match +---------------------------------------------------------------------------------------------------- +input 5 - (no probe) +add 6 6 OK +multiply 12 12 OK +``` + +If they disagree, the `Match` column shows `MISMATCH`. + +### Choosing which nodes to probe + +```python +# All encrypted nodes (default) +probed = circuit.run_with_probes(5) + +# By tag +probed = circuit.run_with_probes(5, probes=["step1"]) + +# By predicate +probed = circuit.run_with_probes(5, probes=lambda node: node.converted_to_table_lookup) +``` + +Filtering and overflow detection work the same way as `inspect()`: + +```python +probed.filter(tag_filter="step1") +probed.has_overflow +probed.overflows +``` + +Each `ProbeSnapshot` has the same properties as `NodeSnapshot`, plus `snap.probe_id`. + +For modules: `module.my_func.run_with_probes(x, probes=["my_tag"])`. + +> **Note:** Each `run_with_probes()` call recompiles the MLIR. This is fast for small circuits but may be noticeable for large ones. + +--- + +## VS Code debugger + +The VS Code extension gives you a standard IDE debugging experience for FHE circuits: breakpoints, stepping, a variables pane, and a debug console. + +### Installation + +Build and package the extension: + +```bash +cd tools/vscode-concrete-debugger +npm install +npm run build +npx vsce package +``` + +In VS Code: **Cmd+Shift+P** → **"Extensions: Install from VSIX..."** → select the `.vsix` file. + +### Launch configuration + +Add to your `.vscode/launch.json`: + +```jsonc +{ + "type": "concrete", + "request": "launch", + "name": "Debug FHE Circuit", + "program": "${file}", + "function": "my_circuit", // variable name of the compiled Circuit + "args": [3, 5], // input values + "pythonPath": "python3", // must have concrete-python installed + "stopOnEntry": true, // pause before first operation + "stopOnOverflow": false // stop on bit-width overflow +} +``` + +For `@fhe.module`: + +```jsonc +{ + "type": "concrete", + "request": "launch", + "name": "Debug FHE Module", + "program": "${file}", + "function": "my_module", + "functions": [ + { "name": "scale", "args": [5] }, + { "name": "shift", "args": [7] } + ], + "stopOnEntry": true +} +``` + +### Example script + +```python +from concrete import fhe + +@fhe.compiler({"x": "encrypted", "y": "encrypted"}) +def my_circuit(x, y): + with fhe.tag("compute"): + return (x + y) * 2 + +inputset = [(i, j) for i in range(8) for j in range(8)] +my_circuit = my_circuit.compile(inputset) +``` + +Open the script, select your launch configuration, and press **F5**. + +### Stepping + +| Action | Key | Behavior | +|---|---|---| +| Step Over | F10 | Evaluate next DAG node (groups nodes from the same source line) | +| Continue | F5 | Run to next breakpoint or end | +| Step Into | F11 | Enter next function (`@fhe.module`) or step one node | +| Step Out | Shift+F11 | Finish current function | + +### Breakpoints + +Set breakpoints on Python source lines as usual. The debugger maps each line to DAG nodes at that location. A breakpoint shows as a solid red dot if at least one node exists at that line. + +### Variables pane + +When stopped, two scopes are available: + +**Current Node:** + +| Variable | Example | +|---|---| +| `value` | `42` | +| `operation` | `add` | +| `encrypted` | `True` | +| `bit_width` | `8` | +| `overflow` | `False` | +| `tag` | `compute` | +| `location` | `script.py:12` | +| `bounds` | `[0, 255]` | + +**All Evaluated** — expandable list of every node snapshot so far. + +Large arrays show a summary and can be expanded to see individual elements. + +### Call stack + +Synthesized from the tag hierarchy. A node tagged `layer1.matmul.relu` shows: + +``` +relu [layer1.matmul.relu] @ script.py:12 +matmul [layer1.matmul] @ script.py:12 +layer1 [layer1] @ script.py:12 +``` + +### Debug console + +| Expression | Result | +|---|---| +| `value` | Current node's computed value | +| `nodes` | Total nodes and how many evaluated | +| `overflow` | Summary of overflow nodes | +| `snap[3]` | Snapshot at index 3 | +| `function` | Current function name (modules) | +| `functions` | All function names (modules) | + +--- + +## API reference + +### `Circuit.inspect(*args, stop_at=None)` + +Evaluate in cleartext and return an `InspectionResult`. + +- `stop_at` — `str` (location prefix) or `Callable[[Node], bool]` + +### `InspectionResult` + +| Method / Property | Returns | +|---|---| +| `result.output` | Final output (raises `RuntimeError` if stopped early) | +| `result.has_overflow` | `True` if any node overflowed | +| `result.overflows` | List of overflow snapshots | +| `result.filter(...)` | Query by tag, operation, location, encryption, overflow, or predicate | +| `result.summary()` | Formatted table string | +| `len(result)`, `result[i]`, `for snap in result` | List-like access | + +### `Circuit.run_with_probes(*args, probes=None)` + +Run in simulation mode with MLIR debug probes and return a `ProbeResult`. + +- `probes` — `None` (all encrypted nodes), `list[str]` (by tag), or `Callable[[Node], bool]` + +### `ProbeResult` + +| Method / Property | Returns | +|---|---| +| `probed.output` | Simulation output | +| `probed.has_overflow` | `True` if any probe overflowed | +| `probed.overflows` | List of overflow snapshots | +| `probed.filter(...)` | Same filtering as `InspectionResult` | +| `probed.summary()` | Formatted table | +| `probed.compare_with(inspection)` | Side-by-side comparison with `InspectionResult` | +| `len(probed)`, `probed[i]`, `for snap in probed` | List-like access | diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 88a1bbf2a8..e1c1f93856 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -72,7 +72,16 @@ zeros_like, ) from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH -from .representation import Graph, GraphProcessor, Node, Operation +from .representation import ( + Graph, + GraphProcessor, + InspectionResult, + Node, + NodeSnapshot, + Operation, + ProbeResult, + ProbeSnapshot, +) from .tracing.typing import ( f32, f64, diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index 093bc27463..3fd44b9fed 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -5,13 +5,14 @@ # pylint: disable=import-error,no-member,no-name-in-module from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import numpy as np from concrete.compiler import CompilationContext, LweSecretKey, Parameter from mlir.ir import Module as MlirModule from ..representation import Graph +from ..representation.probes import ProbeResult from .client import Client from .configuration import Configuration from .keys import Keys @@ -208,6 +209,53 @@ def decrypt( return self._function.decrypt(*results) + def inspect( + self, + *args: Any, + stop_at=None, + ) -> "InspectionResult": + """ + Inspect intermediate values of the circuit evaluation without noise. + + Args: + *args (Any): + inputs to the circuit + + stop_at (Optional[Union[str, Callable[[Node], bool]]]): + stop condition — string (location prefix) or predicate on Node + + Returns: + InspectionResult: + inspection result with per-node snapshots + """ + return self._function.graph.inspect(*args, stop_at=stop_at) + + def run_with_probes( + self, + *args: Any, + probes: Optional[Union[list[str], Callable]] = None, + ) -> "ProbeResult": + """ + Run the circuit in simulation mode with debug probes inserted. + + Probes capture intermediate values during MLIR simulation execution. + + Args: + *args (Any): + inputs to the circuit + + probes (Optional[Union[list[str], Callable[[Node], bool]]]): + probe specification: + - None: probe all encrypted nodes + - list[str]: probe nodes matching these tags + - Callable: predicate on Node, probe where True + + Returns: + ProbeResult: + result containing the output and captured probe snapshots + """ + return self._function.run_with_probes(*args, probes=probes) + def encrypt_run_decrypt(self, *args: Any) -> Any: """ Encrypt inputs, run the circuit, and decrypt the outputs in one go. diff --git a/frontends/concrete-python/concrete/fhe/compilation/module.py b/frontends/concrete-python/concrete/fhe/compilation/module.py index 4b71e7e13e..dee915b4f2 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module.py @@ -9,7 +9,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path from threading import Thread -from typing import Any, NamedTuple, Optional, Union +from typing import Any, Callable, NamedTuple, Optional, Union import numpy as np from concrete.compiler import CompilationContext, LweSecretKey, Parameter @@ -17,6 +17,7 @@ from ..internal.utils import assert_that from ..representation import Graph +from ..representation.probes import ProbeResult, ProbeSnapshot from ..tfhers.specs import TFHERSClientSpecs from .client import Client from .composition import CompositionRule @@ -355,6 +356,213 @@ def decrypt( results = [res.result() if isinstance(res, Future) else res for res in results] return self.execution_runtime.val.client.decrypt(*results, function_name=self.name) + def inspect( + self, + *args: Any, + stop_at=None, + ) -> "InspectionResult": + """ + Inspect intermediate values of the function evaluation without noise. + + Args: + *args (Any): + inputs to the function + + stop_at (Optional[Union[str, Callable[[Node], bool]]]): + stop condition — string (location prefix) or predicate on Node + + Returns: + InspectionResult: + inspection result with per-node snapshots + """ + return self.graph.inspect(*args, stop_at=stop_at) + + def run_with_probes( + self, + *args: Any, + probes: Optional[Union[list[str], Callable]] = None, + ) -> "ProbeResult": + """ + Run the function in simulation mode with debug probes inserted. + + Probes capture intermediate values during MLIR simulation execution. + + Args: + *args (Any): + inputs to the function + + probes (Optional[Union[list[str], Callable[[Node], bool]]]): + probe specification: + - None: probe all encrypted nodes + - list[str]: probe nodes matching these tags + - Callable: predicate on Node, probe where True + + Returns: + ProbeResult: + result containing the output and captured probe snapshots + """ + import warnings + + import concrete.lang + import concrete.lang.dialects.tracing + import networkx as nx + from concrete.compiler import CompilationContext + from mlir.ir import Context as MlirContext + from mlir.ir import InsertionPoint as MlirInsertionPoint + from mlir.ir import Location as MlirLocation + from mlir.ir import Module as MlirModule + + from ..mlir.converter import Converter + from ..representation import Node, Operation + + graph = self.graph + + # Resolve probe spec to set[Node] + probed_nodes: set[Node] = set() + all_nodes = list(nx.lexicographical_topological_sort(graph.graph)) + + if probes is None: + # Probe all encrypted nodes (skip inputs) + for node in all_nodes: + if node.operation != Operation.Input and node.output.is_encrypted: + probed_nodes.add(node) + elif isinstance(probes, list): + # Match by tag + for node in all_nodes: + if node.operation != Operation.Input and node.tag in probes: + probed_nodes.add(node) + elif callable(probes): + for node in all_nodes: + if node.operation != Operation.Input and probes(node): + probed_nodes.add(node) + + if len(probed_nodes) == 0: + # No probes matched — run normally, return empty ProbeResult + output = self.simulate(*args) + return ProbeResult(output, [], graph) + + if len(probed_nodes) > 100: + warnings.warn( + f"Large number of probes ({len(probed_nodes)}). " + "This may use significant memory.", + stacklevel=2, + ) + + # Recompile MLIR with probes inserted + converter = Converter(self.configuration) + compilation_context = CompilationContext() + mlir_context = compilation_context.mlir_context() + + # Set probed_nodes on the converter's context during conversion + original_convert_many = converter.convert_many + + probe_ctx_ref = [None] + + def patched_convert_many(graphs, mlir_ctx): + with mlir_ctx as context, MlirLocation.unknown(): + concrete.lang.register_dialects(context) + + module = MlirModule.create() + with MlirInsertionPoint(module.body): + for name, g in graphs.items(): + from ..mlir.context import Context + from ..mlir.conversion import Conversion + + ctx = Context(context, g, converter.configuration) + ctx.probed_nodes = probed_nodes + probe_ctx_ref[0] = ctx + + from mlir.dialects import func + + input_types = [ctx.typeof(node).mlir for node in g.ordered_inputs()] + + location = g.location.split(":") + with MlirLocation.file( + location[0], line=int(location[1]), col=0, context=context + ): + + @func.FuncOp.from_py_func(*input_types, name=name) + def main(*fn_args): + for index, node in enumerate(g.ordered_inputs()): + conversion = Conversion(node, fn_args[index]) + if "original_bit_width" in node.properties: + conversion.set_original_bit_width( + node.properties["original_bit_width"] + ) + ctx.conversions[node] = conversion + + ordered_nodes = [ + node + for node in nx.lexicographical_topological_sort(g.graph) + if node.operation != Operation.Input + ] + + for node in ordered_nodes: + preds = [ + ctx.conversions[pred] + for pred in g.ordered_preds_of(node) + ] + converter.node(ctx, node, preds) + + outputs = [] + for node in g.ordered_outputs(): + assert node in ctx.conversions + outputs.append(ctx.conversions[node].result) + + return tuple(outputs) + + return module + + # Process graphs first (assigns bit widths, etc.) + converter.process({graph.name: graph}) + + probed_mlir = patched_convert_many({graph.name: graph}, mlir_context) + probe_id_to_node = probe_ctx_ref[0].probe_id_to_node if probe_ctx_ref[0] else {} + + # Create temporary simulation runtime with probed MLIR + probed_server = Server.create( + probed_mlir, + self.configuration.fork(fhe_simulation=True), + is_simulated=True, + compilation_context=compilation_context, + ) + probed_client = Client(probed_server.client_specs, is_simulated=True) + + # Reset probe buffer, run simulation, read back + from concrete.compiler import ( + debug_probe_buffer_reset, + debug_probe_buffer_size, + debug_probe_get_entries, + ) + + debug_probe_buffer_reset() + + encrypted = probed_client.simulate_encrypt(*args, function_name=self.name) + result = probed_server.run(encrypted, function_name=self.name) + output = probed_client.simulate_decrypt(result, function_name=self.name) + + # Read probe entries + entries = debug_probe_get_entries() + snapshots = [] + for entry in entries: + pid = entry["probe_id"] + node = probe_id_to_node.get(pid) + if node is not None: + snapshots.append( + ProbeSnapshot( + node=node, + probe_id=pid, + value=entry["value"], + tag=entry["tag"], + nmsb=entry["nmsb"], + ) + ) + + # Cleanup + probed_server.cleanup() + + return ProbeResult(output, snapshots, graph) + def encrypt_run_decrypt(self, *args: Any) -> Any: """ Encrypt inputs, run the function, and decrypt the outputs in one go. diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 73dc0616f1..69b62fa716 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -85,6 +85,11 @@ class Context: tfhers_partition: dict[tfhers.CryptoParams, str] + # Debug probe support + probed_nodes: Optional[set[Node]] + _probe_id_counter: int + probe_id_to_node: dict[int, Node] + def __init__(self, context: MlirContext, graph: Graph, configuration: Configuration): self.context = context @@ -99,6 +104,15 @@ def __init__(self, context: MlirContext, graph: Graph, configuration: Configurat self.tfhers_partition = {} + self.probed_nodes = None + self._probe_id_counter = 0 + self.probe_id_to_node = {} + + def next_probe_id(self) -> int: + """Allocate and return the next probe ID.""" + self._probe_id_counter += 1 + return self._probe_id_counter + # types def i(self, width: int) -> ConversionType: diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 7a823d4e99..70c8bf53c1 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -312,6 +312,18 @@ def node(self, ctx: Context, node: Node, preds: list[Conversion]) -> Conversion: conversion.set_original_bit_width(node.properties["original_bit_width"]) ctx.conversions[node] = conversion + + # Insert debug probe if this node is being probed + if ctx.probed_nodes is not None and node in ctx.probed_nodes: + probe_id = ctx.next_probe_id() + ctx.probe_id_to_node[probe_id] = node + tag = node.tag if node.tag else "" + concrete.lang.dialects.tracing.DebugProbeOp( # pylint: disable=no-member + conversion.result, + probe_id=probe_id, + tag=tag, + ) + return conversion # The name of the remaining methods all correspond to node names. diff --git a/frontends/concrete-python/concrete/fhe/representation/__init__.py b/frontends/concrete-python/concrete/fhe/representation/__init__.py index 30825b1f7a..1b0779cdbe 100644 --- a/frontends/concrete-python/concrete/fhe/representation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/representation/__init__.py @@ -3,5 +3,7 @@ """ from .graph import Graph, GraphProcessor, MultiGraphProcessor +from .inspection import InspectionResult, NodeSnapshot from .node import Node from .operation import Operation +from .probes import ProbeResult, ProbeSnapshot diff --git a/frontends/concrete-python/concrete/fhe/representation/graph.py b/frontends/concrete-python/concrete/fhe/representation/graph.py index 454da7820e..1b7b0b8921 100644 --- a/frontends/concrete-python/concrete/fhe/representation/graph.py +++ b/frontends/concrete-python/concrete/fhe/representation/graph.py @@ -208,6 +208,77 @@ def evaluate( return node_results + def inspect( + self, + *args: Any, + stop_at: Optional[Union[str, Callable[["Node"], bool]]] = None, + ) -> "InspectionResult": + """ + Evaluate the graph without noise and return an InspectionResult with per-node snapshots. + + Args: + *args (Any): + inputs to the computation + + stop_at (Optional[Union[str, Callable[[Node], bool]]]): + stop condition — if a string, halts before any node whose location starts with it; + if a callable, halts before the first node where the predicate returns True. + Input nodes are never stopped at. + + Returns: + InspectionResult: + inspection result containing snapshots of all evaluated nodes + """ + + # pylint: disable=import-outside-toplevel + from .inspection import InspectionResult, NodeSnapshot + # pylint: enable=import-outside-toplevel + + def should_stop(node: Node) -> bool: + if node.operation == Operation.Input: + return False + if stop_at is None: + return False + if isinstance(stop_at, str): + return node.location.startswith(stop_at) + return stop_at(node) + + snapshots: list[NodeSnapshot] = [] + node_results: dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]] = {} + stopped = False + stopped_at_node: Optional[Node] = None + index = 0 + + for node in nx.topological_sort(self.graph): + if should_stop(node): + stopped = True + stopped_at_node = node + break + + if node.operation == Operation.Input: + node_results[node] = node(args[self.input_indices[node]]) + snapshots.append(NodeSnapshot(node, deepcopy(node_results[node]), index)) + index += 1 + continue + + pred_results = [deepcopy(node_results[pred]) for pred in self.ordered_preds_of(node)] + + try: + node_results[node] = node(*pred_results) + except Exception as error: + raise RuntimeError( + "Evaluation of the graph failed\n\n" + + self.format( + highlighted_nodes={node: ["evaluation of this node failed"]}, + show_bounds=False, + ) + ) from error + + snapshots.append(NodeSnapshot(node, deepcopy(node_results[node]), index)) + index += 1 + + return InspectionResult(snapshots, self, stopped, stopped_at_node) + def draw( self, *, diff --git a/frontends/concrete-python/concrete/fhe/representation/inspection.py b/frontends/concrete-python/concrete/fhe/representation/inspection.py new file mode 100644 index 0000000000..b5dffbf61e --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/representation/inspection.py @@ -0,0 +1,340 @@ +""" +Declaration of `NodeSnapshot` and `InspectionResult` classes for interactive value inspection. +""" + +import re +from copy import deepcopy +from typing import Any, Callable, Optional, Union + +import numpy as np + +from .node import Node +from .operation import Operation + + +class NodeSnapshot: + """ + Snapshot of one node's evaluation during inspection. + """ + + node: Node + value: Union[np.bool_, np.integer, np.floating, np.ndarray] + index: int + overflow: bool + overflow_min: Optional[int] + overflow_max: Optional[int] + + def __init__( + self, + node: Node, + value: Union[np.bool_, np.integer, np.floating, np.ndarray], + index: int, + ): + self.node = node + self.value = value + self.index = index + + # Overflow detection: check if value exceeds the node's output dtype range + self.overflow = False + self.overflow_min = None + self.overflow_max = None + + from ..dtypes import Integer + + if isinstance(node.output.dtype, Integer): + dtype_min = node.output.dtype.min() + dtype_max = node.output.dtype.max() + + if isinstance(value, np.ndarray): + val_min = int(value.min()) + val_max = int(value.max()) + else: + val_min = int(value) + val_max = int(value) + + if val_min < dtype_min or val_max > dtype_max: + self.overflow = True + self.overflow_min = val_min + self.overflow_max = val_max + + @property + def location(self) -> str: + """Get the source location of the node.""" + return self.node.location + + @property + def tag(self) -> str: + """Get the tag of the node.""" + return self.node.tag + + @property + def operation_name(self) -> str: + """Get the operation name of the node.""" + if self.node.operation == Operation.Input: + return "input" + if self.node.operation == Operation.Constant: + return "constant" + return self.node.properties["name"] + + @property + def is_encrypted(self) -> bool: + """Get whether the node output is encrypted.""" + return self.node.output.is_encrypted + + @property + def exceeds_bounds(self) -> bool: + """Check if the value is outside the node's measured bounds (softer than overflow).""" + if self.node.bounds is None: + return False + + lower, upper = self.node.bounds + + if isinstance(self.value, np.ndarray): + val_min = self.value.min() + val_max = self.value.max() + else: + val_min = self.value + val_max = self.value + + return val_min < lower or val_max > upper + + def __repr__(self) -> str: + return ( + f"NodeSnapshot(index={self.index}, op={self.operation_name}, " + f"value={_format_value_short(self.value)}, overflow={self.overflow})" + ) + + +class InspectionResult: + """ + Collection of node snapshots from a graph inspection, with query and display methods. + """ + + _snapshots: list[NodeSnapshot] + _graph: Any # Graph (avoid circular import in type hint) + _stopped: bool + _stopped_at_node: Optional[Node] + + def __init__( + self, + snapshots: list[NodeSnapshot], + graph: Any, + stopped: bool = False, + stopped_at_node: Optional[Node] = None, + ): + self._snapshots = snapshots + self._graph = graph + self._stopped = stopped + self._stopped_at_node = stopped_at_node + + def __iter__(self): + return iter(self._snapshots) + + def __len__(self) -> int: + return len(self._snapshots) + + def __getitem__(self, index: int) -> NodeSnapshot: + return self._snapshots[index] + + def __repr__(self) -> str: + overflow_count = sum(1 for s in self._snapshots if s.overflow) + stopped_str = f", stopped=True" if self._stopped else "" + return ( + f"InspectionResult(nodes={len(self._snapshots)}, " + f"overflows={overflow_count}{stopped_str})" + ) + + def filter( + self, + tag_filter: Optional[Union[str, list[str], re.Pattern]] = None, + operation_filter: Optional[Union[str, list[str], re.Pattern]] = None, + is_encrypted_filter: Optional[bool] = None, + location_filter: Optional[Union[str, re.Pattern]] = None, + custom_filter: Optional[Callable[[NodeSnapshot], bool]] = None, + overflow_only: bool = False, + ) -> list[NodeSnapshot]: + """ + Filter snapshots by various criteria. + + Args: + tag_filter: filter by node tag + operation_filter: filter by operation name + is_encrypted_filter: filter by encryption status + location_filter: filter by source location + custom_filter: arbitrary predicate on NodeSnapshot + overflow_only: if True, only return snapshots with overflow + + Returns: + list of matching NodeSnapshot objects + """ + + def match_text(text_filter, text): + if text_filter is None: + return True + if isinstance(text_filter, str): + return text == text_filter + if isinstance(text_filter, re.Pattern): + return text_filter.match(text) is not None + return any(text == alt for alt in text_filter) + + results = [] + for snap in self._snapshots: + if not match_text(tag_filter, snap.tag): + continue + if not match_text(operation_filter, snap.operation_name): + continue + if is_encrypted_filter is not None and snap.is_encrypted != is_encrypted_filter: + continue + if location_filter is not None and not match_text(location_filter, snap.location): + continue + if custom_filter is not None and not custom_filter(snap): + continue + if overflow_only and not snap.overflow: + continue + results.append(snap) + + return results + + @property + def overflows(self) -> list[NodeSnapshot]: + """Get all snapshots that have overflow.""" + return self.filter(overflow_only=True) + + @property + def has_overflow(self) -> bool: + """Check if any node has overflow.""" + return any(s.overflow for s in self._snapshots) + + @property + def output( + self, + ) -> Union[ + np.bool_, + np.integer, + np.floating, + np.ndarray, + tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...], + ]: + """ + Get the final output value(s) of the inspected graph. + + Raises: + RuntimeError: if evaluation was stopped early via stop_at + """ + if self._stopped: + raise RuntimeError( + "Cannot retrieve output: inspection was stopped early " + f"(stopped before node at {self._stopped_at_node.location if self._stopped_at_node else 'unknown'})" + ) + + node_to_snapshot = {s.node: s for s in self._snapshots} + outputs = [] + for node in self._graph.ordered_outputs(): + if node not in node_to_snapshot: + raise RuntimeError( + "Cannot retrieve output: output node was not evaluated" + ) + outputs.append(node_to_snapshot[node].value) + + return tuple(outputs) if len(outputs) > 1 else outputs[0] + + def summary( + self, + show_values: bool = True, + show_overflow: bool = True, + maximum_constant_length: int = 25, + ) -> str: + """ + Build a formatted summary table of the inspection. + + Args: + show_values: whether to show computed values + show_overflow: whether to show overflow status + maximum_constant_length: max length for constant formatting + + Returns: + formatted string summary + """ + if len(self._snapshots) == 0: + return "(empty graph)" + + import networkx as nx + + # Build id_map and node_to_snapshot + id_map: dict[Node, int] = {} + node_to_snapshot: dict[Node, NodeSnapshot] = {s.node: s for s in self._snapshots} + + for node in nx.lexicographical_topological_sort(self._graph.graph): + id_map[node] = len(id_map) + + lines: list[str] = [] + extra_columns: list[dict[str, str]] = [] + + for node in nx.lexicographical_topological_sort(self._graph.graph): + predecessors = [] + for pred in self._graph.ordered_preds_of(node): + predecessors.append(f"%{id_map[pred]}") + + line = f"%{id_map[node]} = {node.format(predecessors, maximum_constant_length)}" + lines.append(line) + + snap = node_to_snapshot.get(node) + cols: dict[str, str] = {} + + if show_values: + if snap is not None: + cols["value"] = f"=> {_format_value_short(snap.value)}" + else: + cols["value"] = "=> (not evaluated)" + + if show_overflow: + if snap is not None and snap.overflow: + cols["overflow"] = f"OVERFLOW [{snap.overflow_min}, {snap.overflow_max}]" + else: + cols["overflow"] = "" + + extra_columns.append(cols) + + # Align = signs + longest_before_eq = max(len(line.split("=")[0]) for line in lines) + for i, line in enumerate(lines): + before_eq_len = len(line.split("=")[0]) + lines[i] = " " * (longest_before_eq - before_eq_len) + line + + # Add extra columns + shown_keys = [] + if show_values: + shown_keys.append("value") + if show_overflow: + shown_keys.append("overflow") + + indent = 4 + for key in shown_keys: + longest = max(len(line) for line in lines) + lines = [ + line + " " * (longest - len(line) + indent) + cols.get(key, "") + for line, cols in zip(lines, extra_columns) + ] + + # Add return line + returns = [] + for node in self._graph.ordered_outputs(): + returns.append(f"%{id_map[node]}") + lines.append(f"return {', '.join(returns)}") + + if self._stopped: + lines.append( + f"(inspection stopped before reaching " + f"{self._stopped_at_node.location if self._stopped_at_node else 'unknown node'})" + ) + + return "\n".join(line.rstrip() for line in lines) + + +def _format_value_short(value: Union[np.bool_, np.integer, np.floating, np.ndarray]) -> str: + """Format a value concisely for display.""" + if isinstance(value, np.ndarray): + if value.size <= 8: + return repr(value) + return f"array(shape={value.shape}, min={value.min()}, max={value.max()})" + return repr(value) diff --git a/frontends/concrete-python/concrete/fhe/representation/probes.py b/frontends/concrete-python/concrete/fhe/representation/probes.py new file mode 100644 index 0000000000..9dad061013 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/representation/probes.py @@ -0,0 +1,255 @@ +""" +Declaration of `ProbeSnapshot` and `ProbeResult` classes for simulation-mode debug probes. +""" + +import re +from typing import Any, Callable, Optional, Union + +import numpy as np + +from .node import Node +from .operation import Operation + + +class ProbeSnapshot: + """ + Snapshot of one probed node's value captured during simulation execution. + """ + + node: Node + probe_id: int + value: int + tag: str + nmsb: int + + def __init__( + self, + node: Node, + probe_id: int, + value: int, + tag: str = "", + nmsb: int = 0, + ): + self.node = node + self.probe_id = probe_id + self.value = value + self.tag = tag + self.nmsb = nmsb + + @property + def location(self) -> str: + """Get the source location of the node.""" + return self.node.location + + @property + def node_tag(self) -> str: + """Get the tag of the node.""" + return self.node.tag + + @property + def operation_name(self) -> str: + """Get the operation name of the node.""" + if self.node.operation == Operation.Input: + return "input" + if self.node.operation == Operation.Constant: + return "constant" + return self.node.properties["name"] + + @property + def is_encrypted(self) -> bool: + """Get whether the node output is encrypted.""" + return self.node.output.is_encrypted + + @property + def overflow(self) -> bool: + """Check if the probed value overflows the node's output dtype range.""" + from ..dtypes import Integer + + if not isinstance(self.node.output.dtype, Integer): + return False + + dtype_min = self.node.output.dtype.min() + dtype_max = self.node.output.dtype.max() + return self.value < dtype_min or self.value > dtype_max + + def __repr__(self) -> str: + return ( + f"ProbeSnapshot(probe_id={self.probe_id}, op={self.operation_name}, " + f"value={self.value}, tag={self.tag!r})" + ) + + +class ProbeResult: + """ + Collection of probe snapshots from a simulation run, with query and display methods. + """ + + _output: Any + _snapshots: list[ProbeSnapshot] + _graph: Any + + def __init__( + self, + output: Any, + snapshots: list[ProbeSnapshot], + graph: Any, + ): + self._output = output + self._snapshots = snapshots + self._graph = graph + + @property + def output(self) -> Any: + """Get the circuit output from the probed run.""" + return self._output + + def __iter__(self): + return iter(self._snapshots) + + def __len__(self) -> int: + return len(self._snapshots) + + def __getitem__(self, index: int) -> ProbeSnapshot: + return self._snapshots[index] + + def __repr__(self) -> str: + overflow_count = sum(1 for s in self._snapshots if s.overflow) + return ( + f"ProbeResult(output={self._output}, probes={len(self._snapshots)}, " + f"overflows={overflow_count})" + ) + + def filter( + self, + tag_filter: Optional[Union[str, list[str], re.Pattern]] = None, + operation_filter: Optional[Union[str, list[str], re.Pattern]] = None, + is_encrypted_filter: Optional[bool] = None, + location_filter: Optional[Union[str, re.Pattern]] = None, + custom_filter: Optional[Callable[[ProbeSnapshot], bool]] = None, + overflow_only: bool = False, + ) -> list[ProbeSnapshot]: + """ + Filter probe snapshots by various criteria. + + Args: + tag_filter: filter by node tag + operation_filter: filter by operation name + is_encrypted_filter: filter by encryption status + location_filter: filter by source location + custom_filter: arbitrary predicate on ProbeSnapshot + overflow_only: if True, only return snapshots with overflow + + Returns: + list of matching ProbeSnapshot objects + """ + + def match_text(text_filter, text): + if text_filter is None: + return True + if isinstance(text_filter, str): + return text == text_filter + if isinstance(text_filter, re.Pattern): + return text_filter.match(text) is not None + return any(text == alt for alt in text_filter) + + results = [] + for snap in self._snapshots: + if not match_text(tag_filter, snap.node_tag): + continue + if not match_text(operation_filter, snap.operation_name): + continue + if is_encrypted_filter is not None and snap.is_encrypted != is_encrypted_filter: + continue + if location_filter is not None and not match_text(location_filter, snap.location): + continue + if custom_filter is not None and not custom_filter(snap): + continue + if overflow_only and not snap.overflow: + continue + results.append(snap) + + return results + + @property + def overflows(self) -> list[ProbeSnapshot]: + """Get all snapshots that have overflow.""" + return self.filter(overflow_only=True) + + @property + def has_overflow(self) -> bool: + """Check if any probed node has overflow.""" + return any(s.overflow for s in self._snapshots) + + def summary(self) -> str: + """ + Build a formatted summary table of the probe results. + + Returns: + formatted string summary + """ + if len(self._snapshots) == 0: + return "(no probes captured)" + + lines = [] + lines.append(f"Output: {self._output}") + lines.append(f"Probes: {len(self._snapshots)}") + lines.append("") + + # Header + lines.append(f"{'ID':>4} {'Operation':<20} {'Tag':<20} {'Value':>12} {'Overflow'}") + lines.append("-" * 80) + + for snap in self._snapshots: + overflow_str = "OVERFLOW" if snap.overflow else "" + lines.append( + f"{snap.probe_id:>4} {snap.operation_name:<20} " + f"{snap.node_tag:<20} {snap.value:>12} {overflow_str}" + ) + + return "\n".join(line.rstrip() for line in lines) + + def compare_with(self, inspection: "InspectionResult") -> str: + """ + Side-by-side comparison of simulation probes vs cleartext inspection. + + Args: + inspection: InspectionResult from circuit.inspect() + + Returns: + formatted comparison string + """ + from .inspection import InspectionResult + + lines = [] + lines.append(f"{'Node':<30} {'Inspect (cleartext)':>20} {'Probe (simulation)':>20} {'Match'}") + lines.append("-" * 100) + + # Build lookup from node to probe value + node_to_probe = {} + for snap in self._snapshots: + node_to_probe[snap.node] = snap.value + + for inspect_snap in inspection: + node = inspect_snap.node + inspect_val = inspect_snap.value + probe_val = node_to_probe.get(node, None) + + if probe_val is not None: + if isinstance(inspect_val, np.ndarray): + match = "array" + else: + match = "OK" if int(inspect_val) == probe_val else "MISMATCH" + else: + match = "(no probe)" + + probe_str = str(probe_val) if probe_val is not None else "-" + inspect_str = str(inspect_val) + if len(inspect_str) > 20: + inspect_str = inspect_str[:17] + "..." + if len(probe_str) > 20: + probe_str = probe_str[:17] + "..." + + op_name = inspect_snap.operation_name + lines.append(f"{op_name:<30} {inspect_str:>20} {probe_str:>20} {match}") + + return "\n".join(line.rstrip() for line in lines) diff --git a/tools/concrete-debugger/.gitignore b/tools/concrete-debugger/.gitignore new file mode 100644 index 0000000000..ec69fe3834 --- /dev/null +++ b/tools/concrete-debugger/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +*.pyc +.pytest_cache/ +*.egg-info/ +dist/ +build/ diff --git a/tools/concrete-debugger/concrete_dap/__init__.py b/tools/concrete-debugger/concrete_dap/__init__.py new file mode 100644 index 0000000000..0c0df88ee6 --- /dev/null +++ b/tools/concrete-debugger/concrete_dap/__init__.py @@ -0,0 +1 @@ +"""Concrete FHE Debug Adapter Protocol server.""" diff --git a/tools/concrete-debugger/concrete_dap/breakpoints.py b/tools/concrete-debugger/concrete_dap/breakpoints.py new file mode 100644 index 0000000000..b93d7fd3be --- /dev/null +++ b/tools/concrete-debugger/concrete_dap/breakpoints.py @@ -0,0 +1,92 @@ +"""Breakpoint manager: maps file:line to DAG node sets.""" + +import os +from typing import Optional + + +class BreakpointManager: + """Maps (filename, lineno) pairs to sets of node indices in topological order.""" + + def __init__(self): + self._location_index: dict[tuple[str, int], list[int]] = {} + self._active: set[tuple[str, int]] = set() + self._script_dir: str = "" + + def set_script_dir(self, path: str) -> None: + """Set the script directory for resolving relative node locations.""" + self._script_dir = path + + def build_index(self, topo_order: list) -> None: + """Build a reverse index from (file, line) to node indices in topo_order.""" + self._location_index.clear() + for idx, node in enumerate(topo_order): + loc = getattr(node, "location", "") + if not loc or ":" not in loc: + continue + parts = loc.rsplit(":", 1) + try: + filepath = _normalize_path(parts[0], self._script_dir) + lineno = int(parts[1]) + except (ValueError, IndexError): + continue + key = (filepath, lineno) + if key not in self._location_index: + self._location_index[key] = [] + self._location_index[key].append(idx) + + def set_breakpoints(self, source_path: str, lines: list[int]) -> list[dict]: + """Set breakpoints for a source file, returning DAP Breakpoint objects.""" + norm_path = _normalize_path(source_path, self._script_dir) + + # Remove old breakpoints for this file + self._active = { + (f, l) for f, l in self._active if f != norm_path + } + + results = [] + for line in lines: + key = (norm_path, line) + verified = key in self._location_index + if verified: + self._active.add(key) + results.append({ + "verified": verified, + "line": line, + "source": {"path": source_path}, + }) + return results + + def is_breakpoint(self, node, topo_index: int) -> bool: + """Check if a node at given topo index is at a breakpoint location.""" + loc = getattr(node, "location", "") + if not loc or ":" not in loc: + return False + parts = loc.rsplit(":", 1) + try: + filepath = _normalize_path(parts[0], self._script_dir) + lineno = int(parts[1]) + except (ValueError, IndexError): + return False + key = (filepath, lineno) + if key not in self._active: + return False + # Only stop at the first node at this location (in topo order) + first_idx = self._location_index.get(key, [None])[0] + return topo_index == first_idx + + def get_available_lines(self, source_path: str) -> list[int]: + """Get all lines in a source file that have DAG nodes.""" + norm_path = _normalize_path(source_path, self._script_dir) + return sorted({l for f, l in self._location_index if f == norm_path}) + + +def _normalize_path(path: str, base_dir: str = "") -> str: + """Normalize a file path for consistent comparison. + + Relative paths are resolved against *base_dir* (typically the script's + directory) so that node locations like ``"script.py:12"`` match the + absolute paths sent by VS Code. + """ + if base_dir and not os.path.isabs(path): + path = os.path.join(base_dir, path) + return os.path.normcase(os.path.normpath(path)) diff --git a/tools/concrete-debugger/concrete_dap/protocol.py b/tools/concrete-debugger/concrete_dap/protocol.py new file mode 100644 index 0000000000..8f8eebb04f --- /dev/null +++ b/tools/concrete-debugger/concrete_dap/protocol.py @@ -0,0 +1,89 @@ +"""DAP message types and Content-Length framed I/O.""" + +import json +import sys +from typing import Any, Optional + + +def read_message(stream=None) -> Optional[dict]: + """Read a DAP message with Content-Length framing from stream.""" + if stream is None: + stream = sys.stdin.buffer + + headers = {} + while True: + line = stream.readline() + if not line: + return None + line = line.decode("utf-8").rstrip("\r\n") + if line == "": + break + if ":" in line: + key, value = line.split(":", 1) + headers[key.strip()] = value.strip() + + content_length = int(headers.get("Content-Length", "0")) + if content_length == 0: + return None + + body = stream.read(content_length) + if not body: + return None + + return json.loads(body.decode("utf-8")) + + +def write_message(msg: dict, stream=None) -> None: + """Write a DAP message with Content-Length framing to stream.""" + if stream is None: + stream = sys.stdout.buffer + + body = json.dumps(msg, default=_json_default).encode("utf-8") + header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8") + stream.write(header) + stream.write(body) + stream.flush() + + +def _json_default(obj): + """JSON serializer for numpy types.""" + import numpy as np + + if isinstance(obj, (np.integer,)): + return int(obj) + if isinstance(obj, (np.floating,)): + return float(obj) + if isinstance(obj, (np.bool_,)): + return bool(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + +def make_response(request: dict, body: Optional[dict] = None, success: bool = True, + message: str = "") -> dict: + """Build a DAP response message.""" + resp = { + "seq": 0, + "type": "response", + "request_seq": request.get("seq", 0), + "command": request.get("command", ""), + "success": success, + } + if body is not None: + resp["body"] = body + if message: + resp["message"] = message + return resp + + +def make_event(event: str, body: Optional[dict] = None) -> dict: + """Build a DAP event message.""" + evt = { + "seq": 0, + "type": "event", + "event": event, + } + if body is not None: + evt["body"] = body + return evt diff --git a/tools/concrete-debugger/concrete_dap/server.py b/tools/concrete-debugger/concrete_dap/server.py new file mode 100644 index 0000000000..fea1ea62a1 --- /dev/null +++ b/tools/concrete-debugger/concrete_dap/server.py @@ -0,0 +1,477 @@ +"""DAP server: stdin/stdout message loop and dispatch.""" + +import os +import runpy +import sys +import traceback +from typing import Any, Optional + +from .breakpoints import BreakpointManager +from .protocol import make_event, make_response, read_message, write_message +from .session import ConcreteDebugSession, ModuleDebugSession, StopReason +from .variables import VariableStore + +THREAD_ID = 1 + + +class DAPServer: + """Debug Adapter Protocol server for Concrete FHE circuits.""" + + def __init__(self, input_stream=None, output_stream=None): + self._input = input_stream or sys.stdin.buffer + self._output = output_stream or sys.stdout.buffer + self._seq = 1 + self._session: Optional[ConcreteDebugSession] = None + self._breakpoints = BreakpointManager() + self._variables = VariableStore() + self._launch_config: dict = {} + self._running = True + self._initialized = False + + self._handlers = { + "initialize": self._handle_initialize, + "launch": self._handle_launch, + "disconnect": self._handle_disconnect, + "setBreakpoints": self._handle_set_breakpoints, + "setFunctionBreakpoints": self._handle_set_function_breakpoints, + "setExceptionBreakpoints": self._handle_set_exception_breakpoints, + "configurationDone": self._handle_configuration_done, + "threads": self._handle_threads, + "stackTrace": self._handle_stack_trace, + "scopes": self._handle_scopes, + "variables": self._handle_variables, + "continue": self._handle_continue, + "next": self._handle_next, + "stepIn": self._handle_step_in, + "stepOut": self._handle_step_out, + "evaluate": self._handle_evaluate, + "pause": self._handle_pause, + "source": self._handle_source, + } + + def run(self) -> None: + """Main message loop.""" + while self._running: + msg = read_message(self._input) + if msg is None: + break + self._dispatch(msg) + + def _dispatch(self, msg: dict) -> None: + """Route a DAP message to its handler.""" + msg_type = msg.get("type", "") + if msg_type != "request": + return + + command = msg.get("command", "") + handler = self._handlers.get(command) + + if handler is None: + self._send(make_response(msg, success=False, + message=f"Unknown command: {command}")) + return + + try: + handler(msg) + except Exception as e: + self._send(make_response(msg, success=False, + message=f"Error handling {command}: {e}")) + traceback.print_exc(file=sys.stderr) + + def _send(self, msg: dict) -> None: + """Send a DAP message with auto-incrementing sequence number.""" + msg["seq"] = self._seq + self._seq += 1 + write_message(msg, self._output) + + def _send_output(self, text: str, category: str = "console") -> None: + """Send a DAP output event to the Debug Console.""" + self._send(make_event("output", { + "category": category, + "output": text + "\n", + })) + + # ── Lifecycle ── + + def _handle_initialize(self, request: dict) -> None: + capabilities = { + "supportsConfigurationDoneRequest": True, + "supportsEvaluateForHovers": True, + "supportsSingleThreadExecutionRequests": True, + } + self._send(make_response(request, body=capabilities)) + self._send(make_event("initialized")) + self._initialized = True + + def _handle_launch(self, request: dict) -> None: + args = request.get("arguments", {}) + self._launch_config = args + + program = args.get("program", "") + function_name = args.get("function", "") + functions_config = args.get("functions") + input_args = tuple(args.get("args", [])) + stop_on_entry = args.get("stopOnEntry", True) + stop_on_overflow = args.get("stopOnOverflow", False) + + if not program: + self._send(make_response(request, success=False, + message="'program' is required in launch config")) + return + + if not function_name: + self._send(make_response(request, success=False, + message="'function' is required in launch config")) + return + + self._send_output("Loading script...") + + # Execute the user script to find the circuit. + # Change to the script's directory so that .artifacts/ and other + # relative paths land next to the script, not in the (potentially + # read-only) working directory inherited from VS Code. + script_dir = os.path.dirname(os.path.abspath(program)) + self._breakpoints.set_script_dir(script_dir) + prev_cwd = os.getcwd() + try: + os.chdir(script_dir) + except OSError: + pass # best-effort; if it fails we'll still try to run + + try: + namespace = runpy.run_path(program, run_name="__main__") + except Exception as e: + self._send_output(f"Error: {e}", "stderr") + self._send(make_response(request, success=False, + message=f"Failed to execute {program}: {e}")) + return + finally: + os.chdir(prev_cwd) + + # Find the circuit object + circuit_obj = namespace.get(function_name) + if circuit_obj is None: + self._send(make_response(request, success=False, + message=f"'{function_name}' not found in {program}")) + return + + # Check for module with functions config + if functions_config: + module_graphs = _extract_module_graphs(circuit_obj) + if module_graphs is None: + self._send(make_response(request, success=False, + message=f"'{function_name}' is not a module (no .graphs attribute)")) + return + self._handle_launch_module(request, module_graphs, functions_config, + stop_on_entry, stop_on_overflow) + return + + graph = _extract_graph(circuit_obj) + if graph is None: + self._send(make_response(request, success=False, + message=f"'{function_name}' is not a Circuit or Compiler object")) + return + + self._send_output(f"Found circuit with {len(graph.graph.nodes)} nodes") + self._send_output(f"Evaluating inputs: {input_args}") + + # Create debug session + self._session = ConcreteDebugSession(graph, input_args, self._breakpoints, + stop_on_overflow=stop_on_overflow) + self._emit_graph_summary(graph) + self._send(make_response(request)) + + def _handle_disconnect(self, request: dict) -> None: + self._send(make_response(request)) + self._running = False + + def _handle_configuration_done(self, request: dict) -> None: + self._send(make_response(request)) + if not self._session: + return + # Now that VS Code is fully configured, start evaluation. + reason = self._session.evaluate_inputs_and_stop_on_entry() + if self._launch_config.get("stopOnEntry", True): + # Pause before the first non-input node + self._send_stopped_event(reason) + else: + # Run until breakpoint or end + if reason == StopReason.ENTRY: + reason = self._session.continue_to_breakpoint() + self._send_stopped_event(reason) + + # ── Breakpoints ── + + def _handle_set_breakpoints(self, request: dict) -> None: + args = request.get("arguments", {}) + source = args.get("source", {}) + source_path = source.get("path", "") + bp_lines = [bp.get("line", 0) for bp in args.get("breakpoints", [])] + + results = self._breakpoints.set_breakpoints(source_path, bp_lines) + self._send(make_response(request, body={"breakpoints": results})) + + def _handle_set_function_breakpoints(self, request: dict) -> None: + self._send(make_response(request, body={"breakpoints": []})) + + def _handle_set_exception_breakpoints(self, request: dict) -> None: + self._send(make_response(request)) + + # ── Threads ── + + def _handle_threads(self, request: dict) -> None: + threads = [{"id": THREAD_ID, "name": "FHE Circuit Evaluation"}] + self._send(make_response(request, body={"threads": threads})) + + # ── Stack / Scopes / Variables ── + + def _handle_stack_trace(self, request: dict) -> None: + if self._session is None: + self._send(make_response(request, body={"stackFrames": [], "totalFrames": 0})) + return + + frames = self._session.get_stack_frames() + self._send(make_response(request, body={ + "stackFrames": frames, + "totalFrames": len(frames), + })) + + def _handle_scopes(self, request: dict) -> None: + if self._session is None or self._session.current_snapshot is None: + self._send(make_response(request, body={"scopes": []})) + return + + self._variables.reset() + scopes = self._variables.scopes_for_stop( + self._session.current_snapshot, + list(self._session.snapshots), + ) + + if isinstance(self._session, ModuleDebugSession): + scopes.extend(self._variables.scopes_for_module_stop( + self._session.current_function_name, + self._session._current_idx, + len(self._session._sessions), + self._session.all_snapshots, + )) + + self._send(make_response(request, body={"scopes": scopes})) + + def _handle_variables(self, request: dict) -> None: + args = request.get("arguments", {}) + ref = args.get("variablesReference", 0) + variables = self._variables.get_variables(ref) + self._send(make_response(request, body={"variables": variables})) + + # ── Stepping ── + + def _handle_continue(self, request: dict) -> None: + self._send(make_response(request, body={"allThreadsContinued": True})) + if self._session: + reason = self._session.continue_to_breakpoint() + self._send_stopped_event(reason) + + def _handle_next(self, request: dict) -> None: + self._send(make_response(request)) + if self._session: + reason = self._session.step_one() + self._send_stopped_event(reason) + + def _handle_step_in(self, request: dict) -> None: + self._send(make_response(request)) + if self._session: + if isinstance(self._session, ModuleDebugSession): + reason = self._session.step_into_next_function() + else: + reason = self._session.step_one() + self._send_stopped_event(reason) + + def _handle_step_out(self, request: dict) -> None: + self._send(make_response(request)) + if self._session: + reason = self._session.step_out() + self._send_stopped_event(reason) + + def _handle_pause(self, request: dict) -> None: + # Graph evaluation is synchronous, pause is a no-op + self._send(make_response(request)) + + # ── Module Launch ── + + def _handle_launch_module(self, request: dict, module_graphs: dict, + functions_config: list, stop_on_entry: bool, + stop_on_overflow: bool) -> None: + """Launch a module debug session with multiple functions.""" + named_sessions = [] + for func_conf in functions_config: + name = func_conf.get("name", "") + func_args = tuple(func_conf.get("args", [])) + + if name not in module_graphs: + avail = list(module_graphs.keys()) + self._send(make_response( + request, success=False, + message=f"Function '{name}' not found in module. Available: {avail}")) + return + + graph = module_graphs[name] + session = ConcreteDebugSession(graph, func_args, self._breakpoints, + stop_on_overflow=stop_on_overflow) + named_sessions.append((name, session)) + self._send_output(f"Function '{name}': {len(graph.graph.nodes)} nodes") + self._emit_graph_summary(graph) + + self._session = ModuleDebugSession(named_sessions, self._breakpoints) + self._send(make_response(request)) + + def _emit_graph_summary(self, graph) -> None: + """Emit a compact circuit summary to the Debug Console.""" + total = len(graph.graph.nodes) + input_count = len(getattr(graph, 'input_nodes', {})) + output_count = len(getattr(graph, 'output_nodes', {})) + + ops = set() + input_set = set(getattr(graph, 'input_nodes', {}).values()) + for n in graph.graph.nodes: + if n not in input_set: + name = getattr(n, 'properties', {}).get("name", "") + if name: + ops.add(name) + + has_tags = any(getattr(n, 'tag', '') for n in graph.graph.nodes) + + lines = [ + "=== Circuit Summary ===", + f" Nodes: {total} ({input_count} inputs, {output_count} outputs)", + f" Operations: {', '.join(sorted(ops)) if ops else '(none)'}", + f" Tags: {'yes' if has_tags else 'no'}", + ] + self._send_output("\n".join(lines)) + + # ── Evaluate ── + + def _handle_evaluate(self, request: dict) -> None: + args = request.get("arguments", {}) + expression = args.get("expression", "") + + if self._session is None: + self._send(make_response(request, body={"result": "(no active session)", "variablesReference": 0})) + return + + result = self._evaluate_expression(expression) + self._send(make_response(request, body={"result": result, "variablesReference": 0})) + + def _handle_source(self, request: dict) -> None: + self._send(make_response(request, body={"content": ""})) + + # ── Helpers ── + + def _send_stopped_event(self, reason: StopReason) -> None: + """Send a DAP 'stopped' event based on the stop reason.""" + if reason == StopReason.FINISHED: + self._send(make_event("terminated")) + return + + reason_map = { + StopReason.STEP: "step", + StopReason.BREAKPOINT: "breakpoint", + StopReason.ENTRY: "entry", + StopReason.EXCEPTION: "exception", + StopReason.OVERFLOW: "data breakpoint", + } + + body: dict = { + "reason": reason_map.get(reason, "step"), + "threadId": THREAD_ID, + "allThreadsStopped": True, + } + + if reason == StopReason.EXCEPTION and self._session and self._session.error: + body["text"] = str(self._session.error) + body["description"] = "Node evaluation failed" + snap = self._session.current_snapshot + if snap: + self._send_output( + f"Exception at node [{snap.index}] {snap.operation_name}: " + f"{self._session.error}", "stderr") + + if reason == StopReason.OVERFLOW and self._session: + snap = self._session.current_snapshot + if snap: + body["text"] = f"Overflow at [{snap.index}] {snap.operation_name}" + body["description"] = "Value exceeds bit width" + self._send_output( + f"Overflow: node [{snap.index}] {snap.operation_name}, " + f"value={repr(snap.value)}", "important") + + self._send(make_event("stopped", body)) + + def _evaluate_expression(self, expression: str) -> str: + """Evaluate a debug console expression against the session state.""" + session = self._session + if session is None: + return "(no session)" + + # Support querying by snapshot index: e.g. "snap[3]" + if expression.startswith("snap[") and expression.endswith("]"): + try: + idx = int(expression[5:-1]) + if 0 <= idx < len(session.snapshots): + snap = session.snapshots[idx] + return repr(snap) + return f"(index {idx} out of range, {len(session.snapshots)} snapshots)" + except ValueError: + pass + + # Support "nodes" to get count + if expression == "nodes": + return f"{len(session.topo_order)} nodes total, {len(session.snapshots)} evaluated" + + # Support "value" for current + if expression == "value" and session.current_snapshot: + return repr(session.current_snapshot.value) + + # Support "overflow" check + if expression == "overflow": + overflows = [s for s in session.snapshots if s.overflow] + if overflows: + return f"{len(overflows)} overflow(s): " + ", ".join( + f"[{s.index}] {s.operation_name}" for s in overflows + ) + return "No overflows detected" + + # Module-specific expressions + if isinstance(self._session, ModuleDebugSession): + if expression == "functions": + names = [name for name, _ in self._session._sessions] + return f"Functions: {', '.join(names)}" + if expression == "function": + return f"Current function: {self._session.current_function_name}" + + return f"(unknown expression: {expression})" + + +def _extract_graph(obj): + """Extract a Graph from a Circuit, Compiler/Compilable, or object with .graph.""" + # Circuit object + if hasattr(obj, "graph"): + graph = obj.graph + if hasattr(graph, "graph") and hasattr(graph, "input_indices"): + return graph + + # Compilable that has been traced + if hasattr(obj, "_graph"): + return obj._graph + + return None + + +def _extract_module_graphs(obj) -> dict | None: + """Extract a dict of graphs from an FheModule (obj.graphs).""" + graphs = getattr(obj, 'graphs', None) + if isinstance(graphs, dict): + # Verify at least one entry looks like a Graph + for g in graphs.values(): + if hasattr(g, 'graph') and hasattr(g, 'input_indices'): + return graphs + return None diff --git a/tools/concrete-debugger/concrete_dap/session.py b/tools/concrete-debugger/concrete_dap/session.py new file mode 100644 index 0000000000..a014abfb91 --- /dev/null +++ b/tools/concrete-debugger/concrete_dap/session.py @@ -0,0 +1,427 @@ +"""ConcreteDebugSession: graph walker with pause/resume for DAP stepping.""" + +from copy import deepcopy +from enum import Enum, auto +from typing import Any, Optional + +import networkx as nx +import numpy as np + +from .breakpoints import BreakpointManager + +# Module-level references — populated lazily from concrete.fhe or overridden by tests. +_NodeSnapshot = None +_OperationInput = None + + +def _ensure_imports(): + """Populate _NodeSnapshot and _OperationInput from concrete.fhe (if available).""" + global _NodeSnapshot, _OperationInput + if _NodeSnapshot is None: + try: + from concrete.fhe.representation.inspection import NodeSnapshot + _NodeSnapshot = NodeSnapshot + except ImportError: + pass + if _OperationInput is None: + try: + from concrete.fhe.representation.operation import Operation + _OperationInput = Operation.Input + except ImportError: + pass + + +def _is_input_node(node) -> bool: + """Check whether *node* is an input node, compatible with both real and mock types.""" + op = node.operation + if _OperationInput is not None and op == _OperationInput: + return True + # Fallback: duck-type check for mock/string operations + op_val = getattr(op, "value", op) + return op_val == "input" + + +class StopReason(Enum): + STEP = auto() + BREAKPOINT = auto() + ENTRY = auto() + FINISHED = auto() + EXCEPTION = auto() + OVERFLOW = auto() + + +class ConcreteDebugSession: + """Walks a Concrete FHE graph node-by-node with pause/resume state.""" + + def __init__(self, graph, args: tuple, breakpoints: BreakpointManager, + stop_on_overflow: bool = False): + _ensure_imports() + self.graph = graph + self.args = args + self.breakpoints = breakpoints + self.stop_on_overflow = stop_on_overflow + + self.topo_order: list = list(nx.topological_sort(graph.graph)) + self.current_index: int = 0 + self.node_results: dict = {} + self.snapshots: list = [] + self.finished: bool = False + self._error: Optional[Exception] = None + + # Build breakpoint index + breakpoints.build_index(self.topo_order) + + @property + def current_node(self): + """The node about to be evaluated (or just evaluated after a step).""" + if self.current_index > 0 and self.current_index <= len(self.topo_order): + return self.topo_order[self.current_index - 1] + return None + + @property + def current_snapshot(self): + """The most recent snapshot after stepping.""" + if self.snapshots: + return self.snapshots[-1] + return None + + @property + def error(self) -> Optional[Exception]: + return self._error + + def evaluate_inputs_and_stop_on_entry(self) -> StopReason: + """Evaluate all Input nodes and stop before the first non-Input node.""" + while self.current_index < len(self.topo_order): + node = self.topo_order[self.current_index] + if not _is_input_node(node): + return StopReason.ENTRY + self._evaluate_node(node) + self.current_index += 1 + + self.finished = True + return StopReason.FINISHED + + def step_one(self) -> StopReason: + """Evaluate the next node and stop (with same-line grouping).""" + if self.finished or self.current_index >= len(self.topo_order): + self.finished = True + return StopReason.FINISHED + + node = self.topo_order[self.current_index] + try: + self._evaluate_node(node) + except Exception as e: + self._error = e + self.finished = True + return StopReason.EXCEPTION + + if self.stop_on_overflow and self.snapshots[-1].overflow: + self.current_index += 1 + if self.current_index >= len(self.topo_order): + self.finished = True + return StopReason.OVERFLOW + + current_location = node.location + self.current_index += 1 + + # Keep stepping while the next node has the same source location + while self.current_index < len(self.topo_order): + next_node = self.topo_order[self.current_index] + if _is_input_node(next_node): + break + if next_node.location != current_location: + break + try: + self._evaluate_node(next_node) + except Exception as e: + self._error = e + self.finished = True + return StopReason.EXCEPTION + if self.stop_on_overflow and self.snapshots[-1].overflow: + self.current_index += 1 + if self.current_index >= len(self.topo_order): + self.finished = True + return StopReason.OVERFLOW + self.current_index += 1 + + if self.current_index >= len(self.topo_order): + self.finished = True + return StopReason.FINISHED + + return StopReason.STEP + + def continue_to_breakpoint(self, _skip_first: bool = True) -> StopReason: + """Evaluate nodes until a breakpoint is hit or the graph finishes.""" + if self.finished: + return StopReason.FINISHED + + # If currently stopped at a breakpoint node, step past it first + first_step = _skip_first + offset = getattr(self, '_topo_offset', 0) + while self.current_index < len(self.topo_order): + node = self.topo_order[self.current_index] + + # Check breakpoint (skip on first step to avoid re-stopping at same spot) + if not first_step and self.breakpoints.is_breakpoint( + node, self.current_index + offset + ): + return StopReason.BREAKPOINT + first_step = False + + try: + self._evaluate_node(node) + except Exception as e: + self._error = e + self.finished = True + return StopReason.EXCEPTION + if self.stop_on_overflow and self.snapshots[-1].overflow: + self.current_index += 1 + if self.current_index >= len(self.topo_order): + self.finished = True + return StopReason.OVERFLOW + self.current_index += 1 + + self.finished = True + return StopReason.FINISHED + + def step_out(self) -> StopReason: + """Step out — in single-graph mode, runs to completion.""" + while self.current_index < len(self.topo_order): + node = self.topo_order[self.current_index] + try: + self._evaluate_node(node) + except Exception as e: + self._error = e + self.finished = True + return StopReason.EXCEPTION + if self.stop_on_overflow and self.snapshots[-1].overflow: + self.current_index += 1 + if self.current_index >= len(self.topo_order): + self.finished = True + return StopReason.OVERFLOW + self.current_index += 1 + + self.finished = True + return StopReason.FINISHED + + def _evaluate_node(self, node): + """Evaluate a single node and create a snapshot.""" + if _is_input_node(node): + self.node_results[node] = node(self.args[self.graph.input_indices[node]]) + else: + pred_results = [ + deepcopy(self.node_results[pred]) + for pred in self.graph.ordered_preds_of(node) + ] + self.node_results[node] = node(*pred_results) + + snapshot_cls = _NodeSnapshot + if snapshot_cls is None: + # Fallback: plain object (for testing without concrete) + snapshot_cls = _FallbackSnapshot + snapshot = snapshot_cls(node, deepcopy(self.node_results[node]), len(self.snapshots)) + self.snapshots.append(snapshot) + return snapshot + + def get_stack_frames(self) -> list[dict]: + """Synthesize DAP stack frames from the tag hierarchy of the current node.""" + snapshot = self.current_snapshot + if snapshot is None: + return [] + + tag = snapshot.tag + location = snapshot.location + frames = [] + + # Parse file:line from location + source_info = _parse_location(location) + + if tag: + parts = tag.split(".") + # Innermost frame first + for i in range(len(parts), 0, -1): + frame_tag = ".".join(parts[:i]) + frame_name = parts[i - 1] + frame = { + "id": i - 1, + "name": f"{frame_name} [{frame_tag}]", + "line": source_info.get("line", 0), + "column": 0, + } + if source_info.get("path"): + frame["source"] = {"path": source_info["path"]} + frames.append(frame) + else: + # No tag — single frame with operation name + frame = { + "id": 0, + "name": snapshot.operation_name, + "line": source_info.get("line", 0), + "column": 0, + } + if source_info.get("path"): + frame["source"] = {"path": source_info["path"]} + frames.append(frame) + + return frames + + +class _FallbackSnapshot: + """Minimal snapshot for when concrete.fhe is not available (testing).""" + + def __init__(self, node, value, index): + self.node = node + self.value = value + self.index = index + # Bounds-based overflow check from node dtype + self.overflow = False + try: + dtype = node.output.dtype + lo, hi = dtype.min(), dtype.max() + if isinstance(value, np.ndarray): + self.overflow = bool(int(value.min()) < lo or int(value.max()) > hi) + else: + self.overflow = bool(int(value) < lo or int(value) > hi) + except (AttributeError, TypeError, ValueError): + pass + + @property + def location(self): + return self.node.location + + @property + def tag(self): + return self.node.tag + + @property + def operation_name(self): + op = self.node.operation + op_val = getattr(op, "value", op) + if op_val == "input": + return "input" + if op_val == "constant": + return "constant" + return self.node.properties.get("name", "unknown") + + @property + def is_encrypted(self): + return getattr(self.node.output, "is_encrypted", False) + + +class ModuleDebugSession: + """Debug session for @fhe.module circuits with multiple functions. + + Wraps an ordered list of ConcreteDebugSession instances, one per function. + Tracks the active function and advances to the next when it finishes. + """ + + def __init__(self, named_sessions: list, breakpoints: BreakpointManager): + self._sessions = named_sessions # list of (name, ConcreteDebugSession) + self._current_idx = 0 + self.breakpoints = breakpoints + + # Build combined topo order for cross-function breakpoints + offset = 0 + combined_topo = [] + for _name, session in self._sessions: + session._topo_offset = offset + combined_topo.extend(session.topo_order) + offset += len(session.topo_order) + breakpoints.build_index(combined_topo) + + @property + def _active(self): + return self._sessions[self._current_idx][1] + + @property + def current_function_name(self) -> str: + return self._sessions[self._current_idx][0] + + @property + def current_snapshot(self): + return self._active.current_snapshot + + @property + def snapshots(self) -> list: + return self._active.snapshots + + @property + def all_snapshots(self) -> list: + result = [] + for _, session in self._sessions: + result.extend(session.snapshots) + return result + + @property + def error(self): + return self._active.error + + @property + def finished(self) -> bool: + return (self._current_idx >= len(self._sessions) - 1 + and self._active.finished) + + @property + def topo_order(self) -> list: + return self._active.topo_order + + @property + def stop_on_overflow(self) -> bool: + return self._active.stop_on_overflow + + def evaluate_inputs_and_stop_on_entry(self) -> StopReason: + return self._active.evaluate_inputs_and_stop_on_entry() + + def step_one(self) -> StopReason: + reason = self._active.step_one() + if reason == StopReason.FINISHED and self._advance_to_next(): + return StopReason.ENTRY + return reason + + def continue_to_breakpoint(self) -> StopReason: + reason = self._active.continue_to_breakpoint() + while reason == StopReason.FINISHED and self._advance_to_next(): + reason = self._active.continue_to_breakpoint(_skip_first=False) + return reason + + def step_out(self) -> StopReason: + reason = self._active.step_out() + if reason == StopReason.FINISHED and self._advance_to_next(): + return StopReason.ENTRY + return reason + + def step_into_next_function(self) -> StopReason: + if self._active.finished: + if self._advance_to_next(): + return StopReason.ENTRY + return StopReason.FINISHED + return self.step_one() + + def _advance_to_next(self) -> bool: + if self._current_idx < len(self._sessions) - 1: + self._current_idx += 1 + self._active.evaluate_inputs_and_stop_on_entry() + return True + return False + + def get_stack_frames(self) -> list[dict]: + frames = self._active.get_stack_frames() + progress = f"{self._current_idx + 1}/{len(self._sessions)}" + frames.append({ + "id": len(frames), + "name": f"Module [{self.current_function_name}] ({progress})", + "line": 0, + "column": 0, + }) + return frames + + +def _parse_location(location: str) -> dict: + """Parse 'file.py:42' into {path, line}.""" + if not location or ":" not in location: + return {} + parts = location.rsplit(":", 1) + try: + return {"path": parts[0], "line": int(parts[1])} + except (ValueError, IndexError): + return {} diff --git a/tools/concrete-debugger/concrete_dap/variables.py b/tools/concrete-debugger/concrete_dap/variables.py new file mode 100644 index 0000000000..0577143d5e --- /dev/null +++ b/tools/concrete-debugger/concrete_dap/variables.py @@ -0,0 +1,224 @@ +"""Convert NodeSnapshots into DAP variable trees.""" + +from typing import Any, Optional + +import numpy as np + + +class VariableStore: + """Manages DAP variable references for lazy expansion of complex values.""" + + def __init__(self): + self._next_ref = 1 + self._refs: dict[int, Any] = {} + + def reset(self): + """Clear all variable references.""" + self._next_ref = 1 + self._refs.clear() + + def _alloc_ref(self, obj: Any) -> int: + """Allocate a variables reference for an expandable object.""" + ref = self._next_ref + self._next_ref += 1 + self._refs[ref] = obj + return ref + + def scopes_for_stop(self, current_snapshot, all_snapshots: list) -> list[dict]: + """Build DAP Scope objects for a stopped state.""" + scopes = [] + + current_ref = self._alloc_ref(("current_node", current_snapshot)) + scopes.append({ + "name": "Current Node", + "variablesReference": current_ref, + "expensive": False, + }) + + all_ref = self._alloc_ref(("all_evaluated", all_snapshots)) + scopes.append({ + "name": "All Evaluated", + "variablesReference": all_ref, + "expensive": False, + }) + + return scopes + + def get_variables(self, variables_ref: int) -> list[dict]: + """Resolve a variablesReference into DAP Variable objects.""" + obj = self._refs.get(variables_ref) + if obj is None: + return [] + + if isinstance(obj, tuple) and len(obj) == 2: + tag, data = obj + + if tag == "current_node": + return self._snapshot_variables(data) + + if tag == "all_evaluated": + return self._snapshot_list_variables(data) + + if tag == "ndarray": + return self._ndarray_variables(data) + + if tag == "module_context": + return self._module_context_variables(data) + + return [] + + def _snapshot_variables(self, snapshot) -> list[dict]: + """Build variables for a single NodeSnapshot.""" + variables = [] + value = snapshot.value + + # Value — possibly expandable if ndarray + if isinstance(value, np.ndarray) and value.size > 8: + ref = self._alloc_ref(("ndarray", value)) + variables.append({ + "name": "value", + "value": f"array(shape={value.shape}, min={value.min()}, max={value.max()})", + "variablesReference": ref, + }) + else: + variables.append({ + "name": "value", + "value": _format_value(value), + "variablesReference": 0, + }) + + variables.append({ + "name": "operation", + "value": snapshot.operation_name, + "variablesReference": 0, + }) + variables.append({ + "name": "encrypted", + "value": str(snapshot.is_encrypted), + "variablesReference": 0, + }) + + # Bit width + bit_width = _get_bit_width(snapshot.node) + if bit_width is not None: + variables.append({ + "name": "bit_width", + "value": str(bit_width), + "variablesReference": 0, + }) + + variables.append({ + "name": "overflow", + "value": str(snapshot.overflow), + "variablesReference": 0, + }) + + variables.append({ + "name": "tag", + "value": snapshot.tag or "(none)", + "variablesReference": 0, + }) + variables.append({ + "name": "location", + "value": snapshot.location or "(unknown)", + "variablesReference": 0, + }) + + # Bounds + if snapshot.node.bounds is not None: + lower, upper = snapshot.node.bounds + variables.append({ + "name": "bounds", + "value": f"[{lower}, {upper}]", + "variablesReference": 0, + }) + + return variables + + def _snapshot_list_variables(self, snapshots: list) -> list[dict]: + """Build variables for a list of snapshots (expandable).""" + variables = [] + for snap in snapshots: + ref = self._alloc_ref(("current_node", snap)) + variables.append({ + "name": f"[{snap.index}] {snap.operation_name}", + "value": _format_value(snap.value), + "variablesReference": ref, + }) + return variables + + def scopes_for_module_stop(self, function_name: str, current_idx: int, + total_functions: int, + all_snapshots: list) -> list[dict]: + """Build a Module Context scope for module debug sessions.""" + ref = self._alloc_ref(("module_context", { + "function_name": function_name, + "current_idx": current_idx, + "total_functions": total_functions, + "all_snapshots": all_snapshots, + })) + return [{ + "name": "Module Context", + "variablesReference": ref, + "expensive": False, + }] + + def _module_context_variables(self, ctx: dict) -> list[dict]: + """Build variables for the Module Context scope.""" + variables = [ + { + "name": "function", + "value": ctx["function_name"], + "variablesReference": 0, + }, + { + "name": "progress", + "value": f"{ctx['current_idx'] + 1}/{ctx['total_functions']}", + "variablesReference": 0, + }, + ] + + if ctx["all_snapshots"]: + ref = self._alloc_ref(("all_evaluated", ctx["all_snapshots"])) + variables.append({ + "name": "all_functions_snapshots", + "value": f"{len(ctx['all_snapshots'])} snapshots", + "variablesReference": ref, + }) + + return variables + + def _ndarray_variables(self, arr: np.ndarray) -> list[dict]: + """Expand an ndarray into indexed elements.""" + variables = [] + flat = arr.flatten() + for i, val in enumerate(flat[:200]): # cap at 200 elements + variables.append({ + "name": f"[{i}]", + "value": str(val), + "variablesReference": 0, + }) + if flat.size > 200: + variables.append({ + "name": "...", + "value": f"({flat.size - 200} more elements)", + "variablesReference": 0, + }) + return variables + + +def _format_value(value) -> str: + """Format a numpy value for display.""" + if isinstance(value, np.ndarray): + if value.size <= 8: + return repr(value) + return f"array(shape={value.shape}, min={value.min()}, max={value.max()})" + return repr(value) + + +def _get_bit_width(node) -> Optional[int]: + """Extract bit width from a node's output dtype, if available.""" + try: + return node.output.dtype.bit_width + except AttributeError: + return None diff --git a/tools/concrete-debugger/concrete_dap_server.py b/tools/concrete-debugger/concrete_dap_server.py new file mode 100644 index 0000000000..b14a9a134d --- /dev/null +++ b/tools/concrete-debugger/concrete_dap_server.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +"""Entry point for the Concrete FHE DAP debug server. + +VS Code spawns this process and communicates via stdin/stdout using +the Debug Adapter Protocol (Content-Length framed JSON). +""" + +import sys +import os + +# Ensure the concrete-debugger package is importable +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from concrete_dap.server import DAPServer + + +def main(): + server = DAPServer() + server.run() + + +if __name__ == "__main__": + main() diff --git a/tools/concrete-debugger/requirements.txt b/tools/concrete-debugger/requirements.txt new file mode 100644 index 0000000000..e6b87730f1 --- /dev/null +++ b/tools/concrete-debugger/requirements.txt @@ -0,0 +1,3 @@ +concrete-python +numpy +networkx diff --git a/tools/concrete-debugger/tests/__init__.py b/tools/concrete-debugger/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/concrete-debugger/tests/test_breakpoints.py b/tools/concrete-debugger/tests/test_breakpoints.py new file mode 100644 index 0000000000..0d5abc63c4 --- /dev/null +++ b/tools/concrete-debugger/tests/test_breakpoints.py @@ -0,0 +1,81 @@ +"""Tests for BreakpointManager.""" + +import pytest + +from concrete_dap.breakpoints import BreakpointManager + + +class MockNode: + def __init__(self, location=""): + self.location = location + + +class TestBreakpointManager: + def test_build_index(self): + bp = BreakpointManager() + nodes = [ + MockNode("script.py:10"), + MockNode("script.py:11"), + MockNode("script.py:10"), # same line as first + ] + bp.build_index(nodes) + lines = bp.get_available_lines("script.py") + assert sorted(lines) == [10, 11] + + def test_set_breakpoints_verified(self): + bp = BreakpointManager() + nodes = [MockNode("app.py:5"), MockNode("app.py:10")] + bp.build_index(nodes) + + results = bp.set_breakpoints("app.py", [5, 7, 10]) + assert results[0]["verified"] is True + assert results[0]["line"] == 5 + assert results[1]["verified"] is False # line 7 has no nodes + assert results[2]["verified"] is True + + def test_is_breakpoint(self): + bp = BreakpointManager() + n1 = MockNode("app.py:5") + n2 = MockNode("app.py:5") # same line, second node + n3 = MockNode("app.py:10") + nodes = [n1, n2, n3] + bp.build_index(nodes) + bp.set_breakpoints("app.py", [5]) + + # Only the first node at line 5 should trigger + assert bp.is_breakpoint(n1, 0) is True + assert bp.is_breakpoint(n2, 1) is False # not first at this line + assert bp.is_breakpoint(n3, 2) is False # not a breakpoint line + + def test_clear_old_breakpoints(self): + bp = BreakpointManager() + nodes = [MockNode("a.py:1"), MockNode("a.py:2")] + bp.build_index(nodes) + bp.set_breakpoints("a.py", [1, 2]) + + # Now set only line 2 — line 1 should be cleared + bp.set_breakpoints("a.py", [2]) + assert bp.is_breakpoint(nodes[0], 0) is False + assert bp.is_breakpoint(nodes[1], 1) is True + + def test_node_without_location(self): + bp = BreakpointManager() + nodes = [MockNode(""), MockNode("valid.py:1")] + bp.build_index(nodes) + lines = bp.get_available_lines("valid.py") + assert lines == [1] + + def test_is_breakpoint_no_location(self): + bp = BreakpointManager() + node = MockNode("") + assert bp.is_breakpoint(node, 0) is False + + def test_path_normalization(self): + """Paths should be normalized for comparison.""" + bp = BreakpointManager() + nodes = [MockNode("/home/user/./scripts/../scripts/app.py:5")] + bp.build_index(nodes) + + # Setting breakpoints with a different but equivalent path + results = bp.set_breakpoints("/home/user/scripts/app.py", [5]) + assert results[0]["verified"] is True diff --git a/tools/concrete-debugger/tests/test_module_session.py b/tools/concrete-debugger/tests/test_module_session.py new file mode 100644 index 0000000000..2f0386d06c --- /dev/null +++ b/tools/concrete-debugger/tests/test_module_session.py @@ -0,0 +1,343 @@ +"""Tests for ModuleDebugSession — multi-function stepping.""" + +import networkx as nx +import numpy as np +import pytest + +from concrete_dap.breakpoints import BreakpointManager +from concrete_dap.session import ConcreteDebugSession, ModuleDebugSession, StopReason + + +# ── Helpers ── + + +class MockValueDescription: + def __init__(self, shape=(), is_encrypted=True, dtype=None): + self.shape = shape + self.is_encrypted = is_encrypted + self.dtype = dtype or MockDtype() + + +class MockDtype: + def __init__(self, bit_width=8): + self.bit_width = bit_width + + def min(self): + return -(2 ** (self.bit_width - 1)) + + def max(self): + return 2 ** (self.bit_width - 1) - 1 + + +class MockNode: + def __init__(self, name, operation, evaluator, inputs=None, location="test.py:1", + tag="", output=None): + self.name = name + self.operation = operation + self.evaluator = evaluator + self.inputs = inputs or [] + self.location = location + self.tag = tag + self.output = output or MockValueDescription() + self.bounds = None + self.properties = {"name": name} + self.created_at = 0.0 + + def __call__(self, *args): + return self.evaluator(*args) + + def __hash__(self): + return hash(id(self)) + + def __eq__(self, other): + return self is other + + +class MockGraph: + def __init__(self): + self.graph = nx.MultiDiGraph() + self.input_nodes = {} + self.output_nodes = {} + self.input_indices = {} + + def ordered_preds_of(self, node): + idx_to_pred = {} + for pred in self.graph.predecessors(node): + edge_data = self.graph.get_edge_data(pred, node) + for data in edge_data.values(): + idx_to_pred[data["input_idx"]] = pred + return [idx_to_pred[i] for i in range(len(idx_to_pred))] + + +def _make_simple_graph(name_prefix, location_file): + """Build: input(x) -> double(x).""" + g = MockGraph() + + inp = MockNode(f"{name_prefix}_x", "input", lambda x: np.int64(x), + location=f"{location_file}:1") + double = MockNode(f"{name_prefix}_double", "generic", + lambda a: np.int64(a * 2), + inputs=[MockValueDescription()], + location=f"{location_file}:2", + tag=name_prefix) + + g.graph.add_node(inp) + g.graph.add_node(double) + g.graph.add_edge(inp, double, input_idx=0) + + g.input_nodes = {0: inp} + g.output_nodes = {0: double} + g.input_indices = {inp: 0} + + return g + + +def _make_chain_graph(name_prefix, location_file): + """Build: input(x) -> double(x) -> add1(x). Two non-input nodes for breakpoint tests.""" + g = MockGraph() + + inp = MockNode(f"{name_prefix}_x", "input", lambda x: np.int64(x), + location=f"{location_file}:1") + double = MockNode(f"{name_prefix}_double", "generic", + lambda a: np.int64(a * 2), + inputs=[MockValueDescription()], + location=f"{location_file}:2", + tag=name_prefix) + add1 = MockNode(f"{name_prefix}_add1", "generic", + lambda a: np.int64(a + 1), + inputs=[MockValueDescription()], + location=f"{location_file}:3", + tag=name_prefix) + + g.graph.add_node(inp) + g.graph.add_node(double) + g.graph.add_node(add1) + g.graph.add_edge(inp, double, input_idx=0) + g.graph.add_edge(double, add1, input_idx=0) + + g.input_nodes = {0: inp} + g.output_nodes = {0: add1} + g.input_indices = {inp: 0} + + return g + + +def _make_module(bp=None): + """Create a two-function module session.""" + if bp is None: + bp = BreakpointManager() + g1 = _make_simple_graph("f1", "f1.py") + g2 = _make_simple_graph("f2", "f2.py") + s1 = ConcreteDebugSession(g1, (5,), bp) + s2 = ConcreteDebugSession(g2, (10,), bp) + module = ModuleDebugSession([("func1", s1), ("func2", s2)], bp) + return module, bp + + +# ── Tests ── + + +class TestModuleStepOne: + def test_step_advances_to_next_function(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + assert module.current_function_name == "func1" + + reason = module.step_one() + # func1 double evaluated → func1 finishes → advance to func2 entry + assert reason == StopReason.ENTRY + assert module.current_function_name == "func2" + + def test_step_through_all_functions(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + + reason = module.step_one() + assert reason == StopReason.ENTRY # advanced to func2 + + reason = module.step_one() + assert reason == StopReason.FINISHED # func2 done, no more functions + + def test_snapshots_per_function(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + module.step_one() # finish func1, enter func2 + + # snapshots should be func2's snapshots (input was evaluated on advance) + assert len(module.snapshots) == 1 # func2's input + assert module.current_function_name == "func2" + + +class TestModuleContinue: + def test_continue_runs_all(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + + reason = module.continue_to_breakpoint() + assert reason == StopReason.FINISHED + assert module.finished + + def test_continue_stops_at_breakpoint_in_second_function(self): + bp = BreakpointManager() + module, _ = _make_module(bp) + + # Set breakpoint in func2 + bp.set_breakpoints("f2.py", [2]) + + module.evaluate_inputs_and_stop_on_entry() + reason = module.continue_to_breakpoint() + + assert reason == StopReason.BREAKPOINT + assert module.current_function_name == "func2" + assert not module.finished + + +class TestModuleStepOut: + def test_step_out_finishes_function_and_enters_next(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + + reason = module.step_out() + assert reason == StopReason.ENTRY + assert module.current_function_name == "func2" + + def test_step_out_last_function(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + + module.step_out() # finish func1 → enter func2 + reason = module.step_out() # finish func2 → no more functions + assert reason == StopReason.FINISHED + assert module.finished + + +class TestModuleStepIntoNextFunction: + def test_step_into_when_function_finished(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + + # Finish func1 via its inner session + module._active.step_out() + assert module._active.finished + + reason = module.step_into_next_function() + assert reason == StopReason.ENTRY + assert module.current_function_name == "func2" + + def test_step_into_when_not_finished_acts_as_step(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + + # func1 is not finished, step_into falls back to step_one + reason = module.step_into_next_function() + assert reason == StopReason.ENTRY # func1 finishes → advance + + +class TestModuleStackFrames: + def test_includes_module_frame(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + module.step_one() # advance to func2 + + frames = module.get_stack_frames() + assert len(frames) >= 2 + bottom = frames[-1] + assert "Module" in bottom["name"] + assert "func2" in bottom["name"] + assert "2/2" in bottom["name"] + + def test_module_frame_shows_progress(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + + frames = module.get_stack_frames() + bottom = frames[-1] + assert "1/2" in bottom["name"] + + +class TestModuleProperties: + def test_finished_false_while_running(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + assert not module.finished + + def test_finished_true_when_all_done(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + module.continue_to_breakpoint() + assert module.finished + + def test_all_snapshots_across_functions(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + module.continue_to_breakpoint() + + # Each function: 1 input + 1 double = 2 snapshots + all_snaps = module.all_snapshots + assert len(all_snaps) == 4 + + def test_current_function_name(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + assert module.current_function_name == "func1" + + module.step_one() + assert module.current_function_name == "func2" + + def test_error_from_active_session(self): + module, _ = _make_module() + module.evaluate_inputs_and_stop_on_entry() + assert module.error is None + + +class TestModuleBreakpointsAcrossFunctions: + def test_breakpoint_in_first_function(self): + bp = BreakpointManager() + # Use chain graphs so breakpoint isn't on the first non-input node + g1 = _make_chain_graph("f1", "f1.py") + g2 = _make_chain_graph("f2", "f2.py") + s1 = ConcreteDebugSession(g1, (5,), bp) + s2 = ConcreteDebugSession(g2, (10,), bp) + module = ModuleDebugSession([("func1", s1), ("func2", s2)], bp) + + bp.set_breakpoints("f1.py", [3]) # breakpoint on second non-input node + module.evaluate_inputs_and_stop_on_entry() + reason = module.continue_to_breakpoint() + + assert reason == StopReason.BREAKPOINT + assert module.current_function_name == "func1" + + def test_breakpoint_in_second_function(self): + bp = BreakpointManager() + g1 = _make_chain_graph("f1", "f1.py") + g2 = _make_chain_graph("f2", "f2.py") + s1 = ConcreteDebugSession(g1, (5,), bp) + s2 = ConcreteDebugSession(g2, (10,), bp) + module = ModuleDebugSession([("func1", s1), ("func2", s2)], bp) + + bp.set_breakpoints("f2.py", [3]) + module.evaluate_inputs_and_stop_on_entry() + reason = module.continue_to_breakpoint() + + assert reason == StopReason.BREAKPOINT + assert module.current_function_name == "func2" + + def test_breakpoints_in_both_functions(self): + bp = BreakpointManager() + g1 = _make_chain_graph("f1", "f1.py") + g2 = _make_chain_graph("f2", "f2.py") + s1 = ConcreteDebugSession(g1, (5,), bp) + s2 = ConcreteDebugSession(g2, (10,), bp) + module = ModuleDebugSession([("func1", s1), ("func2", s2)], bp) + + bp.set_breakpoints("f1.py", [3]) + bp.set_breakpoints("f2.py", [3]) + module.evaluate_inputs_and_stop_on_entry() + + reason = module.continue_to_breakpoint() + assert reason == StopReason.BREAKPOINT + assert module.current_function_name == "func1" + + reason = module.continue_to_breakpoint() + assert reason == StopReason.BREAKPOINT + assert module.current_function_name == "func2" diff --git a/tools/concrete-debugger/tests/test_protocol.py b/tools/concrete-debugger/tests/test_protocol.py new file mode 100644 index 0000000000..c96e429070 --- /dev/null +++ b/tools/concrete-debugger/tests/test_protocol.py @@ -0,0 +1,114 @@ +"""Tests for DAP protocol message I/O.""" + +import io +import json + +import numpy as np +import pytest + +from concrete_dap.protocol import ( + make_event, + make_response, + read_message, + write_message, +) + + +def _encode_dap(msg: dict) -> bytes: + """Encode a dict as a Content-Length framed DAP message.""" + body = json.dumps(msg).encode("utf-8") + return f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8") + body + + +class TestReadMessage: + def test_basic_read(self): + msg = {"seq": 1, "type": "request", "command": "initialize"} + stream = io.BytesIO(_encode_dap(msg)) + result = read_message(stream) + assert result == msg + + def test_empty_stream(self): + stream = io.BytesIO(b"") + result = read_message(stream) + assert result is None + + def test_multiple_messages(self): + msg1 = {"seq": 1, "type": "request", "command": "initialize"} + msg2 = {"seq": 2, "type": "request", "command": "launch"} + stream = io.BytesIO(_encode_dap(msg1) + _encode_dap(msg2)) + assert read_message(stream) == msg1 + assert read_message(stream) == msg2 + + def test_zero_content_length(self): + stream = io.BytesIO(b"Content-Length: 0\r\n\r\n") + result = read_message(stream) + assert result is None + + +class TestWriteMessage: + def test_basic_write(self): + msg = {"seq": 1, "type": "response", "success": True} + stream = io.BytesIO() + write_message(msg, stream) + stream.seek(0) + result = read_message(stream) + assert result == msg + + def test_numpy_serialization(self): + msg = {"value": np.int64(42)} + stream = io.BytesIO() + write_message(msg, stream) + stream.seek(0) + result = read_message(stream) + assert result["value"] == 42 + + def test_numpy_array_serialization(self): + msg = {"arr": np.array([1, 2, 3])} + stream = io.BytesIO() + write_message(msg, stream) + stream.seek(0) + result = read_message(stream) + assert result["arr"] == [1, 2, 3] + + def test_numpy_bool_serialization(self): + msg = {"flag": np.bool_(True)} + stream = io.BytesIO() + write_message(msg, stream) + stream.seek(0) + result = read_message(stream) + assert result["flag"] is True + + +class TestMakeResponse: + def test_success_response(self): + req = {"seq": 5, "type": "request", "command": "initialize"} + resp = make_response(req, body={"supportsConfigurationDone": True}) + assert resp["type"] == "response" + assert resp["request_seq"] == 5 + assert resp["command"] == "initialize" + assert resp["success"] is True + assert resp["body"]["supportsConfigurationDone"] is True + + def test_error_response(self): + req = {"seq": 3, "type": "request", "command": "launch"} + resp = make_response(req, success=False, message="file not found") + assert resp["success"] is False + assert resp["message"] == "file not found" + + def test_no_body(self): + req = {"seq": 1, "type": "request", "command": "disconnect"} + resp = make_response(req) + assert "body" not in resp + + +class TestMakeEvent: + def test_event_with_body(self): + evt = make_event("stopped", {"reason": "breakpoint", "threadId": 1}) + assert evt["type"] == "event" + assert evt["event"] == "stopped" + assert evt["body"]["reason"] == "breakpoint" + + def test_event_without_body(self): + evt = make_event("initialized") + assert evt["event"] == "initialized" + assert "body" not in evt diff --git a/tools/concrete-debugger/tests/test_server.py b/tools/concrete-debugger/tests/test_server.py new file mode 100644 index 0000000000..fb283f16f1 --- /dev/null +++ b/tools/concrete-debugger/tests/test_server.py @@ -0,0 +1,339 @@ +"""Integration tests for the DAP server — send raw DAP messages, verify responses.""" + +import io +import json +import os +import tempfile +from typing import Optional + +import numpy as np +import pytest + +from concrete_dap.protocol import read_message +from concrete_dap.server import DAPServer + + +def _encode_dap(msg: dict) -> bytes: + body = json.dumps(msg).encode("utf-8") + return f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8") + body + + +def _make_request(seq: int, command: str, arguments: Optional[dict] = None) -> dict: + req = {"seq": seq, "type": "request", "command": command} + if arguments is not None: + req["arguments"] = arguments + return req + + +class DAPTestClient: + """Helper to drive a DAPServer with in-memory streams.""" + + def __init__(self): + self._input = io.BytesIO() + self._output = io.BytesIO() + self._seq = 1 + + def send(self, command: str, arguments: Optional[dict] = None): + req = _make_request(self._seq, command, arguments) + self._seq += 1 + self._input.write(_encode_dap(req)) + + def run_server(self): + self._input.seek(0) + server = DAPServer(self._input, self._output) + server.run() + self._output.seek(0) + + def read_all_messages(self) -> list[dict]: + messages = [] + while True: + msg = read_message(self._output) + if msg is None: + break + messages.append(msg) + return messages + + def find_response(self, messages: list[dict], command: str) -> Optional[dict]: + for msg in messages: + if msg.get("type") == "response" and msg.get("command") == command: + return msg + return None + + def find_event(self, messages: list[dict], event: str) -> Optional[dict]: + for msg in messages: + if msg.get("type") == "event" and msg.get("event") == event: + return msg + return None + + def find_events(self, messages: list[dict], event: str) -> list[dict]: + return [m for m in messages if m.get("type") == "event" and m.get("event") == event] + + +class TestInitializeDisconnect: + def test_initialize(self): + client = DAPTestClient() + client.send("initialize", {"adapterID": "concrete"}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + init_resp = client.find_response(msgs, "initialize") + assert init_resp is not None + assert init_resp["success"] is True + assert init_resp["body"]["supportsConfigurationDoneRequest"] is True + + init_evt = client.find_event(msgs, "initialized") + assert init_evt is not None + + def test_disconnect(self): + client = DAPTestClient() + client.send("initialize", {"adapterID": "concrete"}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + disc_resp = client.find_response(msgs, "disconnect") + assert disc_resp is not None + assert disc_resp["success"] is True + + +class TestThreads: + def test_threads(self): + client = DAPTestClient() + client.send("initialize") + client.send("threads") + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "threads") + assert resp is not None + assert len(resp["body"]["threads"]) == 1 + assert resp["body"]["threads"][0]["name"] == "FHE Circuit Evaluation" + + +class TestLaunchErrors: + def test_missing_program(self): + client = DAPTestClient() + client.send("initialize") + client.send("launch", {"function": "f"}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "launch") + assert resp["success"] is False + assert "program" in resp["message"] + + def test_missing_function(self): + client = DAPTestClient() + client.send("initialize") + client.send("launch", {"program": "test.py"}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "launch") + assert resp["success"] is False + assert "function" in resp["message"] + + def test_nonexistent_program(self): + client = DAPTestClient() + client.send("initialize") + client.send("launch", {"program": "/nonexistent.py", "function": "f"}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "launch") + assert resp["success"] is False + + +class TestUnknownCommand: + def test_unknown(self): + client = DAPTestClient() + client.send("initialize") + client.send("foobar") + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "foobar") + assert resp is not None + assert resp["success"] is False + + +class TestSetBreakpointsWithoutSession: + def test_set_breakpoints_no_session(self): + client = DAPTestClient() + client.send("initialize") + client.send("setBreakpoints", { + "source": {"path": "test.py"}, + "breakpoints": [{"line": 5}], + }) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "setBreakpoints") + assert resp is not None + assert resp["success"] is True + # No nodes, so breakpoints won't be verified + assert resp["body"]["breakpoints"][0]["verified"] is False + + +class TestStackTraceWithoutSession: + def test_empty_stack(self): + client = DAPTestClient() + client.send("initialize") + client.send("stackTrace", {"threadId": 1}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "stackTrace") + assert resp["body"]["stackFrames"] == [] + assert resp["body"]["totalFrames"] == 0 + + +class TestScopesWithoutSession: + def test_empty_scopes(self): + client = DAPTestClient() + client.send("initialize") + client.send("scopes", {"frameId": 0}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "scopes") + assert resp["body"]["scopes"] == [] + + +class TestVariablesWithoutSession: + def test_empty_variables(self): + client = DAPTestClient() + client.send("initialize") + client.send("variables", {"variablesReference": 1}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "variables") + assert resp["body"]["variables"] == [] + + +class TestEvaluateWithoutSession: + def test_no_session(self): + client = DAPTestClient() + client.send("initialize") + client.send("evaluate", {"expression": "value"}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "evaluate") + assert "no active session" in resp["body"]["result"] + + +class TestSetFunctionBreakpoints: + def test_returns_empty(self): + client = DAPTestClient() + client.send("initialize") + client.send("setFunctionBreakpoints", {"breakpoints": []}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "setFunctionBreakpoints") + assert resp["success"] is True + + +class TestSetExceptionBreakpoints: + def test_returns_ok(self): + client = DAPTestClient() + client.send("initialize") + client.send("setExceptionBreakpoints", {"filters": []}) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "setExceptionBreakpoints") + assert resp["success"] is True + + +class TestConfigurationDoneWithoutSession: + def test_ok(self): + client = DAPTestClient() + client.send("initialize") + client.send("configurationDone") + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "configurationDone") + assert resp["success"] is True + + +class TestOutputEvents: + def test_launch_sends_loading_output(self): + # Launch with a nonexistent program still emits "Loading script..." before failing + client = DAPTestClient() + client.send("initialize") + client.send("launch", { + "program": "/nonexistent_script_for_test.py", + "function": "f", + }) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + output_events = client.find_events(msgs, "output") + outputs = [e["body"]["output"] for e in output_events] + assert any("Loading" in o for o in outputs) + + +class TestStopOnOverflowConfig: + def test_overflow_in_launch_config_accepted(self): + # Verify stopOnOverflow doesn't cause errors (program doesn't exist, but config is parsed) + client = DAPTestClient() + client.send("initialize") + client.send("launch", { + "program": "/nonexistent.py", + "function": "f", + "stopOnOverflow": True, + }) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + # Should fail on program execution, not on config parsing + resp = client.find_response(msgs, "launch") + assert resp["success"] is False + assert "nonexistent" in resp["message"].lower() or "failed" in resp["message"].lower() + + +class TestModuleLaunchErrors: + def test_functions_without_module(self): + """When functions config is given but object isn't a module, should error.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write("my_obj = 42\n") + f.flush() + temp_path = f.name + + try: + client = DAPTestClient() + client.send("initialize") + client.send("launch", { + "program": temp_path, + "function": "my_obj", + "functions": [{"name": "f1", "args": []}], + }) + client.send("disconnect") + client.run_server() + + msgs = client.read_all_messages() + resp = client.find_response(msgs, "launch") + assert resp["success"] is False + assert "module" in resp["message"].lower() + finally: + os.unlink(temp_path) diff --git a/tools/concrete-debugger/tests/test_session.py b/tools/concrete-debugger/tests/test_session.py new file mode 100644 index 0000000000..c410c32a1f --- /dev/null +++ b/tools/concrete-debugger/tests/test_session.py @@ -0,0 +1,461 @@ +"""Tests for ConcreteDebugSession — graph walker with stepping.""" + +import networkx as nx +import numpy as np +import pytest + +from concrete_dap.breakpoints import BreakpointManager +from concrete_dap.session import ConcreteDebugSession, StopReason, _FallbackSnapshot + + +# ── Helpers to build a mock graph ── + + +class MockValueDescription: + def __init__(self, shape=(), is_encrypted=True, dtype=None): + self.shape = shape + self.is_encrypted = is_encrypted + self.dtype = dtype or MockDtype() + + +class MockDtype: + def __init__(self, bit_width=8): + self.bit_width = bit_width + + def min(self): + return -(2 ** (self.bit_width - 1)) + + def max(self): + return 2 ** (self.bit_width - 1) - 1 + + +class MockNode: + """Minimal Node-like object for testing.""" + + def __init__(self, name, operation, evaluator, inputs=None, location="test.py:1", + tag="", output=None): + self.name = name + self.operation = operation # string: "input", "generic", "constant" + self.evaluator = evaluator + self.inputs = inputs or [] + self.location = location + self.tag = tag + self.output = output or MockValueDescription() + self.bounds = None + self.properties = {"name": name} + self.created_at = 0.0 + + def __call__(self, *args): + return self.evaluator(*args) + + def __hash__(self): + return hash(id(self)) + + def __eq__(self, other): + return self is other + + def label(self): + return self.name + + +class MockGraph: + """Minimal Graph-like object wrapping a networkx digraph.""" + + def __init__(self): + self.graph = nx.MultiDiGraph() + self.input_nodes = {} + self.output_nodes = {} + self.input_indices = {} + + def ordered_preds_of(self, node): + idx_to_pred = {} + for pred in self.graph.predecessors(node): + edge_data = self.graph.get_edge_data(pred, node) + for data in edge_data.values(): + idx_to_pred[data["input_idx"]] = pred + return [idx_to_pred[i] for i in range(len(idx_to_pred))] + + +def _make_add_graph(): + """Build: input(x) -> input(y) -> add(x, y).""" + g = MockGraph() + + inp_x = MockNode("x", "input", lambda x: np.int64(x), + location="script.py:5", tag="") + inp_y = MockNode("y", "input", lambda y: np.int64(y), + location="script.py:5", tag="") + add = MockNode("add", "generic", + lambda a, b: np.int64(a + b), + inputs=[MockValueDescription(), MockValueDescription()], + location="script.py:6", tag="") + + g.graph.add_node(inp_x) + g.graph.add_node(inp_y) + g.graph.add_node(add) + g.graph.add_edge(inp_x, add, input_idx=0) + g.graph.add_edge(inp_y, add, input_idx=1) + + g.input_nodes = {0: inp_x, 1: inp_y} + g.output_nodes = {0: add} + g.input_indices = {inp_x: 0, inp_y: 1} + + return g, inp_x, inp_y, add + + +def _make_chain_graph(): + """Build: input(x) -> double(x) -> triple(x).""" + g = MockGraph() + + inp = MockNode("x", "input", lambda x: np.int64(x), + location="chain.py:1", tag="layer1") + double = MockNode("double", "generic", + lambda a: np.int64(a * 2), + inputs=[MockValueDescription()], + location="chain.py:2", tag="layer1.double") + triple = MockNode("triple", "generic", + lambda a: np.int64(a * 3), + inputs=[MockValueDescription()], + location="chain.py:3", tag="layer1.triple") + + g.graph.add_node(inp) + g.graph.add_node(double) + g.graph.add_node(triple) + g.graph.add_edge(inp, double, input_idx=0) + g.graph.add_edge(double, triple, input_idx=0) + + g.input_nodes = {0: inp} + g.output_nodes = {0: triple} + g.input_indices = {inp: 0} + + return g, inp, double, triple + + +def _make_same_line_graph(): + """Build: input(x) -> mul(x, x) -> add(mul, x). Both ops on same line.""" + g = MockGraph() + + inp = MockNode("x", "input", lambda x: np.int64(x), + location="sameline.py:1") + mul = MockNode("mul", "generic", + lambda a, b: np.int64(a * b), + inputs=[MockValueDescription(), MockValueDescription()], + location="sameline.py:5") + add = MockNode("add", "generic", + lambda a, b: np.int64(a + b), + inputs=[MockValueDescription(), MockValueDescription()], + location="sameline.py:5") + + g.graph.add_node(inp) + g.graph.add_node(mul) + g.graph.add_node(add) + g.graph.add_edge(inp, mul, input_idx=0) + g.graph.add_edge(inp, mul, input_idx=1) + g.graph.add_edge(mul, add, input_idx=0) + g.graph.add_edge(inp, add, input_idx=1) + + g.input_nodes = {0: inp} + g.output_nodes = {0: add} + g.input_indices = {inp: 0} + + return g, inp, mul, add + + +# ── Tests ── + + +class TestStopOnEntry: + def test_evaluates_inputs_and_stops(self): + graph, inp_x, inp_y, add = _make_add_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (3, 5), bp) + + reason = session.evaluate_inputs_and_stop_on_entry() + assert reason == StopReason.ENTRY + assert len(session.snapshots) == 2 + assert not session.finished + assert session.current_index == 2 + + def test_empty_graph(self): + g = MockGraph() + bp = BreakpointManager() + session = ConcreteDebugSession(g, (), bp) + reason = session.evaluate_inputs_and_stop_on_entry() + assert reason == StopReason.FINISHED + assert session.finished + + def test_input_only_graph(self): + g = MockGraph() + inp = MockNode("x", "input", lambda x: np.int64(x), location="test.py:1") + g.graph.add_node(inp) + g.input_nodes = {0: inp} + g.input_indices = {inp: 0} + + bp = BreakpointManager() + session = ConcreteDebugSession(g, (42,), bp) + reason = session.evaluate_inputs_and_stop_on_entry() + assert reason == StopReason.FINISHED + assert len(session.snapshots) == 1 + assert int(session.snapshots[0].value) == 42 + + +class TestStepOne: + def test_step_through_add(self): + graph, *_ = _make_add_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (3, 5), bp) + + session.evaluate_inputs_and_stop_on_entry() + reason = session.step_one() + assert reason == StopReason.FINISHED + assert len(session.snapshots) == 3 + assert int(session.snapshots[-1].value) == 8 + + def test_step_through_chain(self): + graph, *_ = _make_chain_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (4,), bp) + + session.evaluate_inputs_and_stop_on_entry() + + reason = session.step_one() + assert reason == StopReason.STEP + assert int(session.snapshots[-1].value) == 8 + + reason = session.step_one() + assert reason == StopReason.FINISHED + assert int(session.snapshots[-1].value) == 24 + + def test_same_line_grouping(self): + graph, *_ = _make_same_line_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (3,), bp) + + session.evaluate_inputs_and_stop_on_entry() + + reason = session.step_one() + assert reason == StopReason.FINISHED + assert len(session.snapshots) == 3 # input + mul + add + assert int(session.snapshots[-1].value) == 12 # 3*3 + 3 + + def test_step_on_finished(self): + graph, *_ = _make_add_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (3, 5), bp) + session.evaluate_inputs_and_stop_on_entry() + session.step_one() + assert session.finished + reason = session.step_one() + assert reason == StopReason.FINISHED + + +class TestContinueToBreakpoint: + def test_continue_no_breakpoints(self): + graph, *_ = _make_chain_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (4,), bp) + session.evaluate_inputs_and_stop_on_entry() + + reason = session.continue_to_breakpoint() + assert reason == StopReason.FINISHED + assert session.finished + assert int(session.snapshots[-1].value) == 24 + + def test_continue_with_breakpoint(self): + graph, *_ = _make_chain_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (4,), bp) + # Set breakpoint at chain.py:3 (triple node) + bp.set_breakpoints("chain.py", [3]) + + session.evaluate_inputs_and_stop_on_entry() + reason = session.continue_to_breakpoint() + assert reason == StopReason.BREAKPOINT + assert len(session.snapshots) == 2 # input + double + assert int(session.snapshots[-1].value) == 8 + + def test_continue_on_finished(self): + graph, *_ = _make_add_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (1, 2), bp) + session.evaluate_inputs_and_stop_on_entry() + session.step_one() + assert session.finished + reason = session.continue_to_breakpoint() + assert reason == StopReason.FINISHED + + +class TestStepOut: + def test_step_out_runs_to_end(self): + graph, *_ = _make_chain_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (2,), bp) + session.evaluate_inputs_and_stop_on_entry() + + reason = session.step_out() + assert reason == StopReason.FINISHED + assert session.finished + assert int(session.snapshots[-1].value) == 12 + + +class TestStackFrames: + def test_tag_hierarchy(self): + graph, *_ = _make_chain_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (4,), bp) + session.evaluate_inputs_and_stop_on_entry() + session.step_one() + + frames = session.get_stack_frames() + assert len(frames) == 2 + assert "double" in frames[0]["name"] + assert "layer1" in frames[1]["name"] + + def test_no_tag(self): + graph, *_ = _make_add_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (1, 2), bp) + session.evaluate_inputs_and_stop_on_entry() + session.step_one() + + frames = session.get_stack_frames() + assert len(frames) == 1 + assert frames[0]["name"] == "add" + + def test_no_snapshot(self): + g = MockGraph() + bp = BreakpointManager() + session = ConcreteDebugSession(g, (), bp) + frames = session.get_stack_frames() + assert frames == [] + + +class TestEvaluationError: + def test_step_with_error(self): + g = MockGraph() + inp = MockNode("x", "input", lambda x: np.int64(x), location="err.py:1") + + def bad_eval(a): + raise ValueError("kaboom") + + bad = MockNode("bad", "generic", bad_eval, + inputs=[MockValueDescription()], location="err.py:2") + g.graph.add_node(inp) + g.graph.add_node(bad) + g.graph.add_edge(inp, bad, input_idx=0) + g.input_nodes = {0: inp} + g.input_indices = {inp: 0} + + bp = BreakpointManager() + session = ConcreteDebugSession(g, (1,), bp) + session.evaluate_inputs_and_stop_on_entry() + + reason = session.step_one() + assert reason == StopReason.EXCEPTION + assert session.error is not None + assert "kaboom" in str(session.error) + + +def _make_overflow_graph(): + """Build: input(x) -> mul(x, 100). With bit_width=8, mul overflows for x > 1.""" + g = MockGraph() + + inp = MockNode("x", "input", lambda x: np.int64(x), + location="overflow.py:1") + mul = MockNode("mul", "generic", + lambda a: np.int64(a * 100), + inputs=[MockValueDescription()], + location="overflow.py:2", + output=MockValueDescription(dtype=MockDtype(bit_width=8))) + + g.graph.add_node(inp) + g.graph.add_node(mul) + g.graph.add_edge(inp, mul, input_idx=0) + + g.input_nodes = {0: inp} + g.output_nodes = {0: mul} + g.input_indices = {inp: 0} + + return g, inp, mul + + +class TestOverflowStop: + def test_step_one_stops_on_overflow(self): + graph, inp, mul = _make_overflow_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (2,), bp, stop_on_overflow=True) + session.evaluate_inputs_and_stop_on_entry() + + reason = session.step_one() + assert reason == StopReason.OVERFLOW + assert session.snapshots[-1].overflow is True + + def test_step_one_no_stop_when_disabled(self): + graph, inp, mul = _make_overflow_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (2,), bp, stop_on_overflow=False) + session.evaluate_inputs_and_stop_on_entry() + + reason = session.step_one() + assert reason == StopReason.FINISHED # doesn't stop, runs to end + assert session.snapshots[-1].overflow is True + + def test_continue_stops_on_overflow(self): + graph, inp, mul = _make_overflow_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (2,), bp, stop_on_overflow=True) + session.evaluate_inputs_and_stop_on_entry() + + reason = session.continue_to_breakpoint() + assert reason == StopReason.OVERFLOW + + def test_step_out_stops_on_overflow(self): + graph, inp, mul = _make_overflow_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (2,), bp, stop_on_overflow=True) + session.evaluate_inputs_and_stop_on_entry() + + reason = session.step_out() + assert reason == StopReason.OVERFLOW + + def test_no_overflow_no_stop(self): + graph, inp, mul = _make_overflow_graph() + bp = BreakpointManager() + session = ConcreteDebugSession(graph, (1,), bp, stop_on_overflow=True) + session.evaluate_inputs_and_stop_on_entry() + + reason = session.step_one() + # 1*100=100, within [-128, 127] + assert reason == StopReason.FINISHED + assert session.snapshots[-1].overflow is False + + +class TestFallbackSnapshotOverflow: + def test_overflow_detected(self): + node = MockNode("mul", "generic", lambda a: a, + output=MockValueDescription(dtype=MockDtype(bit_width=8))) + snap = _FallbackSnapshot(node, np.int64(200), 0) + assert snap.overflow is True + + def test_no_overflow(self): + node = MockNode("mul", "generic", lambda a: a, + output=MockValueDescription(dtype=MockDtype(bit_width=8))) + snap = _FallbackSnapshot(node, np.int64(50), 0) + assert snap.overflow is False + + def test_negative_overflow(self): + node = MockNode("mul", "generic", lambda a: a, + output=MockValueDescription(dtype=MockDtype(bit_width=8))) + snap = _FallbackSnapshot(node, np.int64(-200), 0) + assert snap.overflow is True + + def test_array_overflow(self): + node = MockNode("mul", "generic", lambda a: a, + output=MockValueDescription(dtype=MockDtype(bit_width=8))) + snap = _FallbackSnapshot(node, np.array([50, 200]), 0) + assert snap.overflow is True + + def test_array_no_overflow(self): + node = MockNode("mul", "generic", lambda a: a, + output=MockValueDescription(dtype=MockDtype(bit_width=8))) + snap = _FallbackSnapshot(node, np.array([50, 100]), 0) + assert snap.overflow is False diff --git a/tools/concrete-debugger/tests/test_variables.py b/tools/concrete-debugger/tests/test_variables.py new file mode 100644 index 0000000000..f22fee5616 --- /dev/null +++ b/tools/concrete-debugger/tests/test_variables.py @@ -0,0 +1,211 @@ +"""Tests for VariableStore.""" + +import numpy as np +import pytest + +from concrete_dap.variables import VariableStore + + +class MockDtype: + def __init__(self, bit_width=8): + self.bit_width = bit_width + + +class MockValueDescription: + def __init__(self, is_encrypted=True, dtype=None): + self.is_encrypted = is_encrypted + self.dtype = dtype or MockDtype() + + +class MockNode: + def __init__(self, location="test.py:1", tag="", operation="generic", + is_encrypted=True, bounds=None, bit_width=8): + self.location = location + self.tag = tag + self.operation = operation + self.output = MockValueDescription(is_encrypted, MockDtype(bit_width)) + self.bounds = bounds + self.properties = {"name": "add"} + + +class MockSnapshot: + def __init__(self, value, index=0, node=None, overflow=False): + self.value = value + self.index = index + self.node = node or MockNode() + self.overflow = overflow + self.overflow_min = None + self.overflow_max = None + + @property + def location(self): + return self.node.location + + @property + def tag(self): + return self.node.tag + + @property + def operation_name(self): + return self.node.properties.get("name", "unknown") + + @property + def is_encrypted(self): + return self.node.output.is_encrypted + + +class TestVariableStore: + def test_scopes(self): + store = VariableStore() + snap = MockSnapshot(np.int64(42)) + scopes = store.scopes_for_stop(snap, [snap]) + assert len(scopes) == 2 + assert scopes[0]["name"] == "Current Node" + assert scopes[1]["name"] == "All Evaluated" + assert scopes[0]["variablesReference"] > 0 + assert scopes[1]["variablesReference"] > 0 + + def test_current_node_variables(self): + store = VariableStore() + node = MockNode(location="test.py:10", tag="layer1", bounds=(0, 100)) + snap = MockSnapshot(np.int64(42), node=node) + scopes = store.scopes_for_stop(snap, [snap]) + variables = store.get_variables(scopes[0]["variablesReference"]) + + var_dict = {v["name"]: v["value"] for v in variables} + assert var_dict["value"] == "np.int64(42)" + assert var_dict["operation"] == "add" + assert var_dict["encrypted"] == "True" + assert var_dict["bit_width"] == "8" + assert var_dict["overflow"] == "False" + assert var_dict["tag"] == "layer1" + assert var_dict["location"] == "test.py:10" + assert var_dict["bounds"] == "[0, 100]" + + def test_no_tag(self): + store = VariableStore() + snap = MockSnapshot(np.int64(1), node=MockNode(tag="")) + scopes = store.scopes_for_stop(snap, [snap]) + variables = store.get_variables(scopes[0]["variablesReference"]) + var_dict = {v["name"]: v["value"] for v in variables} + assert var_dict["tag"] == "(none)" + + def test_array_value_expandable(self): + store = VariableStore() + arr = np.arange(100) + snap = MockSnapshot(arr) + scopes = store.scopes_for_stop(snap, [snap]) + variables = store.get_variables(scopes[0]["variablesReference"]) + + value_var = next(v for v in variables if v["name"] == "value") + assert value_var["variablesReference"] > 0 # expandable + assert "shape" in value_var["value"] + + # Expand the array + arr_vars = store.get_variables(value_var["variablesReference"]) + assert len(arr_vars) == 100 + assert arr_vars[0]["name"] == "[0]" + assert arr_vars[0]["value"] == "0" + + def test_small_array_inline(self): + store = VariableStore() + arr = np.array([1, 2, 3]) + snap = MockSnapshot(arr) + scopes = store.scopes_for_stop(snap, [snap]) + variables = store.get_variables(scopes[0]["variablesReference"]) + + value_var = next(v for v in variables if v["name"] == "value") + assert value_var["variablesReference"] == 0 # not expandable + + def test_all_evaluated_scope(self): + store = VariableStore() + snaps = [ + MockSnapshot(np.int64(1), index=0), + MockSnapshot(np.int64(2), index=1), + MockSnapshot(np.int64(3), index=2), + ] + scopes = store.scopes_for_stop(snaps[-1], snaps) + variables = store.get_variables(scopes[1]["variablesReference"]) + + assert len(variables) == 3 + assert "[0]" in variables[0]["name"] + assert "[2]" in variables[2]["name"] + # Each should be expandable + assert variables[0]["variablesReference"] > 0 + + def test_reset(self): + store = VariableStore() + snap = MockSnapshot(np.int64(1)) + scopes = store.scopes_for_stop(snap, [snap]) + ref = scopes[0]["variablesReference"] + assert len(store.get_variables(ref)) > 0 + + store.reset() + assert store.get_variables(ref) == [] + + def test_large_array_capped(self): + store = VariableStore() + arr = np.arange(500) + snap = MockSnapshot(arr) + scopes = store.scopes_for_stop(snap, [snap]) + variables = store.get_variables(scopes[0]["variablesReference"]) + value_var = next(v for v in variables if v["name"] == "value") + + arr_vars = store.get_variables(value_var["variablesReference"]) + assert len(arr_vars) == 201 # 200 elements + "..." entry + assert arr_vars[-1]["name"] == "..." + + def test_unknown_ref(self): + store = VariableStore() + assert store.get_variables(999) == [] + + def test_overflow_shown(self): + store = VariableStore() + snap = MockSnapshot(np.int64(42), overflow=True) + scopes = store.scopes_for_stop(snap, [snap]) + variables = store.get_variables(scopes[0]["variablesReference"]) + var_dict = {v["name"]: v["value"] for v in variables} + assert var_dict["overflow"] == "True" + + +class TestModuleContextScope: + def test_module_scope_variables(self): + store = VariableStore() + scopes = store.scopes_for_module_stop("encrypt_layer", 0, 3, []) + assert len(scopes) == 1 + assert scopes[0]["name"] == "Module Context" + + variables = store.get_variables(scopes[0]["variablesReference"]) + var_dict = {v["name"]: v["value"] for v in variables} + assert var_dict["function"] == "encrypt_layer" + assert var_dict["progress"] == "1/3" + + def test_module_scope_with_snapshots(self): + store = VariableStore() + snaps = [ + MockSnapshot(np.int64(1), index=0), + MockSnapshot(np.int64(2), index=1), + ] + scopes = store.scopes_for_module_stop("compute", 1, 2, snaps) + variables = store.get_variables(scopes[0]["variablesReference"]) + + var_dict = {v["name"]: v for v in variables} + assert var_dict["progress"]["value"] == "2/2" + assert "all_functions_snapshots" in var_dict + assert var_dict["all_functions_snapshots"]["variablesReference"] > 0 + + # Expand all_functions_snapshots + all_vars = store.get_variables( + var_dict["all_functions_snapshots"]["variablesReference"] + ) + assert len(all_vars) == 2 + + def test_module_scope_no_snapshots_no_expand(self): + store = VariableStore() + scopes = store.scopes_for_module_stop("func", 0, 1, []) + variables = store.get_variables(scopes[0]["variablesReference"]) + + names = [v["name"] for v in variables] + assert "function" in names + assert "progress" in names + assert "all_functions_snapshots" not in names diff --git a/tools/vscode-concrete-debugger/.gitignore b/tools/vscode-concrete-debugger/.gitignore new file mode 100644 index 0000000000..36a7572b00 --- /dev/null +++ b/tools/vscode-concrete-debugger/.gitignore @@ -0,0 +1,4 @@ +node_modules/ +out/ +dap-server/ +*.vsix diff --git a/tools/vscode-concrete-debugger/.vscodeignore b/tools/vscode-concrete-debugger/.vscodeignore new file mode 100644 index 0000000000..a81d226f17 --- /dev/null +++ b/tools/vscode-concrete-debugger/.vscodeignore @@ -0,0 +1,6 @@ +.vscode/** +src/** +node_modules/** +tsconfig.json +**/*.ts +**/*.map diff --git a/tools/vscode-concrete-debugger/LICENSE b/tools/vscode-concrete-debugger/LICENSE new file mode 100644 index 0000000000..94b078be87 --- /dev/null +++ b/tools/vscode-concrete-debugger/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022-present, Zama +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tools/vscode-concrete-debugger/package-lock.json b/tools/vscode-concrete-debugger/package-lock.json new file mode 100644 index 0000000000..828ccaf52e --- /dev/null +++ b/tools/vscode-concrete-debugger/package-lock.json @@ -0,0 +1,59 @@ +{ + "name": "vscode-concrete-debugger", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "vscode-concrete-debugger", + "version": "0.1.0", + "license": "BSD-3-Clause", + "devDependencies": { + "@types/node": "^25.5.0", + "@types/vscode": "^1.80.0", + "typescript": "^5.0.0" + }, + "engines": { + "vscode": "^1.80.0" + } + }, + "node_modules/@types/node": { + "version": "25.5.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-25.5.0.tgz", + "integrity": "sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw==", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~7.18.0" + } + }, + "node_modules/@types/vscode": { + "version": "1.110.0", + "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.110.0.tgz", + "integrity": "sha512-AGuxUEpU4F4mfuQjxPPaQVyuOMhs+VT/xRok1jiHVBubHK7lBRvCuOMZG0LKUwxncrPorJ5qq/uil3IdZBd5lA==", + "dev": true, + "license": "MIT" + }, + "node_modules/typescript": { + "version": "5.9.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz", + "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "7.18.2", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.18.2.tgz", + "integrity": "sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==", + "dev": true, + "license": "MIT" + } + } +} diff --git a/tools/vscode-concrete-debugger/package.json b/tools/vscode-concrete-debugger/package.json new file mode 100644 index 0000000000..be0e2173de --- /dev/null +++ b/tools/vscode-concrete-debugger/package.json @@ -0,0 +1,132 @@ +{ + "name": "vscode-concrete-debugger", + "displayName": "Concrete FHE Debugger", + "description": "Debug Adapter for Concrete FHE circuits — step through DAG nodes, inspect intermediate values, and detect overflow.", + "version": "0.1.0", + "publisher": "zama-ai", + "license": "BSD-3-Clause", + "repository": { + "type": "git", + "url": "https://github.com/zama-ai/concrete" + }, + "engines": { + "vscode": "^1.80.0" + }, + "categories": [ + "Debuggers" + ], + "activationEvents": [ + "onDebugResolve:concrete" + ], + "main": "./out/extension.js", + "contributes": { + "debuggers": [ + { + "type": "concrete", + "label": "Concrete FHE", + "languages": [ + "python" + ], + "configurationAttributes": { + "launch": { + "required": [ + "program" + ], + "properties": { + "program": { + "type": "string", + "description": "Path to the Python script containing the FHE circuit.", + "default": "${file}" + }, + "function": { + "type": "string", + "description": "Name of the circuit/function variable to debug." + }, + "args": { + "type": "array", + "description": "Input values for the circuit.", + "default": [] + }, + "pythonPath": { + "type": "string", + "description": "Path to the Python interpreter.", + "default": "python3" + }, + "stopOnEntry": { + "type": "boolean", + "description": "Stop on the first non-input node.", + "default": true + }, + "stopOnOverflow": { + "type": "boolean", + "description": "Stop execution when a node overflows its bit width.", + "default": false + }, + "functions": { + "type": "array", + "description": "Ordered list of function calls for @fhe.module debugging. Each entry has a 'name' and optional 'args'.", + "items": { + "type": "object", + "required": ["name"], + "properties": { + "name": { + "type": "string", + "description": "Function name within the module." + }, + "args": { + "type": "array", + "description": "Input values for this function.", + "default": [] + } + } + } + } + } + } + }, + "configurationSnippets": [ + { + "label": "Concrete FHE: Debug Circuit", + "description": "Debug a Concrete FHE circuit with the simulator inspector.", + "body": { + "type": "concrete", + "request": "launch", + "name": "Debug FHE Circuit", + "program": "${file}", + "function": "${1:circuit_name}", + "args": [], + "stopOnEntry": true + } + }, + { + "label": "Concrete FHE: Debug Module", + "description": "Debug an @fhe.module with multiple functions.", + "body": { + "type": "concrete", + "request": "launch", + "name": "Debug FHE Module", + "program": "${file}", + "function": "${1:module_name}", + "functions": [ + { "name": "${2:function_name}", "args": [] } + ], + "stopOnEntry": true, + "stopOnOverflow": false + } + } + ] + } + ] + }, + "scripts": { + "bundle-dap": "rm -rf dap-server && mkdir -p dap-server/concrete_dap && cp ../concrete-debugger/concrete_dap_server.py dap-server/ && cp ../concrete-debugger/concrete_dap/*.py dap-server/concrete_dap/", + "compile": "tsc -p ./", + "build": "npm run bundle-dap && npm run compile", + "watch": "tsc -watch -p ./" + }, + "devDependencies": { + "@types/node": "^25.5.0", + "@types/vscode": "^1.80.0", + "typescript": "^5.0.0" + } +} diff --git a/tools/vscode-concrete-debugger/src/extension.ts b/tools/vscode-concrete-debugger/src/extension.ts new file mode 100644 index 0000000000..cce9ab0cfb --- /dev/null +++ b/tools/vscode-concrete-debugger/src/extension.ts @@ -0,0 +1,34 @@ +import * as vscode from 'vscode'; +import * as path from 'path'; + +export function activate(context: vscode.ExtensionContext) { + const factory = new ConcreteDebugAdapterFactory(context); + context.subscriptions.push( + vscode.debug.registerDebugAdapterDescriptorFactory('concrete', factory) + ); +} + +export function deactivate() { + // Nothing to clean up. +} + +class ConcreteDebugAdapterFactory implements vscode.DebugAdapterDescriptorFactory { + constructor(private readonly context: vscode.ExtensionContext) {} + + createDebugAdapterDescriptor( + session: vscode.DebugSession, + _executable: vscode.DebugAdapterExecutable | undefined + ): vscode.ProviderResult { + const config = session.configuration; + const pythonPath: string = config.pythonPath || 'python3'; + + // The DAP server Python script is bundled at dap-server/ inside the extension + const serverScript = path.join( + this.context.extensionPath, + 'dap-server', + 'concrete_dap_server.py' + ); + + return new vscode.DebugAdapterExecutable(pythonPath, [serverScript]); + } +} diff --git a/tools/vscode-concrete-debugger/tsconfig.json b/tools/vscode-concrete-debugger/tsconfig.json new file mode 100644 index 0000000000..5c341c948f --- /dev/null +++ b/tools/vscode-concrete-debugger/tsconfig.json @@ -0,0 +1,18 @@ +{ + "compilerOptions": { + "target": "ES2020", + "module": "commonjs", + "lib": ["ES2020"], + "outDir": "./out", + "rootDir": "./src", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "declaration": true, + "sourceMap": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "out"] +}