Move methods with more than two lines to outside of class definition

This commit is contained in:
Tiago Gomes 2013-02-06 00:24:02 +00:00
parent 0d9d59f5fe
commit 42a5bc493a
25 changed files with 1442 additions and 856 deletions

View File

@ -3,6 +3,23 @@
#include "BayesBall.h" #include "BayesBall.h"
BayesBall::BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph*
BayesBall::getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
FactorGraph* FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds) BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{ {

View File

@ -29,19 +29,11 @@ typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
class BayesBall class BayesBall
{ {
public: public:
BayesBall (FactorGraph& fg) BayesBall (FactorGraph& fg);
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph* getMinimalFactorGraph (const VarIds&); FactorGraph* getMinimalFactorGraph (const VarIds&);
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids) static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids);
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
private: private:

View File

@ -56,6 +56,8 @@ class BayesBallGraph
public: public:
BayesBallGraph (void) { } BayesBallGraph (void) { }
bool empty (void) const { return nodes_.empty(); }
void addNode (BBNode* n); void addNode (BBNode* n);
void addEdge (VarId vid1, VarId vid2); void addEdge (VarId vid1, VarId vid2);
@ -64,8 +66,6 @@ class BayesBallGraph
BBNode* getNode (VarId vid); BBNode* getNode (VarId vid);
bool empty (void) const { return nodes_.empty(); }
void setIndexes (void); void setIndexes (void);
void clear (void); void clear (void);

View File

@ -9,11 +9,61 @@
#include "Horus.h" #include "Horus.h"
BpLink::BpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
void
BpLink::clearResidual (void)
{
residual_ = 0.0;
}
void
BpLink::updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
void
BpLink::updateMessage (void)
{
swap (currMsg_, nextMsg_);
}
string
BpLink::toString (void) const
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
double BeliefProp::accuracy_ = 0.0001; double BeliefProp::accuracy_ = 0.0001;
unsigned BeliefProp::maxIter_ = 1000; unsigned BeliefProp::maxIter_ = 1000;
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED; MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg) BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
{ {
runned_ = false; runned_ = false;
@ -152,6 +202,46 @@ BeliefProp::getFactorJoint (
void
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
{
if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void
BeliefProp::calculateMessage (BpLink* link, bool calcResidual)
{
if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void
BeliefProp::updateMessage (BpLink* link)
{
link->updateMessage();
if (Globals::verbosity > 2) {
cout << "updating " << link->toString() << endl;
}
}
void void
BeliefProp::runSolver (void) BeliefProp::runSolver (void)
{ {

View File

@ -24,16 +24,7 @@ enum MsgSchedule {
class BpLink class BpLink
{ {
public: public:
BpLink (FacNode* fn, VarNode* vn) BpLink (FacNode* fn, VarNode* vn);
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
virtual ~BpLink (void) { }; virtual ~BpLink (void) { };
@ -47,26 +38,13 @@ class BpLink
double residual (void) const { return residual_; } double residual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; } void clearResidual (void);
void updateResidual (void) void updateResidual (void);
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
virtual void updateMessage (void) virtual void updateMessage (void);
{
swap (currMsg_, nextMsg_);
}
string toString (void) const string toString (void) const;
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
protected: protected:
FacNode* fac_; FacNode* fac_;
@ -126,46 +104,15 @@ class BeliefProp : public GroundSolver
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; } static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
protected: protected:
SPNodeInfo* ninf (const VarNode* var) const SPNodeInfo* ninf (const VarNode* var) const;
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FacNode* fac) const SPNodeInfo* ninf (const FacNode* fac) const;
{
return facsI_[fac->getIndex()];
}
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true) void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
{
if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (BpLink* link, bool calcResidual = true) void calculateMessage (BpLink* link, bool calcResidual = true);
{
if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (BpLink* link) void updateMessage (BpLink* link);
{
link->updateMessage();
if (Globals::verbosity > 2) {
cout << "updating " << link->toString() << endl;
}
}
struct CompareResidual struct CompareResidual
{ {
@ -213,5 +160,21 @@ class BeliefProp : public GroundSolver
DISALLOW_COPY_AND_ASSIGN (BeliefProp); DISALLOW_COPY_AND_ASSIGN (BeliefProp);
}; };
inline SPNodeInfo*
BeliefProp::ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
}
inline SPNodeInfo*
BeliefProp::ninf (const FacNode* fac) const
{
return facsI_[fac->getIndex()];
}
#endif // HORUS_BELIEFPROP_H #endif // HORUS_BELIEFPROP_H

View File

@ -216,6 +216,17 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
ConstraintTree::ConstraintTree (
const CTChilds& rootChilds,
const LogVars& logVars)
: root_(new CTNode (0, 0, rootChilds)),
logVars_(logVars),
logVarSet_(logVars)
{
}
ConstraintTree::~ConstraintTree (void) ConstraintTree::~ConstraintTree (void)
{ {
CTNode::deleteSubtree (root_); CTNode::deleteSubtree (root_);

View File

@ -59,11 +59,7 @@ class CTNode
bool isLeaf (void) const { return childs_.empty(); } bool isLeaf (void) const { return childs_.empty(); }
CTChilds_::iterator findSymbol (Symbol symb) CTChilds_::iterator findSymbol (Symbol symb);
{
CTNode tmp (symb, 0);
return childs_.find (&tmp);
}
void mergeSubtree (CTNode*, bool = true); void mergeSubtree (CTNode*, bool = true);
@ -91,12 +87,21 @@ class CTNode
DISALLOW_ASSIGN (CTNode); DISALLOW_ASSIGN (CTNode);
}; };
ostream& operator<< (ostream &out, const CTNode&);
typedef TinySet<CTNode*, CTNode::CompareSymbol> CTChilds; typedef TinySet<CTNode*, CTNode::CompareSymbol> CTChilds;
inline CTChilds::iterator
CTNode::findSymbol (Symbol symb)
{
CTNode tmp (symb, 0);
return childs_.find (&tmp);
}
ostream& operator<< (ostream &out, const CTNode&);
class ConstraintTree class ConstraintTree
{ {
public: public:
@ -110,10 +115,7 @@ class ConstraintTree
ConstraintTree (const ConstraintTree&); ConstraintTree (const ConstraintTree&);
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars) ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars);
: root_(new CTNode (0, 0, rootChilds)),
logVars_(logVars),
logVarSet_(logVars) { }
~ConstraintTree (void); ~ConstraintTree (void);
@ -121,23 +123,11 @@ class ConstraintTree
bool empty (void) const { return root_->childs().empty(); } bool empty (void) const { return root_->childs().empty(); }
const LogVars& logVars (void) const const LogVars& logVars (void) const;
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVars_;
}
const LogVarSet& logVarSet (void) const const LogVarSet& logVarSet (void) const;
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVarSet_;
}
size_t nrLogVars (void) const size_t nrLogVars (void) const;
{
return logVars_.size();
assert (LogVarSet (logVars_) == logVarSet_);
}
void addTuple (const Tuple&); void addTuple (const Tuple&);
@ -230,5 +220,31 @@ class ConstraintTree
}; };
inline const LogVars&
ConstraintTree::logVars (void) const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVars_;
}
inline const LogVarSet&
ConstraintTree::logVarSet (void) const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVarSet_;
}
inline size_t
ConstraintTree::nrLogVars (void) const
{
return logVars_.size();
assert (LogVarSet (logVars_) == logVarSet_);
}
#endif // HORUS_CONSTRAINTTREE_H #endif // HORUS_CONSTRAINTTREE_H

View File

@ -27,6 +27,7 @@ typedef unordered_map<VarId, VarCluster*> VarClusterMap;
typedef vector<VarCluster*> VarClusters; typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters; typedef vector<FacCluster*> FacClusters;
template <class T> template <class T>
inline size_t hash_combine (size_t seed, const T& v) inline size_t hash_combine (size_t seed, const T& v)
{ {
@ -35,6 +36,7 @@ inline size_t hash_combine (size_t seed, const T& v)
namespace std { namespace std {
template <typename T1, typename T2> struct hash<std::pair<T1,T2>> template <typename T1, typename T2> struct hash<std::pair<T1,T2>>
{ {
size_t operator() (const std::pair<T1,T2>& p) const size_t operator() (const std::pair<T1,T2>& p) const
@ -56,6 +58,7 @@ namespace std {
return h; return h;
} }
}; };
} }
@ -119,31 +122,15 @@ class CountingBp : public GroundSolver
static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; } static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; }
private: private:
Color getNewColor (void) Color getNewColor (void);
{
++ freeColor_;
return freeColor_ - 1;
}
Color getColor (const VarNode* vn) const Color getColor (const VarNode* vn) const;
{
return varColors_[vn->getIndex()];
}
Color getColor (const FacNode* fn) const Color getColor (const FacNode* fn) const;
{
return facColors_[fn->getIndex()];
}
void setColor (const VarNode* vn, Color c) void setColor (const VarNode* vn, Color c);
{
varColors_[vn->getIndex()] = c;
}
void setColor (const FacNode* fn, Color c) void setColor (const FacNode* fn, Color c);
{
facColors_[fn->getIndex()] = c;
}
void findIdenticalFactors (void); void findIdenticalFactors (void);
@ -184,5 +171,46 @@ class CountingBp : public GroundSolver
DISALLOW_COPY_AND_ASSIGN (CountingBp); DISALLOW_COPY_AND_ASSIGN (CountingBp);
}; };
inline Color
CountingBp::getNewColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
inline Color
CountingBp::getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
inline Color
CountingBp::getColor (const FacNode* fn) const
{
return facColors_[fn->getIndex()];
}
inline void
CountingBp::setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
inline void
CountingBp::setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
#endif // HORUS_COUNTINGBP_H #endif // HORUS_COUNTINGBP_H

