refactor and increase the scope of parfactor multiplication

This commit is contained in:
Tiago Gomes 2012-05-15 19:05:39 +01:00
parent f4bca3ceea
commit db0d2c9772
9 changed files with 443 additions and 136 deletions

View File

@ -359,10 +359,6 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
void void
ConstraintTree::applySubstitution (const Substitution& theta) ConstraintTree::applySubstitution (const Substitution& theta)
{ {
LogVars discardedLvs = theta.getDiscardedLogVars();
for (unsigned i = 0; i < discardedLvs.size(); i++) {
remove(discardedLvs[i]);
}
for (unsigned i = 0; i < logVars_.size(); i++) { for (unsigned i = 0; i < logVars_.size(); i++) {
logVars_[i] = theta.newNameFor (logVars_[i]); logVars_[i] = theta.newNameFor (logVars_[i]);
} }

View File

@ -13,17 +13,23 @@ LiftedOperator::getValidOps (
const Grounds& query) const Grounds& query)
{ {
vector<LiftedOperator*> validOps; vector<LiftedOperator*> validOps;
vector<ProductOperator*> multOps;
multOps = ProductOperator::getValidOps (pfList);
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
if (Globals::verbosity > 1 || multOps.empty()) {
vector<SumOutOperator*> sumOutOps; vector<SumOutOperator*> sumOutOps;
vector<CountingOperator*> countOps; vector<CountingOperator*> countOps;
vector<GroundOperator*> groundOps; vector<GroundOperator*> groundOps;
sumOutOps = SumOutOperator::getValidOps (pfList, query); sumOutOps = SumOutOperator::getValidOps (pfList, query);
countOps = CountingOperator::getValidOps (pfList); countOps = CountingOperator::getValidOps (pfList);
groundOps = GroundOperator::getValidOps (pfList); groundOps = GroundOperator::getValidOps (pfList);
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end()); validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
validOps.insert (validOps.end(), countOps.begin(), countOps.end()); validOps.insert (validOps.end(), countOps.begin(), countOps.end());
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end()); validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
}
return validOps; return validOps;
} }
@ -44,14 +50,6 @@ LiftedOperator::printValidOps (
vector<unsigned>
LiftedOperator::getAllGroupss (ParfactorList& )
{
return { };
}
vector<ParfactorList::iterator> vector<ParfactorList::iterator>
LiftedOperator::getParfactorsWithGroup ( LiftedOperator::getParfactorsWithGroup (
ParfactorList& pfList, unsigned group) ParfactorList& pfList, unsigned group)
@ -69,6 +67,105 @@ LiftedOperator::getParfactorsWithGroup (
double
ProductOperator::getLogCost (void)
{
return std::log (0.0);
}
void
ProductOperator::apply (void)
{
Parfactor* g1 = *g1_;
Parfactor* g2 = *g2_;
g1->multiply (*g2);
pfList_.remove (g1_);
pfList_.removeAndDelete (g2_);
pfList_.addShattered (g1);
}
vector<ProductOperator*>
ProductOperator::getValidOps (ParfactorList& pfList)
{
vector<ProductOperator*> validOps;
ParfactorList::iterator it1 = pfList.begin();
ParfactorList::iterator penultimate = -- pfList.end();
set<Parfactor*> pfs;
while (it1 != penultimate) {
if (Util::contains (pfs, *it1)) {
++ it1;
continue;
}
ParfactorList::iterator it2 = it1;
++ it2;
while (it2 != pfList.end()) {
if (Util::contains (pfs, *it2)) {
++ it2;
continue;
} else {
if (validOp (*it1, *it2)) {
pfs.insert (*it1);
pfs.insert (*it2);
validOps.push_back (new ProductOperator (
it1, it2, pfList));
if (Globals::verbosity < 2) {
return validOps;
}
break;
}
}
++ it2;
}
++ it1;
}
return validOps;
}
string
ProductOperator::toString (void)
{
stringstream ss;
ss << "just multiplicate " ;
ss << (*g1_)->getAllGroups();
ss << " x " ;
ss << (*g2_)->getAllGroups();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
return ss.str();
}
bool
ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
{
TinySet<unsigned> g1_gs (g1->getAllGroups());
TinySet<unsigned> g2_gs (g2->getAllGroups());
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
TinySet<unsigned> intersect = g1_gs & g2_gs;
for (unsigned i = 0; i < intersect.size(); i++) {
if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
g2->nrFormulasWithGroup (intersect[i]) != 1) {
return false;
}
int idx1 = g1->indexOfGroup (intersect[i]);
int idx2 = g2->indexOfGroup (intersect[i]);
if (g1->range (idx1) != g2->range (idx2)) {
return false;
}
}
return Parfactor::canMultiply (g1, g2);
}
return false;
}
double double
SumOutOperator::getLogCost (void) SumOutOperator::getLogCost (void)
{ {
@ -84,7 +181,8 @@ SumOutOperator::getLogCost (void)
++ pfIter; ++ pfIter;
} }
if (nrProdFactors == 1) { if (nrProdFactors == 1) {
return std::log (0.0); // best possible case // best possible case
return std::log (0.0);
} }
double cost = 1.0; double cost = 1.0;
for (unsigned i = 0; i < groupSet.size(); i++) { for (unsigned i = 0; i < groupSet.size(); i++) {
@ -190,40 +288,22 @@ SumOutOperator::validOp (
if (isToEliminate (*pfIters[0], group, query) == false) { if (isToEliminate (*pfIters[0], group, query) == false) {
return false; return false;
} }
int range = -1;
unordered_map<unsigned, unsigned> groupToRange;
for (unsigned i = 0; i < pfIters.size(); i++) { for (unsigned i = 0; i < pfIters.size(); i++) {
const ProbFormulas& formulas = (*pfIters[i])->arguments(); if ((*pfIters[i])->nrFormulasWithGroup (group) > 1) {
int fIdx = -1;
for (unsigned j = 0; j < formulas.size(); j++) {
if (formulas[j].group() == group) {
if (fIdx != -1) {
// only summout a group of rand vars if they don't
// appear in another position on the factor
return false; return false;
} else {
fIdx = j;
}
}
} }
int fIdx = (*pfIters[i])->indexOfGroup (group);
if ((*pfIters[i])->argument (fIdx).contains ( if ((*pfIters[i])->argument (fIdx).contains (
(*pfIters[i])->elimLogVars()) == false) { (*pfIters[i])->elimLogVars()) == false) {
return false; return false;
} }
vector<unsigned> ranges = (*pfIters[i])->ranges(); if (range == -1) {
vector<unsigned> groups = (*pfIters[i])->getAllGroups(); range = (*pfIters[i])->range (fIdx);
for (unsigned i = 0; i < groups.size(); i++) { } else if ((int)(*pfIters[i])->range (fIdx) != range) {
unordered_map<unsigned, unsigned>::iterator it;
it = groupToRange.find (groups[i]);
if (it == groupToRange.end()) {
groupToRange.insert (make_pair (groups[i], ranges[i]));
} else {
if (it->second != ranges[i]) {
return false; return false;
} }
} }
}
}
return true; return true;
} }
@ -265,8 +345,23 @@ CountingOperator::getLogCost (void)
for (unsigned i = 0; i < counts.size(); i++) { for (unsigned i = 0; i < counts.size(); i++) {
cost += size * HistogramSet::nrHistograms (counts[i], range); cost += size * HistogramSet::nrHistograms (counts[i], range);
} }
if ((*pfIter_)->nrArguments() == 1) { unsigned group = (*pfIter_)->argument (fIdx).group();
cost *= 3; // avoid counting conversion in the beginning int lvIndex = Util::vectorIndex (
(*pfIter_)->argument (fIdx).logVars(), X_);
assert (lvIndex != -1);
ParfactorList::iterator pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
if (pfIter != pfIter_) {
int fIdx2 = (*pfIter)->indexOfGroup (group);
if (fIdx2 != -1) {
LogVar Y = ((*pfIter)->argument (fIdx2).logVars()[lvIndex]);
if ((*pfIter)->canCountConvert (Y) == false) {
// the real cost should be the cost of grounding Y
cost *= 10.0;
}
}
}
++ pfIter;
} }
return std::log (cost); return std::log (cost);
} }
@ -308,6 +403,7 @@ CountingOperator::getValidOps (ParfactorList& pfList)
if (validOp (*it, candidates[i])) { if (validOp (*it, candidates[i])) {
validOps.push_back (new CountingOperator ( validOps.push_back (new CountingOperator (
it, candidates[i], pfList)); it, candidates[i], pfList));
} else {
} }
} }
++ it; ++ it;
@ -350,12 +446,7 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
} }
bool countNormalized = g->constr()->isCountNormalized (X); bool countNormalized = g->constr()->isCountNormalized (X);
if (countNormalized) { if (countNormalized) {
unsigned condCount = g->constr()->getConditionalCount (X); return g->canCountConvert (X);
bool cartProduct = g->constr()->isCarteesianProduct (
g->countedLogVars() | X);
if (condCount == 1 || cartProduct == false) {
return false;
}
} }
return true; return true;
} }
@ -438,6 +529,11 @@ GroundOperator::apply (void)
} }
delete pf; delete pf;
} }
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
(*pflIt)->simplifyGrounds();
++ pflIt;
}
} }
@ -625,6 +721,39 @@ FoveSolver::countNormalize (
Parfactor
FoveSolver::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
Parfactors newPfs = {new Parfactor (pf)};
for (unsigned i = 0; i < lvs.size(); i++) {
Parfactors pfs = newPfs;
newPfs.clear();
for (unsigned j = 0; j < pfs.size(); j++) {
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
if (countedLv) {
pfs[j]->fullExpand (lvs[i]);
newPfs.push_back (pfs[j]);
} else {
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
for (unsigned k = 0; k < cts.size(); k++) {
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
}
delete pfs[j];
}
}
}
ParfactorList pfList (newPfs);
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
for (unsigned i = 1; i < groundShatteredPfs.size(); i++) {
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
}
return Parfactor (*groundShatteredPfs[0]);
}
void void
FoveSolver::runSolver (const Grounds& query) FoveSolver::runSolver (const Grounds& query)
{ {
@ -665,6 +794,7 @@ FoveSolver::runSolver (const Grounds& query)
cout << "largest cost = " << std::exp (largestCost_) << endl; cout << "largest cost = " << std::exp (largestCost_) << endl;
cout << endl; cout << endl;
} }
(*pfList_.begin())->simplifyGrounds();
(*pfList_.begin())->reorderAccordingGrounds (query); (*pfList_.begin())->reorderAccordingGrounds (query);
} }
@ -769,8 +899,11 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
while (it != pfList_.end()) { while (it != pfList_.end()) {
if ((*it)->containsGround (query[i])) { if ((*it)->containsGround (query[i])) {
found = true; found = true;
std::pair<ConstraintTree*, ConstraintTree*> split = std::pair<ConstraintTree*, ConstraintTree*> split;
(*it)->constr()->split (query[i].args(), query[i].arity()); LogVars queryLvs (
(*it)->constr()->logVars().begin(),
(*it)->constr()->logVars().begin() + query[i].arity());
split = (*it)->constr()->split (query[i].args());
ConstraintTree* commCt = split.first; ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second; ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt)); newPfs.push_back (new Parfactor (*it, commCt));
@ -826,8 +959,11 @@ FoveSolver::absorve (
} }
g->constr()->moveToTop (formulas[i].logVars()); g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res std::pair<ConstraintTree*, ConstraintTree*> res;
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity()); res = g->constr()->split (
formulas[i].logVars(),
&(obsFormula.constr()),
obsFormula.constr().logVars());
ConstraintTree* commCt = res.first; ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second; ConstraintTree* exclCt = res.second;

