#ifndef HORUS_CFACTORGRAPH_H #define HORUS_CFACTORGRAPH_H #include #include "FactorGraph.h" #include "Factor.h" #include "Shared.h" class VarCluster; class FacCluster; class Distribution; class Signature; class SignatureHash; typedef long Color; typedef unordered_map > VarColorMap; typedef unordered_map DistColorMap; typedef unordered_map VarId2VarCluster; typedef vector VarClusterSet; typedef vector FacClusterSet; typedef unordered_map VarSignMap; typedef unordered_map FacSignMap; struct Signature { Signature (unsigned size) { colors.resize (size); } bool operator< (const Signature& sig) const { if (colors.size() < sig.colors.size()) { return true; } else if (colors.size() > sig.colors.size()) { return false; } else { for (unsigned i = 0; i < colors.size(); i++) { if (colors[i] < sig.colors[i]) { return true; } else if (colors[i] > sig.colors[i]) { return false; } } } return false; } bool operator== (const Signature& sig) const { if (colors.size() != sig.colors.size()) { return false; } for (unsigned i = 0; i < colors.size(); i++) { if (colors[i] != sig.colors[i]) { return false; } } return true; } vector colors; }; struct SignatureHash { size_t operator() (const Signature &sig) const { size_t val = hash()(sig.colors.size()); for (unsigned i = 0; i < sig.colors.size(); i++) { val ^= hash()(sig.colors[i]); } return val; } }; class VarCluster { public: VarCluster (const FgVarSet& vs) { for (unsigned i = 0; i < vs.size(); i++) { groundVars_.push_back (vs[i]); } } void addFacCluster (FacCluster* fc) { factorClusters_.push_back (fc); } const FacClusterSet& getFacClusters (void) const { return factorClusters_; } FgVarNode* getRepresentativeVariable (void) const { return representVar_; } void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; } const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; } private: FgVarSet groundVars_; FacClusterSet factorClusters_; FgVarNode* representVar_; }; class FacCluster { public: FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs) { groundFactors_ = groundFactors; varClusters_ = vcs; for (unsigned i = 0; i < varClusters_.size(); i++) { varClusters_[i]->addFacCluster (this); } } const VarClusterSet& getVarClusters (void) const { return varClusters_; } bool containsGround (const FgFacNode* fn) { for (unsigned i = 0; i < groundFactors_.size(); i++) { if (groundFactors_[i] == fn) { return true; } } return false; } FgFacNode* getRepresentativeFactor (void) const { return representFactor_; } void setRepresentativeFactor (FgFacNode* fn) { representFactor_ = fn; } const FgFacSet& getGroundFactors (void) const { return groundFactors_; } private: FgFacSet groundFactors_; VarClusterSet varClusters_; FgFacNode* representFactor_; }; class CFactorGraph { public: CFactorGraph (const FactorGraph&); ~CFactorGraph (void); FactorGraph* getCompressedFactorGraph (void); unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const; FgVarNode* getEquivalentVariable (VarId vid) { VarCluster* vc = vid2VarCluster_.find (vid)->second; return vc->getRepresentativeVariable(); } const VarClusterSet& getVariableClusters (void) { return varClusters_; } const FacClusterSet& getFacClusters (void) { return factorClusters_; } static void enableCheckForIdenticalFactors (void) { checkForIdenticalFactors_ = true; } static void disableCheckForIdenticalFactors (void) { checkForIdenticalFactors_ = false; } private: void setInitialColors (void); void createGroups (void); void createClusters (const VarSignMap&, const FacSignMap&); const Signature& getSignature (const FgVarNode*); const Signature& getSignature (const FgFacNode*); void printGroups (const VarSignMap&, const FacSignMap&) const; Color getFreeColor (void) { ++ freeColor_; return freeColor_ - 1; } Color getColor (const FgVarNode* vn) const { return varColors_[vn->getIndex()]; } Color getColor (const FgFacNode* fn) const { return factorColors_[fn->getIndex()]; } void setColor (const FgVarNode* vn, Color c) { varColors_[vn->getIndex()] = c; } void setColor (const FgFacNode* fn, Color c) { factorColors_[fn->getIndex()] = c; } VarCluster* getVariableCluster (VarId vid) const { return vid2VarCluster_.find (vid)->second; } Color freeColor_; vector varColors_; vector factorColors_; vector varSignatures_; vector factorSignatures_; VarClusterSet varClusters_; FacClusterSet factorClusters_; VarId2VarCluster vid2VarCluster_; const FactorGraph* groundFg_; bool static checkForIdenticalFactors_; }; #endif // HORUS_CFACTORGRAPH_H