View File

@ -63,64 +63,17 @@ class ElimGraph
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; } static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
private: private:
void addEdge (EgNode* n1, EgNode* n2) void addEdge (EgNode* n1, EgNode* n2);
{
assert (n1 != n2);
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
unsigned getNeighborsCost (const EgNode* n) const unsigned getNeighborsCost (const EgNode* n) const;
{
return n->neighbors().size();
}
unsigned getWeightCost (const EgNode* n) const unsigned getWeightCost (const EgNode* n) const;
{
unsigned cost = 1;
const EGNeighs& neighs = n->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
cost *= neighs[i]->range();
}
return cost;
}
unsigned getFillCost (const EgNode* n) const unsigned getFillCost (const EgNode* n) const;
{
unsigned cost = 0;
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
if ( ! neighbors (neighs[i], neighs[j])) {
cost ++;
}
}
}
}
return cost;
}
unsigned getWeightedFillCost (const EgNode* n) const unsigned getWeightedFillCost (const EgNode* n) const;
{
unsigned cost = 0;
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
if ( ! neighbors (neighs[i], neighs[j])) {
cost += neighs[i]->range() * neighs[j]->range();
}
}
}
}
return cost;
}
bool neighbors (EgNode* n1, EgNode* n2) const bool neighbors (EgNode* n1, EgNode* n2) const;
{
return n1->isNeighbor (n2);
}
void addNode (EgNode*); void addNode (EgNode*);
@ -139,5 +92,82 @@ class ElimGraph
DISALLOW_COPY_AND_ASSIGN (ElimGraph); DISALLOW_COPY_AND_ASSIGN (ElimGraph);
}; };
inline void
ElimGraph::addEdge (EgNode* n1, EgNode* n2)
{
assert (n1 != n2);
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
inline unsigned
ElimGraph::getNeighborsCost (const EgNode* n) const
{
return n->neighbors().size();
}
inline unsigned
ElimGraph::getWeightCost (const EgNode* n) const
{
unsigned cost = 1;
const EGNeighs& neighs = n->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
cost *= neighs[i]->range();
}
return cost;
}
inline unsigned
ElimGraph::getFillCost (const EgNode* n) const
{
unsigned cost = 0;
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
if ( ! neighbors (neighs[i], neighs[j])) {
cost ++;
}
}
}
}
return cost;
}
inline unsigned
ElimGraph::getWeightedFillCost (const EgNode* n) const
{
unsigned cost = 0;
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
if ( ! neighbors (neighs[i], neighs[j])) {
cost += neighs[i]->range() * neighs[j]->range();
}
}
}
}
return cost;
}
inline bool
ElimGraph::neighbors (EgNode* n1, EgNode* n2) const
{
return n1->isNeighbor (n2);
}
#endif // HORUS_ELIMGRAPH_H #endif // HORUS_ELIMGRAPH_H

