// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_H__
#define DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_H__
#include "find_max_factor_graph_nmplp_abstract.h"
#include <vector>
#include <map>
#include "../matrix.h"
#include "../hash.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl
{
class simple_hash_map
{
public:
simple_hash_map(
) :
scan_dist(6)
{
data.resize(5000);
}
void insert (
const unsigned long a,
const unsigned long b,
const unsigned long value
)
/*!
requires
- a != std::numeric_limits<unsigned long>::max()
ensures
- #(*this)(a,b) == value
!*/
{
const uint32 h = murmur_hash3_2(a,b)%(data.size()-scan_dist);
const unsigned long empty_bucket = std::numeric_limits<unsigned long>::max();
for (uint32 i = 0; i < scan_dist; ++i)
{
if (data[i+h].key1 == empty_bucket)
{
data[i+h].key1 = a;
data[i+h].key2 = b;
data[i+h].value = value;
return;
}
}
// if we get this far it means the hash table is filling up. So double its size.
std::vector<bucket> new_data;
new_data.resize(data.size()*2);
new_data.swap(data);
for (uint32 i = 0; i < new_data.size(); ++i)
{
if (new_data[i].key1 != empty_bucket)
{
insert(new_data[i].key1, new_data[i].key2, new_data[i].value);
}
}
insert(a,b,value);
}
unsigned long operator() (
const unsigned long a,
const unsigned long b
) const
/*!
requires
- this->insert(a,b,some_value) has been called
ensures
- returns the value stored at key (a,b)
!*/
{
DLIB_ASSERT(a != b, "An invalid map_problem was given to find_max_factor_graph_nmplp()."
<< "\nNode " << a << " is listed as being a neighbor with itself, which is illegal.");
uint32 h = murmur_hash3_2(a,b)%(data.size()-scan_dist);
for (unsigned long i = 0; i < scan_dist; ++i)
{
if (data[h].key1 == a && data[h].key2 == b)
{
return data[h].value;
}
++h;
}
// this should never happen (since this function requires (a,b) to be in the hash table
DLIB_ASSERT(false, "An invalid map_problem was given to find_max_factor_graph_nmplp()."
<< "\nThe nodes in the map_problem are inconsistent because node "<<a<<" is in the neighbor list"
<< "\nof node "<<b<< " but node "<<b<<" isn't in the neighbor list of node "<<a<<". The neighbor relationship"
<< "\nis supposed to be symmetric."
);
return 0;
}
private:
struct bucket
{
// having max() in key1 indicates that the bucket isn't used.
bucket() : key1(std::numeric_limits<unsigned long>::max()) {}
unsigned long key1;
unsigned long key2;
unsigned long value;
};
std::vector<bucket> data;
const unsigned int scan_dist;
};
}
// ----------------------------------------------------------------------------------------
template <
typename map_problem
>
void find_max_factor_graph_nmplp (
const map_problem& prob,
std::vector<unsigned long>& map_assignment,
unsigned long max_iter,
double eps
)
{
// make sure requires clause is not broken
DLIB_ASSERT( eps > 0,
"\t void find_max_factor_graph_nmplp()"
<< "\n\t eps must be greater than zero"
<< "\n\t eps: " << eps
);
/*
This function is an implementation of the NMPLP algorithm introduced in the
following paper:
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations
by Amir Globerson and Tommi Jaakkola
In particular, see the pseudocode in Figure 1. The code in this function
follows what is described there.
*/
typedef typename map_problem::node_iterator node_iterator;
typedef typename map_problem::neighbor_iterator neighbor_iterator;
map_assignment.resize(prob.number_of_nodes());
if (prob.number_of_nodes() == 0)
return;
std::vector<double> gamma_elements;
gamma_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3);
impl::simple_hash_map gamma_idx;
// initialize gamma according to the initialization instructions at top of Figure 1
for (node_iterator i = prob.begin(); i != prob.end(); ++i)
{
const unsigned long id_i = prob.node_id(i);
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
{
const unsigned long id_j = prob.node_id(j);
gamma_idx.insert(id_i, id_j, gamma_elements.size());
const unsigned long num_states_xj = prob.num_states(j);
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
{
const unsigned long num_states_xi = prob.num_states(i);
double best_val = -std::numeric_limits<double>::infinity();
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
{
double val = prob.factor_value(i,j,xi,xj);
double sum_temp = 0;
for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k)
{
if (j == k)
continue;
double max_val = -std::numeric_limits<double>::infinity();
for (unsigned long xk = 0; xk < prob.num_states(k); ++xk)
{
double temp = prob.factor_value(k,i,xk,xi);
if (temp > max_val)
max_val = temp;
}
sum_temp += max_val;
}
val += 0.5*sum_temp;
if (val > best_val)
best_val = val;
}
gamma_elements.push_back(best_val);
}
}
}
double max_change = eps + 1;
// Now do the main body of the optimization.
for (unsigned long iter = 0; iter < max_iter && max_change > eps; ++iter)
{
max_change = -std::numeric_limits<double>::infinity();
for (node_iterator i = prob.begin(); i != prob.end(); ++i)
{
const unsigned long id_i = prob.node_id(i);
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
{
const unsigned long id_j = prob.node_id(j);
double* const gamma_ji = &gamma_elements[gamma_idx(id_j,id_i)];
double* const gamma_ij = &gamma_elements[gamma_idx(id_i,id_j)];
const unsigned long num_states_xj = prob.num_states(j);
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
{
const unsigned long num_states_xi = prob.num_states(i);
double best_val = -std::numeric_limits<double>::infinity();
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
{
double val = prob.factor_value(i,j,xi,xj) - gamma_ji[xi];
double sum_temp = 0;
int num_neighbors = 0;
for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k)
{
const unsigned long id_k = prob.node_id(k);
++num_neighbors;
const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)];
sum_temp += gamma_ki[xi];
}
val += 2.0/(num_neighbors + 1.0)*sum_temp;
if (val > best_val)
best_val = val;
}
if (std::abs(gamma_ij[xj] - best_val) > max_change)
max_change = std::abs(gamma_ij[xj] - best_val);
gamma_ij[xj] = best_val;
}
}
}
}
// now decode the "beliefs"
std::vector<double> b;
for (node_iterator i = prob.begin(); i != prob.end(); ++i)
{
const unsigned long id_i = prob.node_id(i);
b.assign(prob.num_states(i), 0);
for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k)
{
const unsigned long id_k = prob.node_id(k);
for (unsigned long xi = 0; xi < b.size(); ++xi)
{
const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)];
b[xi] += gamma_ki[xi];
}
}
map_assignment[id_i] = index_of_max(mat(b));
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_H__