Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,22 @@ def Tracing_TraceMessageOp : Tracing_Op<"trace_message"> {
let arguments = (ins OptionalAttr<StrAttr> : $msg);
}

def Tracing_DebugProbeOp : Tracing_Op<"debug_probe"> {
let summary = "Captures a value into the debug probe buffer.";

let arguments = (ins
Type<Or<[
FHE_EncryptedUnsignedIntegerType.predicate,
FHE_EncryptedSignedIntegerType.predicate,
TFHE_GLWECipherTextType.predicate,
1DTensorOf<[I64]>.predicate,
MemRefRankOf<[I64], [1]>.predicate,
AnyInteger.predicate
]>>: $value,
I32Attr: $probe_id,
OptionalAttr<StrAttr>: $tag,
OptionalAttr<I32Attr>: $nmsb
);
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_RUNTIME_DEBUG_PROBES_H
#define CONCRETELANG_RUNTIME_DEBUG_PROBES_H

#include <cstdint>
#include <mutex>
#include <string>
#include <vector>

namespace mlir {
namespace concretelang {
namespace debug {

struct ProbeEntry {
uint32_t probe_id;
std::string tag;
int64_t value;
uint32_t nmsb;
};

class ProbeBuffer {
std::mutex mutex;
std::vector<ProbeEntry> entries;

public:
static ProbeBuffer &instance();

void reset();

void record_plaintext(uint32_t probe_id, int64_t value, const char *tag_ptr,
uint32_t tag_len, uint32_t nmsb);

size_t size() const;

const ProbeEntry &get(size_t index) const;

const std::vector<ProbeEntry> &all() const;
};

} // namespace debug
} // namespace concretelang
} // namespace mlir

extern "C" {
void memref_debug_probe_plaintext(int64_t value, int64_t input_width,
int32_t probe_id, char *tag_ptr,
int32_t tag_len, int32_t nmsb);

void debug_probe_buffer_reset();

uint64_t debug_probe_buffer_size();
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@ void memref_trace_plaintext(uint64_t input, uint64_t input_width,

void memref_trace_message(char *message_ptr, uint32_t message_len);

// Debug probes ////////////////////////////////////////////////////////////////
void memref_debug_probe_plaintext(int64_t value, int64_t input_width,
int32_t probe_id, char *tag_ptr,
int32_t tag_len, int32_t nmsb);

void debug_probe_buffer_reset();

uint64_t debug_probe_buffer_size();

/// @brief Allocate memory using malloc and check for nullptr
/// @param size number of bytes to allocate
/// @return pointer to the allocated memory or nullptr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "concretelang/Common/Values.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/GPUDFG.hpp"
#include "concretelang/Runtime/debug_probes.h"
#include "concretelang/ServerLib/ServerLib.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
Expand Down Expand Up @@ -2347,4 +2348,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
}
return result.value();
});

// Debug probes
m.def("debug_probe_buffer_reset", &debug_probe_buffer_reset);
m.def("debug_probe_buffer_size", &debug_probe_buffer_size);
m.def("debug_probe_get_entries", []() -> pybind11::list {
auto &buf = mlir::concretelang::debug::ProbeBuffer::instance();
pybind11::list result;
for (const auto &entry : buf.all()) {
pybind11::dict d;
d["probe_id"] = entry.probe_id;
d["tag"] = entry.tag;
d["value"] = entry.value;
d["nmsb"] = entry.nmsb;
result.append(d);
}
return result;
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,39 @@ struct ZeroTensorOpPattern
};
};

struct DebugProbeOpPattern
: public mlir::OpConversionPattern<Tracing::DebugProbeOp> {
DebugProbeOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<Tracing::DebugProbeOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}

::mlir::LogicalResult
matchAndRewrite(Tracing::DebugProbeOp debugProbeOp,
Tracing::DebugProbeOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
// After simulation type conversion, the value is now i64.
// Re-create the op with the converted operand.
auto newOp = rewriter.replaceOpWithNewOp<Tracing::DebugProbeOp>(
debugProbeOp, mlir::TypeRange{}, adaptor.getValue());

newOp.setProbeId(debugProbeOp.getProbeId());

if (auto tag = debugProbeOp.getTag())
newOp.setTag(tag);

if (auto nmsb = debugProbeOp.getNmsb())
newOp.setNmsb(nmsb);

auto inputWidth =
newOp.getValue().getType().cast<mlir::IntegerType>().getWidth();
newOp->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));

return ::mlir::success();
};
};