View File

@ -34,43 +34,103 @@ class TFactor
void normalize (void) { LogAware::normalize (params_); } void normalize (void) { LogAware::normalize (params_); }
void randomize (void) void randomize (void);
void setParams (const Params& newParams);
size_t indexOf (const T& t) const;
const T& argument (size_t idx) const;
T& argument (size_t idx);
unsigned range (size_t idx) const;
void multiply (TFactor<T>& g);
void sumOutIndex (size_t idx);
void absorveEvidence (const T& arg, unsigned obsIdx);
void reorderArguments (const vector<T> new_args);
bool contains (const T& arg) const;
bool contains (const vector<T>& args) const;
double& operator[] (size_t idx);
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void extend (unsigned range_prod);
void cartesianProduct (
Params::const_iterator first2, Params::const_iterator last2);
};
template <typename T> inline void
TFactor<T>::randomize (void)
{ {
for (size_t i = 0; i < params_.size(); ++i) { for (size_t i = 0; i < params_.size(); ++i) {
params_[i] = (double) std::rand() / RAND_MAX; params_[i] = (double) std::rand() / RAND_MAX;
} }
} }
void setParams (const Params& newParams)
template <typename T> inline void
TFactor<T>::setParams (const Params& newParams)
{ {
params_ = newParams; params_ = newParams;
assert (params_.size() == Util::sizeExpected (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
size_t indexOf (const T& t) const
template <typename T> inline size_t
TFactor<T>::indexOf (const T& t) const
{ {
return Util::indexOf (args_, t); return Util::indexOf (args_, t);
} }
const T& argument (size_t idx) const
template <typename T> inline const T&
TFactor<T>::argument (size_t idx) const
{ {
assert (idx < args_.size()); assert (idx < args_.size());
return args_[idx]; return args_[idx];
} }
T& argument (size_t idx)
template <typename T> inline T&
TFactor<T>::argument (size_t idx)
{ {
assert (idx < args_.size()); assert (idx < args_.size());
return args_[idx]; return args_[idx];
} }
unsigned range (size_t idx) const
template <typename T> inline unsigned
TFactor<T>::range (size_t idx) const
{ {
assert (idx < ranges_.size()); assert (idx < ranges_.size());
return ranges_[idx]; return ranges_[idx];
} }
void multiply (TFactor<T>& g)
template <typename T> inline void
TFactor<T>::multiply (TFactor<T>& g)
{ {
if (args_ == g.arguments()) { if (args_ == g.arguments()) {
// optimization // optimization
@ -113,7 +173,10 @@ class TFactor
} }
} }
void sumOutIndex (size_t idx)
template <typename T> inline void
TFactor<T>::sumOutIndex (size_t idx)
{ {
assert (idx < args_.size()); assert (idx < args_.size());
assert (args_.size() > 1); assert (args_.size() > 1);
@ -136,7 +199,10 @@ class TFactor
ranges_.erase (ranges_.begin() + idx); ranges_.erase (ranges_.begin() + idx);
} }
void absorveEvidence (const T& arg, unsigned obsIdx)
template <typename T> inline void
TFactor<T>::absorveEvidence (const T& arg, unsigned obsIdx)
{ {
size_t idx = indexOf (arg); size_t idx = indexOf (arg);
assert (idx != args_.size()); assert (idx != args_.size());
@ -156,7 +222,10 @@ class TFactor
ranges_.erase (ranges_.begin() + idx); ranges_.erase (ranges_.begin() + idx);
} }
void reorderArguments (const vector<T> new_args)
template <typename T> inline void
TFactor<T>::reorderArguments (const vector<T> new_args)
{ {
assert (new_args.size() == args_.size()); assert (new_args.size() == args_.size());
if (new_args == args_) { if (new_args == args_) {
@ -179,12 +248,18 @@ class TFactor
ranges_ = new_ranges; ranges_ = new_ranges;
} }
bool contains (const T& arg) const
template <typename T> inline bool
TFactor<T>::contains (const T& arg) const
{ {
return Util::contains (args_, arg); return Util::contains (args_, arg);
} }
bool contains (const vector<T>& args) const
template <typename T> inline bool
TFactor<T>::contains (const vector<T>& args) const
{ {
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
if (contains (args[i]) == false) { if (contains (args[i]) == false) {
@ -194,21 +269,19 @@ class TFactor
return true; return true;
} }
double& operator[] (size_t idx)
template <typename T> inline double&
TFactor<T>::operator[] (size_t idx)
{ {
assert (idx < params_.size()); assert (idx < params_.size());
return params_[idx]; return params_[idx];
} }
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private: template <typename T> inline void
void extend (unsigned range_prod) TFactor<T>::extend (unsigned range_prod)
{ {
Params backup = params_; Params backup = params_;
params_.clear(); params_.clear();
@ -222,7 +295,10 @@ class TFactor
} }
} }
void cartesianProduct (
template <typename T> inline void
TFactor<T>::cartesianProduct (
Params::const_iterator first2, Params::const_iterator first2,
Params::const_iterator last2) Params::const_iterator last2)
{ {
@ -247,8 +323,6 @@ class TFactor
} }
} }
};
class Factor : public TFactor<VarId> class Factor : public TFactor<VarId>

View File

@ -82,11 +82,7 @@ class FactorGraph
size_t nrFacNodes (void) const { return facNodes_.size(); } size_t nrFacNodes (void) const { return facNodes_.size(); }
VarNode* getVarNode (VarId vid) const VarNode* getVarNode (VarId vid) const;
{
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void readFromUaiFormat (const char*); void readFromUaiFormat (const char*);
@ -166,6 +162,15 @@ class FactorGraph
inline VarNode*
FactorGraph::getVarNode (VarId vid) const
{
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
struct sortByVarId struct sortByVarId
{ {
bool operator()(VarNode* vn1, VarNode* vn2) { bool operator()(VarNode* vn1, VarNode* vn2) {

View File

@ -30,6 +30,7 @@ class GroundSolver
protected: protected:
const FactorGraph& fg; const FactorGraph& fg;
private:
DISALLOW_COPY_AND_ASSIGN (GroundSolver); DISALLOW_COPY_AND_ASSIGN (GroundSolver);
}; };

View File

@ -118,14 +118,6 @@ HistogramSet::getNumAssigns (unsigned N, unsigned R)
ostream& operator<< (ostream &os, const HistogramSet& hs)
{
os << "#" << hs.hist_;
return os;
}
unsigned unsigned
HistogramSet::maxCount (size_t idx) const HistogramSet::maxCount (size_t idx) const
{ {
@ -144,3 +136,11 @@ HistogramSet::clearAfter (size_t idx)
std::fill (hist_.begin() + idx + 1, hist_.end(), 0); std::fill (hist_.begin() + idx + 1, hist_.end(), 0);
} }
ostream& operator<< (ostream &os, const HistogramSet& hs)
{
os << "#" << hs.hist_;
return os;
}

View File

@ -33,8 +33,6 @@ class HistogramSet
static vector<double> getNumAssigns (unsigned, unsigned); static vector<double> getNumAssigns (unsigned, unsigned);
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
private: private:
unsigned maxCount (size_t) const; unsigned maxCount (size_t) const;
@ -43,6 +41,8 @@ class HistogramSet
unsigned size_; unsigned size_;
Histogram hist_; Histogram hist_;
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
DISALLOW_COPY_AND_ASSIGN (HistogramSet); DISALLOW_COPY_AND_ASSIGN (HistogramSet);
}; };

View File

@ -13,7 +13,46 @@
class Indexer class Indexer
{ {
public: public:
Indexer (const Ranges& ranges, bool calcOffsets = true) Indexer (const Ranges& ranges, bool calcOffsets = true);
void increment (void);
void incrementDimension (size_t dim);
void incrementExceptDimension (size_t dim);
Indexer& operator++ (void);
operator size_t (void) const;
unsigned operator[] (size_t dim) const;
bool valid (void) const;
void reset (void);
void resetDimension (size_t dim);
size_t size (void) const;
private:
void calculateOffsets (void);
size_t index_;
Ranges indices_;
const Ranges& ranges_;
size_t size_;
vector<size_t> offsets_;
friend std::ostream& operator<< (std::ostream&, const Indexer&);
DISALLOW_COPY_AND_ASSIGN (Indexer);
};
inline
Indexer::Indexer (const Ranges& ranges, bool calcOffsets)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
size_(Util::sizeExpected (ranges)) size_(Util::sizeExpected (ranges))
{ {
@ -22,7 +61,10 @@ class Indexer
} }
} }
void increment (void)
inline void
Indexer::increment (void)
{ {
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
indices_[i] ++; indices_[i] ++;
@ -35,7 +77,10 @@ class Indexer
index_ ++; index_ ++;
} }
void incrementDimension (size_t dim)
inline void
Indexer::incrementDimension (size_t dim)
{ {
assert (dim < ranges_.size()); assert (dim < ranges_.size());
assert (ranges_.size() == offsets_.size()); assert (ranges_.size() == offsets_.size());
@ -44,7 +89,10 @@ class Indexer
index_ += offsets_[dim]; index_ += offsets_[dim];
} }
void incrementExceptDimension (size_t dim)
inline void
Indexer::incrementExceptDimension (size_t dim)
{ {
assert (ranges_.size() == offsets_.size()); assert (ranges_.size() == offsets_.size());
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
@ -62,50 +110,71 @@ class Indexer
index_ = size_; index_ = size_;
} }
Indexer& operator++ (void)
inline Indexer&
Indexer::operator++ (void)
{ {
increment(); increment();
return *this; return *this;
} }
operator size_t (void) const
inline
Indexer::operator size_t (void) const
{ {
return index_; return index_;
} }
unsigned operator[] (size_t dim) const
inline unsigned
Indexer::operator[] (size_t dim) const
{ {
assert (valid()); assert (valid());
assert (dim < ranges_.size()); assert (dim < ranges_.size());
return indices_[dim]; return indices_[dim];
} }
bool valid (void) const
inline bool
Indexer::valid (void) const
{ {
return index_ < size_; return index_ < size_;
} }
void reset (void)
inline void
Indexer::reset (void)
{ {
std::fill (indices_.begin(), indices_.end(), 0); std::fill (indices_.begin(), indices_.end(), 0);
index_ = 0; index_ = 0;
} }
void resetDimension (size_t dim)
inline void
Indexer::resetDimension (size_t dim)
{ {
indices_[dim] = 0; indices_[dim] = 0;
index_ -= offsets_[dim] * ranges_[dim]; index_ -= offsets_[dim] * ranges_[dim];
} }
size_t size (void) const
inline size_t
Indexer::size (void) const
{ {
return size_ ; return size_ ;
} }
friend std::ostream& operator<< (std::ostream&, const Indexer&);
private:
void calculateOffsets (void) inline void
Indexer::calculateOffsets (void)
{ {
size_t prod = 1; size_t prod = 1;
offsets_.resize (ranges_.size()); offsets_.resize (ranges_.size());
@ -115,15 +184,6 @@ class Indexer
} }
} }
size_t index_;
Ranges indices_;
const Ranges& ranges_;
size_t size_;
vector<size_t> offsets_;
DISALLOW_COPY_AND_ASSIGN (Indexer);
};
inline std::ostream& inline std::ostream&
@ -141,7 +201,43 @@ operator<< (std::ostream& os, const Indexer& indexer)
class MapIndexer class MapIndexer
{ {
public: public:
MapIndexer (const Ranges& ranges, const vector<bool>& mask) MapIndexer (const Ranges& ranges, const vector<bool>& mask);
MapIndexer (const Ranges& ranges, size_t dim);
template <typename T>
MapIndexer (
const vector<T>& allArgs,
const Ranges& allRanges,
const vector<T>& wantedArgs,
const Ranges& wantedRanges);
MapIndexer& operator++ (void);
operator size_t (void) const;
unsigned operator[] (size_t dim) const;
bool valid (void) const;
void reset (void);
private:
size_t index_;
Ranges indices_;
const Ranges& ranges_;
bool valid_;
vector<size_t> offsets_;
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
DISALLOW_COPY_AND_ASSIGN (MapIndexer);
};
inline
MapIndexer::MapIndexer (const Ranges& ranges, const vector<bool>& mask)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true) valid_(true)
{ {
@ -156,7 +252,10 @@ class MapIndexer
assert (ranges.size() == mask.size()); assert (ranges.size() == mask.size());
} }
MapIndexer (const Ranges& ranges, size_t dim)
inline
MapIndexer::MapIndexer (const Ranges& ranges, size_t dim)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true) valid_(true)
{ {
@ -170,8 +269,10 @@ class MapIndexer
} }
} }
template <typename T>
MapIndexer (
template <typename T> inline
MapIndexer::MapIndexer (
const vector<T>& allArgs, const vector<T>& allArgs,
const Ranges& allRanges, const Ranges& allRanges,
const vector<T>& wantedArgs, const vector<T>& wantedArgs,
@ -192,7 +293,10 @@ class MapIndexer
} }
} }
MapIndexer& operator++ (void)
inline MapIndexer&
MapIndexer::operator++ (void)
{ {
assert (valid_); assert (valid_);
for (size_t i = ranges_.size(); i-- > 0; ) { for (size_t i = ranges_.size(); i-- > 0; ) {
@ -209,42 +313,42 @@ class MapIndexer
return *this; return *this;
} }
operator size_t (void) const
inline
MapIndexer::operator size_t (void) const
{ {
assert (valid()); assert (valid());
return index_; return index_;
} }
unsigned operator[] (size_t dim) const
inline unsigned
MapIndexer::operator[] (size_t dim) const
{ {
assert (valid()); assert (valid());
assert (dim < ranges_.size()); assert (dim < ranges_.size());
return indices_[dim]; return indices_[dim];
} }
bool valid (void) const
inline bool
MapIndexer::valid (void) const
{ {
return valid_; return valid_;
} }
void reset (void)
inline void
MapIndexer::reset (void)
{ {
std::fill (indices_.begin(), indices_.end(), 0); std::fill (indices_.begin(), indices_.end(), 0);
index_ = 0; index_ = 0;
} }
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
private:
size_t index_;
Ranges indices_;
const Ranges& ranges_;
bool valid_;
vector<size_t> offsets_;
DISALLOW_COPY_AND_ASSIGN (MapIndexer);
};
inline std::ostream& inline std::ostream&
@ -257,6 +361,5 @@ operator<< (std::ostream &os, const MapIndexer& indexer)
return os; return os;
} }
#endif // HORUS_INDEXER_H #endif // HORUS_INDEXER_H