View File

@ -19,14 +19,37 @@ class LiftedOperator
static void printValidOps (ParfactorList&, const Grounds&); static void printValidOps (ParfactorList&, const Grounds&);
static vector<unsigned> getAllGroupss (ParfactorList&);
static vector<ParfactorList::iterator> getParfactorsWithGroup ( static vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, unsigned group); ParfactorList&, unsigned group);
}; };
class ProductOperator : public LiftedOperator
{
public:
ProductOperator (
ParfactorList::iterator g1, ParfactorList::iterator g2,
ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<ProductOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
static bool validOp (Parfactor*, Parfactor*);
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
};
class SumOutOperator : public LiftedOperator class SumOutOperator : public LiftedOperator
{ {
public: public:
@ -123,6 +146,8 @@ class FoveSolver
static Parfactors countNormalize (Parfactor*, const LogVarSet&); static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static Parfactor calcGroundMultiplication (Parfactor pf);
private: private:
void runSolver (const Grounds&); void runSolver (const Grounds&);

View File

@ -68,6 +68,7 @@ int createLiftedNetwork (void)
Util::printHeader ("INITIAL PARFACTORS"); Util::printHeader ("INITIAL PARFACTORS");
for (unsigned i = 0; i < parfactors.size(); i++) { for (unsigned i = 0; i < parfactors.size(); i++) {
parfactors[i]->print(); parfactors[i]->print();
cout << endl;
} }
} }

View File

@ -137,15 +137,21 @@ class Substitution
LogVar newNameFor (LogVar X) const LogVar newNameFor (LogVar X) const
{ {
assert (Util::contains (subs_, X)); unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.find (X);
if (it != subs_.end()) {
return subs_.find (X)->second; return subs_.find (X)->second;
} }
return X;
}
bool containsReplacementFor (LogVar X) const bool containsReplacementFor (LogVar X) const
{ {
return Util::contains (subs_, X); return Util::contains (subs_, X);
} }
unsigned nrReplacements (void) const { return subs_.size(); }
LogVars getDiscardedLogVars (void) const; LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta); friend ostream& operator<< (ostream &os, const Substitution& theta);

