refactor and increase the scope of parfactor multiplication
This commit is contained in:
parent
f4bca3ceea
commit
db0d2c9772
@ -359,10 +359,6 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
|
||||
void
|
||||
ConstraintTree::applySubstitution (const Substitution& theta)
|
||||
{
|
||||
LogVars discardedLvs = theta.getDiscardedLogVars();
|
||||
for (unsigned i = 0; i < discardedLvs.size(); i++) {
|
||||
remove(discardedLvs[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < logVars_.size(); i++) {
|
||||
logVars_[i] = theta.newNameFor (logVars_[i]);
|
||||
}
|
||||
|
@ -13,17 +13,23 @@ LiftedOperator::getValidOps (
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<SumOutOperator*> sumOutOps;
|
||||
vector<CountingOperator*> countOps;
|
||||
vector<GroundOperator*> groundOps;
|
||||
vector<ProductOperator*> multOps;
|
||||
|
||||
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
||||
countOps = CountingOperator::getValidOps (pfList);
|
||||
groundOps = GroundOperator::getValidOps (pfList);
|
||||
multOps = ProductOperator::getValidOps (pfList);
|
||||
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
|
||||
|
||||
if (Globals::verbosity > 1 || multOps.empty()) {
|
||||
vector<SumOutOperator*> sumOutOps;
|
||||
vector<CountingOperator*> countOps;
|
||||
vector<GroundOperator*> groundOps;
|
||||
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
||||
countOps = CountingOperator::getValidOps (pfList);
|
||||
groundOps = GroundOperator::getValidOps (pfList);
|
||||
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
|
||||
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
|
||||
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
|
||||
}
|
||||
|
||||
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
|
||||
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
|
||||
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
|
||||
return validOps;
|
||||
}
|
||||
|
||||
@ -44,17 +50,9 @@ LiftedOperator::printValidOps (
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
LiftedOperator::getAllGroupss (ParfactorList& )
|
||||
{
|
||||
return { };
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<ParfactorList::iterator>
|
||||
LiftedOperator::getParfactorsWithGroup (
|
||||
ParfactorList& pfList, unsigned group)
|
||||
ParfactorList& pfList, unsigned group)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
ParfactorList::iterator pflIt = pfList.begin();
|
||||
@ -69,6 +67,105 @@ LiftedOperator::getParfactorsWithGroup (
|
||||
|
||||
|
||||
|
||||
double
|
||||
ProductOperator::getLogCost (void)
|
||||
{
|
||||
return std::log (0.0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ProductOperator::apply (void)
|
||||
{
|
||||
Parfactor* g1 = *g1_;
|
||||
Parfactor* g2 = *g2_;
|
||||
g1->multiply (*g2);
|
||||
pfList_.remove (g1_);
|
||||
pfList_.removeAndDelete (g2_);
|
||||
pfList_.addShattered (g1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<ProductOperator*>
|
||||
ProductOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<ProductOperator*> validOps;
|
||||
ParfactorList::iterator it1 = pfList.begin();
|
||||
ParfactorList::iterator penultimate = -- pfList.end();
|
||||
set<Parfactor*> pfs;
|
||||
while (it1 != penultimate) {
|
||||
if (Util::contains (pfs, *it1)) {
|
||||
++ it1;
|
||||
continue;
|
||||
}
|
||||
ParfactorList::iterator it2 = it1;
|
||||
++ it2;
|
||||
while (it2 != pfList.end()) {
|
||||
if (Util::contains (pfs, *it2)) {
|
||||
++ it2;
|
||||
continue;
|
||||
} else {
|
||||
if (validOp (*it1, *it2)) {
|
||||
pfs.insert (*it1);
|
||||
pfs.insert (*it2);
|
||||
validOps.push_back (new ProductOperator (
|
||||
it1, it2, pfList));
|
||||
if (Globals::verbosity < 2) {
|
||||
return validOps;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
++ it2;
|
||||
}
|
||||
++ it1;
|
||||
}
|
||||
return validOps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
ProductOperator::toString (void)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "just multiplicate " ;
|
||||
ss << (*g1_)->getAllGroups();
|
||||
ss << " x " ;
|
||||
ss << (*g2_)->getAllGroups();
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
TinySet<unsigned> g1_gs (g1->getAllGroups());
|
||||
TinySet<unsigned> g2_gs (g2->getAllGroups());
|
||||
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
|
||||
TinySet<unsigned> intersect = g1_gs & g2_gs;
|
||||
for (unsigned i = 0; i < intersect.size(); i++) {
|
||||
if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
|
||||
g2->nrFormulasWithGroup (intersect[i]) != 1) {
|
||||
return false;
|
||||
}
|
||||
int idx1 = g1->indexOfGroup (intersect[i]);
|
||||
int idx2 = g2->indexOfGroup (intersect[i]);
|
||||
if (g1->range (idx1) != g2->range (idx2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return Parfactor::canMultiply (g1, g2);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
SumOutOperator::getLogCost (void)
|
||||
{
|
||||
@ -84,7 +181,8 @@ SumOutOperator::getLogCost (void)
|
||||
++ pfIter;
|
||||
}
|
||||
if (nrProdFactors == 1) {
|
||||
return std::log (0.0); // best possible case
|
||||
// best possible case
|
||||
return std::log (0.0);
|
||||
}
|
||||
double cost = 1.0;
|
||||
for (unsigned i = 0; i < groupSet.size(); i++) {
|
||||
@ -190,38 +288,20 @@ SumOutOperator::validOp (
|
||||
if (isToEliminate (*pfIters[0], group, query) == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
unordered_map<unsigned, unsigned> groupToRange;
|
||||
int range = -1;
|
||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||
const ProbFormulas& formulas = (*pfIters[i])->arguments();
|
||||
int fIdx = -1;
|
||||
for (unsigned j = 0; j < formulas.size(); j++) {
|
||||
if (formulas[j].group() == group) {
|
||||
if (fIdx != -1) {
|
||||
// only summout a group of rand vars if they don't
|
||||
// appear in another position on the factor
|
||||
return false;
|
||||
} else {
|
||||
fIdx = j;
|
||||
}
|
||||
}
|
||||
if ((*pfIters[i])->nrFormulasWithGroup (group) > 1) {
|
||||
return false;
|
||||
}
|
||||
int fIdx = (*pfIters[i])->indexOfGroup (group);
|
||||
if ((*pfIters[i])->argument (fIdx).contains (
|
||||
(*pfIters[i])->elimLogVars()) == false) {
|
||||
return false;
|
||||
}
|
||||
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
||||
vector<unsigned> groups = (*pfIters[i])->getAllGroups();
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
unordered_map<unsigned, unsigned>::iterator it;
|
||||
it = groupToRange.find (groups[i]);
|
||||
if (it == groupToRange.end()) {
|
||||
groupToRange.insert (make_pair (groups[i], ranges[i]));
|
||||
} else {
|
||||
if (it->second != ranges[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (range == -1) {
|
||||
range = (*pfIters[i])->range (fIdx);
|
||||
} else if ((int)(*pfIters[i])->range (fIdx) != range) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@ -265,8 +345,23 @@ CountingOperator::getLogCost (void)
|
||||
for (unsigned i = 0; i < counts.size(); i++) {
|
||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||
}
|
||||
if ((*pfIter_)->nrArguments() == 1) {
|
||||
cost *= 3; // avoid counting conversion in the beginning
|
||||
unsigned group = (*pfIter_)->argument (fIdx).group();
|
||||
int lvIndex = Util::vectorIndex (
|
||||
(*pfIter_)->argument (fIdx).logVars(), X_);
|
||||
assert (lvIndex != -1);
|
||||
ParfactorList::iterator pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
if (pfIter != pfIter_) {
|
||||
int fIdx2 = (*pfIter)->indexOfGroup (group);
|
||||
if (fIdx2 != -1) {
|
||||
LogVar Y = ((*pfIter)->argument (fIdx2).logVars()[lvIndex]);
|
||||
if ((*pfIter)->canCountConvert (Y) == false) {
|
||||
// the real cost should be the cost of grounding Y
|
||||
cost *= 10.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
return std::log (cost);
|
||||
}
|
||||
@ -308,6 +403,7 @@ CountingOperator::getValidOps (ParfactorList& pfList)
|
||||
if (validOp (*it, candidates[i])) {
|
||||
validOps.push_back (new CountingOperator (
|
||||
it, candidates[i], pfList));
|
||||
} else {
|
||||
}
|
||||
}
|
||||
++ it;
|
||||
@ -350,12 +446,7 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
}
|
||||
bool countNormalized = g->constr()->isCountNormalized (X);
|
||||
if (countNormalized) {
|
||||
unsigned condCount = g->constr()->getConditionalCount (X);
|
||||
bool cartProduct = g->constr()->isCarteesianProduct (
|
||||
g->countedLogVars() | X);
|
||||
if (condCount == 1 || cartProduct == false) {
|
||||
return false;
|
||||
}
|
||||
return g->canCountConvert (X);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -438,6 +529,11 @@ GroundOperator::apply (void)
|
||||
}
|
||||
delete pf;
|
||||
}
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
(*pflIt)->simplifyGrounds();
|
||||
++ pflIt;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -625,6 +721,39 @@ FoveSolver::countNormalize (
|
||||
|
||||
|
||||
|
||||
Parfactor
|
||||
FoveSolver::calcGroundMultiplication (Parfactor pf)
|
||||
{
|
||||
LogVarSet lvs = pf.constr()->logVarSet();
|
||||
lvs -= pf.constr()->singletons();
|
||||
Parfactors newPfs = {new Parfactor (pf)};
|
||||
for (unsigned i = 0; i < lvs.size(); i++) {
|
||||
Parfactors pfs = newPfs;
|
||||
newPfs.clear();
|
||||
for (unsigned j = 0; j < pfs.size(); j++) {
|
||||
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
||||
if (countedLv) {
|
||||
pfs[j]->fullExpand (lvs[i]);
|
||||
newPfs.push_back (pfs[j]);
|
||||
} else {
|
||||
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
||||
for (unsigned k = 0; k < cts.size(); k++) {
|
||||
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
||||
}
|
||||
delete pfs[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
ParfactorList pfList (newPfs);
|
||||
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
||||
for (unsigned i = 1; i < groundShatteredPfs.size(); i++) {
|
||||
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
||||
}
|
||||
return Parfactor (*groundShatteredPfs[0]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::runSolver (const Grounds& query)
|
||||
{
|
||||
@ -665,6 +794,7 @@ FoveSolver::runSolver (const Grounds& query)
|
||||
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
||||
cout << endl;
|
||||
}
|
||||
(*pfList_.begin())->simplifyGrounds();
|
||||
(*pfList_.begin())->reorderAccordingGrounds (query);
|
||||
}
|
||||
|
||||
@ -769,8 +899,11 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
found = true;
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
||||
(*it)->constr()->split (query[i].args(), query[i].arity());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split;
|
||||
LogVars queryLvs (
|
||||
(*it)->constr()->logVars().begin(),
|
||||
(*it)->constr()->logVars().begin() + query[i].arity());
|
||||
split = (*it)->constr()->split (query[i].args());
|
||||
ConstraintTree* commCt = split.first;
|
||||
ConstraintTree* exclCt = split.second;
|
||||
newPfs.push_back (new Parfactor (*it, commCt));
|
||||
@ -826,8 +959,11 @@ FoveSolver::absorve (
|
||||
}
|
||||
|
||||
g->constr()->moveToTop (formulas[i].logVars());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> res
|
||||
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> res;
|
||||
res = g->constr()->split (
|
||||
formulas[i].logVars(),
|
||||
&(obsFormula.constr()),
|
||||
obsFormula.constr().logVars());
|
||||
ConstraintTree* commCt = res.first;
|
||||
ConstraintTree* exclCt = res.second;
|
||||
|
||||
|
@ -19,14 +19,37 @@ class LiftedOperator
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<unsigned> getAllGroupss (ParfactorList&);
|
||||
|
||||
static vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||
ParfactorList&, unsigned group);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class ProductOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
ProductOperator (
|
||||
ParfactorList::iterator g1, ParfactorList::iterator g2,
|
||||
ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<ProductOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, Parfactor*);
|
||||
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SumOutOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
@ -123,6 +146,8 @@ class FoveSolver
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
|
||||
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||
|
||||
private:
|
||||
void runSolver (const Grounds&);
|
||||
|
||||
|
@ -68,6 +68,7 @@ int createLiftedNetwork (void)
|
||||
Util::printHeader ("INITIAL PARFACTORS");
|
||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
parfactors[i]->print();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -137,14 +137,20 @@ class Substitution
|
||||
|
||||
LogVar newNameFor (LogVar X) const
|
||||
{
|
||||
assert (Util::contains (subs_, X));
|
||||
return subs_.find (X)->second;
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.find (X);
|
||||
if (it != subs_.end()) {
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
bool containsReplacementFor (LogVar X) const
|
||||
{
|
||||
return Util::contains (subs_, X);
|
||||
}
|
||||
|
||||
unsigned nrReplacements (void) const { return subs_.size(); }
|
||||
|
||||
LogVars getDiscardedLogVars (void) const;
|
||||
|
||||
|
@ -193,6 +193,32 @@ Parfactor::multiply (Parfactor& g)
|
||||
alignAndExponentiate (this, &g);
|
||||
TFactor<ProbFormula>::multiply (g);
|
||||
constr_->join (g.constr(), true);
|
||||
simplifyGrounds();
|
||||
assert (constr_->isCarteesianProduct (countedLogVars()));
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::canCountConvert (LogVar X)
|
||||
{
|
||||
if (nrFormulas (X) != 1) {
|
||||
return false;
|
||||
}
|
||||
int fIdx = indexOfLogVar (X);
|
||||
if (args_[fIdx].isCounting()) {
|
||||
return false;
|
||||
}
|
||||
if (constr_->isCountNormalized (X) == false) {
|
||||
return false;
|
||||
}
|
||||
if (constr_->getConditionalCount (X) == 1) {
|
||||
return false;
|
||||
}
|
||||
if (constr_->isCarteesianProduct (countedLogVars() | X) == false) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -201,11 +227,10 @@ void
|
||||
Parfactor::countConvert (LogVar X)
|
||||
{
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (constr_->isCountNormalized (X));
|
||||
assert (constr_->getConditionalCount (X) > 1);
|
||||
assert (constr_->isCarteesianProduct (countedLogVars() | X));
|
||||
|
||||
assert (canCountConvert (X));
|
||||
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = ranges_[fIdx];
|
||||
unsigned H = HistogramSet::nrHistograms (N, R);
|
||||
@ -245,6 +270,7 @@ Parfactor::countConvert (LogVar X)
|
||||
++ mapIndexer;
|
||||
}
|
||||
args_[fIdx].setCountedLogVar (X);
|
||||
simplifyCountingFormulas (fIdx);
|
||||
}
|
||||
|
||||
|
||||
@ -307,10 +333,9 @@ Parfactor::fullExpand (LogVar X)
|
||||
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = args_[fIdx].range();
|
||||
|
||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||
|
||||
assert (ranges_[fIdx] == originHists.size());
|
||||
vector<unsigned> sumIndexes;
|
||||
sumIndexes.reserve (N * R);
|
||||
|
||||
@ -496,6 +521,20 @@ Parfactor::indexOfGroup (unsigned group) const
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::nrFormulasWithGroup (unsigned group) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].group() == group) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
Parfactor::getAllGroups (void) const
|
||||
{
|
||||
@ -672,22 +711,119 @@ Parfactor::expandPotential (
|
||||
|
||||
|
||||
void
|
||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||
Parfactor::simplifyCountingFormulas (int fIdx)
|
||||
{
|
||||
LogVars X_1, X_2;
|
||||
const ProbFormulas& formulas1 = g1->arguments();
|
||||
const ProbFormulas& formulas2 = g2->arguments();
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].group() == formulas2[j].group()) {
|
||||
Util::addToVector (X_1, formulas1[i].logVars());
|
||||
Util::addToVector (X_2, formulas2[j].logVars());
|
||||
// check if we can simplify the parfactor
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if ((int)i != fIdx &&
|
||||
args_[i].isCounting() &&
|
||||
args_[i].group() == args_[fIdx].group()) {
|
||||
// if they only differ in the name of the counting log var
|
||||
if ((args_[i].logVarSet() - args_[i].countedLogVar()) ==
|
||||
(args_[fIdx].logVarSet()) - args_[fIdx].countedLogVar() &&
|
||||
ranges_[i] == ranges_[fIdx]) {
|
||||
simplifyParfactor (fIdx, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
||||
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
||||
// counting log vars were already raised on counting conversion
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::simplifyGrounds (void)
|
||||
{
|
||||
LogVarSet singletons = constr_->singletons();
|
||||
for (int i = 0; i < (int)args_.size() - 1; i++) {
|
||||
for (unsigned j = i + 1; j < args_.size(); j++) {
|
||||
if (args_[i].group() == args_[j].group() &&
|
||||
singletons.contains (args_[i].logVarSet()) &&
|
||||
singletons.contains (args_[j].logVarSet())) {
|
||||
simplifyParfactor (i, j);
|
||||
i --;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::canMultiply (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
|
||||
LogVarSet Xs_1 (res.first);
|
||||
LogVarSet Xs_2 (res.second);
|
||||
LogVarSet Y_1 = g1->logVarSet() - Xs_1;
|
||||
LogVarSet Y_2 = g2->logVarSet() - Xs_2;
|
||||
Y_1 -= g1->countedLogVars();
|
||||
Y_2 -= g2->countedLogVars();
|
||||
return g1->constr()->isCountNormalized (Y_1) &&
|
||||
g2->constr()->isCountNormalized (Y_2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::simplifyParfactor (unsigned fIdx1, unsigned fIdx2)
|
||||
{
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
StatesIndexer indexer (ranges_);
|
||||
while (indexer.valid()) {
|
||||
if (indexer[fIdx1] == indexer[fIdx2]) {
|
||||
params_.push_back (copy[indexer]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
for (unsigned i = 0; i < args_[fIdx2].logVars().size(); i++) {
|
||||
if (nrFormulas (args_[fIdx2].logVars()[i]) == 1) {
|
||||
constr_->remove ({ args_[fIdx2].logVars()[i] });
|
||||
}
|
||||
}
|
||||
args_.erase (args_.begin() + fIdx2);
|
||||
ranges_.erase (ranges_.begin() + fIdx2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::pair<LogVars, LogVars>
|
||||
Parfactor::getAlignLogVars (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
g1->simplifyGrounds();
|
||||
g2->simplifyGrounds();
|
||||
LogVars Xs_1, Xs_2;
|
||||
TinySet<unsigned> matchedI;
|
||||
TinySet<unsigned> matchedJ;
|
||||
ProbFormulas& formulas1 = g1->arguments();
|
||||
ProbFormulas& formulas2 = g2->arguments();
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].group() == formulas2[j].group() &&
|
||||
g1->range (i) == g2->range (j) &&
|
||||
matchedI.contains (i) == false &&
|
||||
matchedJ.contains (j) == false) {
|
||||
Util::addToVector (Xs_1, formulas1[i].logVars());
|
||||
Util::addToVector (Xs_2, formulas2[j].logVars());
|
||||
matchedI.insert (i);
|
||||
matchedJ.insert (j);
|
||||
}
|
||||
}
|
||||
}
|
||||
return make_pair (Xs_1, Xs_2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
alignLogicalVars (g1, g2);
|
||||
LogVarSet comm = g1->logVarSet() & g2->logVarSet();
|
||||
LogVarSet Y_1 = g1->logVarSet() - comm;
|
||||
LogVarSet Y_2 = g2->logVarSet() - comm;
|
||||
Y_1 -= g1->countedLogVars();
|
||||
Y_2 -= g2->countedLogVars();
|
||||
assert (g1->constr()->isCountNormalized (Y_1));
|
||||
@ -696,30 +832,27 @@ Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||
LogAware::pow (g1->params(), 1.0 / condCount2);
|
||||
LogAware::pow (g2->params(), 1.0 / condCount1);
|
||||
//cout << "g1::::::::::::::::::::::::" << endl;
|
||||
//g1->print();
|
||||
//cout << "g2::::::::::::::::::::::::" << endl;
|
||||
//g2->print();
|
||||
// the alignment should be done in the end or else X_1 and X_2
|
||||
// will refer to the old log var names on the code above
|
||||
align (g1, X_1, g2, X_2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::align (
|
||||
Parfactor* g1, const LogVars& alignLvs1,
|
||||
Parfactor* g2, const LogVars& alignLvs2)
|
||||
Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
LogVar freeLogVar = 0;
|
||||
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
|
||||
const LogVars& alignLvs1 = res.first;
|
||||
const LogVars& alignLvs2 = res.second;
|
||||
// cout << "ALIGNING :::::::::::::::::" << endl;
|
||||
// g1->print();
|
||||
// cout << "AND" << endl;
|
||||
// g2->print();
|
||||
// cout << "-> align lvs1 = " << alignLvs1 << endl;
|
||||
// cout << "-> align lvs2 = " << alignLvs2 << endl;
|
||||
LogVar freeLogVar (0);
|
||||
Substitution theta1, theta2;
|
||||
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
||||
bool b1 = theta1.containsReplacementFor (alignLvs1[i]);
|
||||
bool b2 = theta2.containsReplacementFor (alignLvs2[i]);
|
||||
// handle this type of situation:
|
||||
// g1 = p(X), q(X) ; X in {x1}
|
||||
// g2 = p(X), q(Y) ; X in {x1}, Y in {x1}
|
||||
if (b1 == false && b2 == false) {
|
||||
theta1.add (alignLvs1[i], freeLogVar);
|
||||
theta2.add (alignLvs2[i], freeLogVar);
|
||||
@ -730,6 +863,7 @@ Parfactor::align (
|
||||
theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i]));
|
||||
}
|
||||
}
|
||||
|
||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
||||
if (theta1.containsReplacementFor (allLvs1[i]) == false) {
|
||||
@ -744,25 +878,33 @@ Parfactor::align (
|
||||
++ freeLogVar;
|
||||
}
|
||||
}
|
||||
//cout << "theta1: " << theta1 << endl;
|
||||
//cout << "theta2: " << theta2 << endl;
|
||||
|
||||
// handle this type of situation:
|
||||
// g1 = p(X), q(X) ; X in {(p1),(p2)}
|
||||
// g2 = p(X), q(Y) ; (X,Y) in {(p1,p2),(p2,p1)}
|
||||
LogVars discardedLvs1 = theta1.getDiscardedLogVars();
|
||||
for (unsigned i = 0; i < discardedLvs1.size(); i++) {
|
||||
unsigned condCount = g1->constr()->getConditionalCount (discardedLvs1[i]);
|
||||
cout << "discarding g1" << discardedLvs1[i];
|
||||
cout << " cc = " << condCount << endl;
|
||||
LogAware::pow (g1->params(), condCount);
|
||||
if (g1->constr()->isSingleton (discardedLvs1[i]) &&
|
||||
g1->nrFormulas (discardedLvs1[i]) == 1) {
|
||||
g1->constr()->remove (discardedLvs1[i]);
|
||||
} else {
|
||||
LogVar X_new = ++ g1->constr()->logVarSet().back();
|
||||
theta1.rename (discardedLvs1[i], X_new);
|
||||
}
|
||||
}
|
||||
LogVars discardedLvs2 = theta2.getDiscardedLogVars();
|
||||
for (unsigned i = 0; i < discardedLvs2.size(); i++) {
|
||||
unsigned condCount = g2->constr()->getConditionalCount (discardedLvs2[i]);
|
||||
cout << "discarding g2" << discardedLvs2[i];
|
||||
cout << " cc = " << condCount << endl;
|
||||
//if (condCount != 1) {
|
||||
// theta2.rename (discardedLvs2[i], freeLogVar);
|
||||
//}
|
||||
LogAware::pow (g2->params(), condCount);
|
||||
if (g2->constr()->isSingleton (discardedLvs2[i]) &&
|
||||
g2->nrFormulas (discardedLvs2[i]) == 1) {
|
||||
g2->constr()->remove (discardedLvs2[i]);
|
||||
} else {
|
||||
LogVar X_new = ++ g2->constr()->logVarSet().back();
|
||||
theta2.rename (discardedLvs2[i], X_new);
|
||||
}
|
||||
}
|
||||
|
||||
// cout << "theta1: " << theta1 << endl;
|
||||
// cout << "theta2: " << theta2 << endl;
|
||||
g1->applySubstitution (theta1);
|
||||
g2->applySubstitution (theta2);
|
||||
}
|
||||
|
@ -50,6 +50,8 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
void multiply (Parfactor&);
|
||||
|
||||
bool canCountConvert (LogVar X);
|
||||
|
||||
void countConvert (LogVar);
|
||||
|
||||
void expand (LogVar, LogVar, LogVar);
|
||||
@ -76,6 +78,8 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
int indexOfGroup (unsigned) const;
|
||||
|
||||
unsigned nrFormulasWithGroup (unsigned) const;
|
||||
|
||||
vector<unsigned> getAllGroups (void) const;
|
||||
|
||||
void print (bool = false) const;
|
||||
@ -86,16 +90,28 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
string getLabel (void) const;
|
||||
|
||||
void simplifyGrounds (void);
|
||||
|
||||
static bool canMultiply (Parfactor*, Parfactor*);
|
||||
|
||||
private:
|
||||
|
||||
void simplifyCountingFormulas (int fIdx);
|
||||
|
||||
void simplifyParfactor (unsigned fIdx1, unsigned fIdx2);
|
||||
|
||||
static std::pair<LogVars, LogVars> getAlignLogVars (
|
||||
Parfactor* g1, Parfactor* g2);
|
||||
|
||||
void expandPotential (int fIdx, unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes);
|
||||
|
||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||
|
||||
static void align (
|
||||
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
||||
static void alignLogicalVars (Parfactor*, Parfactor*);
|
||||
|
||||
ConstraintTree* constr_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
@ -1,4 +1,10 @@
|
||||
TODO
|
||||
- add a way to calculate combinations and factorials with large numbers
|
||||
- refactor sumOut in parfactor -> is really ugly code
|
||||
- Indexer: start receiving ranges as constant reference
|
||||
- Refactor sum out in factor
|
||||
- Add a way to sum out several vars at the same time
|
||||
- Receive ranges as a constant reference in Indexer
|
||||
- Merge TinySet and SortedVector classes
|
||||
- Check if evidence remains in the compressed factor graph
|
||||
- Consider using hashs instead of vectors of colors to calculate the groups in
|
||||
counting bp
|
||||
- use more psize_t instead of unsigned for looping through params
|
||||
- use more Util::abort and Util::vectorIndex
|
||||
- LogVar should not cast to int
|
||||
|
@ -1,21 +0,0 @@
|
||||
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- set_pfl_flag(solver,fove).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).
|
||||
|
||||
|
||||
t(ann).
|
||||
t(dave).
|
||||
|
||||
% p(ann,t).
|
||||
|
||||
markov p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
|
||||
|
||||
% use standard Prolog queries: provide evidence first.
|
||||
|
||||
?- p(ann,t), p(ann,X).
|
||||
% ?- p(ann,X).
|
||||
|
Reference in New Issue
Block a user