refactor and increase the scope of parfactor multiplication
This commit is contained in:
parent
f4bca3ceea
commit
db0d2c9772
@ -359,10 +359,6 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
|
|||||||
void
|
void
|
||||||
ConstraintTree::applySubstitution (const Substitution& theta)
|
ConstraintTree::applySubstitution (const Substitution& theta)
|
||||||
{
|
{
|
||||||
LogVars discardedLvs = theta.getDiscardedLogVars();
|
|
||||||
for (unsigned i = 0; i < discardedLvs.size(); i++) {
|
|
||||||
remove(discardedLvs[i]);
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < logVars_.size(); i++) {
|
for (unsigned i = 0; i < logVars_.size(); i++) {
|
||||||
logVars_[i] = theta.newNameFor (logVars_[i]);
|
logVars_[i] = theta.newNameFor (logVars_[i]);
|
||||||
}
|
}
|
||||||
|
@ -13,17 +13,23 @@ LiftedOperator::getValidOps (
|
|||||||
const Grounds& query)
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
vector<LiftedOperator*> validOps;
|
vector<LiftedOperator*> validOps;
|
||||||
|
vector<ProductOperator*> multOps;
|
||||||
|
|
||||||
|
multOps = ProductOperator::getValidOps (pfList);
|
||||||
|
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
|
||||||
|
|
||||||
|
if (Globals::verbosity > 1 || multOps.empty()) {
|
||||||
vector<SumOutOperator*> sumOutOps;
|
vector<SumOutOperator*> sumOutOps;
|
||||||
vector<CountingOperator*> countOps;
|
vector<CountingOperator*> countOps;
|
||||||
vector<GroundOperator*> groundOps;
|
vector<GroundOperator*> groundOps;
|
||||||
|
|
||||||
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
||||||
countOps = CountingOperator::getValidOps (pfList);
|
countOps = CountingOperator::getValidOps (pfList);
|
||||||
groundOps = GroundOperator::getValidOps (pfList);
|
groundOps = GroundOperator::getValidOps (pfList);
|
||||||
|
|
||||||
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
|
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
|
||||||
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
|
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
|
||||||
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
|
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
|
||||||
|
}
|
||||||
|
|
||||||
return validOps;
|
return validOps;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,14 +50,6 @@ LiftedOperator::printValidOps (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<unsigned>
|
|
||||||
LiftedOperator::getAllGroupss (ParfactorList& )
|
|
||||||
{
|
|
||||||
return { };
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<ParfactorList::iterator>
|
vector<ParfactorList::iterator>
|
||||||
LiftedOperator::getParfactorsWithGroup (
|
LiftedOperator::getParfactorsWithGroup (
|
||||||
ParfactorList& pfList, unsigned group)
|
ParfactorList& pfList, unsigned group)
|
||||||
@ -69,6 +67,105 @@ LiftedOperator::getParfactorsWithGroup (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
ProductOperator::getLogCost (void)
|
||||||
|
{
|
||||||
|
return std::log (0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ProductOperator::apply (void)
|
||||||
|
{
|
||||||
|
Parfactor* g1 = *g1_;
|
||||||
|
Parfactor* g2 = *g2_;
|
||||||
|
g1->multiply (*g2);
|
||||||
|
pfList_.remove (g1_);
|
||||||
|
pfList_.removeAndDelete (g2_);
|
||||||
|
pfList_.addShattered (g1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<ProductOperator*>
|
||||||
|
ProductOperator::getValidOps (ParfactorList& pfList)
|
||||||
|
{
|
||||||
|
vector<ProductOperator*> validOps;
|
||||||
|
ParfactorList::iterator it1 = pfList.begin();
|
||||||
|
ParfactorList::iterator penultimate = -- pfList.end();
|
||||||
|
set<Parfactor*> pfs;
|
||||||
|
while (it1 != penultimate) {
|
||||||
|
if (Util::contains (pfs, *it1)) {
|
||||||
|
++ it1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ParfactorList::iterator it2 = it1;
|
||||||
|
++ it2;
|
||||||
|
while (it2 != pfList.end()) {
|
||||||
|
if (Util::contains (pfs, *it2)) {
|
||||||
|
++ it2;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
if (validOp (*it1, *it2)) {
|
||||||
|
pfs.insert (*it1);
|
||||||
|
pfs.insert (*it2);
|
||||||
|
validOps.push_back (new ProductOperator (
|
||||||
|
it1, it2, pfList));
|
||||||
|
if (Globals::verbosity < 2) {
|
||||||
|
return validOps;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it2;
|
||||||
|
}
|
||||||
|
++ it1;
|
||||||
|
}
|
||||||
|
return validOps;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
ProductOperator::toString (void)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "just multiplicate " ;
|
||||||
|
ss << (*g1_)->getAllGroups();
|
||||||
|
ss << " x " ;
|
||||||
|
ss << (*g2_)->getAllGroups();
|
||||||
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
TinySet<unsigned> g1_gs (g1->getAllGroups());
|
||||||
|
TinySet<unsigned> g2_gs (g2->getAllGroups());
|
||||||
|
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
|
||||||
|
TinySet<unsigned> intersect = g1_gs & g2_gs;
|
||||||
|
for (unsigned i = 0; i < intersect.size(); i++) {
|
||||||
|
if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
|
||||||
|
g2->nrFormulasWithGroup (intersect[i]) != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int idx1 = g1->indexOfGroup (intersect[i]);
|
||||||
|
int idx2 = g2->indexOfGroup (intersect[i]);
|
||||||
|
if (g1->range (idx1) != g2->range (idx2)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Parfactor::canMultiply (g1, g2);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
SumOutOperator::getLogCost (void)
|
SumOutOperator::getLogCost (void)
|
||||||
{
|
{
|
||||||
@ -84,7 +181,8 @@ SumOutOperator::getLogCost (void)
|
|||||||
++ pfIter;
|
++ pfIter;
|
||||||
}
|
}
|
||||||
if (nrProdFactors == 1) {
|
if (nrProdFactors == 1) {
|
||||||
return std::log (0.0); // best possible case
|
// best possible case
|
||||||
|
return std::log (0.0);
|
||||||
}
|
}
|
||||||
double cost = 1.0;
|
double cost = 1.0;
|
||||||
for (unsigned i = 0; i < groupSet.size(); i++) {
|
for (unsigned i = 0; i < groupSet.size(); i++) {
|
||||||
@ -190,40 +288,22 @@ SumOutOperator::validOp (
|
|||||||
if (isToEliminate (*pfIters[0], group, query) == false) {
|
if (isToEliminate (*pfIters[0], group, query) == false) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
int range = -1;
|
||||||
unordered_map<unsigned, unsigned> groupToRange;
|
|
||||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||||
const ProbFormulas& formulas = (*pfIters[i])->arguments();
|
if ((*pfIters[i])->nrFormulasWithGroup (group) > 1) {
|
||||||
int fIdx = -1;
|
|
||||||
for (unsigned j = 0; j < formulas.size(); j++) {
|
|
||||||
if (formulas[j].group() == group) {
|
|
||||||
if (fIdx != -1) {
|
|
||||||
// only summout a group of rand vars if they don't
|
|
||||||
// appear in another position on the factor
|
|
||||||
return false;
|
return false;
|
||||||
} else {
|
|
||||||
fIdx = j;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
int fIdx = (*pfIters[i])->indexOfGroup (group);
|
||||||
if ((*pfIters[i])->argument (fIdx).contains (
|
if ((*pfIters[i])->argument (fIdx).contains (
|
||||||
(*pfIters[i])->elimLogVars()) == false) {
|
(*pfIters[i])->elimLogVars()) == false) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
if (range == -1) {
|
||||||
vector<unsigned> groups = (*pfIters[i])->getAllGroups();
|
range = (*pfIters[i])->range (fIdx);
|
||||||
for (unsigned i = 0; i < groups.size(); i++) {
|
} else if ((int)(*pfIters[i])->range (fIdx) != range) {
|
||||||
unordered_map<unsigned, unsigned>::iterator it;
|
|
||||||
it = groupToRange.find (groups[i]);
|
|
||||||
if (it == groupToRange.end()) {
|
|
||||||
groupToRange.insert (make_pair (groups[i], ranges[i]));
|
|
||||||
} else {
|
|
||||||
if (it->second != ranges[i]) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -265,8 +345,23 @@ CountingOperator::getLogCost (void)
|
|||||||
for (unsigned i = 0; i < counts.size(); i++) {
|
for (unsigned i = 0; i < counts.size(); i++) {
|
||||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||||
}
|
}
|
||||||
if ((*pfIter_)->nrArguments() == 1) {
|
unsigned group = (*pfIter_)->argument (fIdx).group();
|
||||||
cost *= 3; // avoid counting conversion in the beginning
|
int lvIndex = Util::vectorIndex (
|
||||||
|
(*pfIter_)->argument (fIdx).logVars(), X_);
|
||||||
|
assert (lvIndex != -1);
|
||||||
|
ParfactorList::iterator pfIter = pfList_.begin();
|
||||||
|
while (pfIter != pfList_.end()) {
|
||||||
|
if (pfIter != pfIter_) {
|
||||||
|
int fIdx2 = (*pfIter)->indexOfGroup (group);
|
||||||
|
if (fIdx2 != -1) {
|
||||||
|
LogVar Y = ((*pfIter)->argument (fIdx2).logVars()[lvIndex]);
|
||||||
|
if ((*pfIter)->canCountConvert (Y) == false) {
|
||||||
|
// the real cost should be the cost of grounding Y
|
||||||
|
cost *= 10.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ pfIter;
|
||||||
}
|
}
|
||||||
return std::log (cost);
|
return std::log (cost);
|
||||||
}
|
}
|
||||||
@ -308,6 +403,7 @@ CountingOperator::getValidOps (ParfactorList& pfList)
|
|||||||
if (validOp (*it, candidates[i])) {
|
if (validOp (*it, candidates[i])) {
|
||||||
validOps.push_back (new CountingOperator (
|
validOps.push_back (new CountingOperator (
|
||||||
it, candidates[i], pfList));
|
it, candidates[i], pfList));
|
||||||
|
} else {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
++ it;
|
++ it;
|
||||||
@ -350,12 +446,7 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
|||||||
}
|
}
|
||||||
bool countNormalized = g->constr()->isCountNormalized (X);
|
bool countNormalized = g->constr()->isCountNormalized (X);
|
||||||
if (countNormalized) {
|
if (countNormalized) {
|
||||||
unsigned condCount = g->constr()->getConditionalCount (X);
|
return g->canCountConvert (X);
|
||||||
bool cartProduct = g->constr()->isCarteesianProduct (
|
|
||||||
g->countedLogVars() | X);
|
|
||||||
if (condCount == 1 || cartProduct == false) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -438,6 +529,11 @@ GroundOperator::apply (void)
|
|||||||
}
|
}
|
||||||
delete pf;
|
delete pf;
|
||||||
}
|
}
|
||||||
|
ParfactorList::iterator pflIt = pfList_.begin();
|
||||||
|
while (pflIt != pfList_.end()) {
|
||||||
|
(*pflIt)->simplifyGrounds();
|
||||||
|
++ pflIt;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -625,6 +721,39 @@ FoveSolver::countNormalize (
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor
|
||||||
|
FoveSolver::calcGroundMultiplication (Parfactor pf)
|
||||||
|
{
|
||||||
|
LogVarSet lvs = pf.constr()->logVarSet();
|
||||||
|
lvs -= pf.constr()->singletons();
|
||||||
|
Parfactors newPfs = {new Parfactor (pf)};
|
||||||
|
for (unsigned i = 0; i < lvs.size(); i++) {
|
||||||
|
Parfactors pfs = newPfs;
|
||||||
|
newPfs.clear();
|
||||||
|
for (unsigned j = 0; j < pfs.size(); j++) {
|
||||||
|
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
||||||
|
if (countedLv) {
|
||||||
|
pfs[j]->fullExpand (lvs[i]);
|
||||||
|
newPfs.push_back (pfs[j]);
|
||||||
|
} else {
|
||||||
|
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
||||||
|
for (unsigned k = 0; k < cts.size(); k++) {
|
||||||
|
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
||||||
|
}
|
||||||
|
delete pfs[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ParfactorList pfList (newPfs);
|
||||||
|
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
||||||
|
for (unsigned i = 1; i < groundShatteredPfs.size(); i++) {
|
||||||
|
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
||||||
|
}
|
||||||
|
return Parfactor (*groundShatteredPfs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::runSolver (const Grounds& query)
|
FoveSolver::runSolver (const Grounds& query)
|
||||||
{
|
{
|
||||||
@ -665,6 +794,7 @@ FoveSolver::runSolver (const Grounds& query)
|
|||||||
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
(*pfList_.begin())->simplifyGrounds();
|
||||||
(*pfList_.begin())->reorderAccordingGrounds (query);
|
(*pfList_.begin())->reorderAccordingGrounds (query);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -769,8 +899,11 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
|
|||||||
while (it != pfList_.end()) {
|
while (it != pfList_.end()) {
|
||||||
if ((*it)->containsGround (query[i])) {
|
if ((*it)->containsGround (query[i])) {
|
||||||
found = true;
|
found = true;
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
std::pair<ConstraintTree*, ConstraintTree*> split;
|
||||||
(*it)->constr()->split (query[i].args(), query[i].arity());
|
LogVars queryLvs (
|
||||||
|
(*it)->constr()->logVars().begin(),
|
||||||
|
(*it)->constr()->logVars().begin() + query[i].arity());
|
||||||
|
split = (*it)->constr()->split (query[i].args());
|
||||||
ConstraintTree* commCt = split.first;
|
ConstraintTree* commCt = split.first;
|
||||||
ConstraintTree* exclCt = split.second;
|
ConstraintTree* exclCt = split.second;
|
||||||
newPfs.push_back (new Parfactor (*it, commCt));
|
newPfs.push_back (new Parfactor (*it, commCt));
|
||||||
@ -826,8 +959,11 @@ FoveSolver::absorve (
|
|||||||
}
|
}
|
||||||
|
|
||||||
g->constr()->moveToTop (formulas[i].logVars());
|
g->constr()->moveToTop (formulas[i].logVars());
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> res
|
std::pair<ConstraintTree*, ConstraintTree*> res;
|
||||||
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
res = g->constr()->split (
|
||||||
|
formulas[i].logVars(),
|
||||||
|
&(obsFormula.constr()),
|
||||||
|
obsFormula.constr().logVars());
|
||||||
ConstraintTree* commCt = res.first;
|
ConstraintTree* commCt = res.first;
|
||||||
ConstraintTree* exclCt = res.second;
|
ConstraintTree* exclCt = res.second;
|
||||||
|
|
||||||
|
@ -19,14 +19,37 @@ class LiftedOperator
|
|||||||
|
|
||||||
static void printValidOps (ParfactorList&, const Grounds&);
|
static void printValidOps (ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
static vector<unsigned> getAllGroupss (ParfactorList&);
|
|
||||||
|
|
||||||
static vector<ParfactorList::iterator> getParfactorsWithGroup (
|
static vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||||
ParfactorList&, unsigned group);
|
ParfactorList&, unsigned group);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ProductOperator : public LiftedOperator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ProductOperator (
|
||||||
|
ParfactorList::iterator g1, ParfactorList::iterator g2,
|
||||||
|
ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { }
|
||||||
|
|
||||||
|
double getLogCost (void);
|
||||||
|
|
||||||
|
void apply (void);
|
||||||
|
|
||||||
|
static vector<ProductOperator*> getValidOps (ParfactorList&);
|
||||||
|
|
||||||
|
string toString (void);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static bool validOp (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
|
ParfactorList::iterator g1_;
|
||||||
|
ParfactorList::iterator g2_;
|
||||||
|
ParfactorList& pfList_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SumOutOperator : public LiftedOperator
|
class SumOutOperator : public LiftedOperator
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -123,6 +146,8 @@ class FoveSolver
|
|||||||
|
|
||||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||||
|
|
||||||
|
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runSolver (const Grounds&);
|
void runSolver (const Grounds&);
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ int createLiftedNetwork (void)
|
|||||||
Util::printHeader ("INITIAL PARFACTORS");
|
Util::printHeader ("INITIAL PARFACTORS");
|
||||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||||
parfactors[i]->print();
|
parfactors[i]->print();
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,15 +137,21 @@ class Substitution
|
|||||||
|
|
||||||
LogVar newNameFor (LogVar X) const
|
LogVar newNameFor (LogVar X) const
|
||||||
{
|
{
|
||||||
assert (Util::contains (subs_, X));
|
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||||
|
it = subs_.find (X);
|
||||||
|
if (it != subs_.end()) {
|
||||||
return subs_.find (X)->second;
|
return subs_.find (X)->second;
|
||||||
}
|
}
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
bool containsReplacementFor (LogVar X) const
|
bool containsReplacementFor (LogVar X) const
|
||||||
{
|
{
|
||||||
return Util::contains (subs_, X);
|
return Util::contains (subs_, X);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned nrReplacements (void) const { return subs_.size(); }
|
||||||
|
|
||||||
LogVars getDiscardedLogVars (void) const;
|
LogVars getDiscardedLogVars (void) const;
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||||
|
@ -193,6 +193,32 @@ Parfactor::multiply (Parfactor& g)
|
|||||||
alignAndExponentiate (this, &g);
|
alignAndExponentiate (this, &g);
|
||||||
TFactor<ProbFormula>::multiply (g);
|
TFactor<ProbFormula>::multiply (g);
|
||||||
constr_->join (g.constr(), true);
|
constr_->join (g.constr(), true);
|
||||||
|
simplifyGrounds();
|
||||||
|
assert (constr_->isCarteesianProduct (countedLogVars()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::canCountConvert (LogVar X)
|
||||||
|
{
|
||||||
|
if (nrFormulas (X) != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int fIdx = indexOfLogVar (X);
|
||||||
|
if (args_[fIdx].isCounting()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (constr_->isCountNormalized (X) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (constr_->getConditionalCount (X) == 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (constr_->isCarteesianProduct (countedLogVars() | X) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -201,10 +227,9 @@ void
|
|||||||
Parfactor::countConvert (LogVar X)
|
Parfactor::countConvert (LogVar X)
|
||||||
{
|
{
|
||||||
int fIdx = indexOfLogVar (X);
|
int fIdx = indexOfLogVar (X);
|
||||||
assert (fIdx != -1);
|
|
||||||
assert (constr_->isCountNormalized (X));
|
assert (constr_->isCountNormalized (X));
|
||||||
assert (constr_->getConditionalCount (X) > 1);
|
assert (constr_->getConditionalCount (X) > 1);
|
||||||
assert (constr_->isCarteesianProduct (countedLogVars() | X));
|
assert (canCountConvert (X));
|
||||||
|
|
||||||
unsigned N = constr_->getConditionalCount (X);
|
unsigned N = constr_->getConditionalCount (X);
|
||||||
unsigned R = ranges_[fIdx];
|
unsigned R = ranges_[fIdx];
|
||||||
@ -245,6 +270,7 @@ Parfactor::countConvert (LogVar X)
|
|||||||
++ mapIndexer;
|
++ mapIndexer;
|
||||||
}
|
}
|
||||||
args_[fIdx].setCountedLogVar (X);
|
args_[fIdx].setCountedLogVar (X);
|
||||||
|
simplifyCountingFormulas (fIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -307,10 +333,9 @@ Parfactor::fullExpand (LogVar X)
|
|||||||
|
|
||||||
unsigned N = constr_->getConditionalCount (X);
|
unsigned N = constr_->getConditionalCount (X);
|
||||||
unsigned R = args_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
|
|
||||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||||
|
assert (ranges_[fIdx] == originHists.size());
|
||||||
vector<unsigned> sumIndexes;
|
vector<unsigned> sumIndexes;
|
||||||
sumIndexes.reserve (N * R);
|
sumIndexes.reserve (N * R);
|
||||||
|
|
||||||
@ -496,6 +521,20 @@ Parfactor::indexOfGroup (unsigned group) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
Parfactor::nrFormulasWithGroup (unsigned group) const
|
||||||
|
{
|
||||||
|
unsigned count = 0;
|
||||||
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].group() == group) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<unsigned>
|
vector<unsigned>
|
||||||
Parfactor::getAllGroups (void) const
|
Parfactor::getAllGroups (void) const
|
||||||
{
|
{
|
||||||
@ -672,22 +711,119 @@ Parfactor::expandPotential (
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
Parfactor::simplifyCountingFormulas (int fIdx)
|
||||||
{
|
{
|
||||||
LogVars X_1, X_2;
|
// check if we can simplify the parfactor
|
||||||
const ProbFormulas& formulas1 = g1->arguments();
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
const ProbFormulas& formulas2 = g2->arguments();
|
if ((int)i != fIdx &&
|
||||||
|
args_[i].isCounting() &&
|
||||||
|
args_[i].group() == args_[fIdx].group()) {
|
||||||
|
// if they only differ in the name of the counting log var
|
||||||
|
if ((args_[i].logVarSet() - args_[i].countedLogVar()) ==
|
||||||
|
(args_[fIdx].logVarSet()) - args_[fIdx].countedLogVar() &&
|
||||||
|
ranges_[i] == ranges_[fIdx]) {
|
||||||
|
simplifyParfactor (fIdx, i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::simplifyGrounds (void)
|
||||||
|
{
|
||||||
|
LogVarSet singletons = constr_->singletons();
|
||||||
|
for (int i = 0; i < (int)args_.size() - 1; i++) {
|
||||||
|
for (unsigned j = i + 1; j < args_.size(); j++) {
|
||||||
|
if (args_[i].group() == args_[j].group() &&
|
||||||
|
singletons.contains (args_[i].logVarSet()) &&
|
||||||
|
singletons.contains (args_[j].logVarSet())) {
|
||||||
|
simplifyParfactor (i, j);
|
||||||
|
i --;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::canMultiply (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
|
||||||
|
LogVarSet Xs_1 (res.first);
|
||||||
|
LogVarSet Xs_2 (res.second);
|
||||||
|
LogVarSet Y_1 = g1->logVarSet() - Xs_1;
|
||||||
|
LogVarSet Y_2 = g2->logVarSet() - Xs_2;
|
||||||
|
Y_1 -= g1->countedLogVars();
|
||||||
|
Y_2 -= g2->countedLogVars();
|
||||||
|
return g1->constr()->isCountNormalized (Y_1) &&
|
||||||
|
g2->constr()->isCountNormalized (Y_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::simplifyParfactor (unsigned fIdx1, unsigned fIdx2)
|
||||||
|
{
|
||||||
|
Params copy = params_;
|
||||||
|
params_.clear();
|
||||||
|
StatesIndexer indexer (ranges_);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
if (indexer[fIdx1] == indexer[fIdx2]) {
|
||||||
|
params_.push_back (copy[indexer]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < args_[fIdx2].logVars().size(); i++) {
|
||||||
|
if (nrFormulas (args_[fIdx2].logVars()[i]) == 1) {
|
||||||
|
constr_->remove ({ args_[fIdx2].logVars()[i] });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args_.erase (args_.begin() + fIdx2);
|
||||||
|
ranges_.erase (ranges_.begin() + fIdx2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::pair<LogVars, LogVars>
|
||||||
|
Parfactor::getAlignLogVars (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
g1->simplifyGrounds();
|
||||||
|
g2->simplifyGrounds();
|
||||||
|
LogVars Xs_1, Xs_2;
|
||||||
|
TinySet<unsigned> matchedI;
|
||||||
|
TinySet<unsigned> matchedJ;
|
||||||
|
ProbFormulas& formulas1 = g1->arguments();
|
||||||
|
ProbFormulas& formulas2 = g2->arguments();
|
||||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||||
if (formulas1[i].group() == formulas2[j].group()) {
|
if (formulas1[i].group() == formulas2[j].group() &&
|
||||||
Util::addToVector (X_1, formulas1[i].logVars());
|
g1->range (i) == g2->range (j) &&
|
||||||
Util::addToVector (X_2, formulas2[j].logVars());
|
matchedI.contains (i) == false &&
|
||||||
|
matchedJ.contains (j) == false) {
|
||||||
|
Util::addToVector (Xs_1, formulas1[i].logVars());
|
||||||
|
Util::addToVector (Xs_2, formulas2[j].logVars());
|
||||||
|
matchedI.insert (i);
|
||||||
|
matchedJ.insert (j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
return make_pair (Xs_1, Xs_2);
|
||||||
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
}
|
||||||
// counting log vars were already raised on counting conversion
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
alignLogicalVars (g1, g2);
|
||||||
|
LogVarSet comm = g1->logVarSet() & g2->logVarSet();
|
||||||
|
LogVarSet Y_1 = g1->logVarSet() - comm;
|
||||||
|
LogVarSet Y_2 = g2->logVarSet() - comm;
|
||||||
Y_1 -= g1->countedLogVars();
|
Y_1 -= g1->countedLogVars();
|
||||||
Y_2 -= g2->countedLogVars();
|
Y_2 -= g2->countedLogVars();
|
||||||
assert (g1->constr()->isCountNormalized (Y_1));
|
assert (g1->constr()->isCountNormalized (Y_1));
|
||||||
@ -696,30 +832,27 @@ Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
|||||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||||
LogAware::pow (g1->params(), 1.0 / condCount2);
|
LogAware::pow (g1->params(), 1.0 / condCount2);
|
||||||
LogAware::pow (g2->params(), 1.0 / condCount1);
|
LogAware::pow (g2->params(), 1.0 / condCount1);
|
||||||
//cout << "g1::::::::::::::::::::::::" << endl;
|
|
||||||
//g1->print();
|
|
||||||
//cout << "g2::::::::::::::::::::::::" << endl;
|
|
||||||
//g2->print();
|
|
||||||
// the alignment should be done in the end or else X_1 and X_2
|
|
||||||
// will refer to the old log var names on the code above
|
|
||||||
align (g1, X_1, g2, X_2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::align (
|
Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
|
||||||
Parfactor* g1, const LogVars& alignLvs1,
|
|
||||||
Parfactor* g2, const LogVars& alignLvs2)
|
|
||||||
{
|
{
|
||||||
LogVar freeLogVar = 0;
|
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
|
||||||
|
const LogVars& alignLvs1 = res.first;
|
||||||
|
const LogVars& alignLvs2 = res.second;
|
||||||
|
// cout << "ALIGNING :::::::::::::::::" << endl;
|
||||||
|
// g1->print();
|
||||||
|
// cout << "AND" << endl;
|
||||||
|
// g2->print();
|
||||||
|
// cout << "-> align lvs1 = " << alignLvs1 << endl;
|
||||||
|
// cout << "-> align lvs2 = " << alignLvs2 << endl;
|
||||||
|
LogVar freeLogVar (0);
|
||||||
Substitution theta1, theta2;
|
Substitution theta1, theta2;
|
||||||
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
||||||
bool b1 = theta1.containsReplacementFor (alignLvs1[i]);
|
bool b1 = theta1.containsReplacementFor (alignLvs1[i]);
|
||||||
bool b2 = theta2.containsReplacementFor (alignLvs2[i]);
|
bool b2 = theta2.containsReplacementFor (alignLvs2[i]);
|
||||||
// handle this type of situation:
|
|
||||||
// g1 = p(X), q(X) ; X in {x1}
|
|
||||||
// g2 = p(X), q(Y) ; X in {x1}, Y in {x1}
|
|
||||||
if (b1 == false && b2 == false) {
|
if (b1 == false && b2 == false) {
|
||||||
theta1.add (alignLvs1[i], freeLogVar);
|
theta1.add (alignLvs1[i], freeLogVar);
|
||||||
theta2.add (alignLvs2[i], freeLogVar);
|
theta2.add (alignLvs2[i], freeLogVar);
|
||||||
@ -730,6 +863,7 @@ Parfactor::align (
|
|||||||
theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i]));
|
theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
||||||
if (theta1.containsReplacementFor (allLvs1[i]) == false) {
|
if (theta1.containsReplacementFor (allLvs1[i]) == false) {
|
||||||
@ -744,25 +878,33 @@ Parfactor::align (
|
|||||||
++ freeLogVar;
|
++ freeLogVar;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//cout << "theta1: " << theta1 << endl;
|
|
||||||
//cout << "theta2: " << theta2 << endl;
|
// handle this type of situation:
|
||||||
|
// g1 = p(X), q(X) ; X in {(p1),(p2)}
|
||||||
|
// g2 = p(X), q(Y) ; (X,Y) in {(p1,p2),(p2,p1)}
|
||||||
LogVars discardedLvs1 = theta1.getDiscardedLogVars();
|
LogVars discardedLvs1 = theta1.getDiscardedLogVars();
|
||||||
for (unsigned i = 0; i < discardedLvs1.size(); i++) {
|
for (unsigned i = 0; i < discardedLvs1.size(); i++) {
|
||||||
unsigned condCount = g1->constr()->getConditionalCount (discardedLvs1[i]);
|
if (g1->constr()->isSingleton (discardedLvs1[i]) &&
|
||||||
cout << "discarding g1" << discardedLvs1[i];
|
g1->nrFormulas (discardedLvs1[i]) == 1) {
|
||||||
cout << " cc = " << condCount << endl;
|
g1->constr()->remove (discardedLvs1[i]);
|
||||||
LogAware::pow (g1->params(), condCount);
|
} else {
|
||||||
|
LogVar X_new = ++ g1->constr()->logVarSet().back();
|
||||||
|
theta1.rename (discardedLvs1[i], X_new);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LogVars discardedLvs2 = theta2.getDiscardedLogVars();
|
LogVars discardedLvs2 = theta2.getDiscardedLogVars();
|
||||||
for (unsigned i = 0; i < discardedLvs2.size(); i++) {
|
for (unsigned i = 0; i < discardedLvs2.size(); i++) {
|
||||||
unsigned condCount = g2->constr()->getConditionalCount (discardedLvs2[i]);
|
if (g2->constr()->isSingleton (discardedLvs2[i]) &&
|
||||||
cout << "discarding g2" << discardedLvs2[i];
|
g2->nrFormulas (discardedLvs2[i]) == 1) {
|
||||||
cout << " cc = " << condCount << endl;
|
g2->constr()->remove (discardedLvs2[i]);
|
||||||
//if (condCount != 1) {
|
} else {
|
||||||
// theta2.rename (discardedLvs2[i], freeLogVar);
|
LogVar X_new = ++ g2->constr()->logVarSet().back();
|
||||||
//}
|
theta2.rename (discardedLvs2[i], X_new);
|
||||||
LogAware::pow (g2->params(), condCount);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cout << "theta1: " << theta1 << endl;
|
||||||
|
// cout << "theta2: " << theta2 << endl;
|
||||||
g1->applySubstitution (theta1);
|
g1->applySubstitution (theta1);
|
||||||
g2->applySubstitution (theta2);
|
g2->applySubstitution (theta2);
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,8 @@ class Parfactor : public TFactor<ProbFormula>
|
|||||||
|
|
||||||
void multiply (Parfactor&);
|
void multiply (Parfactor&);
|
||||||
|
|
||||||
|
bool canCountConvert (LogVar X);
|
||||||
|
|
||||||
void countConvert (LogVar);
|
void countConvert (LogVar);
|
||||||
|
|
||||||
void expand (LogVar, LogVar, LogVar);
|
void expand (LogVar, LogVar, LogVar);
|
||||||
@ -76,6 +78,8 @@ class Parfactor : public TFactor<ProbFormula>
|
|||||||
|
|
||||||
int indexOfGroup (unsigned) const;
|
int indexOfGroup (unsigned) const;
|
||||||
|
|
||||||
|
unsigned nrFormulasWithGroup (unsigned) const;
|
||||||
|
|
||||||
vector<unsigned> getAllGroups (void) const;
|
vector<unsigned> getAllGroups (void) const;
|
||||||
|
|
||||||
void print (bool = false) const;
|
void print (bool = false) const;
|
||||||
@ -86,16 +90,28 @@ class Parfactor : public TFactor<ProbFormula>
|
|||||||
|
|
||||||
string getLabel (void) const;
|
string getLabel (void) const;
|
||||||
|
|
||||||
|
void simplifyGrounds (void);
|
||||||
|
|
||||||
|
static bool canMultiply (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
void simplifyCountingFormulas (int fIdx);
|
||||||
|
|
||||||
|
void simplifyParfactor (unsigned fIdx1, unsigned fIdx2);
|
||||||
|
|
||||||
|
static std::pair<LogVars, LogVars> getAlignLogVars (
|
||||||
|
Parfactor* g1, Parfactor* g2);
|
||||||
|
|
||||||
void expandPotential (int fIdx, unsigned newRange,
|
void expandPotential (int fIdx, unsigned newRange,
|
||||||
const vector<unsigned>& sumIndexes);
|
const vector<unsigned>& sumIndexes);
|
||||||
|
|
||||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
static void align (
|
static void alignLogicalVars (Parfactor*, Parfactor*);
|
||||||
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
|
||||||
|
|
||||||
ConstraintTree* constr_;
|
ConstraintTree* constr_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
TODO
|
- Refactor sum out in factor
|
||||||
- add a way to calculate combinations and factorials with large numbers
|
- Add a way to sum out several vars at the same time
|
||||||
- refactor sumOut in parfactor -> is really ugly code
|
- Receive ranges as a constant reference in Indexer
|
||||||
- Indexer: start receiving ranges as constant reference
|
- Merge TinySet and SortedVector classes
|
||||||
|
- Check if evidence remains in the compressed factor graph
|
||||||
|
- Consider using hashs instead of vectors of colors to calculate the groups in
|
||||||
|
counting bp
|
||||||
|
- use more psize_t instead of unsigned for looping through params
|
||||||
|
- use more Util::abort and Util::vectorIndex
|
||||||
|
- LogVar should not cast to int
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
|
|
||||||
:- use_module(library(pfl)).
|
|
||||||
|
|
||||||
:- set_pfl_flag(solver,fove).
|
|
||||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
|
||||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
|
||||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).
|
|
||||||
|
|
||||||
|
|
||||||
t(ann).
|
|
||||||
t(dave).
|
|
||||||
|
|
||||||
% p(ann,t).
|
|
||||||
|
|
||||||
markov p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
|
|
||||||
|
|
||||||
% use standard Prolog queries: provide evidence first.
|
|
||||||
|
|
||||||
?- p(ann,t), p(ann,X).
|
|
||||||
% ?- p(ann,X).
|
|
||||||
|
|
Reference in New Issue
Block a user