View File

@ -26,10 +26,10 @@ class Symbol
static Symbol invalid (void) { return Symbol(); } static Symbol invalid (void) { return Symbol(); }
friend ostream& operator<< (ostream &os, const Symbol& s);
private: private:
unsigned id_; unsigned id_;
friend ostream& operator<< (ostream &os, const Symbol& s);
}; };
@ -42,35 +42,50 @@ class LogVar
operator unsigned (void) const { return id_; } operator unsigned (void) const { return id_; }
LogVar& operator++ (void) LogVar& operator++ (void);
bool valid (void) const;
private:
unsigned id_;
friend ostream& operator<< (ostream &os, const LogVar& X);
};
inline LogVar&
LogVar::operator++ (void)
{ {
assert (valid()); assert (valid());
id_ ++; id_ ++;
return *this; return *this;
} }
bool valid (void) const
inline bool
LogVar::valid (void) const
{ {
return id_ != Util::maxUnsigned(); return id_ != Util::maxUnsigned();
} }
friend ostream& operator<< (ostream &os, const LogVar& X);
private:
unsigned id_;
};
namespace std { namespace std {
template <> struct hash<Symbol> { template <> struct hash<Symbol> {
size_t operator() (const Symbol& s) const { size_t operator() (const Symbol& s) const {
return std::hash<unsigned>() (s); return std::hash<unsigned>() (s);
}}; }
};
template <> struct hash<LogVar> { template <> struct hash<LogVar> {
size_t operator() (const LogVar& X) const { size_t operator() (const LogVar& X) const {
return std::hash<unsigned>() (X); return std::hash<unsigned>() (X);
}}; }
};
}; };
@ -87,8 +102,11 @@ ostream& operator<< (ostream &os, const Tuple& t);
namespace LiftedUtils { namespace LiftedUtils {
Symbol getSymbol (const string&); Symbol getSymbol (const string&);
void printSymbolDictionary (void); void printSymbolDictionary (void);
} }
@ -108,11 +126,11 @@ class Ground
bool isAtom (void) const { return args_.empty(); } bool isAtom (void) const { return args_.empty(); }
friend ostream& operator<< (ostream &os, const Ground& gr);
private: private:
Symbol functor_; Symbol functor_;
Symbols args_; Symbols args_;
friend ostream& operator<< (ostream &os, const Ground& gr);
}; };
typedef vector<Ground> Grounds; typedef vector<Ground> Grounds;
@ -122,19 +140,48 @@ typedef vector<Ground> Grounds;
class Substitution class Substitution
{ {
public: public:
void add (LogVar X_old, LogVar X_new) void add (LogVar X_old, LogVar X_new);
void rename (LogVar X_old, LogVar X_new);
LogVar newNameFor (LogVar X) const;
bool containsReplacementFor (LogVar X) const;
size_t nrReplacements (void) const;
LogVars getDiscardedLogVars (void) const;
private:
unordered_map<LogVar, LogVar> subs_;
friend ostream& operator<< (ostream &os, const Substitution& theta);
};
inline void
Substitution::add (LogVar X_old, LogVar X_new)
{ {
assert (Util::contains (subs_, X_old) == false); assert (Util::contains (subs_, X_old) == false);
subs_.insert (make_pair (X_old, X_new)); subs_.insert (make_pair (X_old, X_new));
} }
void rename (LogVar X_old, LogVar X_new)
inline void
Substitution::rename (LogVar X_old, LogVar X_new)
{ {
assert (Util::contains (subs_, X_old)); assert (Util::contains (subs_, X_old));
subs_.find (X_old)->second = X_new; subs_.find (X_old)->second = X_new;
} }
LogVar newNameFor (LogVar X) const
inline LogVar
Substitution::newNameFor (LogVar X) const
{ {
unordered_map<LogVar, LogVar>::const_iterator it; unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.find (X); it = subs_.find (X);
@ -144,21 +191,22 @@ class Substitution
return X; return X;
} }
bool containsReplacementFor (LogVar X) const
inline bool
Substitution::containsReplacementFor (LogVar X) const
{ {
return Util::contains (subs_, X); return Util::contains (subs_, X);
} }
size_t nrReplacements (void) const { return subs_.size(); }
LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta); inline size_t
Substitution::nrReplacements (void) const
{
return subs_.size();
}
private:
unordered_map<LogVar, LogVar> subs_;
};
#endif // HORUS_LIFTEDUTILS_H #endif // HORUS_LIFTEDUTILS_H

