From 689244a0d8818bb072f41c6a0929c38654a1f030 Mon Sep 17 00:00:00 2001 From: Tiago Gomes Date: Thu, 26 Apr 2012 00:54:06 +0100 Subject: [PATCH] improve cbp by supporting factors connected to a single var two or more times --- packages/CLPBN/clpbn/bp/BpSolver.cpp | 14 +- packages/CLPBN/clpbn/bp/BpSolver.h | 2 +- packages/CLPBN/clpbn/bp/CFactorGraph.cpp | 144 ++++++++---------- packages/CLPBN/clpbn/bp/CFactorGraph.h | 122 +++++---------- packages/CLPBN/clpbn/bp/CbpSolver.cpp | 179 +++++++++++++++++------ packages/CLPBN/clpbn/bp/CbpSolver.h | 11 +- packages/CLPBN/clpbn/bp/Factor.cpp | 4 +- packages/CLPBN/clpbn/bp/FactorGraph.cpp | 2 +- 8 files changed, 252 insertions(+), 226 deletions(-) diff --git a/packages/CLPBN/clpbn/bp/BpSolver.cpp b/packages/CLPBN/clpbn/bp/BpSolver.cpp index 49a203388..3271371e2 100644 --- a/packages/CLPBN/clpbn/bp/BpSolver.cpp +++ b/packages/CLPBN/clpbn/bp/BpSolver.cpp @@ -325,7 +325,6 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) } } } - Factor result (src->factor().arguments(), src->factor().ranges(), msgProduct); result.multiply (src->factor()); @@ -336,18 +335,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) } result.sumOutAllExcept (dst->varId()); if (Constants::DEBUG >= 5) { - cout << " marginalized: " ; - cout << result.params() << endl; + cout << " marginalized: " << result.params() << endl; } - const Params& resultParams = result.params(); - Params& message = link->getNextMessage(); - for (unsigned i = 0; i < resultParams.size(); i++) { - message[i] = resultParams[i]; - } - LogAware::normalize (message); + link->getNextMessage() = result.params(); + LogAware::normalize (link->getNextMessage()); if (Constants::DEBUG >= 5) { cout << " curr msg: " << link->getMessage() << endl; - cout << " next msg: " << message << endl; + cout << " next msg: " << link->getNextMessage() << endl; } } diff --git a/packages/CLPBN/clpbn/bp/BpSolver.h b/packages/CLPBN/clpbn/bp/BpSolver.h index 388346ddd..688b6bb15 100644 --- a/packages/CLPBN/clpbn/bp/BpSolver.h +++ b/packages/CLPBN/clpbn/bp/BpSolver.h @@ -183,7 +183,7 @@ class BpSolver : public Solver bool converged (void); - void printLinkInformation (void) const; + virtual void printLinkInformation (void) const; }; #endif // HORUS_BPSOLVER_H diff --git a/packages/CLPBN/clpbn/bp/CFactorGraph.cpp b/packages/CLPBN/clpbn/bp/CFactorGraph.cpp index 70e98db74..e24bc0be0 100644 --- a/packages/CLPBN/clpbn/bp/CFactorGraph.cpp +++ b/packages/CLPBN/clpbn/bp/CFactorGraph.cpp @@ -7,20 +7,6 @@ bool CFactorGraph::checkForIdenticalFactors = true; CFactorGraph::CFactorGraph (const FactorGraph& fg) : freeColor_(0), groundFg_(&fg) { - const VarNodes& varNodes = fg.varNodes(); - varSignatures_.reserve (varNodes.size()); - for (unsigned i = 0; i < varNodes.size(); i++) { - unsigned c = (varNodes[i]->neighbors().size() * 2) + 1; - varSignatures_.push_back (Signature (c)); - } - const FacNodes& facNodes = fg.facNodes(); - facSignatures_.reserve (facNodes.size()); - for (unsigned i = 0; i < facNodes.size(); i++) { - unsigned c = facNodes[i]->neighbors().size() + 1; - facSignatures_.push_back (Signature (c)); - } - varColors_.resize (varNodes.size()); - facColors_.resize (facNodes.size()); findIdenticalFactors(); setInitialColors(); createGroups(); @@ -77,6 +63,9 @@ CFactorGraph::findIdenticalFactors() void CFactorGraph::setInitialColors (void) { + varColors_.resize (groundFg_->nrVarNodes()); + facColors_.resize (groundFg_->nrFacNodes()); + // create the initial variable colors VarColorMap colorMap; const VarNodes& varNodes = groundFg_->varNodes(); @@ -127,31 +116,11 @@ CFactorGraph::createGroups (void) while (groupsHaveChanged || nIters == 1) { nIters ++; - unsigned prevFactorGroupsSize = facGroups.size(); - facGroups.clear(); - // set a new color to the factors with the same signature - for (unsigned i = 0; i < facNodes.size(); i++) { - const Signature& signature = getSignature (facNodes[i]); - FacSignMap::iterator it = facGroups.find (signature); - if (it == facGroups.end()) { - it = facGroups.insert (make_pair (signature, FacNodes())).first; - } - it->second.push_back (facNodes[i]); - } - for (FacSignMap::iterator it = facGroups.begin(); - it != facGroups.end(); it++) { - Color newColor = getFreeColor(); - FacNodes& groupMembers = it->second; - for (unsigned i = 0; i < groupMembers.size(); i++) { - setColor (groupMembers[i], newColor); - } - } - // set a new color to the variables with the same signature unsigned prevVarGroupsSize = varGroups.size(); varGroups.clear(); for (unsigned i = 0; i < varNodes.size(); i++) { - const Signature& signature = getSignature (varNodes[i]); + const VarSignature& signature = getSignature (varNodes[i]); VarSignMap::iterator it = varGroups.find (signature); if (it == varGroups.end()) { it = varGroups.insert (make_pair (signature, VarNodes())).first; @@ -167,6 +136,26 @@ CFactorGraph::createGroups (void) } } + unsigned prevFactorGroupsSize = facGroups.size(); + facGroups.clear(); + // set a new color to the factors with the same signature + for (unsigned i = 0; i < facNodes.size(); i++) { + const FacSignature& signature = getSignature (facNodes[i]); + FacSignMap::iterator it = facGroups.find (signature); + if (it == facGroups.end()) { + it = facGroups.insert (make_pair (signature, FacNodes())).first; + } + it->second.push_back (facNodes[i]); + } + for (FacSignMap::iterator it = facGroups.begin(); + it != facGroups.end(); it++) { + Color newColor = getFreeColor(); + FacNodes& groupMembers = it->second; + for (unsigned i = 0; i < groupMembers.size(); i++) { + setColor (groupMembers[i], newColor); + } + } + groupsHaveChanged = prevVarGroupsSize != varGroups.size() || prevFactorGroupsSize != facGroups.size(); } @@ -183,7 +172,7 @@ CFactorGraph::createClusters ( { varClusters_.reserve (varGroups.size()); for (VarSignMap::const_iterator it = varGroups.begin(); - it != varGroups.end(); it++) { + it != varGroups.end(); it++) { const VarNodes& groupVars = it->second; VarCluster* vc = new VarCluster (groupVars); for (unsigned i = 0; i < groupVars.size(); i++) { @@ -194,7 +183,7 @@ CFactorGraph::createClusters ( facClusters_.reserve (facGroups.size()); for (FacSignMap::const_iterator it = facGroups.begin(); - it != facGroups.end(); it++) { + it != facGroups.end(); it++) { FacNode* groupFactor = it->second[0]; const VarNodes& neighs = groupFactor->neighbors(); VarClusters varClusters; @@ -209,67 +198,61 @@ CFactorGraph::createClusters ( -const Signature& +VarSignature CFactorGraph::getSignature (const VarNode* varNode) { - Signature& sign = varSignatures_[varNode->getIndex()]; - Colors::iterator it = sign.colors.begin(); const FacNodes& neighs = varNode->neighbors(); + VarSignature sign; + sign.reserve (neighs.size() + 1); for (unsigned i = 0; i < neighs.size(); i++) { - *it = getColor (neighs[i]); - it ++; - *it = neighs[i]->factor().indexOf (varNode->varId()); - it ++; + sign.push_back (make_pair ( + getColor (neighs[i]), + neighs[i]->factor().indexOf (varNode->varId()))); } - *it = getColor (varNode); + std::sort (sign.begin(), sign.end()); + sign.push_back (make_pair (getColor (varNode), 0)); return sign; } -const Signature& +FacSignature CFactorGraph::getSignature (const FacNode* facNode) { - Signature& sign = facSignatures_[facNode->getIndex()]; - Colors::iterator it = sign.colors.begin(); const VarNodes& neighs = facNode->neighbors(); + FacSignature sign; + sign.reserve (neighs.size() + 1); for (unsigned i = 0; i < neighs.size(); i++) { - *it = getColor (neighs[i]); - it ++; + sign.push_back (getColor (neighs[i])); } - std::sort (sign.colors.begin(), -- sign.colors.end()); - *it = getColor (facNode); + sign.push_back (getColor (facNode)); return sign; } FactorGraph* -CFactorGraph::getGroundFactorGraph (void) const +CFactorGraph::getGroundFactorGraph (void) { FactorGraph* fg = new FactorGraph(); for (unsigned i = 0; i < varClusters_.size(); i++) { - VarNode* newVar = new VarNode (varClusters_[i]->members()[0]); + VarNode* newVar = new VarNode (varClusters_[i]->first()); varClusters_[i]->setRepresentative (newVar); fg->addVarNode (newVar); } - for (unsigned i = 0; i < facClusters_.size(); i++) { - const VarClusters& myVarClusters = facClusters_[i]->varClusters(); - Vars myGroundVars; - myGroundVars.reserve (myVarClusters.size()); - for (unsigned j = 0; j < myVarClusters.size(); j++) { - VarNode* v = myVarClusters[j]->getRepresentative(); - myGroundVars.push_back (v); + Vars vars; + const VarClusters& clusters = facClusters_[i]->varClusters(); + for (unsigned j = 0; j < clusters.size(); j++) { + vars.push_back (clusters[j]->representative()); } + const Factor& groundFac = facClusters_[i]->first()->factor(); FacNode* fn = new FacNode (Factor ( - myGroundVars, - facClusters_[i]->members()[0]->factor().params(), - facClusters_[i]->members()[0]->factor().distId())); + vars, groundFac.params(), groundFac.distId())); facClusters_[i]->setRepresentative (fn); fg->addFacNode (fn); - for (unsigned j = 0; j < myGroundVars.size(); j++) { - fg->addEdge (static_cast (myGroundVars[j]), fn); + for (unsigned j = 0; j < vars.size(); j++) { + fg->addEdge (static_cast (vars[j]), fn); } } return fg; @@ -280,29 +263,21 @@ CFactorGraph::getGroundFactorGraph (void) const unsigned CFactorGraph::getEdgeCount ( const FacCluster* fc, - const VarCluster* vc) const + const VarCluster* vc, + unsigned index) const { unsigned count = 0; - VarId vid = vc->members().front()->varId(); - const FacNodes& members = fc->members(); - for (unsigned i = 0; i < members.size(); i++) { - if (members[i]->factor().contains (vid)) { + VarId reprVid = vc->representative()->varId(); + VarNode* groundVar = groundFg_->getVarNode (reprVid); + const FacNodes& neighs = groundVar->neighbors(); + for (unsigned i = 0; i < neighs.size(); i++) { + FacNodes::const_iterator it; + it = std::find (fc->members().begin(), fc->members().end(), neighs[i]); + if (it != fc->members().end() && + (*it)->factor().indexOf (reprVid) == (int)index) { count ++; } } - if (Constants::DEBUG > 0) { - const VarNodes& vars = vc->members(); - for (unsigned i = 1; i < vars.size(); i++) { - VarId vid = vars[i]->varId(); - unsigned count2 = 0; - for (unsigned i = 0; i < members.size(); i++) { - if (members[i]->factor().contains (vid)) { - count2 ++; - } - } - assert (count == count2); - } - } return count; } @@ -327,7 +302,6 @@ CFactorGraph::printGroups ( cout << endl; } } - count = 1; cout << endl << "factor groups:" << endl; for (FacSignMap::const_iterator it = facGroups.begin(); diff --git a/packages/CLPBN/clpbn/bp/CFactorGraph.h b/packages/CLPBN/clpbn/bp/CFactorGraph.h index de4791a82..468787b89 100644 --- a/packages/CLPBN/clpbn/bp/CFactorGraph.h +++ b/packages/CLPBN/clpbn/bp/CFactorGraph.h @@ -10,17 +10,19 @@ class VarCluster; class FacCluster; -class Signature; -class SignatureHash; +class VarSignatureHash; +class FacSignatureHash; typedef long Color; typedef vector Colors; +typedef vector> VarSignature; +typedef vector FacSignature; typedef unordered_map DistColorMap; typedef unordered_map VarColorMap; -typedef unordered_map VarSignMap; -typedef unordered_map FacSignMap; +typedef unordered_map VarSignMap; +typedef unordered_map FacSignMap; typedef vector VarClusters; typedef vector FacClusters; @@ -28,53 +30,27 @@ typedef vector FacClusters; typedef unordered_map VarId2VarCluster; -struct Signature +struct VarSignatureHash { - Signature (unsigned size) : colors(size) { } - - bool operator< (const Signature& sig) const + size_t operator() (const VarSignature &sig) const { - if (colors.size() < sig.colors.size()) { - return true; - } else if (colors.size() > sig.colors.size()) { - return false; - } else { - for (unsigned i = 0; i < colors.size(); i++) { - if (colors[i] < sig.colors[i]) { - return true; - } else if (colors[i] > sig.colors[i]) { - return false; - } - } + size_t val = hash()(sig.size()); + for (unsigned i = 0; i < sig.size(); i++) { + val ^= hash()(sig[i].first); + val ^= hash()(sig[i].second); } - return false; + return val; } - - bool operator== (const Signature& sig) const - { - if (colors.size() != sig.colors.size()) { - return false; - } - for (unsigned i = 0; i < colors.size(); i++) { - if (colors[i] != sig.colors[i]) { - return false; - } - } - return true; - } - - Colors colors; }; - -struct SignatureHash +struct FacSignatureHash { - size_t operator() (const Signature &sig) const + size_t operator() (const FacSignature &sig) const { - size_t val = hash()(sig.colors.size()); - for (unsigned i = 0; i < sig.colors.size(); i++) { - val ^= hash()(sig.colors[i]); + size_t val = hash()(sig.size()); + for (unsigned i = 0; i < sig.size(); i++) { + val ^= hash()(sig[i]); } return val; } @@ -87,19 +63,16 @@ class VarCluster public: VarCluster (const VarNodes& vs) : members_(vs) { } + const VarNode* first (void) const { return members_.front(); } + const VarNodes& members (void) const { return members_; } - const FacClusters& facClusters (void) const { return facClusters_; } - - void addFacCluster (FacCluster* fc) { facClusters_.push_back (fc); } - - VarNode* getRepresentative (void) const { return repr_; } + VarNode* representative (void) const { return repr_; } void setRepresentative (VarNode* vn) { repr_ = vn; } private: VarNodes members_; - FacClusters facClusters_; VarNode* repr_; }; @@ -108,26 +81,17 @@ class FacCluster { public: FacCluster (const FacNodes& fcs, const VarClusters& vcs) - : members_(fcs), varClusters_(vcs) - { - for (unsigned i = 0; i < varClusters_.size(); i++) { - varClusters_[i]->addFacCluster (this); - } - } + : members_(fcs), varClusters_(vcs) { } + + const FacNode* first (void) const { return members_.front(); } const FacNodes& members (void) const { return members_; } - const VarClusters& varClusters (void) const { return varClusters_; } + VarClusters& varClusters (void) { return varClusters_; } - FacNode* getRepresentative (void) const { return repr_; } + FacNode* representative (void) const { return repr_; } void setRepresentative (FacNode* fn) { repr_ = fn; } - - bool containsGround (const FacNode* fn) const - { - return std::find (members_.begin(), members_.end(), fn) - != members_.end(); - } private: FacNodes members_; @@ -147,15 +111,16 @@ class CFactorGraph const FacClusters& facClusters (void) { return facClusters_; } - VarNode* getEquivalentVariable (VarId vid) + VarNode* getEquivalent (VarId vid) { VarCluster* vc = vid2VarCluster_.find (vid)->second; - return vc->getRepresentative(); + return vc->representative(); } - FactorGraph* getGroundFactorGraph (void) const; + FactorGraph* getGroundFactorGraph (void); - unsigned getEdgeCount (const FacCluster*, const VarCluster*) const; + unsigned getEdgeCount (const FacCluster*, + const VarCluster*, unsigned index) const; static bool checkForIdenticalFactors; @@ -184,11 +149,6 @@ class CFactorGraph facColors_[fn->getIndex()] = c; } - VarCluster* getVariableCluster (VarId vid) const - { - return vid2VarCluster_.find (vid)->second; - } - void findIdenticalFactors (void); void setInitialColors (void); @@ -197,21 +157,19 @@ class CFactorGraph void createClusters (const VarSignMap&, const FacSignMap&); - const Signature& getSignature (const VarNode*); + VarSignature getSignature (const VarNode*); - const Signature& getSignature (const FacNode*); + FacSignature getSignature (const FacNode*); void printGroups (const VarSignMap&, const FacSignMap&) const; - Color freeColor_; - Colors varColors_; - Colors facColors_; - vector varSignatures_; - vector facSignatures_; - VarClusters varClusters_; - FacClusters facClusters_; - VarId2VarCluster vid2VarCluster_; - const FactorGraph* groundFg_; + Color freeColor_; + Colors varColors_; + Colors facColors_; + VarClusters varClusters_; + FacClusters facClusters_; + VarId2VarCluster vid2VarCluster_; + const FactorGraph* groundFg_; }; #endif // HORUS_CFACTORGRAPH_H diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.cpp b/packages/CLPBN/clpbn/bp/CbpSolver.cpp index 812f0e85a..4cecc8f86 100644 --- a/packages/CLPBN/clpbn/bp/CbpSolver.cpp +++ b/packages/CLPBN/clpbn/bp/CbpSolver.cpp @@ -1,5 +1,6 @@ #include "CbpSolver.h" +vector CbpSolver::counts; CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg) { @@ -24,16 +25,14 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg) Statistics::updateCompressingStatistics (nrGroundVars, nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless); } - // cout << "uncompressed factor graph:" << endl; - // cout << " " << fg.nrVarNodes() << " variables " << endl; - // cout << " " << fg.nrFacNodes() << " factors " << endl; - // cout << "compressed factor graph:" << endl; - // cout << " " << fg_->nrVarNodes() << " variables " << endl; - // cout << " " << fg_->nrFacNodes() << " factors " << endl; - // Util::printHeader ("Compressed Factor Graph"); - // fg_->print(); - // Util::printHeader ("Uncompressed Factor Graph"); - // fg.print(); + if (Constants::DEBUG >= 5) { + cout << "uncompressed factor graph:" << endl; + cout << " " << fg.nrVarNodes() << " variables " << endl; + cout << " " << fg.nrFacNodes() << " factors " << endl; + cout << "compressed factor graph:" << endl; + cout << " " << fg_->nrVarNodes() << " variables " << endl; + cout << " " << fg_->nrFacNodes() << " factors " << endl; + } } @@ -81,8 +80,8 @@ CbpSolver::getPosterioriOf (VarId vid) if (runned_ == false) { runSolver(); } - assert (cfg_->getEquivalentVariable (vid)); - VarNode* var = cfg_->getEquivalentVariable (vid); + assert (cfg_->getEquivalent (vid)); + VarNode* var = cfg_->getEquivalent (vid); Params probs; if (var->hasEvidence()) { probs.resize (var->range(), LogAware::noEvidence()); @@ -115,7 +114,7 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids) { VarIds eqVarIds; for (unsigned i = 0; i < jointVids.size(); i++) { - VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]); + VarNode* vn = cfg_->getEquivalent (jointVids[i]); eqVarIds.push_back (vn->varId()); } return BpSolver::getJointDistributionOf (eqVarIds); @@ -125,15 +124,20 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids) void CbpSolver::createLinks (void) -{ +{ const FacClusters& fcs = cfg_->facClusters(); for (unsigned i = 0; i < fcs.size(); i++) { const VarClusters& vcs = fcs[i]->varClusters(); for (unsigned j = 0; j < vcs.size(); j++) { - unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]); + unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j); + if (Constants::DEBUG >= 5) { + cout << "creating edge " ; + cout << fcs[i]->representative()->getLabel() << " -> " ; + cout << vcs[j]->representative()->label(); + cout << " idx=" << j << ", count=" << count << endl; + } links_.push_back (new CbpSolverLink ( - fcs[i]->getRepresentative(), - vcs[j]->getRepresentative(), c)); + fcs[i]->representative(), vcs[j]->representative(), j, count)); } } } @@ -213,47 +217,135 @@ CbpSolver::maxResidualSchedule (void) -Params -CbpSolver::getVar2FactorMsg (const SpLink* link) const +void +CbpSolver::calculateFactor2VariableMsg (SpLink* _link) { - Params msg; + CbpSolverLink* link = static_cast (_link); + FacNode* src = link->getFactor(); + const VarNode* dst = link->getVariable(); + const SpLinkSet& links = ninf(src)->getLinks(); + // calculate the product of messages that were sent + // to factor `src', except from var `dst' + unsigned msgSize = 1; + for (unsigned i = 0; i < links.size(); i++) { + msgSize *= links[i]->getVariable()->range(); + } + unsigned repetitions = 1; + Params msgProduct (msgSize, LogAware::multIdenty()); + if (Globals::logDomain) { + for (int i = links.size() - 1; i >= 0; i--) { + const CbpSolverLink* cl = static_cast (links[i]); + if ( ! (cl->getVariable() == dst && cl->index() == link->index())) { + if (Constants::DEBUG >= 5) { + cout << " message from " << links[i]->getVariable()->label(); + cout << ": " ; + } + Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); + repetitions *= links[i]->getVariable()->range(); + if (Constants::DEBUG >= 5) { + cout << endl; + } + } else { + unsigned range = links[i]->getVariable()->range(); + Util::add (msgProduct, Params (range, 0.0), repetitions); + repetitions *= range; + } + } + } else { + for (int i = links.size() - 1; i >= 0; i--) { + const CbpSolverLink* cl = static_cast (links[i]); + if ( ! (cl->getVariable() == dst && cl->index() == link->index())) { + if (Constants::DEBUG >= 5) { + cout << " message from " << links[i]->getVariable()->label(); + cout << ": " ; + } + Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); + repetitions *= links[i]->getVariable()->range(); + if (Constants::DEBUG >= 5) { + cout << endl; + } + } else { + unsigned range = links[i]->getVariable()->range(); + Util::multiply (msgProduct, Params (range, 1.0), repetitions); + repetitions *= range; + } + } + } + Factor result (src->factor().arguments(), + src->factor().ranges(), msgProduct); + assert (msgProduct.size() == src->factor().size()); + if (Globals::logDomain) { + for (unsigned i = 0; i < result.size(); i++) { + result[i] += src->factor()[i]; + } + } else { + for (unsigned i = 0; i < result.size(); i++) { + result[i] *= src->factor()[i]; + } + } + if (Constants::DEBUG >= 5) { + cout << " message product: " << msgProduct << endl; + cout << " original factor: " << src->factor().params() << endl; + cout << " factor product: " << result.params() << endl; + } + result.sumOutAllExceptIndex (link->index()); + if (Constants::DEBUG >= 5) { + cout << " marginalized: " << result.params() << endl; + } + link->getNextMessage() = result.params(); + LogAware::normalize (link->getNextMessage()); + if (Constants::DEBUG >= 5) { + cout << " curr msg: " << link->getMessage() << endl; + cout << " next msg: " << link->getNextMessage() << endl; + } +} + + + +Params +CbpSolver::getVar2FactorMsg (const SpLink* _link) const +{ + const CbpSolverLink* link = static_cast (_link); const VarNode* src = link->getVariable(); const FacNode* dst = link->getFactor(); - const CbpSolverLink* l = static_cast (link); + Params msg; if (src->hasEvidence()) { msg.resize (src->range(), LogAware::noEvidence()); double value = link->getMessage()[src->getEvidence()]; - msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1); + if (Constants::DEBUG >= 5) { + msg[src->getEvidence()] = value; + cout << msg << "^" << link->nrEdges() << "-1" ; + } + msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1); } else { msg = link->getMessage(); - LogAware::pow (msg, l->nrEdges() - 1); - } - if (Constants::DEBUG >= 5) { - cout << " " << "init: " << msg << " " << src->hasEvidence() << endl; + if (Constants::DEBUG >= 5) { + cout << msg << "^" << link->nrEdges() << "-1" ; + } + LogAware::pow (msg, link->nrEdges() - 1); } const SpLinkSet& links = ninf(src)->getLinks(); if (Globals::logDomain) { for (unsigned i = 0; i < links.size(); i++) { - if (links[i]->getFactor() != dst) { - CbpSolverLink* l = static_cast (links[i]); - Util::add (msg, l->poweredMessage()); + CbpSolverLink* cl = static_cast (links[i]); + if ( ! (cl->getFactor() == dst && cl->index() == link->index())) { + CbpSolverLink* cl = static_cast (links[i]); + Util::add (msg, cl->poweredMessage()); } } } else { for (unsigned i = 0; i < links.size(); i++) { - if (links[i]->getFactor() != dst) { - CbpSolverLink* l = static_cast (links[i]); - Util::multiply (msg, l->poweredMessage()); + CbpSolverLink* cl = static_cast (links[i]); + if ( ! (cl->getFactor() == dst && cl->index() == link->index())) { + Util::multiply (msg, cl->poweredMessage()); if (Constants::DEBUG >= 5) { - cout << " msg from " << l->getFactor()->getLabel() << ": " ; - cout << l->poweredMessage() << endl; + cout << " x " << cl->getNextMessage() << "^" << link->nrEdges(); } } } } - if (Constants::DEBUG >= 5) { - cout << " result = " << msg << endl; + cout << " = " << msg; } return msg; } @@ -264,13 +356,14 @@ void CbpSolver::printLinkInformation (void) const { for (unsigned i = 0; i < links_.size(); i++) { - CbpSolverLink* l = static_cast (links_[i]); - cout << l->toString() << ":" << endl; - cout << " curr msg = " << l->getMessage() << endl; - cout << " next msg = " << l->getNextMessage() << endl; - cout << " nr edges = " << l->nrEdges() << endl; - cout << " powered = " << l->poweredMessage() << endl; - cout << " residual = " << l->getResidual() << endl; + CbpSolverLink* cl = static_cast (links_[i]); + cout << cl->toString() << ":" << endl; + cout << " curr msg = " << cl->getMessage() << endl; + cout << " next msg = " << cl->getNextMessage() << endl; + cout << " index = " << cl->index() << endl; + cout << " nr edges = " << cl->nrEdges() << endl; + cout << " powered = " << cl->poweredMessage() << endl; + cout << " residual = " << cl->getResidual() << endl; } } diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.h b/packages/CLPBN/clpbn/bp/CbpSolver.h index 7ff3a2dc4..115700b78 100644 --- a/packages/CLPBN/clpbn/bp/CbpSolver.h +++ b/packages/CLPBN/clpbn/bp/CbpSolver.h @@ -9,10 +9,12 @@ class Factor; class CbpSolverLink : public SpLink { public: - CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c) - : SpLink (fn, vn), nrEdges_(c), + CbpSolverLink (FacNode* fn, VarNode* vn, unsigned idx, unsigned count) + : SpLink (fn, vn), index_(idx), nrEdges_(count), pwdMsg_(vn->range(), LogAware::one()) { } + unsigned index (void) const { return index_; } + unsigned nrEdges (void) const { return nrEdges_; } const Params& poweredMessage (void) const { return pwdMsg_; } @@ -26,6 +28,7 @@ class CbpSolverLink : public SpLink } private: + unsigned index_; unsigned nrEdges_; Params pwdMsg_; }; @@ -35,6 +38,8 @@ class CbpSolverLink : public SpLink class CbpSolver : public BpSolver { public: + static vector counts; + CbpSolver (const FactorGraph& fg); ~CbpSolver (void); @@ -51,6 +56,8 @@ class CbpSolver : public BpSolver void maxResidualSchedule (void); + void calculateFactor2VariableMsg (SpLink*); + Params getVar2FactorMsg (const SpLink*) const; void printLinkInformation (void) const; diff --git a/packages/CLPBN/clpbn/bp/Factor.cpp b/packages/CLPBN/clpbn/bp/Factor.cpp index 6b883e204..7f8b7e611 100644 --- a/packages/CLPBN/clpbn/bp/Factor.cpp +++ b/packages/CLPBN/clpbn/bp/Factor.cpp @@ -105,7 +105,7 @@ Factor::sumOutIndex (unsigned idx) // on the left of `var', with the states of the remaining vars fixed unsigned leftVarOffset = 1; - for (int i = args_.size() - 1; i > idx; i--) { + for (int i = args_.size() - 1; i > (int)idx; i--) { varOffset *= ranges_[i]; leftVarOffset *= ranges_[i]; } @@ -151,7 +151,7 @@ Factor::sumOutIndex (unsigned idx) void Factor::sumOutAllExceptIndex (unsigned idx) { - int i = idx; + int i = (int)idx; while (args_.size() > i + 1) { sumOutLastVariable(); } diff --git a/packages/CLPBN/clpbn/bp/FactorGraph.cpp b/packages/CLPBN/clpbn/bp/FactorGraph.cpp index ae6605208..acf5d3adf 100644 --- a/packages/CLPBN/clpbn/bp/FactorGraph.cpp +++ b/packages/CLPBN/clpbn/bp/FactorGraph.cpp @@ -163,7 +163,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName) reverse (vids.begin(), vids.end()); Factor f (vids, ranges, params); reverse (vids.begin(), vids.end()); - f.reorderArguments (vids); + f.reorderArguments (vids); addFactor (f); } is.close();