186 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			186 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#ifndef YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
 | 
						|
#define YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
 | 
						|
 | 
						|
#include <vector>
 | 
						|
#include <unordered_map>
 | 
						|
#include <string>
 | 
						|
#include <fstream>
 | 
						|
 | 
						|
#include "Factor.h"
 | 
						|
#include "BayesBallGraph.h"
 | 
						|
#include "Horus.h"
 | 
						|
 | 
						|
 | 
						|
namespace Horus {
 | 
						|
 | 
						|
class FacNode;
 | 
						|
 | 
						|
 | 
						|
class VarNode : public Var {
 | 
						|
  public:
 | 
						|
    VarNode (VarId varId, unsigned nrStates,
 | 
						|
        int evidence = Constants::unobserved)
 | 
						|
        : Var (varId, nrStates, evidence) { }
 | 
						|
 | 
						|
    VarNode (const Var* v) : Var (v) { }
 | 
						|
 | 
						|
    void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
 | 
						|
 | 
						|
    const FacNodes& neighbors() const { return neighs_; }
 | 
						|
 | 
						|
  private:
 | 
						|
    FacNodes neighs_;
 | 
						|
 | 
						|
    DISALLOW_COPY_AND_ASSIGN (VarNode);
 | 
						|
};
 | 
						|
 | 
						|
 | 
						|
 | 
						|
class FacNode {
 | 
						|
  public:
 | 
						|
    FacNode (const Factor& f) : factor_(f), index_(-1) { }
 | 
						|
 | 
						|
    const Factor& factor() const { return factor_; }
 | 
						|
 | 
						|
    Factor& factor() { return factor_; }
 | 
						|
 | 
						|
    void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
 | 
						|
 | 
						|
    const VarNodes& neighbors() const { return neighs_; }
 | 
						|
 | 
						|
    size_t getIndex() const { return index_; }
 | 
						|
 | 
						|
    void setIndex (size_t index) { index_ = index; }
 | 
						|
 | 
						|
    std::string getLabel() { return factor_.getLabel(); }
 | 
						|
 | 
						|
  private:
 | 
						|
    VarNodes  neighs_;
 | 
						|
    Factor    factor_;
 | 
						|
    size_t    index_;
 | 
						|
 | 
						|
    DISALLOW_COPY_AND_ASSIGN (FacNode);
 | 
						|
};
 | 
						|
 | 
						|
 | 
						|
 | 
						|
class FactorGraph {
 | 
						|
  public:
 | 
						|
    FactorGraph() : bayesFactors_(false) { }
 | 
						|
 | 
						|
    FactorGraph (const FactorGraph&);
 | 
						|
 | 
						|
   ~FactorGraph();
 | 
						|
 | 
						|
    const VarNodes& varNodes() const { return varNodes_; }
 | 
						|
 | 
						|
    const FacNodes& facNodes() const { return facNodes_; }
 | 
						|
 | 
						|
    void setFactorsAsBayesian() { bayesFactors_ = true; }
 | 
						|
 | 
						|
    bool bayesianFactors() const { return bayesFactors_; }
 | 
						|
 | 
						|
    size_t nrVarNodes() const { return varNodes_.size(); }
 | 
						|
 | 
						|
    size_t nrFacNodes() const { return facNodes_.size(); }
 | 
						|
 | 
						|
    VarNode* getVarNode (VarId vid) const;
 | 
						|
 | 
						|
    void addFactor (const Factor& factor);
 | 
						|
 | 
						|
    void addVarNode (VarNode*);
 | 
						|
 | 
						|
    void addFacNode (FacNode*);
 | 
						|
 | 
						|
    void addEdge (VarNode*, FacNode*);
 | 
						|
 | 
						|
    bool isTree() const;
 | 
						|
 | 
						|
    BayesBallGraph& getStructure();
 | 
						|
 | 
						|
    void print() const;
 | 
						|
 | 
						|
    void exportToLibDai (const char*) const;
 | 
						|
 | 
						|
    void exportToUai (const char*) const;
 | 
						|
 | 
						|
    void exportToGraphViz (const char*) const;
 | 
						|
 | 
						|
    FactorGraph& operator= (const FactorGraph&);
 | 
						|
 | 
						|
    static FactorGraph readFromUaiFormat (const char*);
 | 
						|
 | 
						|
    static FactorGraph readFromLibDaiFormat (const char*);
 | 
						|
 | 
						|
    static bool exportToLibDai() { return exportLd_; }
 | 
						|
 | 
						|
    static bool exportToUai() { return exportUai_; }
 | 
						|
 | 
						|
    static bool exportGraphViz() { return exportGv_; }
 | 
						|
 | 
						|
    static bool printFactorGraph() { return printFg_; }
 | 
						|
 | 
						|
    static void enableExportToLibDai() { exportLd_ = true; }
 | 
						|
 | 
						|
    static void disableExportToLibDai() { exportLd_ = false; }
 | 
						|
 | 
						|
    static void enableExportToUai() { exportUai_ = true; }
 | 
						|
 | 
						|
    static void disableExportToUai() { exportUai_ = false; }
 | 
						|
 | 
						|
    static void enableExportToGraphViz() { exportGv_ = true; }
 | 
						|
 | 
						|
    static void disableExportToGraphViz() { exportGv_ = false; }
 | 
						|
 | 
						|
    static void enablePrintFactorGraph() { printFg_ = true; }
 | 
						|
 | 
						|
    static void disablePrintFactorGraph() { printFg_ = false; }
 | 
						|
 | 
						|
  private:
 | 
						|
    typedef std::unordered_map<unsigned, VarNode*> VarMap;
 | 
						|
 | 
						|
    void clone (const FactorGraph& fg);
 | 
						|
 | 
						|
    bool containsCycle() const;
 | 
						|
 | 
						|
    bool containsCycle (const VarNode*, const FacNode*,
 | 
						|
        std::vector<bool>&, std::vector<bool>&) const;
 | 
						|
 | 
						|
    bool containsCycle (const FacNode*, const VarNode*,
 | 
						|
        std::vector<bool>&, std::vector<bool>&) const;
 | 
						|
 | 
						|
    static void ignoreLines (std::ifstream&);
 | 
						|
 | 
						|
    VarNodes        varNodes_;
 | 
						|
    FacNodes        facNodes_;
 | 
						|
    VarMap          varMap_;
 | 
						|
    BayesBallGraph  structure_;
 | 
						|
    bool            bayesFactors_;
 | 
						|
 | 
						|
    static bool     exportLd_;
 | 
						|
    static bool     exportUai_;
 | 
						|
    static bool     exportGv_;
 | 
						|
    static bool     printFg_;
 | 
						|
};
 | 
						|
 | 
						|
 | 
						|
 | 
						|
inline VarNode*
 | 
						|
FactorGraph::getVarNode (VarId vid) const
 | 
						|
{
 | 
						|
  VarMap::const_iterator it = varMap_.find (vid);
 | 
						|
  return it != varMap_.end() ? it->second : 0;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
struct sortByVarId {
 | 
						|
  bool operator()(VarNode* vn1, VarNode* vn2) {
 | 
						|
    return vn1->varId() < vn2->varId();
 | 
						|
}};
 | 
						|
 | 
						|
}  // namespace Horus
 | 
						|
 | 
						|
#endif  // YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
 | 
						|
 |