View File

@ -51,12 +51,12 @@ class Literal
LogVarSet posCountedLvs = LogVarSet(), LogVarSet posCountedLvs = LogVarSet(),
LogVarSet negCountedLvs = LogVarSet()) const; LogVarSet negCountedLvs = LogVarSet()) const;
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
private: private:
LiteralId lid_; LiteralId lid_;
LogVars logVars_; LogVars logVars_;
bool negated_; bool negated_;
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
}; };
typedef vector<Literal> Literals; typedef vector<Literal> Literals;
@ -138,8 +138,6 @@ class Clause
static void deleteClauses (vector<Clause*>& clauses); static void deleteClauses (vector<Clause*>& clauses);
friend std::ostream& operator<< (ostream &os, const Clause& clause);
private: private:
LogVarSet getLogVarSetExcluding (size_t idx) const; LogVarSet getLogVarSetExcluding (size_t idx) const;
@ -149,6 +147,8 @@ class Clause
LogVarSet negCountedLvs_; LogVarSet negCountedLvs_;
ConstraintTree constr_; ConstraintTree constr_;
friend std::ostream& operator<< (ostream &os, const Clause& clause);
DISALLOW_ASSIGN (Clause); DISALLOW_ASSIGN (Clause);
}; };
@ -185,11 +185,11 @@ class LitLvTypes
void setAllFullLogVars (void) { void setAllFullLogVars (void) {
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); } std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
friend std::ostream& operator<< (std::ostream &os, const LitLvTypes& lit);
private: private:
LiteralId lid_; LiteralId lid_;
LogVarTypes lvTypes_; LogVarTypes lvTypes_;
friend std::ostream& operator<< (std::ostream &os, const LitLvTypes& lit);
}; };
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet; typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;

