Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/mori/io/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ class Backend {
TransferStatus* status) = 0;

virtual bool CanHandle(const MemoryDesc& local, const MemoryDesc& remote) const { return true; }

// Returns the maximum memory region size the backend can register in a
// single ibv_reg_mr call. SIZE_MAX means no known limit.
virtual size_t GetMaxMemoryRegionSize() const { return SIZE_MAX; }
};

} // namespace io
Expand Down
4 changes: 4 additions & 0 deletions include/mori/io/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class IOEngine {
std::optional<IOEngineSession> CreateSession(const MemoryDesc& local, const MemoryDesc& remote);
void LoadScatterGatherModule(const std::string& hsacoPath);

// Returns the minimum max_mr_size across all RDMA backends/devices.
// SIZE_MAX means no known limit.
size_t GetMaxMemoryRegionSize() const;

private:
struct RouteCacheKey {
EngineKey remoteEngineKey;
Expand Down
8 changes: 8 additions & 0 deletions src/io/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,13 @@ bool IOEngine::PopInboundTransferStatus(EngineKey remote, TransferUniqueId id,
return false;
}

size_t IOEngine::GetMaxMemoryRegionSize() const {
size_t min_size = SIZE_MAX;
for (const auto& [type, be] : backends) {
min_size = std::min(min_size, be->GetMaxMemoryRegionSize());
}
return min_size;
}

} // namespace io
} // namespace mori
22 changes: 22 additions & 0 deletions src/io/rdma/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,5 +989,27 @@ void RdmaBackend::InvalidateSessionsForMemory(MemoryUniqueId id) {
}
}

