This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/CountingBp.h

221 lines
4.6 KiB
C
Raw Normal View History

2013-02-07 17:50:02 +00:00
#ifndef YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
#define YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
2012-05-23 14:56:01 +01:00
2013-02-07 20:09:10 +00:00
#include <vector>
#include <unordered_map>
2012-05-23 14:56:01 +01:00
2012-11-14 21:55:51 +00:00
#include "GroundSolver.h"
#include "FactorGraph.h"
#include "Horus.h"
2012-05-23 14:56:01 +01:00
2013-02-07 13:37:15 +00:00
namespace Horus {
2013-02-07 23:53:13 +00:00
class VarCluster;
class FacCluster;
class WeightedBp;
typedef long Color;
2013-02-07 13:37:15 +00:00
typedef std::vector<Color> Colors;
typedef std::vector<std::pair<Color,unsigned>> VarSignature;
typedef std::vector<Color> FacSignature;
2013-02-07 13:37:15 +00:00
typedef std::unordered_map<unsigned, Color> DistColorMap;
typedef std::unordered_map<unsigned, Colors> VarColorMap;
2013-02-07 13:37:15 +00:00
typedef std::unordered_map<VarSignature, VarNodes> VarSignMap;
typedef std::unordered_map<FacSignature, FacNodes> FacSignMap;
2013-02-07 13:37:15 +00:00
typedef std::unordered_map<VarId, VarCluster*> VarClusterMap;
2013-02-07 13:37:15 +00:00
typedef std::vector<VarCluster*> VarClusters;
typedef std::vector<FacCluster*> FacClusters;
template <class T>
inline size_t hash_combine (size_t seed, const T& v)
{
2013-02-07 13:37:15 +00:00
return seed ^ (std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}
} // namespace Horus
2013-02-07 23:53:13 +00:00
namespace std {
2013-02-08 00:56:42 +00:00
template <typename T1, typename T2> struct hash<std::pair<T1,T2>> {
size_t operator() (const std::pair<T1,T2>& p) const {
return Horus::hash_combine (std::hash<T1>()(p.first), p.second);
2013-02-08 00:56:42 +00:00
}};
2013-02-08 01:11:18 +00:00
template <typename T> struct hash<std::vector<T>>
{
size_t operator() (const std::vector<T>& vec) const
{
size_t h = 0;
2013-02-07 13:37:15 +00:00
typename std::vector<T>::const_iterator first = vec.begin();
typename std::vector<T>::const_iterator last = vec.end();
for (; first != last; ++first) {
h = Horus::hash_combine (h, *first);
}
return h;
2013-02-08 01:11:18 +00:00
}
};
2013-02-07 23:53:13 +00:00
} // namespace std
namespace Horus {
2013-02-07 23:53:13 +00:00
class VarCluster {
2012-05-23 14:56:01 +01:00
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
2012-05-23 14:56:01 +01:00
const VarNode* first (void) const { return members_.front(); }
2012-05-23 14:56:01 +01:00
const VarNodes& members (void) const { return members_; }
2012-05-23 14:56:01 +01:00
VarNode* representative (void) const { return repr_; }
2012-05-23 14:56:01 +01:00
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
2012-12-27 22:25:45 +00:00
DISALLOW_COPY_AND_ASSIGN (VarCluster);
};
class FacCluster {
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
2012-12-17 18:39:42 +00:00
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
VarClusters& varClusters (void) { return varClusters_; }
2012-12-17 18:39:42 +00:00
2012-05-23 14:56:01 +01:00
private:
FacNodes members_;
FacNode* repr_;
VarClusters varClusters_;
2012-12-27 22:25:45 +00:00
DISALLOW_COPY_AND_ASSIGN (FacCluster);
2012-05-23 14:56:01 +01:00
};
class CountingBp : public GroundSolver {
public:
CountingBp (const FactorGraph& fg);
2012-05-23 14:56:01 +01:00
~CountingBp (void);
2012-05-23 14:56:01 +01:00
void printSolverFlags (void) const;
Params solveQuery (VarIds);
2012-12-17 18:39:42 +00:00
2012-12-27 23:21:32 +00:00
static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; }
2012-12-17 18:39:42 +00:00
private:
Color getNewColor (void);
Color getColor (const VarNode* vn) const;
Color getColor (const FacNode* fn) const;
2012-05-23 14:56:01 +01:00
void setColor (const VarNode* vn, Color c);
void setColor (const FacNode* fn, Color c);
void findIdenticalFactors (void);
void setInitialColors (void);
2012-05-23 14:56:01 +01:00
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
VarSignature getSignature (const VarNode*);
FacSignature getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
VarId getRepresentative (VarId vid);
FacNode* getRepresentative (FacNode*);
2012-05-23 14:56:01 +01:00
FactorGraph* getCompressedFactorGraph (void);
2012-05-23 14:56:01 +01:00
2013-02-07 13:37:15 +00:00
std::vector<std::vector<unsigned>> getWeights (void) const;
2012-05-23 14:56:01 +01:00
unsigned getWeight (const FacCluster*,
const VarCluster*, size_t index) const;
2012-05-23 14:56:01 +01:00
Color freeColor_;
Colors varColors_;
Colors facColors_;
VarClusters varClusters_;
FacClusters facClusters_;
VarClusterMap varClusterMap_;
const FactorGraph* compressedFg_;
WeightedBp* solver_;
2012-12-27 22:25:45 +00:00
2012-12-27 23:21:32 +00:00
static bool fif_;
2012-12-27 22:25:45 +00:00
DISALLOW_COPY_AND_ASSIGN (CountingBp);
2012-05-23 14:56:01 +01:00
};
inline Color
CountingBp::getNewColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
inline Color
CountingBp::getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
inline Color
CountingBp::getColor (const FacNode* fn) const
{
return facColors_[fn->getIndex()];
}
inline void
CountingBp::setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
inline void
CountingBp::setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
} // namespace Horus
2013-02-07 23:53:13 +00:00
2013-02-08 00:20:01 +00:00
#endif // YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
2012-05-23 14:56:01 +01:00