%%%%
%%%%  Profile HMM --- phmm.psm
%%%%      
%%%%  Copyright (C) 2004,2006,2007,2008
%%%%    Sato Laboratory, Dept. of Computer Science,
%%%%    Tokyo Institute of Technology

%%  Profile HMMs are a variant of HMMs that have three types of states,
%%  i.e. `match state',`insert state' and `delete state.' Match states
%%  constitute an HMM that outputs a `true' string. Insertion states
%%  emit a symbol additionally to the `true' string whereas delete (skip)
%%  states emit no symbol.
%%
%%  Profile HMMs are used to align amino-acid sequences by inserting
%%  and skipping symbols as well as matching symbols.  For example
%%  amino-acid sequences below
%%
%%  HLKIANRKDKHHNKEFGGHHLA
%%  HLKATHRKDQHHNREFGGHHLA
%%  VLKFANRKSKHHNKEMGAHHLA
%%  ...
%% 
%%  are aligned by the profile HMM program in this file as follows.
%%
%%  -HLKIA-NRKDK-H-H----NKEFGGHH-LA
%%  -HLK-A-T-HRK-DQHHN--R-EFGGHH-LA
%%  -VLKFA-NRKSK-H-H----NKEMGAHH-LA
%%  ...

%%-------------------------------------
%%  Quick start : sample session, align the sample data in phmm.dat.
%%
%%  To run on an interactive session:
%%  ?- prism(phmm),go.         (ML/MAP)
%%  ?- prism(phmm),go_vb.      (variational Bayes)
%%
%%  To perform a batch execution:
%%  > upprism phmm

go :-
   read_goals(Gs,'phmm.dat'),   % Read the sequence data from phmm.dat.
   learn(Gs),                   % Learn parameters from the data.
   wmag(Gs).                    % Compute viterbi paths using the learned
                                % parameters and aligns sequences in Gs.

% To enable variational Bayes, we need some additional flag settings:
go_vb :-
   set_prism_flag(learn_mode,both),
   set_prism_flag(viterbi_mode,hparams),
   set_prism_flag(reset_hparams,on),
   go.

prism_main :- go.
%prism_main :- go_vb.


%%%--------------------- model ---------------------

observe(Sequence) :- hmm(Sequence,start).

hmm([],end).
hmm(Sequence,State) :-
   State \== end,
   msw(move_from(State),NextState),
   msw(emit_at(State), Symbol),
   ( Symbol = epsilon ->
       hmm( Sequence, NextState )
   ; Sequence = [Symbol|TailSeq],
       hmm( TailSeq , NextState )
   ).

amino_acids(['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R',
             'S','T','V','W','X','Y']).
hmm_len(17).

%%%--------------------- values ---------------------

values(move_from(State),Values) :-
   hmm_len(Len),
   get_index(State,X),
   ( 0 =< X, X < Len ->
       Y is X + 1,
       Values = [insert(X),match(Y),delete(Y)]
   ; Values = [insert(X),end] ).

values(emit_at(State),Vs) :-
   ((State = insert(_) ; State = match(_)) ->
       amino_acids(Vs)
   ; Vs = [epsilon] ).

%%%--------------------- set_sw ---------------------

:- init_set_sw.

init_set_sw :-
%  tell('/dev/null'),           % Suppress output (on Linux only)
   set_sw( move_from(start) ),
   set_sw( move_from(insert(0)) ),
   set_sw( emit_at(start) ),
   set_sw( emit_at(insert(0)) ),
   hmm_len(Len),
%  told,
   init_set_sw(Len).

init_set_sw(0).
init_set_sw(X) :-
   X > 0,
   set_sw( move_from(insert(X)) ),
   set_sw( move_from(match(X)) ),
   set_sw( move_from(delete(X)) ),
   set_sw( emit_at(insert(X)) ),
   set_sw( emit_at(match(X)) ),
   set_sw( emit_at(delete(X)) ),
   Y is X - 1,
   init_set_sw(Y).

%%%--------------------- estimation ---------------------
%% most likely path
%% mlpath(['A','E'],Path) => Path = [start,match(1),end]

mlpath(Sequence,Path):-
   mlpath(Sequence,Path,_).
mlpath(Sequence,Path,Prob):-
   viterbif(hmm(Sequence,start),Prob,Nodes),
   nodes2path(Nodes,Path).

nodes2path([Node|Nodes],[State|Path]):-
   Node = node(hmm(_,State),_),
   nodes2path(Nodes,Path).
nodes2path([],[]).

mlpaths([Seq|Seqs],[Path|Paths], X):- 
   mlpath(Seq,Path),
X= [P|_], writeln(P),
stop_low_level_trace,
   mlpaths(Seqs,Paths, X).