size_t RdmaBackend::GetMaxMemoryRegionSize() const {
// IONIC (Pensando/AINIC) NICs report an incorrect max_mr_size via
// ibv_query_device. Cap to 2 GB for these devices.
static constexpr size_t kIonicMaxMrSize = 2ULL * 1024 * 1024 * 1024;

size_t min_size = SIZE_MAX;
for (const auto& [dev, port] : rdma->GetAvailDevices()) {
const auto* attr = dev->GetDeviceAttr();
if (!attr) continue;

size_t dev_max = static_cast<size_t>(attr->orig_attr.max_mr_size);
if (attr->orig_attr.vendor_id ==
static_cast<uint32_t>(application::RdmaDeviceVendorId::Pensando)) {
dev_max = kIonicMaxMrSize;
}
if (dev_max > 0) {
min_size = std::min(min_size, dev_max);
}
}
return min_size;
}

} // namespace io
} // namespace mori
2 changes: 2 additions & 0 deletions src/io/rdma/backend_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class RdmaManager {
std::vector<std::shared_ptr<EndpointRuntime>> SnapshotEndpointRuntimes();

application::RdmaDeviceContext* GetRdmaDeviceContext(int devId);
const application::ActiveDevicePortList& GetAvailDevices() const { return availDevices; }

private:
application::RdmaDeviceContext* GetOrCreateDeviceContext(int devId);
Expand Down Expand Up @@ -250,6 +251,7 @@ class RdmaBackend : public Backend {
bool isRead);
BackendSession* CreateSession(const MemoryDesc& local, const MemoryDesc& remote);
bool PopInboundTransferStatus(EngineKey remote, TransferUniqueId id, TransferStatus* status);
size_t GetMaxMemoryRegionSize() const override;

private:
void CreateSession(const MemoryDesc& local, const MemoryDesc& remote, RdmaBackendSession& sess);
Expand Down
3 changes: 2 additions & 1 deletion src/pybind/pybind_umbp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ void RegisterMoriUmbp(py::module_& m) {
.def_readwrite("io_engine_port", &UMBPDistributedConfig::io_engine_port)
.def_readwrite("staging_buffer_size", &UMBPDistributedConfig::staging_buffer_size)
.def_readwrite("peer_service_port", &UMBPDistributedConfig::peer_service_port)
.def_readwrite("cache_remote_fetches", &UMBPDistributedConfig::cache_remote_fetches);
.def_readwrite("cache_remote_fetches", &UMBPDistributedConfig::cache_remote_fetches)
.def_readwrite("max_mr_chunk_size", &UMBPDistributedConfig::max_mr_chunk_size);

py::class_<UMBPConfig>(m, "UMBPConfig")
.def(py::init<>())
Expand Down
6 changes: 5 additions & 1 deletion src/umbp/distributed/master/client_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ size_t ClientRegistry::UnregisterClient(const std::string& node_id) {
// PA-3 fix: exclusive lock because we mutate last_heartbeat and tier_capacities
ClientStatus ClientRegistry::Heartbeat(const std::string& node_id,
const std::map<TierType, TierCapacity>& tier_capacities) {
(void)tier_capacities;
std::unique_lock lock(mutex_);
auto it = clients_.find(node_id);
if (it == clients_.end()) {
Expand All @@ -239,6 +238,11 @@ ClientStatus ClientRegistry::Heartbeat(const std::string& node_id,
it->second.last_heartbeat = std::chrono::steady_clock::now();
it->second.status = ClientStatus::ALIVE;

// Update tier capacities reported by the client.
for (const auto& [tier, cap] : tier_capacities) {
it->second.tier_capacities[tier] = cap;
}

return ClientStatus::ALIVE;
}

Expand Down
83 changes: 65 additions & 18 deletions src/umbp/distributed/pool_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,44 @@ bool PoolClient::Init() {
staging_mem_ = io_engine_->RegisterMemory(staging_buffer_.get(), config_.staging_buffer_size,
-1, mori::io::MemoryLocationType::CPU);

// Determine effective chunk size for DRAM MR registration.
size_t device_max_mr = io_engine_->GetMaxMemoryRegionSize();
size_t effective_chunk = config_.max_mr_chunk_size > 0
? std::min(config_.max_mr_chunk_size, device_max_mr)
: device_max_mr;
// Align to system page size.
size_t page_size = static_cast<size_t>(sysconf(_SC_PAGE_SIZE));
if (effective_chunk != SIZE_MAX && effective_chunk > page_size) {
effective_chunk = (effective_chunk / page_size) * page_size;
}
// If effective_chunk covers all buffers, normalize to SIZE_MAX (no chunking).
size_t max_buffer_size = 0;
for (const auto& dram : config_.dram_buffers) {
if (dram.buffer && dram.size > 0) {
auto mem = io_engine_->RegisterMemory(dram.buffer, dram.size, -1,
max_buffer_size = std::max(max_buffer_size, dram.size);
}
if (effective_chunk >= max_buffer_size) {
effective_chunk = SIZE_MAX;
}
dram_chunk_size_ = effective_chunk;

if (dram_chunk_size_ != SIZE_MAX) {
MORI_UMBP_INFO("[PoolClient] DRAM MR chunk size: {} bytes (device_max={}, config={})",
dram_chunk_size_, device_max_mr, config_.max_mr_chunk_size);
}

// Register DRAM buffers, splitting into chunks if needed.
for (const auto& dram : config_.dram_buffers) {
if (!dram.buffer || dram.size == 0) continue;
size_t chunk = (dram_chunk_size_ != SIZE_MAX) ? dram_chunk_size_ : dram.size;
for (size_t off = 0; off < dram.size; off += chunk) {
size_t sz = std::min(chunk, dram.size - off);
auto mem = io_engine_->RegisterMemory(static_cast<char*>(dram.buffer) + off, sz, -1,
mori::io::MemoryLocationType::CPU);
export_dram_mems_.push_back(mem);
}
}

MORI_UMBP_INFO("[PoolClient] IOEngine initialized on {}:{} ({} DRAM buffers)",
MORI_UMBP_INFO("[PoolClient] IOEngine initialized on {}:{} ({} DRAM MR chunks)",
config_.io_engine_host, config_.io_engine_port, export_dram_mems_.size());
}

Expand All @@ -112,7 +141,7 @@ bool PoolClient::Init() {
msgpack::sbuffer mbuf;
msgpack::pack(mbuf, export_dram_mems_[i]);
dram_memory_desc_bytes_list.emplace_back(mbuf.data(), mbuf.data() + mbuf.size());
dram_buffer_sizes.push_back(config_.dram_buffers[i].size);
dram_buffer_sizes.push_back(export_dram_mems_[i].size);
}
}

Expand Down Expand Up @@ -214,20 +243,32 @@ bool PoolClient::RegisterMemory(void* ptr, size_t size) {
MORI_UMBP_ERROR("[PoolClient] RegisterMemory: IOEngine not available");
return false;
}
auto mem_desc = io_engine_->RegisterMemory(ptr, size, -1, mori::io::MemoryLocationType::CPU);
// Split into chunks matching dram_chunk_size_ to stay within MR limits.
size_t chunk = (dram_chunk_size_ != 0 && dram_chunk_size_ != SIZE_MAX) ? dram_chunk_size_ : size;
std::lock_guard<std::mutex> lock(registered_mem_mutex_);
registered_regions_.push_back({ptr, size, mem_desc});
MORI_UMBP_INFO("[PoolClient] RegisterMemory: ptr={}, size={}", ptr, size);
size_t num_chunks = 0;
for (size_t off = 0; off < size; off += chunk) {
size_t sz = std::min(chunk, size - off);
auto mem_desc = io_engine_->RegisterMemory(static_cast<char*>(ptr) + off, sz, -1,
mori::io::MemoryLocationType::CPU);
registered_regions_.push_back({static_cast<char*>(ptr) + off, sz, mem_desc, ptr});
++num_chunks;
}
MORI_UMBP_INFO("[PoolClient] RegisterMemory: ptr={}, size={}, chunks={}", ptr, size, num_chunks);
return true;
}

void PoolClient::DeregisterMemory(void* ptr) {
std::lock_guard<std::mutex> lock(registered_mem_mutex_);
auto it = std::find_if(registered_regions_.begin(), registered_regions_.end(),
[ptr](const RegisteredRegion& r) { return r.base == ptr; });
if (it != registered_regions_.end()) {
if (io_engine_) io_engine_->DeregisterMemory(it->mem_desc);
registered_regions_.erase(it);
// Remove all chunk entries belonging to the same original RegisterMemory() call.
auto it = registered_regions_.begin();
while (it != registered_regions_.end()) {
if (it->group_base == ptr) {
if (io_engine_) io_engine_->DeregisterMemory(it->mem_desc);
it = registered_regions_.erase(it);
} else {
++it;
}
}
}

Expand Down Expand Up @@ -491,13 +532,19 @@ PoolClient::PeerConnection& PoolClient::GetOrConnectPeer(
std::lock_guard<std::mutex> lock(peers_mutex_);
auto it = peers_.find(node_id);
if (it != peers_.end()) {
// Ensure dram_memories vector has the requested index populated
// Ensure dram_memories vector has the requested index populated.
// Always fill the slot when we have desc bytes, even if the vector
// was previously resized past this index by an out-of-order arrival.
auto& peer = *it->second;
if (buffer_index >= peer.dram_memories.size() && !dram_memory_desc_bytes.empty()) {
peer.dram_memories.resize(buffer_index + 1);
auto handle = msgpack::unpack(reinterpret_cast<const char*>(dram_memory_desc_bytes.data()),
dram_memory_desc_bytes.size());
peer.dram_memories[buffer_index] = handle.get().as<mori::io::MemoryDesc>();
if (!dram_memory_desc_bytes.empty()) {
if (buffer_index >= peer.dram_memories.size()) {
peer.dram_memories.resize(buffer_index + 1);
}
if (!IsValidMemoryDesc(peer.dram_memories[buffer_index])) {
auto handle = msgpack::unpack(reinterpret_cast<const char*>(dram_memory_desc_bytes.data()),
dram_memory_desc_bytes.size());
peer.dram_memories[buffer_index] = handle.get().as<mori::io::MemoryDesc>();
}
}
return peer;
}
Expand Down
9 changes: 9 additions & 0 deletions src/umbp/include/umbp/common/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ struct UMBPDistributedConfig {
uint16_t peer_service_port = 0; // gRPC peer service port

bool cache_remote_fetches = true; // cache remotely-fetched blocks locally

// Maximum single MR size for RDMA memory registration (bytes).
// 0 (default) = auto-detect from ibv_device_attr.max_mr_size.
// Set explicitly when auto-detection is unavailable or for testing.
// Env: UMBP_MAX_MR_CHUNK_SIZE
size_t max_mr_chunk_size = 0;
};

struct UMBPConfig {
Expand Down Expand Up @@ -365,6 +371,9 @@ struct PoolClientConfig {
std::map<TierType, TierCapacity> tier_capacities;

uint16_t peer_service_port = 0;

// Passed from UMBPDistributedConfig::max_mr_chunk_size.
size_t max_mr_chunk_size = 0;
};

} // namespace mori::umbp
12 changes: 10 additions & 2 deletions src/umbp/include/umbp/distributed/pool_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class PoolClient {

const std::string& NodeId() const { return config_.master_config.node_id; }

// Returns the effective chunk size used for DRAM MR registration.
// SIZE_MAX means no chunking was applied.
size_t DramChunkSize() const { return dram_chunk_size_; }

bool RegisterMemory(void* ptr, size_t size);
void DeregisterMemory(void* ptr);

Expand Down Expand Up @@ -150,9 +154,10 @@ class PoolClient {

// Zero-copy registered memory regions
struct RegisteredRegion {
void* base;
size_t size;
void* base; // this chunk's actual start address
size_t size; // this chunk's size
mori::io::MemoryDesc mem_desc;
void* group_base; // original RegisterMemory() caller's base pointer
};
std::mutex registered_mem_mutex_;
std::vector<RegisteredRegion> registered_regions_;
Expand All @@ -162,6 +167,9 @@ class PoolClient {

mutable std::mutex cache_mutex_;
std::unordered_map<std::string, Location> cluster_locations_;

// Effective MR chunk size for DRAM registration. SIZE_MAX = no chunking.
size_t dram_chunk_size_ = SIZE_MAX;
};

} // namespace mori::umbp
Loading
Loading