struct TraceCiphertextOpPattern
: public mlir::OpConversionPattern<Tracing::TraceCiphertextOp> {
TraceCiphertextOpPattern(mlir::MLIRContext *context,
Expand Down Expand Up @@ -747,6 +780,10 @@ void SimulateTFHEPass::runOnOperation() {
mlir::tensor::CastOp, mlir::LLVM::GlobalOp,
mlir::LLVM::AddressOfOp, mlir::LLVM::GEPOp,
Tracing::TracePlaintextOp>();
// DebugProbeOp is dynamically legal after type conversion
target.addDynamicallyLegalOp<Tracing::DebugProbeOp>([&](mlir::Operation *op) {
return converter.isLegal(op->getOperandTypes());
});
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();

Expand Down Expand Up @@ -821,7 +858,8 @@ void SimulateTFHEPass::runOnOperation() {
patterns.insert<ZeroOpPattern, ZeroTensorOpPattern, KeySwitchGLWEOpPattern,
WopPBSGLWEOpPattern, EncodeLutForCrtWopPBSOpPattern,
EncodePlaintextWithCrtOpPattern, NegOpPattern,
TraceCiphertextOpPattern>(&getContext(), converter);
TraceCiphertextOpPattern, DebugProbeOpPattern>(
&getContext(), converter);
patterns.insert<SubIntGLWEOpPattern>(&getContext());

// if overflow detection is enable, then rewrite to CAPI functions that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace memref = mlir::memref;
char memref_trace_ciphertext[] = "memref_trace_ciphertext";
char memref_trace_plaintext[] = "memref_trace_plaintext";
char memref_trace_message[] = "memref_trace_message";
char memref_debug_probe_plaintext[] = "memref_debug_probe_plaintext";

mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
Expand Down Expand Up @@ -76,6 +77,14 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
{mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getI32Type()},
{});
} else if (funcName == memref_debug_probe_plaintext) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{rewriter.getI64Type(), rewriter.getI64Type(),
rewriter.getI32Type(),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getI32Type(), rewriter.getI32Type()},
{});
} else {
op->emitError("unknwon external function") << funcName;
return mlir::failure();
Expand Down Expand Up @@ -190,6 +199,38 @@ void traceMessageAddOperands(Tracing::TraceMessageOp op,
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
}

void debugProbeAddOperands(Tracing::DebugProbeOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
auto tag = op.getTag().value_or("");
auto nmsb = op.getNmsb().value_or(0);
auto probeId = op.getProbeId();

// input_width (set as attribute by SimulateTFHE pass)
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op->getAttr("input_width")));

// probe_id
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(probeId)));

// tag string
std::string tagName;
std::stringstream stream;
stream << rand();
stream >> tagName;
auto tagVal = mlir::LLVM::createGlobalString(
op.getLoc(), rewriter, tagName, tag,
mlir::LLVM::linkage::Linkage::Linkonce, false);
operands.push_back(tagVal);
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(tag.size())));

// nmsb
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(nmsb)));
}

struct TracingToCAPIPass : public TracingToCAPIBase<TracingToCAPIPass> {

TracingToCAPIPass() {}
Expand Down Expand Up @@ -219,6 +260,9 @@ struct TracingToCAPIPass : public TracingToCAPIBase<TracingToCAPIPass> {
patterns.add<TracingToCAPICallPattern<Tracing::TraceMessageOp,
memref_trace_message>>(
&getContext(), traceMessageAddOperands);
patterns.add<TracingToCAPICallPattern<Tracing::DebugProbeOp,
memref_debug_probe_plaintext>>(
&getContext(), debugProbeAddOperands);

// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ if(CONCRETELANG_CUDA_SUPPORT)
GPUDFG.cpp
simulation.cpp
wrappers.cpp
time_util.cpp)
time_util.cpp
debug_probes.cpp)
else()
add_library(
ConcretelangRuntime SHARED
Expand All @@ -19,7 +20,8 @@ else()
GPUDFG.cpp
simulation.cpp
wrappers.cpp
time_util.cpp)
time_util.cpp
debug_probes.cpp)
endif()

add_dependencies(ConcretelangRuntime rust_deps_bundle concrete-protocol)
Expand Down Expand Up @@ -83,7 +85,8 @@ add_library(
GPUDFG.cpp
simulation.cpp
wrappers.cpp
time_util.cpp)
time_util.cpp
debug_probes.cpp)

if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
target_link_libraries(ConcretelangRuntimeStatic PUBLIC omp)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt
// for license information.

#include "concretelang/Runtime/debug_probes.h"
#include <cassert>

namespace mlir {
namespace concretelang {
namespace debug {

ProbeBuffer &ProbeBuffer::instance() {
static ProbeBuffer buf;
return buf;
}

void ProbeBuffer::reset() {
std::lock_guard<std::mutex> lock(mutex);
entries.clear();
}

void ProbeBuffer::record_plaintext(uint32_t probe_id, int64_t value,
const char *tag_ptr, uint32_t tag_len,
uint32_t nmsb) {
std::lock_guard<std::mutex> lock(mutex);
entries.push_back(
{probe_id, std::string(tag_ptr, tag_len), value, nmsb});
}

size_t ProbeBuffer::size() const { return entries.size(); }

const ProbeEntry &ProbeBuffer::get(size_t index) const {
assert(index < entries.size());
return entries[index];
}

const std::vector<ProbeEntry> &ProbeBuffer::all() const { return entries; }

} // namespace debug
} // namespace concretelang
} // namespace mlir

extern "C" {

void memref_debug_probe_plaintext(int64_t value, int64_t input_width,
int32_t probe_id, char *tag_ptr,
int32_t tag_len, int32_t nmsb) {
mlir::concretelang::debug::ProbeBuffer::instance().record_plaintext(
static_cast<uint32_t>(probe_id), value,
tag_ptr, static_cast<uint32_t>(tag_len),
static_cast<uint32_t>(nmsb));
}

void debug_probe_buffer_reset() {
mlir::concretelang::debug::ProbeBuffer::instance().reset();
}

uint64_t debug_probe_buffer_size() {
return static_cast<uint64_t>(
mlir::concretelang::debug::ProbeBuffer::instance().size());
}
}
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

* [Simulation](execution-analysis/simulation.md)
* [Debugging and artifact](execution-analysis/debug.md)
* [Interactive debugger](interactive-debugger.md)
* [Performance](optimization/summary.md)
* [GPU acceleration](execution-analysis/gpu_acceleration.md)
* [Rust integration](execution-analysis/rust_integration.md)
Expand Down
Loading