This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/HorusYap.cpp

534 lines
16 KiB
C++
Raw Normal View History

2013-02-07 20:09:10 +00:00
#include <cassert>
2012-05-23 14:56:01 +01:00
#include <vector>
2013-02-07 20:09:10 +00:00
#include <unordered_map>
#include <string>
2012-05-23 14:56:01 +01:00
#include <iostream>
#include <sstream>
#include <YapInterface.h>
#include "ParfactorList.h"
#include "FactorGraph.h"
#include "LiftedOperations.h"
#include "LiftedVe.h"
#include "VarElim.h"
#include "LiftedBp.h"
#include "CountingBp.h"
2012-06-19 15:10:57 +01:00
#include "BeliefProp.h"
#include "LiftedKc.h"
2012-05-23 14:56:01 +01:00
#include "ElimGraph.h"
#include "BayesBall.h"
namespace Horus {
2012-05-23 14:56:01 +01:00
namespace {
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
Parfactor *readParfactor(YAP_Term);
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
ObservedFormulas *readLiftedEvidence(YAP_Term);
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
std::vector<unsigned> readUnsignedList(YAP_Term);
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
Params readParameters(YAP_Term);
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
YAP_Term fillSolutionList(const std::vector<Params> &);
2015-09-21 23:05:36 +01:00
extern "C" void init_predicates();
2015-09-21 23:05:36 +01:00
void init_predicates();
}
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
typedef std::pair<ParfactorList *, ObservedFormulas *> LiftedNetwork;
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
static YAP_Bool createLiftedNetwork() {
2012-05-23 14:56:01 +01:00
Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1;
while (parfactorList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Term pfTerm = YAP_HeadOfTerm(parfactorList);
parfactors.push_back(readParfactor(pfTerm));
parfactorList = YAP_TailOfTerm(parfactorList);
2012-05-23 14:56:01 +01:00
}
// LiftedUtils::printSymbolDictionary();
if (Globals::verbosity > 2) {
2015-09-21 23:05:36 +01:00
Util::printHeader("INITIAL PARFACTORS");
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < parfactors.size(); i++) {
2012-05-23 14:56:01 +01:00
parfactors[i]->print();
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
2012-05-23 14:56:01 +01:00
}
}
2015-09-21 23:05:36 +01:00
ParfactorList *pfList = new ParfactorList(parfactors);
2012-12-20 23:19:10 +00:00
if (Globals::verbosity > 2) {
2015-09-21 23:05:36 +01:00
Util::printHeader("SHATTERED PARFACTORS");
2012-05-23 14:56:01 +01:00
pfList->print();
}
// read evidence
2015-09-21 23:05:36 +01:00
ObservedFormulas *obsFormulas = readLiftedEvidence(YAP_ARG2);
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
LiftedNetwork *network = new LiftedNetwork(pfList, obsFormulas);
2015-09-21 23:05:36 +01:00
YAP_Int p = (YAP_Int)(network);
return YAP_Unify(YAP_MkIntTerm(p), YAP_ARG3);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
static YAP_Bool createGroundNetwork() {
std::string factorsType((char *)YAP_AtomName(YAP_AtomOfTerm(YAP_ARG1)));
FactorGraph *fg = new FactorGraph();
if (factorsType == "bayes") {
fg->setFactorsAsBayesian();
}
2012-05-23 14:56:01 +01:00
YAP_Term factorList = YAP_ARG2;
while (factorList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Term factor = YAP_HeadOfTerm(factorList);
2012-05-23 14:56:01 +01:00
// read the var ids
2015-09-21 23:05:36 +01:00
VarIds varIds = readUnsignedList(YAP_ArgOfTerm(1, factor));
2012-05-23 14:56:01 +01:00
// read the ranges
2015-09-21 23:05:36 +01:00
Ranges ranges = readUnsignedList(YAP_ArgOfTerm(2, factor));
2012-05-23 14:56:01 +01:00
// read the parameters
2015-09-21 23:05:36 +01:00
Params params = readParameters(YAP_ArgOfTerm(3, factor));
2012-05-23 14:56:01 +01:00
// read dist id
2015-09-21 23:05:36 +01:00
unsigned distId = (unsigned)YAP_IntOfTerm(YAP_ArgOfTerm(4, factor));
fg->addFactor(Factor(varIds, ranges, params, distId));
factorList = YAP_TailOfTerm(factorList);
2012-05-23 14:56:01 +01:00
}
unsigned nrObservedVars = 0;
YAP_Term evidenceList = YAP_ARG3;
while (evidenceList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Term evTerm = YAP_HeadOfTerm(evidenceList);
unsigned vid = (unsigned)YAP_IntOfTerm((YAP_ArgOfTerm(1, evTerm)));
unsigned ev = (unsigned)YAP_IntOfTerm((YAP_ArgOfTerm(2, evTerm)));
assert(fg->getVarNode(vid));
fg->getVarNode(vid)->setEvidence(ev);
evidenceList = YAP_TailOfTerm(evidenceList);
nrObservedVars++;
2012-05-23 14:56:01 +01:00
}
if (FactorGraph::exportToLibDai()) {
2015-09-21 23:05:36 +01:00
fg->exportToLibDai("model.fg");
}
if (FactorGraph::exportToUai()) {
2015-09-21 23:05:36 +01:00
fg->exportToUai("model.uai");
}
if (FactorGraph::exportGraphViz()) {
2015-09-21 23:05:36 +01:00
fg->exportToGraphViz("model.dot");
}
if (FactorGraph::printFactorGraph()) {
fg->print();
}
if (Globals::verbosity > 0) {
2015-09-21 23:05:36 +01:00
std::cout << "factor graph contains ";
std::cout << fg->nrVarNodes() << " variables and ";
2013-02-07 13:37:15 +00:00
std::cout << fg->nrFacNodes() << " factors " << std::endl;
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
YAP_Int p = (YAP_Int)(fg);
return YAP_Unify(YAP_MkIntTerm(p), YAP_ARG4);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
static YAP_Bool runLiftedSolver() {
LiftedNetwork *network = (LiftedNetwork *)YAP_IntOfTerm(YAP_ARG1);
ParfactorList copy(*network->first);
LiftedOperations::absorveEvidence(copy, *network->second);
2012-05-23 14:56:01 +01:00
2015-09-21 23:05:36 +01:00
LiftedSolver *solver = 0;
switch (Globals::liftedSolver) {
2015-09-21 23:05:36 +01:00
case LiftedSolverType::lveSolver:
solver = new LiftedVe(copy);
break;
case LiftedSolverType::lbpSolver:
solver = new LiftedBp(copy);
break;
case LiftedSolverType::lkcSolver:
solver = new LiftedKc(copy);
break;
2012-11-16 16:50:19 +00:00
}
2012-12-17 18:39:42 +00:00
if (Globals::verbosity > 0) {
2012-11-16 16:50:19 +00:00
solver->printSolverFlags();
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
2012-11-16 16:50:19 +00:00
}
2012-12-17 18:39:42 +00:00
2012-11-16 16:50:19 +00:00
YAP_Term taskList = YAP_ARG2;
2013-02-07 13:37:15 +00:00
std::vector<Params> results;
2012-05-23 14:56:01 +01:00
while (taskList != YAP_TermNil()) {
Grounds queryVars;
2015-09-21 23:05:36 +01:00
YAP_Term jointList = YAP_HeadOfTerm(taskList);
2012-05-23 14:56:01 +01:00
while (jointList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Term ground = YAP_HeadOfTerm(jointList);
if (YAP_IsAtomTerm(ground)) {
std::string name((char *)YAP_AtomName(YAP_AtomOfTerm(ground)));
queryVars.push_back(Ground(LiftedUtils::getSymbol(name)));
2012-05-23 14:56:01 +01:00
} else {
2015-09-21 23:05:36 +01:00
assert(YAP_IsApplTerm(ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm(ground);
std::string name((char *)(YAP_AtomName(YAP_NameOfFunctor(yapFunctor))));
unsigned arity = (unsigned)YAP_ArityOfFunctor(yapFunctor);
Symbol functor = LiftedUtils::getSymbol(name);
2012-05-23 14:56:01 +01:00
Symbols args;
for (unsigned i = 1; i <= arity; i++) {
2015-09-21 23:05:36 +01:00
YAP_Term ti = YAP_ArgOfTerm(i, ground);
assert(YAP_IsAtomTerm(ti));
std::string arg((char *)YAP_AtomName(YAP_AtomOfTerm(ti)));
args.push_back(LiftedUtils::getSymbol(arg));
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
queryVars.push_back(Ground(functor, args));
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
jointList = YAP_TailOfTerm(jointList);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
results.push_back(solver->solveQuery(queryVars));
taskList = YAP_TailOfTerm(taskList);
2012-05-23 14:56:01 +01:00
}
2012-11-16 16:50:19 +00:00
delete solver;
2015-09-21 23:05:36 +01:00
return YAP_Unify(fillSolutionList(results), YAP_ARG3);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
static YAP_Bool runGroundSolver() {
FactorGraph *fg = (FactorGraph *)YAP_IntOfTerm(YAP_ARG1);
2012-12-20 23:19:10 +00:00
2013-02-07 13:37:15 +00:00
std::vector<VarIds> tasks;
2012-05-23 14:56:01 +01:00
YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
tasks.push_back(readUnsignedList(YAP_HeadOfTerm(taskList)));
taskList = YAP_TailOfTerm(taskList);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
FactorGraph *mfg = fg;
if (fg->bayesianFactors()) {
2012-11-16 22:38:14 +00:00
std::set<VarId> vids;
for (size_t i = 0; i < tasks.size(); i++) {
2015-09-21 23:05:36 +01:00
Util::addToSet(vids, tasks[i]);
2012-11-16 22:38:14 +00:00
}
2015-09-21 23:05:36 +01:00
mfg =
BayesBall::getMinimalFactorGraph(*fg, VarIds(vids.begin(), vids.end()));
2012-05-23 14:56:01 +01:00
}
2012-06-19 15:29:09 +01:00
2015-09-21 23:05:36 +01:00
GroundSolver *solver = 0;
CountingBp::setFindIdenticalFactorsFlag(false);
switch (Globals::groundSolver) {
2015-09-21 23:05:36 +01:00
case GroundSolverType::veSolver:
solver = new VarElim(*mfg);
break;
case GroundSolverType::bpSolver:
solver = new BeliefProp(*mfg);
break;
case GroundSolverType::CbpSolver:
solver = new CountingBp(*mfg);
break;
2012-05-23 14:56:01 +01:00
}
2012-06-19 15:29:09 +01:00
if (Globals::verbosity > 0) {
2012-05-23 14:56:01 +01:00
solver->printSolverFlags();
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
2012-05-23 14:56:01 +01:00
}
2012-06-19 15:29:09 +01:00
2013-02-07 13:37:15 +00:00
std::vector<Params> results;
2015-09-21 23:05:36 +01:00
results.reserve(tasks.size());
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < tasks.size(); i++) {
2015-09-21 23:05:36 +01:00
results.push_back(solver->solveQuery(tasks[i]));
2012-05-23 14:56:01 +01:00
}
2012-06-19 15:29:09 +01:00
delete solver;
if (fg->bayesianFactors()) {
2012-06-19 15:10:57 +01:00
delete mfg;
2012-05-23 14:56:01 +01:00
}
2012-06-19 15:29:09 +01:00
2015-09-21 23:05:36 +01:00
return YAP_Unify(fillSolutionList(results), YAP_ARG3);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
static YAP_Bool setParfactorsParams() {
LiftedNetwork *network = (LiftedNetwork *)YAP_IntOfTerm(YAP_ARG1);
ParfactorList *pfList = network->first;
2012-12-18 22:47:43 +00:00
YAP_Term distIdsList = YAP_ARG2;
2015-09-21 23:05:36 +01:00
YAP_Term paramsList = YAP_ARG3;
2013-02-07 13:37:15 +00:00
std::unordered_map<unsigned, Params> paramsMap;
2012-12-18 22:47:43 +00:00
while (distIdsList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
unsigned distId = (unsigned)YAP_IntOfTerm(YAP_HeadOfTerm(distIdsList));
assert(Util::contains(paramsMap, distId) == false);
paramsMap[distId] = readParameters(YAP_HeadOfTerm(paramsList));
distIdsList = YAP_TailOfTerm(distIdsList);
paramsList = YAP_TailOfTerm(paramsList);
2012-05-23 14:56:01 +01:00
}
ParfactorList::iterator it = pfList->begin();
while (it != pfList->end()) {
2015-09-21 23:05:36 +01:00
assert(Util::contains(paramsMap, (*it)->distId()));
(*it)->setParams(paramsMap[(*it)->distId()]);
++it;
2012-05-23 14:56:01 +01:00
}
return TRUE;
}
2015-09-21 23:05:36 +01:00
static YAP_Bool setFactorsParams() {
FactorGraph *fg = (FactorGraph *)YAP_IntOfTerm(YAP_ARG1);
2012-12-18 22:47:43 +00:00
YAP_Term distIdsList = YAP_ARG2;
2015-09-21 23:05:36 +01:00
YAP_Term paramsList = YAP_ARG3;
2013-02-07 13:37:15 +00:00
std::unordered_map<unsigned, Params> paramsMap;
2012-12-18 22:47:43 +00:00
while (distIdsList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
unsigned distId = (unsigned)YAP_IntOfTerm(YAP_HeadOfTerm(distIdsList));
assert(Util::contains(paramsMap, distId) == false);
paramsMap[distId] = readParameters(YAP_HeadOfTerm(paramsList));
distIdsList = YAP_TailOfTerm(distIdsList);
paramsList = YAP_TailOfTerm(paramsList);
}
const FacNodes &facNodes = fg->facNodes();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes.size(); i++) {
2012-05-23 14:56:01 +01:00
unsigned distId = facNodes[i]->factor().distId();
2015-09-21 23:05:36 +01:00
assert(Util::contains(paramsMap, distId));
facNodes[i]->factor().setParams(paramsMap[distId]);
2012-05-23 14:56:01 +01:00
}
return TRUE;
}
2015-09-21 23:05:36 +01:00
static YAP_Bool setVarsInformation() {
2012-05-23 14:56:01 +01:00
Var::clearVarsInfo();
2013-02-07 13:37:15 +00:00
std::vector<std::string> labels;
2012-05-23 14:56:01 +01:00
YAP_Term labelsL = YAP_ARG1;
while (labelsL != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Atom atom = YAP_AtomOfTerm(YAP_HeadOfTerm(labelsL));
labels.push_back((char *)YAP_AtomName(atom));
labelsL = YAP_TailOfTerm(labelsL);
2012-05-23 14:56:01 +01:00
}
unsigned count = 0;
YAP_Term stateNamesL = YAP_ARG2;
while (stateNamesL != YAP_TermNil()) {
States states;
2015-09-21 23:05:36 +01:00
YAP_Term namesL = YAP_HeadOfTerm(stateNamesL);
2012-05-23 14:56:01 +01:00
while (namesL != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Atom atom = YAP_AtomOfTerm(YAP_HeadOfTerm(namesL));
states.push_back((char *)YAP_AtomName(atom));
namesL = YAP_TailOfTerm(namesL);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
Var::addVarInfo(count, labels[count], states);
count++;
stateNamesL = YAP_TailOfTerm(stateNamesL);
2012-05-23 14:56:01 +01:00
}
return TRUE;
}
2015-09-21 23:05:36 +01:00
static YAP_Bool setHorusFlag() {
std::string option((char *)YAP_AtomName(YAP_AtomOfTerm(YAP_ARG1)));
2013-02-07 13:37:15 +00:00
std::string value;
if (option == "verbosity") {
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2015-09-21 23:05:36 +01:00
ss << (int)YAP_IntOfTerm(YAP_ARG2);
2012-05-23 14:56:01 +01:00
ss >> value;
} else if (option == "bp_accuracy") {
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2015-09-21 23:05:36 +01:00
ss << (float)YAP_FloatOfTerm(YAP_ARG2);
2012-05-23 14:56:01 +01:00
ss >> value;
} else if (option == "bp_max_iter") {
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2015-09-21 23:05:36 +01:00
ss << (int)YAP_IntOfTerm(YAP_ARG2);
2012-05-23 14:56:01 +01:00
ss >> value;
} else {
2015-09-21 23:05:36 +01:00
value = ((char *)YAP_AtomName(YAP_AtomOfTerm(YAP_ARG2)));
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
return Util::setHorusFlag(option, value);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
static YAP_Bool freeGroundNetwork() {
delete (FactorGraph *)YAP_IntOfTerm(YAP_ARG1);
2012-05-23 14:56:01 +01:00
return TRUE;
}
2015-09-21 23:05:36 +01:00
static YAP_Bool freeLiftedNetwork() {
LiftedNetwork *network = (LiftedNetwork *)YAP_IntOfTerm(YAP_ARG1);
2012-05-23 14:56:01 +01:00
delete network->first;
delete network->second;
delete network;
return TRUE;
}
namespace {
2015-09-21 23:05:36 +01:00
Parfactor *readParfactor(YAP_Term pfTerm) {
2012-11-16 17:10:04 +00:00
// read dist id
2015-09-21 23:05:36 +01:00
unsigned distId = YAP_IntOfTerm(YAP_ArgOfTerm(1, pfTerm));
2012-11-16 17:10:04 +00:00
// read the ranges
Ranges ranges;
2015-09-21 23:05:36 +01:00
YAP_Term rangeList = YAP_ArgOfTerm(3, pfTerm);
2012-11-16 17:10:04 +00:00
while (rangeList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
unsigned range = (unsigned)YAP_IntOfTerm(YAP_HeadOfTerm(rangeList));
ranges.push_back(range);
rangeList = YAP_TailOfTerm(rangeList);
2012-11-16 17:10:04 +00:00
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
2013-02-07 13:37:15 +00:00
std::unordered_map<YAP_Term, LogVar> lvMap;
2015-09-21 23:05:36 +01:00
YAP_Term pvList = YAP_ArgOfTerm(2, pfTerm);
2012-11-16 17:10:04 +00:00
while (pvList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Term formulaTerm = YAP_HeadOfTerm(pvList);
if (YAP_IsAtomTerm(formulaTerm)) {
std::string name((char *)YAP_AtomName(YAP_AtomOfTerm(formulaTerm)));
Symbol functor = LiftedUtils::getSymbol(name);
formulas.push_back(ProbFormula(functor, ranges[count]));
2012-11-16 17:10:04 +00:00
} else {
LogVars logVars;
2015-09-21 23:05:36 +01:00
YAP_Functor yapFunctor = YAP_FunctorOfTerm(formulaTerm);
std::string name((char *)YAP_AtomName(YAP_NameOfFunctor(yapFunctor)));
Symbol functor = LiftedUtils::getSymbol(name);
unsigned arity = (unsigned)YAP_ArityOfFunctor(yapFunctor);
2012-11-16 17:10:04 +00:00
for (unsigned i = 1; i <= arity; i++) {
2015-09-21 23:05:36 +01:00
YAP_Term ti = YAP_ArgOfTerm(i, formulaTerm);
std::unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find(ti);
2012-11-16 17:10:04 +00:00
if (it != lvMap.end()) {
2015-09-21 23:05:36 +01:00
logVars.push_back(it->second);
2012-11-16 17:10:04 +00:00
} else {
2015-09-21 23:05:36 +01:00
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back(newLv);
2012-11-16 17:10:04 +00:00
}
}
2015-09-21 23:05:36 +01:00
formulas.push_back(ProbFormula(functor, logVars, ranges[count]));
2012-11-16 17:10:04 +00:00
}
2015-09-21 23:05:36 +01:00
count++;
pvList = YAP_TailOfTerm(pvList);
2012-11-16 17:10:04 +00:00
}
// read the parameters
2015-09-21 23:05:36 +01:00
Params params = readParameters(YAP_ArgOfTerm(4, pfTerm));
2012-11-16 17:10:04 +00:00
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
2015-09-21 23:05:36 +01:00
YAP_Term tupleList = YAP_ArgOfTerm(5, pfTerm);
2012-11-16 17:10:04 +00:00
while (tupleList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Term term = YAP_HeadOfTerm(tupleList);
assert(YAP_IsApplTerm(term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm(term);
unsigned arity = (unsigned)YAP_ArityOfFunctor(yapFunctor);
assert(lvMap.size() == arity);
Tuple tuple(arity);
2012-11-16 17:10:04 +00:00
for (unsigned i = 1; i <= arity; i++) {
2015-09-21 23:05:36 +01:00
YAP_Term ti = YAP_ArgOfTerm(i, term);
if (YAP_IsAtomTerm(ti) == false) {
std::cerr << "Error: the constraint contains free variables.";
2013-02-07 13:37:15 +00:00
std::cerr << std::endl;
2015-09-21 23:05:36 +01:00
exit(EXIT_FAILURE);
2012-11-16 17:10:04 +00:00
}
2015-09-21 23:05:36 +01:00
std::string name((char *)YAP_AtomName(YAP_AtomOfTerm(ti)));
tuple[i - 1] = LiftedUtils::getSymbol(name);
2012-11-16 17:10:04 +00:00
}
2015-09-21 23:05:36 +01:00
tuples.push_back(tuple);
tupleList = YAP_TailOfTerm(tupleList);
2012-11-16 17:10:04 +00:00
}
}
2015-09-21 23:05:36 +01:00
return new Parfactor(formulas, params, tuples, distId);
2012-11-16 17:10:04 +00:00
}
2015-09-21 23:05:36 +01:00
ObservedFormulas *readLiftedEvidence(YAP_Term observedList) {
ObservedFormulas *obsFormulas = new ObservedFormulas();
2012-11-16 17:10:04 +00:00
while (observedList != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Term pair = YAP_HeadOfTerm(observedList);
YAP_Term ground = YAP_ArgOfTerm(1, pair);
2012-11-16 17:10:04 +00:00
Symbol functor;
Symbols args;
2015-09-21 23:05:36 +01:00
if (YAP_IsAtomTerm(ground)) {
std::string name((char *)YAP_AtomName(YAP_AtomOfTerm(ground)));
functor = LiftedUtils::getSymbol(name);
2012-11-16 17:10:04 +00:00
} else {
2015-09-21 23:05:36 +01:00
assert(YAP_IsApplTerm(ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm(ground);
std::string name((char *)(YAP_AtomName(YAP_NameOfFunctor(yapFunctor))));
functor = LiftedUtils::getSymbol(name);
unsigned arity = (unsigned)YAP_ArityOfFunctor(yapFunctor);
2012-11-16 17:10:04 +00:00
for (unsigned i = 1; i <= arity; i++) {
2015-09-21 23:05:36 +01:00
YAP_Term ti = YAP_ArgOfTerm(i, ground);
assert(YAP_IsAtomTerm(ti));
std::string arg((char *)YAP_AtomName(YAP_AtomOfTerm(ti)));
args.push_back(LiftedUtils::getSymbol(arg));
2012-11-16 17:10:04 +00:00
}
}
2015-09-21 23:05:36 +01:00
unsigned evidence = (unsigned)YAP_IntOfTerm(YAP_ArgOfTerm(2, pair));
2012-11-16 17:10:04 +00:00
bool found = false;
2013-02-16 01:54:11 +00:00
for (size_t i = 0; i < obsFormulas->size(); i++) {
2015-09-21 23:05:36 +01:00
if ((*obsFormulas)[i].functor() == functor &&
(*obsFormulas)[i].arity() == args.size() &&
2013-02-16 01:54:11 +00:00
(*obsFormulas)[i].evidence() == evidence) {
2015-09-21 23:05:36 +01:00
(*obsFormulas)[i].addTuple(args);
2012-11-16 17:10:04 +00:00
found = true;
}
}
if (found == false) {
2015-09-21 23:05:36 +01:00
obsFormulas->push_back(ObservedFormula(functor, evidence, args));
2012-11-16 17:10:04 +00:00
}
2015-09-21 23:05:36 +01:00
observedList = YAP_TailOfTerm(observedList);
2012-12-20 23:19:10 +00:00
}
2013-02-16 01:54:11 +00:00
return obsFormulas;
2012-11-16 17:10:04 +00:00
}
2015-09-21 23:05:36 +01:00
std::vector<unsigned> readUnsignedList(YAP_Term list) {
2013-02-07 13:37:15 +00:00
std::vector<unsigned> vec;
2012-11-16 17:10:04 +00:00
while (list != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
vec.push_back((unsigned)YAP_IntOfTerm(YAP_HeadOfTerm(list)));
list = YAP_TailOfTerm(list);
2012-11-16 17:10:04 +00:00
}
return vec;
}
2015-09-21 23:05:36 +01:00
Params readParameters(YAP_Term paramL) {
2012-11-16 17:10:04 +00:00
Params params;
2015-09-21 23:05:36 +01:00
assert(YAP_IsPairTerm(paramL));
2012-11-16 17:10:04 +00:00
while (paramL != YAP_TermNil()) {
2015-09-21 23:05:36 +01:00
YAP_Term hd = YAP_HeadOfTerm(paramL);
if (YAP_IsFloatTerm(hd)) {
params.push_back((double)YAP_FloatOfTerm(hd));
2013-04-11 23:06:13 +01:00
} else {
2015-09-21 23:05:36 +01:00
params.push_back((double)YAP_IntOfTerm(hd));
2013-04-11 23:06:13 +01:00
}
2015-09-21 23:05:36 +01:00
paramL = YAP_TailOfTerm(paramL);
2012-11-16 17:10:04 +00:00
}
if (Globals::logDomain) {
2015-09-21 23:05:36 +01:00
Util::log(params);
2012-11-16 17:10:04 +00:00
}
return params;
}
2015-09-21 23:05:36 +01:00
YAP_Term fillSolutionList(const std::vector<Params> &results) {
2012-11-16 17:10:04 +00:00
YAP_Term list = YAP_TermNil();
2015-09-21 23:05:36 +01:00
for (size_t i = results.size(); i-- > 0;) {
const Params &beliefs = results[i];
2012-11-16 17:10:04 +00:00
YAP_Term queryBeliefsL = YAP_TermNil();
2015-09-21 23:05:36 +01:00
for (size_t j = beliefs.size(); j-- > 0;) {
YAP_Int sl = YAP_InitSlot(list);
YAP_Term belief = YAP_MkFloatTerm(beliefs[j]);
queryBeliefsL = YAP_MkPairTerm(belief, queryBeliefsL);
list = YAP_GetFromSlot(sl);
YAP_RecoverSlots(1, sl);
2012-11-16 17:10:04 +00:00
}
2015-09-21 23:05:36 +01:00
list = YAP_MkPairTerm(queryBeliefsL, list);
2012-11-16 17:10:04 +00:00
}
return list;
}
}
2015-09-21 23:05:36 +01:00
extern "C" void init_predicates() {
YAP_UserCPredicate("cpp_create_lifted_network", createLiftedNetwork, 3);
2012-11-16 17:10:04 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_create_ground_network", createGroundNetwork, 4);
2012-11-16 17:10:04 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_run_lifted_solver", runLiftedSolver, 3);
2012-12-18 22:47:43 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_run_ground_solver", runGroundSolver, 3);
2012-12-18 22:47:43 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_set_parfactors_params", setParfactorsParams, 3);
2012-12-18 22:47:43 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_set_factors_params", setFactorsParams, 3);
2012-12-18 22:47:43 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_set_vars_information", setVarsInformation, 2);
2012-12-18 22:47:43 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_set_horus_flag", setHorusFlag, 2);
2012-12-18 22:47:43 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_free_lifted_network", freeLiftedNetwork, 1);
2012-12-18 22:47:43 +00:00
2015-09-21 23:05:36 +01:00
YAP_UserCPredicate("cpp_free_ground_network", freeGroundNetwork, 1);
2012-05-23 14:56:01 +01:00
}
2015-09-21 23:05:36 +01:00
} // namespace Horus