%%%% %%%% Double coin tossing --- dcoin.psm %%%% %%%% Copyright (C) 2004,2006,2008 %%%% Sato Laboratory, Dept. of Computer Science, %%%% Tokyo Institute of Technology %% A sequential mixture of two Bernoulli trials processes. %% We have two coins, coin(1) and coin(2). %% Start with coin(1), we keep flipping a coin and observe the outcome. %% We change coins according to the rule in the process. %% If the outcome is "head", the next coin to flip is coin(2). %% If the outcome is "tail", the next coin to flip is coin(1). %% The learning task is to estimate parameters for coin(1) and coin(2), %% observing a sequence of outcomes. %% As there is no hidden variable in this model, EM learning is just %% ML estimation from complete data. %%------------------------------------- %% Quick start : sample session %% %% (1) load this program %% ?- prism(dcoin). %% %% (2) sampling and probability computations %% ?- sample(dcoin(10,X)),prob(dcoin(10,X)). %% ?- sample(dcoin(10,X)),probf(dcoin(10,X)). %% %% (3) EM learning %% ?- go. go:- dcoin_learn(500). %%------------------------------------ %% Declarations: values(coin(1),[head,tail],[0.5,0.5]). % Declare msw(coin(1),V) s.t. V = head or % V = tail, where P(msw(coin(1),head)) = 0.5 % and P(msw(coin(1),tail)) = 0.5. values(coin(2),[head,tail],[0.7,0.3]). % Declare msw(coin(2),V) s.t. V = head or % V = tail, where P(msw(coin(2),head)) = 0.7 % and P(msw(coin(2),tail)) = 0.3. %%------------------------------------ %% Modeling part: dcoin(N,Rs) :- % Rs is a list with length N of outcomes dcoin(N,coin(1),Rs). % from two Bernoulli trials processes. dcoin(N,Coin,[R|Rs]) :- N > 0, msw(Coin,R), ( R == head, NextCoin = coin(2) ; R == tail, NextCoin = coin(1) ), N1 is N-1, dcoin(N1,NextCoin,Rs). dcoin(0,_,[]). %%------------------------------------ %% Utility part: dcoin_learn(N) :- set_params, % Set parameters. sample(dcoin(N,Rs)), % Get a sample Rs of size N. Goals = [dcoin(N,Rs)], % Estimate the parameters from Rs. learn(Goals). set_params :- set_sw(coin(1),[0.5,0.5]), set_sw(coin(2),[0.7,0.3]).