more work to support inference with lifted knowledge compilation
This commit is contained in:
parent
c53220aa61
commit
8ab622e0aa
@ -190,7 +190,7 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
|||||||
Clauses clauses = lwcnf->clauses();
|
Clauses clauses = lwcnf->clauses();
|
||||||
compile (&root_, clauses);
|
compile (&root_, clauses);
|
||||||
exportToGraphViz("circuit.dot");
|
exportToGraphViz("circuit.dot");
|
||||||
smoothCircuit();
|
smoothCircuit (root_);
|
||||||
exportToGraphViz("circuit.smooth.dot");
|
exportToGraphViz("circuit.smooth.dot");
|
||||||
cout << "--------------------------------------------------" << endl;
|
cout << "--------------------------------------------------" << endl;
|
||||||
cout << "--------------------------------------------------" << endl;
|
cout << "--------------------------------------------------" << endl;
|
||||||
@ -199,14 +199,6 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedCircuit::smoothCircuit (void)
|
|
||||||
{
|
|
||||||
smoothCircuit (root_);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
LiftedCircuit::getWeightedModelCount (void) const
|
LiftedCircuit::getWeightedModelCount (void) const
|
||||||
{
|
{
|
||||||
|
@ -211,8 +211,6 @@ class LiftedCircuit
|
|||||||
public:
|
public:
|
||||||
LiftedCircuit (const LiftedWCNF* lwcnf);
|
LiftedCircuit (const LiftedWCNF* lwcnf);
|
||||||
|
|
||||||
void smoothCircuit (void);
|
|
||||||
|
|
||||||
double getWeightedModelCount (void) const;
|
double getWeightedModelCount (void) const;
|
||||||
|
|
||||||
void exportToGraphViz (const char*);
|
void exportToGraphViz (const char*);
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include "LiftedKc.h"
|
#include "LiftedKc.h"
|
||||||
#include "LiftedWCNF.h"
|
#include "LiftedWCNF.h"
|
||||||
#include "LiftedCircuit.h"
|
#include "LiftedCircuit.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
|
||||||
|
|
||||||
LiftedKc::LiftedKc (const ParfactorList& pfList)
|
LiftedKc::LiftedKc (const ParfactorList& pfList)
|
||||||
@ -19,9 +20,46 @@ LiftedKc::~LiftedKc (void)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
LiftedKc::solveQuery (const Grounds&)
|
LiftedKc::solveQuery (const Grounds& query)
|
||||||
{
|
{
|
||||||
return Params();
|
vector<PrvGroup> groups;
|
||||||
|
Ranges ranges;
|
||||||
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
size_t idx = (*it)->indexOfGround (query[i]);
|
||||||
|
if (idx != (*it)->nrArguments()) {
|
||||||
|
groups.push_back ((*it)->argument (idx).group());
|
||||||
|
ranges.push_back ((*it)->range (idx));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "groups: " << groups << endl;
|
||||||
|
cout << "ranges: " << ranges << endl;
|
||||||
|
Params params;
|
||||||
|
Indexer indexer (ranges);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
||||||
|
for (size_t j = 0; j < litIds.size(); j++) {
|
||||||
|
if (indexer[i] == j) {
|
||||||
|
lwcnf_->addWeight (litIds[j], 1.0, 1.0); // TODO not log aware
|
||||||
|
} else {
|
||||||
|
lwcnf_->addWeight (litIds[j], 0.0, 1.0); // TODO not log aware
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// cout << "new weights ----- ----- -----" << endl;
|
||||||
|
// lwcnf_->printWeights();
|
||||||
|
// circuit_->exportToGraphViz ("ccircuit.dot");
|
||||||
|
params.push_back (circuit_->getWeightedModelCount());
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
cout << "params: " << params << endl;
|
||||||
|
LogAware::normalize (params);
|
||||||
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -352,8 +352,8 @@ Clause::getLogVarSetExcluding (size_t idx) const
|
|||||||
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||||
: freeLiteralId_(0), pfList_(pfList)
|
: freeLiteralId_(0), pfList_(pfList)
|
||||||
{
|
{
|
||||||
//addIndicatorClauses (pfList);
|
addIndicatorClauses (pfList);
|
||||||
//addParameterClauses (pfList);
|
addParameterClauses (pfList);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
vector<vector<string>> names = {
|
vector<vector<string>> names = {
|
||||||
@ -378,24 +378,24 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
|||||||
freeLiteralId_ = 2;
|
freeLiteralId_ = 2;
|
||||||
*/
|
*/
|
||||||
|
|
||||||
Literal lit1 (0, {0});
|
//Literal lit1 (0, {0});
|
||||||
Literal lit2 (1, {0});
|
//Literal lit2 (1, {0});
|
||||||
Literal lit3 (2, {1});
|
//Literal lit3 (2, {1});
|
||||||
Literal lit4 (3, {1});
|
//Literal lit4 (3, {1});
|
||||||
|
|
||||||
vector<vector<string>> names = {{"p1","p2"},{"p3","p4"}};
|
//vector<vector<string>> names = {{"p1","p2"},{"p3","p4"}};
|
||||||
Clause c1 (names);
|
//Clause c1 (names);
|
||||||
c1.addLiteral (lit1);
|
//c1.addLiteral (lit1);
|
||||||
c1.addLiteral (lit2);
|
//c1.addLiteral (lit2);
|
||||||
c1.addLiteral (lit3);
|
//c1.addLiteral (lit3);
|
||||||
c1.addLiteral (lit4);
|
//c1.addLiteral (lit4);
|
||||||
//c1.addPosCountedLogVar (0);
|
//c1.addPosCountedLogVar (0);
|
||||||
clauses_.push_back (c1);
|
//clauses_.push_back (c1);
|
||||||
|
|
||||||
Clause c2 (names);
|
//Clause c2 (names);
|
||||||
c2.addLiteral (lit1);
|
//c2.addLiteral (lit1);
|
||||||
c2.addLiteral (lit3);
|
//c2.addLiteral (lit3);
|
||||||
c2.addNegCountedLogVar (0);
|
//c2.addNegCountedLogVar (0);
|
||||||
//clauses_.push_back (c2);
|
//clauses_.push_back (c2);
|
||||||
/*
|
/*
|
||||||
Clause c3;
|
Clause c3;
|
||||||
@ -408,17 +408,18 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
|||||||
c4.addLiteral (lit3);
|
c4.addLiteral (lit3);
|
||||||
clauses_.push_back (c4);
|
clauses_.push_back (c4);
|
||||||
*/
|
*/
|
||||||
freeLiteralId_ = 4;
|
//freeLiteralId_ = 4;
|
||||||
|
|
||||||
cout << "FORMULA INDICATORS:" << endl;
|
cout << "FORMULA INDICATORS:" << endl;
|
||||||
// printFormulaIndicators();
|
printFormulaIndicators();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
|
||||||
cout << "WEIGHTS:" << endl;
|
cout << "WEIGHTS:" << endl;
|
||||||
printWeights();
|
printWeights();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
|
||||||
cout << "CLAUSES:" << endl;
|
cout << "CLAUSES:" << endl;
|
||||||
printClauses();
|
printClauses();
|
||||||
// abort();
|
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -431,6 +432,14 @@ LiftedWCNF::~LiftedWCNF (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||||
|
{
|
||||||
|
weights_[lid] = make_pair (posW, negW);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
LiftedWCNF::posWeight (LiteralId lid) const
|
LiftedWCNF::posWeight (LiteralId lid) const
|
||||||
{
|
{
|
||||||
@ -451,6 +460,15 @@ LiftedWCNF::negWeight (LiteralId lid) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<LiteralId>
|
||||||
|
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
|
||||||
|
{
|
||||||
|
assert (Util::contains (map_, prvGroup));
|
||||||
|
return map_[prvGroup];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Clause
|
Clause
|
||||||
LiftedWCNF::createClause (LiteralId lid) const
|
LiftedWCNF::createClause (LiteralId lid) const
|
||||||
{
|
{
|
||||||
@ -481,14 +499,6 @@ LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
|
||||||
{
|
|
||||||
weights_[lid] = make_pair (posW, negW);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
||||||
{
|
{
|
||||||
|
@ -191,10 +191,14 @@ class LiftedWCNF
|
|||||||
|
|
||||||
const Clauses& clauses (void) const { return clauses_; }
|
const Clauses& clauses (void) const { return clauses_; }
|
||||||
|
|
||||||
|
void addWeight (LiteralId lid, double posW, double negW);
|
||||||
|
|
||||||
double posWeight (LiteralId lid) const;
|
double posWeight (LiteralId lid) const;
|
||||||
|
|
||||||
double negWeight (LiteralId lid) const;
|
double negWeight (LiteralId lid) const;
|
||||||
|
|
||||||
|
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
|
||||||
|
|
||||||
Clause createClause (LiteralId lid) const;
|
Clause createClause (LiteralId lid) const;
|
||||||
|
|
||||||
void printFormulaIndicators (void) const;
|
void printFormulaIndicators (void) const;
|
||||||
@ -207,8 +211,6 @@ class LiftedWCNF
|
|||||||
|
|
||||||
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
|
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
|
||||||
|
|
||||||
void addWeight (LiteralId lid, double posW, double negW);
|
|
||||||
|
|
||||||
void addIndicatorClauses (const ParfactorList& pfList);
|
void addIndicatorClauses (const ParfactorList& pfList);
|
||||||
|
|
||||||
void addParameterClauses (const ParfactorList& pfList);
|
void addParameterClauses (const ParfactorList& pfList);
|
||||||
|
Reference in New Issue
Block a user