View File

@ -131,6 +131,22 @@ ProbFormula::getNewGroup (void)
ObservedFormula::ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(a)
{
}
ObservedFormula::ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
}
ostream& operator<< (ostream &os, const ObservedFormula& of) ostream& operator<< (ostream &os, const ObservedFormula& of)
{ {
os << of.functor_ << "/" << of.arity_; os << of.functor_ << "/" << of.arity_;

View File

@ -58,10 +58,10 @@ class ProbFormula
static PrvGroup getNewGroup (void); static PrvGroup getNewGroup (void);
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2); friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
private: private:
Symbol functor_; Symbol functor_;
LogVars logVars_; LogVars logVars_;
@ -77,14 +77,9 @@ typedef vector<ProbFormula> ProbFormulas;
class ObservedFormula class ObservedFormula
{ {
public: public:
ObservedFormula (Symbol f, unsigned a, unsigned ev) ObservedFormula (Symbol f, unsigned a, unsigned ev);
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple) ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple);
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
}
Symbol functor (void) const { return functor_; } Symbol functor (void) const { return functor_; }
@ -100,13 +95,13 @@ class ObservedFormula
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); } void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
private: private:
Symbol functor_; Symbol functor_;
unsigned arity_; unsigned arity_;
unsigned evidence_; unsigned evidence_;
ConstraintTree constr_; ConstraintTree constr_;
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
}; };
typedef vector<ObservedFormula> ObservedFormulas; typedef vector<ObservedFormula> ObservedFormulas;

View File

