52 lines
1018 B
C++
52 lines
1018 B
C++
#include "../CountingBp.h"
|
|
#include "../FactorGraph.h"
|
|
#include "Common.h"
|
|
|
|
|
|
namespace Horus {
|
|
|
|
namespace UnitTests {
|
|
|
|
class CountingBpTest : public CppUnit::TestFixture {
|
|
CPPUNIT_TEST_SUITE (CountingBpTest);
|
|
CPPUNIT_TEST (testMarginals);
|
|
CPPUNIT_TEST (testJoint);
|
|
CPPUNIT_TEST_SUITE_END();
|
|
public:
|
|
void testMarginals();
|
|
void testJoint();
|
|
};
|
|
|
|
|
|
|
|
void
|
|
CountingBpTest::testMarginals()
|
|
{
|
|
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
|
|
CountingBp solver (fg);
|
|
for (unsigned i = 0; i < marginalProbs.size(); i++) {
|
|
Params params = solver.solveQuery ({i});
|
|
CPPUNIT_ASSERT (similiar (params, marginalProbs[i]));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
void
|
|
CountingBpTest::testJoint()
|
|
{
|
|
FactorGraph fg = FactorGraph::readFromLibDaiFormat (modelFile.c_str());
|
|
CountingBp solver (fg);
|
|
Params params = solver.solveQuery ({0, 4, 6});
|
|
CPPUNIT_ASSERT (similiar (params, jointProbs));
|
|
}
|
|
|
|
|
|
|
|
CPPUNIT_TEST_SUITE_REGISTRATION (CountingBpTest);
|
|
|
|
} // namespace UnitTests
|
|
|
|
} // namespace Horus
|
|
|