Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions external_parser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ set(binary_parser_headers
${CMAKE_CURRENT_LIST_DIR}/joiners/i_joiner.h
${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.h
${CMAKE_CURRENT_LIST_DIR}/log_converter.h
${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.h
${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_external.h
Expand All @@ -146,7 +146,7 @@ set(binary_parser_sources
${CMAKE_CURRENT_LIST_DIR}/joiners/example_joiner.cc
${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.cc
${CMAKE_CURRENT_LIST_DIR}/log_converter.cc
${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.cc
${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_external.cc
Expand Down
2 changes: 1 addition & 1 deletion external_parser/joiners/example_joiner.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include "../rlclientlib/lru_dedup_cache.h"
#include "event_processors/joined_event.h"
#include "event_processors/loop.h"
#include "joiners/i_joiner.h"
#include "lru_dedup_cache.h"
#include "metrics/metrics.h"
#include "parse_example_external.h"
#include "vw/core/error_constants.h"
Expand Down
2 changes: 1 addition & 1 deletion external_parser/joiners/i_joiner.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#pragma once

#include "../rlclientlib/lru_dedup_cache.h"
#include "event_processors/reward.h"
#include "generated/v2/CbEvent_generated.h"
#include "generated/v2/FileFormat_generated.h"
#include "generated/v2/Metadata_generated.h"
#include "lru_dedup_cache.h"
#include "metrics/metrics.h"
#include "parse_example_external.h"
#include "vw/core/error_constants.h"
Expand Down
2 changes: 1 addition & 1 deletion external_parser/unit_tests/test_lru_dedup_cache.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <boost/test/unit_test.hpp>

#include "lru_dedup_cache.h"
#include "../rlclientlib/lru_dedup_cache.h"
#include "parse_example_external.h"
#include "test_common.h"
#include "vw/config/options_cli.h"
Expand Down
10 changes: 10 additions & 0 deletions include/live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ class live_model
*/
int init(api_status* status = nullptr);

/**
* @brief Load dedup cache.
* Load the dedup cache from the specified file. This cache is used to
* prevent duplicate actions from being sent to the online trainer.
* @param hash Hash of the dedup cache
* @param action_str Action string
* @return int Return error code. This will also be returned in the api_status object
*/
int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status);
Comment thread
bassmang marked this conversation as resolved.
Outdated

/**
* @brief Choose an action, given a list of actions, action features and context features. The
* inference library chooses an action by creating a probability distribution over the actions
Expand Down
1 change: 1 addition & 0 deletions include/model_mgmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class i_model
{
public:
virtual int update(const model_data& data, bool& model_ready, api_status* status = nullptr) = 0;
virtual int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) = 0;
virtual int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) = 0;
virtual int choose_continuous_action(string_view features, float& action, float& pdf_value,
Expand Down
2 changes: 2 additions & 0 deletions rlclientlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ set(PROJECT_SOURCES
logger/logger_facade.cc
logger/preamble.cc
logger/preamble_sender.cc
lru_dedup_cache.cc
model_mgmt/data_callback_fn.cc
model_mgmt/empty_data_transport.cc
model_mgmt/file_model_loader.cc
Expand Down Expand Up @@ -149,6 +150,7 @@ set(PROJECT_PRIVATE_HEADERS
logger/async_batcher.h
logger/event_logger.h
logger/logger_facade.h
lru_dedup_cache.h
model_mgmt/data_callback_fn.h
model_mgmt/empty_data_transport.h
model_mgmt/file_model_loader.h
Expand Down
6 changes: 6 additions & 0 deletions rlclientlib/live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ std::vector<int> live_model::c_array_to_vector(const int* c_array, size_t array_
return std::vector<int>(c_array, c_array + array_size);
}

int live_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
INIT_CHECK();
return _pimpl->add_lru_dedup_cache(hash, action_str, status);
}

int live_model::choose_rank(
const char* event_id, string_view context_json, ranking_response& response, api_status* status)
{
Expand Down
5 changes: 5 additions & 0 deletions rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ int live_model_impl::init(api_status* status)
return error_code::success;
}

int live_model_impl::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
return _model->add_lru_dedup_cache(hash, action_str, status);
}

int live_model_impl::choose_rank(
const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/live_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class live_model_impl

int init(api_status* status);

int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status);
int choose_rank(
const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status);
// here the event_id is auto-generated
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct lru_dedup_cache
void* context = nullptr);
bool exists(uint64_t dedup_id);
void clear(release_example_f release_example = lru_dedup_cache::noop_release_example_f, void* context = nullptr);
std::unordered_map<uint64_t, VW::example*>* get_dict() { return &dedup_examples; }
Comment thread
bassmang marked this conversation as resolved.
Outdated

lru_dedup_cache() = default;
~lru_dedup_cache() = default;
Expand Down
8 changes: 7 additions & 1 deletion rlclientlib/vw_model/pdf_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace model_management
// We construct a VW object here to use the example parser to parse joined dsjson-style examples
// to extract the PDF.
pdf_model::pdf_model(i_trace* trace_logger, const utility::configuration& /*unused*/)
: _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf"))
: _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf", nullptr))
{
}

Expand All @@ -23,6 +23,12 @@ int pdf_model::update(const model_data& data, bool& model_ready, api_status* sta
return error_code::success;
}

// TODO: Implement LRU cache for PDF models.
int pdf_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
return error_code::not_supported;
}

int pdf_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/vw_model/pdf_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class pdf_model : public i_model
public:
pdf_model(i_trace* trace_logger, const utility::configuration& config);
int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override;
int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override;
int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) override;
int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version,
Expand Down
68 changes: 52 additions & 16 deletions rlclientlib/vw_model/safe_vw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ namespace reinforcement_learning
{
static const std::string SEED_TAG = "seed=";

safe_vw::safe_vw(std::shared_ptr<safe_vw> master) : _master(std::move(master))
safe_vw::safe_vw(std::shared_ptr<safe_vw> master, lru_dedup_cache* dedup_cache)
Comment thread
bassmang marked this conversation as resolved.
Outdated
: _master(std::move(master)), _dedup_cache(dedup_cache)
{
_vw = VW::seed_vw_model(_master->_vw, "", nullptr, nullptr);
init();
}

safe_vw::safe_vw(const char* model_data, size_t len)
safe_vw::safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache)
{
io_buf buf;
buf.add_file(VW::io::create_buffer_view(model_data, len));
Expand All @@ -34,7 +35,8 @@ safe_vw::safe_vw(const char* model_data, size_t len)
init();
}

safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline)
safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache)
: _dedup_cache(dedup_cache)
{
io_buf buf;
buf.add_file(VW::io::create_buffer_view(model_data, len));
Expand All @@ -43,7 +45,7 @@ safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_comma
init();
}

safe_vw::safe_vw(const std::string& vw_commandline)
safe_vw::safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache)
{
_vw = VW::initialize(vw_commandline);
init();
Expand Down Expand Up @@ -120,6 +122,24 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector<int>& acti
for (auto&& ex : examples) { _example_pool.emplace_back(ex); }
}

void safe_vw::add_lru_dedup_cache(uint64_t hash, std::string action_str)
{
if (_dedup_cache == nullptr) { _dedup_cache = new lru_dedup_cache(); }
Comment thread
bassmang marked this conversation as resolved.
Outdated
VW::multi_ex examples;
examples.push_back(get_or_create_example());

if (_vw->audit)
{
_vw->audit_buffer->clear();
VW::read_line_json_s<true>(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this);
}
else
{
VW::read_line_json_s<false>(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this);
}
_dedup_cache->add(hash, examples[0]);
}

void safe_vw::rank(string_view context, std::vector<int>& actions, std::vector<float>& scores)
{
VW::multi_ex examples;
Expand All @@ -131,9 +151,14 @@ void safe_vw::rank(string_view context, std::vector<int>& actions, std::vector<f
if (_vw->audit)
{
_vw->audit_buffer->clear();
VW::read_line_json_s<true>(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this);
VW::read_line_json_s<true>(
*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict());
}
else
{
VW::read_line_json_s<false>(
*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict());
}
else { VW::read_line_json_s<false>(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this); }

// finalize example
VW::setup_examples(*_vw, examples);
Expand Down Expand Up @@ -372,19 +397,30 @@ void safe_vw::init()
}
}

safe_vw_factory::safe_vw_factory(std::string command_line) : _command_line(std::move(command_line)) {}
safe_vw_factory::safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache)
: _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data) : _master_data(master_data) {}
safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data) : _master_data(master_data) {}
safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, std::string command_line)
: _master_data(master_data), _command_line(std::move(command_line))
safe_vw_factory::safe_vw_factory(
const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, std::string command_line)
: _master_data(master_data), _command_line(std::move(command_line))
safe_vw_factory::safe_vw_factory(
const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

Expand All @@ -393,13 +429,13 @@ safe_vw* safe_vw_factory::operator()()
if ((_master_data.data() != nullptr) && !_command_line.empty())
{
// Construct new vw object from raw model data and command line argument
return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line);
return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line, _dedup_cache);
}
if (_master_data.data() != nullptr)
{
// Construct new vw object from raw model data.
return new safe_vw(_master_data.data(), _master_data.data_sz());
return new safe_vw(_master_data.data(), _master_data.data_sz(), _dedup_cache);
}
return new safe_vw(_command_line);
return new safe_vw(_command_line, _dedup_cache);
}
} // namespace reinforcement_learning
24 changes: 15 additions & 9 deletions rlclientlib/vw_model/safe_vw.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "lru_dedup_cache.h"
#include "model_mgmt.h"
#include "vw/core/vw.h"