View File

@ -193,6 +193,32 @@ Parfactor::multiply (Parfactor& g)
alignAndExponentiate (this, &g); alignAndExponentiate (this, &g);
TFactor<ProbFormula>::multiply (g); TFactor<ProbFormula>::multiply (g);
constr_->join (g.constr(), true); constr_->join (g.constr(), true);
simplifyGrounds();
assert (constr_->isCarteesianProduct (countedLogVars()));
}
bool
Parfactor::canCountConvert (LogVar X)
{
if (nrFormulas (X) != 1) {
return false;
}
int fIdx = indexOfLogVar (X);
if (args_[fIdx].isCounting()) {
return false;
}
if (constr_->isCountNormalized (X) == false) {
return false;
}
if (constr_->getConditionalCount (X) == 1) {
return false;
}
if (constr_->isCarteesianProduct (countedLogVars() | X) == false) {
return false;
}
return true;
} }
@ -201,10 +227,9 @@ void
Parfactor::countConvert (LogVar X) Parfactor::countConvert (LogVar X)
{ {
int fIdx = indexOfLogVar (X); int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (constr_->isCountNormalized (X)); assert (constr_->isCountNormalized (X));
assert (constr_->getConditionalCount (X) > 1); assert (constr_->getConditionalCount (X) > 1);
assert (constr_->isCarteesianProduct (countedLogVars() | X)); assert (canCountConvert (X));
unsigned N = constr_->getConditionalCount (X); unsigned N = constr_->getConditionalCount (X);
unsigned R = ranges_[fIdx]; unsigned R = ranges_[fIdx];
@ -245,6 +270,7 @@ Parfactor::countConvert (LogVar X)
++ mapIndexer; ++ mapIndexer;
} }
args_[fIdx].setCountedLogVar (X); args_[fIdx].setCountedLogVar (X);
simplifyCountingFormulas (fIdx);
} }
@ -307,10 +333,9 @@ Parfactor::fullExpand (LogVar X)
unsigned N = constr_->getConditionalCount (X); unsigned N = constr_->getConditionalCount (X);
unsigned R = args_[fIdx].range(); unsigned R = args_[fIdx].range();
vector<Histogram> originHists = HistogramSet::getHistograms (N, R); vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R); vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
assert (ranges_[fIdx] == originHists.size());
vector<unsigned> sumIndexes; vector<unsigned> sumIndexes;
sumIndexes.reserve (N * R); sumIndexes.reserve (N * R);
@ -496,6 +521,20 @@ Parfactor::indexOfGroup (unsigned group) const
unsigned
Parfactor::nrFormulasWithGroup (unsigned group) const
{
unsigned count = 0;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) {
count ++;
}
}
return count;
}
vector<unsigned> vector<unsigned>
Parfactor::getAllGroups (void) const Parfactor::getAllGroups (void) const
{ {
@ -672,22 +711,119 @@ Parfactor::expandPotential (
void void
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2) Parfactor::simplifyCountingFormulas (int fIdx)
{ {
LogVars X_1, X_2; // check if we can simplify the parfactor
const ProbFormulas& formulas1 = g1->arguments(); for (unsigned i = 0; i < args_.size(); i++) {
const ProbFormulas& formulas2 = g2->arguments(); if ((int)i != fIdx &&
args_[i].isCounting() &&
args_[i].group() == args_[fIdx].group()) {
// if they only differ in the name of the counting log var
if ((args_[i].logVarSet() - args_[i].countedLogVar()) ==
(args_[fIdx].logVarSet()) - args_[fIdx].countedLogVar() &&
ranges_[i] == ranges_[fIdx]) {
simplifyParfactor (fIdx, i);
break;
}
}
}
}
void
Parfactor::simplifyGrounds (void)
{
LogVarSet singletons = constr_->singletons();
for (int i = 0; i < (int)args_.size() - 1; i++) {
for (unsigned j = i + 1; j < args_.size(); j++) {
if (args_[i].group() == args_[j].group() &&
singletons.contains (args_[i].logVarSet()) &&
singletons.contains (args_[j].logVarSet())) {
simplifyParfactor (i, j);
i --;
break;
}
}
}
}
bool
Parfactor::canMultiply (Parfactor* g1, Parfactor* g2)
{
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
LogVarSet Xs_1 (res.first);
LogVarSet Xs_2 (res.second);
LogVarSet Y_1 = g1->logVarSet() - Xs_1;
LogVarSet Y_2 = g2->logVarSet() - Xs_2;
Y_1 -= g1->countedLogVars();
Y_2 -= g2->countedLogVars();
return g1->constr()->isCountNormalized (Y_1) &&
g2->constr()->isCountNormalized (Y_2);
}
void
Parfactor::simplifyParfactor (unsigned fIdx1, unsigned fIdx2)
{
Params copy = params_;
params_.clear();
StatesIndexer indexer (ranges_);
while (indexer.valid()) {
if (indexer[fIdx1] == indexer[fIdx2]) {
params_.push_back (copy[indexer]);
}
++ indexer;
}
for (unsigned i = 0; i < args_[fIdx2].logVars().size(); i++) {
if (nrFormulas (args_[fIdx2].logVars()[i]) == 1) {
constr_->remove ({ args_[fIdx2].logVars()[i] });
}
}
args_.erase (args_.begin() + fIdx2);
ranges_.erase (ranges_.begin() + fIdx2);
}
std::pair<LogVars, LogVars>
Parfactor::getAlignLogVars (Parfactor* g1, Parfactor* g2)
{
g1->simplifyGrounds();
g2->simplifyGrounds();
LogVars Xs_1, Xs_2;
TinySet<unsigned> matchedI;
TinySet<unsigned> matchedJ;
ProbFormulas& formulas1 = g1->arguments();
ProbFormulas& formulas2 = g2->arguments();
for (unsigned i = 0; i < formulas1.size(); i++) { for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) { for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].group() == formulas2[j].group()) { if (formulas1[i].group() == formulas2[j].group() &&
Util::addToVector (X_1, formulas1[i].logVars()); g1->range (i) == g2->range (j) &&
Util::addToVector (X_2, formulas2[j].logVars()); matchedI.contains (i) == false &&
matchedJ.contains (j) == false) {
Util::addToVector (Xs_1, formulas1[i].logVars());
Util::addToVector (Xs_2, formulas2[j].logVars());
matchedI.insert (i);
matchedJ.insert (j);
} }
} }
} }
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1); return make_pair (Xs_1, Xs_2);
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2); }
// counting log vars were already raised on counting conversion
void
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
{
alignLogicalVars (g1, g2);
LogVarSet comm = g1->logVarSet() & g2->logVarSet();
LogVarSet Y_1 = g1->logVarSet() - comm;
LogVarSet Y_2 = g2->logVarSet() - comm;
Y_1 -= g1->countedLogVars(); Y_1 -= g1->countedLogVars();
Y_2 -= g2->countedLogVars(); Y_2 -= g2->countedLogVars();
assert (g1->constr()->isCountNormalized (Y_1)); assert (g1->constr()->isCountNormalized (Y_1));
@ -696,30 +832,27 @@ Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2); unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
LogAware::pow (g1->params(), 1.0 / condCount2); LogAware::pow (g1->params(), 1.0 / condCount2);
LogAware::pow (g2->params(), 1.0 / condCount1); LogAware::pow (g2->params(), 1.0 / condCount1);
//cout << "g1::::::::::::::::::::::::" << endl;
//g1->print();
//cout << "g2::::::::::::::::::::::::" << endl;
//g2->print();
// the alignment should be done in the end or else X_1 and X_2
// will refer to the old log var names on the code above
align (g1, X_1, g2, X_2);
} }
void void
Parfactor::align ( Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
Parfactor* g1, const LogVars& alignLvs1,
Parfactor* g2, const LogVars& alignLvs2)
{ {
LogVar freeLogVar = 0; std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
const LogVars& alignLvs1 = res.first;
const LogVars& alignLvs2 = res.second;
// cout << "ALIGNING :::::::::::::::::" << endl;
// g1->print();
// cout << "AND" << endl;
// g2->print();
// cout << "-> align lvs1 = " << alignLvs1 << endl;
// cout << "-> align lvs2 = " << alignLvs2 << endl;
LogVar freeLogVar (0);
Substitution theta1, theta2; Substitution theta1, theta2;
for (unsigned i = 0; i < alignLvs1.size(); i++) { for (unsigned i = 0; i < alignLvs1.size(); i++) {
bool b1 = theta1.containsReplacementFor (alignLvs1[i]); bool b1 = theta1.containsReplacementFor (alignLvs1[i]);
bool b2 = theta2.containsReplacementFor (alignLvs2[i]); bool b2 = theta2.containsReplacementFor (alignLvs2[i]);
// handle this type of situation:
// g1 = p(X), q(X) ; X in {x1}
// g2 = p(X), q(Y) ; X in {x1}, Y in {x1}
if (b1 == false && b2 == false) { if (b1 == false && b2 == false) {
theta1.add (alignLvs1[i], freeLogVar); theta1.add (alignLvs1[i], freeLogVar);
theta2.add (alignLvs2[i], freeLogVar); theta2.add (alignLvs2[i], freeLogVar);
@ -730,6 +863,7 @@ Parfactor::align (
theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i])); theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i]));
} }
} }
const LogVarSet& allLvs1 = g1->logVarSet(); const LogVarSet& allLvs1 = g1->logVarSet();
for (unsigned i = 0; i < allLvs1.size(); i++) { for (unsigned i = 0; i < allLvs1.size(); i++) {
if (theta1.containsReplacementFor (allLvs1[i]) == false) { if (theta1.containsReplacementFor (allLvs1[i]) == false) {
@ -744,25 +878,33 @@ Parfactor::align (
++ freeLogVar; ++ freeLogVar;
} }
} }
//cout << "theta1: " << theta1 << endl;
//cout << "theta2: " << theta2 << endl; // handle this type of situation:
// g1 = p(X), q(X) ; X in {(p1),(p2)}
// g2 = p(X), q(Y) ; (X,Y) in {(p1,p2),(p2,p1)}
LogVars discardedLvs1 = theta1.getDiscardedLogVars(); LogVars discardedLvs1 = theta1.getDiscardedLogVars();
for (unsigned i = 0; i < discardedLvs1.size(); i++) { for (unsigned i = 0; i < discardedLvs1.size(); i++) {
unsigned condCount = g1->constr()->getConditionalCount (discardedLvs1[i]); if (g1->constr()->isSingleton (discardedLvs1[i]) &&
cout << "discarding g1" << discardedLvs1[i]; g1->nrFormulas (discardedLvs1[i]) == 1) {
cout << " cc = " << condCount << endl; g1->constr()->remove (discardedLvs1[i]);
LogAware::pow (g1->params(), condCount); } else {
LogVar X_new = ++ g1->constr()->logVarSet().back();
theta1.rename (discardedLvs1[i], X_new);
}
} }
LogVars discardedLvs2 = theta2.getDiscardedLogVars(); LogVars discardedLvs2 = theta2.getDiscardedLogVars();
for (unsigned i = 0; i < discardedLvs2.size(); i++) { for (unsigned i = 0; i < discardedLvs2.size(); i++) {
unsigned condCount = g2->constr()->getConditionalCount (discardedLvs2[i]); if (g2->constr()->isSingleton (discardedLvs2[i]) &&
cout << "discarding g2" << discardedLvs2[i]; g2->nrFormulas (discardedLvs2[i]) == 1) {
cout << " cc = " << condCount << endl; g2->constr()->remove (discardedLvs2[i]);
//if (condCount != 1) { } else {
// theta2.rename (discardedLvs2[i], freeLogVar); LogVar X_new = ++ g2->constr()->logVarSet().back();
//} theta2.rename (discardedLvs2[i], X_new);
LogAware::pow (g2->params(), condCount);
} }
}
// cout << "theta1: " << theta1 << endl;
// cout << "theta2: " << theta2 << endl;
g1->applySubstitution (theta1); g1->applySubstitution (theta1);
g2->applySubstitution (theta2); g2->applySubstitution (theta2);
} }

