%%%%
%%%%  Evaluation of a naive Bayes classifier for `votes' dataset
%%%%  --- votes.psm
%%%%
%%%%  Copyright (C) 2009
%%%%    Sato Laboratory, Dept. of Computer Science,
%%%%    Tokyo Institute of Technology

%%  In this program, we conduct n-fold cross validation of a naive Bayes
%%  classifier.  This program was created to demonstrate the usefulness of
%%  the built-in predicates introduced since version 1.12.  The target
%%  dataset is the congressional voting records (`votes') dataset, which
%%  is available from UCI machine learning repository (http://archive.ics.
%%  uci.edu/ml/).  
%%  
%%  From this program, it is seen that, using new built-in predicates such
%%  as maplist/5, avglist/2, random_shuffle/2, and so on, we can make the
%%  utility part compact, as well as the modeling part.  Also one may find
%%  that we only combine general-purpose built-ins to implement n-fold cross
%%  validation.

%%-------------------------------------
%%  Quick start : sample session
%%
%%  (Preparation: Download the data file `house-votes-84.data' from UCI ML
%%  repository, and put it `as-is' on the current directly)
%%
%%  ?- prism(votes),votes_learn.     % Learn parameters from the whole dataset
%%
%%  ?- prism(votes),votes_cv(10).    % Conduct 10-fold cross validation
%%

%%-------------------------------------
%%  Declarations

values(class,[democrat,republican]).                % class labels
values(attr(_,_),[y,n]).  % all attributes have two values: y or n

%%-------------------------------------
%%  Modeling part (a naive Bayes model)
%%
%%  [Note]
%%  According to `house-votes-84.names', a data description file for the
%%  `votes' dataset, '?' simply denotes that the value is not "yea" nor
%%  "nay".  On the other hand, in this program, we consider '?' as a missing
%%  value just for demonstration purpose.

nbayes(C,Vals):- msw(class,C),nbayes(1,C,Vals).

nbayes(_,_,[]).
nbayes(J,C,[V|Vals]):-
    choose(J,C,V),
    J1 is J+1,
    nbayes(J1,C,Vals).

choose(J,C,V):- 
    ( V == '?' -> msw(attr(J,C),_)   % handling '?' as a missing value
    ; msw(attr(J,C),V0),
      V = V0
    ).

%%-------------------------------------
%%  Utility part:

%% Batch routine for a simple learning

votes_learn:-
    load_data_file(Gs),
    learn(Gs).

%% Batch routine for N-fold cross validation

votes_cv(N):-
    random_set_seed(81729), % Fix the random seed to keep the same splitting
    load_data_file(Gs0),    % Load the entire data
    random_shuffle(Gs0,Gs), % Randomly reorder the data
    numlist(1,N,Ks),        % Get Ks = [1,...,N] (B-Prolog built-in)
    maplist(K,Rate,votes_cv(Gs,K,N,Rate),Ks,Rates),
                            % Call votes_cv/2 for K=1...N
    avglist(Rates,AvgRate), % Get the avg. of the precisions
    maplist(K,Rate,format("Test #~d: ~2f%~n",[K,Rate*100]),Ks,Rates),
    format("Average: ~2f%~n",[AvgRate*100]).

%% Subroutine for learning and testing for K-th split data (K = 1...N)

votes_cv(Gs,K,N,Rate):-
    format("<<<< Test #~d >>>>~n",[K]),
    separate_data(Gs,K,N,Gs0,Gs1),  % Gs0: training data, Gs1: test data
    learn(Gs0),                     % Learn by PRISM's built-in
    maplist(nbayes(C,Vs),R,(viterbig(nbayes(C0,Vs)),(C0==C->R=1;R=0)),Gs1,Rs),
                       % Predict the class by viterbig/1 for each test example
                       %           and evaluate it with the answer class label
    avglist(Rs,Rate),  % Get the accuracy for the K-th splitting
    format("Done (~2f%).~n~n",[Rate*100]).

%% Split the entire data (Data) into the training data (Train)
%% and the test data (Test) for the K-th evaluation (K=1...N)

separate_data(Data,K,N,Train,Test):-
    length(Data,L),
    L0 is L*(K-1)//N,    % L0: offset of the test data (// - integer division)
    L1 is L*(K-0)//N-L0, % L1: size of the test data
    splitlist(Train0,Rest,Data,L0),   % Length of Train0 = L0
    splitlist(Test,Train1,Rest,L1),   % Length of Test = L1
    append(Train0,Train1,Train).

%% Load the `votes' data in CSV form and convert it to suitable
%% Prolog terms

load_data_file(Gs):-
    load_csv('house-votes-84.data',Gs0),
    maplist(csvrow([C|Vs]),nbayes(C,Vs),true,Gs0,Gs).