diff --git a/packages/CLPBN/clpbn/bp/ConstraintTree.cpp b/packages/CLPBN/clpbn/bp/ConstraintTree.cpp index 46899b47b..9bdd7f9fb 100644 --- a/packages/CLPBN/clpbn/bp/ConstraintTree.cpp +++ b/packages/CLPBN/clpbn/bp/ConstraintTree.cpp @@ -12,7 +12,7 @@ CTNode::mergeSubtree (CTNode* n, bool updateLevels) if (updateLevels) { updateChildLevels (n, level_ + 1); } - CTChilds::iterator chIt = findChild (n); + CTChilds::iterator chIt = childs_.find (n); if (chIt != childs_.end()) { assert ((*chIt)->symbol() == n->symbol()); const CTChilds& childsToAdd = n->childs(); @@ -31,10 +31,8 @@ CTNode::mergeSubtree (CTNode* n, bool updateLevels) void CTNode::removeChild (CTNode* child) { - CTChilds::iterator it; - it = findChild (child); - assert (it != childs_.end()); - childs_.erase (it); + assert (childs_.contains (child)); + childs_.remove (child); } @@ -119,7 +117,7 @@ CTNode::copySubtree (const CTNode* root1) for (CTChilds::const_iterator chIt = n1->childs().begin(); chIt != n1->childs().end(); ++ chIt) { CTNode* chCopy = new CTNode (**chIt); - n2->childs().push_back (chCopy); + n2->childs().insert_sorted (chCopy); if ((*chIt)->nrChilds() != 0) { stack.push_back (StackPair (*chIt, chCopy)); } @@ -860,7 +858,7 @@ ConstraintTree::copyLogVar (LogVar X_1, LogVar X_2) moveToBottom ({X_1}); CTNodes leafs = getNodesAtLevel (logVars_.size()); for (unsigned i = 0; i < leafs.size(); i++) { - leafs[i]->childs().push_back ( + leafs[i]->childs().insert_sorted ( new CTNode (leafs[i]->symbol(), leafs[i]->level() + 1)); } logVars_.push_back (X_2); @@ -989,7 +987,7 @@ ConstraintTree::swapLogVar (LogVar X) CTNode* childCopy = new CTNode ( (*ccIt)->symbol(), (*ccIt)->level() + 1, (*gsIt)->childs()); (*gsIt)->removeChilds(); - (*gsIt)->childs().push_back (childCopy); + (*gsIt)->childs().insert_sorted (childCopy); (*gsIt)->setLevel ((*gsIt)->level() - 1); (*nodeIt)->mergeSubtree ((*gsIt), false); } @@ -1114,18 +1112,18 @@ ConstraintTree::split ( chIt1 != childs1.end(); ++ chIt1) { CTChilds::iterator chIt2 = n2->findSymbol ((*chIt1)->symbol()); if (chIt2 == n2->childs().end()) { - exclChilds.push_back (CTNode::copySubtree (*chIt1)); + exclChilds.insert_sorted (CTNode::copySubtree (*chIt1)); } else { if ((*chIt1)->level() == stopLevel) { - commChilds.push_back (CTNode::copySubtree (*chIt1)); + commChilds.insert_sorted (CTNode::copySubtree (*chIt1)); } else { CTChilds lowerCommChilds, lowerExclChilds; split (*chIt1, *chIt2, lowerCommChilds, lowerExclChilds, stopLevel); if (lowerCommChilds.empty() == false) { - commChilds.push_back (new CTNode (**chIt1, lowerCommChilds)); + commChilds.insert_sorted (new CTNode (**chIt1, lowerCommChilds)); } if (lowerExclChilds.empty() == false) { - exclChilds.push_back (new CTNode (**chIt1, lowerExclChilds)); + exclChilds.insert_sorted (new CTNode (**chIt1, lowerExclChilds)); } } } diff --git a/packages/CLPBN/clpbn/bp/ConstraintTree.h b/packages/CLPBN/clpbn/bp/ConstraintTree.h index f1206ba0e..802422858 100644 --- a/packages/CLPBN/clpbn/bp/ConstraintTree.h +++ b/packages/CLPBN/clpbn/bp/ConstraintTree.h @@ -34,7 +34,7 @@ class CTNode private: - typedef SortedVector CTChilds_; + typedef TinySet CTChilds_; public: @@ -64,11 +64,6 @@ class CTNode bool isLeaf (void) const { return childs_.empty(); } - CTChilds_::iterator findChild (CTNode* n) - { - return childs_.find (n); - } - CTChilds_::iterator findSymbol (Symbol symb) { CTNode tmp (symb, 0); @@ -102,7 +97,7 @@ class CTNode ostream& operator<< (ostream &out, const CTNode&); -typedef SortedVector CTChilds; +typedef TinySet CTChilds; class ConstraintTree diff --git a/packages/CLPBN/clpbn/bp/TODO b/packages/CLPBN/clpbn/bp/TODO index e247b6de5..9001fdd32 100644 --- a/packages/CLPBN/clpbn/bp/TODO +++ b/packages/CLPBN/clpbn/bp/TODO @@ -1,10 +1,10 @@ - Refactor sum out in factor - Add a way to sum out several vars at the same time - Receive ranges as a constant reference in Indexer -- Merge TinySet and SortedVector classes - Check if evidence remains in the compressed factor graph - Consider using hashs instead of vectors of colors to calculate the groups in counting bp - use more psize_t instead of unsigned for looping through params - use more Util::abort and Util::vectorIndex -- LogVar should not cast to int + + diff --git a/packages/CLPBN/clpbn/bp/TinySet.h b/packages/CLPBN/clpbn/bp/TinySet.h index 459307702..8c1d01f1c 100644 --- a/packages/CLPBN/clpbn/bp/TinySet.h +++ b/packages/CLPBN/clpbn/bp/TinySet.h @@ -7,53 +7,73 @@ using namespace std; -template +template > class TinySet { public: - TinySet (void) { } - TinySet (const T& t) + TinySet (const TinySet& s) + : vec_(s.vec_), cmp_(s.cmp_) { } + + TinySet (const Compare& cmp = Compare()) + : vec_(), cmp_(cmp) { } + + TinySet (const T& t, const Compare& cmp = Compare()) + : vec_(1, t), cmp_(cmp) { } + + TinySet (const vector& elements, const Compare& cmp = Compare()) + : vec_(elements), cmp_(cmp) { - elements_.push_back (t); + std::sort (begin(), end(), cmp_); } - TinySet (const vector& elements) + typedef typename vector::iterator iterator; + typedef typename vector::const_iterator const_iterator; + + iterator insert (const T& t) { - elements_.reserve (elements.size()); - for (unsigned i = 0; i < elements.size(); i++) { - insert (elements[i]); + iterator it = std::lower_bound (begin(), end(), t, cmp_); + if (it == end() || cmp_(t, *it)) { + vec_.insert (it, t); } + return it; } - TinySet (const TinySet& s) : elements_(s.elements_) { } - - void insert (const T& t) + void insert_sorted (const T& t) { - typename vector::iterator it; - it = std::lower_bound (elements_.begin(), elements_.end(), t); - if (it == elements_.end() || *it != t) { - elements_.insert (it, t); - } + vec_.push_back (t); + assert (consistent()); } void remove (const T& t) { - typename vector::iterator it; - it = std::lower_bound (elements_.begin(), elements_.end(), t); - if (it != elements_.end()) { - elements_.erase (it); + iterator it = std::lower_bound (begin(), end(), t, cmp_); + if (it != end()) { + vec_.erase (it); } } + const_iterator find (const T& t) const + { + const_iterator it = std::lower_bound (begin(), end(), t, cmp_); + return it == end() || cmp_(t, *it) ? end() : it; + } + + iterator find (const T& t) + { + iterator it = std::lower_bound (begin(), end(), t, cmp_); + return it == end() || cmp_(t, *it) ? end() : it; + } + /* set union */ TinySet operator| (const TinySet& s) const { TinySet res; std::set_union ( - elements_.begin(), elements_.end(), - s.elements_.begin(), s.elements_.end(), - std::back_inserter (res.elements_)); + vec_.begin(), vec_.end(), + s.vec_.begin(), s.vec_.end(), + std::back_inserter (res.vec_), + cmp_); return res; } @@ -62,9 +82,10 @@ class TinySet { TinySet res; std::set_intersection ( - elements_.begin(), elements_.end(), - s.elements_.begin(), s.elements_.end(), - std::back_inserter (res.elements_)); + vec_.begin(), vec_.end(), + s.vec_.begin(), s.vec_.end(), + std::back_inserter (res.vec_), + cmp_); return res; } @@ -73,9 +94,10 @@ class TinySet { TinySet res; std::set_difference ( - elements_.begin(), elements_.end(), - s.elements_.begin(), s.elements_.end(), - std::back_inserter (res.elements_)); + vec_.begin(), vec_.end(), + s.vec_.begin(), s.vec_.end(), + std::back_inserter (res.vec_), + cmp_); return res; } @@ -97,25 +119,27 @@ class TinySet bool contains (const T& t) const { return std::binary_search ( - elements_.begin(), elements_.end(), t); + vec_.begin(), vec_.end(), t, cmp_); } bool contains (const TinySet& s) const { return std::includes ( - elements_.begin(), - elements_.end(), - s.elements_.begin(), - s.elements_.end()); + vec_.begin(), + vec_.end(), + s.vec_.begin(), + s.vec_.end(), + cmp_); } bool in (const TinySet& s) const { return std::includes ( - s.elements_.begin(), - s.elements_.end(), - elements_.begin(), - elements_.end()); + s.vec_.begin(), + s.vec_.end(), + vec_.begin(), + vec_.end(), + cmp_); } bool intersects (const TinySet& s) const @@ -123,148 +147,97 @@ class TinySet return (*this & s).size() > 0; } - T operator[] (unsigned i) const + const T& operator[] (typename vector::size_type i) const { - return elements_[i]; - } - - const vector& elements (void) const - { - return elements_; + return vec_[i]; } T front (void) const { - return elements_.front(); + return vec_.front(); + } + + T& front (void) + { + return vec_.front(); } T back (void) const { - return elements_.back(); + return vec_.back(); } - unsigned size (void) const + T& back (void) { - return elements_.size(); + return vec_.back(); + } + + const vector& elements (void) const + { + return vec_; } bool empty (void) const { - return elements_.size() == 0; + return size() == 0; } - typedef typename std::vector::const_iterator const_iterator; - - const_iterator begin (void) const + typename vector::size_type size (void) const { - return elements_.begin(); + return vec_.size(); } - const_iterator end (void) const + void clear (void) { - return elements_.end(); + vec_.clear(); } + void reserve (typename vector::size_type size) + { + vec_.reserve (size); + } + + iterator begin (void) { return vec_.begin(); } + iterator end (void) { return vec_.end(); } + const_iterator begin (void) const { return vec_.begin(); } + const_iterator end (void) const { return vec_.end(); } + friend bool operator== (const TinySet& s1, const TinySet& s2) { - return s1.elements_ == s2.elements_; + return s1.vec_ == s2.vec_; } friend bool operator!= (const TinySet& s1, const TinySet& s2) { - return s1.elements_ != s2.elements_; + return ! (s1.vec_ == s2.vec_); } - friend std::ostream& operator << (std::ostream& out, const TinySet& s) + friend std::ostream& operator << (std::ostream& out, const TinySet& s) { out << "{" ; - for (unsigned i = 0; i < s.size(); i++) { - out << ((i != 0) ? "," : "") << s.elements()[i]; + typename vector::size_type i; + for (i = 0; i < s.size(); i++) { + out << ((i != 0) ? "," : "") << s.vec_[i]; } out << "}" ; return out; } - protected: - vector elements_; -}; - - - -template > -class SortedVector -{ - public: - SortedVector (const Compare& c = Compare()) : vec_(), cmp_(c) { } - /* - template - SortedVector (InputIterator first, InputIterator last, - const Compare& c = Compare()) : vec_(first, last), cmp_(c) - { - std::sort (begin(), end(), cmp_); - } - */ - typedef typename vector::iterator iterator; - typedef typename vector::const_iterator const_iterator; - iterator begin (void) { return vec_.begin(); } - iterator end (void) { return vec_.end(); } - const_iterator begin (void) const { return vec_.begin(); } - const_iterator end (void) const { return vec_.end(); } - - iterator insert (const T& t) - { - iterator i = std::lower_bound (begin(), end(), t, cmp_); - if (i == end() || cmp_(t, *i)) - vec_.insert(i, t); - return i; - } - - void push_back (const T& t) - { - vec_.push_back (t); - assert (consistent()); - } - - const_iterator find (const T& t) const - { - const_iterator i = std::lower_bound (begin(), end(), t, cmp_); - return i == end() || cmp_(t, *i) ? end() : i; - } - - iterator find (const T& t) - { - iterator i = std::lower_bound (begin(), end(), t, cmp_); - return i == end() || cmp_(t, *i) ? end() : i; - } - - const vector& elements (void) { return vec_; } - - void reserve (unsigned space) { vec_.reserve (space); } - - unsigned size (void) const { return vec_.size(); } - - bool empty (void) const { return vec_.empty(); } - - void clear (void) { vec_.clear(); } - - iterator erase (iterator it) { return vec_.erase (it); } - private: - bool consistent (void) const { - for (unsigned i = 0; i < vec_.size() - 1; i++) { - if (cmp_(vec_[i], vec_[i+1]) == false) { + typename vector::size_type i; + for (i = 0; i < vec_.size() - 1; i++) { + if (cmp_(vec_[i], vec_[i + 1]) == false) { return false; } } return true; } - std::vector vec_; - Compare cmp_; + vector vec_; + Compare cmp_; }; - #endif // HORUS_TINYSET_H