View File

@ -50,6 +50,8 @@ class Parfactor : public TFactor<ProbFormula>
void multiply (Parfactor&); void multiply (Parfactor&);
bool canCountConvert (LogVar X);
void countConvert (LogVar); void countConvert (LogVar);
void expand (LogVar, LogVar, LogVar); void expand (LogVar, LogVar, LogVar);
@ -76,6 +78,8 @@ class Parfactor : public TFactor<ProbFormula>
int indexOfGroup (unsigned) const; int indexOfGroup (unsigned) const;
unsigned nrFormulasWithGroup (unsigned) const;
vector<unsigned> getAllGroups (void) const; vector<unsigned> getAllGroups (void) const;
void print (bool = false) const; void print (bool = false) const;
@ -86,16 +90,28 @@ class Parfactor : public TFactor<ProbFormula>
string getLabel (void) const; string getLabel (void) const;
void simplifyGrounds (void);
static bool canMultiply (Parfactor*, Parfactor*);
private: private:
void simplifyCountingFormulas (int fIdx);
void simplifyParfactor (unsigned fIdx1, unsigned fIdx2);
static std::pair<LogVars, LogVars> getAlignLogVars (
Parfactor* g1, Parfactor* g2);
void expandPotential (int fIdx, unsigned newRange, void expandPotential (int fIdx, unsigned newRange,
const vector<unsigned>& sumIndexes); const vector<unsigned>& sumIndexes);
static void alignAndExponentiate (Parfactor*, Parfactor*); static void alignAndExponentiate (Parfactor*, Parfactor*);
static void align ( static void alignLogicalVars (Parfactor*, Parfactor*);
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
ConstraintTree* constr_; ConstraintTree* constr_;
}; };

View File

@ -1,4 +1,10 @@
TODO - Refactor sum out in factor
- add a way to calculate combinations and factorials with large numbers - Add a way to sum out several vars at the same time
- refactor sumOut in parfactor -> is really ugly code - Receive ranges as a constant reference in Indexer
- Indexer: start receiving ranges as constant reference - Merge TinySet and SortedVector classes
- Check if evidence remains in the compressed factor graph
- Consider using hashs instead of vectors of colors to calculate the groups in
counting bp
- use more psize_t instead of unsigned for looping through params
- use more Util::abort and Util::vectorIndex
- LogVar should not cast to int

View File

@ -1,21 +0,0 @@
:- use_module(library(pfl)).
:- set_pfl_flag(solver,fove).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).
t(ann).
t(dave).
% p(ann,t).
markov p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
% use standard Prolog queries: provide evidence first.
?- p(ann,t), p(ann,X).
% ?- p(ann,X).