fix align of logical variables

This commit is contained in:
Tiago Gomes 2012-04-03 11:58:21 +01:00
parent fd8980642b
commit 911b241ad6
11 changed files with 60 additions and 40 deletions

View File

@ -347,6 +347,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
void
ConstraintTree::applySubstitution (const Substitution& theta)
{
LogVars discardedLvs = theta.getDiscardedLogVars();
for (unsigned i = 0; i < discardedLvs.size(); i++) {
remove(discardedLvs[i]);
}
for (unsigned i = 0; i < logVars_.size(); i++) {
logVars_[i] = theta.newNameFor (logVars_[i]);
}

View File

@ -245,7 +245,6 @@ Factor::multiply (Factor& g)
return;
}
TFactor<VarId>::multiply (g);
cout << "Factor mult called" << endl;
}

View File

@ -62,6 +62,12 @@ class TFactor
return args_[idx];
}
T& argument (unsigned idx)
{
assert (idx < args_.size());
return args_[idx];
}
unsigned range (unsigned idx) const
{
assert (idx < ranges_.size());

View File

@ -554,7 +554,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
int group = (*it)->groupWithGround (query[i]);
int group = (*it)->findGroup (query[i]);
if (group != -1) {
todo.push (group);
done.insert (group);

View File

@ -452,7 +452,6 @@ setBayesNetParams (void)
int
setExtraVarsInfo (void)
{
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
GraphicalModel::clearVariablesInformation();
YAP_Term varsInfoL = YAP_ARG2;
while (varsInfoL != YAP_TermNil()) {
@ -583,15 +582,15 @@ freeParfactors (void)
extern "C" void
init_predicates (void)
{
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
}

View File

@ -95,6 +95,25 @@ ostream& operator<< (ostream &os, const Ground& gr)
LogVars
Substitution::getDiscardedLogVars (void) const
{
LogVars discardedLvs;
set<LogVar> doneLvs;
unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.begin();
while (it != subs_.end()) {
if (Util::contains (doneLvs, it->second)) {
discardedLvs.push_back (it->first);
} else {
doneLvs.insert (it->second);
}
it ++;
}
return discardedLvs;
}
ostream& operator<< (ostream &os, const Substitution& theta)
{

View File

@ -141,10 +141,13 @@ class Substitution
return subs_.find (X)->second;
}
LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta);
private:
unordered_map<LogVar, LogVar> subs_;
};

View File

@ -191,7 +191,6 @@ Parfactor::multiply (Parfactor& g)
{
alignAndExponentiate (this, &g);
TFactor<ProbFormula>::multiply (g);
cout << "calling lifted mult" << endl;
constr_->join (g.constr(), true);
}
@ -377,15 +376,6 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
void
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
{
assert (indexOf (f) != -1);
args_[indexOf (f)].setGroup (group);
}
void
Parfactor::setNewGroups (void)
{
@ -415,7 +405,7 @@ Parfactor::applySubstitution (const Substitution& theta)
int
Parfactor::groupWithGround (const Ground& ground) const
Parfactor::findGroup (const Ground& ground) const
{
int group = -1;
for (unsigned i = 0; i < args_.size(); i++) {
@ -436,7 +426,7 @@ Parfactor::groupWithGround (const Ground& ground) const
bool
Parfactor::containsGround (const Ground& ground) const
{
return groupWithGround (ground) != -1;
return findGroup (ground) != -1;
}
@ -670,7 +660,6 @@ Parfactor::align (
LogVar freeLogVar = 0;
Substitution theta1;
Substitution theta2;
const LogVarSet& allLvs1 = g1->logVarSet();
for (unsigned i = 0; i < allLvs1.size(); i++) {
theta1.add (allLvs1[i], freeLogVar);

View File

@ -60,13 +60,11 @@ class Parfactor : public TFactor<ProbFormula>
void absorveEvidence (const ProbFormula&, unsigned);
void setFormulaGroup (const ProbFormula&, int);
void setNewGroups (void);
void applySubstitution (const Substitution&);
int groupWithGround (const Ground&) const;
int findGroup (const Ground&) const;
bool containsGround (const Ground&) const;

View File

@ -197,8 +197,8 @@ ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
std::pair<Parfactors, Parfactors> res
= shatter (formulas1[i], g1, formulas2[j], g2);
std::pair<Parfactors, Parfactors> res;
res = shatter (i, g1, j, g2);
if (res.first.empty() == false ||
res.second.empty() == false) {
return res;
@ -213,9 +213,11 @@ ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
std::pair<Parfactors, Parfactors>
ParfactorList::shatter (
ProbFormula& f1, Parfactor* g1,
ProbFormula& f2, Parfactor* g2)
unsigned fIdx1, Parfactor* g1,
unsigned fIdx2, Parfactor* g2)
{
ProbFormula& f1 = g1->argument (fIdx1);
ProbFormula& f2 = g2->argument (fIdx2);
// cout << endl;
// Util::printDashLine();
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
@ -299,8 +301,8 @@ ParfactorList::shatter (
} else {
group = ProbFormula::getNewGroup();
}
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
return make_pair (res1, res2);
}
@ -309,15 +311,16 @@ ParfactorList::shatter (
Parfactors
ParfactorList::shatter (
Parfactor* g,
const ProbFormula& f,
unsigned fIdx,
ConstraintTree* commCt,
ConstraintTree* exclCt,
unsigned commGroup)
{
ProbFormula& f = g->argument (fIdx);
if (exclCt->empty()) {
delete commCt;
delete exclCt;
g->setFormulaGroup (f, commGroup);
f.setGroup (commGroup);
return { };
}
@ -346,7 +349,7 @@ ParfactorList::shatter (
} else {
Parfactor* newPf = new Parfactor (g, commCt);
newPf->setNewGroups();
newPf->setFormulaGroup (f, commGroup);
newPf->argument (fIdx).setGroup (commGroup);
result.push_back (newPf);
newPf = new Parfactor (g, exclCt);
newPf->setNewGroups();

View File

@ -67,11 +67,11 @@ class ParfactorList
Parfactor*, Parfactor*);
std::pair<Parfactors, Parfactors> shatter (
ProbFormula&, Parfactor*, ProbFormula&, Parfactor*);
unsigned, Parfactor*, unsigned, Parfactor*);
Parfactors shatter (
Parfactor*,
const ProbFormula&,
unsigned,
ConstraintTree*,
ConstraintTree*,
unsigned);