mlpaths([],[],_).

%%%--------------------- alignment ---------------------

wmag(Gs):-
   seqs2goals(S,Gs),wma(S).
wma(Seqs):-
   write_multiple_alignments(Seqs).
write_multiple_alignments(Seqs):-
   nl,
   write('search Viterbi paths...'),nl,
   mlpaths(Seqs,Paths,Paths),
   write('done.'),
   nl,
   write('------------ALIGNMENTS------------'),
   nl,
   write_multiple_alignments( Seqs, Paths ),
   write('----------------------------------'),
   nl.

make_max_length_list([Path|Paths],MaxLenList) :-
   make_max_length_list(Paths, TmpLenList),
   make_length_list(Path,LenList),
   marge_len_list(LenList,TmpLenList,MaxLenList).
make_max_length_list([Path],MaxLenList) :- 
   !,make_length_list(Path,MaxLenList).

marge_len_list([H1|T1],[H2|T2],[MargedH|MargedT]) :-
   max(MargedH,[H1,H2]),
   marge_len_list(T1,T2,MargedT).
marge_len_list([],[],[]).

%% make_length_list([start,insert(0),match(1),end],LenList)
%% -> LenList = [2,1]
%% make_length_list([start,delete(1),insert(1),insert(1),end],LenList)
%% -> LenList = [1,1]

make_length_list(Path,[Len|LenList]) :-
   count_emission(Path,Len,NextIndexPath),
   make_length_list(NextIndexPath,LenList).
make_length_list([end],[]).

count_emission(Path,Count,NextIndexPath) :-
   Path = [State|_],
   get_index(State,Index),
   count_emission2(Path,Count,Index,NextIndexPath).

%% count_emission2([start,insert(0),match(1),end],Count,0,NextIndexPath)
%% -> Count = 2, NextIndexPath = [match(1),end]
%% count_emission2([delete(1),insert(1),insert(1),end],Count,1,NextIndexPath)
%% -> Count = 2, NextIndexPath = [end]

count_emission2([State|Path],Count,Index,NextIndexPath) :-
   ( get_index(State,Index) ->
       count_emission2( Path, Count2, Index, NextIndexPath ),
       ( (State = delete(_); State==start) ->
           Count = Count2
       ; Count is Count2 + 1 )
   ; Count = 0,
       NextIndexPath = [State|Path]
   ).

write_multiple_alignments(Seqs,Paths) :-
   make_max_length_list(Paths,LenList),
   write_multiple_alignments(Seqs,Paths,LenList).
write_multiple_alignments([Seq|Seqs],[Path|Paths],LenList) :-
   write_alignment(Seq,Path,LenList),
   write_multiple_alignments(Seqs,Paths,LenList).
write_multiple_alignments([],[],_).

write_alignment(Seq,Path,LenList) :- 
   write_alignment(Seq,Path,LenList,0).

write_alignment([],[end],[],_):- !,nl.
write_alignment(Seq,[State|Path],LenList,Index) :-
   get_index(State,Index),!,
   ( (State = delete(_) ; State == start) ->
       write_alignment( Seq, Path, LenList, Index )
   ; Seq = [Symbol|Seq2],
       LenList = [Len|LenList2],
       write(Symbol),
       Len2 is Len - 1,
       write_alignment(Seq2,Path,[Len2|LenList2],Index)
   ).
write_alignment(Seq,[State|Path],LenList,Index) :-
   LenList = [Len|LenList2],
   Index2 is Index + 1,
   pad(Len),
   write_alignment(Seq,[State|Path],LenList2,Index2).

pad(Len) :-
   Len > 0,
   write('-'),
   Len2 is Len - 1,!,
   pad(Len2).
pad(0).

%%%--------------------- utility ---------------------

get_index(State,Index) :-
   (State=match(_),!,State=match(Index));
   (State=insert(_),!,State=insert(Index));
   (State=delete(_),!,State=delete(Index));
   (State=start,!,Index=0);
   (State=end,!,hmm_len(X),Index is X+1).

seqs2goals([Seq|Seqs],[Goal|Goals]) :-
   Goal = observe(Seq),
   seqs2goals(Seqs,Goals).
seqs2goals([],[]).

max(Max,[Head|Tail]) :-
   max(Tmp,Tail),!,
   ( Tmp > Head -> Max = Tmp ; Max = Head ).
max(Max,[Max]).

read_goals(Goals,FileName) :-
   see(FileName),
   read_goals(Goals),
   seen.
read_goals(Goals) :-
   read(Term),
   ( Term = end_of_file ->
       Goals = []
   ; Goals = [Term|Goals1],
       read_goals(Goals1)
   ).