%%%%
%%%%  Bayesian networks (1) -- alarm.psm
%%%%
%%%%  Copyright (C) 2004,2006,2008
%%%%    Sato Laboratory, Dept. of Computer Science,
%%%%    Tokyo Institute of Technology

%%  This example is borrowed from:
%%    Poole, D., Probabilistic Horn abduction and Bayesian networks,
%%    In Proc. of Artificial Intelligence 64, pp.81-129, 1993.
%%
%%                (Fire)  (Tampering)
%%                /    \   /
%%          ((Smoke)) (Alarm)
%%                       |
%%                   (Leaving)       ((  )) -- observable node
%%                       |            (  )  -- hidden node
%%                   ((Report))
%%
%%  In this network, we assume that all rvs (random variables)
%%  take on {yes,no} and also assume that only two nodes, `Smoke'
%%  and `Report', are observable.

%%-------------------------------------
%%  Quick start : sample session
%%
%%  ?- prism(alarm),go.         % Learn parameters from randomly generated
%%                              % 100 samples
%%
%%  Get the probability and the explanation graph:
%%  ?- prob(world(yes,no)).
%%  ?- probf(world(yes,no)).
%%
%%  Get the most likely explanation and its probability:
%%  ?- viterbif(world(yes,no)).
%%  ?- viterbi(world(yes,no)).
%%
%%  Compute conditional hindsight probabilities:
%%  ?- chindsight(world(yes,no)).
%%  ?- chindsight_agg(world(yes,no),world(_,_,query,yes,_,no)).

go:- alarm_learn(100).

%%-------------------------------------
%%  Declarations:

:- set_prism_flag(data_source,file('world.dat')).
                            % When we run learn/0, the data are supplied
                            % from `world.dat'.

values(_,[yes,no]).         % We declare multiary random switch msw(.,V)
                            % used in this program such that V (outcome)
                            % is one of {yes,no}. Note that '_' is
                            % an anonymous logical variable in Prolog.

                            % The distribution of V is specified by
                            % set_params below.

%%------------------------------------
%%  Modeling part:
%%
%%  The above BN defines a joint distribution 
%%     P(Fire,Tapering,Smoke,Alarm,Leaving,Report).
%%  We assume `Smoke' and `Report' are observable while others are not.
%%  Our modeling simulates random sampling of the BN from top nodes
%%  using msws. For each rv, say `Fire', we introduce a corresponding
%%  msw, say msw(fi,Fi) such that
%%     msw(fi,Fi) <=> sampling msw named fi yields the outcome Fi.
%%  Here fi is a constant intended for the name of rv `Fire.'
%%

world(Fi,Ta,Al,Sm,Le,Re) :-
   %% Define a distribution for world/5 such that e.g.
   %%    P(Fire=yes,Tapering=yes,Smoke=no,Alarm=no,Leaving=no,Report=no)
   %%       = P(world(yes,yes,no,no,no,no))
   msw(fi,Fi),                  % P(Fire)
   msw(ta,Ta),                  % P(Tampering)
   msw(sm(Fi),Sm),              % CPT P(Smoke | Fire)
   msw(al(Fi,Ta),Al),           % CPT P(Alarm | Fire,Tampering)
   msw(le(Al),Le),              % CPT P(Leaving | Alarm)
   msw(re(Le),Re).              % CPT P(Report | Leaving)

world(Sm,Re):-
   %% Define marginal distribution for `Smoke' and `Report'
   world(_,_,_,Sm,_,Re).

%%------------------------------------
%%  Utility part:

alarm_learn(N) :-
   unfix_sw(_),                  % Make all parameters changeable
   set_params,                   % Set parameters as you specified
   get_samples(N,world(_,_),Gs), % Get N samples
   fix_sw(fi),                   % Preserve the parameter values
   learn(Gs).                    %   for {msw(fi,yes), msw(fi,no)}

% alarm_learn(N) :-
%    %% generate teacher data and write them to `world.dat'
%    %% before learn/0 is called.
%    write_world(N,'world.dat'),
%    learn.

set_params :-
   set_sw(fi,[0.1,0.9]),
   set_sw(ta,[0.15,0.85]),
   set_sw(sm(yes),[0.95,0.05]),
   set_sw(sm(no),[0.05,0.95]),
   set_sw(al(yes,yes),[0.50,0.50]),
   set_sw(al(yes,no),[0.90,0.10]),
   set_sw(al(no,yes),[0.85,0.15]),
   set_sw(al(no,no),[0.05,0.95]),
   set_sw(le(yes),[0.88,0.12]),
   set_sw(le(no),[0.01,0.99]),
   set_sw(re(yes),[0.75,0.25]),
   set_sw(re(no),[0.10,0.90]).

write_world(N,File) :-
   get_samples(N,world(_,_),Gs),tell(File),write_world(Gs),told.

write_world([world(Sm,Re)|Gs]) :-
   write(world(Sm,Re)),write('.'),nl,write_world(Gs).
write_world([]).