Expand All @@ -14,19 +15,21 @@ class safe_vw
std::shared_ptr<safe_vw> _master;
VW::workspace* _vw;
std::vector<VW::example*> _example_pool;
lru_dedup_cache* _dedup_cache;

VW::example* get_or_create_example();
static VW::example& get_or_create_example_f(void* vw);

public:
safe_vw(std::shared_ptr<safe_vw> master);
safe_vw(const char* model_data, size_t len, const std::string& vw_commandline);
safe_vw(const char* model_data, size_t len);
safe_vw(const std::string& vw_commandline);
safe_vw(std::shared_ptr<safe_vw> master, lru_dedup_cache* dedup_cache);
safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache);
safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache);
safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache);

~safe_vw();

void parse_context_with_pdf(string_view context, std::vector<int>& actions, std::vector<float>& scores);
void add_lru_dedup_cache(uint64_t hash, std::string action_str);
void rank(string_view context, std::vector<int>& actions, std::vector<float>& scores);
void choose_continuous_action(string_view context, float& action, float& pdf_value);
// Used for CCB
Expand Down Expand Up @@ -57,14 +60,17 @@ class safe_vw_factory
{
model_management::model_data _master_data;
std::string _command_line;
lru_dedup_cache* _dedup_cache;

public:
// model_data is copied and stored in the factory object.
safe_vw_factory(std::string command_line);
safe_vw_factory(const model_management::model_data& master_data);
safe_vw_factory(const model_management::model_data&& master_data);
safe_vw_factory(const model_management::model_data& master_data, std::string command_line);
safe_vw_factory(const model_management::model_data&& master_data, std::string command_line);
safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache);
safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache);
safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache);
safe_vw_factory(
const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache);
safe_vw_factory(
const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache);

