%%%%
%%%%  Bayesian networks (2) -- sbn.psm
%%%%
%%%%  Copyright (C) 2004,2008
%%%%    Sato Laboratory, Dept. of Computer Science,
%%%%    Tokyo Institute of Technology

%%  This example shows how to simulate Pearl's message passing
%%  (without normalization) for singly connected BNs (Bayesian networks).
%%
%%  Suppose that we have a Bayesian network in Fiugre 1 and that
%%  we wish to compute marginal probabilites P(B) of B.
%%  The distribution defined by the BN in Figure 1 is expressed
%%  by a BN program in Figure 3. We transform it into another
%%  program that defines the same marginal distribuion for B.
%%
%%    Original graph       Transformed graph
%%
%%       A      B                 B
%%      /  \   /                  |
%%     /    \ /                   v
%%    C      D       ==>          D
%%          / \                 / | \
%%         /   \               /  v  v
%%        E     F             A   E  F 
%%                           /
%%                          v
%%                          C
%%    (Figure 1)             (Figure 2)
%%
%%    Original BN program for Figure 1
%%
        world(VA,VB,VC,VD,VE,VF):-
           msw(par('A',[]),VA),   msw(par('B',[]),VB),
           msw(par('C',[VA]),VC), msw(par('D',[VA,VB]),VD),
           msw(par('E',[VD]),VE), msw(par('F',[VD]),VF).
        check_B(VB):- world(_,VB,_,_,_,_).
%%
%%             (Figure 3)
%%
%%  Transformation:
%%  [Step 1] Transform the orignal BN in Figure 1 into Figure 2 by letting
%%     B be the top node and other nodes dangle from B.
%%  [Step 2] Construct a program that calls nodes in Figure 2 from the top
%%     node to leaves. For example for D, add clause
%%
%%        call_BD(VB):- call_DA(VA),call_DE(VE),call_DF(VF).
%%
%%     while inserting an msw expressing the CPT P(D|A,B) in the body. Here,
%%
%%        call_XY(V) <=>
%%           node Y is called from X with ground term V (=X's realization)
%%
%%  It can be proved by unfolding that the transformed program is equivalent
%%  in distribution semantics to the original program in Figure 3.
%%     => Both programs compute the same marginal distribution for B.
%%        Confirm by ?- prob(ask_B(2),X),prob(check_B(2),Y).

%%-------------------------------------
%%  Quick start : sample session
%%
%%  ?- prism(sbn),go.           % Learn parameters from randomly generated
%%                              % 100 samples while preserving the marginal
%%                              % disribution P(B)
%%
%%  ?- prob(ask_B(2)).
%%  ?- prob(ask_B(2),X),prob(check_B(2),Y). % => X=Y
%%  ?- probf(ask_B(2)).
%%  ?- sample(ask_B(X)).
%%
%%  ?- viterbi(ask_B(2)).
%%  ?- viterbif(ask_B(2),P,E),print_graph(E).

go:- sbn_learn(100).

%%------------------------------------
%%  Declarations:

values(par('A',[]),   [0,1]).   % Declare msw(par('A',[]),VA) where
values(par('B',[]),   [2,3]).   % VA is one of {0,1}
values(par('C',[_]),  [4,5]).
values(par('D',[_,_]),[6,7]).   % Declare msw(par('D',[VA,VB]),VD) where
values(par('E',[_]),  [8,9]).   % VD is one of {6,7}
values(par('F',[_]),  [10,11]).

set_params:-                    % Call set_sw/2 built-in 
   set_sw(par('A',[]),   [0.3,0.7]),
   set_sw(par('B',[]),   uniform),              % => [0.5,0.5]
   set_sw(par('C',[0]),  f_geometric(3,asc)),   % => [0.25,0.75]
   set_sw(par('C',[1]),  f_geometric(3,desc)),  % => [0.75,0.25]
   set_sw(par('D',[0,2]),f_geometric(3)),       % => [0.75,0.25]
   set_sw(par('D',[1,2]),f_geometric(2)),       % => [0.666...,0.333...]
   set_sw(par('D',[0,3]),f_geometric),          % => [0.666...,0.333...]
   set_sw(par('D',[1,3]),[0.3,0.7]),
   set_sw(par('E',[6]),  [0.3,0.7]),
   set_sw(par('E',[7]),  [0.1,0.9]),
   set_sw(par('F',[6]),  [0.3,0.7]),
   set_sw(par('F',[7]),  [0.1,0.9]).

:- set_params.

%%------------------------------------
%%  Modeling part: transformed program defining P(B)

ask_B(VB) :-                    % ?- prob(ask_B(2),X)
   msw(par('B',[]),VB),         %  => X = P(B=2)
   call_BD(VB).
call_BD(VB):-                   % msw's Id must be ground
   call_DA(VA),                 %  => VA must be ground
   msw(par('D',[VA,VB]),VD),    %  => call_DA(VA)
   call_DE(VD),                 %     before msw(par('D',[VA,VB]),VD)
   call_DF(VD).
call_DA(VA):-
   msw(par('A',[]),VA),
   call_AC(VA).
call_AC(VA):-
   msw(par('C',[VA]),_VC).
call_DE(VD):-
   msw(par('E',[VD]),_VE).
call_DF(VD):-
   msw(par('F',[VD]),_VF).

%%------------------------------------
%%  Utility part:

sbn_learn(N):-                  % Learn parameters (CPTs) from a list of
   random_set_seed(123456),     % N randomly generated ask_B(.) atoms
   set_params,
   get_samples(N,ask_B(_),Goals),
   learn(Goals).