%%%%
%%%%  Hidden Markov model --- hmm.psm
%%%%      
%%%%  Copyright (C) 2004,2006,2008
%%%%    Sato Laboratory, Dept. of Computer Science,
%%%%    Tokyo Institute of Technology

%%  [state diagram:]  (2 states and 2 output symbols)
%%
%%         +--------+                +--------+
%%         |        |                |        |
%%         |     +------+        +------+     |
%%         |     |      |------->|      |     |
%%         +---->|  s0  |        |  s1  |<----+
%%               |      |<-------|      |
%%               +------+        +------+
%%
%%    - In each state, possible output symbols are `a' and `b'.

%%-------------------------------------
%%  Quick start : sample session
%%
%%  ?- prism(hmm),hmm_learn(100).   % Learn parameters from 100 randomly
%%                                  % generated samples
%%
%%  ?- show_sw.                     % Confirm the learned parameter
%%
%%  ?- prob(hmm([a,a,a,a,a,b,b,b,b,b])).        % Calculate the probability
%%  ?- probf(hmm([a,a,a,a,a,b,b,b,b,b])).       % Get the explanation graph
%%
%%  ?- viterbi(hmm([a,a,a,a,a,b,b,b,b,b])).     % Run the Viterbi computation
%%  ?- viterbif(hmm([a,a,a,a,a,b,b,b,b,b])).    % Get the Viterbi explanation
%%
%%  ?- hindsight(hmm([a,a,a,a,a,b,b,b,b,b])).   % Get hindsight probabilities

%%------------------------------------
%%  Declarations:

values(init,[s0,s1]).       % state initialization
values(out(_),[a,b]).       % symbol emission
values(tr(_),[s0,s1]).      % state transition

% :- set_prism_flag(default_sw_d,1.0).
% :- set_prism_flag(epsilon,1.0e-2).
% :- set_prism_flag(restart,10).
% :- set_prism_flag(log_scale,on).

%%------------------------------------
%%  Modeling part:

hmm(L):-                        % To observe a string L:
   str_length(N),               %   Get the string length as N
   msw(init,S),                 %   Choose an initial state randomly
   hmm(1,N,S,L).                %   Start stochastic transition (loop)

hmm(T,N,_,[]):- T>N,!.          % Stop the loop
hmm(T,N,S,[Ob|Y]) :-            % Loop: current state is S, current time is T
   msw(out(S),Ob),              %   Output Ob at the state S
   msw(tr(S),Next),             %   Transit from S to Next.
   T1 is T+1,                   %   Count up time
   hmm(T1,N,Next,Y).            %   Go next (recursion)

str_length(10).                 % String length is 10

%%------------------------------------
%%  Utility part:

hmm_learn(N):-
   set_params,!,                % Set parameters manually
   get_samples(N,hmm(_),Gs),!,  % Get N samples
   learn(Gs).                   % Learn with the samples

set_params:-
   set_sw(init,   [0.9,0.1]),
   set_sw(tr(s0), [0.2,0.8]),
   set_sw(tr(s1), [0.8,0.2]),
   set_sw(out(s0),[0.5,0.5]),
   set_sw(out(s1),[0.6,0.4]).

%%  prism_main/1 is a special predicate for batch execution.
%%  The following command conducts learning from 50 randomly
%%  generated samples:
%%      > upprism hmm 50

prism_main([Arg]):-
   parse_atom(Arg,N),           % Convert an atom ('50') to a number (50)
   hmm_learn(N).                % Learn with N samples

%%  viterbi_states(Os,Ss) returns the most probable sequence Ss
%%  of state transitions for an output sequence Os.
%%
%%  | ?- viterbi_states([a,a,a,a,a,b,b,b,b,b],States).
%%
%%  States = [s0,s1,s0,s1,s0,s1,s0,s1,s0,s1,s0] ?

viterbi_states(Outputs,States):-
   viterbif(hmm(Outputs),_,E),
   viterbi_subgoals(E,E1),
   maplist(hmm(_,_,S,_),S,true,E1,States).