fix align of logical variables
This commit is contained in:
parent
fd8980642b
commit
911b241ad6
@ -347,6 +347,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
|
||||
void
|
||||
ConstraintTree::applySubstitution (const Substitution& theta)
|
||||
{
|
||||
LogVars discardedLvs = theta.getDiscardedLogVars();
|
||||
for (unsigned i = 0; i < discardedLvs.size(); i++) {
|
||||
remove(discardedLvs[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < logVars_.size(); i++) {
|
||||
logVars_[i] = theta.newNameFor (logVars_[i]);
|
||||
}
|
||||
|
@ -245,7 +245,6 @@ Factor::multiply (Factor& g)
|
||||
return;
|
||||
}
|
||||
TFactor<VarId>::multiply (g);
|
||||
cout << "Factor mult called" << endl;
|
||||
}
|
||||
|
||||
|
||||
|
@ -62,6 +62,12 @@ class TFactor
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
T& argument (unsigned idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
unsigned range (unsigned idx) const
|
||||
{
|
||||
assert (idx < ranges_.size());
|
||||
|
@ -554,7 +554,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
int group = (*it)->groupWithGround (query[i]);
|
||||
int group = (*it)->findGroup (query[i]);
|
||||
if (group != -1) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
|
@ -452,7 +452,6 @@ setBayesNetParams (void)
|
||||
int
|
||||
setExtraVarsInfo (void)
|
||||
{
|
||||
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
GraphicalModel::clearVariablesInformation();
|
||||
YAP_Term varsInfoL = YAP_ARG2;
|
||||
while (varsInfoL != YAP_TermNil()) {
|
||||
@ -583,15 +582,15 @@ freeParfactors (void)
|
||||
extern "C" void
|
||||
init_predicates (void)
|
||||
{
|
||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||
}
|
||||
|
||||
|
@ -95,6 +95,25 @@ ostream& operator<< (ostream &os, const Ground& gr)
|
||||
|
||||
|
||||
|
||||
LogVars
|
||||
Substitution::getDiscardedLogVars (void) const
|
||||
{
|
||||
LogVars discardedLvs;
|
||||
set<LogVar> doneLvs;
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.begin();
|
||||
while (it != subs_.end()) {
|
||||
if (Util::contains (doneLvs, it->second)) {
|
||||
discardedLvs.push_back (it->first);
|
||||
} else {
|
||||
doneLvs.insert (it->second);
|
||||
}
|
||||
it ++;
|
||||
}
|
||||
return discardedLvs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||
{
|
||||
|
@ -141,10 +141,13 @@ class Substitution
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
|
||||
LogVars getDiscardedLogVars (void) const;
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||
|
||||
private:
|
||||
unordered_map<LogVar, LogVar> subs_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
@ -191,7 +191,6 @@ Parfactor::multiply (Parfactor& g)
|
||||
{
|
||||
alignAndExponentiate (this, &g);
|
||||
TFactor<ProbFormula>::multiply (g);
|
||||
cout << "calling lifted mult" << endl;
|
||||
constr_->join (g.constr(), true);
|
||||
}
|
||||
|
||||
@ -377,15 +376,6 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
|
||||
{
|
||||
assert (indexOf (f) != -1);
|
||||
args_[indexOf (f)].setGroup (group);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setNewGroups (void)
|
||||
{
|
||||
@ -415,7 +405,7 @@ Parfactor::applySubstitution (const Substitution& theta)
|
||||
|
||||
|
||||
int
|
||||
Parfactor::groupWithGround (const Ground& ground) const
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
{
|
||||
int group = -1;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
@ -436,7 +426,7 @@ Parfactor::groupWithGround (const Ground& ground) const
|
||||
bool
|
||||
Parfactor::containsGround (const Ground& ground) const
|
||||
{
|
||||
return groupWithGround (ground) != -1;
|
||||
return findGroup (ground) != -1;
|
||||
}
|
||||
|
||||
|
||||
@ -670,7 +660,6 @@ Parfactor::align (
|
||||
LogVar freeLogVar = 0;
|
||||
Substitution theta1;
|
||||
Substitution theta2;
|
||||
|
||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
||||
theta1.add (allLvs1[i], freeLogVar);
|
||||
|
@ -60,13 +60,11 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
void absorveEvidence (const ProbFormula&, unsigned);
|
||||
|
||||
void setFormulaGroup (const ProbFormula&, int);
|
||||
|
||||
void setNewGroups (void);
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
int groupWithGround (const Ground&) const;
|
||||
int findGroup (const Ground&) const;
|
||||
|
||||
bool containsGround (const Ground&) const;
|
||||
|
||||
|
@ -197,8 +197,8 @@ ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
||||
std::pair<Parfactors, Parfactors> res
|
||||
= shatter (formulas1[i], g1, formulas2[j], g2);
|
||||
std::pair<Parfactors, Parfactors> res;
|
||||
res = shatter (i, g1, j, g2);
|
||||
if (res.first.empty() == false ||
|
||||
res.second.empty() == false) {
|
||||
return res;
|
||||
@ -213,9 +213,11 @@ ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (
|
||||
ProbFormula& f1, Parfactor* g1,
|
||||
ProbFormula& f2, Parfactor* g2)
|
||||
unsigned fIdx1, Parfactor* g1,
|
||||
unsigned fIdx2, Parfactor* g2)
|
||||
{
|
||||
ProbFormula& f1 = g1->argument (fIdx1);
|
||||
ProbFormula& f2 = g2->argument (fIdx2);
|
||||
// cout << endl;
|
||||
// Util::printDashLine();
|
||||
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
||||
@ -299,8 +301,8 @@ ParfactorList::shatter (
|
||||
} else {
|
||||
group = ProbFormula::getNewGroup();
|
||||
}
|
||||
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
|
||||
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
|
||||
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
|
||||
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
|
||||
return make_pair (res1, res2);
|
||||
}
|
||||
|
||||
@ -309,15 +311,16 @@ ParfactorList::shatter (
|
||||
Parfactors
|
||||
ParfactorList::shatter (
|
||||
Parfactor* g,
|
||||
const ProbFormula& f,
|
||||
unsigned fIdx,
|
||||
ConstraintTree* commCt,
|
||||
ConstraintTree* exclCt,
|
||||
unsigned commGroup)
|
||||
{
|
||||
ProbFormula& f = g->argument (fIdx);
|
||||
if (exclCt->empty()) {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
g->setFormulaGroup (f, commGroup);
|
||||
f.setGroup (commGroup);
|
||||
return { };
|
||||
}
|
||||
|
||||
@ -346,7 +349,7 @@ ParfactorList::shatter (
|
||||
} else {
|
||||
Parfactor* newPf = new Parfactor (g, commCt);
|
||||
newPf->setNewGroups();
|
||||
newPf->setFormulaGroup (f, commGroup);
|
||||
newPf->argument (fIdx).setGroup (commGroup);
|
||||
result.push_back (newPf);
|
||||
newPf = new Parfactor (g, exclCt);
|
||||
newPf->setNewGroups();
|
||||
|
@ -67,11 +67,11 @@ class ParfactorList
|
||||
Parfactor*, Parfactor*);
|
||||
|
||||
std::pair<Parfactors, Parfactors> shatter (
|
||||
ProbFormula&, Parfactor*, ProbFormula&, Parfactor*);
|
||||
unsigned, Parfactor*, unsigned, Parfactor*);
|
||||
|
||||
Parfactors shatter (
|
||||
Parfactor*,
|
||||
const ProbFormula&,
|
||||
unsigned,
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
unsigned);
|
||||
|
Reference in New Issue
Block a user