safe_vw* operator()();
};
Expand Down
13 changes: 10 additions & 3 deletions rlclientlib/vw_model/vw_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ vw_model::vw_model(i_trace* trace_logger, const utility::configuration& config)
, _initial_command_line(std::string(config.get(name::MODEL_VW_INITIAL_COMMAND_LINE,
"--cb_explore_adf --json --quiet --epsilon 0.0 --first_only --id N/A")) +
(_audit ? " --audit" : ""))
, _vw_pool(safe_vw_factory(_initial_command_line),
, _vw_pool(safe_vw_factory(_initial_command_line, _dedup_cache),
config.get_int(name::VW_POOL_INIT_SIZE, value::DEFAULT_VW_POOL_INIT_SIZE), trace_logger)
, _trace_logger(trace_logger)
{
Expand All @@ -34,13 +34,13 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat
{
std::string cmd_line = add_optional_audit_flag(_quiet_commandline_options);

std::unique_ptr<safe_vw> init_vw(new safe_vw(data.data(), data.data_sz(), cmd_line));
std::unique_ptr<safe_vw> init_vw(new safe_vw(data.data(), data.data_sz(), cmd_line, _dedup_cache));
if (init_vw->is_CB_to_CCB_model_upgrade(_initial_command_line))
{
cmd_line = add_optional_audit_flag(_upgrade_to_CCB_vw_commandline_options);
}

safe_vw_factory factory(data, cmd_line);
safe_vw_factory factory(data, cmd_line, _dedup_cache);
std::unique_ptr<safe_vw> test_vw(factory());
if (test_vw->is_compatible(_initial_command_line))
{
Expand All @@ -67,6 +67,13 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat
return error_code::success;
}

int vw_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably move the "parse action and populate the cache" functionality into vw_model and just have a pointer of it in each vw instance

that way whenever you use vw for a rank call we could just do

auto vw = _vw_pool.get_or_create()
vw.set_action_cache(&action_cache)
vw.rank(...)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way it is done here the specific action is only passed into one of the vw instances in the object pool

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed lru_cache from safe_vw and constructor etc. Added as an arg to load_action and rank in safe_vw

{
auto vw = _vw_pool.get_or_create();
vw->add_lru_dedup_cache(hash, action_str);
return error_code::success;
}

int vw_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand Down
Loading