@ -25,191 +25,76 @@ class TinySet
TinySet (const T& t, const Compare& cmp = Compare()) TinySet (const T& t, const Compare& cmp = Compare())
: vec_(1, t), cmp_(cmp) { } : vec_(1, t), cmp_(cmp) { }
TinySet (const vector<T>& elements, const Compare& cmp = Compare()) TinySet (const vector<T>& elements, const Compare& cmp = Compare());
: vec_(elements), cmp_(cmp)
{
std::sort (begin(), end(), cmp_);
iterator it = unique_cmp (begin(), end());
vec_.resize (it - begin());
}
iterator insert (const T& t) iterator insert (const T& t);
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it == end() || cmp_(t, *it)) {
vec_.insert (it, t);
}
return it;
}
void insert_sorted (const T& t) void insert_sorted (const T& t);
{
vec_.push_back (t);
assert (consistent());
}
void remove (const T& t) void remove (const T& t);
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it != end()) {
vec_.erase (it);
}
}
const_iterator find (const T& t) const const_iterator find (const T& t) const;
{
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
iterator find (const T& t) iterator find (const T& t);
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
/* set union */ /* set union */
TinySet operator| (const TinySet& s) const TinySet operator| (const TinySet& s) const;
{
TinySet res;
std::set_union (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set intersection */ /* set intersection */
TinySet operator& (const TinySet& s) const TinySet operator& (const TinySet& s) const;
{
TinySet res;
std::set_intersection (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set difference */ /* set difference */
TinySet operator- (const TinySet& s) const TinySet operator- (const TinySet& s) const;
{
TinySet res;
std::set_difference (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
TinySet& operator|= (const TinySet& s) TinySet& operator|= (const TinySet& s);
{
return *this = (*this | s);
}
TinySet& operator&= (const TinySet& s) TinySet& operator&= (const TinySet& s);
{
return *this = (*this & s);
}
TinySet& operator-= (const TinySet& s) TinySet& operator-= (const TinySet& s);
{
return *this = (*this - s);
}
bool contains (const T& t) const bool contains (const T& t) const;
{
return std::binary_search (
vec_.begin(), vec_.end(), t, cmp_);
}
bool contains (const TinySet& s) const bool contains (const TinySet& s) const;
{
return std::includes (
vec_.begin(),
vec_.end(),
s.vec_.begin(),
s.vec_.end(),
cmp_);
}
bool in (const TinySet& s) const bool in (const TinySet& s) const;
{
return std::includes (
s.vec_.begin(),
s.vec_.end(),
vec_.begin(),
vec_.end(),
cmp_);
}
bool intersects (const TinySet& s) const bool intersects (const TinySet& s) const;
{
return (*this & s).size() > 0;
}
const T& operator[] (typename vector<T>::size_type i) const const T& operator[] (typename vector<T>::size_type i) const;
{
return vec_[i];
}
T& operator[] (typename vector<T>::size_type i) T& operator[] (typename vector<T>::size_type i);
{
return vec_[i];
}
T front (void) const T front (void) const;
{
return vec_.front();
}
T& front (void) T& front (void);
{
return vec_.front();
}
T back (void) const T back (void) const;
{
return vec_.back();
}
T& back (void) T& back (void);
{
return vec_.back();
}
const vector<T>& elements (void) const const vector<T>& elements (void) const;
{
return vec_;
}
bool empty (void) const bool empty (void) const;
{
return vec_.empty();
}
typename vector<T>::size_type size (void) const typename vector<T>::size_type size (void) const;
{
return vec_.size();
}
void clear (void) void clear (void);
{
vec_.clear();
}
void reserve (typename vector<T>::size_type size) void reserve (typename vector<T>::size_type size);
{
vec_.reserve (size);
}
iterator begin (void) { return vec_.begin(); } iterator begin (void) { return vec_.begin(); }
iterator end (void) { return vec_.end(); } iterator end (void) { return vec_.end(); }
const_iterator begin (void) const { return vec_.begin(); } const_iterator begin (void) const { return vec_.begin(); }
const_iterator end (void) const { return vec_.end(); } const_iterator end (void) const { return vec_.end(); }
private:
iterator unique_cmp (iterator first, iterator last);
bool consistent (void) const;
vector<T> vec_;
Compare cmp_;
friend bool operator== (const TinySet& s1, const TinySet& s2) friend bool operator== (const TinySet& s1, const TinySet& s2)
{ {
return s1.vec_ == s2.vec_; return s1.vec_ == s2.vec_;
@ -231,8 +116,269 @@ class TinySet
return out; return out;
} }
private: };
iterator unique_cmp (iterator first, iterator last)
template <typename T, typename C> inline
TinySet<T,C>::TinySet (const vector<T>& elements, const C& cmp)
: vec_(elements), cmp_(cmp)
{
std::sort (begin(), end(), cmp_);
iterator it = unique_cmp (begin(), end());
vec_.resize (it - begin());
}
template <typename T, typename C> inline typename TinySet<T,C>::iterator
TinySet<T,C>::insert (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it == end() || cmp_(t, *it)) {
vec_.insert (it, t);
}
return it;
}
template <typename T, typename C> inline void
TinySet<T,C>::insert_sorted (const T& t)
{
vec_.push_back (t);
assert (consistent());
}
template <typename T, typename C> inline void
TinySet<T,C>::remove (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it != end()) {
vec_.erase (it);
}
}
template <typename T, typename C> inline typename TinySet<T,C>::const_iterator
TinySet<T,C>::find (const T& t) const
{
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
template <typename T, typename C> inline typename TinySet<T,C>::iterator
TinySet<T,C>::find (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
/* set union */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator| (const TinySet& s) const
{
TinySet res;
std::set_union (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set intersection */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator& (const TinySet& s) const
{
TinySet res;
std::set_intersection (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set difference */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator- (const TinySet& s) const
{
TinySet res;
std::set_difference (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator|= (const TinySet& s)
{
return *this = (*this | s);
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator&= (const TinySet& s)
{
return *this = (*this & s);
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator-= (const TinySet& s)
{
return *this = (*this - s);
}
template <typename T, typename C> inline bool
TinySet<T,C>::contains (const T& t) const
{
return std::binary_search (
vec_.begin(), vec_.end(), t, cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::contains (const TinySet& s) const
{
return std::includes (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::in (const TinySet& s) const
{
return std::includes (
s.vec_.begin(), s.vec_.end(),
vec_.begin(), vec_.end(),
cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::intersects (const TinySet& s) const
{
return (*this & s).size() > 0;
}
template <typename T, typename C> inline const T&
TinySet<T,C>::operator[] (typename vector<T>::size_type i) const
{
return vec_[i];
}
template <typename T, typename C> inline T&
TinySet<T,C>::operator[] (typename vector<T>::size_type i)
{
return vec_[i];
}
template <typename T, typename C> inline T
TinySet<T,C>::front (void) const
{
return vec_.front();
}
template <typename T, typename C> inline T&
TinySet<T,C>::front (void)
{
return vec_.front();
}
template <typename T, typename C> inline T
TinySet<T,C>::back (void) const
{
return vec_.back();
}
template <typename T, typename C> inline T&
TinySet<T,C>::back (void)
{
return vec_.back();
}
template <typename T, typename C> inline const vector<T>&
TinySet<T,C>::elements (void) const
{
return vec_;
}
template <typename T, typename C> inline bool
TinySet<T,C>::empty (void) const
{
return vec_.empty();
}
template <typename T, typename C> inline typename vector<T>::size_type
TinySet<T,C>::size (void) const
{
return vec_.size();
}
template <typename T, typename C> inline void
TinySet<T,C>::clear (void)
{
vec_.clear();
}
template <typename T, typename C> inline void
TinySet<T,C>::reserve (typename vector<T>::size_type size)
{
vec_.reserve (size);
}
template <typename T, typename C> typename TinySet<T,C>::iterator
TinySet<T,C>::unique_cmp (iterator first, iterator last)
{ {
if (first == last) { if (first == last) {
return last; return last;
@ -246,7 +392,10 @@ class TinySet
return ++result; return ++result;
} }
bool consistent (void) const
template <typename T, typename C> inline bool
TinySet<T,C>::consistent (void) const
{ {
typename vector<T>::size_type i; typename vector<T>::size_type i;
for (i = 0; i < vec_.size() - 1; i++) { for (i = 0; i < vec_.size() - 1; i++) {
@ -257,9 +406,5 @@ class TinySet
return true; return true;
} }
vector<T> vec_;
Compare cmp_;
};
#endif // HORUS_TINYSET_H #endif // HORUS_TINYSET_H

View File

@ -21,7 +21,9 @@ using namespace std;
namespace { namespace {
const double NEG_INF = -std::numeric_limits<double>::infinity(); const double NEG_INF = -std::numeric_limits<double>::infinity();
}; };

View File

@ -73,3 +73,38 @@ Var::states (void) const
return states; return states;
} }
inline void
Var::addVarInfo (
VarId vid, string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
inline VarInfo
Var::getVarInfo (VarId vid)
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
inline bool
Var::varsHaveInfo (void)
{
return varsInfo_.empty() == false;
}
inline void
Var::clearVarsInfo (void)
{
varsInfo_.clear();
}

View File

@ -39,23 +39,13 @@ class Var
void setIndex (size_t idx) { index_ = idx; } void setIndex (size_t idx) { index_ = idx; }
bool hasEvidence (void) const bool hasEvidence (void) const;
{
return evidence_ != Constants::NO_EVIDENCE;
}
operator size_t (void) const { return index_; } operator size_t (void) const;
bool operator== (const Var& var) const bool operator== (const Var& var) const;
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
bool operator!= (const Var& var) const bool operator!= (const Var& var) const;
{
return !(*this == var);
}
bool isValidState (int); bool isValidState (int);
@ -66,27 +56,13 @@ class Var
States states (void) const; States states (void) const;
static void addVarInfo ( static void addVarInfo (
VarId vid, string label, const States& states) VarId vid, string label, const States& states);
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
static VarInfo getVarInfo (VarId vid) static VarInfo getVarInfo (VarId vid);
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool varsHaveInfo (void) static bool varsHaveInfo (void);
{
return varsInfo_.empty() == false;
}
static void clearVarsInfo (void) static void clearVarsInfo (void);
{
varsInfo_.clear();
}
private: private:
VarId varId_; VarId varId_;
@ -95,8 +71,41 @@ class Var
size_t index_; size_t index_;
static unordered_map<VarId, VarInfo> varsInfo_; static unordered_map<VarId, VarInfo> varsInfo_;
}; };
inline bool
Var::hasEvidence (void) const
{
return evidence_ != Constants::NO_EVIDENCE;
}
inline
Var::operator size_t (void) const
{
return index_;
}
inline bool
Var::operator== (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
inline bool
Var::operator!= (const Var& var) const
{
return !(*this == var);
}
#endif // HORUS_VAR_H #endif // HORUS_VAR_H

View File

@ -3,6 +3,7 @@
#include "BeliefProp.h" #include "BeliefProp.h"
class WeightedLink : public BpLink class WeightedLink : public BpLink
{ {
public: public:
@ -16,12 +17,7 @@ class WeightedLink : public BpLink
const Params& powMessage (void) const { return pwdMsg_; } const Params& powMessage (void) const { return pwdMsg_; }
void updateMessage (void) void updateMessage (void);
{
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
LogAware::pow (pwdMsg_, weight_);
}
private: private:
DISALLOW_COPY_AND_ASSIGN (WeightedLink); DISALLOW_COPY_AND_ASSIGN (WeightedLink);
@ -33,6 +29,16 @@ class WeightedLink : public BpLink
inline void
WeightedLink::updateMessage (void)
{
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
LogAware::pow (pwdMsg_, weight_);
}
class WeightedBp : public BeliefProp class WeightedBp : public BeliefProp
{ {
public: public: