From e865248dce4c5dc28db4545517042e7344e45556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADtor=20Santos=20Costa?= Date: Thu, 10 Nov 2011 12:24:47 +0000 Subject: [PATCH] prism logical probabilistic system. --- packages/prism/LICENSE | 93 ++ packages/prism/LICENSE.src | 39 + packages/prism/README | 24 + packages/prism/exs/README | 65 + packages/prism/exs/alarm.psm | 122 ++ packages/prism/exs/bloodABO.psm | 111 ++ packages/prism/exs/bloodAaBb.psm | 114 ++ packages/prism/exs/bloodtype.dat | 100 ++ packages/prism/exs/dcoin.psm | 72 ++ packages/prism/exs/direction.psm | 46 + packages/prism/exs/hmm.psm | 99 ++ packages/prism/exs/jtree/README | 8 + packages/prism/exs/jtree/asia.psm | 84 ++ packages/prism/exs/jtree/jasia.psm | 153 +++ packages/prism/exs/jtree/jasia_a.psm | 167 +++ packages/prism/exs/noisy_or/README | 7 + .../prism/exs/noisy_or/alarm_nor_basic.psm | 160 +++ .../prism/exs/noisy_or/alarm_nor_generic.psm | 174 +++ packages/prism/exs/noisy_or/noisy_or.psm | 65 + packages/prism/exs/pdcg.psm | 89 ++ packages/prism/exs/pdcg_c.psm | 121 ++ packages/prism/exs/phmm.dat | 44 + packages/prism/exs/phmm.psm | 263 ++++ packages/prism/exs/plc.dat | 60 + packages/prism/exs/plc.psm | 215 ++++ packages/prism/exs/sbn.psm | 130 ++ packages/prism/exs/votes.psm | 112 ++ packages/prism/src/README | 16 + packages/prism/src/c/Makefile.in | 91 ++ packages/prism/src/c/core/bpx.c | 401 ++++++ packages/prism/src/c/core/bpx.h | 323 +++++ packages/prism/src/c/core/error.c | 108 ++ packages/prism/src/c/core/error.h | 66 + packages/prism/src/c/core/fputil.c | 11 + packages/prism/src/c/core/fputil.h | 51 + packages/prism/src/c/core/gamma.c | 306 +++++ packages/prism/src/c/core/gamma.h | 7 + packages/prism/src/c/core/glue.c | 197 +++ packages/prism/src/c/core/glue.h | 9 + packages/prism/src/c/core/idtable.c | 175 +++ packages/prism/src/c/core/idtable.h | 29 + packages/prism/src/c/core/idtable_preds.c | 249 ++++ packages/prism/src/c/core/idtable_preds.h | 41 + packages/prism/src/c/core/random.c | 360 ++++++ packages/prism/src/c/core/random.h | 14 + packages/prism/src/c/core/stuff.h | 23 + packages/prism/src/c/core/termpool.c | 424 +++++++ packages/prism/src/c/core/termpool.h | 20 + packages/prism/src/c/core/vector.c | 87 ++ packages/prism/src/c/core/vector.h | 59 + packages/prism/src/c/core/xmalloc.c | 35 + packages/prism/src/c/core/xmalloc.h | 25 + packages/prism/src/c/makefiles/Makefile.files | 56 + packages/prism/src/c/makefiles/README | 11 + packages/prism/src/c/mp/mp.h | 21 + packages/prism/src/c/mp/mp_core.c | 101 ++ packages/prism/src/c/mp/mp_core.h | 19 + packages/prism/src/c/mp/mp_em_aux.c | 256 ++++ packages/prism/src/c/mp/mp_em_aux.h | 29 + packages/prism/src/c/mp/mp_em_ml.c | 265 ++++ packages/prism/src/c/mp/mp_em_ml.h | 15 + packages/prism/src/c/mp/mp_em_preds.c | 167 +++ packages/prism/src/c/mp/mp_em_preds.h | 19 + packages/prism/src/c/mp/mp_em_vb.c | 256 ++++ packages/prism/src/c/mp/mp_em_vb.h | 15 + packages/prism/src/c/mp/mp_flags.c | 77 ++ packages/prism/src/c/mp/mp_flags.h | 13 + packages/prism/src/c/mp/mp_preds.c | 191 +++ packages/prism/src/c/mp/mp_preds.h | 22 + packages/prism/src/c/mp/mp_sw.c | 206 +++ packages/prism/src/c/mp/mp_sw.h | 22 + packages/prism/src/c/up/em.h | 106 ++ packages/prism/src/c/up/em_aux.c | 151 +++ packages/prism/src/c/up/em_aux.h | 16 + packages/prism/src/c/up/em_aux_ml.c | 777 ++++++++++++ packages/prism/src/c/up/em_aux_ml.h | 26 + packages/prism/src/c/up/em_aux_vb.c | 569 +++++++++ packages/prism/src/c/up/em_aux_vb.h | 25 + packages/prism/src/c/up/em_ml.c | 162 +++ packages/prism/src/c/up/em_ml.h | 8 + packages/prism/src/c/up/em_preds.c | 181 +++ packages/prism/src/c/up/em_preds.h | 11 + packages/prism/src/c/up/em_vb.c | 170 +++ packages/prism/src/c/up/em_vb.h | 8 + packages/prism/src/c/up/flags.c | 158 +++ packages/prism/src/c/up/flags.h | 48 + packages/prism/src/c/up/graph.c | 888 +++++++++++++ packages/prism/src/c/up/graph.h | 82 ++ packages/prism/src/c/up/graph_aux.c | 299 +++++ packages/prism/src/c/up/graph_aux.h | 15 + packages/prism/src/c/up/hindsight.c | 300 +++++ packages/prism/src/c/up/hindsight.h | 15 + packages/prism/src/c/up/up.h | 118 ++ packages/prism/src/c/up/util.c | 147 +++ packages/prism/src/c/up/util.h | 23 + packages/prism/src/c/up/viterbi.c | 1121 +++++++++++++++++ packages/prism/src/c/up/viterbi.h | 13 + packages/prism/src/prolog/Makefile.in | 108 ++ packages/prism/src/prolog/README | 40 + packages/prism/src/prolog/bp/eval.pl | 388 ++++++ packages/prism/src/prolog/core/error.pl | 909 +++++++++++++ packages/prism/src/prolog/core/format.pl | 55 + packages/prism/src/prolog/core/message.pl | 194 +++ packages/prism/src/prolog/core/random.pl | 286 +++++ packages/prism/src/prolog/mp/mp_learn.pl | 151 +++ packages/prism/src/prolog/mp/mp_main.pl | 112 ++ packages/prism/src/prolog/prism.yap | 50 + packages/prism/src/prolog/trans/bpif.pl | 53 + packages/prism/src/prolog/trans/dump.pl | 150 +++ packages/prism/src/prolog/trans/trans.pl | 735 +++++++++++ packages/prism/src/prolog/trans/verify.pl | 130 ++ packages/prism/src/prolog/up/batch.pl | 5 + packages/prism/src/prolog/up/bigarray.pl | 154 +++ packages/prism/src/prolog/up/dist.pl | 193 +++ packages/prism/src/prolog/up/dynamic.pl | 41 + packages/prism/src/prolog/up/expl.pl | 410 ++++++ packages/prism/src/prolog/up/flags.pl | 291 +++++ packages/prism/src/prolog/up/hash.pl | 42 + packages/prism/src/prolog/up/hindsight.pl | 497 ++++++++ packages/prism/src/prolog/up/learn.pl | 435 +++++++ packages/prism/src/prolog/up/list.pl | 882 +++++++++++++ packages/prism/src/prolog/up/main.pl | 338 +++++ packages/prism/src/prolog/up/prob.pl | 412 ++++++ packages/prism/src/prolog/up/sample.pl | 113 ++ packages/prism/src/prolog/up/switch.pl | 844 +++++++++++++ packages/prism/src/prolog/up/util.pl | 923 ++++++++++++++ packages/prism/src/prolog/up/viterbi.pl | 785 ++++++++++++ 127 files changed, 22788 insertions(+) create mode 100644 packages/prism/LICENSE create mode 100644 packages/prism/LICENSE.src create mode 100644 packages/prism/README create mode 100644 packages/prism/exs/README create mode 100644 packages/prism/exs/alarm.psm create mode 100644 packages/prism/exs/bloodABO.psm create mode 100644 packages/prism/exs/bloodAaBb.psm create mode 100644 packages/prism/exs/bloodtype.dat create mode 100644 packages/prism/exs/dcoin.psm create mode 100644 packages/prism/exs/direction.psm create mode 100644 packages/prism/exs/hmm.psm create mode 100644 packages/prism/exs/jtree/README create mode 100644 packages/prism/exs/jtree/asia.psm create mode 100644 packages/prism/exs/jtree/jasia.psm create mode 100644 packages/prism/exs/jtree/jasia_a.psm create mode 100644 packages/prism/exs/noisy_or/README create mode 100644 packages/prism/exs/noisy_or/alarm_nor_basic.psm create mode 100644 packages/prism/exs/noisy_or/alarm_nor_generic.psm create mode 100644 packages/prism/exs/noisy_or/noisy_or.psm create mode 100644 packages/prism/exs/pdcg.psm create mode 100644 packages/prism/exs/pdcg_c.psm create mode 100644 packages/prism/exs/phmm.dat create mode 100644 packages/prism/exs/phmm.psm create mode 100644 packages/prism/exs/plc.dat create mode 100644 packages/prism/exs/plc.psm create mode 100644 packages/prism/exs/sbn.psm create mode 100644 packages/prism/exs/votes.psm create mode 100644 packages/prism/src/README create mode 100644 packages/prism/src/c/Makefile.in create mode 100644 packages/prism/src/c/core/bpx.c create mode 100644 packages/prism/src/c/core/bpx.h create mode 100644 packages/prism/src/c/core/error.c create mode 100644 packages/prism/src/c/core/error.h create mode 100644 packages/prism/src/c/core/fputil.c create mode 100644 packages/prism/src/c/core/fputil.h create mode 100644 packages/prism/src/c/core/gamma.c create mode 100644 packages/prism/src/c/core/gamma.h create mode 100644 packages/prism/src/c/core/glue.c create mode 100644 packages/prism/src/c/core/glue.h create mode 100644 packages/prism/src/c/core/idtable.c create mode 100644 packages/prism/src/c/core/idtable.h create mode 100644 packages/prism/src/c/core/idtable_preds.c create mode 100644 packages/prism/src/c/core/idtable_preds.h create mode 100644 packages/prism/src/c/core/random.c create mode 100644 packages/prism/src/c/core/random.h create mode 100644 packages/prism/src/c/core/stuff.h create mode 100644 packages/prism/src/c/core/termpool.c create mode 100644 packages/prism/src/c/core/termpool.h create mode 100644 packages/prism/src/c/core/vector.c create mode 100644 packages/prism/src/c/core/vector.h create mode 100644 packages/prism/src/c/core/xmalloc.c create mode 100644 packages/prism/src/c/core/xmalloc.h create mode 100644 packages/prism/src/c/makefiles/Makefile.files create mode 100644 packages/prism/src/c/makefiles/README create mode 100644 packages/prism/src/c/mp/mp.h create mode 100644 packages/prism/src/c/mp/mp_core.c create mode 100644 packages/prism/src/c/mp/mp_core.h create mode 100644 packages/prism/src/c/mp/mp_em_aux.c create mode 100644 packages/prism/src/c/mp/mp_em_aux.h create mode 100644 packages/prism/src/c/mp/mp_em_ml.c create mode 100644 packages/prism/src/c/mp/mp_em_ml.h create mode 100644 packages/prism/src/c/mp/mp_em_preds.c create mode 100644 packages/prism/src/c/mp/mp_em_preds.h create mode 100644 packages/prism/src/c/mp/mp_em_vb.c create mode 100644 packages/prism/src/c/mp/mp_em_vb.h create mode 100644 packages/prism/src/c/mp/mp_flags.c create mode 100644 packages/prism/src/c/mp/mp_flags.h create mode 100644 packages/prism/src/c/mp/mp_preds.c create mode 100644 packages/prism/src/c/mp/mp_preds.h create mode 100644 packages/prism/src/c/mp/mp_sw.c create mode 100644 packages/prism/src/c/mp/mp_sw.h create mode 100644 packages/prism/src/c/up/em.h create mode 100644 packages/prism/src/c/up/em_aux.c create mode 100644 packages/prism/src/c/up/em_aux.h create mode 100644 packages/prism/src/c/up/em_aux_ml.c create mode 100644 packages/prism/src/c/up/em_aux_ml.h create mode 100644 packages/prism/src/c/up/em_aux_vb.c create mode 100644 packages/prism/src/c/up/em_aux_vb.h create mode 100644 packages/prism/src/c/up/em_ml.c create mode 100644 packages/prism/src/c/up/em_ml.h create mode 100644 packages/prism/src/c/up/em_preds.c create mode 100644 packages/prism/src/c/up/em_preds.h create mode 100644 packages/prism/src/c/up/em_vb.c create mode 100644 packages/prism/src/c/up/em_vb.h create mode 100644 packages/prism/src/c/up/flags.c create mode 100644 packages/prism/src/c/up/flags.h create mode 100644 packages/prism/src/c/up/graph.c create mode 100644 packages/prism/src/c/up/graph.h create mode 100644 packages/prism/src/c/up/graph_aux.c create mode 100644 packages/prism/src/c/up/graph_aux.h create mode 100644 packages/prism/src/c/up/hindsight.c create mode 100644 packages/prism/src/c/up/hindsight.h create mode 100644 packages/prism/src/c/up/up.h create mode 100644 packages/prism/src/c/up/util.c create mode 100644 packages/prism/src/c/up/util.h create mode 100644 packages/prism/src/c/up/viterbi.c create mode 100644 packages/prism/src/c/up/viterbi.h create mode 100644 packages/prism/src/prolog/Makefile.in create mode 100644 packages/prism/src/prolog/README create mode 100644 packages/prism/src/prolog/bp/eval.pl create mode 100644 packages/prism/src/prolog/core/error.pl create mode 100644 packages/prism/src/prolog/core/format.pl create mode 100644 packages/prism/src/prolog/core/message.pl create mode 100644 packages/prism/src/prolog/core/random.pl create mode 100644 packages/prism/src/prolog/mp/mp_learn.pl create mode 100644 packages/prism/src/prolog/mp/mp_main.pl create mode 100644 packages/prism/src/prolog/prism.yap create mode 100644 packages/prism/src/prolog/trans/bpif.pl create mode 100644 packages/prism/src/prolog/trans/dump.pl create mode 100644 packages/prism/src/prolog/trans/trans.pl create mode 100644 packages/prism/src/prolog/trans/verify.pl create mode 100644 packages/prism/src/prolog/up/batch.pl create mode 100644 packages/prism/src/prolog/up/bigarray.pl create mode 100644 packages/prism/src/prolog/up/dist.pl create mode 100644 packages/prism/src/prolog/up/dynamic.pl create mode 100644 packages/prism/src/prolog/up/expl.pl create mode 100644 packages/prism/src/prolog/up/flags.pl create mode 100644 packages/prism/src/prolog/up/hash.pl create mode 100644 packages/prism/src/prolog/up/hindsight.pl create mode 100644 packages/prism/src/prolog/up/learn.pl create mode 100644 packages/prism/src/prolog/up/list.pl create mode 100644 packages/prism/src/prolog/up/main.pl create mode 100644 packages/prism/src/prolog/up/prob.pl create mode 100644 packages/prism/src/prolog/up/sample.pl create mode 100644 packages/prism/src/prolog/up/switch.pl create mode 100644 packages/prism/src/prolog/up/util.pl create mode 100644 packages/prism/src/prolog/up/viterbi.pl diff --git a/packages/prism/LICENSE b/packages/prism/LICENSE new file mode 100644 index 000000000..a862be461 --- /dev/null +++ b/packages/prism/LICENSE @@ -0,0 +1,93 @@ +LICENSE AGREEMENT OF THE PRISM SYSTEM + +Copyright (c) 2009, +Taisuke Sato, Neng-Fa Zhou, Yoshitaka Kameya, Yusuke Izumi +All rights reserved. + +The PRISM system ("the Software") is built on top of B-Prolog +(http://www.probp.com/), which is provided by Afany Software. +The Software is developed subject to the C source code license +of B-Prolog (http://www.probp.com/license.htm) and distributed +with the permission from Afany Software. + +The PRISM development team, which consists of the members from +Tokyo Institute of Technology and from Afany Software, hereby +grants a non-exclusive and non-transferable license to the +person who uses the Software ("the User"), subject to this +agreement. + +1. RELATION WITH B-PROLOG. The Software consists of the +standard routines of B-Prolog ("the B-Prolog part") and the +extensional routines by the PRISM development team ("the PRISM +part"). The User must agree that the use of the B-Prolog part +is also restricted by the license agreement of B-Prolog with +the exception stated in Paragraphs 3 and 4. + +2. RIGHT TO USE. The User may use the Software provided +that the User has right to use B-Prolog according to the User's +license agreement of B-Prolog. Given the license agreement of +B-Prolog as of the release date of the Software, the User may +use the Software free of charge for academic and non-commercial +purposes, and must purchase a license for other use. + +3. DISTRIBUTION. The User may distribute the Software, only +for non-commercial purposes, provided that the Software is +distributed along with this agreement. + +4. SOURCE CODE AND DERIVED SOFTWARE. The PRISM development +team may make the source code of the PRISM part ("the Public +Source Code") publicly available under a separate license ("the +Additional License"), along with a minimal set of source and +binary files coming from the B-Prolog part and required to build +the Software ("the Build Kit"). The User may use and distribute +the Public Source Code and the Build Kit subject to the +following subparagraphs. + + 4.1. SOURCE CODE. The User may use and distribute the +Public Source Code, entirely or in part, subject to the +Additional License. + + 4.2. BUILD KIT. The User may use and distribute the Build +Kit according to the remaining subparagraphs, provided that +the User has right to use B-Prolog the User's license agreement +of B-Prolog. The Additional License shall not apply to the +Build Kit. + + 4.3. DERIVED SOFTWARE. The User may build software ("the +Derived Software") from the Public Source Code, modified or +unmodified, along with the Build Kit provided that (a) the User +has right to use the Build Kit as stated in Subparagraph 4.2, +and that (b) the Derived Software presents the following +message in the same way as the Software. + + This edition of B-Prolog is for evaluation, learning, and + non-profit research purposes only, and a license is needed for + any other uses. Please visit http://www.probp.com/license.htm + for the detail. + + 4.4. DISTRIBUTION OF DERIVED SOFTWARE. The User may distribute +the Derived Software built according to Subparagraph 4.3, only +for non-commercial purposes, provided that the Derived Software +is distributed (a) along with this agreement and (b) under the +license consistent with this agreement. + +5. COPYRIGHT. The B-Prolog part is copyrighted by Afany +Software and the PRISM part is copyrighted by the PRISM +development team. The Software contains several public domain +modules as listed in the B-Prolog's manual and the implementation +of Mersenne Twister copyrighted by its authors +(http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html), and +some portion of code in the PRISM part is based on the SPECFUN +library available in the NETLIB repository (http://www.netlib.org/). +The User shall own the copyright for the modified part of the +Software according to Subparagraph 3.3. + +6. NO WARRANTY. The Software is provided "as-is", without +any warranties express or implied. The User may report any +defects of the Software to the PRISM development team, but +there is no guarantee for those defects to be fixed. The User +who purchased a license from Afany Software might receive a +warranty according to the license agreement of B-Prolog, only +when the defects obviously originate from the B-Prolog part. +Neither Afany Software nor the PRISM development team is +responsible for any damages caused by the use of the Software. diff --git a/packages/prism/LICENSE.src b/packages/prism/LICENSE.src new file mode 100644 index 000000000..1be8900d5 --- /dev/null +++ b/packages/prism/LICENSE.src @@ -0,0 +1,39 @@ +The following license agreement is referred to as the "Additional +License" in Paragraph 4 of a license agreement on the use of the +software, which is titled "LICENSE AGREEMENT OF THE PRISM SYSTEM." + +-------------------------------------------------------------------- + +SOURCE CODE LICENSE AGREEMENT OF THE PRISM SYSTEM + +Copyright (c) 2009, +Taisuke Sato, Neng-Fa Zhou, Yoshitaka Kameya, Yusuke Izumi +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * None of the name of Tokyo Institute of Technology, the name of + City University of New York, nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/packages/prism/README b/packages/prism/README new file mode 100644 index 000000000..22c99b919 --- /dev/null +++ b/packages/prism/README @@ -0,0 +1,24 @@ +========================== README (top) ========================== + +This is a software package of PRISM version 2.0, a logic-based +programming system for statistical modeling, which is built +on top of B-Prolog (http://www.probp.com/). Since version 2.0, +the source code of the PRISM part is included in the released +package. Please use PRISM based on the agreement described in +LICENSE and LICENSE.src. + + LICENSE ... license agreement of PRISM + LICENSE.src ... additional license agreement on the source + code of PRISM + bin/ ... executables + doc/ ... documents + src/ ... source code + exs/ ... example programs + exs_fail/ ... example programs for generative modeling + with failure + exs_foc/ ... additional examples that demonstrate the + First Order Compiler + +For the files under each directory, please read the README file +in the directory. For the papers or additional information +on PRISM, please visit http://sato-www.cs.titech.ac.jp/prism/ . diff --git a/packages/prism/exs/README b/packages/prism/exs/README new file mode 100644 index 000000000..1dc2490a5 --- /dev/null +++ b/packages/prism/exs/README @@ -0,0 +1,65 @@ +========================== README (exs) ========================== + +Files/Directories: + README ... this file + direction.psm ... the first example in the user's manual + dcoin.psm ... simple program modeling two Bernoulli trial processes + bloodABO.psm ... ABO blood type program (ABO gene model) + bloodAaBb.psm ... ABO blood type program (AaBb gene model) + bloodtype.dat ... data file for bloodABO.psm and bloodAaBb.psm + alarm.psm ... Bayesian network program + sbn.psm ... Singly connected Bayesian network program + hmm.psm ... discrete hidden Markov model + phmm.psm ... profile hmm for the alignment of amino-acid sequences + phmm.dat ... data file for phmm.psm + pdcg.psm ... PCFG program for top-down parsing + pdcg_c.psm ... PCFG program for Charniak's example + plc.psm ... probabilistic left-corner parsing + votes.psm ... cross-validation of naive Bayes with the `votes' data + jtree/ ... Bayesian network program in a junction-tree form + noisy_or/ ... Bayesian network program using noisy OR + +How to use: + All programs are self-contained, hopefully. Try first a sample + session in each program to get familiar with a model. + +Comment: + The above programs contain no negation. When a program contains + negation, you have to compile away negation by FOC (first order + compiler). For PRISM programs with negation, see ../exs_fail. + +References: + +(PRISM) + Parameter Learning of Logic Programs for Symbolic-statistical Modeling, + Sato,T. and Kameya,Y., + Journal of Artificial Intelligence Research 15, pp.391-454, 2001. + + New advances in logic-based probabilistic modeling by PRISM, + Sato,T. and Kameya,Y., + Probabilistic Inductive Logic Programming, LNCS 4911, Springer, + pp.118-155, 2008. + +(PCFGs) + Foundations of Statistical Natural Language Processing, + Manning,C.D. and Schutze,H., + The MIT Press, 1999. + + A Separate-and-Learn Approach to EM Learning of PCFGs + Sato,T., Abe,S., Kameya,Y. and Shirai,K., + Proc. of the 6th Natural Language Processing Pacific Rim Symposium + (NLRPS-2001), pp.255-262, 2001. + +(BNs) + Probabilistic Reasoning in Intelligent Systems, + Pearl,J., + Morgan Kaufmann, 1988. + + Expert Systems and Probabilistic Network Models, + Castillo,E., Gutierrez,J.M. and Hadi,A.S., + Springer-Verlag, 1997. + +(HMMs) + Foundations of Speech Recognition, + Rabiner,L.R. and Juang,B., + Prentice-Hall, 1993. diff --git a/packages/prism/exs/alarm.psm b/packages/prism/exs/alarm.psm new file mode 100644 index 000000000..4ab044d34 --- /dev/null +++ b/packages/prism/exs/alarm.psm @@ -0,0 +1,122 @@ +%%%% +%%%% Bayesian networks (1) -- alarm.psm +%%%% +%%%% Copyright (C) 2004,2006,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% This example is borrowed from: +%% Poole, D., Probabilistic Horn abduction and Bayesian networks, +%% In Proc. of Artificial Intelligence 64, pp.81-129, 1993. +%% +%% (Fire) (Tampering) +%% / \ / +%% ((Smoke)) (Alarm) +%% | +%% (Leaving) (( )) -- observable node +%% | ( ) -- hidden node +%% ((Report)) +%% +%% In this network, we assume that all rvs (random variables) +%% take on {yes,no} and also assume that only two nodes, `Smoke' +%% and `Report', are observable. + +%%------------------------------------- +%% Quick start : sample session +%% +%% ?- prism(alarm),go. % Learn parameters from randomly generated +%% % 100 samples +%% +%% Get the probability and the explanation graph: +%% ?- prob(world(yes,no)). +%% ?- probf(world(yes,no)). +%% +%% Get the most likely explanation and its probability: +%% ?- viterbif(world(yes,no)). +%% ?- viterbi(world(yes,no)). +%% +%% Compute conditional hindsight probabilities: +%% ?- chindsight(world(yes,no)). +%% ?- chindsight_agg(world(yes,no),world(_,_,query,yes,_,no)). + +go:- alarm_learn(100). + +%%------------------------------------- +%% Declarations: + +:- set_prism_flag(data_source,file('world.dat')). + % When we run learn/0, the data are supplied + % from `world.dat'. + +values(_,[yes,no]). % We declare multiary random switch msw(.,V) + % used in this program such that V (outcome) + % is one of {yes,no}. Note that '_' is + % an anonymous logical variable in Prolog. + + % The distribution of V is specified by + % set_params below. + +%%------------------------------------ +%% Modeling part: +%% +%% The above BN defines a joint distribution +%% P(Fire,Tapering,Smoke,Alarm,Leaving,Report). +%% We assume `Smoke' and `Report' are observable while others are not. +%% Our modeling simulates random sampling of the BN from top nodes +%% using msws. For each rv, say `Fire', we introduce a corresponding +%% msw, say msw(fi,Fi) such that +%% msw(fi,Fi) <=> sampling msw named fi yields the outcome Fi. +%% Here fi is a constant intended for the name of rv `Fire.' +%% + +world(Fi,Ta,Al,Sm,Le,Re) :- + %% Define a distribution for world/5 such that e.g. + %% P(Fire=yes,Tapering=yes,Smoke=no,Alarm=no,Leaving=no,Report=no) + %% = P(world(yes,yes,no,no,no,no)) + msw(fi,Fi), % P(Fire) + msw(ta,Ta), % P(Tampering) + msw(sm(Fi),Sm), % CPT P(Smoke | Fire) + msw(al(Fi,Ta),Al), % CPT P(Alarm | Fire,Tampering) + msw(le(Al),Le), % CPT P(Leaving | Alarm) + msw(re(Le),Re). % CPT P(Report | Leaving) + +world(Sm,Re):- + %% Define marginal distribution for `Smoke' and `Report' + world(_,_,_,Sm,_,Re). + +%%------------------------------------ +%% Utility part: + +alarm_learn(N) :- + unfix_sw(_), % Make all parameters changeable + set_params, % Set parameters as you specified + get_samples(N,world(_,_),Gs), % Get N samples + fix_sw(fi), % Preserve the parameter values + learn(Gs). % for {msw(fi,yes), msw(fi,no)} + +% alarm_learn(N) :- +% %% generate teacher data and write them to `world.dat' +% %% before learn/0 is called. +% write_world(N,'world.dat'), +% learn. + +set_params :- + set_sw(fi,[0.1,0.9]), + set_sw(ta,[0.15,0.85]), + set_sw(sm(yes),[0.95,0.05]), + set_sw(sm(no),[0.05,0.95]), + set_sw(al(yes,yes),[0.50,0.50]), + set_sw(al(yes,no),[0.90,0.10]), + set_sw(al(no,yes),[0.85,0.15]), + set_sw(al(no,no),[0.05,0.95]), + set_sw(le(yes),[0.88,0.12]), + set_sw(le(no),[0.01,0.99]), + set_sw(re(yes),[0.75,0.25]), + set_sw(re(no),[0.10,0.90]). + +write_world(N,File) :- + get_samples(N,world(_,_),Gs),tell(File),write_world(Gs),told. + +write_world([world(Sm,Re)|Gs]) :- + write(world(Sm,Re)),write('.'),nl,write_world(Gs). +write_world([]). diff --git a/packages/prism/exs/bloodABO.psm b/packages/prism/exs/bloodABO.psm new file mode 100644 index 000000000..03a2e1bbc --- /dev/null +++ b/packages/prism/exs/bloodABO.psm @@ -0,0 +1,111 @@ +%%%% +%%%% ABO blood type --- bloodABO.psm +%%%% +%%%% Copyright (C) 2004,2006,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% ABO blood type consists of A, B, O and AB. They are observable +%% (phenotypes) and determined by a pair of blood type genes (geneotypes). +%% There are three ABO genes, namely a, b and o located on the 9th +%% chromosome of a human being. There are 6 geneotypes ({a,a},{a,b},{a,o}, +%% {b,b},{b,o},{o,o}) and each determines a blood type. For example {a,b} +%% gives blood type AB etc. Our task is to estimate frequencies of ABO +%% genes from a random sample of ABO blood type, assuming random mate. + +%%------------------------------------- +%% Quick start : sample session +%% +%% ?- prism(bloodABO),go,print_blood. +%% % Learn parameters from randomly generated +%% % 100 samples with A:B:O:AB = 38:22:31:9 +%% +%% ?- sample(bloodtype(X)). +%% % Pick up a person with blood type X randomly +%% % acccording to the currrent parameter settings +%% +%% ?- get_samples(100,bloodtype(X),_Gs),countlist(_Gs,Cs). +%% % Pick up 100 persons and get the frequencies +%% % of their blood types +%% +%% ?- probf(bloodtype(ab),E),print_graph(E). +%% % Print all explanations for blooodtype(ab) in +%% % a compressed form +%% +%% ?- prob(bloodtype(ab),P). +%% % P is the probability of bloodtype(ab) being true +%% +%% ?- viterbif(bloodtype(ab)). +%% ?- viterbif(bloodtype(ab),P,E),print_graph(E). +%% ?- viterbi(bloodtype(ab),P). +%% % P is the probability of a most likely +%% % explanation E for bloodtype(ab). +%% +%% ?- viterbit(bloodtype(ab)). +%% % Print the most likely explanation for +%% % bloodtype(ab) in a tree form. + +go:- learn_bloodtype(100). + +%%------------------------------------- +%% Declarations: + +:- set_prism_flag(data_source,file('bloodtype.dat')). + % When we run learn/0, the data are supplied + % by `bloodtype.dat'. + +values(gene,[a,b,o],[0.5,0.2,0.3]). + % We declare msw(gene,V) s.t. V takes on + % one of the genes {a,b,o} when executed, + % with the freq.: a 50%, b 20%, o 30%. + +%%------------------------------------ +%% Modeling part: + +bloodtype(P) :- + genotype(X,Y), + ( X=Y -> P=X + ; X=o -> P=Y + ; Y=o -> P=X + ; P=ab + ). + +genotype(X,Y) :- msw(gene,X),msw(gene,Y). + % We assume random mate. Note that msw(gene,X) + % and msw(gene,Y) are i.i.d. (independent and + % identically distributed) random variables + % in Prism because they have the same id but + % different subgoals. + +%%------------------------------------ +%% Utility part: + +learn_bloodtype(N) :- % Learn parameters from N observations + random_set_seed(214857), % Set seed of the random number generator + gen_bloodtype(N,Gs),!, % Sample bloodtype/1 of size N + learn(Gs). % Perform search and graphical EM learning +% learn. % <= when using the file `bloodtype.dat' + +gen_bloodtype(N,Gs) :- + N > 0, + random_select([a,b,o,ab],[0.38,0.22,0.31,0.09],X), + Gs = [bloodtype(X)|Gs1], % Sample a blood type with an empirical + N1 is N-1,!, % ratio for Japanese people. + gen_bloodtype(N1,Gs1). +gen_bloodtype(0,[]). + +print_blood :- + prob(bloodtype(a),PA),prob(bloodtype(b),PB), + prob(bloodtype(o),PO),prob(bloodtype(ab),PAB), + nl, + format("P(A) = ~6f~n",[PA]), + format("P(B) = ~6f~n",[PB]), + format("P(O) = ~6f~n",[PO]), + format("P(AB) = ~6f~n",[PAB]). + +print_gene :- + get_sw(gene,[_,[a,b,o],[GA,GB,GO]]), + nl, + format("P(a) = ~6f~n",[GA]), + format("P(b) = ~6f~n",[GB]), + format("P(o) = ~6f~n",[GO]). diff --git a/packages/prism/exs/bloodAaBb.psm b/packages/prism/exs/bloodAaBb.psm new file mode 100644 index 000000000..1048ebff5 --- /dev/null +++ b/packages/prism/exs/bloodAaBb.psm @@ -0,0 +1,114 @@ +%%%% +%%%% Another hypothesis on ABO blood type inheritance --- bloodAaBb.psm +%%%% +%%%% Copyright (C) 2007,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% ABO blood type consists of A, B, O and AB. They are observable +%% (phenotypes) and determined by a pair of blood type genes (geneotypes). +%% At present, it is known that there are three ABO genes, namely a, b and +%% o located on the 9th chromosome of a human being, but in early 20th +%% century, there was another hypothesis that we have two loci for ABO +%% blood type with dominant alleles A/a and B/b. That is, genotypes aabb, +%% A*bb, aaB* and A*B* correspond to the blood types (phenotypes) O, A, B +%% and AB, respectively, where * stands for a don't care symbol. We call +%% this hypothesis the AaBb gene model, and assume random mating. + +%%------------------------------------- +%% Quick start : sample session -- the same as that of bloodABO.psm +%% +%% ?- prism(bloodAaBb),go,print_blood. +%% % Learn parameters from randomly generated +%% % 100 samples with A:B:O:AB = 38:22:31:9 +%% +%% ?- probf(bloodtype(ab),E),print_graph(E). +%% ?- prob(bloodtype(ab),P). +%% +%% ?- viterbif(bloodtype(ab),P,E),print_graph(E). +%% ?- viterbi(bloodtype(ab),P). +%% % P is the probability of a most likely +%% % explanation E for bloodtype(ab). + +go:- learn_bloodtype(100). + +%%------------------------------------- +%% Session for model selection: +%% +%% -- we try to evaluate the plausibilities of the correct model (ABO +%% gene model) and this AaBb gene model according to the data in +%% `bloodtype.dat'. The data file `bloodtype.dat' contains 38 +%% persons of blood type A, 22 persons of blood type B, 31 persons +%% of blood type O, and 9 persons of blood type AB (the ratio is +%% almost the same as that in Japanese people). +%% +%% 1. Modify bloodABO.psm and bloodAaBb.psm: +%% - Use learn/0 instead of learn/1. +%% +%% 2. Get the BIC value for the ABO gene model (bloodABO.psm) +%% ?- prism(bloodABO). +%% ?- learn. +%% ?- learn_statistics(bic,BIC). +%% +%% 3. Get the BIC value for the AaBb gene model (this file) +%% ?- prism(bloodAaBb). +%% ?- learn. +%% ?- learn_statistics(bic,BIC). +%% + +:- set_prism_flag(data_source,file('bloodtype.dat')). + % When we run learn/0, the data are supplied + % by `bloodtype.dat'. + +values(locus1,['A',a]). +values(locus2,['B',b]). + +%%------------------------------------ +%% Modeling part: + +bloodtype(P) :- + genotype(locus1,X1,Y1), + genotype(locus2,X2,Y2), + ( X1=a, Y1=a, X2=b, Y2=b -> P=o + ; ( X1='A' ; Y1='A' ), X2=b, Y2=b -> P=a + ; X1=a, Y1=a, ( X2='B' ; Y2='B') -> P=b + ; P=ab + ). + +genotype(L,X,Y) :- msw(L,X),msw(L,Y). + +%%------------------------------------ +%% Utility part: +%% (the same as that in bloodABO.psm) + +learn_bloodtype(N) :- % Learn parameters from N observations + random_set_seed(214857), % Set seed of the random number generator + gen_bloodtype(N,Gs),!, % Sample bloodtype/1 of size N + learn(Gs). % Perform search and graphical EM learning +% learn. % <= when using the file `bloodtype.dat' + +gen_bloodtype(N,Gs) :- + N > 0, + random_select([a,b,o,ab],[0.38,0.22,0.31,0.09],X), + Gs = [bloodtype(X)|Gs1], % Sample a blood type with an empirical + N1 is N-1,!, % ratio for Japanese people. + gen_bloodtype(N1,Gs1). +gen_bloodtype(0,[]). + +print_blood :- + prob(bloodtype(a),PA),prob(bloodtype(b),PB), + prob(bloodtype(o),PO),prob(bloodtype(ab),PAB), + nl, + format("P(A) = ~6f~n",[PA]), + format("P(B) = ~6f~n",[PB]), + format("P(O) = ~6f~n",[PO]), + format("P(AB) = ~6f~n",[PAB]). + +print_gene :- + get_sw(locus1,[_,['A',a],[GA,Ga]]), + get_sw(locus2,[_,['B',b],[GB,Gb]]), + nl, + format("P(A) = ~6f~n",[GA]), + format("P(a) = ~6f~n",[Ga]), + format("P(B) = ~6f~n",[GB]), + format("P(b) = ~6f~n",[Gb]). diff --git a/packages/prism/exs/bloodtype.dat b/packages/prism/exs/bloodtype.dat new file mode 100644 index 000000000..01512e0a6 --- /dev/null +++ b/packages/prism/exs/bloodtype.dat @@ -0,0 +1,100 @@ +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(a). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(b). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(o). +bloodtype(ab). +bloodtype(ab). +bloodtype(ab). +bloodtype(ab). +bloodtype(ab). +bloodtype(ab). +bloodtype(ab). +bloodtype(ab). +bloodtype(ab). diff --git a/packages/prism/exs/dcoin.psm b/packages/prism/exs/dcoin.psm new file mode 100644 index 000000000..a8caa9d47 --- /dev/null +++ b/packages/prism/exs/dcoin.psm @@ -0,0 +1,72 @@ +%%%% +%%%% Double coin tossing --- dcoin.psm +%%%% +%%%% Copyright (C) 2004,2006,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% A sequential mixture of two Bernoulli trials processes. +%% We have two coins, coin(1) and coin(2). +%% Start with coin(1), we keep flipping a coin and observe the outcome. +%% We change coins according to the rule in the process. +%% If the outcome is "head", the next coin to flip is coin(2). +%% If the outcome is "tail", the next coin to flip is coin(1). +%% The learning task is to estimate parameters for coin(1) and coin(2), +%% observing a sequence of outcomes. +%% As there is no hidden variable in this model, EM learning is just +%% ML estimation from complete data. + +%%------------------------------------- +%% Quick start : sample session +%% +%% (1) load this program +%% ?- prism(dcoin). +%% +%% (2) sampling and probability computations +%% ?- sample(dcoin(10,X)),prob(dcoin(10,X)). +%% ?- sample(dcoin(10,X)),probf(dcoin(10,X)). +%% +%% (3) EM learning +%% ?- go. + +go:- dcoin_learn(500). + +%%------------------------------------ +%% Declarations: + +values(coin(1),[head,tail],[0.5,0.5]). + % Declare msw(coin(1),V) s.t. V = head or + % V = tail, where P(msw(coin(1),head)) = 0.5 + % and P(msw(coin(1),tail)) = 0.5. +values(coin(2),[head,tail],[0.7,0.3]). + % Declare msw(coin(2),V) s.t. V = head or + % V = tail, where P(msw(coin(2),head)) = 0.7 + % and P(msw(coin(2),tail)) = 0.3. + +%%------------------------------------ +%% Modeling part: + +dcoin(N,Rs) :- % Rs is a list with length N of outcomes + dcoin(N,coin(1),Rs). % from two Bernoulli trials processes. + +dcoin(N,Coin,[R|Rs]) :- + N > 0, + msw(Coin,R), + ( R == head, NextCoin = coin(2) + ; R == tail, NextCoin = coin(1) ), + N1 is N-1, + dcoin(N1,NextCoin,Rs). +dcoin(0,_,[]). + +%%------------------------------------ +%% Utility part: + +dcoin_learn(N) :- + set_params, % Set parameters. + sample(dcoin(N,Rs)), % Get a sample Rs of size N. + Goals = [dcoin(N,Rs)], % Estimate the parameters from Rs. + learn(Goals). + +set_params :- + set_sw(coin(1),[0.5,0.5]), + set_sw(coin(2),[0.7,0.3]). diff --git a/packages/prism/exs/direction.psm b/packages/prism/exs/direction.psm new file mode 100644 index 000000000..3ca8b476d --- /dev/null +++ b/packages/prism/exs/direction.psm @@ -0,0 +1,46 @@ +%%%% +%%%% Decision of the direction by a coin tossing -- direction.psm +%%%% +%%%% This program has just one random switch named `coin'. +%%%% +%%%% Copyright (C) 2004,2006,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%%------------------------------------- +%% Sample session +%% +%% (1) Load this program: +%% ?- prism(direction). +%% +%% (2) Get a sample: +%% ?- sample(direction(D)). +%% +%% (3) Display the information about the switch `coin': +%% ?- show_sw. +%% +%% (4) Set the probability distribution to the switch `coin': +%% ?- set_sw(coin,[0.7,0.3]). +%% +%% (5) Display the switch information again with the distribution set +%% at step 4: +%% ?- show_sw. +%% +%% (6) Get a sample again with the distribution set at step 4: +%% ?- sample(direction(D)). +%% +%% [Note1] +%% Since 1.9, without any extra settings, the probability distribution +%% of every switch is set to a uniform distribution. +%% +%% [Note2] +%% If you go (3) with skipping (2), nothing should be displayed. This +%% is because any random switch will not be registered by the system until +%% it is explicitly used or referred to. + +values(coin,[head,tail]). % The switch `coin' takes `head' or `tail' as its value + +direction(D):- + msw(coin,Face), % Make a coin tossing + ( Face==head -> D=left ; D=right). % Decide the direction according to + % the result of coin tossing diff --git a/packages/prism/exs/hmm.psm b/packages/prism/exs/hmm.psm new file mode 100644 index 000000000..c1cecf6d5 --- /dev/null +++ b/packages/prism/exs/hmm.psm @@ -0,0 +1,99 @@ +%%%% +%%%% Hidden Markov model --- hmm.psm +%%%% +%%%% Copyright (C) 2004,2006,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% [state diagram:] (2 states and 2 output symbols) +%% +%% +--------+ +--------+ +%% | | | | +%% | +------+ +------+ | +%% | | |------->| | | +%% +---->| s0 | | s1 |<----+ +%% | |<-------| | +%% +------+ +------+ +%% +%% - In each state, possible output symbols are `a' and `b'. + +%%------------------------------------- +%% Quick start : sample session +%% +%% ?- prism(hmm),hmm_learn(100). % Learn parameters from 100 randomly +%% % generated samples +%% +%% ?- show_sw. % Confirm the learned parameter +%% +%% ?- prob(hmm([a,a,a,a,a,b,b,b,b,b])). % Calculate the probability +%% ?- probf(hmm([a,a,a,a,a,b,b,b,b,b])). % Get the explanation graph +%% +%% ?- viterbi(hmm([a,a,a,a,a,b,b,b,b,b])). % Run the Viterbi computation +%% ?- viterbif(hmm([a,a,a,a,a,b,b,b,b,b])). % Get the Viterbi explanation +%% +%% ?- hindsight(hmm([a,a,a,a,a,b,b,b,b,b])). % Get hindsight probabilities + +%%------------------------------------ +%% Declarations: + +values(init,[s0,s1]). % state initialization +values(out(_),[a,b]). % symbol emission +values(tr(_),[s0,s1]). % state transition + +% :- set_prism_flag(default_sw_d,1.0). +% :- set_prism_flag(epsilon,1.0e-2). +% :- set_prism_flag(restart,10). +% :- set_prism_flag(log_scale,on). + +%%------------------------------------ +%% Modeling part: + +hmm(L):- % To observe a string L: + str_length(N), % Get the string length as N + msw(init,S), % Choose an initial state randomly + hmm(1,N,S,L). % Start stochastic transition (loop) + +hmm(T,N,_,[]):- T>N,!. % Stop the loop +hmm(T,N,S,[Ob|Y]) :- % Loop: current state is S, current time is T + msw(out(S),Ob), % Output Ob at the state S + msw(tr(S),Next), % Transit from S to Next. + T1 is T+1, % Count up time + hmm(T1,N,Next,Y). % Go next (recursion) + +str_length(10). % String length is 10 + +%%------------------------------------ +%% Utility part: + +hmm_learn(N):- + set_params,!, % Set parameters manually + get_samples(N,hmm(_),Gs),!, % Get N samples + learn(Gs). % Learn with the samples + +set_params:- + set_sw(init, [0.9,0.1]), + set_sw(tr(s0), [0.2,0.8]), + set_sw(tr(s1), [0.8,0.2]), + set_sw(out(s0),[0.5,0.5]), + set_sw(out(s1),[0.6,0.4]). + +%% prism_main/1 is a special predicate for batch execution. +%% The following command conducts learning from 50 randomly +%% generated samples: +%% > upprism hmm 50 + +prism_main([Arg]):- + parse_atom(Arg,N), % Convert an atom ('50') to a number (50) + hmm_learn(N). % Learn with N samples + +%% viterbi_states(Os,Ss) returns the most probable sequence Ss +%% of state transitions for an output sequence Os. +%% +%% | ?- viterbi_states([a,a,a,a,a,b,b,b,b,b],States). +%% +%% States = [s0,s1,s0,s1,s0,s1,s0,s1,s0,s1,s0] ? + +viterbi_states(Outputs,States):- + viterbif(hmm(Outputs),_,E), + viterbi_subgoals(E,E1), + maplist(hmm(_,_,S,_),S,true,E1,States). diff --git a/packages/prism/exs/jtree/README b/packages/prism/exs/jtree/README new file mode 100644 index 000000000..66eb57bc2 --- /dev/null +++ b/packages/prism/exs/jtree/README @@ -0,0 +1,8 @@ +================== README (exs/jtree) ========================== + +Files: + README ... This file + asia.psm ... BN for Asia network (naive) + jasia.psm ... BN for Asia network (junction-tree; evidences kept in D-list) + jasia_a.psm ... BN for Asia network (junction-tree; evidences asserted first) + bn2prism/ ... Java translator from BNs to join-tree PRISM programs diff --git a/packages/prism/exs/jtree/asia.psm b/packages/prism/exs/jtree/asia.psm new file mode 100644 index 000000000..d2db8fda5 --- /dev/null +++ b/packages/prism/exs/jtree/asia.psm @@ -0,0 +1,84 @@ +%%%% +%%%% Bayesian networks for Asia network -- asia.psm +%%%% +%%%% Copyright (C) 2007,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% This example is known as the Asia network, and was borrowed from: +%% S. L. Lauritzen and D. J. Spiegelhalter (1988). +%% Local computations with probabilities on graphical structures +%% and their application to expert systems. +%% Journal of Royal Statistical Society, Vol.B50, No.2, pp.157-194. +%% +%% ((Smoking[S])) +%% ((Visit to Asia[A])) / \ +%% | / \ +%% v v \ +%% (Tuberculosis[T]) (Lang cancer[L]) \ +%% \ / \ +%% \ / v +%% v v (Bronchinitis[B]) +%% (Tuberculosis or lang cancer[TL]) / +%% / \ / +%% / \ / +%% v \ / +%% ((X-ray[X])) v v +%% ((Dyspnea[D])) +%% +%% We assume that the nodes A, S, X and D are observable. This +%% program provides a naive representation of the Asia network, as +%% shown in ../alarm.psm. The junction-tree version of the Asia +%% network program is given in jasia.psm + +%%------------------------------------- +%% Quick start: +%% +%% ?- prism(asia),go. + +go:- chindsight_agg(world(f,_,_,t),world(f,query,_,_,_,_,_,t)). + % we compute a conditional distribution P(T | A=false, D=true) + +%%------------------------------------- +%% Declarations: + +values(bn(_,_),[t,f]). % each switch takes on true or false + +%%------------------------------------- +%% Modeling part: + +world(A,S,X,D):- world(A,_,S,_,_,X,_,D). + +world(A,T,S,L,TL,X,B,D) :- + msw(bn(a,[]),A),msw(bn(t,[A]),T), + msw(bn(s,[]),S),msw(bn(l,[S]),L), + incl_or(T,L,TL), + msw(bn(x,[TL]),X),msw(bn(b,[S]),B), + msw(bn(d,[TL,B]),D). + +% inclusive OR +incl_or(t,t,t). +incl_or(t,f,t). +incl_or(f,t,t). +incl_or(f,f,f). + +%%------------------------------------- +%% Utility part: + +:- set_params. + +set_params:- + set_sw(bn(a,[]),[0.01,0.99]), + set_sw(bn(t,[t]),[0.05,0.95]), + set_sw(bn(t,[f]),[0.01,0.99]), + set_sw(bn(s,[]),[0.5,0.5]), + set_sw(bn(l,[t]),[0.1,0.9]), + set_sw(bn(l,[f]),[0.01,0.99]), + set_sw(bn(x,[t]),[0.98,0.02]), + set_sw(bn(x,[f]),[0.05,0.95]), + set_sw(bn(b,[t]),[0.60,0.40]), + set_sw(bn(b,[f]),[0.30,0.70]), + set_sw(bn(d,[t,t]),[0.90,0.10]), + set_sw(bn(d,[t,f]),[0.70,0.30]), + set_sw(bn(d,[f,t]),[0.80,0.20]), + set_sw(bn(d,[f,f]),[0.10,0.90]). diff --git a/packages/prism/exs/jtree/jasia.psm b/packages/prism/exs/jtree/jasia.psm new file mode 100644 index 000000000..6f0753466 --- /dev/null +++ b/packages/prism/exs/jtree/jasia.psm @@ -0,0 +1,153 @@ +%%%% +%%%% Join-tree PRISM program for Asia network -- jasia.psm +%%%% +%%%% Copyright (C) 2007,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% This example is known as the Asia network, and was borrowed from: +%% S. L. Lauritzen and D. J. Spiegelhalter (1988). +%% Local computations with probabilities on graphical structures +%% and their application to expert systems. +%% Journal of Royal Statistical Society, Vol.B50, No.2, pp.157-194. +%% +%% ((Smoking[S])) +%% ((Visit to Asia[A])) / \ +%% | / \ +%% v v \ +%% (Tuberculosis[T]) (Lang cancer[L]) \ +%% \ / \ +%% \ / v +%% v v (Bronchinitis[B]) +%% (Tuberculosis or lang cancer[TL]) / +%% / \ / +%% / \ / +%% v \ / +%% ((X-ray[X])) v v +%% ((Dyspnea[D])) +%% +%% We assume that the nodes A, S, X and D are observable. One may +%% notice that this network is multiply-connected (there are undirected +%% loop: S-L-TL-D-B-S). To perform efficient probabilistic inferences, +%% one popular method is the join-tree (JT) algorithm. In the JT +%% algorithm, we first convert the original network (DAG) into a tree- +%% structured undirected graph, called join tree (junction tree), in +%% which a node corresponds to a set of nodes in the original network. +%% Then we compute the conditional probabilities based on the join +%% tree. For example, the above network is converted into the +%% following join tree: +%% +%% node4(A,T) node2(S,L,B) +%% \ \ +%% [T] [L,B] +%% \ \ node1 +%% node3(T,L,TL)--[L,TL]--(L,TL,B) +%% / +%% [TL,B] +%% node6 / +%% (TL,X)--[TL]--(TL,B,D) +%% node5 +%% +%% where (...) corresponds to a node and [...] corresponds to a +%% separator. In this join tree, node2 corresponds to a set {S,L,B} of +%% the original nodes. We consider that node1 is the root of this join +%% tree. +%% +%% Here we write a PRISM program that represents the above join tree. +%% The predicate named msg_i_j corresponds to the edge from node i to +%% node j in the join tree. The predicate named node_i corresponds to +%% node i. +%% +%% The directory `bn2prism' in the same directory contains BN2Prism, a +%% Java translator from a Bayesian network to a PRISM program in join- +%% tree style, like the one shown here. + +%%------------------------------------- +%% Quick start: +%% +%% ?- prism(jasia),go. + +go:- chindsight_agg(world([(a,f),(d,t)]),node_4(_,query,_)). + % we compute a conditional distribution P(T | A=false, D=true) + +go2:- prob(world([(a,f),(d,t)])). + % we compute a marginal probability P(A=false, D=true) + +%%------------------------------------- +%% Declarations: + +values(bn(_,_),[t,f]). % each switch takes on true or false + +%%------------------------------------- +%% Modeling part: +%% +%% [Note] +%% Evidences are kept in a difference list in the last argument of +%% the msg_i_j and the node_i predicates. For simplicity, it is +%% assumed that the evidences are given in the same order as that +%% of appearances of msw/2 in the top-down execution of world/1. + +world(E):- msg_1_0(E-[]). + +msg_1_0(E0-E1) :- node_1(_L,_TL,_B,E0-E1). +msg_2_1(L,B,E0-E1 ):- node_2(_S,L,B,E0-E1). +msg_3_1(L,TL,E0-E1):- node_3(_T,L,TL,E0-E1). +msg_4_3(T,E0-E1) :- node_4(_A,T,E0-E1). +msg_5_1(TL,B,E0-E1):- node_5(TL,B,_D,E0-E1). +msg_6_5(TL,E0-E1) :- node_6(TL,_X,E0-E1). + +node_1(L,TL,B,E0-E1):- + msg_2_1(L,B,E0-E2), + msg_3_1(L,TL,E2-E3), + msg_5_1(TL,B,E3-E1). + +node_2(S,L,B,E0-E1):- + cpt(s,[],S,E0-E2), + cpt(l,[S],L,E2-E3), + cpt(b,[S],B,E3-E1). + +node_3(T,L,TL,E0-E1):- + incl_or(L,T,TL), + msg_4_3(T,E0-E1). + +node_4(A,T,E0-E1):- + cpt(a,[],A,E0-E2), + cpt(t,[A],T,E2-E1). + +node_5(TL,B,D,E0-E1):- + cpt(d,[TL,B],D,E0-E2), + msg_6_5(TL,E2-E1). + +node_6(TL,X,E0-E1):- + cpt(x,[TL],X,E0-E1). + +cpt(X,Par,V,E0-E1):- + ( E0=[(X,V)|E1] -> true ; E0=E1 ), + msw(bn(X,Par),V). + +% inclusive OR +incl_or(t,t,t). +incl_or(t,f,t). +incl_or(f,t,t). +incl_or(f,f,f). + +%%------------------------------------- +%% Utility part: + +:- set_params. + +set_params:- + set_sw(bn(a,[]),[0.01,0.99]), + set_sw(bn(t,[t]),[0.05,0.95]), + set_sw(bn(t,[f]),[0.01,0.99]), + set_sw(bn(s,[]),[0.5,0.5]), + set_sw(bn(l,[t]),[0.1,0.9]), + set_sw(bn(l,[f]),[0.01,0.99]), + set_sw(bn(x,[t]),[0.98,0.02]), + set_sw(bn(x,[f]),[0.05,0.95]), + set_sw(bn(b,[t]),[0.60,0.40]), + set_sw(bn(b,[f]),[0.30,0.70]), + set_sw(bn(d,[t,t]),[0.90,0.10]), + set_sw(bn(d,[t,f]),[0.70,0.30]), + set_sw(bn(d,[f,t]),[0.80,0.20]), + set_sw(bn(d,[f,f]),[0.10,0.90]). diff --git a/packages/prism/exs/jtree/jasia_a.psm b/packages/prism/exs/jtree/jasia_a.psm new file mode 100644 index 000000000..c8545e2b0 --- /dev/null +++ b/packages/prism/exs/jtree/jasia_a.psm @@ -0,0 +1,167 @@ +%%%% +%%%% Join-tree PRISM program for Asia network -- jasia.psm +%%%% +%%%% Copyright (C) 2009 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% This example is known as the Asia network, and was borrowed from: +%% S. L. Lauritzen and D. J. Spiegelhalter (1988). +%% Local computations with probabilities on graphical structures +%% and their application to expert systems. +%% Journal of Royal Statistical Society, Vol.B50, No.2, pp.157-194. +%% +%% ((Smoking[S])) +%% ((Visit to Asia[A])) / \ +%% | / \ +%% v v \ +%% (Tuberculosis[T]) (Lang cancer[L]) \ +%% \ / \ +%% \ / v +%% v v (Bronchinitis[B]) +%% (Tuberculosis or lang cancer[TL]) / +%% / \ / +%% / \ / +%% v \ / +%% ((X-ray[X])) v v +%% ((Dyspnea[D])) +%% +%% We assume that the nodes A, S, X and D are observable. One may +%% notice that this network is multiply-connected (there are undirected +%% loop: S-L-TL-D-B-S). To perform efficient probabilistic inferences, +%% one popular method is the join-tree (JT) algorithm. In the JT +%% algorithm, we first convert the original network (DAG) into a tree- +%% structured undirected graph, called join tree (junction tree), in +%% which a node corresponds to a set of nodes in the original network. +%% Then we compute the conditional probabilities based on the join +%% tree. For example, the above network is converted into the +%% following join tree: +%% +%% node4(A,T) node2(S,L,B) +%% \ \ +%% [T] [L,B] +%% \ \ node1 +%% node3(T,L,TL)--[L,TL]--(L,TL,B) +%% / +%% [TL,B] +%% node6 / +%% (TL,X)--[TL]--(TL,B,D) +%% node5 +%% +%% where (...) corresponds to a node and [...] corresponds to a +%% separator. In this join tree, node2 corresponds to a set {S,L,B} of +%% the original nodes. We consider that node1 is the root of this join +%% tree. +%% +%% Here we write a PRISM program that represents the above join tree. +%% The predicate named msg_i_j corresponds to the edge from node i to +%% node j in the join tree. The predicate named node_i corresponds to +%% node i. +%% +%% The directory `bn2prism' in the same directory contains BN2Prism, a +%% Java translator from a Bayesian network to a PRISM program in join- +%% tree style, like the one shown here. + +%%------------------------------------- +%% Quick start: +%% +%% ?- prism(jasia_a),go. + +go:- chindsight_agg(world([(a,f),(d,t)]),node_4(_,query)). + % we compute a conditional distribution P(T | A=false, D=true) + +go2:- prob(world([(a,f),(d,t)])). + % we compute a marginal probability P(A=false, D=true) + +%%------------------------------------- +%% Declarations: + +values(bn(_,_),[t,f]). % each switch takes on true or false + +%%------------------------------------- +%% Modeling part: +%% +%% [Note] +%% Evidences are added first into the Prolog database. This is a +%% simpler method than keeping the evidences in difference list +%% (as done in jasia.psm). However, in learning, the subgoals are +%% inappropriately shared among the observed goals, each of which +%% is associated with a different set of evidences (This optimization +%% is called inter-goal sharing, and unconditionally enabled in the +%% current PRISM system). An ad-hoc workaround is to introduce an +%% ID for each set of evidences and keep the ID through the arguments +%% (e.g. we define world(ID,E), msg_2_1(ID,L,B), and so on). + +world(E):- assert_evid(E),msg_1_0. + +msg_1_0 :- node_1(_L,_TL,_B). +msg_2_1(L,B) :- node_2(_S,L,B). +msg_3_1(L,TL):- node_3(_T,L,TL). +msg_4_3(T) :- node_4(_A,T). +msg_5_1(TL,B):- node_5(TL,B,_D). +msg_6_5(TL) :- node_6(TL,_X). + +node_1(L,TL,B):- + msg_2_1(L,B), + msg_3_1(L,TL), + msg_5_1(TL,B). + +node_2(S,L,B):- + cpt(s,[],S), + cpt(l,[S],L), + cpt(b,[S],B). + +node_3(T,L,TL):- + incl_or(L,T,TL), + msg_4_3(T). + +node_4(A,T):- + cpt(a,[],A), + cpt(t,[A],T). + +node_5(TL,B,D):- + cpt(d,[TL,B],D), + msg_6_5(TL). + +node_6(TL,X):- + cpt(x,[TL],X). + +cpt(X,Par,V):- + ( evid(X,V) -> true ; true ), + msw(bn(X,Par),V). + +% inclusive OR +incl_or(t,t,t). +incl_or(t,f,t). +incl_or(f,t,t). +incl_or(f,f,f). + +% adding evidences to Prolog database +assert_evid(Es):- + retractall(evid(_,_)), + assert_evid0(Es). +assert_evid0([]). +assert_evid0([(X,V)|Es]):- + assert(evid(X,V)),!, + assert_evid0(Es). + +%%------------------------------------- +%% Utility part: + +:- set_params. + +set_params:- + set_sw(bn(a,[]),[0.01,0.99]), + set_sw(bn(t,[t]),[0.05,0.95]), + set_sw(bn(t,[f]),[0.01,0.99]), + set_sw(bn(s,[]),[0.5,0.5]), + set_sw(bn(l,[t]),[0.1,0.9]), + set_sw(bn(l,[f]),[0.01,0.99]), + set_sw(bn(x,[t]),[0.98,0.02]), + set_sw(bn(x,[f]),[0.05,0.95]), + set_sw(bn(b,[t]),[0.60,0.40]), + set_sw(bn(b,[f]),[0.30,0.70]), + set_sw(bn(d,[t,t]),[0.90,0.10]), + set_sw(bn(d,[t,f]),[0.70,0.30]), + set_sw(bn(d,[f,t]),[0.80,0.20]), + set_sw(bn(d,[f,f]),[0.10,0.90]). diff --git a/packages/prism/exs/noisy_or/README b/packages/prism/exs/noisy_or/README new file mode 100644 index 000000000..c7640e4d4 --- /dev/null +++ b/packages/prism/exs/noisy_or/README @@ -0,0 +1,7 @@ +================== README (exs/noisy_or) ========================== + +Files: + README ... this file + alarm_nor_basic.psm ... BN program using noisy OR (network-specific) + alarm_nor_generic.psm ... BN program using noisy OR (network-independent) + noisy_or.psm ... library for noisy OR diff --git a/packages/prism/exs/noisy_or/alarm_nor_basic.psm b/packages/prism/exs/noisy_or/alarm_nor_basic.psm new file mode 100644 index 000000000..1e801f31b --- /dev/null +++ b/packages/prism/exs/noisy_or/alarm_nor_basic.psm @@ -0,0 +1,160 @@ +%%%% +%%%% Bayesian networks using noisy OR (1) -- alarm_nor_basic.psm +%%%% +%%%% Copyright (C) 2004,2006,2007,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% This example is borrowed from: +%% Poole, D., Probabilistic Horn abduction and Bayesian networks, +%% In Proc. of Artificial Intelligence 64, pp.81-129, 1993. +%% +%% (Fire) (Tampering) +%% / \ / +%% ((Smoke)) (Alarm) +%% | +%% (Leaving) (( )) -- observable node +%% | ( ) -- hidden node +%% ((Report)) +%% +%% In this network, we assume that all rvs (random variables) take on +%% {yes,no} and also assume that only two nodes, `Smoke' and `Report', are +%% observable. +%% +%% Furthermore, in this program, we consider that the Alarm variable's CPT +%% (conditional probability table) given through the noisy-OR rule. That is, +%% let us assume that we have the following inhibition probabilities: +%% +%% P(Alarm=no | Fire=yes, Tampering=no) = 0.3 +%% P(Alarm=no | Fire=no, Tampering=yes) = 0.2 +%% +%% The CPT for the Alarm variable is then constructed from these inhibition +%% probabilities and the noisy-OR rule: +%% +%% +------+-----------+--------------------+----------------+ +%% | Fire | Tampering | P(Alarm=yes) | P(Alarm=no) | +%% +------+-----------+--------------------+----------------+ +%% | yes | yes | 0.94 = 1 - 0.3*0.2 | 0.06 = 0.3*0.2 | +%% | yes | no | 0.7 = 1 - 0.3 | 0.3 | +%% | no | yes | 0.8 = 1 - 0.2 | 0.2 | +%% | no | no | 0 | 1.0 | +%% +------+-----------+--------------------+----------------+ +%% +%% cpt_al/3 in this program implements the above CPT with random switches. +%% The key step is to consider the generation process underlying the noisy-OR +%% rule. One may notice that this program is written in a network-specific +%% form, but a more generic, network-independent program is given in +%% alarm_nor_generic.psm. +%% +%% Please note that this program shares a considerably large part with +%% ../alarm.psm, so some comments are omitted for simplicity. + +%%------------------------------------- +%% Quick start: +%% +%% ?- prism(alarm_nor_basic). +%% +%% Print the CPT of the Alarm variable constructed from the noisy OR rule: +%% ?- print_dist_al. +%% +%% Print logical formulas that express the probabilistic behavior of +%% the noisy OR rule for Alarm: +%% ?- print_expl_al. +%% +%% Get the probability and the explanation graph: +%% ?- prob(world(yes,no)). +%% ?- probf(world(yes,no)). +%% +%% Get the most likely explanation and its probability: +%% ?- viterbif(world(yes,no)). +%% ?- viterbi(world(yes,no)). +%% +%% Compute conditional hindsight probabilities: +%% ?- chindsight(world(yes,no),world(_,_,_,_,_,_)). +%% ?- chindsight_agg(world(yes,no),world(_,_,query,yes,_,no)). +%% +%% Learn parameters from randomly generated 100 samples +%% ?- alarm_learn(100). + +go:- alarm_learn(100). + +%%------------------------------------- +%% Declarations: + +values(_,[yes,no]). + +%%------------------------------------ +%% Modeling part: + +world(Sm,Re):- world(_,_,_,Sm,_,Re). + +world(Fi,Ta,Al,Sm,Le,Re) :- + cpt_fi(Fi), % P(Fire) + cpt_ta(Ta), % P(Tampering) + cpt_sm(Fi,Sm), % CPT P(Smoke | Fire) + cpt_al(Fi,Ta,Al), % CPT P(Alarm | Fire,Tampering) + cpt_le(Al,Le), % CPT P(Leaving | Alarm) + cpt_re(Le,Re). % CPT P(Report | Leaving) + +cpt_fi(Fi):- msw(fi,Fi). +cpt_ta(Ta):- msw(ta,Ta). +cpt_sm(Fi,Sm):- msw(sm(Fi),Sm). +cpt_al(Fi,Ta,Al):- % implementation of noisy OR: + ( Fi = yes, Ta = yes -> + msw(cause_al_fi,N_Al_Fi), + msw(cause_al_ta,N_Al_Ta), + ( N_Al_Fi = no, N_Al_Ta = no -> Al = no + ; Al = yes + ) + ; Fi = yes, Ta = no -> msw(cause_al_fi,Al) + ; Fi = no, Ta = yes -> msw(cause_al_ta,Al) + ; Fi = no, Ta = no -> Al = no + ). +cpt_le(Al,Le):- msw(le(Al),Le). +cpt_re(Le,Re):- msw(re(Le),Re). + +%%------------------------------------ +%% Utility part: + +alarm_learn(N) :- + unfix_sw(_), % Make all parameters changeable + set_params, % Set parameters as you specified + get_samples(N,world(_,_),Gs), % Get N samples + fix_sw(fi), % Preserve the parameter values + learn(Gs). % for {msw(fi,yes), msw(fi,no)} + +set_params :- + set_sw(fi,[0.1,0.9]), + set_sw(ta,[0.15,0.85]), + set_sw(sm(yes),[0.95,0.05]), + set_sw(sm(no),[0.05,0.95]), + set_sw(le(yes),[0.88,0.12]), + set_sw(le(no),[0.01,0.99]), + set_sw(re(yes),[0.75,0.25]), + set_sw(re(no),[0.10,0.90]), + set_sw(cause_al_fi,[0.7,0.3]), % switch for an inhibition prob + set_sw(cause_al_ta,[0.8,0.2]). % switch for an inhibition prob + +:- set_params. + +%% Check routine for Noisy-OR +print_dist_al:- + set_params, + ( member(Fi,[yes,no]), + member(Ta,[yes,no]), + member(Al,[yes,no]), + prob(cpt_al(Fi,Ta,Al),P), + format("P(al=~w | fi=~w, ta=~w):~t~6f~n",[Al,Fi,Ta,P]), + fail + ; true + ). + +print_expl_al:- + set_params, + ( member(Fi,[yes,no]), + member(Ta,[yes,no]), + member(Al,[yes,no]), + probf(cpt_al(Fi,Ta,Al)), + fail + ; true + ). diff --git a/packages/prism/exs/noisy_or/alarm_nor_generic.psm b/packages/prism/exs/noisy_or/alarm_nor_generic.psm new file mode 100644 index 000000000..5db78b863 --- /dev/null +++ b/packages/prism/exs/noisy_or/alarm_nor_generic.psm @@ -0,0 +1,174 @@ +%%%% +%%%% Bayesian networks using noisy OR (2) -- alarm_nor_generic.psm +%%%% +%%%% Copyright (C) 2004,2006,2007,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% This example is borrowed from: +%% Poole, D., Probabilistic Horn abduction and Bayesian networks, +%% In Proc. of Artificial Intelligence 64, pp.81-129, 1993. +%% +%% (Fire) (Tampering) +%% / \ / +%% ((Smoke)) (Alarm) +%% | +%% (Leaving) (( )) -- observable node +%% | ( ) -- hidden node +%% ((Report)) +%% +%% In this network, we assume that all rvs (random variables) take on +%% {yes,no} and also assume that only two nodes, `Smoke' and `Report', are +%% observable. +%% +%% Furthermore, as did in alarm_nor_basic.psm, we consider that the Alarm +%% variable's CPT given through the noisy-OR rule. That is, we have the +%% following inhibition probabilities: +%% +%% P(Alarm=no | Fire=yes, Tampering=no) = 0.3 +%% P(Alarm=no | Fire=no, Tampering=yes) = 0.2 +%% +%% The CPT for the Alarm variable is then constructed from these inhibition +%% probabilities and the noisy-OR rule: +%% +%% +------+-----------+--------------------+----------------+ +%% | Fire | Tampering | P(Alarm=yes) | P(Alarm=no) | +%% +------+-----------+--------------------+----------------+ +%% | yes | yes | 0.94 = 1 - 0.3*0.2 | 0.06 = 0.3*0.2 | +%% | yes | no | 0.7 = 1 - 0.3 | 0.3 | +%% | no | yes | 0.8 = 1 - 0.2 | 0.2 | +%% | no | no | 0 | 1.0 | +%% +------+-----------+--------------------+----------------+ +%% +%% While alarm_nor_basic.psm uses network-specific implementation, in this +%% program, we attempt to introduce a more generic routine that can handle +%% noisy OR. To be more concrete: +%% +%% - We specify noisy OR nodes in a declarative form (with noisy_or/3). +%% - We introduce generic probabilistic predicates that make probabilistic +%% choices, following the specifications of noisy OR nodes. +%% +%% The definition of these generic probabilistic predicates are given in +%% noisy_or.psm, and we will include noisy_or.psm into this program. +%% + +%%------------------------------------- +%% Quick start (the same as those listed in alarm_nor_basic.psm): +%% +%% ?- prism(alarm_nor_generic). +%% +%% Print the CPT of the Alarm variable constructed from the noisy OR rule: +%% ?- print_dist_al. +%% +%% Print logical formulas that express the probabilistic behavior of +%% the noisy OR rule for Alarm: +%% ?- print_expl_al. +%% +%% Get the probability and the explanation graph: +%% ?- prob(world(yes,no)). +%% ?- probf(world(yes,no)). +%% +%% Get the most likely explanation and its probability: +%% ?- viterbif(world(yes,no)). +%% ?- viterbi(world(yes,no)). +%% +%% Compute conditional hindsight probabilities: +%% ?- chindsight(world(yes,no),world(_,_,_,_,_,_)). +%% ?- chindsight_agg(world(yes,no),world(_,_,query,yes,_,no)). +%% +%% Learn parameters from randomly generated 100 samples +%% ?- alarm_learn(100). + +%%------------------------------------- +%% Declarations: + +values(_,[yes,no]). + +:- include('noisy_or.psm'). + % We include generic probabilistic predicates that can handle + % noisy-OR. The following predicates will be available: + % + % - cpt(X,PaVs,V) represents a probabilistic choice where a + % random variable X given instantiations PaVs of parents + % takes a value V. If X is an ordinary node, a random + % switch bn(X,PaVs) will be used. On the other hand, if + % X is a noisy-OR node, switch cause(X,Y) will be used, + % where Y is one of parents of X. + % + % - set_nor_params/0 sets inhibition probabilisties (i.e. + % the parameters of switches cause(X,Y)) according to + % the specifications for noisy-OR nodes with noisy_or/3. + +%%------------------------------------ +%% Modeling part: + +world(Sm,Re):- world(_,_,_,Sm,_,Re). + +world(Fi,Ta,Al,Sm,Le,Re) :- + cpt(fi,[],Fi), % P(Fire) + cpt(ta,[],Ta), % P(Tampering) + cpt(sm,[Fi],Sm), % CPT P(Smoke | Fire) + cpt(al,[Fi,Ta],Al), % CPT P(Alarm | Fire,Tampering) + cpt(le,[Al],Le), % CPT P(Leaving | Alarm) + cpt(re,[Le],Re). % CPT P(Report | Leaving) + + +% declarations for noisy OR nodes: +noisy_or(al,[fi,ta],[[0.7,0.3],[0.8,0.2]]). + +%%------------------------------------ +%% Utility part: + +alarm_learn(N) :- + unfix_sw(_), % Make all parameters changeable + set_params, % Set ordinary parameters + set_nor_params, % Set inhibition parameters + get_samples(N,world(_,_),Gs), % Get N samples + fix_sw(bn(fi,[])), % Preserve the parameter values + learn(Gs). % for {msw(bn(fi,[]),yes), msw(bn(fi,[]),no)} + +:- set_params. +:- set_nor_params. + +set_params:- + set_sw(bn(fi,[]),[0.1,0.9]), + set_sw(bn(ta,[]),[0.15,0.85]), + set_sw(bn(sm,[yes]),[0.95,0.05]), + set_sw(bn(sm,[no]),[0.05,0.95]), + set_sw(bn(le,[yes]),[0.88,0.12]), + set_sw(bn(le,[no]),[0.01,0.99]), + set_sw(bn(re,[yes]),[0.75,0.25]), + set_sw(bn(re,[no]),[0.10,0.90]). + +%% Check routine for Noisy-OR + +print_dist_al:- + ( member(Fi,[yes,no]), + member(Ta,[yes,no]), + member(Al,[yes,no]), + get_cpt_prob(al,[Fi,Ta],Al,P), + format("P(al=~w | fi=~w, ta=~w):~t~6f~n",[Al,Fi,Ta,P]), + fail + ; true + ). + +print_expl_al:- + ( member(Fi,[yes,no]), + member(Ta,[yes,no]), + member(Al,[yes,no]), + get_cpt_probf(al,[Fi,Ta],Al), + fail + ; true + ). + +%% [Note] prob/1 and probf/1 will fail if its argument fails + +get_cpt_prob(X,PaVs,V,P):- + ( prob(cpt(X,PaVs,V),P) + ; P = 0.0 + ),!. + +get_cpt_probf(X,PaVs,V):- + ( probf(cpt(X,PaVs,V)) + ; format("cpt(~w,~w,~w): always false~n",[X,PaVs,V]) + ),!. diff --git a/packages/prism/exs/noisy_or/noisy_or.psm b/packages/prism/exs/noisy_or/noisy_or.psm new file mode 100644 index 000000000..06eb4b3c1 --- /dev/null +++ b/packages/prism/exs/noisy_or/noisy_or.psm @@ -0,0 +1,65 @@ +%%%% +%%%% Library for generic noisy OR predicates --- noisy_or.psm +%%%% +%%%% Copyright (C) 2007,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% When this file included, the following predicates will be available: +%% +%% - cpt(X,PaVs,V) represents a probabilistic choice where a +%% random variable X given instantiations PaVs of parents +%% takes a value V. If X is an ordinary node, a random +%% switch bn(X,PaVs) will be used. On the other hand, if +%% X is a noisy-OR node, switch cause(X,Y) will be used, +%% where Y is one of parents of X. +%% +%% - set_nor_params/0 sets inhibition probabilisties (i.e. +%% the parameters of switches cause(X,Y)) according to +%% the specifications for noisy-OR nodes with noisy_or/3. + +%%--------------------------------------- +%% Declarations: + +% added just for making the results of probabilistic inference +% simple and readable: +:- p_not_table choose_noisy_or/4, choose_noisy_or/6. + +%%--------------------------------------- +%% Modeling part: + +cpt(X,PaVs,V):- + ( noisy_or(X,Pa,_) -> choose_noisy_or(X,Pa,PaVs,V) % for noisy OR nodes + ; msw(bn(X,PaVs),V) % for ordinary nodes + ). + +choose_noisy_or(X,Pa,PaVs,V):- choose_noisy_or(X,Pa,PaVs,no,no,V). + +choose_noisy_or(_,[],[],yes,V,V). +choose_noisy_or(_,[],[],no,_,no). +choose_noisy_or(X,[Y|Pa],[PaV|PaVs],PaHasYes0,ValHasYes0,V):- + ( PaV=yes -> + msw(cause(X,Y),V0), + PaHasYes=yes, + ( ValHasYes0=no, V0=no -> ValHasYes=no + ; ValHasYes=yes + ) + ; PaHasYes=PaHasYes0, + ValHasYes=ValHasYes0 + ), % do not insert the cut symbol here + choose_noisy_or(X,Pa,PaVs,PaHasYes,ValHasYes,V). + + +%%--------------------------------------- +%% Utility part: + +set_nor_params:- + ( noisy_or(X,Pa,DistList), % spec for a noisy OR node + set_nor_params(X,Pa,DistList), + fail + ; true + ). +set_nor_params(_,[],[]). +set_nor_params(X,[Y|Pa],[Dist|DistList]):- + set_sw(cause(X,Y),Dist),!, + set_nor_params(X,Pa,DistList). diff --git a/packages/prism/exs/pdcg.psm b/packages/prism/exs/pdcg.psm new file mode 100644 index 000000000..2134f0432 --- /dev/null +++ b/packages/prism/exs/pdcg.psm @@ -0,0 +1,89 @@ +%%%% +%%%% Probabilistic DCG --- pdcg.psm +%%%% +%%%% Copyright (C) 2004,2006,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% PCFGs (probabilistic contex free grammars) are a stochastic extension +%% of CFG grammar such that in a (leftmost) derivation, each production +%% rule is selected probabilistically and applied. Look at the following +%% sample PCFG in which S is a start symbol and {a,b} are terminals. +%% +%% Rule 1: S -> SS (0.4) +%% Rule 2: S -> a (0.5) +%% Rule 3: S -> b (0.1) +%% +%% When S is expanded, three rules, Rule 1, 2 and 3 are applicable. +%% To determine a rule to apply, probabilistic selection is made in +%% such a way that Rule 1 is selected with probability 0.4, Rule 2 +%% with probability 0.5 and Rule 3 with probability 0.1, respectively. +%% The probability of a derivation tree is defined to be the product +%% of probabilities associated with rules used in the derivation, +%% and that of a sentence is defined to be the sum of proabibities of +%% derivations for the sentence. +%% +%% When modeling PCFGs, we follow DCG (definite clause grammar) +%% formalism. So we write down a top-down parser using difference +%% list which represents the rest of the sentence to parse. Note that +%% the grammar is left-recursive, and hence running the program below +%% without a tabling mechanism goes into an infinite loop. + +%%------------------------------------- +%% Quick start : learning experiment with the sample grammar +%% +%% ?- prism(pdcg),go. % Learn parameters of the PCFG above from +%% % randomly generated 100 samples +%% +%% ?- prob(pdcg([a,b,b])). +%% ?- prob(pdcg([a,b,b]),P). +%% ?- probf(pdcg([a,b,b])). +%% ?- probf(pdcg([a,b,b]),E),print_graph(E). +%% ?- sample(pdcg(X)). +%% +%% ?- viterbi(pdcg([a,b,b]),P). % P is the prob. of the most likely +%% ?- viterbif(pdcg([a,b,b]),P,E). % explanation E for pdcg([a,b,b]) +%% ?- viterbif(pdcg([a,b,b]),P,E),print_graph(E). + +go:- pdcg_learn(100). +max_str_len(20). % Maximum string length is 20. + +%%------------------------------------ +%% Declarations: + +values('S',[['S','S'],a,b],[0.4,0.5,0.1]). + % We use a msw of the form msw('S',V) such + % that V is one of { ['S','S'], a, b }, + % and when msw('S',V) is executed, the prob. + % of V=['S','S'] is 0.4, that of V=a is 0.5 + % and that of V=b is 0.1. + +%%------------------------------------ +%% Modeling part: + +start_symbol('S'). % Start symbol is S + +pdcg(L):- + start_symbol(I), + pdcg2(I,L-[]). + % I is a category to expand. +pdcg2(I,L0-L2):- % L0-L2 is a list for I to span. + msw(I,RHS), % Choose a rule I -> RHS probabilistically. + ( RHS == ['S','S'], + pdcg2('S',L0-L1), + pdcg2('S',L1-L2) + ; RHS == a, + L0 = [RHS | L2] + ; RHS == b, + L0 = [RHS | L2] ). + +%%------------------------------------ +%% Utility part: + +pdcg_learn(N):- + max_str_len(MaxStrL), + get_samples_c(N,pdcg(X),(length(X,Y),Y =< MaxStrL),Goals,[Ns,_]), + format("#sentences= ~d~n",[Ns]), + unfix_sw('S'), % Make parameters of msw('S',.) changable + learn(Goals). % Conduct ML estimation by graphical EM learning + diff --git a/packages/prism/exs/pdcg_c.psm b/packages/prism/exs/pdcg_c.psm new file mode 100644 index 000000000..f4847b385 --- /dev/null +++ b/packages/prism/exs/pdcg_c.psm @@ -0,0 +1,121 @@ +%%%% +%%%% Probabilistic DCG for Charniak's example --- pdcg_c.psm +%%%% +%%%% Copyright (C) 2007,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% As described in the comments in pdcg.psm, PCFGs (probabilistic context- +%% free grammars) are a stochastic extension of CFG grammar such that in a +%% (leftmost) derivation, each production rule is selected probabilistically +%% and applied. This program presents an implementation of an example from +%% Charniak's textbook (Statistical Language Learning, The MIT Press, 1993): +%% +%% s --> np vp (0.8) | verb --> swat (0.2) +%% s --> vp (0.2) | verb --> flies (0.4) +%% np --> noun (0.4) | verb --> like (0.4) +%% np --> noun pp (0.4) | noun --> swat (0.05) +%% np --> noun np (0.2) | noun --> flies (0.45) +%% vp --> verb (0.3) | noun --> ants (0.5) +%% vp --> verb np (0.3) | prep --> like (1.0) +%% vp --> verb pp (0.2) | +%% vp --> verb np pp (0.2) | +%% pp --> prep np (1.0) | +%% (`s' is the start symbol) +%% +%% This program has a grammar-independent part (pcfg/1-2 and proj/2), +%% which can work with any underlying CFG which has no epsilon rules +%% and produces no unit cycles. + +%%---------------------------------- +%% Quick start: +%% +%% ?- prism(pdcg_c). +%% +%% ?- prob(pcfg([swat,flies,like,ants])). +%% % get the generative probability of a sentence +%% % "swat flies like ants" +%% +%% ?- sample(pcfg(_X)),viterbif(pcfg(_X)). +%% % parse a sampled sentence +%% +%% ?- get_samples(50,pcfg(X),_Gs),learn(_Gs),show_sw. +%% % conduct an artificial learning experiments +%% +%% ?- viterbif(pcfg([swat,flies,like,ants])). +%% % get the most probabile parse for "swat flies like ants" +%% +%% ?- n_viterbif(3,pcfg([swat,flies,like,ants])). +%% % get top 3 ranked parses for "swat flies like ants" +%% +%% ?- viterbit(pcfg([swat,flies,like,ants])). +%% % print the most probabile parse for "swat flies like ants" in +%% % a tree form. +%% +%% ?- viterbit(pcfg([swat,flies,like,ants]),P,E), build_tree(E,T). +%% % get the most probabile parse for "swat flies like ants" in a +%% % tree form, and convert it to a more readable Prolog term. +%% +%% ?- probfi(pcfg([swat,flies,like,ants])). +%% % print the parse forest with inside probabilities +%% + +%%---------------------------------- +%% Declarations: + +values(s,[[np,vp],[vp]]). +values(np,[[noun],[noun,pp],[noun,np]]). +values(vp,[[verb],[verb,np],[verb,pp],[verb,np,pp]]). +values(pp,[[prep,np]]). +values(verb,[[swat],[flies],[like]]). +values(noun,[[swat],[flies],[ants]]). +values(prep,[[like]]). + +:- p_not_table proj/2. % This declaration is introduced just for + % making the results of probabilistic inferences + % simple and readable. + +%%---------------------------------- +%% Modeling part: + +pcfg(L):- pcfg(s,L-[]). +pcfg(LHS,L0-L1):- + ( nonterminal(LHS) -> msw(LHS,RHS),proj(RHS,L0-L1) + ; L0 = [LHS|L1] + ). + +proj([],L-L). +proj([X|Xs],L0-L1):- + pcfg(X,L0-L2),proj(Xs,L2-L1). + +nonterminal(s). +nonterminal(np). +nonterminal(vp). +nonterminal(pp). +nonterminal(verb). +nonterminal(noun). +nonterminal(prep). + +%%---------------------------------- +%% Utility part: + +% set the rule probabilities: +:- set_sw(s,[0.8,0.2]). +:- set_sw(np,[0.4,0.4,0.2]). +:- set_sw(vp,[0.3,0.3,0.2,0.2]). +:- set_sw(pp,[1.0]). +:- set_sw(verb,[0.2,0.4,0.4]). +:- set_sw(noun,[0.05,0.45,0.5]). +:- set_sw(prep,[1.0]). + +% build_tree(E,T):- +% Build a parse tree T from a tree-formed explanation E. + +build_tree([],[]). +build_tree([pcfg(_),Gs],T) :- build_tree(Gs,T). +build_tree([pcfg(Sym,_)|Gs],T) :- build_tree1(Gs,T0),T=..[Sym|T0]. + +build_tree1([],[]). +build_tree1([pcfg(Sym,_)|Gs],[Sym|T]) :- !,build_tree1(Gs,T). +build_tree1([msw(_,_)|Gs],T) :- !, build_tree1(Gs,T). +build_tree1([G|Gs],[T0|T]) :- build_tree(G,T0),!,build_tree1(Gs,T). diff --git a/packages/prism/exs/phmm.dat b/packages/prism/exs/phmm.dat new file mode 100644 index 000000000..2c81f004a --- /dev/null +++ b/packages/prism/exs/phmm.dat @@ -0,0 +1,44 @@ +%% This data was created by Rose. +%% see http://bibiserv.techfak.uni-bielefeld.de/rose + +%% Rose +%% Copyright (c) 1997-2000 University of Bielefeld, Germany and +%% Deutsches Krebsforschungszentrum (DKFZ) Heidelberg, Germany. +%% All rights reserved. + +%% +%% correct alignments +%% +%% HLKIANRKDK----HHNKEFGGHHLA +%% HLKATHRKDQ----HHNREFGGHHLA +%% VLKFANRKSK----HHNKEMGAHHLA +%% HKKGAT---------------PVNVS +%% HKKGATATG-----------NPKHVC +%% QFKVAAAVGK----HQDASRGVHHID +%% SFKGQGAVSK----HQDPEWGVHHID +%% SFKGQGAVSV----PQAPAWGINHID +%% HFKSQAEVNK----HDRPEWGLNQID +%% HFRSQAEVNQRQFNHHRPQWSFNQIG +%% SFNVVKGASK----RENGGMGAEPVD +%% KFKKVDGLGK----KEHPALGVH--- +%% KFMVGGKDGK----NRKDAHAHRKVE +%% KYKVPEKDGK----KRTNAHSHRKVE +%% RYKIPESDGK----KRTNSHRHRKVE +%% RYKIASMDGK----KRYAEHKHKKLE + +observe( ['H','L','K','I','A','N','R','K','D','K','H','H','N','K','E','F','G','G','H','H','L','A'] ). +observe( ['H','L','K','A','T','H','R','K','D','Q','H','H','N','R','E','F','G','G','H','H','L','A'] ). +observe( ['V','L','K','F','A','N','R','K','S','K','H','H','N','K','E','M','G','A','H','H','L','A'] ). +observe( ['H','K','K','G','A','T','P','V','N','V','S'] ). +observe( ['H','K','K','G','A','T','A','T','G','N','P','K','H','V','C'] ). +observe( ['Q','F','K','V','A','A','A','V','G','K','H','Q','D','A','S','R','G','V','H','H','I','D'] ). +observe( ['S','F','K','G','Q','G','A','V','S','K','H','Q','D','P','E','W','G','V','H','H','I','D'] ). +observe( ['S','F','K','G','Q','G','A','V','S','V','P','Q','A','P','A','W','G','I','N','H','I','D'] ). +observe( ['H','F','K','S','Q','A','E','V','N','K','H','D','R','P','E','W','G','L','N','Q','I','D'] ). +observe( ['H','F','R','S','Q','A','E','V','N','Q','R','Q','F','N','H','H','R','P','Q','W','S','F','N','Q','I','G'] ). +observe( ['S','F','N','V','V','K','G','A','S','K','R','E','N','G','G','M','G','A','E','P','V','D'] ). +observe( ['K','F','K','K','V','D','G','L','G','K','K','E','H','P','A','L','G','V','H'] ). +observe( ['K','F','M','V','G','G','K','D','G','K','N','R','K','D','A','H','A','H','R','K','V','E'] ). +observe( ['K','Y','K','V','P','E','K','D','G','K','K','R','T','N','A','H','S','H','R','K','V','E'] ). +observe( ['R','Y','K','I','P','E','S','D','G','K','K','R','T','N','S','H','R','H','R','K','V','E'] ). +observe( ['R','Y','K','I','A','S','M','D','G','K','K','R','Y','A','E','H','K','H','K','K','L','E'] ). diff --git a/packages/prism/exs/phmm.psm b/packages/prism/exs/phmm.psm new file mode 100644 index 000000000..9bcae526a --- /dev/null +++ b/packages/prism/exs/phmm.psm @@ -0,0 +1,263 @@ +%%%% +%%%% Profile HMM --- phmm.psm +%%%% +%%%% Copyright (C) 2004,2006,2007,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% Profile HMMs are a variant of HMMs that have three types of states, +%% i.e. `match state',`insert state' and `delete state.' Match states +%% constitute an HMM that outputs a `true' string. Insertion states +%% emit a symbol additionally to the `true' string whereas delete (skip) +%% states emit no symbol. +%% +%% Profile HMMs are used to align amino-acid sequences by inserting +%% and skipping symbols as well as matching symbols. For example +%% amino-acid sequences below +%% +%% HLKIANRKDKHHNKEFGGHHLA +%% HLKATHRKDQHHNREFGGHHLA +%% VLKFANRKSKHHNKEMGAHHLA +%% ... +%% +%% are aligned by the profile HMM program in this file as follows. +%% +%% -HLKIA-NRKDK-H-H----NKEFGGHH-LA +%% -HLK-A-T-HRK-DQHHN--R-EFGGHH-LA +%% -VLKFA-NRKSK-H-H----NKEMGAHH-LA +%% ... + +%%------------------------------------- +%% Quick start : sample session, align the sample data in phmm.dat. +%% +%% To run on an interactive session: +%% ?- prism(phmm),go. (ML/MAP) +%% ?- prism(phmm),go_vb. (variational Bayes) +%% +%% To perform a batch execution: +%% > upprism phmm + +go :- + read_goals(Gs,'phmm.dat'), % Read the sequence data from phmm.dat. + learn(Gs), % Learn parameters from the data. + wmag(Gs). % Compute viterbi paths using the learned + % parameters and aligns sequences in Gs. + +% To enable variational Bayes, we need some additional flag settings: +go_vb :- + set_prism_flag(learn_mode,both), + set_prism_flag(viterbi_mode,hparams), + set_prism_flag(reset_hparams,on), + go. + +prism_main :- go. +%prism_main :- go_vb. + + +%%%--------------------- model --------------------- + +observe(Sequence) :- hmm(Sequence,start). + +hmm([],end). +hmm(Sequence,State) :- + State \== end, + msw(move_from(State),NextState), + msw(emit_at(State), Symbol), + ( Symbol = epsilon -> + hmm( Sequence, NextState ) + ; Sequence = [Symbol|TailSeq], + hmm( TailSeq , NextState ) + ). + +amino_acids(['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R', + 'S','T','V','W','X','Y']). +hmm_len(17). + +%%%--------------------- values --------------------- + +values(move_from(State),Values) :- + hmm_len(Len), + get_index(State,X), + ( 0 =< X, X < Len -> + Y is X + 1, + Values = [insert(X),match(Y),delete(Y)] + ; Values = [insert(X),end] ). + +values(emit_at(State),Vs) :- + ((State = insert(_) ; State = match(_)) -> + amino_acids(Vs) + ; Vs = [epsilon] ). + +%%%--------------------- set_sw --------------------- + +:- init_set_sw. + +init_set_sw :- +% tell('/dev/null'), % Suppress output (on Linux only) + set_sw( move_from(start) ), + set_sw( move_from(insert(0)) ), + set_sw( emit_at(start) ), + set_sw( emit_at(insert(0)) ), + hmm_len(Len), +% told, + init_set_sw(Len). + +init_set_sw(0). +init_set_sw(X) :- + X > 0, + set_sw( move_from(insert(X)) ), + set_sw( move_from(match(X)) ), + set_sw( move_from(delete(X)) ), + set_sw( emit_at(insert(X)) ), + set_sw( emit_at(match(X)) ), + set_sw( emit_at(delete(X)) ), + Y is X - 1, + init_set_sw(Y). + +%%%--------------------- estimation --------------------- +%% most likely path +%% mlpath(['A','E'],Path) => Path = [start,match(1),end] + +mlpath(Sequence,Path):- + mlpath(Sequence,Path,_). +mlpath(Sequence,Path,Prob):- + viterbif(hmm(Sequence,start),Prob,Nodes), + nodes2path(Nodes,Path). + +nodes2path([Node|Nodes],[State|Path]):- + Node = node(hmm(_,State),_), + nodes2path(Nodes,Path). +nodes2path([],[]). + +mlpaths([Seq|Seqs],[Path|Paths], X):- + mlpath(Seq,Path), +X= [P|_], writeln(P), +stop_low_level_trace, + mlpaths(Seqs,Paths, X). +mlpaths([],[],_). + +%%%--------------------- alignment --------------------- + +wmag(Gs):- + seqs2goals(S,Gs),wma(S). +wma(Seqs):- + write_multiple_alignments(Seqs). +write_multiple_alignments(Seqs):- + nl, + write('search Viterbi paths...'),nl, + mlpaths(Seqs,Paths,Paths), + write('done.'), + nl, + write('------------ALIGNMENTS------------'), + nl, + write_multiple_alignments( Seqs, Paths ), + write('----------------------------------'), + nl. + +make_max_length_list([Path|Paths],MaxLenList) :- + make_max_length_list(Paths, TmpLenList), + make_length_list(Path,LenList), + marge_len_list(LenList,TmpLenList,MaxLenList). +make_max_length_list([Path],MaxLenList) :- + !,make_length_list(Path,MaxLenList). + +marge_len_list([H1|T1],[H2|T2],[MargedH|MargedT]) :- + max(MargedH,[H1,H2]), + marge_len_list(T1,T2,MargedT). +marge_len_list([],[],[]). + +%% make_length_list([start,insert(0),match(1),end],LenList) +%% -> LenList = [2,1] +%% make_length_list([start,delete(1),insert(1),insert(1),end],LenList) +%% -> LenList = [1,1] + +make_length_list(Path,[Len|LenList]) :- + count_emission(Path,Len,NextIndexPath), + make_length_list(NextIndexPath,LenList). +make_length_list([end],[]). + +count_emission(Path,Count,NextIndexPath) :- + Path = [State|_], + get_index(State,Index), + count_emission2(Path,Count,Index,NextIndexPath). + +%% count_emission2([start,insert(0),match(1),end],Count,0,NextIndexPath) +%% -> Count = 2, NextIndexPath = [match(1),end] +%% count_emission2([delete(1),insert(1),insert(1),end],Count,1,NextIndexPath) +%% -> Count = 2, NextIndexPath = [end] + +count_emission2([State|Path],Count,Index,NextIndexPath) :- + ( get_index(State,Index) -> + count_emission2( Path, Count2, Index, NextIndexPath ), + ( (State = delete(_); State==start) -> + Count = Count2 + ; Count is Count2 + 1 ) + ; Count = 0, + NextIndexPath = [State|Path] + ). + +write_multiple_alignments(Seqs,Paths) :- + make_max_length_list(Paths,LenList), + write_multiple_alignments(Seqs,Paths,LenList). +write_multiple_alignments([Seq|Seqs],[Path|Paths],LenList) :- + write_alignment(Seq,Path,LenList), + write_multiple_alignments(Seqs,Paths,LenList). +write_multiple_alignments([],[],_). + +write_alignment(Seq,Path,LenList) :- + write_alignment(Seq,Path,LenList,0). + +write_alignment([],[end],[],_):- !,nl. +write_alignment(Seq,[State|Path],LenList,Index) :- + get_index(State,Index),!, + ( (State = delete(_) ; State == start) -> + write_alignment( Seq, Path, LenList, Index ) + ; Seq = [Symbol|Seq2], + LenList = [Len|LenList2], + write(Symbol), + Len2 is Len - 1, + write_alignment(Seq2,Path,[Len2|LenList2],Index) + ). +write_alignment(Seq,[State|Path],LenList,Index) :- + LenList = [Len|LenList2], + Index2 is Index + 1, + pad(Len), + write_alignment(Seq,[State|Path],LenList2,Index2). + +pad(Len) :- + Len > 0, + write('-'), + Len2 is Len - 1,!, + pad(Len2). +pad(0). + +%%%--------------------- utility --------------------- + +get_index(State,Index) :- + (State=match(_),!,State=match(Index)); + (State=insert(_),!,State=insert(Index)); + (State=delete(_),!,State=delete(Index)); + (State=start,!,Index=0); + (State=end,!,hmm_len(X),Index is X+1). + +seqs2goals([Seq|Seqs],[Goal|Goals]) :- + Goal = observe(Seq), + seqs2goals(Seqs,Goals). +seqs2goals([],[]). + +max(Max,[Head|Tail]) :- + max(Tmp,Tail),!, + ( Tmp > Head -> Max = Tmp ; Max = Head ). +max(Max,[Max]). + +read_goals(Goals,FileName) :- + see(FileName), + read_goals(Goals), + seen. +read_goals(Goals) :- + read(Term), + ( Term = end_of_file -> + Goals = [] + ; Goals = [Term|Goals1], + read_goals(Goals1) + ). diff --git a/packages/prism/exs/plc.dat b/packages/prism/exs/plc.dat new file mode 100644 index 000000000..1dd2e1083 --- /dev/null +++ b/packages/prism/exs/plc.dat @@ -0,0 +1,60 @@ +pslc([adv,n,p,v,n,adv,adv,adv,adv,v,n,p,v]). +pslc([v,n,c,v,n,p,v,n,c,n,p,v]). +pslc([adv,n,p,v,n,adv,adv,v,n,p,v,n,c,v,n,p,v,n,p,v]). +pslc([n,p,v]). +pslc([n,p,v]). +pslc([adv,adv,v,n,p,v,n,c,adv,adv,v,n,p,v,n,p,v]). +pslc([n,p,v]). +pslc([n,p,v]). +pslc([adv,adv,n,c,n,p,v,n,p,v,n,p,v,n,p,v,n,p,v]). +pslc([n,p,v]). +pslc([n,p,v]). +pslc([adv,adv,v,n,c,adv,v,n,p,v]). +pslc([n,p,v]). +pslc([v,n,c,adv,v,n,c,n,p,v,n,p,v]). +pslc([v,n,c,n,c,v,n,p,v]). +pslc([adv,adv,v,n,c,adv,v,n,c,adv,n,p,v,n,c,n,p,v,n,v,n,p,v]). +pslc([n,p,v]). +pslc([adv,n,p,v,n,c,v,n,p,v,n,v,n,p,v]). +pslc([v,n,c,n,p,v,n,p,v]). +pslc([n,c,v,n,c,n,c,n,p,v,n,p,v,n,p,v]). +pslc([v,n,c,n,p,v,n,c,adv,adv,v,n,p,v]). +pslc([adv,adv,v,n,c,v,n,p,v]). +pslc([n,p,v,n,c,adv,v,n,v,n,p,v]). +pslc([v,n,c,n,p,v,n,c,v,n,p,v]). +pslc([n,p,v]). +pslc([adv,adv,v,n,p,v,n,p,v]). +pslc([n,p,v]). +pslc([v,n,p,v]). +pslc([adv,adv,adv,n,p,v,n,p,v,n,c,v,n,v,n,c,v,n,p,v,n,c,n,p,v,n,c,n,p,v]). +pslc([v,n,p,v,n,p,v]). +pslc([v,n,p,v]). +pslc([n,c,n,p,v,n,p,v]). +pslc([n,p,v]). +pslc([adv,adv,v,n,v,n,c,adv,v,n,n,p,v,n,c,n,c,n,p,v,n,p,v,n,p,v]). +pslc([n,p,v]). +pslc([n,p,v,n,p,v]). +pslc([adv,n,adv,adv,v]). +pslc([adv,v,n,p,v,n,v,n,c,v,n,c,v,n,c,n,p,v,n,p,v,n,c,v,n,c,v,n,p,v]). +pslc([adv,adv,v,n,p,v,n,c,v,n,c,v,n,c,adv,v,n,p,v,n,p,v,n,p,v]). +pslc([n,p,v,n,p,v,n,p,v]). +pslc([n,p,v,n,c,adv,adv,v,n,p,v,n,v,n,p,v]). +pslc([adv,v,n,p,v,n,p,v]). +pslc([adv,adv,v,n,p,v]). +pslc([adv,adv,v,n,p,v,n,p,v]). +pslc([v,n,p,v]). +pslc([adv,n,p,v,n,c,adv,adv,v,n,v,n,n,p,v]). +pslc([n,p,v]). +pslc([adv,n,p,v,n,p,v]). +pslc([adv,n,p,v,n,adv,adv,v,n,c,n,p,v,n,p,v,n,c,v,n,p,v]). +pslc([n,p,v]). +pslc([n,c,v,n,c,n,p,v,n,c,adv,v,n,v,n,p,v]). +pslc([n,p,v,n,p,v,n,p,v,n,p,v]). +pslc([v,n,p,v,n,p,v]). +pslc([v,n,c,adv,v,n,c,n,p,v,n,p,v,n,c,adv,adv,v,n,p,v,n,p,v]). +pslc([n,p,v]). +pslc([v,n,p,v,n,p,v,n,c,adv,adv,v,n,p,v,n,v,n,p,v,n,p,v,n,p,v,n,p,v]). +pslc([v,n,p,v]). +pslc([n,p,v]). +pslc([n,c,adv,adv,v,n,p,v]). +pslc([n,p,v]). diff --git a/packages/prism/exs/plc.psm b/packages/prism/exs/plc.psm new file mode 100644 index 000000000..7f3295003 --- /dev/null +++ b/packages/prism/exs/plc.psm @@ -0,0 +1,215 @@ +%%%% +%%%% Probablistic left corner grammar --- plc.psm +%%%% +%%%% Copyright (C) 2004,2006,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% This is a PRISM program modeling a probabilistic left-corner +%% parser (stack version) described in +%% +%% "Probabilistic Parsing using left corner language models", +%% C.D.Manning, +%% Proc. of the 5th Int'l Conf. on Parsing Technologies (IWPT-97), +%% MIT Press, pp.147-158. +%% +%% Note that this program defines a distribution over sentences +%% procedurally, i.e. the derivation process is described in terms +%% of stack operations. Also note that we automatically get +%% a correctness-guaranteed EM procedure for probablistic +%% left-corner grammars. + +%%------------------------------------- +%% Quick start : sample session with Grammar_1 (attached below) +%% +%% (1) Move to a directory where this program is placed. +%% (2) Start PRISM (no options needed since 1.10) +%% +%% > prism +%% +%% (3) Load this program (by default, every msw is given a uniform +%% distribution) +%% +%% ?- prism(plc). +%% +%% (4) Use uitilities, e.g. +%% (4-1) Computing explanation (support) graphs and probabilities +%% +%% ?- prob(pslc([n,p,v])). +%% ?- probf(pslc([n,p,v])). +%% ?- probf(pslc([n,p,v]),E),print_graph(E). +%% ?- prob(pslc([adv,adv,n,c,n,p,v])). +%% ?- probf(pslc([adv,adv,n,c,n,p,v])). +%% ?- probf(pslc([adv,adv,n,c,n,p,v]),E),print_graph(E). +%% +%% Pv is prob. of a most likely explanation E for pslc([adv,...,v]) +%% ?- viterbif(pslc([adv,adv,n,c,n,p,v]),Pv,E). +%% ?- viterbi(pslc([adv,adv,n,c,n,p,v]),Pv). +%% +%% (4-2) Sampling +%% +%% ?- sample(pslc(X)), sample(pslc(Y)), sample(pslc(Z)). +%% +%% (4-3) Graphical EM learning for Grammar_1 (wait for some time) +%% +%% ?- go. + +go:- plc_learn(50). % Generate randomly 50 sentences and learn +max_str_len(30). % Sentence length <= 30 + +%%------------------------------------ +%% Modeling part: + +pslc(Ws) :- + start_symbol(C), % asserted in Grammar_1 + pslc(Ws,[g(C)]). % C is a top-goal category + +pslc([],[]). +pslc(L0,Stack0) :- + process(Stack0,Stack,L0,L), + pslc(L,Stack). + +%% shift operation +process([g(A)|Rest],Stack,[Wd|L],L):- % g(A) is a goal category + ( terminal(A), % Stack given = [g(A),g(F),D...] created + A = Wd, Stack = Rest % by e.g. projection using E -> D,A,F + ; \+ terminal(A), % Select probabilistically one of first(A) + ( get_values(first(A),[Wd]) % No choice if the first set is a singleton + ; get_values(first(A),[_,_|_]), % Select 1st word by msw + msw(first(A),Wd) ), + Stack = [Wd,g(A)|Rest] + ). + +%% projection and attachment +process([A|Rest],Stack,L,L):- % a subtree with top=A is completed + \+ A = g(_), % A's right neighbor has the form g(_) + Rest = [g(C)|Stack0], % => A is not a terminal + ( A == C, % g(A) is waiting for an A-tree + ( get_values(lc(A,A),_), % lc(X,Y) means X - left-corner -> Y + msw(attach(A),Op), % A must have a chance of not attaching + ( Op == attach, Stack = Stack0 % attachment + ; Op == project, next_Stack(A,Rest,Stack) ) % projection + ; \+ get_values(lc(A,A),_), + Stack = Stack0 ) % forcible attachment for nonterminal + ; A \== C, + next_Stack(A,Rest,Stack) ). + +%% projection % subtree A completed, waited for by g(C) +next_Stack(A,[g(C)|Rest2],Stack) :- % rule I -> A J K + ( get_values(lc(C,A),[_,_|_]), % => Stack=[g(J),g(K),I,g(C)...] + msw(lc(C,A),rule(LHS,[A|RHS2])) % if C - left-corner -> A + ; get_values(lc(C,A),[rule(LHS,[A|RHS2])]) ), % no other rules for projection + predict(RHS2,[LHS,g(C)|Rest2],Stack). + +predict([],L,L). +predict([A|Ls],L2,[g(A)|NewLs]):- + predict(Ls,L2,NewLs). + +%%------------------------------------ +%% Utility part: + +plc_learn(N):- + gen_plc(N,Goals), + learn(Goals). + +gen_plc(0,[]). +gen_plc(N,Goals):- + N > 0, + N1 is N-1, + sample(pslc(L)), + length(L,K), + max_str_len(StrL), + ( K > StrL, + Goals = G2 + ; Goals=[pslc(L)|G2], + format(" G = ~w~n",[pslc(L)]) + ),!, + gen_plc(N1,G2). + + +%%--------------- Grammar_1 ----------------- + +start_symbol(s). + +rule(s,[pp,v]). +rule(s,[ap,vp]). +rule(vp,[pp,v]). +rule(vp,[ap,v]). +rule(np,[vp,n]). +rule(np,[v,n]). +rule(np,[n]). +rule(np,[np,c,np]). +rule(np,[ap,np]). +rule(pp,[np,p]). +rule(pp,[n,p]). +rule(ap,[adv,adv]). +rule(ap,[adv]). +rule(ap,[adv,np]). + +terminal(v). +terminal(n). +terminal(c). +terminal(p). +terminal(adv). + +%% first set computed from Grammar_1 +first(vp,v). +first(np,v). +first(pp,v). +first(s,v). +first(vp,n). +first(np,n). +first(pp,n). +first(s,n). +first(vp,adv). +first(ap,adv). +first(np,adv). +first(pp,adv). +first(s,adv). + +%%------------------------------------ +%% Declarations: +%% +%% created from Grammar_1 + +values(lc(s,pp),[rule(s,[pp,v]),rule(vp,[pp,v])]). +values(lc(s,np),[rule(np,[np,c,np]),rule(pp,[np,p])]). +values(lc(s,vp),[rule(np,[vp,n])]). +values(lc(pp,np),[rule(np,[np,c,np]),rule(pp,[np,p])]). +values(lc(pp,vp),[rule(np,[vp,n])]). +values(lc(pp,pp),[rule(vp,[pp,v])]). +values(lc(np,vp),[rule(np,[vp,n])]). +values(lc(np,pp),[rule(vp,[pp,v])]). +values(lc(np,np),[rule(np,[np,c,np]),rule(pp,[np,p])]). +values(lc(vp,pp),[rule(vp,[pp,v])]). +values(lc(vp,np),[rule(np,[np,c,np]),rule(pp,[np,p])]). +values(lc(vp,vp),[rule(np,[vp,n])]). +values(lc(vp,ap),[rule(np,[ap,np]),rule(vp,[ap,v])]). +values(lc(vp,adv),[rule(ap,[adv]),rule(ap,[adv,adv]),rule(ap,[adv,np])]). +values(lc(ap,adv),[rule(ap,[adv]),rule(ap,[adv,adv]),rule(ap,[adv,np])]). +values(lc(vp,v),[rule(np,[v,n])]). +values(lc(vp,n),[rule(np,[n]),rule(pp,[n,p])]). +values(lc(np,v),[rule(np,[v,n])]). +values(lc(np,n),[rule(np,[n]),rule(pp,[n,p])]). +values(lc(np,ap),[rule(np,[ap,np]),rule(vp,[ap,v])]). +values(lc(np,adv),[rule(ap,[adv]),rule(ap,[adv,adv]),rule(ap,[adv,np])]). +values(lc(pp,n),[rule(np,[n]),rule(pp,[n,p])]). +values(lc(pp,ap),[rule(np,[ap,np]),rule(vp,[ap,v])]). +values(lc(pp,adv),[rule(ap,[adv]),rule(ap,[adv,adv]),rule(ap,[adv,np])]). +values(lc(pp,v),[rule(np,[v,n])]). +values(lc(s,ap),[rule(np,[ap,np]),rule(s,[ap,vp]),rule(vp,[ap,v])]). +values(lc(s,adv),[rule(ap,[adv]),rule(ap,[adv,adv]),rule(ap,[adv,np])]). +values(lc(s,v),[rule(np,[v,n])]). +values(lc(s,n),[rule(np,[n]),rule(pp,[n,p])]). + +values(first(s),[adv,n,v]). +values(first(vp),[adv,n,v]). +values(first(np),[adv,n,v]). +values(first(pp),[adv,n,v]). +values(first(ap),[adv]). + +values(attach(s),[attach,project]). +values(attach(vp),[attach,project]). +values(attach(np),[attach,project]). +values(attach(pp),[attach,project]). +values(attach(ap),[attach,project]). diff --git a/packages/prism/exs/sbn.psm b/packages/prism/exs/sbn.psm new file mode 100644 index 000000000..0ee82c1eb --- /dev/null +++ b/packages/prism/exs/sbn.psm @@ -0,0 +1,130 @@ +%%%% +%%%% Bayesian networks (2) -- sbn.psm +%%%% +%%%% Copyright (C) 2004,2008 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% This example shows how to simulate Pearl's message passing +%% (without normalization) for singly connected BNs (Bayesian networks). +%% +%% Suppose that we have a Bayesian network in Fiugre 1 and that +%% we wish to compute marginal probabilites P(B) of B. +%% The distribution defined by the BN in Figure 1 is expressed +%% by a BN program in Figure 3. We transform it into another +%% program that defines the same marginal distribuion for B. +%% +%% Original graph Transformed graph +%% +%% A B B +%% / \ / | +%% / \ / v +%% C D ==> D +%% / \ / | \ +%% / \ / v v +%% E F A E F +%% / +%% v +%% C +%% (Figure 1) (Figure 2) +%% +%% Original BN program for Figure 1 +%% + world(VA,VB,VC,VD,VE,VF):- + msw(par('A',[]),VA), msw(par('B',[]),VB), + msw(par('C',[VA]),VC), msw(par('D',[VA,VB]),VD), + msw(par('E',[VD]),VE), msw(par('F',[VD]),VF). + check_B(VB):- world(_,VB,_,_,_,_). +%% +%% (Figure 3) +%% +%% Transformation: +%% [Step 1] Transform the orignal BN in Figure 1 into Figure 2 by letting +%% B be the top node and other nodes dangle from B. +%% [Step 2] Construct a program that calls nodes in Figure 2 from the top +%% node to leaves. For example for D, add clause +%% +%% call_BD(VB):- call_DA(VA),call_DE(VE),call_DF(VF). +%% +%% while inserting an msw expressing the CPT P(D|A,B) in the body. Here, +%% +%% call_XY(V) <=> +%% node Y is called from X with ground term V (=X's realization) +%% +%% It can be proved by unfolding that the transformed program is equivalent +%% in distribution semantics to the original program in Figure 3. +%% => Both programs compute the same marginal distribution for B. +%% Confirm by ?- prob(ask_B(2),X),prob(check_B(2),Y). + +%%------------------------------------- +%% Quick start : sample session +%% +%% ?- prism(sbn),go. % Learn parameters from randomly generated +%% % 100 samples while preserving the marginal +%% % disribution P(B) +%% +%% ?- prob(ask_B(2)). +%% ?- prob(ask_B(2),X),prob(check_B(2),Y). % => X=Y +%% ?- probf(ask_B(2)). +%% ?- sample(ask_B(X)). +%% +%% ?- viterbi(ask_B(2)). +%% ?- viterbif(ask_B(2),P,E),print_graph(E). + +go:- sbn_learn(100). + +%%------------------------------------ +%% Declarations: + +values(par('A',[]), [0,1]). % Declare msw(par('A',[]),VA) where +values(par('B',[]), [2,3]). % VA is one of {0,1} +values(par('C',[_]), [4,5]). +values(par('D',[_,_]),[6,7]). % Declare msw(par('D',[VA,VB]),VD) where +values(par('E',[_]), [8,9]). % VD is one of {6,7} +values(par('F',[_]), [10,11]). + +set_params:- % Call set_sw/2 built-in + set_sw(par('A',[]), [0.3,0.7]), + set_sw(par('B',[]), uniform), % => [0.5,0.5] + set_sw(par('C',[0]), f_geometric(3,asc)), % => [0.25,0.75] + set_sw(par('C',[1]), f_geometric(3,desc)), % => [0.75,0.25] + set_sw(par('D',[0,2]),f_geometric(3)), % => [0.75,0.25] + set_sw(par('D',[1,2]),f_geometric(2)), % => [0.666...,0.333...] + set_sw(par('D',[0,3]),f_geometric), % => [0.666...,0.333...] + set_sw(par('D',[1,3]),[0.3,0.7]), + set_sw(par('E',[6]), [0.3,0.7]), + set_sw(par('E',[7]), [0.1,0.9]), + set_sw(par('F',[6]), [0.3,0.7]), + set_sw(par('F',[7]), [0.1,0.9]). + +:- set_params. + +%%------------------------------------ +%% Modeling part: transformed program defining P(B) + +ask_B(VB) :- % ?- prob(ask_B(2),X) + msw(par('B',[]),VB), % => X = P(B=2) + call_BD(VB). +call_BD(VB):- % msw's Id must be ground + call_DA(VA), % => VA must be ground + msw(par('D',[VA,VB]),VD), % => call_DA(VA) + call_DE(VD), % before msw(par('D',[VA,VB]),VD) + call_DF(VD). +call_DA(VA):- + msw(par('A',[]),VA), + call_AC(VA). +call_AC(VA):- + msw(par('C',[VA]),_VC). +call_DE(VD):- + msw(par('E',[VD]),_VE). +call_DF(VD):- + msw(par('F',[VD]),_VF). + +%%------------------------------------ +%% Utility part: + +sbn_learn(N):- % Learn parameters (CPTs) from a list of + random_set_seed(123456), % N randomly generated ask_B(.) atoms + set_params, + get_samples(N,ask_B(_),Goals), + learn(Goals). diff --git a/packages/prism/exs/votes.psm b/packages/prism/exs/votes.psm new file mode 100644 index 000000000..40fd615d4 --- /dev/null +++ b/packages/prism/exs/votes.psm @@ -0,0 +1,112 @@ +%%%% +%%%% Evaluation of a naive Bayes classifier for `votes' dataset +%%%% --- votes.psm +%%%% +%%%% Copyright (C) 2009 +%%%% Sato Laboratory, Dept. of Computer Science, +%%%% Tokyo Institute of Technology + +%% In this program, we conduct n-fold cross validation of a naive Bayes +%% classifier. This program was created to demonstrate the usefulness of +%% the built-in predicates introduced since version 1.12. The target +%% dataset is the congressional voting records (`votes') dataset, which +%% is available from UCI machine learning repository (http://archive.ics. +%% uci.edu/ml/). +%% +%% From this program, it is seen that, using new built-in predicates such +%% as maplist/5, avglist/2, random_shuffle/2, and so on, we can make the +%% utility part compact, as well as the modeling part. Also one may find +%% that we only combine general-purpose built-ins to implement n-fold cross +%% validation. + +%%------------------------------------- +%% Quick start : sample session +%% +%% (Preparation: Download the data file `house-votes-84.data' from UCI ML +%% repository, and put it `as-is' on the current directly) +%% +%% ?- prism(votes),votes_learn. % Learn parameters from the whole dataset +%% +%% ?- prism(votes),votes_cv(10). % Conduct 10-fold cross validation +%% + +%%------------------------------------- +%% Declarations + +values(class,[democrat,republican]). % class labels +values(attr(_,_),[y,n]). % all attributes have two values: y or n + +%%------------------------------------- +%% Modeling part (a naive Bayes model) +%% +%% [Note] +%% According to `house-votes-84.names', a data description file for the +%% `votes' dataset, '?' simply denotes that the value is not "yea" nor +%% "nay". On the other hand, in this program, we consider '?' as a missing +%% value just for demonstration purpose. + +nbayes(C,Vals):- msw(class,C),nbayes(1,C,Vals). + +nbayes(_,_,[]). +nbayes(J,C,[V|Vals]):- + choose(J,C,V), + J1 is J+1, + nbayes(J1,C,Vals). + +choose(J,C,V):- + ( V == '?' -> msw(attr(J,C),_) % handling '?' as a missing value + ; msw(attr(J,C),V0), + V = V0 + ). + +%%------------------------------------- +%% Utility part: + +%% Batch routine for a simple learning + +votes_learn:- + load_data_file(Gs), + learn(Gs). + +%% Batch routine for N-fold cross validation + +votes_cv(N):- + random_set_seed(81729), % Fix the random seed to keep the same splitting + load_data_file(Gs0), % Load the entire data + random_shuffle(Gs0,Gs), % Randomly reorder the data + numlist(1,N,Ks), % Get Ks = [1,...,N] (B-Prolog built-in) + maplist(K,Rate,votes_cv(Gs,K,N,Rate),Ks,Rates), + % Call votes_cv/2 for K=1...N + avglist(Rates,AvgRate), % Get the avg. of the precisions + maplist(K,Rate,format("Test #~d: ~2f%~n",[K,Rate*100]),Ks,Rates), + format("Average: ~2f%~n",[AvgRate*100]). + +%% Subroutine for learning and testing for K-th split data (K = 1...N) + +votes_cv(Gs,K,N,Rate):- + format("<<<< Test #~d >>>>~n",[K]), + separate_data(Gs,K,N,Gs0,Gs1), % Gs0: training data, Gs1: test data + learn(Gs0), % Learn by PRISM's built-in + maplist(nbayes(C,Vs),R,(viterbig(nbayes(C0,Vs)),(C0==C->R=1;R=0)),Gs1,Rs), + % Predict the class by viterbig/1 for each test example + % and evaluate it with the answer class label + avglist(Rs,Rate), % Get the accuracy for the K-th splitting + format("Done (~2f%).~n~n",[Rate*100]). + +%% Split the entire data (Data) into the training data (Train) +%% and the test data (Test) for the K-th evaluation (K=1...N) + +separate_data(Data,K,N,Train,Test):- + length(Data,L), + L0 is L*(K-1)//N, % L0: offset of the test data (// - integer division) + L1 is L*(K-0)//N-L0, % L1: size of the test data + splitlist(Train0,Rest,Data,L0), % Length of Train0 = L0 + splitlist(Test,Train1,Rest,L1), % Length of Test = L1 + append(Train0,Train1,Train). + +%% Load the `votes' data in CSV form and convert it to suitable +%% Prolog terms + +load_data_file(Gs):- + load_csv('house-votes-84.data',Gs0), + maplist(csvrow([C|Vs]),nbayes(C,Vs),true,Gs0,Gs). diff --git a/packages/prism/src/README b/packages/prism/src/README new file mode 100644 index 000000000..5a127a878 --- /dev/null +++ b/packages/prism/src/README @@ -0,0 +1,16 @@ +========================== README (src) ========================== + +This directory contains the source files of the PRISM part, along +with a minimal set of source and binary files from B-Prolog, +required to build the PRISM system: + + c/ ... C code + prolog/ ... Prolog code + +Please use/modify/distribute the source code based on the license +agreements described $(TOP)/LICENSE and $(TOP)/LICENSE.src, where +$(TOP) is the top directory in the unfolded package. + +To build the PRISM system, we need to compile both C and Prolog +source files. Please follow the instructions described in READMEs +at the `c' and `prolog' directories. diff --git a/packages/prism/src/c/Makefile.in b/packages/prism/src/c/Makefile.in new file mode 100644 index 000000000..96ada7b1e --- /dev/null +++ b/packages/prism/src/c/Makefile.in @@ -0,0 +1,91 @@ +# -*- Makefile -*- + +# +# default base directory for YAP installation +# (EROOT for architecture-dependent files) +# +prefix = @prefix@ +exec_prefix = @exec_prefix@ +ROOTDIR = $(prefix) +EROOTDIR = @exec_prefix@ +abs_top_builddir = @abs_top_builddir@ +# +# where the binary should be +# +BINDIR = $(EROOTDIR)/bin +# +# where YAP should look for libraries +# +LIBDIR=@libdir@ +YAPLIBDIR=@libdir@/Yap +YAP_EXTRAS=@YAP_EXTRAS@ -D_YAP_NOT_INSTALLED_=1 -D__YAP_PROLOG__=1 +# +# +CC=@CC@ +CFLAGS= @SHLIB_CFLAGS@ $(YAP_EXTRAS) $(DEFS) -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include -I$(srcdir)/../../../../H -I$(srcdir)/../../../../library/dialect/bprolog/fli +LDFLAGS=@LDFLAGS@ +# +# +# You shouldn't need to change what follows. +# +INSTALL=@INSTALL@ +INSTALL_DATA=@INSTALL_DATA@ +INSTALL_PROGRAM=@INSTALL_PROGRAM@ +SHELL=/bin/sh +RANLIB=@RANLIB@ +srcdir=@srcdir@ +SO=@SO@ +#4.1VPATH=@srcdir@:@srcdir@/OPTYap +CWD=$(PWD) +# + +##---------------------------------------------------------------------- + +ifeq ($(PROCTYPE),mp) +SUBDIRS += $(MP_DIR) +OBJS += $(MP_OBJS) +endif + +##---------------------------------------------------------------------- + +include $(srcdir)/makefiles/Makefile.files +S=/ +O=o + +SOBJS=prism.@SO@ + +#in some systems we just create a single object, in others we need to +# create a libray + +all: $(SOBJS) + +core/%.o: $(srcdir)/core/%.c + $(CC) -c $(CFLAGS) $< -o $@ + +up/%.o: $(srcdir)/up/%.c + $(CC) -c $(CFLAGS) $< -o $@ + +mp/%.o: $(srcdir)/mp/%.c + $(CC) -c $(CFLAGS) $< -o $@ + +@DO_SECOND_LD@prism.@SO@: $(OBJS) +@DO_SECOND_LD@ @SHLIB_LD@ $(LDFLAGS) -o $@ $(OBJS) @EXTRA_LIBS_FOR_DLLS@ + +all: $(TARGET) + +install: $(TARGET) + $(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR)$(YAPLIBDIR) + +clean: clean_subdirs + $(RM) $(TARGET) + +clean_subdirs: + for i in $(SUBDIRS); do \ + ($(MAKE) -f $(MAKEFILE) -C $$i clean ) \ + done + +##---------------------------------------------------------------------- + +.PHONY: all install clean $(SUBDIRS) + +##---------------------------------------------------------------------- diff --git a/packages/prism/src/c/core/bpx.c b/packages/prism/src/c/core/bpx.c new file mode 100644 index 000000000..eaa5bcf5c --- /dev/null +++ b/packages/prism/src/c/core/bpx.c @@ -0,0 +1,401 @@ +#include +#include +#include +#include +#include +#include "core/bpx.h" +#include "core/vector.h" + +/*--------------------------------------------------------------------*/ + +#define REQUIRE_HEAP(n) \ + ( heap_top + (n) <= local_top ? \ + (void)(0) : myquit(STACK_OVERFLOW, "stack + heap") ) + +/*--------------------------------------------------------------------*/ +/* Functions from B-Prolog */ + +/* cpred.c */ +int bp_string_2_term(const char *, TERM, TERM); +char* bp_term_2_string(TERM); +int bp_call_term(TERM); +int bp_mount_query_term(TERM); +int bp_next_solution(void); + +/* file.c */ +void write_term(TERM); + +/* float1.c */ +double floatval(TERM); +TERM encodefloat1(double); + +/* loader.c */ +SYM_REC_PTR insert(const char *, int, int); + +/* mic.c */ +NORET quit(const char *); +NORET myquit(int, const char *); + +/* unify.c */ +int unify(TERM, TERM); +int is_UNIFIABLE(TERM, TERM); +int is_IDENTICAL(TERM, TERM); + +/* prism.c */ +NORET bp4p_quit(int); + +/*--------------------------------------------------------------------*/ + +static NORET bpx_raise(const char *fmt, ...) +{ + va_list ap; + + fprintf(curr_out, "*** {PRISM BPX ERROR: "); + va_start(ap, fmt); + vfprintf(curr_out, fmt, ap); + va_end(ap); + fprintf(curr_out, "}\n"); + + bp4p_quit(1); +} + +/*--------------------------------------------------------------------*/ + +bool bpx_is_var(TERM t) +{ + XDEREF(t); + return ISREF(t); +} + +bool bpx_is_atom(TERM t) +{ + XDEREF(t); + return ISATOM(t); +} + +bool bpx_is_integer(TERM t) +{ + XDEREF(t); + return ISINT(t); +} + +bool bpx_is_float(TERM t) +{ + XDEREF(t); + return ISNUM(t); +} + +bool bpx_is_nil(TERM t) +{ + XDEREF(t); + return ISNIL(t); +} + +bool bpx_is_list(TERM t) +{ + XDEREF(t); + return ISLIST(t); +} + +bool bpx_is_structure(TERM t) +{ + XDEREF(t); + return ISSTRUCT(t); +} + +bool bpx_is_compound(TERM t) +{ + XDEREF(t); + return ISCOMPOUND(t); +} + +bool bpx_is_unifiable(TERM t1, TERM t2) +{ + XDEREF(t1); + XDEREF(t2); + return (bool)(is_UNIFIABLE(t1, t2)); +} + +bool bpx_is_identical(TERM t1, TERM t2) +{ + XDEREF(t1); + XDEREF(t2); + return (bool)(is_IDENTICAL(t1, t2)); +} + +/*--------------------------------------------------------------------*/ + +TERM bpx_get_call_arg(BPLONG i, BPLONG arity) +{ + if (i < 1 || i > arity) { + bpx_raise("index out of range"); + } + return ARG(i, arity); +} + +BPLONG bpx_get_integer(TERM t) +{ + XDEREF(t); + + if (ISINT(t)) { + return INTVAL(t); + } + else { + bpx_raise("integer expected"); + } +} + +double bpx_get_float(TERM t) +{ + XDEREF(t); + + if (ISINT(t)) { + return (double)(INTVAL(t)); + } + else if (ISFLOAT(t)) { + return floatval(t); + } + else { + bpx_raise("integer or floating number expected"); + } +} + +const char * bpx_get_name(TERM t) +{ + XDEREF(t); + + switch (XTAG(t)) { + case STR: + return GET_NAME_STR(GET_STR_SYM_REC(t)); + case ATM: + return GET_NAME_ATOM(GET_ATM_SYM_REC(t)); + case LST: + return "."; + default: + bpx_raise("callable expected"); + } +} + +int bpx_get_arity(TERM t) +{ + XDEREF(t); + + switch (XTAG(t)) { + case STR: + return GET_ARITY_STR(GET_STR_SYM_REC(t)); + case ATM: + return GET_ARITY_ATOM(GET_ATM_SYM_REC(t)); + case LST: + return 2; + default: + bpx_raise("callable expected"); + } +} + +TERM bpx_get_arg(BPLONG i, TERM t) +{ + BPLONG n, j; + + XDEREF(t); + + switch (XTAG(t)) { + case STR: + n = GET_ARITY_STR(GET_STR_SYM_REC(t)); + j = 0; + break; + case LST: + n = 2; + j = 1; + break; + default: + bpx_raise("compound expected"); + } + + if (i < 1 || i > n) { + bpx_raise("bad argument index"); + } + return GET_ARG(t, i - j); +} + +TERM bpx_get_car(TERM t) +{ + XDEREF(t); + + if (ISLIST(t)) { + return GET_CAR(t); + } + else { + bpx_raise("list expected"); + } +} + +TERM bpx_get_cdr(TERM t) +{ + XDEREF(t); + + if (ISLIST(t)) { + return GET_CDR(t); + } + else { + bpx_raise("list expected"); + } +} + +/*--------------------------------------------------------------------*/ + +TERM bpx_build_var(void) +{ + TERM term; + + REQUIRE_HEAP(1); + term = (TERM)(heap_top); + NEW_HEAP_FREE; + return term; +} + +TERM bpx_build_integer(BPLONG n) +{ + return MAKEINT(n); +} + +TERM bpx_build_float(double x) +{ + REQUIRE_HEAP(4); + return encodefloat1(x); +} + +TERM bpx_build_atom(const char *name) +{ + SYM_REC_PTR sym; + + sym = insert(name, strlen(name), 0); + return ADDTAG(sym, ATM); +} + +TERM bpx_build_list(void) +{ + TERM term; + + REQUIRE_HEAP(2); + term = ADDTAG(heap_top, LST); + NEW_HEAP_FREE; + NEW_HEAP_FREE; + return term; +} + +TERM bpx_build_nil(void) +{ + return nil_sym; +} + +TERM bpx_build_structure(const char *name, BPLONG arity) +{ + SYM_REC_PTR sym; + TERM term; + + REQUIRE_HEAP(arity + 1); + term = ADDTAG(heap_top, STR); + sym = insert(name, strlen(name), arity); + NEW_HEAP_NODE((TERM)(sym)); + while (--arity >= 0) { + NEW_HEAP_FREE; + } + return term; +} + +/*--------------------------------------------------------------------*/ + +bool bpx_unify(TERM t1, TERM t2) +{ + return (bool)(unify(t1, t2)); +} + +/*--------------------------------------------------------------------*/ + +TERM bpx_string_2_term(const char *s) +{ + TERM term, vars; + int result; + + REQUIRE_HEAP(2); + term = (TERM)(heap_top); + NEW_HEAP_FREE; + vars = (TERM)(heap_top); + NEW_HEAP_FREE; + + result = bp_string_2_term(s, term, vars); + if (result != BP_TRUE) { + bpx_raise("parsing failed -- %s", s); + } + return term; +} + +const char * bpx_term_2_string(TERM t) +{ + XDEREF(t); + return bp_term_2_string(t); +} + +/*--------------------------------------------------------------------*/ + +int bpx_call_term(TERM t) +{ + XDEREF(t); + return bp_call_term(t); +} + +int bpx_call_string(const char *s) +{ + return bp_call_term(bpx_string_2_term(s)); +} + +int bpx_mount_query_term(TERM t) +{ + XDEREF(t); + return bp_mount_query_term(t); +} + +int bpx_mount_query_string(const char *s) +{ + return bp_mount_query_term(bpx_string_2_term(s)); +} + +int bpx_next_solution(void) +{ + if (curr_toam_status == TOAM_NOTSET) { + bpx_raise("no goal mounted"); + } + return bp_next_solution(); +} + +/*--------------------------------------------------------------------*/ + +void bpx_write(TERM t) +{ + XDEREF(t); + write_term(t); +} + +/*--------------------------------------------------------------------*/ + +int bpx_printf(const char *fmt, ...) +{ + va_list ap; + int r; + + va_start(ap, fmt); + r = vfprintf(curr_out, fmt, ap); + va_end(ap); + + return r; +} + +/*--------------------------------------------------------------------*/ + +#ifdef __YAP_PROLOG__ +BPLONG toam_signal_vec; + +BPLONG illegal_arguments; +BPLONG failure_atom; +BPLONG number_var_exception; +#endif diff --git a/packages/prism/src/c/core/bpx.h b/packages/prism/src/c/core/bpx.h new file mode 100644 index 000000000..451841425 --- /dev/null +++ b/packages/prism/src/c/core/bpx.h @@ -0,0 +1,323 @@ +#ifndef BPX_H +#define BPX_H + +#include "bprolog.h" +#include "stuff.h" + +#ifdef __YAP_PROLOG__ + +#include +#include +#include +#include +#include + +typedef void *SYM_REC_PTR; + +#define heap_top H +#define local_top ASP +#define trail_top TR +#define trail_up_addr ((tr_fr_ptr)LCL0) + +#define UNDO_TRAILING while (TR > (tr_fr_ptr)trail_top0) { RESET_VARIABLE(VarOfTerm(TrailTerm(TR--))); } + +#define NEW_HEAP_NODE(x) (*heap_top++ = (x)) + +#define STACK_OVERFLOW 1 + +/*====================================================================*/ + +#define ARG(X,Y) XREGS[X] +#define XDEREF(T) while (IsVarTerm(T)) { CELL *next = VarOfTerm(T); if (IsUnboundVar(next)) break; (T) = *next; } +#define MAKEINT(I) bp_build_integer(I) +#define INTVAL(T) bp_get_integer(T) + +#define MAX_ARITY 256 + +#define BP_MALLOC(X,Y,Z) ( X = malloc((Y)*sizeof(BPLONG)) ) + +#define NULL_TERM ((TERM)(0)) + +#define REF0 0x0L +#define REF1 0x1L +#define SUSP 0x2L +#define LST 0x4L +#define ATM 0x8L +#define INT 0x10L +#define STR 0x20L +#define NVAR (LST|ATM|INT|STR) + +#define GET_STR_SYM_REC(p) ((SYM_REC_PTR)*RepAppl(p)) +#define GET_ATM_SYM_REC(p) ((SYM_REC_PTR)AtomOfTerm(p)) + +#define GET_ARITY_STR(s) YAP_ArityOfFunctor((YAP_Functor)(s)) +#define GET_ARITY_ATOM(s) 0 + +#define GET_NAME_STR(f) YAP_AtomName(YAP_NameOfFunctor((YAP_Functor)(f))) +#define GET_NAME_ATOM(a) YAP_AtomName((YAP_Atom)(a)) + +static inline +long int XTAG(TERM t) +{ + switch(YAP_TagOfTerm(t)) { + case YAP_TAG_UNBOUND: + return REF0; + case YAP_TAG_ATT: + return SUSP; + case YAP_TAG_REF: + return REF1; + case YAP_TAG_PAIR: + return LST; + case YAP_TAG_ATOM: + return ATM; + case YAP_TAG_INT: + return INT; + case YAP_TAG_LONG_INT: + return INT; + case YAP_TAG_APPL: + default: + return STR; + } +} + +extern inline TERM ADDTAG(void * t,int tag) { + if (tag == ATM) + return MkAtomTerm((Atom)t); + if (tag == LST) + return AbsPair((CELL *)t); + return AbsAppl((CELL *)t); +} + +#define ISREF(t) IsVarTerm(t) +#define ISATOM(t) IsAtomTerm(t) +#define ISINT(t) IsIntegerTerm(t) +#define ISNUM(t) YAP_IsNumberTerm(t) +#define ISNIL(t) YAP_IsTermNil(t) +#define ISLIST(t) IsPairTerm(t) +#define ISSTRUCT(t) IsApplTerm(t) +#define ISFLOAT(t) IsFloatTerm(t) +#define ISCOMPOUND(t) YAP_IsCompoundTerm(t) + +#define floatval FloatOfTerm +#define encodefloat1 MkFloatTerm + +extern inline int is_UNIFIABLE(TERM t1, TERM t2) +{ + return YAP_Unifiable(t1, t2); +} + +extern inline int is_IDENTICAL(TERM t1, TERM t2) +{ + return YAP_ExactlyEqual(t1, t2); +} + + +#define SWITCH_OP(T,NDEREF,VCODE,ACODE,LCODE,SCODE,SUCODE) \ + switch (XTAG((T))) { \ + case REF0: \ + VCODE \ + case LST: \ + LCODE \ + case SUSP: \ + SUCODE \ + case STR: \ + SCODE \ + default: \ + ACODE \ + } + +#define XNDEREF(X,LAB) + +#define GET_ARG(A,I) YAP_ArgOfTerm((I),(A)) +#define GET_CAR(A) YAP_HeadOfTerm(A) +#define GET_CDR(A) YAP_TailOfTerm(A) + +#define MAKE_NVAR(id) ( (YAP_Term)(id) ) + +#define float_psc ((YAP_Functor)FunctorDouble) + +#define NEW_HEAP_FREE (*H = (CELL)H); H++ + +#define nil_sym YAP_TermNil() + +extern BPLONG illegal_arguments; +extern BPLONG failure_atom; +extern BPLONG number_var_exception; + +extern BPLONG toam_signal_vec; + +#define unify YAP_Unify + +extern inline char * +bp_term_2_string(TERM t) +{ + char *buf = malloc(256); + if (!buf) return NULL; + YAP_WriteBuffer(t, buf, 256, 0); + return buf; +} + +// char *bp_get_name(TERM t) +extern inline int +bp_string_2_term(const char *s, TERM to, TERM tv) +{ + TERM t0 = YAP_ReadBuffer(s, NULL); + TERM t1 = YAP_TermNil(); // for now + return unify(t0, to) && unify(t1,tv); +} + +extern inline SYM_REC_PTR +insert(const char *name, int size, int arity) +{ + if (!arity) { + return (SYM_REC_PTR)YAP_LookupAtom(name); + } + return (SYM_REC_PTR)YAP_MkFunctor(YAP_LookupAtom(name), arity); +} + +extern inline int +compare(TERM t1, TERM t2) +{ + // compare terms?? + return YAP_CompareTerms(t1,t2); +} + +extern inline void +write_term(TERM t) +{ + YAP_Write(t,NULL,0); +} + +static NORET quit(const char *s) +{ + fprintf(stderr,"PRISM QUIT: %s\n",s); + exit(0); +} + + +static NORET myquit(int i, const char *s) +{ + fprintf(stderr,"PRISM QUIT: %s\n",s); + exit(i); +} + +// vsc: why two arguments? +static inline int +list_length(BPLONG t1, BPLONG t2) +{ + return YAP_ListLength((TERM)t1); +} + +#define PRE_NUMBER_VAR(X) + +static inline void +numberVarTermOpt(TERM t) +{ + YAP_NumberVars(t, 0); +} + +static inline TERM +unnumberVarTerm(TERM t, BPLONG_PTR pt1, BPLONG_PTR pt2) +{ + return YAP_UnNumberVars(t); +} + +extern inline int +unifyNumberedTerms(TERM t1, TERM t2) +{ + if (YAP_Unify(t1,t2)) + return TRUE; + return FALSE; +} + +#define IsNumberedVar YAP_IsNumberedVariable + +#else + +#define GET_ARITY_ATOM GET_ARITY +#define GET_ARITY_STR GET_ARITY + +#define GET_NAME_STR GET_NAME +#define GET_NAME_ATOM GET_NAME + +/*====================================================================*/ + +#define NULL_TERM ((TERM)(0)) + +/*--------------------------------*/ + +/* These are the safer versions of DEREF and NDEREF macros. */ + +#define XDEREF(op) \ + do { if(TAG(op) || (op) == FOLLOW(op)) { break; } (op) = FOLLOW(op); } while(1) +#define XNDEREF(op, label) \ + do { if(TAG(op) || (op) == FOLLOW(op)) { break; } (op) = FOLLOW(op); goto label; } while(1) + +/*--------------------------------*/ + +/* This low-level macro provides more detailed information about the */ +/* type of a given term than TAG(op). */ + +#define XTAG(op) ((op) & TAG_MASK) + +#define REF0 0x0L +#define REF1 TOP_BIT +#define INT INT_TAG +#define NVAR TAG_MASK + +/*--------------------------------*/ + +/* The following macros are the same as IsNumberedVar and NumberVar */ +/* respectively, provided just for more consistent naming. */ + +#define IS_NVAR(op) ( ((op) & TAG_MASK) == NVAR ) +#define MAKE_NVAR(id) ( (((BPLONG)(id)) << 2) | NVAR ) + +/*--------------------------------*/ + +/* This macro is redefined to reduce warnings on GCC 4.x. */ + +#if defined LINUX && ! defined M64BITS +#undef UNTAGGED_ADDR +#define UNTAGGED_ADDR(op) ( (((BPLONG)(op)) & VAL_MASK0) | addr_top_bit ) +#endif + +/*====================================================================*/ + +#endif /* YAP */ + +bool bpx_is_var(TERM); +bool bpx_is_atom(TERM); +bool bpx_is_integer(TERM); +bool bpx_is_float(TERM); +bool bpx_is_nil(TERM); +bool bpx_is_list(TERM); +bool bpx_is_structure(TERM); +bool bpx_is_compound(TERM); +bool bpx_is_unifiable(TERM, TERM); +bool bpx_is_identical(TERM, TERM); + +TERM bpx_get_call_arg(BPLONG, BPLONG); + +BPLONG bpx_get_integer(TERM); +double bpx_get_float(TERM); +const char* bpx_get_name(TERM); +int bpx_get_arity(TERM); +TERM bpx_get_arg(BPLONG, TERM); +TERM bpx_get_car(TERM); +TERM bpx_get_cdr(TERM); + +TERM bpx_build_var(void); +TERM bpx_build_integer(BPLONG); +TERM bpx_build_float(double); +TERM bpx_build_atom(const char *); +TERM bpx_build_list(void); +TERM bpx_build_nil(void); +TERM bpx_build_structure(const char *, BPLONG); + +bool bpx_unify(TERM, TERM); + +TERM bpx_string_2_term(const char *); +const char* bpx_term_2_string(TERM); + +#endif /* BPX_H */ diff --git a/packages/prism/src/c/core/error.c b/packages/prism/src/c/core/error.c new file mode 100644 index 000000000..cf22d004a --- /dev/null +++ b/packages/prism/src/c/core/error.c @@ -0,0 +1,108 @@ +#include +#include "bprolog.h" +#include "core/bpx.h" + +/*--------------------------------------------------------------------*/ + +#ifndef __YAP_PROLOG__ +TERM bpx_build_atom(const char *); +#endif + +/*--------------------------------------------------------------------*/ + +TERM err_runtime; +TERM err_internal; + +TERM err_cycle_detected; +TERM err_invalid_likelihood; +TERM err_invalid_free_energy; +TERM err_invalid_numeric_value; +TERM err_invalid_goal_id; +TERM err_invalid_switch_instance_id; +TERM err_underflow; +TERM err_overflow; +TERM err_ctrl_c_pressed; + +TERM ierr_invalid_likelihood; +TERM ierr_invalid_free_energy; +TERM ierr_function_not_implemented; +TERM ierr_unmatched_branches; + +/*--------------------------------------------------------------------*/ + +TERM build_runtime_error(const char *s) +{ + TERM t; + + if (s == NULL) return bpx_build_atom("prism_runtime_error"); + + t = bpx_build_structure("prism_runtime_error",1); + bpx_unify(bpx_get_arg(1,t),bpx_build_atom(s)); + + return t; +} + +TERM build_internal_error(const char *s) +{ + TERM t; + + if (s == NULL) return bpx_build_atom("prism_internal_error"); + + t = bpx_build_structure("prism_internal_error",1); + bpx_unify(bpx_get_arg(1,t),bpx_build_atom(s)); + + return t; +} + +/*--------------------------------------------------------------------*/ + +void register_prism_errors(void) +{ + err_runtime = build_runtime_error(NULL); + err_internal = build_internal_error(NULL); + + err_cycle_detected = build_runtime_error("cycle_detected"); + err_invalid_likelihood = build_runtime_error("invalid_likelihood"); + err_invalid_free_energy = build_runtime_error("invalid_free_energy"); + err_invalid_numeric_value = build_runtime_error("invalid_numeric_value"); + err_invalid_goal_id = build_runtime_error("invalid_goal_id"); + err_invalid_switch_instance_id = build_runtime_error("invalid_switch_instance_id"); + err_underflow = build_runtime_error("underflow"); + err_overflow = build_runtime_error("overflow"); + err_ctrl_c_pressed = build_runtime_error("ctrl_c_pressed"); + + ierr_invalid_likelihood = build_internal_error("invalid_likelihood"); + ierr_invalid_free_energy = build_internal_error("invalid_free_energy"); + ierr_function_not_implemented = build_internal_error("function_not_implemented"); + ierr_unmatched_branches = build_internal_error("unmatched_branches"); +} + +/*--------------------------------------------------------------------*/ + +void emit_error(const char *fmt, ...) +{ + va_list ap; + + fprintf(curr_out, "*** PRISM ERROR: "); + va_start(ap, fmt); + vfprintf(curr_out, fmt, ap); + va_end(ap); + fprintf(curr_out, "\n"); + + fflush(curr_out); +} + +void emit_internal_error(const char *fmt, ...) +{ + va_list ap; + + fprintf(curr_out, "*** PRISM INTERNAL ERROR: "); + va_start(ap, fmt); + vfprintf(curr_out, fmt, ap); + va_end(ap); + fprintf(curr_out, "\n"); + + fflush(curr_out); +} + +/*--------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/core/error.h b/packages/prism/src/c/core/error.h new file mode 100644 index 000000000..f52727d83 --- /dev/null +++ b/packages/prism/src/c/core/error.h @@ -0,0 +1,66 @@ +#ifndef ERROR_H +#define ERROR_H + +/*--------------------------------------------------------------------*/ + +#define RET_ERR(err) \ + do { \ + exception = (err); \ + return BP_ERROR; \ + } while (0) + +#define RET_RUNTIME_ERR \ + do { \ + exception = err_runtime; \ + return BP_ERROR; \ + } while (0) + +#define RET_INTERNAL_ERR \ + do { \ + exception = err_internal; \ + return BP_ERROR; \ + } while (0) + +#define RET_ON_ERR(expr) \ + do { \ + if ((expr) == BP_ERROR) return BP_ERROR; \ + } while (0) + +#define RET_ERR_ON_ERR(expr,err) \ + do { \ + if ((expr) == BP_ERROR) { \ + exception = (err); \ + return BP_ERROR; \ + } \ + } while (0) + +/*--------------------------------------------------------------------*/ + +extern TERM err_runtime; +extern TERM err_internal; + +extern TERM err_cycle_detected; +extern TERM err_invalid_likelihood; +extern TERM err_invalid_free_energy; +extern TERM err_invalid_numeric_value; +extern TERM err_invalid_goal_id; +extern TERM err_invalid_switch_instance_id; +extern TERM err_underflow; +extern TERM err_overflow; +extern TERM err_ctrl_c_pressed; + +extern TERM ierr_invalid_likelihood; +extern TERM ierr_invalid_free_energy; +extern TERM ierr_function_not_implemented; +extern TERM ierr_unmatched_branches; + +/*--------------------------------------------------------------------*/ + +TERM build_runtime_error(const char *); +TERM build_internal_error(const char *); +void emit_error(const char *, ...); +void emit_internal_error(const char *, ...); + +/*--------------------------------------------------------------------*/ + +#endif /* ERROR_H */ diff --git a/packages/prism/src/c/core/fputil.c b/packages/prism/src/c/core/fputil.c new file mode 100644 index 000000000..2436d66d3 --- /dev/null +++ b/packages/prism/src/c/core/fputil.c @@ -0,0 +1,11 @@ +#include "core/fputil.h" + +double fputil_snan(void) +{ + return +sqrt(-1); +} + +double fputil_qnan(void) +{ + return -sqrt(-1); +} diff --git a/packages/prism/src/c/core/fputil.h b/packages/prism/src/c/core/fputil.h new file mode 100644 index 000000000..2d8c1e4bb --- /dev/null +++ b/packages/prism/src/c/core/fputil.h @@ -0,0 +1,51 @@ +#ifndef FPUTIL_H +#define FPUTIL_H + +/*--------------------------------------------------------------------*/ + +#include + +#ifdef __STDC_VERSION__ +#if __STDC_VERSION__ >= 199901L +#define C99 +#endif +#endif + +/*--------------------------------------------------------------------*/ + +#if defined C99 +/* (empty) */ +#elif defined _MSC_VER +#include +#define isfinite _finite +#define isnan _isnan +#define INFINITY HUGE_VAL +#elif defined LINUX +# ifndef isfinite +# define isfinite finite +# endif +# ifndef isnan +# define isnan isnan +# endif +# ifndef INFINITY +# define INFINITY HUGE_VAL +# endif +#elif defined DARWIN +/* (empty) */ +#else +#define isfinite(x) (0.0 * (x) != 0.0) +#define isnan(x) ((x) != (x)) +#define INFINITY HUGE_VAL +#endif + +#define SNAN fputil_snan() +#define QNAN fputil_qnan() + +/*--------------------------------------------------------------------*/ + +double fputil_snan(void); +double fputil_qnan(void); + +/*--------------------------------------------------------------------*/ + +#endif /* FPUTIL_H */ diff --git a/packages/prism/src/c/core/gamma.c b/packages/prism/src/c/core/gamma.c new file mode 100644 index 000000000..f928db75b --- /dev/null +++ b/packages/prism/src/c/core/gamma.c @@ -0,0 +1,306 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- + + This file contains a portable implementation for a couple of gamma- + family functions, originally written for the PRISM programming system + . + + The code is based on SPECFUN (Fortran program collection for special + functions by W. J. Cody et al. at Argonne National Laboratory), which + is available in public domain at . + + Here is the license terms for this file (just provided to explicitly + state that the code can be used for any purpose): + +------------------------------------------------------------------------------ + + Copyright (c) 2007-2009 Yusuke Izumi + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + + 3. This notice may not be removed or altered from any source distribution. + +------------------------------------------------------------------------------ + +*/ + +#include +#include "core/gamma.h" + +#define PI (3.14159265358979323846) /* pi */ +#define PI_2 (1.57079632679489661923) /* pi / 2 */ +#define PI_4 (0.78539816339744830962) /* pi / 4 */ +#define LN_SQRT2PI (0.91893853320467274178) /* ln(sqrt(2 * pi)) */ + +/** + * Computes ln(|Gamma(x)|). + */ +double lngamma(double x) +{ + /* Constants for [0.5,1.5) -------------------------------------------*/ + + const double D1 = -5.772156649015328605195174e-01; + + const double P1[] = { + +4.945235359296727046734888e+00, +2.018112620856775083915565e+02, + +2.290838373831346393026739e+03, +1.131967205903380828685045e+04, + +2.855724635671635335736389e+04, +3.848496228443793359990269e+04, + +2.637748787624195437963534e+04, +7.225813979700288197698961e+03 + }; + + const double Q1[] = { + +6.748212550303777196073036e+01, +1.113332393857199323513008e+03, + +7.738757056935398733233834e+03, +2.763987074403340708898585e+04, + +5.499310206226157329794414e+04, +6.161122180066002127833352e+04, + +3.635127591501940507276287e+04, +8.785536302431013170870835e+03 + }; + + /* Constants for [1.5,4.0) -------------------------------------------*/ + + const double D2 = +4.227843350984671393993777e-01; + + const double P2[] = { + +4.974607845568932035012064e+00, +5.424138599891070494101986e+02, + +1.550693864978364947665077e+04, +1.847932904445632425417223e+05, + +1.088204769468828767498470e+06, +3.338152967987029735917223e+06, + +5.106661678927352456275255e+06, +3.074109054850539556250927e+06 + }; + + const double Q2[] = { + +1.830328399370592604055942e+02, +7.765049321445005871323047e+03, + +1.331903827966074194402448e+05, +1.136705821321969608938755e+06, + +5.267964117437946917577538e+06, +1.346701454311101692290052e+07, + +1.782736530353274213975932e+07, +9.533095591844353613395747e+06 + }; + + /* Constants for [4.0,12.0) ------------------------------------------*/ + + const double D4 = +1.791759469228055000094023e+00; + + const double P4[] = { + +1.474502166059939948905062e+04, +2.426813369486704502836312e+06, + +1.214755574045093227939592e+08, +2.663432449630976949898078e+09, + +2.940378956634553899906876e+10, +1.702665737765398868392998e+11, + +4.926125793377430887588120e+11, +5.606251856223951465078242e+11 + }; + + const double Q4[] = { + +2.690530175870899333379843e+03, +6.393885654300092398984238e+05, + +4.135599930241388052042842e+07, +1.120872109616147941376570e+09, + +1.488613728678813811542398e+10, +1.016803586272438228077304e+11, + +3.417476345507377132798597e+11, +4.463158187419713286462081e+11 + }; + + /* Constants for [12.0,Infinity) -------------------------------------*/ + + const double C[] = { + -2.955065359477124231624146e-02, +6.410256410256410034009811e-03, + -1.917526917526917633674555e-03, +8.417508417508417139715760e-04, + -5.952380952380952917890600e-04, +7.936507936507936501052685e-04, + -2.777777777777777883788657e-03, +8.333333333333332870740406e-02 + }; + + /*--------------------------------------------------------------------*/ + + const double EPS = 2.22e-16; + const double P68 = 87.0 / 128.0; + const double BIG = 2.25e+76; + + /*--------------------------------------------------------------------*/ + + double p, q, y; + int i, n; + + if (x != x) /* NaN */ + return x; + else if (0 * x != 0) /* Infinity */ + return HUGE_VAL; + else if (x <= 0.0) { + q = modf(-2.0 * x, &p); + n = (int)(p); + q = sin(PI_2 * (n % 2 == 0 ? q : 1.0 - q)); + return log(PI / q) - lngamma(1.0 - x); + } + else if (x < EPS) + return -log(x); + else if (x < 0.5) { + p = 0.0; + q = 1.0; + y = x; + for (i = 0; i < 8; i++) { + p = p * y + P1[i]; + q = q * y + Q1[i]; + } + return x * (D1 + y * (p / q)) - log(x); + } + else if (x < P68) { + p = 0.0; + q = 1.0; + y = x - 1.0; + for (i = 0; i < 8; i++) { + p = p * y + P2[i]; + q = q * y + Q2[i]; + } + return y * (D2 + y * (p / q)) - log(x); + } + else if (x < 1.5) { + p = 0.0; + q = 1.0; + y = x - 1.0; + for (i = 0; i < 8; i++) { + p = p * y + P1[i]; + q = q * y + Q1[i]; + } + return y * (D1 + y * (p / q)); + } + else if (x < 4.0) { + p = 0.0; + q = 1.0; + y = x - 2.0; + for (i = 0; i < 8; i++) { + p = p * y + P2[i]; + q = q * y + Q2[i]; + } + return y * (D2 + y * (p / q)); + } + else if (x < 12.0) { + p = 0.0; + q = -1.0; + y = x - 4.0; + for (i = 0; i < 8; i++) { + p = p * y + P4[i]; + q = q * y + Q4[i]; + } + return D4 + y * (p / q); + } + else if (x < BIG) { + p = 0.0; + q = log(x); + y = 1.0 / (x * x); + for (i = 0; i < 8; i++) { + p = p * y + C[i]; + } + return p / x + LN_SQRT2PI - 0.5 * q + x * (q - 1.0); + } + else { + q = log(x); + return LN_SQRT2PI - 0.5 * q + x * (q - 1.0); + } + + /*--------------------------------------------------------------------*/ +} + +/** + * Computes Psi(x) = (d/dx)(ln(Gamma(x))) + */ +double digamma(double x) +{ + /* Constants for [0.5,3.0] -------------------------------------------*/ + + const double P1[] = { + +4.5104681245762934160e-03, +5.4932855833000385356e+00, + +3.7646693175929276856e+02, +7.9525490849151998065e+03, + +7.1451595818951933210e+04, +3.0655976301987365674e+05, + +6.3606997788964458797e+05, +5.8041312783537569993e+05, + +1.6585695029761022321e+05 + }; + + const double Q1[] = { + +9.6141654774222358525e+01, +2.6287715790581193330e+03, + +2.9862497022250277920e+04, +1.6206566091533671639e+05, + +4.3487880712768329037e+05, +5.4256384537269993733e+05, + +2.4242185002017985252e+05, +6.4155223783576225996e-08 + }; + + /* Constants for (3.0,Infinity) --------------------------------------*/ + + const double P2[] = { + -2.7103228277757834192e+00, -1.5166271776896121383e+01, + -1.9784554148719218667e+01, -8.8100958828312219821e+00, + -1.4479614616899842986e+00, -7.3689600332394549911e-02, + -6.5135387732718171306e-21 + }; + + const double Q2[] = { + +4.4992760373789365846e+01, +2.0240955312679931159e+02, + +2.4736979003315290057e+02, +1.0742543875702278326e+02, + +1.7463965060678569906e+01, +8.8427520398873480342e-01 + }; + + /*--------------------------------------------------------------------*/ + + const double MIN = 2.23e-308; + const double MAX = 4.50e+015; + const double SMALL = 5.80e-009; + const double LARGE = 2.71e+014; + + const double X01 = 187.0 / 128.0; + const double X02 = 6.9464496836234126266e-04; + + /*--------------------------------------------------------------------*/ + + double p, q, y, sgn; + int i, n; + + sgn = (x > 0.0) ? +1.0 : -1.0; + + y = fabs(x); + + if (x != x) /* NaN */ + return x; + else if (x < -MAX || y < MIN) + return -1.0 * sgn * HUGE_VAL; + else if (y < SMALL) + return digamma(1.0 - x) - 1.0 / x; + else if (x < 0.5) { + q = modf(4.0 * y, &p); + n = (int)(p); + + switch (n % 4) { + case 0: + return digamma(1.0 - x) - sgn * PI / tan(PI_4 * q); + case 1: + return digamma(1.0 - x) - sgn * PI * tan(PI_4 * (1.0 - q)); + case 2: + return digamma(1.0 - x) + sgn * PI * tan(PI_4 * q); + case 3: + return digamma(1.0 - x) + sgn * PI / tan(PI_4 * (1.0 - q)); + } + } + else if (x <= 3.0) { + p = 0.0; + q = 1.0; + for (i = 0; i < 8; i++) { + p = p * x + P1[i]; + q = q * x + Q1[i]; + } + p = p * x + P1[8]; + return p / q * ((x - X01) - X02); + } + else if (x < LARGE) { + p = 0.0; + q = 1.0; + y = 1.0 / (x * x); + for (i = 0; i < 6; i++) { + p = p * y + P2[i]; + q = q * y + Q2[i]; + } + p = p * y + P2[6]; + return p / q - 0.5 / x + log(x); + } + + return log(x); +} diff --git a/packages/prism/src/c/core/gamma.h b/packages/prism/src/c/core/gamma.h new file mode 100644 index 000000000..bb50ca76d --- /dev/null +++ b/packages/prism/src/c/core/gamma.h @@ -0,0 +1,7 @@ +#ifndef GAMMA_H +#define GAMMA_H + +double lngamma(double); +double digamma(double); + +#endif /* GAMMA_H */ diff --git a/packages/prism/src/c/core/glue.c b/packages/prism/src/c/core/glue.c new file mode 100644 index 000000000..e34bea2e2 --- /dev/null +++ b/packages/prism/src/c/core/glue.c @@ -0,0 +1,197 @@ +#include + +/*--------------------------------------------------------------------*/ + +#define REGISTER_CPRED(p,n) \ + do { extern int pc_ ## p ## _ ## n (void); insert_cpred("$pc_" #p, n, pc_ ## p ## _ ## n); } while (0) + +/*--------------------------------------------------------------------*/ + +typedef struct sym_rec * SYM_REC_PTR; +typedef long int TERM; +SYM_REC_PTR insert_cpred(const char *, int, int(*)(void)); +void exit(int); + +#ifdef __YAP_PROLOG__ + +int YAP_UserCpredicate(const char *s, int (*f)(void), unsigned long int n); + +SYM_REC_PTR insert_cpred(const char *s, int n, int(*f)(void)) +{ + YAP_UserCPredicate(s, f, n); + return NULL; +} + +#endif + +/*--------------------------------------------------------------------*/ + +void register_prism_errors(void); +#ifdef MPI +void mp_init(int *argc, char **argv[]); +void mp_done(void); +void mp_quit(int); +#endif + +/*--------------------------------------------------------------------*/ + +void bp4p_init(int *argc, char **argv[]) +{ +#ifdef MPI + mp_init(argc, argv); +#endif +} + +void bp4p_exit(int status) +{ +#ifdef MPI + mp_done(); +#endif + exit(status); +} + +void bp4p_quit(int status) +{ +#ifdef MPI + mp_quit(status); +#else + exit(status); +#endif +} + +void bp4p_register_preds(void) +{ + /* core/idtable.c */ + REGISTER_CPRED(prism_id_table_init,0); + REGISTER_CPRED(prism_goal_id_register,2); + REGISTER_CPRED(prism_sw_id_register,2); + REGISTER_CPRED(prism_sw_ins_id_register,2); + REGISTER_CPRED(prism_goal_id_get,2); + REGISTER_CPRED(prism_sw_id_get,2); + REGISTER_CPRED(prism_sw_ins_id_get,2); + REGISTER_CPRED(prism_goal_count,1); + REGISTER_CPRED(prism_sw_count,1); + REGISTER_CPRED(prism_sw_ins_count,1); + REGISTER_CPRED(prism_goal_term,2); + REGISTER_CPRED(prism_sw_term,2); + REGISTER_CPRED(prism_sw_ins_term,2); + + /* core/random.c */ + REGISTER_CPRED(random_auto_seed, 1); + REGISTER_CPRED(random_init_by_seed, 1); + REGISTER_CPRED(random_init_by_list, 1); + REGISTER_CPRED(random_float, 1); + REGISTER_CPRED(random_gaussian, 1); + REGISTER_CPRED(random_int, 2); + REGISTER_CPRED(random_int, 3); + REGISTER_CPRED(random_get_state, 1); + REGISTER_CPRED(random_set_state, 1); + + /* core/util.c */ + REGISTER_CPRED(lngamma, 2); + + /* up/em_preds.c */ + REGISTER_CPRED(prism_prepare,4); + REGISTER_CPRED(prism_em,6); + REGISTER_CPRED(prism_vbem,2); + REGISTER_CPRED(prism_both_em,2); + REGISTER_CPRED(compute_inside,2); + REGISTER_CPRED(compute_probf,1); + + /* up/viterbi.c */ + REGISTER_CPRED(compute_viterbi,5); + REGISTER_CPRED(compute_n_viterbi,3); + REGISTER_CPRED(compute_n_viterbi_rerank,4); + + /* up/hindsight.c */ + REGISTER_CPRED(compute_hindsight,4); + + /* up/graph.c */ + REGISTER_CPRED(alloc_egraph,0); + REGISTER_CPRED(clean_base_egraph,0); + REGISTER_CPRED(clean_egraph,0); + REGISTER_CPRED(export_switch,2); + REGISTER_CPRED(add_egraph_path,3); + REGISTER_CPRED(alloc_sort_egraph,1); + REGISTER_CPRED(clean_external_tables,0); + REGISTER_CPRED(export_sw_info,1); + REGISTER_CPRED(import_sorted_graph_size,1); + REGISTER_CPRED(import_sorted_graph_gid,2); + REGISTER_CPRED(import_sorted_graph_paths,2); + REGISTER_CPRED(get_gnode_inside,2); + REGISTER_CPRED(get_gnode_outside,2); + REGISTER_CPRED(get_gnode_viterbi,2); + REGISTER_CPRED(get_snode_inside,2); + REGISTER_CPRED(get_snode_expectation,2); + REGISTER_CPRED(import_occ_switches,3); + REGISTER_CPRED(import_graph_stats,4); + + /* up/flags.c */ + REGISTER_CPRED(set_daem,1); + REGISTER_CPRED(set_em_message,1); + REGISTER_CPRED(set_em_progress,1); + REGISTER_CPRED(set_error_on_cycle,1); + REGISTER_CPRED(set_explicit_empty_expls,1); + REGISTER_CPRED(set_fix_init_order,1); + REGISTER_CPRED(set_init_method,1); + REGISTER_CPRED(set_itemp_init,1); + REGISTER_CPRED(set_itemp_rate,1); + REGISTER_CPRED(set_log_scale,1); + REGISTER_CPRED(set_max_iterate,1); + REGISTER_CPRED(set_num_restart,1); + REGISTER_CPRED(set_prism_epsilon,1); + REGISTER_CPRED(set_show_itemp,1); + REGISTER_CPRED(set_std_ratio,1); + REGISTER_CPRED(set_verb_em,1); + REGISTER_CPRED(set_verb_graph,1); + REGISTER_CPRED(set_warn,1); + REGISTER_CPRED(set_debug_level,1); + + /* up/util.c */ + REGISTER_CPRED(mp_mode,0); + REGISTER_CPRED(get_term_depth,2); + REGISTER_CPRED(mtrace,0); + REGISTER_CPRED(muntrace,0); + REGISTER_CPRED(sleep,1); + +#ifdef MPI + /* mp/mp_preds.c */ + REGISTER_CPRED(mp_size,1); + REGISTER_CPRED(mp_rank,1); + REGISTER_CPRED(mp_master,0); + REGISTER_CPRED(mp_abort,0); + REGISTER_CPRED(mp_wtime,1); + REGISTER_CPRED(mp_sync,2); + REGISTER_CPRED(mp_send_goal,1); + REGISTER_CPRED(mp_recv_goal,1); + REGISTER_CPRED(mpm_bcast_command,1); + REGISTER_CPRED(mps_bcast_command,1); + REGISTER_CPRED(mps_revert_stdout,0); + + /* mp/mp_em_preds.c */ + REGISTER_CPRED(mpm_prism_em,6); + REGISTER_CPRED(mps_prism_em,0); + REGISTER_CPRED(mpm_prism_vbem,2); + REGISTER_CPRED(mps_prism_vbem,0); + REGISTER_CPRED(mpm_prism_both_em,2); + REGISTER_CPRED(mps_prism_both_em,0); + REGISTER_CPRED(mpm_import_graph_stats,4); + REGISTER_CPRED(mps_import_graph_stats,0); + + /* mp/mp_sw.c */ + REGISTER_CPRED(mp_send_switches,0); + REGISTER_CPRED(mp_recv_switches,0); + REGISTER_CPRED(mp_send_swlayout,0); + REGISTER_CPRED(mp_recv_swlayout,0); + REGISTER_CPRED(mpm_alloc_occ_switches,0); + + /* mp/mp_flags.c */ + REGISTER_CPRED(mpm_share_prism_flags,0); + REGISTER_CPRED(mps_share_prism_flags,0); +#endif + + /* up/error.c; FIXME: There would be a better place to call */ + register_prism_errors(); +} + +/*--------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/core/glue.h b/packages/prism/src/c/core/glue.h new file mode 100644 index 000000000..7d4da46a8 --- /dev/null +++ b/packages/prism/src/c/core/glue.h @@ -0,0 +1,9 @@ +#ifndef GLUE_H +#define GLUE_H + +void bp4p_init(void); +void bp4p_exit(int); +void bp4p_quit(int); +void bp4p_register_preds(void); + +#endif /* GLUE_H */ diff --git a/packages/prism/src/c/core/idtable.c b/packages/prism/src/c/core/idtable.c new file mode 100644 index 000000000..ed1bac53f --- /dev/null +++ b/packages/prism/src/c/core/idtable.c @@ -0,0 +1,175 @@ +#include "core/xmalloc.h" +#include "core/vector.h" +#include "core/termpool.h" +#include "core/idtable.h" +#include "core/stuff.h" + +/*--------------------------------------------------------------------*/ + +/* table.c */ +TERM unnumberVarTerm(TERM, BPLONG_PTR, BPLONG_PTR); + +/*--------------------------------------------------------------------*/ + +struct id_table { + TERM_POOL *store; + struct id_table_entry *elems; + IDNUM *bucks; + IDNUM nbucks; +}; + +struct id_table_entry { + TERM term; + IDNUM next; +}; + +/*--------------------------------------------------------------------*/ + +static void id_table_rehash(ID_TABLE *this) +{ + IDNUM *bucks, nbucks, i, j; + + nbucks = 2 * this->nbucks + 1; + + /* find the next prime number */ + for (i = 3; i * i <= nbucks; ) { + if (nbucks % i == 0) { + nbucks += 2; + i = 3; + } + else { + i += 2; + } + } + + bucks = MALLOC(sizeof(struct hash_entry *) * nbucks); + + for (i = 0; i < nbucks; i++) + bucks[i] = ID_NONE; + + for (i = 0; i < VECTOR_SIZE(this->elems); i++) { + j = (IDNUM)((BPULONG)(this->elems[i].term) % nbucks); + this->elems[i].next = bucks[j]; + bucks[j] = i; + } + + FREE(this->bucks); + + this->nbucks = nbucks; + this->bucks = bucks; +} + +static IDNUM id_table_search(const ID_TABLE *this, TERM term) +{ + BPULONG hash; + IDNUM i; + + hash = (BPULONG)(term); + + i = this->bucks[hash % this->nbucks]; + + while (i != ID_NONE) { + if (term == this->elems[i].term) { + return i; + } + i = this->elems[i].next; + } + + return ID_NONE; +} + +static IDNUM id_table_insert(ID_TABLE *this, TERM term) +{ + BPULONG hash; + IDNUM n; + const char *bpx_term_2_string(TERM); + + hash = (BPULONG)(term); + + n = (IDNUM)(VECTOR_SIZE(this->elems)); + + if (n >= this->nbucks) { + id_table_rehash(this); + } + + VECTOR_PUSH_NONE(this->elems); + this->elems[n].term = term; + this->elems[n].next = this->bucks[hash % this->nbucks]; + this->bucks[hash % this->nbucks] = n; + + /* fprintf(curr_out,">> TERM: %s = %d\n",bpx_term_2_string(term),n); */ + + return n; +} + +/*--------------------------------------------------------------------*/ + +ID_TABLE * id_table_create(void) +{ + ID_TABLE *this; + IDNUM i; + + this = MALLOC(sizeof(struct id_table)); + + this->elems = NULL; + this->nbucks = 17; /* prime number */ + this->bucks = MALLOC(sizeof(IDNUM) * this->nbucks); + this->store = term_pool_create(); + + for (i = 0; i < this->nbucks; i++) + this->bucks[i] = ID_NONE; + + VECTOR_INIT(this->elems); + return this; +} + +void id_table_delete(ID_TABLE *this) +{ + VECTOR_FREE(this->elems); + FREE(this->bucks); + term_pool_delete(this->store); + + FREE(this); +} + +/*--------------------------------------------------------------------*/ + +TERM id_table_id2term(const ID_TABLE *this, IDNUM i) +{ + return this->elems[i].term; /* numbered */ +} + +IDNUM id_table_retrieve(const ID_TABLE *this, TERM term) +{ + term = term_pool_retrieve(this->store, term); + + return id_table_search(this, term); +} + +IDNUM id_table_register(ID_TABLE *this, TERM term) +{ + BPULONG hash; + IDNUM i; + + term = term_pool_register(this->store, term); + hash = (BPULONG)(term); + + i = id_table_search(this, term); + if (i == ID_NONE) { + i = id_table_insert(this, term); + } + return i; +} + +int id_table_count(const ID_TABLE *this) +{ + return (int)VECTOR_SIZE(this->elems); +} + +/*--------------------------------------------------------------------*/ + +TERM unnumber_var_term(TERM term) +{ + BPLONG mvn = -1; + return unnumberVarTerm(term, local_top, &mvn); +} diff --git a/packages/prism/src/c/core/idtable.h b/packages/prism/src/c/core/idtable.h new file mode 100644 index 000000000..9b0e316f9 --- /dev/null +++ b/packages/prism/src/c/core/idtable.h @@ -0,0 +1,29 @@ +#ifndef IDTABLE_H +#define IDTABLE_H + +#include "bpx.h" + +/*--------------------------------------------------------------------*/ + +#define ID_NONE ((IDNUM)(-1)) + +/*--------------------------------------------------------------------*/ + +typedef struct id_table ID_TABLE; +typedef unsigned int IDNUM; + +/*--------------------------------------------------------------------*/ + +ID_TABLE * id_table_create(void); +void id_table_delete(ID_TABLE *); +TERM id_table_id2term(const ID_TABLE *, IDNUM); +IDNUM id_table_retrieve(const ID_TABLE *, TERM); +IDNUM id_table_register(ID_TABLE *, TERM); +int id_table_count(const ID_TABLE *); + +TERM unnumber_var_term(TERM); + +/*--------------------------------------------------------------------*/ + +#endif /* IDTABLE_H */ + diff --git a/packages/prism/src/c/core/idtable_preds.c b/packages/prism/src/c/core/idtable_preds.c new file mode 100644 index 000000000..8f0c1e802 --- /dev/null +++ b/packages/prism/src/c/core/idtable_preds.c @@ -0,0 +1,249 @@ +#include +#include "core/idtable.h" + +/*--------------------------------------------------------------------*/ + +static ID_TABLE *g_table = NULL; /* goals */ +static ID_TABLE *s_table = NULL; /* switches */ +static ID_TABLE *i_table = NULL; /* switch instances */ + +/*--------------------------------------------------------------------*/ + +/* cpreds.c */ +char * bp_term_2_string(TERM); + +/* unify.c */ +int unify(TERM, TERM); + +/*--------------------------------------------------------------------*/ + +int prism_goal_id_register(TERM term) +{ + return id_table_register(g_table, term); +} + +int prism_sw_id_register(TERM term) +{ + return id_table_register(s_table, term); +} + +int prism_sw_ins_id_register(TERM term) +{ + return id_table_register(i_table, term); +} + +int prism_goal_id_get(TERM term) +{ + return id_table_retrieve(g_table, term); +} + +int prism_sw_id_get(TERM term) +{ + return id_table_retrieve(s_table, term); +} + +int prism_sw_ins_id_get(TERM term) +{ + return id_table_retrieve(i_table, term); +} + +int prism_goal_count(void) +{ + return id_table_count(g_table); +} + +int prism_sw_count(void) +{ + return id_table_count(s_table); +} + +int prism_sw_ins_count(void) +{ + return id_table_count(i_table); +} + +TERM prism_goal_term(IDNUM i) +{ + return id_table_id2term(g_table, i); +} + +TERM prism_sw_term(IDNUM i) +{ + return id_table_id2term(s_table, i); +} + +TERM prism_sw_ins_term(IDNUM i) +{ + return id_table_id2term(i_table, i); +} + +char * prism_goal_string(IDNUM i) +{ + return bp_term_2_string(prism_goal_term(i)); +} + +char * prism_sw_string(IDNUM i) +{ + return bp_term_2_string(prism_sw_term(i)); +} + +char * prism_sw_ins_string(IDNUM i) +{ + return bp_term_2_string(prism_sw_ins_term(i)); +} + +/* Note: the strings returned by strdup() should be released by the caller. */ +char * copy_prism_goal_string(IDNUM i) +{ + return strdup(prism_goal_string(i)); +} + +char * copy_prism_sw_string(IDNUM i) +{ + return strdup(prism_sw_string(i)); +} + +char * copy_prism_sw_ins_string(IDNUM i) +{ + return strdup(prism_sw_ins_string(i)); +} + +/*--------------------------------------------------------------------*/ + +int pc_prism_id_table_init_0(void) +{ + if (g_table != NULL) id_table_delete(g_table); + if (s_table != NULL) id_table_delete(s_table); + if (i_table != NULL) id_table_delete(i_table); + + g_table = id_table_create(); + s_table = id_table_create(); + i_table = id_table_create(); + + return BP_TRUE; +} + +int pc_prism_goal_id_register_2(void) +{ + TERM term; + IDNUM id; + + term = ARG(1,2); + XDEREF(term); + id = prism_goal_id_register(term); + + return unify(MAKEINT(id), ARG(2,2)); +} + +int pc_prism_sw_id_register_2(void) +{ + TERM term; + IDNUM id; + + term = ARG(1,2); + XDEREF(term); + id = prism_sw_id_register(term); + + return unify(MAKEINT(id), ARG(2,2)); +} + +int pc_prism_sw_ins_id_register_2(void) +{ + TERM term; + IDNUM id; + + term = ARG(1,2); + XDEREF(term); + id = prism_sw_ins_id_register(term); + + return unify(MAKEINT(id), ARG(2,2)); +} + +int pc_prism_goal_id_get_2(void) +{ + TERM term; + IDNUM id; + + term = ARG(1,2); + XDEREF(term); + + id = prism_goal_id_get(term); + if (id == ID_NONE) return BP_FALSE; + + return unify(MAKEINT(id), ARG(2,2)); +} + +int pc_prism_sw_id_get_2(void) +{ + TERM term; + IDNUM id; + + term = ARG(1,2); + XDEREF(term); + id = prism_sw_id_get(term); + if (id == ID_NONE) return BP_FALSE; + + return unify(MAKEINT(id), ARG(2,2)); +} + +int pc_prism_sw_ins_id_get_2(void) +{ + TERM term; + IDNUM id; + + term = ARG(1,2); + XDEREF(term); + id = prism_sw_ins_id_get(term); + if (id == ID_NONE) return BP_FALSE; + + return unify(MAKEINT(id), ARG(2,2)); +} + +int pc_prism_goal_count_1(void) +{ + return unify(MAKEINT(prism_goal_count()), ARG(1,1)); +} + +int pc_prism_sw_count_1(void) +{ + return unify(MAKEINT(prism_sw_count()), ARG(1,1)); +} + +int pc_prism_sw_ins_count_1(void) +{ + return unify(MAKEINT(prism_sw_ins_count()), ARG(1,1)); +} + +int pc_prism_goal_term_2(void) +{ + TERM id, term; + + id = ARG(1,2); + XDEREF(id); + term = unnumber_var_term(prism_goal_term((IDNUM)INTVAL(id))); + + return unify(term, ARG(2,2)); +} + +int pc_prism_sw_term_2(void) +{ + TERM id, term; + + id = ARG(1,2); + XDEREF(id); + + term = unnumber_var_term(prism_sw_term((IDNUM)INTVAL(id))); + + return unify(term, ARG(2,2)); +} + +int pc_prism_sw_ins_term_2(void) +{ + TERM id, term; + + id = ARG(1,2); + XDEREF(id); + term = unnumber_var_term(prism_sw_ins_term((IDNUM)INTVAL(id))); + + return unify(term, ARG(2,2)); +} diff --git a/packages/prism/src/c/core/idtable_preds.h b/packages/prism/src/c/core/idtable_preds.h new file mode 100644 index 000000000..d88109053 --- /dev/null +++ b/packages/prism/src/c/core/idtable_preds.h @@ -0,0 +1,41 @@ +#ifndef IDTABLE_AUX_H +#define IDTABLE_AUX_H + +/*--------------------------------------------------------------------*/ + +int prism_goal_id_register(TERM); +int prism_sw_id_register(TERM); +int prism_sw_ins_id_register(TERM); +int prism_goal_id_get(TERM); +int prism_sw_id_get(TERM); +int prism_sw_ins_id_get(TERM); +int prism_goal_count(void); +int prism_sw_id_count(void); +int prism_sw_ins_id_count(void); +TERM prism_goal_term(IDNUM); +TERM prism_sw_term(IDNUM); +TERM prism_sw_ins_term(IDNUM); +char * prism_goal_string(IDNUM); +char * prism_sw_string(IDNUM); +char * prism_sw_ins_string(IDNUM); +char * copy_prism_goal_string(IDNUM); +char * copy_prism_sw_string(IDNUM); +char * copy_prism_sw_ins_string(IDNUM); + +int pc_prism_id_table_init(void); +int pc_prism_goal_id_register(void); +int pc_prism_sw_id_register(void); +int pc_prism_sw_ins_id_register(void); +int pc_prism_goal_id_get(void); +int pc_prism_sw_id_get(void); +int pc_prism_sw_ins_id_get(void); +int pc_prism_goal_count(void); +int pc_prism_sw_count(void); +int pc_prism_sw_ins_count(void); +int pc_prism_goal_term(void); +int pc_prism_sw_term(void); +int pc_prism_sw_ins_term(void); + +/*--------------------------------------------------------------------*/ + +#endif /* IDTABLE_AUX_H */ diff --git a/packages/prism/src/c/core/random.c b/packages/prism/src/c/core/random.c new file mode 100644 index 000000000..97c30f232 --- /dev/null +++ b/packages/prism/src/c/core/random.c @@ -0,0 +1,360 @@ +/* + +This source module contains reduced (and slightly modified) version +of mt19937ar.c implemented by Makoto Matsumoto and Takuji Nishimura. +The original file is available in the following website: + + http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html + +Here is the original copyright notice. + +======================================================================== + + Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. The names of its contributors may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================================================================== + +*/ + +/***********[ REDUCED VERSION OF MT19937AR.C STARTS HERE ]***********/ + +/* Period parameters */ +#define N 624 +#define M 397 +#define MATRIX_A 0x9908b0dfUL /* constant vector a */ +#define UPPER_MASK 0x80000000UL /* most significant w-r bits */ +#define LOWER_MASK 0x7fffffffUL /* least significant r bits */ + +static unsigned long mt[N]; /* the array for the state vector */ +static int mti=N+1; /* mti==N+1 means mt[N] is not initialized */ + +/* initializes mt[N] with a seed */ +static void init_genrand(unsigned long s) +{ + mt[0]= s & 0xffffffffUL; + for (mti=1; mti> 30)) + mti); + /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ + /* In the previous versions, MSBs of the seed affect */ + /* only MSBs of the array mt[]. */ + /* 2002/01/09 modified by Makoto Matsumoto */ + mt[mti] &= 0xffffffffUL; + /* for >32 bit machines */ + } +} + +/* initialize by an array with array-length */ +/* init_key is the array for initializing keys */ +/* key_length is its length */ +/* slight change for C++, 2004/2/26 */ +void init_by_array(unsigned long init_key[], int key_length) +{ + int i, j, k; + init_genrand(19650218UL); + i=1; + j=0; + k = (N>key_length ? N : key_length); + for (; k; k--) { + mt[i] = (mt[i] ^ ((mt[i-1] ^ (mt[i-1] >> 30)) * 1664525UL)) + + init_key[j] + j; /* non linear */ + mt[i] &= 0xffffffffUL; /* for WORDSIZE > 32 machines */ + i++; + j++; + if (i>=N) { + mt[0] = mt[N-1]; + i=1; + } + if (j>=key_length) j=0; + } + for (k=N-1; k; k--) { + mt[i] = (mt[i] ^ ((mt[i-1] ^ (mt[i-1] >> 30)) * 1566083941UL)) + - i; /* non linear */ + mt[i] &= 0xffffffffUL; /* for WORDSIZE > 32 machines */ + i++; + if (i>=N) { + mt[0] = mt[N-1]; + i=1; + } + } + + mt[0] = 0x80000000UL; /* MSB is 1; assuring non-zero initial array */ +} + +/* generates a random number on [0,0xffffffff]-interval */ +static unsigned long genrand_int32(void) +{ + unsigned long y; + static unsigned long mag01[2]={0x0UL, MATRIX_A}; + /* mag01[x] = x * MATRIX_A for x=0,1 */ + + if (mti >= N) { /* generate N words at one time */ + int kk; + + if (mti == N+1) /* if init_genrand() has not been called, */ + init_genrand(5489UL); /* a default initial seed is used */ + + for (kk=0;kk> 1) ^ mag01[y & 0x1UL]; + } + for (;kk> 1) ^ mag01[y & 0x1UL]; + } + y = (mt[N-1]&UPPER_MASK)|(mt[0]&LOWER_MASK); + mt[N-1] = mt[M-1] ^ (y >> 1) ^ mag01[y & 0x1UL]; + + mti = 0; + } + + y = mt[mti++]; + + /* Tempering */ + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680UL; + y ^= (y << 15) & 0xefc60000UL; + y ^= (y >> 18); + + return y; +} + +/* generates a random number on [0,1) with 53-bit resolution */ +static double genrand_res53(void) +{ + unsigned long a=genrand_int32()>>5, b=genrand_int32()>>6; + return(a*67108864.0+b)*(1.0/9007199254740992.0); +} +/* These real versions are due to Isaku Wada, 2002/01/09 added */ + +/***********[ REDUCED VERSION OF MT19937AR.C ENDS HERE ]***********/ + +/*--------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include "core/bpx.h" +#include "core/random.h" +#include "core/vector.h" + +#ifndef M_PI +#define M_PI (3.14159265358979324) +#endif + +static int gauss_flag = 0; + +/*--------------------------------------------------------------------*/ + +int random_int(int n) +{ + unsigned long p, q, r; + + assert(n > 0); + + if (n == 1) { + return 0; + } + + p = 0xFFFFFFFFul - (0xFFFFFFFFul % n + 1) % n; + q = p / n + 1; + + while ((r = genrand_int32()) > p) ; + return (int)(r / q); +} + +double random_float(void) +{ + return genrand_res53(); +} + +/* Box-Muller method */ +double random_gaussian(double mu, double sigma) +{ + double u1, u2; + static double g1, g2; + + gauss_flag = !(gauss_flag); + + if (gauss_flag) { + u1 = genrand_res53(); + u2 = genrand_res53(); + g1 = sqrt(-2.0 * log(u1)) * cos(2.0 * M_PI * u2); + g2 = sqrt(-2.0 * log(u1)) * sin(2.0 * M_PI * u2); + return sigma * g1 + mu; + } + else { + return sigma * g2 + mu; + } +} + +/* N(0,1)-version: +double random_gaussian(void) +{ + double u1, u2; + static double next; + + gauss_flag = !(gauss_flag); + + if (gauss_flag) { + do { + u1 = genrand_res53(); + } + while (u1 == 0.0); + do { + u2 = genrand_res53(); + } + while (u2 == 0.0); + next = sqrt(-2.0 * log(u1)) * sin(2.0 * M_PI * u2); + return sqrt(-2.0 * log(u1)) * cos(2.0 * M_PI * u2); + } + else { + return next; + } +} +*/ + +/*--------------------------------------------------------------------*/ + +int pc_random_auto_seed_1(void) +{ + BPLONG seed = (BPLONG)(time(NULL)); + return bpx_unify(ARG(1,1), bpx_build_integer(seed)); +} + +int pc_random_init_by_seed_1(void) +{ + init_genrand((unsigned long)(bpx_get_integer(ARG(1,1)))); + return BP_TRUE; +} + +int pc_random_init_by_list_1(void) +{ + unsigned long *seed; + TERM t, u; + + VECTOR_INIT(seed); + + t = ARG(1,1); + + while (! bpx_is_nil(t)) { + u = bpx_get_car(t); + t = bpx_get_cdr(t); + VECTOR_PUSH(seed, (unsigned long)(bpx_get_integer(u))); + } + + init_by_array(seed, VECTOR_SIZE(seed)); + return BP_TRUE; +} + +int pc_random_float_1(void) +{ + return bpx_unify(ARG(1,1), bpx_build_float(random_float())); +} + +int pc_random_gaussian_1(void) +{ + return bpx_unify(ARG(1,1), bpx_build_float(random_gaussian(0.0,1.0))); +} + +int pc_random_int_2(void) +{ + int n_max = bpx_get_integer(ARG(1,2)); + int n_out = random_int(n_max); + return bpx_unify(ARG(2,2), bpx_build_integer((BPLONG)(n_out))); +} + +int pc_random_int_3(void) +{ + int n_min = bpx_get_integer(ARG(1,3)); + int n_max = bpx_get_integer(ARG(2,3)); + int n_out = random_int(n_max - n_min + 1) + n_min; + return bpx_unify(ARG(3,3), bpx_build_integer((BPLONG)(n_out))); +} + +/*--------------------------------------------------------------------*/ + +int pc_random_get_state_1(void) +{ + int i, j; + TERM t, u; + unsigned long temp; + + t = bpx_build_structure("$randstate", 4 * N / 3 + 1); + bpx_unify(bpx_get_arg(1, t), bpx_build_integer(mti)); + + for (i = 0; i < 4 * N / 3; i++) { + j = i / 4 * 3; + temp = 0; + + if (i % 4 > 0) { + temp |= mt[j + i % 4 - 1] << (8 * (3 - i % 4)); + } + if (i % 4 < 3) { + temp |= mt[j + i % 4 - 0] >> (8 * (1 + i % 4)); + } + + temp &= 0xFFFFFF; /* == 2^24 - 1 */ + u = bpx_get_arg(i + 2, t); + bpx_unify(u, bpx_build_integer(temp)); + } + + return bpx_unify(ARG(1,1), t); +} + +int pc_random_set_state_1(void) +{ + int i, j; + TERM term; + unsigned long temp; + + term = ARG(1,1); + + assert(strcmp(bpx_get_name(term), "$randstate") == 0); + assert(bpx_get_arity(term) == 4 * N / 3 + 1); + + mti = bpx_get_integer(bpx_get_arg(1, term)); + + for (i = 0; i < N; i++) { + j = i / 3 * 4; + mt[i] = 0; + temp = bpx_get_integer(bpx_get_arg(j + i % 3 + 2, term)); + mt[i] |= temp << (8 * (1 + i % 3)); + temp = bpx_get_integer(bpx_get_arg(j + i % 3 + 3, term)); + mt[i] |= temp >> (8 * (2 - i % 3)); + mt[i] &= 0xFFFFFFFF; + } + + return BP_TRUE; +} + +/*--------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/core/random.h b/packages/prism/src/c/core/random.h new file mode 100644 index 000000000..c9ff6d13a --- /dev/null +++ b/packages/prism/src/c/core/random.h @@ -0,0 +1,14 @@ +#ifndef RANDOM_H +#define RANDOM_H + +#include + +/*--------------------------------------------------------------------*/ + +int random_int(int); +double random_float(void); +double random_gaussian(double, double); + +/*--------------------------------------------------------------------*/ + +#endif /* RANDOM_H */ diff --git a/packages/prism/src/c/core/stuff.h b/packages/prism/src/c/core/stuff.h new file mode 100644 index 000000000..365eea205 --- /dev/null +++ b/packages/prism/src/c/core/stuff.h @@ -0,0 +1,23 @@ +#ifndef STUFF_H +#define STUFF_H + +/*--------------------------------------------------------------------*/ + +typedef enum { false, true } bool; + +/*--------------------------------------------------------------------*/ + +#if defined _MSC_VER +#define NORET void __declspec(noreturn) +#define PRINTF_LIKE_FUNC(m, n) /* empty */ +#elif defined __GNUC__ +#define NORET void __attribute__((noreturn)) +#define PRINTF_LIKE_FUNC(m, n) __attribute__((format(printf, m, n))) +#else /* other */ +#define NORET void +#define PRINTF_LIKE_FUNC(m, n) /* empty */ +#endif + +/*--------------------------------------------------------------------*/ + +#endif /* STUFF_H */ diff --git a/packages/prism/src/c/core/termpool.c b/packages/prism/src/c/core/termpool.c new file mode 100644 index 000000000..565ef17ed --- /dev/null +++ b/packages/prism/src/c/core/termpool.c @@ -0,0 +1,424 @@ +#include +#include "core/termpool.h" +#include "core/xmalloc.h" +#include "core/vector.h" +#include "core/stuff.h" + +/* FIXME */ +#define prism_quit(msg) quit("*** {PRISM FATAL ERROR: " msg "}\n") +NORET quit(const char *); + +/*--------------------------------------------------------------------*/ + +/* [04 Apr 2009, by yuizumi] + * This value should be sufficiently large enough to have malloc(3) + * return an address with its top bit set on 32-bit Linux systems. + */ +#define BLOCK_SIZE 1048576 + +/*--------------------------------------------------------------------*/ + +/* [05 Apr 2009, by yuizumi] + * The area referred by this variable is shared by prism_hash_value() + * and term_pool_store(), under the assumption that BPLONG values and + * BPLONG_PTR values (i.e. pointers) are aligned in the same way even + * without cast operations. + */ +static BPLONG_PTR work; + +/*--------------------------------------------------------------------*/ + +struct term_pool { + BPLONG_PTR head; + BPLONG_PTR curr; + BPLONG_PTR tail; + struct hash_entry **bucks; + size_t nbucks; + size_t count; +}; + +struct hash_entry { + TERM term; + BPULONG hash; + struct hash_entry *next; +}; + +/*--------------------------------------------------------------------*/ +/* Functions from B-Prolog */ + +/* mic.c */ +void c_STATISTICS(void); + +/* table.c */ +void numberVarTermOpt(TERM); +TERM unnumberVarTerm(TERM, BPLONG_PTR, BPLONG_PTR); + +/* unify.c */ +int unifyNumberedTerms(TERM, TERM); + +/*--------------------------------------------------------------------*/ + +static ptrdiff_t trail_pos0 = 0; + +static void number_vars(TERM term) +{ + assert(trail_pos0 == 0); + + trail_pos0 = trail_up_addr - trail_top; + PRE_NUMBER_VAR(0); + numberVarTermOpt(term); + + if (number_var_exception != 0) { + prism_quit("suspension variables not supported in Prism"); + } +} + +static void revert_vars(void) +{ + BPLONG_PTR trail_top0; + + assert(trail_pos0 != 0); + + trail_top0 = trail_up_addr - trail_pos0; + UNDO_TRAILING; + trail_pos0 = 0; +} + +/* [29 Mar 2009, by yuizumi] + * See Also: "Algorithms in C, Third Edition," by Robert Sedgewick, + * Addison-Wesley, 1998. + */ +static BPULONG prism_hash_value(TERM term) +{ + TERM t, *rest; + BPLONG i, n; + SYM_REC_PTR sym; + + BPULONG a = 2130563839ul; + BPULONG b = 1561772629ul; + BPULONG h = 0; + BPULONG u; + + rest = (TERM *)work; + + VECTOR_PUSH(rest, term); + + while (! VECTOR_EMPTY(rest)) { + t = VECTOR_POP(rest); + +nderef_loop: + switch (XTAG(t)) { + case REF0: + case REF1: + XNDEREF(t, nderef_loop); + assert(false); /* numbered by number_vars() */ + + case ATM: + case INT: + case NVAR: + u = (BPULONG)t; + break; + + case LST: + VECTOR_PUSH(rest, GET_CDR(t)); + VECTOR_PUSH(rest, GET_CAR(t)); + u = (BPULONG)LST; + break; + + case STR: + sym = GET_STR_SYM_REC(t); + n = GET_ARITY_STR(sym); + for (i = n; i >= 1; i--) { + VECTOR_PUSH(rest, GET_ARG(t, i)); + } + u = (BPULONG)ADDTAG(sym, STR); + break; + + case SUSP: + assert(false); /* rejected by number_vars() */ + + default: + assert(false); + } + h = (a * h) + (BPULONG)(u); + a *= b; + } + + work = (BPLONG *)rest; + return h; +} + +/*--------------------------------------------------------------------*/ + +static BPLONG_PTR term_pool_allocate(TERM_POOL *this, size_t size) +{ + BPLONG_PTR p_tmp; + + assert(size <= MAX_ARITY + 1); + + if (this->head == NULL || this->curr + size > this->tail) { + BP_MALLOC(p_tmp, BLOCK_SIZE, "(prism part)"); + *p_tmp = (BPLONG)(this->head); + this->head = p_tmp + 0; + this->curr = p_tmp + 1; + this->tail = p_tmp + BLOCK_SIZE; + } + + p_tmp = this->curr; + this->curr += size; + return p_tmp; +} + +/*--------------------------------------------------------------------*/ + +static TERM term_pool_store(TERM_POOL *this, TERM term) +{ + TERM *p, *q, **rest; + BPLONG i, n; + + SYM_REC_PTR sym; + + rest = (void *)(work); + + VECTOR_PUSH(rest, &term); + + while (! VECTOR_EMPTY(rest)) { + p = VECTOR_POP(rest); + +nderef_loop: + switch (XTAG(*p)) { + case REF0: + case REF1: + XNDEREF(*p, nderef_loop); + assert(false); /* numbered by number_vars() */ + + case ATM: + case INT: + case NVAR: + break; + + case LST: + q = term_pool_allocate(this, 2); + *(q + 1) = GET_CDR(*p); + VECTOR_PUSH(rest, q + 1); + *(q + 0) = GET_CAR(*p); + VECTOR_PUSH(rest, q + 0); + *p = ADDTAG(q, LST); + break; + + case STR: + sym = GET_STR_SYM_REC(*p); + n = GET_ARITY_STR(sym); + q = term_pool_allocate(this, n + 1); + *q = (TERM)(sym); + for (i = n; i >= 1; i--) { + *(q + i) = GET_ARG(*p, i); + VECTOR_PUSH(rest, q + i); + } + *p = ADDTAG(q, STR); + break; + + case SUSP: + assert(false); /* rejected by number_vars() */ + + default: + assert(false); + } + } + + work = (void *)(rest); + return term; +} + +/*--------------------------------------------------------------------*/ + +static void term_pool_rehash(TERM_POOL *this) +{ + struct hash_entry **bucks, *p, *q; + size_t nbucks, i; + + nbucks = 2 * this->nbucks + 1; + + /* find the next prime number */ + for (i = 3; i * i <= nbucks; ) { + if (nbucks % i == 0) { + nbucks += 2; + i = 3; + } + else { + i += 2; + } + } + + bucks = MALLOC(sizeof(struct hash_entry *) * nbucks); + + for (i = 0; i < nbucks; i++) + bucks[i] = NULL; + + for (i = 0; i < this->nbucks; i++) { + p = this->bucks[i]; + + while (p != NULL) { + q = p; + p = p->next; + q->next = bucks[q->hash % nbucks]; + bucks[q->hash % nbucks] = q; + } + } + + FREE(this->bucks); + + this->nbucks = nbucks; + this->bucks = bucks; +} + +/*--------------------------------------------------------------------*/ + +static TERM term_pool_search(const TERM_POOL *this, TERM term, BPULONG hash) +{ + struct hash_entry *p; + + p = this->bucks[hash % this->nbucks]; + + while (p != NULL) { + if (hash == p->hash) { + if (unifyNumberedTerms(term, p->term)) { + return p->term; + } + } + p = p->next; + } + + return NULL_TERM; +} + +static TERM term_pool_insert(TERM_POOL *this, TERM term, BPULONG hash) +{ + struct hash_entry *entry; + + if (++(this->count) >= this->nbucks) + term_pool_rehash(this); + + entry = MALLOC(sizeof(struct hash_entry)); + entry->term = term_pool_store(this, term); + entry->hash = hash; + entry->next = this->bucks[hash % this->nbucks]; + this->bucks[hash % this->nbucks] = entry; + + return entry->term; +} + +/*--------------------------------------------------------------------*/ + +static TERM term_pool_intern(const TERM_POOL *this1, TERM_POOL *this2, TERM term) +{ + BPULONG hash; + TERM rval; + + assert(this2 == NULL || this2 == this1); + +nderef_loop: + switch (XTAG(term)) { + case REF0: + case REF1: + XNDEREF(term, nderef_loop); + return MAKE_NVAR(0); + + case ATM: + case INT: + case NVAR: + return term; + + case LST: + case STR: + break; + + case SUSP: + prism_quit("suspension variables not supported in Prism"); + + default: + assert(false); + } + + number_vars(term); + + hash = prism_hash_value(term); + rval = term_pool_search(this1, term, hash); + + if (rval == NULL_TERM && this2 != NULL) { + rval = term_pool_insert(this2, term, hash); + } + + revert_vars(); + + return rval; +} + +/*--------------------------------------------------------------------*/ + +TERM_POOL * term_pool_create(void) +{ + TERM_POOL *this; + int i; + + this = MALLOC(sizeof(struct term_pool)); + + this->head = NULL; + this->curr = NULL; + this->tail = NULL; + this->nbucks = 17; + this->count = 0; + this->bucks = MALLOC(sizeof(struct hash_entry *) * this->nbucks); + + for (i = 0; i < this->nbucks; i++) + this->bucks[i] = NULL; + + if (work == NULL) { + VECTOR_INIT_CAPA(work, 4096); + } + + return this; +} + +/*--------------------------------------------------------------------*/ + +void term_pool_delete(TERM_POOL *this) +{ + BPLONG_PTR p1, p2; + struct hash_entry *q1, *q2; + int i; + + p1 = this->head; + + while (p1 != NULL) { + p2 = p1; + p1 = (BPLONG_PTR)(*p1); + FREE(p2); + } + + for (i = 0; i < this->nbucks; i++) { + q1 = this->bucks[i]; + while (q1 != NULL) { + q2 = q1; + q1 = q1->next; + FREE(q2); + } + } + + FREE(this->bucks); + FREE(this); +} + +/*--------------------------------------------------------------------*/ + +TERM term_pool_retrieve(const TERM_POOL *this, TERM term) +{ + return term_pool_intern(this, NULL, term); +} + +TERM term_pool_register(TERM_POOL *this, TERM term) +{ + return term_pool_intern(this, this, term); +} + +/*--------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/core/termpool.h b/packages/prism/src/c/core/termpool.h new file mode 100644 index 000000000..7deba3cdf --- /dev/null +++ b/packages/prism/src/c/core/termpool.h @@ -0,0 +1,20 @@ +#ifndef TERMPOOL_H +#define TERMPOOL_H + +#include "bpx.h" + +/*--------------------------------------------------------------------*/ + +typedef struct term_pool TERM_POOL; + +/*--------------------------------------------------------------------*/ + +TERM_POOL * term_pool_create(void); +void term_pool_delete(TERM_POOL *); + +TERM term_pool_retrieve(const TERM_POOL *, TERM); +TERM term_pool_register(TERM_POOL *, TERM); + +/*--------------------------------------------------------------------*/ + +#endif /* TERMPOOL_H */ diff --git a/packages/prism/src/c/core/vector.c b/packages/prism/src/c/core/vector.c new file mode 100644 index 000000000..3dad5c980 --- /dev/null +++ b/packages/prism/src/c/core/vector.c @@ -0,0 +1,87 @@ +#include "core/xmalloc.h" +#include "core/vector.h" +#include + +/*--------------------------------------------------------------------*/ + +#define INITIAL_CAPA 16 + +#undef VECTOR_SIZE +#undef VECTOR_CAPA + +/* allow these to be L-values */ +#define VECTOR_SIZE(v) (((size_t *)(v))[-1]) +#define VECTOR_CAPA(v) (((size_t *)(v))[-2]) + +/*--------------------------------------------------------------------*/ + +void * vector_create(size_t unit, size_t size, size_t capa) +{ + void *ptr, *vec; + ptr = MALLOC(sizeof(size_t) * 2 + unit * capa); + vec = ((size_t *)(ptr)) + 2; + VECTOR_SIZE(vec) = size; + VECTOR_CAPA(vec) = capa; + return vec; +} + +void vector_delete(void *vec) +{ + free(((size_t *)(vec)) - 2); +} + +void * vector_expand(void *vec, size_t unit) +{ + size_t capa; + + if (VECTOR_SIZE(vec) >= VECTOR_CAPA(vec)) { + capa = VECTOR_CAPA(vec) * 2; + if (capa < INITIAL_CAPA) { + capa = INITIAL_CAPA; + } + vec = vector_realloc(vec, unit, capa); + } + + ++(VECTOR_SIZE(vec)); + return vec; +} + +void * vector_reduce(void *vec) +{ + assert(VECTOR_SIZE(vec) > 0); + --(VECTOR_SIZE(vec)); + return vec; +} + +void * vector_resize(void *vec, size_t unit, size_t size) +{ + vec = vector_reserve(vec, unit, size); + VECTOR_SIZE(vec) = size; + return vec; +} + +void * vector_reserve(void *vec, size_t unit, size_t capa) +{ + if (VECTOR_CAPA(vec) < capa) { + vec = vector_realloc(vec, unit, capa); + } + return vec; +} + +void * vector_realloc(void *vec, size_t unit, size_t capa) +{ + void *ptr; + + if (VECTOR_CAPA(vec) == capa) + return vec; + + assert(VECTOR_SIZE(vec) <= capa); + + ptr = ((size_t *)(vec)) - 2; + ptr = REALLOC(ptr, sizeof(size_t) * 2 + unit * capa); + vec = ((size_t *)(ptr)) + 2; + VECTOR_CAPA(vec) = capa; + return vec; +} + +/*--------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/core/vector.h b/packages/prism/src/c/core/vector.h new file mode 100644 index 000000000..7373f864f --- /dev/null +++ b/packages/prism/src/c/core/vector.h @@ -0,0 +1,59 @@ +#ifndef VECTOR_H +#define VECTOR_H + +#include "stddef.h" + +/*--------------------------------------------------------------------*/ + +#define VECTOR_INIT(v) \ + ((v) = vector_create(sizeof(*(v)), 0, 0)) +#define VECTOR_INIT_SIZE(v, n) \ + ((v) = vector_create(sizeof(*(v)), n, n)) +#define VECTOR_INIT_CAPA(v, m) \ + ((v) = vector_create(sizeof(*(v)), 0, m)) + +#define VECTOR_FREE(v) \ + ((v) = (vector_delete(v), NULL)) + +/*--------------------------------------------------------------------*/ + +#define VECTOR_SIZE(v) \ + ((size_t)(((const size_t *)(v))[-1])) +#define VECTOR_CAPA(v) \ + ((size_t)(((const size_t *)(v))[-2])) + +#define VECTOR_PUSH(v, x) \ + ((v) = vector_expand(v, sizeof(*(v))), (v)[VECTOR_SIZE(v) - 1] = (x)) +#define VECTOR_POP(v) \ + ((v) = vector_reduce(v), (v)[VECTOR_SIZE(v)]) + +#define VECTOR_PUSH_NONE(v) \ + ((v) = vector_expand(v, sizeof(*(v)))) + +#define VECTOR_RESIZE(v, n) \ + ((v) = vector_resize(v, sizeof(*(v)), n)) +#define VECTOR_RESERVE(v, m) \ + ((v) = vector_reserve(v, sizeof(*(v)), m)) +#define VECTOR_STRIP(v) \ + ((v) = vector_realloc(v, sizeof(*(v)), VECTOR_SIZE(v))) + +#define VECTOR_CLEAR(v) \ + ((void)(((const size_t *)(v))[-1] = 0)) +#define VECTOR_EMPTY(v) \ + (VECTOR_SIZE(v) == 0) + +/*--------------------------------------------------------------------*/ + +void * vector_create(size_t, size_t, size_t); +void vector_delete(void *); + +void * vector_expand(void *, size_t); +void * vector_reduce(void *); + +void * vector_resize(void *, size_t, size_t); +void * vector_reserve(void *, size_t, size_t); +void * vector_realloc(void *, size_t, size_t); + +/*--------------------------------------------------------------------*/ + +#endif /* VECTOR_H */ diff --git a/packages/prism/src/c/core/xmalloc.c b/packages/prism/src/c/core/xmalloc.c new file mode 100644 index 000000000..48fadd141 --- /dev/null +++ b/packages/prism/src/c/core/xmalloc.c @@ -0,0 +1,35 @@ +#include +#include +#include "core/xmalloc.h" + +/*--------------------------------------------------------------------*/ + +void * xmalloc +(size_t size, const char *file, unsigned int line) +{ + void *ptr; + ptr = malloc(size); + + if (ptr == NULL) { + fprintf(stderr, "Out of memory in %s(%u)\n", file, line); + exit(1); /* FIXME */ + } + + return ptr; +} + +void * xrealloc +(void *oldptr, size_t size, const char *file, unsigned int line) +{ + void *newptr; + newptr = realloc(oldptr, size); + + if (newptr == NULL && size > 0) { + fprintf(stderr, "Out of memory in %s(%u)\n", file, line); + exit(1); /* FIXME */ + } + + return newptr; +} + +/*--------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/core/xmalloc.h b/packages/prism/src/c/core/xmalloc.h new file mode 100644 index 000000000..95e0d97c8 --- /dev/null +++ b/packages/prism/src/c/core/xmalloc.h @@ -0,0 +1,25 @@ +#ifndef XMALLOC_H +#define XMALLOC_H + +#include + +/*--------------------------------------------------------------------*/ + +void * xmalloc(size_t, const char *, unsigned int); +void * xrealloc(void *, size_t, const char *, unsigned int); + +/*--------------------------------------------------------------------*/ + +#ifdef MALLOC_TRACE +# define MALLOC(size) malloc((size)) +# define REALLOC(oldptr,size) realloc((oldptr),(size)) +# define FREE(ptr) (free(ptr), (ptr) = NULL) +#else +# define MALLOC(size) xmalloc((size), __FILE__, __LINE__) +# define REALLOC(oldptr,size) xrealloc((oldptr), (size), __FILE__, __LINE__) +# define FREE(ptr) (free(ptr), (ptr) = NULL) +#endif + +/*--------------------------------------------------------------------*/ + +#endif /* XMALLOC_H */ diff --git a/packages/prism/src/c/makefiles/Makefile.files b/packages/prism/src/c/makefiles/Makefile.files new file mode 100644 index 000000000..8b99611f2 --- /dev/null +++ b/packages/prism/src/c/makefiles/Makefile.files @@ -0,0 +1,56 @@ +# -*- Makefile -*- + +##---------------------------------------------------------------------- + +CORE_OBJS = core$(S)glue.$(O) \ + core$(S)bpx.$(O) \ + core$(S)idtable.$(O) \ + core$(S)idtable_preds.$(O) \ + core$(S)termpool.$(O) \ + core$(S)vector.$(O) \ + core$(S)random.$(O) \ + core$(S)gamma.$(O) \ + core$(S)xmalloc.$(O) \ + core$(S)fputil.$(O) \ + core$(S)error.$(O) + +UP_OBJS = up$(S)graph.$(O) \ + up$(S)graph_aux.$(O) \ + up$(S)em_preds.$(O) \ + up$(S)em_ml.$(O) \ + up$(S)em_vb.$(O) \ + up$(S)em_aux.$(O) \ + up$(S)em_aux_ml.$(O) \ + up$(S)em_aux_vb.$(O) \ + up$(S)viterbi.$(O) \ + up$(S)hindsight.$(O) \ + up$(S)flags.$(O) \ + up$(S)util.$(O) + +MP_OBJS = mp$(S)mp_core.$(O) \ + mp$(S)mp_em_aux.$(O) \ + mp$(S)mp_em_ml.$(O) \ + mp$(S)mp_em_preds.$(O) \ + mp$(S)mp_em_vb.$(O) \ + mp$(S)mp_flags.$(O) \ + mp$(S)mp_preds.$(O) \ + mp$(S)mp_sw.$(O) + +OBJS = $(CORE_OBJS) $(UP_OBJS) + +##---------------------------------------------------------------------- + +INSTALLDIR = ..$(S)..$(S)bin + +CORE_DIR = core +UP_DIR = up +MP_DIR = mp + +SUBDIRS = $(CORE_DIR) $(UP_DIR) + +##---------------------------------------------------------------------- + +#BP4P_A = bp4prism$(S)lib$(S)bp4prism-$(PLATFORM).$(A) +BP4P_A = + +##---------------------------------------------------------------------- diff --git a/packages/prism/src/c/makefiles/README b/packages/prism/src/c/makefiles/README new file mode 100644 index 000000000..c1bcc0a0f --- /dev/null +++ b/packages/prism/src/c/makefiles/README @@ -0,0 +1,11 @@ +===================== README (src/c/makefiles) ===================== + +This directory contains the Makefiles which are included into the +Makefiles in the above directory: + + Makefile.opts.gmake ... settings for GNU make + Makefile.opts.nmake ... settings for nmake (MSVC++) + Makefile.files ... source file names + +If you would like to change the default settings, please modify +these Makefiles. diff --git a/packages/prism/src/c/mp/mp.h b/packages/prism/src/c/mp/mp.h new file mode 100644 index 000000000..cc297a476 --- /dev/null +++ b/packages/prism/src/c/mp/mp.h @@ -0,0 +1,21 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_H +#define MP_H + +/*-------------------------------------------------------------------------*/ + +#include + +/*-------------------------------------------------------------------------*/ + +#define TAG_GOAL_REQ (1) +#define TAG_GOAL_LEN (2) +#define TAG_GOAL_STR (3) + +#define TAG_SWITCH_REQ (4) +#define TAG_SWITCH_RES (5) + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_H */ diff --git a/packages/prism/src/c/mp/mp_core.c b/packages/prism/src/c/mp/mp_core.c new file mode 100644 index 000000000..049cf7754 --- /dev/null +++ b/packages/prism/src/c/mp/mp_core.c @@ -0,0 +1,101 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +/* [27 Aug 2007, by yuizumi] + * FIXME: mp_debug() is currently platform-dependent. + */ + +#ifdef MPI + +#include "up/up.h" +#include "mp/mp.h" +#include +#include +#include +#include +#include /* STDOUT_FILENO */ +#include + +/* Currently mpprism works only on Linux systems. */ +#define DEV_NULL "/dev/null" + +/*-------------------------------------------------------------------------*/ + +int fd_dup_stdout = -1; + +int mp_size; +int mp_rank; + +/*-------------------------------------------------------------------------*/ + +static void close_stdout(void) +{ + fd_dup_stdout = dup(STDOUT_FILENO); + + if (fd_dup_stdout < 0) + return; + + if (freopen(DEV_NULL, "w", stdout) == NULL) { + close(fd_dup_stdout); + fd_dup_stdout = -1; + } +} + +/*-------------------------------------------------------------------------*/ + +void mp_init(int *argc, char **argv[]) +{ + MPI_Init(argc, argv); + + MPI_Comm_size(MPI_COMM_WORLD, &mp_size); + MPI_Comm_rank(MPI_COMM_WORLD, &mp_rank); + + if (mp_size < 2) { + printf("Two or more processes required to run mpprism.\n"); + MPI_Finalize(); + exit(1); + } + + if (mp_rank > 0) { + close_stdout(); + } +} + +void mp_done(void) +{ + MPI_Finalize(); +} + +NORET mp_quit(int status) +{ + fprintf(stderr, "The system is aborted by Rank #%d.\n", mp_rank); + MPI_Abort(MPI_COMM_WORLD, status); + exit(status); /* should not reach here */ +} + +/*-------------------------------------------------------------------------*/ + +void mp_debug(const char *fmt, ...) +{ +#ifdef MP_DEBUG + char str[1024]; + va_list ap; + struct timeval tv; + int s, u; + + va_start(ap, fmt); + vsnprintf(str, sizeof(str), fmt, ap); + va_end(ap); + + gettimeofday(&tv, NULL); + + s = tv.tv_sec; + u = tv.tv_usec; + + fprintf(stderr, "[RANK:%d] %02d:%02d:%02d.%03d -- %s\n", + mp_rank, (s / 3600) % 24, (s / 60) % 60, s % 60, u / 1000, str); +#endif +} + +/*-------------------------------------------------------------------------*/ + +#endif /* MPI */ diff --git a/packages/prism/src/c/mp/mp_core.h b/packages/prism/src/c/mp/mp_core.h new file mode 100644 index 000000000..962220dac --- /dev/null +++ b/packages/prism/src/c/mp/mp_core.h @@ -0,0 +1,19 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_CORE_H +#define MP_CORE_H + +/*-------------------------------------------------------------------------*/ + +extern int mp_size; +extern int mp_rank; +extern int fd_dup_stdout; + +/*-------------------------------------------------------------------------*/ + +void mp_debug(const char *, ...); +NORET mp_quit(int); + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_CORE_H */ diff --git a/packages/prism/src/c/mp/mp_em_aux.c b/packages/prism/src/c/mp/mp_em_aux.c new file mode 100644 index 000000000..10baa6a64 --- /dev/null +++ b/packages/prism/src/c/mp/mp_em_aux.c @@ -0,0 +1,256 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifdef MPI + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "up/up.h" +#include "up/em.h" +#include "up/graph.h" +#include "mp/mp.h" +#include "mp/mp_core.h" +#include "mp/mp_sw.h" +#include + +/*------------------------------------------------------------------------*/ + +int sw_msg_size = 0; +static void * sw_msg_send = NULL; +static void * sw_msg_recv = NULL; + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +NORET quit(const char *); + +/*------------------------------------------------------------------------*/ + +void alloc_sw_msg_buffers(void) +{ + sw_msg_send = MALLOC(sizeof(double) * sw_msg_size); + sw_msg_recv = MALLOC(sizeof(double) * sw_msg_size); +} + +void release_sw_msg_buffers(void) +{ + free(sw_msg_send); + sw_msg_send = NULL; + free(sw_msg_recv); + sw_msg_recv = NULL; +} + +/*------------------------------------------------------------------------*/ + +void mpm_bcast_fixed(void) +{ + SW_INS_PTR sw_ins_ptr; + char *meg_ptr; + int i; + + meg_ptr = sw_msg_send; + + for (i = 0; i < occ_switch_tab_size; i++) { + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + *(meg_ptr++) = (!!sw_ins_ptr->fixed) | ((!!sw_ins_ptr->fixed_h) << 1); + } + } + + MPI_Bcast(sw_msg_send, sw_msg_size, MPI_CHAR, 0, MPI_COMM_WORLD); + mp_debug("mpm_bcast_fixed"); +} + +void mps_bcast_fixed(void) +{ + SW_INS_PTR sw_ins_ptr; + char *meg_ptr; + int i; + + MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_CHAR, 0, MPI_COMM_WORLD); + mp_debug("mps_bcast_fixed"); + + for (i = 0; i < occ_switch_tab_size; i++) { + meg_ptr = sw_msg_recv; + meg_ptr += occ_position[i]; + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + sw_ins_ptr->fixed = !!(*meg_ptr & 1); + sw_ins_ptr->fixed_h = !!(*meg_ptr & 2); + meg_ptr++; + } + } +} + +void mpm_bcast_inside(void) +{ + SW_INS_PTR sw_ins_ptr; + double *meg_ptr; + int i; + + meg_ptr = sw_msg_send; + + for (i = 0; i < occ_switch_tab_size; i++) { + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + *(meg_ptr++) = sw_ins_ptr->inside; + } + } + + MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD); + mp_debug("mpm_bcast_inside"); +} + +void mps_bcast_inside(void) +{ + SW_INS_PTR sw_ins_ptr; + double *meg_ptr; + int i; + + MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD); + mp_debug("mps_bcast_inside"); + + for (i = 0; i < occ_switch_tab_size; i++) { + meg_ptr = sw_msg_recv; + meg_ptr += occ_position[i]; + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + sw_ins_ptr->inside = *(meg_ptr++); + } + } +} + +void mpm_bcast_inside_h(void) +{ + SW_INS_PTR sw_ins_ptr; + double *meg_ptr; + int i; + + meg_ptr = sw_msg_send; + + for (i = 0; i < occ_switch_tab_size; i++) { + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + *(meg_ptr++) = sw_ins_ptr->inside_h; + } + } + + MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD); + mp_debug("mpm_bcast_inside_h"); +} + +void mps_bcast_inside_h(void) +{ + SW_INS_PTR sw_ins_ptr; + double *meg_ptr; + int i; + + MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD); + mp_debug("mps_bcast_inside_h"); + + for (i = 0; i < occ_switch_tab_size; i++) { + meg_ptr = sw_msg_recv; + meg_ptr += occ_position[i]; + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + sw_ins_ptr->inside_h = *(meg_ptr++); + } + } +} + +void mpm_bcast_smooth(void) +{ + SW_INS_PTR sw_ins_ptr; + double *meg_ptr; + int i; + + meg_ptr = sw_msg_send; + + for (i = 0; i < occ_switch_tab_size; i++) { + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + *(meg_ptr++) = sw_ins_ptr->smooth; + } + } + + MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD); + mp_debug("mpm_bcast_smooth"); +} + +void mps_bcast_smooth(void) +{ + SW_INS_PTR sw_ins_ptr; + double *meg_ptr; + int i; + + MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD); + mp_debug("mps_bcast_smooth"); + + for (i = 0; i < occ_switch_tab_size; i++) { + meg_ptr = sw_msg_recv; + meg_ptr += occ_position[i]; + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + sw_ins_ptr->smooth = *(meg_ptr++); + } + } +} + +/*------------------------------------------------------------------------*/ + +void clear_sw_msg_send(void) +{ + double *meg_ptr; + double *end_ptr; + + meg_ptr = sw_msg_send; + end_ptr = meg_ptr + sw_msg_size; + while (meg_ptr != end_ptr) { + *(meg_ptr++) = 0.0; + } +} + +void mpm_share_expectation(void) +{ + SW_INS_PTR sw_ins_ptr; + double *meg_ptr; + int i; + + MPI_Allreduce(sw_msg_send, sw_msg_recv, sw_msg_size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + + meg_ptr = sw_msg_recv; + + for (i = 0; i < occ_switch_tab_size; i++) { + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + sw_ins_ptr->total_expect = *(meg_ptr++); + } + } +} + +void mps_share_expectation(void) +{ + SW_INS_PTR sw_ins_ptr; + double *meg_ptr; + int i; + + for (i = 0; i < occ_switch_tab_size; i++) { + meg_ptr = sw_msg_send; + meg_ptr += occ_position[i]; + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + *(meg_ptr++) = sw_ins_ptr->total_expect; + } + } + + MPI_Allreduce(sw_msg_send, sw_msg_recv, sw_msg_size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + + for (i = 0; i < occ_switch_tab_size; i++) { + meg_ptr = sw_msg_recv; + meg_ptr += occ_position[i]; + for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) { + sw_ins_ptr->total_expect = *(meg_ptr++); + } + } +} + +double mp_sum_value(double value) +{ + double g_value; + MPI_Allreduce(&value, &g_value, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + return g_value; +} + +/*------------------------------------------------------------------------*/ + +#endif /* MPI */ diff --git a/packages/prism/src/c/mp/mp_em_aux.h b/packages/prism/src/c/mp/mp_em_aux.h new file mode 100644 index 000000000..94992fb75 --- /dev/null +++ b/packages/prism/src/c/mp/mp_em_aux.h @@ -0,0 +1,29 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_EM_AUX_H +#define MP_EM_AUX_H + +/*-------------------------------------------------------------------------*/ + +extern int sw_msg_size; + +/*-------------------------------------------------------------------------*/ + +void alloc_sw_msg_buffers(void); +void release_sw_msg_buffers(void); +void mpm_bcast_fixed(void); +void mps_bcast_fixed(void); +void mpm_bcast_inside(void); +void mps_bcast_inside(void); +void mpm_bcast_inside_h(void); +void mps_bcast_inside_h(void); +void mpm_bcast_smooth(void); +void mps_bcast_smooth(void); +void clear_sw_msg_send(void); +void mpm_share_expectation(void); +void mps_share_expectation(void); +double mp_sum_value(double); + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_EM_AUX_H */ diff --git a/packages/prism/src/c/mp/mp_em_ml.c b/packages/prism/src/c/mp/mp_em_ml.c new file mode 100644 index 000000000..b6d5ab3de --- /dev/null +++ b/packages/prism/src/c/mp/mp_em_ml.c @@ -0,0 +1,265 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifdef MPI + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "core/error.h" +#include "up/up.h" +#include "up/em.h" +#include "up/em_aux.h" +#include "up/em_aux_ml.h" +#include "up/em_ml.h" +#include "up/graph.h" +#include "up/flags.h" +#include "up/util.h" +#include "mp/mp.h" +#include "mp/mp_core.h" +#include "mp/mp_em_aux.h" +#include + +/*------------------------------------------------------------------------*/ + +void mpm_share_preconds_em(int *smooth) +{ + int ivals[4]; + int ovals[4]; + + ivals[0] = sw_msg_size; + ivals[1] = 0; + ivals[2] = 0; + ivals[3] = *smooth; + + MPI_Allreduce(ivals, ovals, 4, MPI_INT, MPI_SUM, MPI_COMM_WORLD); + + sw_msg_size = ovals[0]; + num_goals = ovals[1]; + failure_observed = ovals[2]; + *smooth = ovals[3]; + + mp_debug("msgsize=%d, #goals=%d, failure=%s, smooth = %s", + sw_msg_size, num_goals, failure_observed ? "on" : "off", *smooth ? "on" : "off"); + + alloc_sw_msg_buffers(); + mpm_bcast_fixed(); + if (*smooth) { + mpm_bcast_smooth(); + } +} + +void mps_share_preconds_em(int *smooth) +{ + int ivals[4]; + int ovals[4]; + + ivals[0] = 0; + ivals[1] = num_goals; + ivals[2] = failure_observed; + ivals[3] = 0; + + MPI_Allreduce(ivals, ovals, 4, MPI_INT, MPI_SUM, MPI_COMM_WORLD); + + sw_msg_size = ovals[0]; + num_goals = ovals[1]; + failure_observed = ovals[2]; + *smooth = ovals[3]; + + mp_debug("msgsize=%d, #goals=%d, failure=%s, smooth = %s", + sw_msg_size, num_goals, failure_observed ? "on" : "off", *smooth ? "on" : "off"); + + alloc_sw_msg_buffers(); + mps_bcast_fixed(); + if (*smooth) { + mps_bcast_smooth(); + } +} + +/*------------------------------------------------------------------------*/ + +int mpm_run_em(EM_ENG_PTR emptr) +{ + int r, iterate, old_valid, converged, saved=0; + double likelihood, log_prior; + double lambda, old_lambda=0.0; + + config_em(emptr); + + for (r = 0; r < num_restart; r++) { + SHOW_PROGRESS_HEAD("#em-iters", r); + + initialize_params(); + mpm_bcast_inside(); + clear_sw_msg_send(); + + itemp = daem ? itemp_init : 1.0; + iterate = 0; + + while (1) { + if (daem) { + SHOW_PROGRESS_TEMP(itemp); + } + old_valid = 0; + + while (1) { + if (CTRLC_PRESSED) { + SHOW_PROGRESS_INTR(); + RET_ERR(err_ctrl_c_pressed); + } + + if (failure_observed) { + inside_failure = mp_sum_value(0.0); + } + + log_prior = emptr->smooth ? emptr->compute_log_prior() : 0.0; + lambda = mp_sum_value(log_prior); + likelihood = lambda - log_prior; + + mp_debug("local lambda = %.9f, lambda = %.9f", log_prior, lambda); + + if (verb_em) { + if (emptr->smooth) { + prism_printf("Iteration #%d:\tlog_likelihood=%.9f\tlog_prior=%.9f\tlog_post=%.9f\n", iterate, likelihood, log_prior, lambda); + } + else { + prism_printf("Iteration #%d:\tlog_likelihood=%.9f\n", iterate, likelihood); + } + } + + if (!isfinite(lambda)) { + emit_internal_error("invalid log likelihood or log post: %s (at iterateion #%d)", + isnan(lambda) ? "NaN" : "infinity", iterate); + RET_ERR(ierr_invalid_likelihood); + } + if (old_valid && old_lambda - lambda > prism_epsilon) { + emit_error("log likelihood or log post decreased [old: %.9f, new: %.9f] (at iteration #%d)", + old_lambda, lambda, iterate); + RET_ERR(err_invalid_likelihood); + } + if (itemp == 1.0 && likelihood > 0.0) { + emit_error("log likelihood greater than zero [value: %.9f] (at iteration #%d)", + likelihood, iterate); + RET_ERR(err_invalid_likelihood); + } + + converged = (old_valid && lambda - old_lambda <= prism_epsilon); + if (converged || REACHED_MAX_ITERATE(iterate)) { + break; + } + + old_lambda = lambda; + old_valid = 1; + + mpm_share_expectation(); + + SHOW_PROGRESS(iterate); + RET_ON_ERR(emptr->update_params()); + iterate++; + } + + if (itemp == 1.0) { + break; + } + itemp *= itemp_rate; + if (itemp >= 1.0) { + itemp = 1.0; + } + } + + SHOW_PROGRESS_TAIL(converged, iterate, lambda); + + if (r == 0 || lambda > emptr->lambda) { + emptr->lambda = lambda; + emptr->likelihood = likelihood; + emptr->iterate = iterate; + + saved = (r < num_restart - 1); + if (saved) { + save_params(); + } + } + } + + if (saved) { + restore_params(); + } + + emptr->bic = compute_bic(emptr->likelihood); + emptr->cs = emptr->smooth ? compute_cs(emptr->likelihood) : 0.0; + + return BP_TRUE; +} + +int mps_run_em(EM_ENG_PTR emptr) +{ + int r, iterate, old_valid, converged, saved=0; + double likelihood; + double lambda, old_lambda=0.0; + + config_em(emptr); + + for (r = 0; r < num_restart; r++) { + mps_bcast_inside(); + clear_sw_msg_send(); + itemp = daem ? itemp_init : 1.0; + iterate = 0; + + while (1) { + old_valid = 0; + + while (1) { + RET_ON_ERR(emptr->compute_inside()); + RET_ON_ERR(emptr->examine_inside()); + + if (failure_observed) { + inside_failure = mp_sum_value(inside_failure); + } + + likelihood = emptr->compute_likelihood(); + lambda = mp_sum_value(likelihood); + + mp_debug("local lambda = %.9f, lambda = %.9f", likelihood, lambda); + + converged = (old_valid && lambda - old_lambda <= prism_epsilon); + if (converged || REACHED_MAX_ITERATE(iterate)) { + break; + } + + old_lambda = lambda; + old_valid = 1; + + RET_ON_ERR(emptr->compute_expectation()); + mps_share_expectation(); + + RET_ON_ERR(emptr->update_params()); + iterate++; + } + + if (itemp == 1.0) { + break; + } + itemp *= itemp_rate; + if (itemp >= 1.0) { + itemp = 1.0; + } + } + + if (r == 0 || lambda > emptr->lambda) { + emptr->lambda = lambda; + saved = (r < num_restart - 1); + if (saved) { + save_params(); + } + } + } + + if (saved) { + restore_params(); + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +#endif /* MPI */ diff --git a/packages/prism/src/c/mp/mp_em_ml.h b/packages/prism/src/c/mp/mp_em_ml.h new file mode 100644 index 000000000..daa2efec2 --- /dev/null +++ b/packages/prism/src/c/mp/mp_em_ml.h @@ -0,0 +1,15 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_EM_ML_H +#define MP_EM_ML_H + +/*-------------------------------------------------------------------------*/ + +void mpm_share_preconds_em(int *); +void mps_share_preconds_em(int *); +int mpm_run_em(EM_ENG_PTR); +int mps_run_em(EM_ENG_PTR); + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_EM_ML_H */ diff --git a/packages/prism/src/c/mp/mp_em_preds.c b/packages/prism/src/c/mp/mp_em_preds.c new file mode 100644 index 000000000..28d1b726a --- /dev/null +++ b/packages/prism/src/c/mp/mp_em_preds.c @@ -0,0 +1,167 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifdef MPI + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "up/up.h" +#include "up/em.h" +#include "up/em_aux.h" +#include "up/em_aux_ml.h" +#include "up/em_aux_vb.h" +#include "up/graph.h" +#include "up/flags.h" +#include "mp/mp.h" +#include "mp/mp_core.h" +#include "mp/mp_em_aux.h" +#include "mp/mp_em_ml.h" +#include "mp/mp_em_vb.h" +#include "mp/mp_sw.h" +#include + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +NORET myquit(int, const char *); + +/*------------------------------------------------------------------------*/ + +int pc_mpm_prism_em_6(void) +{ + struct EM_Engine em_eng; + + /* [28 Aug 2007, by yuizumi] + * occ_switches[] will be freed in pc_import_occ_switches/1. + * occ_position[] is not allocated. + */ + RET_ON_ERR(check_smooth(&em_eng.smooth)); + mpm_share_preconds_em(&em_eng.smooth); + RET_ON_ERR(mpm_run_em(&em_eng)); + release_sw_msg_buffers(); + release_num_sw_vals(); + + return + bpx_unify(bpx_get_call_arg(1,6), bpx_build_integer(em_eng.iterate)) && + bpx_unify(bpx_get_call_arg(2,6), bpx_build_float(em_eng.lambda)) && + bpx_unify(bpx_get_call_arg(3,6), bpx_build_float(em_eng.likelihood)) && + bpx_unify(bpx_get_call_arg(4,6), bpx_build_float(em_eng.bic)) && + bpx_unify(bpx_get_call_arg(5,6), bpx_build_float(em_eng.cs)) && + bpx_unify(bpx_get_call_arg(6,6), bpx_build_integer(em_eng.smooth)); +} + +int pc_mps_prism_em_0(void) +{ + struct EM_Engine em_eng; + + mps_share_preconds_em(&em_eng.smooth); + RET_ON_ERR(mps_run_em(&em_eng)); + release_sw_msg_buffers(); + release_occ_switches(); + release_num_sw_vals(); + release_occ_position(); + + return BP_TRUE; +} + +int pc_mpm_prism_vbem_2(void) +{ + struct VBEM_Engine vb_eng; + + RET_ON_ERR(check_smooth_vb()); + mpm_share_preconds_vbem(); + RET_ON_ERR(mpm_run_vbem(&vb_eng)); + release_sw_msg_buffers(); + release_num_sw_vals(); + + return + bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) && + bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy)); +} + +int pc_mps_prism_vbem_0(void) +{ + struct VBEM_Engine vb_eng; + + mps_share_preconds_vbem(); + RET_ON_ERR(mps_run_vbem(&vb_eng)); + release_sw_msg_buffers(); + release_occ_switches(); + release_num_sw_vals(); + release_occ_position(); + + return BP_TRUE; +} + +int pc_mpm_prism_both_em_2(void) +{ + struct VBEM_Engine vb_eng; + + RET_ON_ERR(check_smooth_vb()); + mpm_share_preconds_vbem(); + RET_ON_ERR(mpm_run_vbem(&vb_eng)); + + get_param_means(); + + release_sw_msg_buffers(); + release_num_sw_vals(); + + return + bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) && + bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy)); +} + +int pc_mps_prism_both_em_0(void) +{ + struct VBEM_Engine vb_eng; + + mps_share_preconds_vbem(); + RET_ON_ERR(mps_run_vbem(&vb_eng)); + + get_param_means(); + + release_sw_msg_buffers(); + release_occ_switches(); + release_num_sw_vals(); + release_occ_position(); + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +int pc_mpm_import_graph_stats_4(void) +{ + int dummy[4] = { 0 }; + int stats[4]; + double avg_shared; + + MPI_Reduce(dummy, stats, 4, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD); + avg_shared = (double)(stats[3]) / stats[0]; + + return + bpx_unify(bpx_get_call_arg(1,4), bpx_build_integer(stats[0])) && + bpx_unify(bpx_get_call_arg(2,4), bpx_build_integer(stats[1])) && + bpx_unify(bpx_get_call_arg(3,4), bpx_build_integer(stats[2])) && + bpx_unify(bpx_get_call_arg(4,4), bpx_build_float(avg_shared)); +} + +int pc_mps_import_graph_stats_0(void) +{ + int dummy[4]; + int stats[4]; + + graph_stats(stats); + MPI_Reduce(stats, dummy, 4, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD); + + mp_debug("# subgoals = %d", stats[0]); + mp_debug("# goal nodes = %d", stats[1]); + mp_debug("# switch nodes = %d", stats[2]); + mp_debug("# sharings = %d", stats[3]); + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +#endif /* MPI */ diff --git a/packages/prism/src/c/mp/mp_em_preds.h b/packages/prism/src/c/mp/mp_em_preds.h new file mode 100644 index 000000000..e09f27123 --- /dev/null +++ b/packages/prism/src/c/mp/mp_em_preds.h @@ -0,0 +1,19 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_EM_PREDS_H +#define MP_EM_PREDS_H + +/*-------------------------------------------------------------------------*/ + +int pc_mpm_prism_em_6(void); +int pc_mps_prism_em_0(void); +int pc_mpm_prism_vbem_2(void); +int pc_mps_prism_vbem_0(void); +int pc_mpm_prism_both_em_7(void); +int pc_mps_prism_both_em_0(void); +int pc_mpm_import_graph_stats_4(void); +int pc_mps_import_graph_stats_0(void); + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_EM_PREDS_H */ diff --git a/packages/prism/src/c/mp/mp_em_vb.c b/packages/prism/src/c/mp/mp_em_vb.c new file mode 100644 index 000000000..04ae09393 --- /dev/null +++ b/packages/prism/src/c/mp/mp_em_vb.c @@ -0,0 +1,256 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifdef MPI + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "up/up.h" +#include "up/em.h" +#include "up/em_aux.h" +#include "up/em_aux_vb.h" +#include "up/em_vb.h" +#include "up/graph.h" +#include "up/flags.h" +#include "up/util.h" +#include "mp/mp.h" +#include "mp/mp_core.h" +#include "mp/mp_em_aux.h" +#include + +/*------------------------------------------------------------------------*/ + +void mpm_share_preconds_vbem(void) +{ + int ivals[3]; + int ovals[3]; + + ivals[0] = sw_msg_size; + ivals[1] = 0; + ivals[2] = 0; + + MPI_Allreduce(ivals, ovals, 3, MPI_INT, MPI_SUM, MPI_COMM_WORLD); + + sw_msg_size = ovals[0]; + num_goals = ovals[1]; + failure_observed = ovals[2]; + + mp_debug("msgsize=%d, #goals=%d, failure=%s", + sw_msg_size, num_goals, failure_observed ? "on" : "off"); + + alloc_sw_msg_buffers(); + mpm_bcast_fixed(); +} + +void mps_share_preconds_vbem(void) +{ + int ivals[3]; + int ovals[3]; + + ivals[0] = 0; + ivals[1] = num_goals; + ivals[2] = failure_observed; + + MPI_Allreduce(ivals, ovals, 3, MPI_INT, MPI_SUM, MPI_COMM_WORLD); + + sw_msg_size = ovals[0]; + num_goals = ovals[1]; + failure_observed = ovals[2]; + + mp_debug("msgsize=%d, #goals=%d, failure=%s", + sw_msg_size, num_goals, failure_observed ? "on" : "off"); + + alloc_sw_msg_buffers(); + mps_bcast_fixed(); +} + +/*------------------------------------------------------------------------*/ + +int mpm_run_vbem(VBEM_ENG_PTR vbptr) +{ + int r, iterate, old_valid, converged, saved=0; + double free_energy, old_free_energy=0.0; + double l0, l1; + + config_vbem(vbptr); + + for (r = 0; r < num_restart; r++) { + SHOW_PROGRESS_HEAD("#vbem-iters", r); + + initialize_hyperparams(); + mpm_bcast_inside_h(); + mpm_bcast_smooth(); + clear_sw_msg_send(); + + itemp = daem ? itemp_init : 1.0; + iterate = 0; + + while (1) { + if (daem) { + SHOW_PROGRESS_TEMP(itemp); + } + old_valid = 0; + + while (1) { + if (CTRLC_PRESSED) { + SHOW_PROGRESS_INTR(); + RET_ERR(err_ctrl_c_pressed); + } + + RET_ON_ERR(vbptr->compute_pi()); + + if (failure_observed) { + inside_failure = mp_sum_value(0.0); + } + + l0 = vbptr->compute_free_energy_l0(); + l1 = vbptr->compute_free_energy_l1(); + free_energy = mp_sum_value(l0 - l1); + + mp_debug("local free_energy = %.9f, free_energy = %.9f", l0 - l1, free_energy); + + if (verb_em) { + prism_printf("Iteration #%d:\tfree_energy=%.9f\n", iterate, free_energy); + } + + if (!isfinite(free_energy)) { + emit_internal_error("invalid variational free energy: %s (at iteration #%d)", + isnan(free_energy) ? "NaN" : "infinity", iterate); + RET_ERR(err_invalid_free_energy); + } + if (old_valid && old_free_energy - free_energy > prism_epsilon) { + emit_error("variational free energy decreased [old: %.9f, new: %.9f] (at iteration #%d)", + old_free_energy, free_energy, iterate); + RET_ERR(err_invalid_free_energy); + } + if (itemp == 1.0 && free_energy > 0.0) { + emit_error("variational free energy greater than zero [value: %.9f] (at iteration #%d)", + free_energy, iterate); + RET_ERR(err_invalid_free_energy); + } + + converged = (old_valid && free_energy - old_free_energy <= prism_epsilon); + if (converged || REACHED_MAX_ITERATE(iterate)) { + break; + } + + old_free_energy = free_energy; + old_valid = 1; + + mpm_share_expectation(); + + SHOW_PROGRESS(iterate); + RET_ON_ERR(vbptr->update_hyperparams()); + + iterate++; + } + + if (itemp == 1.0) { + break; + } + itemp *= itemp_rate; + if (itemp >= 1.0) { + itemp = 1.0; + } + } + + SHOW_PROGRESS_TAIL(converged, iterate, free_energy); + + if (r == 0 || free_energy > vbptr->free_energy) { + vbptr->free_energy = free_energy; + vbptr->iterate = iterate; + + saved = (r < num_restart - 1); + if (saved) { + save_hyperparams(); + } + } + } + + if (saved) { + restore_hyperparams(); + } + + transfer_hyperparams(); + + return BP_TRUE; +} + +int mps_run_vbem(VBEM_ENG_PTR vbptr) +{ + int r, iterate, old_valid, converged, saved=0; + double free_energy, old_free_energy=0.0; + double l2; + + config_vbem(vbptr); + + for (r = 0; r < num_restart; r++) { + mps_bcast_inside_h(); + mps_bcast_smooth(); + clear_sw_msg_send(); + + itemp = daem ? itemp_init : 1.0; + iterate = 0; + + while (1) { + old_valid = 0; + + while (1) { + RET_ON_ERR(vbptr->compute_pi()); + RET_ON_ERR(vbptr->compute_inside()); + RET_ON_ERR(vbptr->examine_inside()); + + if (failure_observed) { + inside_failure = mp_sum_value(inside_failure); + } + + l2 = vbptr->compute_likelihood() / itemp; + free_energy = mp_sum_value(l2); + + mp_debug("local free_energy = %.9f, free_energy = %.9f", l2, free_energy); + + converged = (old_valid && free_energy - old_free_energy <= prism_epsilon); + if (converged || REACHED_MAX_ITERATE(iterate)) { + break; + } + + old_free_energy = free_energy; + old_valid = 1; + + RET_ON_ERR(vbptr->compute_expectation()); + mps_share_expectation(); + + RET_ON_ERR(vbptr->update_hyperparams()); + iterate++; + } + + if (itemp == 1.0) { + break; + } + itemp *= itemp_rate; + if (itemp >= 1.0) { + itemp = 1.0; + } + } + + if (r == 0 || free_energy > vbptr->free_energy) { + vbptr->free_energy = free_energy; + saved = (r < num_restart - 1); + if (saved) { + save_hyperparams(); + } + } + } + + if (saved) { + restore_hyperparams(); + } + + transfer_hyperparams(); + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +#endif /* MPI */ diff --git a/packages/prism/src/c/mp/mp_em_vb.h b/packages/prism/src/c/mp/mp_em_vb.h new file mode 100644 index 000000000..ea616829f --- /dev/null +++ b/packages/prism/src/c/mp/mp_em_vb.h @@ -0,0 +1,15 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_EM_VB_H +#define MP_EM_VB_H + +/*-------------------------------------------------------------------------*/ + +void mpm_share_preconds_vbem(void); +void mps_share_preconds_vbem(void); +int mpm_run_vbem(VBEM_ENG_PTR); +int mps_run_vbem(VBEM_ENG_PTR); + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_EM_VB_H */ diff --git a/packages/prism/src/c/mp/mp_flags.c b/packages/prism/src/c/mp/mp_flags.c new file mode 100644 index 000000000..ba126f1c3 --- /dev/null +++ b/packages/prism/src/c/mp/mp_flags.c @@ -0,0 +1,77 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifdef MPI + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "up/flags.h" +#include + +/*------------------------------------------------------------------------*/ + +#define PUT(msg,pos,type,value) \ + MPI_Pack(&(value),1,(type),(msg),sizeof(msg),&(pos),MPI_COMM_WORLD) + +#define GET(msg,pos,type,value) \ + MPI_Unpack((msg),sizeof(msg),&(pos),&(value),1,(type),MPI_COMM_WORLD) + +/*------------------------------------------------------------------------*/ + +int pc_mpm_share_prism_flags_0(void) +{ + char msg[256]; + int pos = 0; + + PUT( msg , pos , MPI_INT , daem ); + PUT( msg , pos , MPI_INT , em_message ); + PUT( msg , pos , MPI_INT , em_progress ); + PUT( msg , pos , MPI_INT , error_on_cycle ); + PUT( msg , pos , MPI_INT , fix_init_order ); + PUT( msg , pos , MPI_INT , init_method ); + PUT( msg , pos , MPI_DOUBLE , itemp_init ); + PUT( msg , pos , MPI_DOUBLE , itemp_rate ); + PUT( msg , pos , MPI_INT , log_scale ); + PUT( msg , pos , MPI_INT , max_iterate ); + PUT( msg , pos , MPI_INT , num_restart ); + PUT( msg , pos , MPI_DOUBLE , prism_epsilon ); + PUT( msg , pos , MPI_DOUBLE , std_ratio ); + PUT( msg , pos , MPI_INT , verb_em ); + PUT( msg , pos , MPI_INT , verb_graph ); + PUT( msg , pos , MPI_INT , warn ); + + MPI_Bcast(msg, sizeof(msg), MPI_PACKED, 0, MPI_COMM_WORLD); + + return BP_TRUE; +} + +int pc_mps_share_prism_flags_0(void) +{ + char msg[256]; + int pos = 0; + + MPI_Bcast(msg, sizeof(msg), MPI_PACKED, 0, MPI_COMM_WORLD); + + GET( msg , pos , MPI_INT , daem ); + GET( msg , pos , MPI_INT , em_message ); + GET( msg , pos , MPI_INT , em_progress ); + GET( msg , pos , MPI_INT , error_on_cycle ); + GET( msg , pos , MPI_INT , fix_init_order ); + GET( msg , pos , MPI_INT , init_method ); + GET( msg , pos , MPI_DOUBLE , itemp_init ); + GET( msg , pos , MPI_DOUBLE , itemp_rate ); + GET( msg , pos , MPI_INT , log_scale ); + GET( msg , pos , MPI_INT , max_iterate ); + GET( msg , pos , MPI_INT , num_restart ); + GET( msg , pos , MPI_DOUBLE , prism_epsilon ); + GET( msg , pos , MPI_DOUBLE , std_ratio ); + GET( msg , pos , MPI_INT , verb_em ); + GET( msg , pos , MPI_INT , verb_graph ); + GET( msg , pos , MPI_INT , warn ); + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +#endif /* MPI */ diff --git a/packages/prism/src/c/mp/mp_flags.h b/packages/prism/src/c/mp/mp_flags.h new file mode 100644 index 000000000..bc819b209 --- /dev/null +++ b/packages/prism/src/c/mp/mp_flags.h @@ -0,0 +1,13 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_FLAGS_H +#define MP_FLAGS_H + +/*-------------------------------------------------------------------------*/ + +int pc_mpm_share_prism_flags_0(void); +int pc_mps_share_prism_flags_0(void); + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_FLAGS_H */ diff --git a/packages/prism/src/c/mp/mp_preds.c b/packages/prism/src/c/mp/mp_preds.c new file mode 100644 index 000000000..75178e05f --- /dev/null +++ b/packages/prism/src/c/mp/mp_preds.c @@ -0,0 +1,191 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifdef MPI + +#include "bprolog.h" +#include "core/error.h" +#include "up/up.h" +#include "mp/mp.h" +#include "mp/mp_core.h" +#include /* STDOUT_FILENO */ +#include +#include + +/*------------------------------------------------------------------------*/ + +/* cpred.c (B-Prolog) */ +int bp_string_2_term(const char *, TERM, TERM); + +/*------------------------------------------------------------------------*/ + +static char str_prealloc[65536]; + +/*------------------------------------------------------------------------*/ + +static int send_term(TERM arg, int mode, int rank) +{ + char *str; + int len; + + str = (char *)bpx_term_2_string(arg); + len = strlen(str); + + switch (mode) { + case 0: + MPI_Send (&len, 1 , MPI_INT , rank, TAG_GOAL_LEN, MPI_COMM_WORLD); + MPI_Send ( str, len, MPI_CHAR, rank, TAG_GOAL_STR, MPI_COMM_WORLD); + break; + case 1: + MPI_Bcast(&len, 1 , MPI_INT , rank, MPI_COMM_WORLD); + MPI_Bcast( str, len, MPI_CHAR, rank, MPI_COMM_WORLD); + break; + } + + mp_debug("SEND(%d,%d): %s", mode, rank, str); + + return BP_TRUE; +} + +static int recv_term(TERM arg, int mode, int rank) +{ + char *str; + TERM op1, op2; + int len, res; + + switch (mode) { + case 0: + MPI_Recv (&len, 1, MPI_INT, rank, TAG_GOAL_LEN, MPI_COMM_WORLD, NULL); + break; + case 1: + MPI_Bcast(&len, 1, MPI_INT, rank, MPI_COMM_WORLD); + break; + } + + if (len < sizeof(str_prealloc)) + str = str_prealloc; + else { + str = MALLOC(len + 1); + } + + switch (mode) { + case 0: + MPI_Recv (str, len, MPI_CHAR, rank, TAG_GOAL_STR, MPI_COMM_WORLD, NULL); + break; + case 1: + MPI_Bcast(str, len, MPI_CHAR, rank, MPI_COMM_WORLD); + break; + } + + *(str + len) = '\0'; + + mp_debug("RECV(%d,%d): %s", mode, rank, str); + + op1 = bpx_build_var(); + op2 = bpx_build_var(); + + res = bp_string_2_term(str,op1,op2); + if (str != str_prealloc) { + free(str); + } + if (res == BP_TRUE) { + return bpx_unify(arg, op1); + } + return res; +} + +/*------------------------------------------------------------------------*/ + +int pc_mp_size_1(void) +{ + return bpx_unify(bpx_get_call_arg(1,1), bpx_build_integer(mp_size)); +} + +int pc_mp_rank_1(void) +{ + return bpx_unify(bpx_get_call_arg(1,1), bpx_build_integer(mp_rank)); +} + +int pc_mp_master_0(void) +{ + return (mp_rank == 0) ? BP_TRUE : BP_FALSE; +} + +int pc_mp_abort_0(void) +{ + mp_quit(0); +} + +int pc_mp_wtime_1(void) +{ + return bpx_unify(bpx_get_call_arg(1,1), bpx_build_float(MPI_Wtime())); +} + +int pc_mp_sync_2(void) +{ + int args[2], amin[2], amax[2]; + + args[0] = bpx_get_integer(bpx_get_call_arg(1,2)); /* tag */ + args[1] = bpx_get_integer(bpx_get_call_arg(2,2)); /* sync-id */ + + mp_debug("SYNC(%d,%d): BGN", args[0], args[1]); + + MPI_Allreduce(args, amin, 2, MPI_INT, MPI_MIN, MPI_COMM_WORLD); + MPI_Allreduce(args, amax, 2, MPI_INT, MPI_MAX, MPI_COMM_WORLD); + + if (amin[0] != amax[0]) { + emit_internal_error("failure on sync (%d,%d)", args[0], args[1]); + RET_INTERNAL_ERR; + } + + if (amin[1] < 0) { + return BP_FALSE; + } + + if (amin[1] != amax[1]) { + emit_internal_error("failure on sync (%d,%d)", args[0], args[1]); + RET_INTERNAL_ERR; + } + + mp_debug("SYNC(%d,%d): END", args[0], args[1]); + + return BP_TRUE; +} + +int pc_mp_send_goal_1(void) +{ + MPI_Status status; + + MPI_Recv(NULL, 0, MPI_INT, MPI_ANY_SOURCE, TAG_GOAL_REQ, MPI_COMM_WORLD, &status); + return send_term(bpx_get_call_arg(1,1), 0, status.MPI_SOURCE); +} + +int pc_mp_recv_goal_1(void) +{ + MPI_Send(NULL, 0, MPI_INT, 0, TAG_GOAL_REQ, MPI_COMM_WORLD); + return recv_term(bpx_get_call_arg(1,1), 0, 0); +} + +int pc_mpm_bcast_command_1(void) +{ + return send_term(bpx_get_call_arg(1,1), 1, 0); +} + +int pc_mps_bcast_command_1(void) +{ + return recv_term(bpx_get_call_arg(1,1), 1, 0); +} + +int pc_mps_revert_stdout_0(void) +{ + if (fd_dup_stdout >= 0) { + dup2(fd_dup_stdout, STDOUT_FILENO); + close(fd_dup_stdout); + fd_dup_stdout = -1; + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +#endif /* MPI */ diff --git a/packages/prism/src/c/mp/mp_preds.h b/packages/prism/src/c/mp/mp_preds.h new file mode 100644 index 000000000..9535d6a07 --- /dev/null +++ b/packages/prism/src/c/mp/mp_preds.h @@ -0,0 +1,22 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_PREDS_H +#define MP_PREDS_H + +/*-------------------------------------------------------------------------*/ + +int pc_mp_size_1(void); +int pc_mp_rank_1(void); +int pc_mp_master_0(void); +int pc_mp_abort_0(void); +int pc_mp_wtime_1(void); +int pc_mp_sync_2(void); +int pc_mp_send_goal_1(void); +int pc_mp_recv_goal_1(void); +int pc_mpm_bcast_command_1(void); +int pc_mps_bcast_command_1(void); +int pc_mps_revert_stdout_0(void); + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_PREDS_H */ diff --git a/packages/prism/src/c/mp/mp_sw.c b/packages/prism/src/c/mp/mp_sw.c new file mode 100644 index 000000000..dad3c3a01 --- /dev/null +++ b/packages/prism/src/c/mp/mp_sw.c @@ -0,0 +1,206 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifdef MPI + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "core/idtable.h" +#include "core/idtable_preds.h" +#include "up/up.h" +#include "up/em_aux.h" +#include "up/graph.h" +#include "up/flags.h" +#include "mp/mp.h" +#include "mp/mp_core.h" +#include "mp/mp_em_aux.h" +#include +#include +#include + +/*------------------------------------------------------------------------*/ + +int *occ_position = NULL; +static int * sizes = NULL; +static int ** swids = NULL; + +#define L(i) (sizes[i * 2 + 0]) /* length of the message from RANK #i */ +#define N(i) (sizes[i * 2 + 1]) /* number of switches in RANK #i*/ + +/*------------------------------------------------------------------------*/ + +/* cpred.c (B-Prolog) */ +int bp_string_2_term(const char *, TERM, TERM); + +/* mic.c (B-Prolog) */ +NORET quit(const char *); + +/*------------------------------------------------------------------------*/ + +static void parse_switch_req(const char *msg, int src) +{ + const char *p; + TERM op1, op2; + int i; + + swids[src] = MALLOC(sizeof(int) * N(src)); + + p = msg; + + for (i = 0; i < N(src); i++) { + op1 = bpx_build_var(); + op2 = bpx_build_var(); + bp_string_2_term(p, op1, op2); + swids[src][i] = prism_sw_id_register(op1); + while (*(p++) != '\0') ; + } +} + +/*------------------------------------------------------------------------*/ + +int pc_mp_send_switches_0(void) +{ + char *msg, *str; + TERM msw; + int msglen, msgsiz; + int vals[2]; + int i, n; + + msglen = 0; + msgsiz = 65536; + msg = MALLOC(msgsiz); + + for (i = 0; i < occ_switch_tab_size; i++) { + msw = bpx_get_arg(1, prism_sw_ins_term(occ_switches[i]->id)); + str = (char *)bpx_term_2_string(msw); + + n = strlen(str) + 1; + + if (msgsiz <= msglen + n) { + msgsiz = (msglen + n + 65536) & ~65535; + msg = REALLOC(msg, msgsiz); + } + + strcpy(msg + msglen, str); + msglen += n; + } + + msg[msglen++] = '\0'; /* this is safe */ + + vals[0] = msglen; + vals[1] = occ_switch_tab_size; + + MPI_Gather(vals, 2, MPI_INT, NULL, 0, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Send(msg, msglen, MPI_CHAR, 0, TAG_SWITCH_REQ, MPI_COMM_WORLD); + + free(msg); + + return BP_TRUE; +} + +int pc_mp_recv_switches_0(void) +{ + int i, lmax, vals[2]; + char *msg; + + sizes = MALLOC(sizeof(int) * 2 * mp_size); + swids = MALLOC(sizeof(int *) * mp_size); + + MPI_Gather(vals, 2, MPI_INT, sizes, 2, MPI_INT, 0, MPI_COMM_WORLD); + + lmax = 0; + + for (i = 1; i < mp_size; i++) { + if (lmax < L(i)) { + lmax = L(i); + } + } + + msg = MALLOC(lmax); + + for (i = 1; i < mp_size; i++) { + MPI_Recv(msg, L(i), MPI_CHAR, i, TAG_SWITCH_REQ, MPI_COMM_WORLD, NULL); + parse_switch_req(msg, i); + } + + free(msg); + + return BP_TRUE; +} + +int pc_mp_send_swlayout_0(void) +{ + int i, j, *msg, *pos; + + msg = MALLOC(sizeof(int) * sw_tab_size); + pos = MALLOC(sizeof(int) * sw_ins_tab_size); + + j = 0; + + for (i = 0; i < occ_switch_tab_size; i++) { + pos[occ_switches[i]->id] = j; + j += num_sw_vals[i]; + } + + sw_msg_size = j; + + for (i = 1; i < mp_size; i++) { + for (j = 0; j < N(i); j++) { + msg[j] = pos[switches[swids[i][j]]->id]; + } + + MPI_Send(msg, N(i), MPI_INT, i, TAG_SWITCH_RES, MPI_COMM_WORLD); + free(swids[i]); + } + + free(pos); + free(msg); + + free(sizes); + free(swids); + + return BP_TRUE; +} + +int pc_mp_recv_swlayout_0(void) +{ + occ_position = MALLOC(sizeof(int) * occ_switch_tab_size); + + MPI_Recv(occ_position, occ_switch_tab_size, MPI_INT, 0, TAG_SWITCH_RES, MPI_COMM_WORLD, NULL); + + /* debug */ + { + int i; + TERM msw; + for (i = 0; i < occ_switch_tab_size; i++) { + msw = bpx_get_arg(1, prism_sw_ins_term(occ_switches[i]->id)); + mp_debug("%s -> %d", bpx_term_2_string(msw), occ_position[i]); + } + } + + return BP_TRUE; +} + +int pc_mpm_alloc_occ_switches_0(void) +{ + occ_switches = MALLOC(sizeof(SW_INS_PTR) * sw_tab_size); + + occ_switch_tab_size = sw_tab_size; + memcpy(occ_switches, switches, sizeof(SW_INS_PTR) * sw_tab_size); + if (fix_init_order) { + sort_occ_switches(); + } + alloc_num_sw_vals(); + + return BP_TRUE; +} + +void release_occ_position(void) +{ + free(occ_position); + occ_position = NULL; +} + +/*------------------------------------------------------------------------*/ + +#endif /* MPI */ diff --git a/packages/prism/src/c/mp/mp_sw.h b/packages/prism/src/c/mp/mp_sw.h new file mode 100644 index 000000000..d57930f67 --- /dev/null +++ b/packages/prism/src/c/mp/mp_sw.h @@ -0,0 +1,22 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef MP_SW_H +#define MP_SW_H + +/*-------------------------------------------------------------------------*/ + +extern int *occ_position; + +/*-------------------------------------------------------------------------*/ + +int pc_mp_send_switches_0(void); +int pc_mp_recv_switches_0(void); +int pc_mp_send_swlayout_0(void); +int pc_mp_recv_swlayout_0(void); +int pc_mpm_alloc_occ_switches_0(void); + +void release_occ_position(void); + +/*-------------------------------------------------------------------------*/ + +#endif /* MP_SW_H */ diff --git a/packages/prism/src/c/up/em.h b/packages/prism/src/c/up/em.h new file mode 100644 index 000000000..f63fee8ea --- /dev/null +++ b/packages/prism/src/c/up/em.h @@ -0,0 +1,106 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +#ifndef __EM_H__ +#define __EM_H__ + +/*------------------------------------------------------------------------*/ + +#define DEFAULT_MAX_ITERATE (10000) + +/*------------------------------------------------------------------------*/ + +struct EM_Engine { + int smooth; /* [in ] flag: use MAP? */ + double lambda; /* [out] log post */ + double likelihood; /* [out] log likelihood */ + int iterate; /* [out] number of iterations */ + double bic; /* [out] BIC score */ + double cs; /* [out] CS score */ + + /* Functions called during computation. */ + int (* compute_inside )(void); + int (* examine_inside )(void); + int (* compute_expectation )(void); + double (* compute_likelihood )(void); + double (* compute_log_prior )(void); + int (* update_params )(void); +}; + +struct VBEM_Engine { + double free_energy; /* [out] free energy */ + int iterate; /* [out] number of iterations */ + + /* Functions called during computation. */ + int (* compute_pi )(void); + int (* compute_inside )(void); + int (* examine_inside )(void); + int (* compute_expectation )(void); + double (* compute_free_energy_l0 )(void); + double (* compute_free_energy_l1 )(void); + double (* compute_likelihood )(void); + int (* update_hyperparams )(void); +}; + +typedef struct EM_Engine * EM_ENG_PTR; +typedef struct VBEM_Engine * VBEM_ENG_PTR; + +/*------------------------------------------------------------------------*/ + +#define SHOW_PROGRESS(n) \ + do { \ + if(!verb_em && em_message > 0 && (n) % em_progress == 0) { \ + if((n) % (em_progress * 10) == 0) \ + prism_printf("%d", n); \ + else \ + prism_printf("."); \ + } \ + } while (0) + +#define SHOW_PROGRESS_HEAD(str, r) \ + do { \ + if(num_restart > 1) { \ + if(verb_em) \ + prism_printf("<<<< RESTART #%d >>>>\n", r); \ + else if(em_message > 0) \ + prism_printf("[%d] ", r); \ + } \ + if(!verb_em && em_message > 0) \ + prism_printf("%s: ", str); \ + } while (0) + +#define SHOW_PROGRESS_TAIL(converged, n, x) \ + do { \ + const char *str = \ + converged ? "Converged" : "Stopped"; \ + \ + if(verb_em) \ + prism_printf("* %s (%.9f)\n", str, x); \ + else if(em_message > 0) \ + prism_printf("(%d) (%s: %.9f)\n", n, str, x); \ + } while (0) + +#define SHOW_PROGRESS_TEMP(x) \ + do { \ + if(verb_em) \ + prism_printf("* Temperature = %.3f\n", x); \ + else if(em_message > 0 && show_itemp) \ + prism_printf("<%.3f>", x); \ + else \ + prism_printf("*"); \ + } while (0) + +#define SHOW_PROGRESS_INTR() \ + do { \ + if(verb_em) \ + prism_printf("* Interrupted\n"); \ + else if(em_message > 0) \ + prism_printf("(Interrupted)\n"); \ + } while (0) + +#define REACHED_MAX_ITERATE(n) \ + ((max_iterate == -1 && (n) >= DEFAULT_MAX_ITERATE) || \ + (max_iterate >= +1 && (n) >= max_iterate)) + +/*------------------------------------------------------------------------*/ + +#endif /* __EM_H__ */ diff --git a/packages/prism/src/c/up/em_aux.c b/packages/prism/src/c/up/em_aux.c new file mode 100644 index 000000000..efff6f1d5 --- /dev/null +++ b/packages/prism/src/c/up/em_aux.c @@ -0,0 +1,151 @@ +/* -*- c-basic-offset: 2; tab-width: 8 -*- */ + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "up/up.h" +#include "up/graph.h" +#include "up/flags.h" + +/*------------------------------------------------------------------------*/ + +int * num_sw_vals = NULL; +double itemp; +double inside_failure; +int failure_observed; + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +int compare(TERM,TERM); +void quit(const char *); + +/*------------------------------------------------------------------------*/ + +/* for sort_occ_switches() */ +static int compare_sw_ins(const void *a, const void *b) +{ + SW_INS_PTR sw_ins_a, sw_ins_b; + TERM msw_a, msw_b; + + sw_ins_a = *(const SW_INS_PTR *)(a); + sw_ins_b = *(const SW_INS_PTR *)(b); + + msw_a = prism_sw_ins_term(sw_ins_a->id); + msw_b = prism_sw_ins_term(sw_ins_b->id); + + return compare(bpx_get_arg(1,msw_a), bpx_get_arg(1,msw_b)); +} + +/*------------------------------------------------------------------------*/ + +/* Set flags of switches appearing in the e-graphs and allocate an array + * of pointers to such switches (This routine is based on compute_inside()). + */ +void alloc_occ_switches(void) +{ + int i,j,k; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + SW_INS_PTR sw_ins_ptr; + int *occ_sw_flags; + int b; + + /* Initialize the `occ' counters in switch instances */ + for (i = 0; i < sw_ins_tab_size; i++) { + switch_instances[i]->occ = 0; + } + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + path_ptr = eg_ptr->path_ptr; + while (path_ptr != NULL) { + for (k = 0; k < path_ptr->sws_len; k++) { + path_ptr->sws[k]->occ = 1; + } + path_ptr = path_ptr->next; + } + } + + /* Temporarily make an array of flags each of which indicates whether + a switch (not switch instance) occurs in the e-graphs */ + occ_sw_flags = (int *)MALLOC(sizeof(int) * sw_tab_size); + occ_switch_tab_size = 0; + for (i = 0; i < sw_tab_size; i++) { + sw_ins_ptr = switches[i]; + b = 0; + while (sw_ins_ptr != NULL) { + b |= sw_ins_ptr->occ; + sw_ins_ptr = sw_ins_ptr->next; + } + occ_sw_flags[i] = b; + if (b) occ_switch_tab_size++; + } + + occ_switches = + (SW_INS_PTR *)MALLOC(sizeof(SW_INS_PTR) * occ_switch_tab_size); + + j = 0; + for (i = 0; i < sw_tab_size; i++) { + if (occ_sw_flags[i]) { + occ_switches[j] = switches[i]; /* Copy */ + j++; + } + } + + free(occ_sw_flags); +} + +void sort_occ_switches(void) +{ + qsort(occ_switches,occ_switch_tab_size,sizeof(SW_INS_PTR),compare_sw_ins); +} + +void release_occ_switches(void) +{ + free(occ_switches); + occ_switches = NULL; +} + +void alloc_num_sw_vals(void) +{ + int i,n; + SW_INS_PTR sw_ins_ptr; + + num_sw_vals = (int *)MALLOC(sizeof(int) * occ_switch_tab_size); + + for (i = 0; i < occ_switch_tab_size; i++) { + sw_ins_ptr = occ_switches[i]; + n = 0; + while (sw_ins_ptr != NULL) { + n++; + sw_ins_ptr = sw_ins_ptr->next; + } + num_sw_vals[i] = n; + } +} + +void release_num_sw_vals(void) +{ + free(num_sw_vals); + num_sw_vals = NULL; +} + +/*------------------------------------------------------------------------*/ + +void transfer_hyperparams_prolog(void) +{ + int i; + SW_INS_PTR sw_ins_ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + sw_ins_ptr = occ_switches[i]; + while (sw_ins_ptr != NULL) { + sw_ins_ptr->smooth = sw_ins_ptr->smooth_prolog; + sw_ins_ptr->inside_h = sw_ins_ptr->smooth_prolog + 1.0; + sw_ins_ptr = sw_ins_ptr->next; + } + } +} + +/*------------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/up/em_aux.h b/packages/prism/src/c/up/em_aux.h new file mode 100644 index 000000000..4790876ce --- /dev/null +++ b/packages/prism/src/c/up/em_aux.h @@ -0,0 +1,16 @@ +#ifndef EM_AUX_H +#define EM_AUX_H + +extern int * num_sw_vals; /* #-vals of switches that occur in e-graphs */ +extern double itemp; /* inversed temperature (for DAEM) */ +extern double inside_failure; /* inside prob. of failure */ +extern int failure_observed; /* flag: true if failure is observed */ + +void alloc_occ_switches(void); +void sort_occ_switches(void); +void release_occ_switches(void); +void alloc_num_sw_vals(void); +void release_num_sw_vals(void); +void transfer_hyperparams_prolog(void); + +#endif /* EM_AUX_H */ diff --git a/packages/prism/src/c/up/em_aux_ml.c b/packages/prism/src/c/up/em_aux_ml.c new file mode 100644 index 000000000..025264a7f --- /dev/null +++ b/packages/prism/src/c/up/em_aux_ml.c @@ -0,0 +1,777 @@ +/* -*- c-basic-offset: 2; tab-width: 8 -*- */ + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "core/random.h" +#include "core/gamma.h" +#include "up/up.h" +#include "up/graph.h" +#include "up/flags.h" +#include "up/em_aux.h" +#include "up/util.h" + +/*------------------------------------------------------------------------*/ + +/* We check if all smoothing constants are positive (MAP), + * or all smoothing constants are zero. If some are positive, + * but the others are zero, die immediately. We also check + * if there exist parameters fixed at zero in MAP estimation. + */ +int check_smooth(int *smooth) +{ + /* + q = +4 : found non-zero smoothing constants + +2 : found zero-valued smoothing constants + +1 : found parameters fixed to zero + */ + int i, q = 0; + SW_INS_PTR sw_ins_ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + sw_ins_ptr = occ_switches[i]; + while (sw_ins_ptr != NULL) { + if (sw_ins_ptr->smooth_prolog < 0) { + emit_error("negative delta values in MAP estimation"); + RET_ERR(err_invalid_numeric_value); + } + + q |= (sw_ins_ptr->smooth_prolog < TINY_PROB) ? 2 : 4; + q |= (sw_ins_ptr->fixed && sw_ins_ptr->inside < TINY_PROB) ? 1 : 0; + + sw_ins_ptr = sw_ins_ptr->next; + } + } + + switch (q) { + case 0: /* p.counts = (none), w/o 0-valued params */ + case 1: /* p.counts = (none), with 0-valued params */ + emit_internal_error("unexpected case in check_smooth()"); + RET_ERR(ierr_unmatched_branches); + case 2: /* p.counts = 0 only, w/o 0-valued params */ + case 3: /* p.counts = 0 only, with 0-valued params */ + *smooth = 0; + break; + case 4: /* p.counts = + only, w/o 0-valued params */ + *smooth = 1; + break; + case 5: /* p.counts = + only, with 0-valued params */ + emit_error("parameters fixed to zero in MAP estimation"); + RET_ERR(err_invalid_numeric_value); + case 6: /* p.counts = (both), w/o 0-valued params */ + case 7: /* p.counts = (both), with 0-valued params */ + emit_error("mixture of zero and non-zero pseudo counts"); + RET_ERR(err_invalid_numeric_value); + } + + transfer_hyperparams_prolog(); + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +static void initialize_params_noisy_uniform(void) +{ + int i; + SW_INS_PTR ptr; + double sum,p; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + + if (ptr->fixed > 0) continue; + + p = 1.0 / num_sw_vals[i]; + sum = 0.0; + while (ptr != NULL) { + ptr->inside = random_gaussian(p, std_ratio * p); + if (ptr->inside < INIT_PROB_THRESHOLD) + ptr->inside = INIT_PROB_THRESHOLD; + sum += ptr->inside; + ptr = ptr->next; + } + ptr = occ_switches[i]; + while (ptr != NULL) { /* normalize */ + ptr->inside = ptr->inside / sum; + ptr = ptr->next; + } + } +} + +static void initialize_params_random(void) +{ + int i; + SW_INS_PTR ptr; + double sum,p; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + + if (ptr->fixed > 0) continue; + + p = 1.0 / num_sw_vals[i]; + sum = 0.0; + while (ptr != NULL) { + sum += (ptr->inside = p + random_float()); + ptr = ptr->next; + } + ptr = occ_switches[i]; + while (ptr != NULL) { /* normalize */ + ptr->inside = ptr->inside / sum; + ptr = ptr->next; + } + } +} + +void initialize_params(void) +{ + if (init_method == 1) + initialize_params_noisy_uniform(); + if (init_method == 2) + initialize_params_random(); +} + +/*------------------------------------------------------------------------*/ + +int compute_inside_scaling_none(void) +{ + int i,k; + double sum,this_path_inside; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + sum = 0.0; + path_ptr = eg_ptr->path_ptr; + if (path_ptr == NULL) + sum = 1.0; /* path_ptr should not be NULL; but it happens */ + while (path_ptr != NULL) { + this_path_inside = 1.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_inside *= path_ptr->children[k]->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_inside *= path_ptr->sws[k]->inside; + } + path_ptr->inside = this_path_inside; + sum += this_path_inside; + path_ptr = path_ptr->next; + } + + eg_ptr->inside = sum; + } + + return BP_TRUE; +} + +int compute_inside_scaling_log_exp(void) +{ + int i,k,u; + double sum, this_path_inside, first_path_inside = 0.0, sum_rest; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + path_ptr = eg_ptr->path_ptr; + if (path_ptr == NULL) { + sum = 0.0; /* path_ptr should not be NULL; but it happens */ + } + else { + sum_rest = 0.0; + u = 0; + while (path_ptr != NULL) { + this_path_inside = 0.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_inside += path_ptr->children[k]->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_inside += log(path_ptr->sws[k]->inside); + } + path_ptr->inside = this_path_inside; + if (u == 0) { + first_path_inside = this_path_inside; + sum_rest += 1.0; + } + else if (this_path_inside - first_path_inside >= log(HUGE_PROB)) { + sum_rest *= exp(first_path_inside - this_path_inside); + first_path_inside = this_path_inside; + sum_rest += 1.0; /* maybe sum_rest gets 1.0 */ + } + else { + sum_rest += exp(this_path_inside - first_path_inside); + } + path_ptr = path_ptr->next; + u++; + } + sum = first_path_inside + log(sum_rest); + } + + eg_ptr->inside = sum; + } + + return BP_TRUE; +} + +int compute_daem_inside_scaling_none(void) +{ + int i,k; + double sum,this_path_inside; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + sum = 0.0; + path_ptr = eg_ptr->path_ptr; + if (path_ptr == NULL) + sum = 1.0; /* path_ptr should not be NULL; but it happens */ + while (path_ptr != NULL) { + this_path_inside = 1.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_inside *= path_ptr->children[k]->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_inside *= pow(path_ptr->sws[k]->inside, itemp); + } + path_ptr->inside = this_path_inside; + sum += this_path_inside; + path_ptr = path_ptr->next; + } + + eg_ptr->inside = sum; + } + + return BP_TRUE; +} + +int compute_daem_inside_scaling_log_exp(void) +{ + int i,k,u; + double sum, this_path_inside, first_path_inside = 0.0, sum_rest; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + path_ptr = eg_ptr->path_ptr; + if (path_ptr == NULL) { + sum = 0.0; /* path_ptr should not be NULL; but it happens */ + } + else { + sum_rest = 0.0; + u = 0; + while (path_ptr != NULL) { + this_path_inside = 0.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_inside += path_ptr->children[k]->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_inside += itemp * log(path_ptr->sws[k]->inside); + } + path_ptr->inside = this_path_inside; + if (u == 0) { + first_path_inside = this_path_inside; + sum_rest += 1.0; + } + else if (this_path_inside - first_path_inside >= log(HUGE_PROB)) { + sum_rest *= exp(first_path_inside - this_path_inside); + first_path_inside = this_path_inside; + sum_rest += 1.0; /* maybe sum_rest gets 1.0 */ + } + else { + sum_rest += exp(this_path_inside - first_path_inside); + } + path_ptr = path_ptr->next; + u++; + } + sum = first_path_inside + log(sum_rest); + } + + eg_ptr->inside = sum; + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +int examine_inside_scaling_none(void) +{ + int i; + double inside; + + inside_failure = 0.0; + + for (i = 0; i < num_roots; i++) { + inside = expl_graph[roots[i]->id]->inside; + if (i == failure_root_index) { + inside_failure = inside; + if (!(1.0 - inside_failure > 0.0)) { + emit_error("Probability of failure being unity"); + RET_ERR(err_invalid_numeric_value); + } + } + else { + if (!(inside > 0.0)) { + emit_error("Probability of an observed goal being zero"); + RET_ERR(err_invalid_numeric_value); + } + } + } + + return BP_TRUE; +} + +int examine_inside_scaling_log_exp(void) +{ + int i; + double inside; + + /* [23 Aug 2007, by yuizumi] + * By the code below, inside_failure can take only a non-zero value + * when `failure' is observed. We can therefore safely use zero as + * an indicator of failure being not observed. Zero is chosen just + * for convenience in implementation of the parallel version. + */ + inside_failure = 0.0; + + for (i = 0; i < num_roots; i++) { + inside = expl_graph[roots[i]->id]->inside; + if (i == failure_root_index) { + inside_failure = inside; /* log-scale */ + if (!(inside_failure < 0.0)) { + emit_error("Probability of failure being unity"); + RET_ERR(err_invalid_numeric_value); + } + } + else { + if (!isfinite(inside)) { + emit_error("Probability of an observed goal being zero"); + RET_ERR(err_invalid_numeric_value); + } + } + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +int compute_expectation_scaling_none(void) +{ + int i,k; + EG_PATH_PTR path_ptr; + EG_NODE_PTR eg_ptr,node_ptr; + SW_INS_PTR sw_ptr; + double q; + + for (i = 0; i < sw_ins_tab_size; i++) { + switch_instances[i]->total_expect = 0.0; + } + + for (i = 0; i < sorted_egraph_size; i++) { + sorted_expl_graph[i]->outside = 0.0; + } + + for (i = 0; i < num_roots; i++) { + eg_ptr = expl_graph[roots[i]->id]; + if (i == failure_root_index) { + eg_ptr->outside = num_goals / (1.0 - inside_failure); + } + else { + eg_ptr->outside = roots[i]->count / eg_ptr->inside; + } + } + + for (i = sorted_egraph_size - 1; i >= 0; i--) { + eg_ptr = sorted_expl_graph[i]; + path_ptr = eg_ptr->path_ptr; + while (path_ptr != NULL) { + q = eg_ptr->outside * path_ptr->inside; + if (q > 0.0) { + for (k = 0; k < path_ptr->children_len; k++) { + node_ptr = path_ptr->children[k]; + node_ptr->outside += q / node_ptr->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + sw_ptr = path_ptr->sws[k]; + sw_ptr->total_expect += q; + } + } + path_ptr = path_ptr->next; + } + } + + return BP_TRUE; +} + +int compute_expectation_scaling_log_exp(void) +{ + int i,k; + EG_PATH_PTR path_ptr; + EG_NODE_PTR eg_ptr,node_ptr; + SW_INS_PTR sw_ptr; + double q,r; + + for (i = 0; i < sw_ins_tab_size; i++) { + switch_instances[i]->total_expect = 0.0; + switch_instances[i]->has_first_expectation = 0; + switch_instances[i]->first_expectation = 0.0; + } + + for (i = 0; i < sorted_egraph_size; i++) { + sorted_expl_graph[i]->outside = 0.0; + sorted_expl_graph[i]->has_first_outside = 0; + sorted_expl_graph[i]->first_outside = 0.0; + } + + for (i = 0; i < num_roots; i++) { + eg_ptr = expl_graph[roots[i]->id]; + if (i == failure_root_index) { + eg_ptr->first_outside = + log(num_goals / (1.0 - exp(inside_failure))); + } + else { + eg_ptr->first_outside = + log((double)(roots[i]->count)) - eg_ptr->inside; + } + eg_ptr->has_first_outside = 1; + eg_ptr->outside = 1.0; + } + + /* sorted_expl_graph[to] must be a root node */ + for (i = sorted_egraph_size - 1; i >= 0; i--) { + eg_ptr = sorted_expl_graph[i]; + + /* First accumulate log-scale outside probabilities: */ + if (!eg_ptr->has_first_outside) { + emit_internal_error("unexpected has_first_outside[%s]", + prism_goal_string(eg_ptr->id)); + RET_INTERNAL_ERR; + } + else if (!(eg_ptr->outside > 0.0)) { + emit_internal_error("unexpected outside[%s]", + prism_goal_string(eg_ptr->id)); + RET_INTERNAL_ERR; + } + else { + eg_ptr->outside = eg_ptr->first_outside + log(eg_ptr->outside); + } + + path_ptr = sorted_expl_graph[i]->path_ptr; + while (path_ptr != NULL) { + q = sorted_expl_graph[i]->outside + path_ptr->inside; + for (k = 0; k < path_ptr->children_len; k++) { + node_ptr = path_ptr->children[k]; + r = q - node_ptr->inside; + if (!node_ptr->has_first_outside) { + node_ptr->first_outside = r; + node_ptr->outside += 1.0; + node_ptr->has_first_outside = 1; + } + else if (r - node_ptr->first_outside >= log(HUGE_PROB)) { + node_ptr->outside *= exp(node_ptr->first_outside - r); + node_ptr->first_outside = r; + node_ptr->outside += 1.0; + } + else { + node_ptr->outside += exp(r - node_ptr->first_outside); + } + } + for (k = 0; k < path_ptr->sws_len; k++) { + sw_ptr = path_ptr->sws[k]; + if (!sw_ptr->has_first_expectation) { + sw_ptr->first_expectation = q; + sw_ptr->total_expect += 1.0; + sw_ptr->has_first_expectation = 1; + } + else if (q - sw_ptr->first_expectation >= log(HUGE_PROB)) { + sw_ptr->total_expect *= exp(sw_ptr->first_expectation - q); + sw_ptr->first_expectation = q; + sw_ptr->total_expect += 1.0; + } + else { + sw_ptr->total_expect += exp(q - sw_ptr->first_expectation); + } + } + path_ptr = path_ptr->next; + } + } + + /* unscale total_expect */ + for (i = 0; i < sw_ins_tab_size; i++) { + sw_ptr = switch_instances[i]; + if (!sw_ptr->has_first_expectation) continue; + if (!(sw_ptr->total_expect > 0.0)) { + emit_error("unexpected expectation for %s",prism_sw_ins_string(i)); + RET_ERR(err_invalid_numeric_value); + } + sw_ptr->total_expect = + exp(sw_ptr->first_expectation + log(sw_ptr->total_expect)); + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +double compute_likelihood_scaling_none(void) +{ + int i; + double likelihood,adjuster,inside; + + likelihood = 0.0; + adjuster = failure_observed ? log(1.0-inside_failure) : 0.0; + + for (i = 0; i < num_roots; i++) { + if (i == failure_root_index) continue; /* skip failure */ + inside = expl_graph[roots[i]->id]->inside; /* always positive */ + likelihood += roots[i]->count * (log(inside) - adjuster); + } + + return likelihood; +} + +double compute_likelihood_scaling_log_exp(void) +{ + int i; + double likelihood,adjuster,inside; + + likelihood = 0.0; + adjuster = failure_observed ? log(1.0-exp(inside_failure)) : 0.0; + + for (i = 0; i < num_roots; i++) { + if (i == failure_root_index) continue; /* skip failure */ + inside = expl_graph[roots[i]->id]->inside; /* log-scale */ + likelihood += roots[i]->count * (inside - adjuster); + } + + return likelihood; +} + +/*------------------------------------------------------------------------*/ + +double compute_log_prior(void) +{ + int i; + SW_INS_PTR sw_ins_ptr; + double lp; + + lp = 0.0; + for (i = 0; i < occ_switch_tab_size; i++) { + sw_ins_ptr = occ_switches[i]; + while (sw_ins_ptr != NULL) { + lp += sw_ins_ptr->smooth * log(sw_ins_ptr->inside); + sw_ins_ptr = sw_ins_ptr->next; + } + } + + return lp; +} + +double compute_daem_log_prior(void) +{ + int i; + SW_INS_PTR sw_ins_ptr; + double lp; + + lp = 0.0; + for (i = 0; i < occ_switch_tab_size; i++) { + sw_ins_ptr = occ_switches[i]; + while (sw_ins_ptr != NULL) { + lp += sw_ins_ptr->smooth * log(sw_ins_ptr->inside); + sw_ins_ptr = sw_ins_ptr->next; + } + } + + return itemp * lp; +} + +/*------------------------------------------------------------------------*/ + +int update_params(void) +{ + int i; + SW_INS_PTR ptr,next; + double sum,cur_prob_sum; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + sum = 0.0; + while (ptr != NULL) { + sum += ptr->total_expect; + ptr = ptr->next; + } + if (sum != 0.0) { + cur_prob_sum = 0.0; + ptr = occ_switches[i]; + if (ptr->fixed > 0) continue; + next = ptr->next; + while (next != NULL) { + if (ptr->fixed == 0) ptr->inside = ptr->total_expect / sum; + if (log_scale && ptr->inside < log(TINY_PROB)) { + emit_error("Parameter being zero (-inf in log scale) -- %s", + prism_sw_ins_string(ptr->id)); + RET_ERR(err_underflow); + } + cur_prob_sum += ptr->inside; + ptr = next; + next = ptr->next; + } + ptr->inside = 1.0-cur_prob_sum; /* Normalize */ + } + } + + return BP_TRUE; +} + +int update_params_smooth(void) +{ + int i; + SW_INS_PTR ptr,next; + double sum,cur_prob_sum; + double denom; + int n; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + n = num_sw_vals[i]; + sum = 0.0; + while (ptr != NULL) { + sum += ptr->total_expect + ptr->smooth; + ptr = ptr->next; + } + denom = sum; + if (sum != 0.0) { + cur_prob_sum = 0.0; + ptr = occ_switches[i]; + if (ptr->fixed > 0) continue; + next = ptr->next; + while (next != NULL) { + if (ptr->fixed == 0) + ptr->inside = (ptr->total_expect + ptr->smooth) / denom; + cur_prob_sum += ptr->inside; + ptr = next; + next = ptr->next; + } + ptr->inside = 1.0-cur_prob_sum; /* Normalize */ + } + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +void save_params(void) +{ + int i; + SW_INS_PTR ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + if (ptr->fixed > 0) continue; + while (ptr != NULL) { + ptr->best_inside = ptr->inside; + ptr->best_total_expect = ptr->total_expect; + ptr = ptr->next; + } + } +} + +void restore_params(void) +{ + int i; + SW_INS_PTR ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + if (ptr->fixed > 0) continue; + while (ptr != NULL) { + ptr->inside = ptr->best_inside; + ptr->total_expect = ptr->best_total_expect; + ptr = ptr->next; + } + } +} + +/*------------------------------------------------------------------------*/ + +double compute_bic(double likelihood) +{ + double bic = likelihood; + int i, num_sw_ins, num_params; + + num_sw_ins = 0; + for (i = 0; i < occ_switch_tab_size; i++) { + SW_INS_PTR ptr = occ_switches[i]; + while (ptr != NULL) { + num_sw_ins++; + ptr = ptr->next; + } + } + + /* Get the number of free parameters: */ + num_params = num_sw_ins - occ_switch_tab_size; + bic = likelihood - 0.5 * num_params * log(num_goals); + + return bic; +} + +double compute_cs(double likelihood) +{ + double cs; + double l0, l1, l2; + int i; + SW_INS_PTR ptr; + double smooth_sum; + + /* Compute BD score using the expectations: */ + l0 = 0.0; + for (i = 0; i < occ_switch_tab_size; i++) { + smooth_sum = 0.0; + ptr = occ_switches[i]; + while (ptr != NULL) { + smooth_sum += (ptr->smooth + 1.0); + ptr = ptr->next; + } + l0 += lngamma(smooth_sum); + + smooth_sum = 0.0; + ptr = occ_switches[i]; + while (ptr != NULL) { + smooth_sum += (ptr->total_expect + ptr->smooth + 1.0); + ptr = ptr->next; + } + l0 -= lngamma(smooth_sum); + + ptr = occ_switches[i]; + while (ptr != NULL) { + l0 += lngamma(ptr->total_expect + ptr->smooth + 1.0); + l0 -= lngamma(ptr->smooth + 1.0); + ptr = ptr->next; + } + } + + /* Compute the likelihood of complete data using the expectations: */ + l1 = 0.0; + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + while (ptr != NULL) { + l1 += ptr->total_expect * log(ptr->inside); + ptr = ptr->next; + } + } + + /* Get the log-likelihood: */ + l2 = likelihood; + + cs = l0 - l1 + l2; + + return cs; +} + +/*------------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/up/em_aux_ml.h b/packages/prism/src/c/up/em_aux_ml.h new file mode 100644 index 000000000..0ec54d2a7 --- /dev/null +++ b/packages/prism/src/c/up/em_aux_ml.h @@ -0,0 +1,26 @@ +#ifndef EM_AUX_ML_H +#define EM_AUX_ML_H + +int check_smooth(int *); +void initialize_params(void); +int compute_inside_scaling_none(void); +int compute_inside_scaling_log_exp(void); +int compute_daem_inside_scaling_none(void); +int compute_daem_inside_scaling_log_exp(void); +int examine_inside_scaling_none(void); +int examine_inside_scaling_log_exp(void); +int compute_expectation_scaling_none(void); +int compute_expectation_scaling_log_exp(void); +double compute_likelihood_scaling_none(void); +double compute_likelihood_scaling_log_exp(void); +double compute_log_prior(void); +double compute_daem_log_prior(void); +int update_params(void); +int update_params_smooth(void); +void save_params(void); +void restore_params(void); +double compute_bic(double); +double compute_cs(double); + +#endif /* EM_AUX_ML_H */ + diff --git a/packages/prism/src/c/up/em_aux_vb.c b/packages/prism/src/c/up/em_aux_vb.c new file mode 100644 index 000000000..30b1315ad --- /dev/null +++ b/packages/prism/src/c/up/em_aux_vb.c @@ -0,0 +1,569 @@ +/* -*- c-basic-offset: 2; tab-width: 8 -*- */ + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "core/random.h" +#include "core/gamma.h" +#include "up/up.h" +#include "up/graph.h" +#include "up/em_aux.h" +#include "up/em_aux_ml.h" +#include "up/flags.h" +#include "up/util.h" + +/*------------------------------------------------------------------------*/ + +/* Just check if there is any negative hyperparameter */ +int check_smooth_vb(void) +{ + int i; + SW_INS_PTR sw_ins_ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + sw_ins_ptr = occ_switches[i]; + while (sw_ins_ptr != NULL) { + if (sw_ins_ptr->smooth_prolog <= -1.0) { + emit_internal_error("illegal hyperparameters"); + RET_INTERNAL_ERR; + } + sw_ins_ptr = sw_ins_ptr->next; + } + } + + transfer_hyperparams_prolog(); + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +void initialize_hyperparams(void) +{ + int i; + SW_INS_PTR ptr; + double p,r; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + while (ptr != NULL) { + ptr->smooth = ptr->smooth_prolog; + ptr = ptr->next; + } + } + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + + if (ptr->fixed_h > 0) { + while (ptr != NULL) { + ptr->inside_h = ptr->smooth + 1.0; + ptr->total_expect = 0.0; + ptr = ptr->next; + } + } + else { + p = 1.0 / num_sw_vals[i]; + while (ptr != NULL) { + r = random_gaussian(0.0, std_ratio * p); + ptr->inside_h = + (ptr->smooth + 1.0 < EPS) ? EPS : ptr->smooth + 1.0; + ptr->inside_h *= (1.0 + fabs(r)); + ptr->smooth = ptr->inside_h - 1.0; + ptr->total_expect = 0.0; + ptr = ptr->next; + } + } + } +} + +/*------------------------------------------------------------------------*/ + +int compute_pi_scaling_none(void) +{ + int i; + SW_INS_PTR ptr; + double alpha_sum, psi0; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + + alpha_sum = 0.0; + while (ptr != NULL) { + alpha_sum += ptr->inside_h; + ptr = ptr->next; + } + psi0 = digamma(alpha_sum); + + ptr = occ_switches[i]; + while (ptr != NULL) { + ptr->pi = exp(digamma(ptr->inside_h) - psi0); + ptr = ptr->next; + } + } + + return BP_TRUE; +} + +int compute_pi_scaling_log_exp(void) +{ + int i; + SW_INS_PTR ptr; + double alpha_sum, psi0; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + + alpha_sum = 0.0; + while (ptr != NULL) { + alpha_sum += ptr->inside_h; + ptr = ptr->next; + } + psi0 = digamma(alpha_sum); + + ptr = occ_switches[i]; + while (ptr != NULL) { + ptr->pi = digamma(ptr->inside_h) - psi0; + ptr = ptr->next; + } + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +int compute_inside_vb_scaling_none(void) +{ + int i,k; + double sum,this_path_inside; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + sum = 0.0; + path_ptr = eg_ptr->path_ptr; + if (path_ptr == NULL) sum = 1.0; + + while (path_ptr != NULL) { + this_path_inside = 1.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_inside *= path_ptr->children[k]->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_inside *= path_ptr->sws[k]->pi; + } + path_ptr->inside = this_path_inside; + sum += this_path_inside; + path_ptr = path_ptr->next; + } + + eg_ptr->inside = sum; + } + + return BP_TRUE; +} + +int compute_inside_vb_scaling_log_exp(void) +{ + int i,k,u; + double sum, this_path_inside, first_path_inside = 0.0, sum_rest; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + sum = 0.0; + path_ptr = eg_ptr->path_ptr; + + if (path_ptr == NULL) { + sum = 0.0; + } + else { + sum_rest = 0.0; + u = 0; + while (path_ptr != NULL) { + this_path_inside = 0.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_inside += path_ptr->children[k]->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_inside += path_ptr->sws[k]->pi; /* log-scale */ + } + path_ptr->inside = this_path_inside; + + if (u == 0) { + first_path_inside = this_path_inside; + sum_rest += 1.0; + } + else if (this_path_inside - first_path_inside >= log(HUGE_PROB)) { + sum_rest *= exp(first_path_inside - this_path_inside); + first_path_inside = this_path_inside; + sum_rest += 1.0; + } + else { + sum_rest += exp(this_path_inside - first_path_inside); + } + path_ptr = path_ptr->next; + u++; + } + sum = first_path_inside + log(sum_rest); + } + + eg_ptr->inside = sum; + } + + return BP_TRUE; +} + +int compute_daem_inside_vb_scaling_none(void) +{ + int i,k; + double sum,this_path_inside; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + sum = 0.0; + path_ptr = eg_ptr->path_ptr; + if (path_ptr == NULL) sum = 1.0; + + while (path_ptr != NULL) { + this_path_inside = 1.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_inside *= path_ptr->children[k]->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_inside *= pow(path_ptr->sws[k]->pi,itemp); + } + path_ptr->inside = this_path_inside; + sum += this_path_inside; + path_ptr = path_ptr->next; + } + + eg_ptr->inside = sum; + } + + return BP_TRUE; +} + +int compute_daem_inside_vb_scaling_log_exp(void) +{ + int i,k,u; + double sum, this_path_inside, first_path_inside = 0.0, sum_rest; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + sum = 0.0; + path_ptr = eg_ptr->path_ptr; + + if (path_ptr == NULL) { + sum = 0.0; + } + else { + sum_rest = 0.0; + u = 0; + while (path_ptr != NULL) { + this_path_inside = 0.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_inside += path_ptr->children[k]->inside; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_inside += itemp * path_ptr->sws[k]->pi; + } + path_ptr->inside = this_path_inside; + + if (u == 0) { + first_path_inside = this_path_inside; + sum_rest += 1.0; + } + else if (this_path_inside - first_path_inside >= log(HUGE_PROB)) { + sum_rest *= exp(first_path_inside - this_path_inside); + first_path_inside = this_path_inside; + sum_rest += 1.0; + } + else { + sum_rest += exp(this_path_inside - first_path_inside); + } + path_ptr = path_ptr->next; + u++; + } + sum = first_path_inside + log(sum_rest); + } + + eg_ptr->inside = sum; + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +/* [27 Aug 2007, by yuizumi] + * A variational free energy F is given by: + * F = F0 - F1 + L' + * where: + * F0 = compute_[daem_]free_energy_l0() + * F1 = compute_[daem_]free_energy_l1_scaling_{none|log_exp}() + * L' = compute_likelihood() / itemp + */ + +double compute_free_energy_l0(void) +{ + double l0 = 0.0; + double smooth_sum; + SW_INS_PTR ptr; + int i; + + for (i = 0; i < occ_switch_tab_size; i++) { + smooth_sum = 0.0; + ptr = occ_switches[i]; + + while (ptr != NULL) { + smooth_sum += (ptr->smooth + 1.0); + ptr = ptr->next; + } + l0 += lngamma(smooth_sum); + + smooth_sum = 0.0; + ptr = occ_switches[i]; + while (ptr != NULL) { + smooth_sum += (ptr->inside_h); + ptr = ptr->next; + } + l0 -= lngamma(smooth_sum); + + ptr = occ_switches[i]; + while (ptr != NULL) { + l0 += lngamma(ptr->inside_h); + l0 -= lngamma(ptr->smooth + 1.0); + ptr = ptr->next; + } + } + + return l0; +} + +double compute_daem_free_energy_l0(void) +{ + double l0 = 0.0; + double smooth_sum; + SW_INS_PTR ptr; + int i; + + for (i = 0; i < occ_switch_tab_size; i++) { + smooth_sum = 0.0; + ptr = occ_switches[i]; + + while (ptr != NULL) { + smooth_sum += (ptr->smooth + 1.0); + ptr = ptr->next; + } + l0 += lngamma(smooth_sum); + + smooth_sum = 0.0; + ptr = occ_switches[i]; + while (ptr != NULL) { + smooth_sum += (ptr->inside_h); + ptr = ptr->next; + } + l0 -= lngamma(smooth_sum) / itemp; + + ptr = occ_switches[i]; + while (ptr != NULL) { + l0 += lngamma(ptr->inside_h) / itemp; + l0 -= lngamma(ptr->smooth + 1.0); + ptr = ptr->next; + } + } + + return l0; +} + +double compute_free_energy_l1_scaling_none(void) +{ + double l1 = 0.0; + SW_INS_PTR ptr; + int i; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + while (ptr != NULL) { + l1 += ((ptr->inside_h - 1.0) - ptr->smooth) * log(ptr->pi); + ptr = ptr->next; + } + } + + return l1; +} + +double compute_free_energy_l1_scaling_log_exp(void) +{ + double l1 = 0.0; + SW_INS_PTR ptr; + int i; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + while (ptr != NULL) { + /* pi is in log-scale */ + l1 += (ptr->inside_h - (ptr->smooth + 1.0)) * ptr->pi; + ptr = ptr->next; + } + } + + return l1; +} + +double compute_daem_free_energy_l1_scaling_none(void) +{ + double l1 = 0.0; + SW_INS_PTR ptr; + int i; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + while (ptr != NULL) { + l1 += ((ptr->inside_h - 1.0) / itemp - ptr->smooth) * log(ptr->pi); + ptr = ptr->next; + } + } + + return l1; +} + +double compute_daem_free_energy_l1_scaling_log_exp(void) +{ + double l1 = 0.0; + SW_INS_PTR ptr; + int i; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + while (ptr != NULL) { + /* pi is in log-scale */ + l1 += ((ptr->inside_h - 1.0) / itemp - ptr->smooth) * ptr->pi; + ptr = ptr->next; + } + } + + return l1; +} + +/*------------------------------------------------------------------------*/ + +int update_hyperparams(void) +{ + int i; + SW_INS_PTR ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + if (ptr->fixed_h > 0) continue; + + while (ptr != NULL) { + ptr->inside_h = ptr->total_expect + ptr->smooth + 1.0; + ptr = ptr->next; + } + } + + return BP_TRUE; +} + +int update_daem_hyperparams(void) +{ + int i; + SW_INS_PTR ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + if (ptr->fixed_h > 0) continue; + + while (ptr != NULL) { + ptr->inside_h = itemp * (ptr->total_expect + ptr->smooth) + 1.0; + ptr = ptr->next; + } + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +void save_hyperparams(void) +{ + int i; + SW_INS_PTR ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + if (ptr->fixed_h > 0) continue; + while (ptr != NULL) { + ptr->best_inside_h = ptr->inside_h; + ptr = ptr->next; + } + } +} + +void restore_hyperparams(void) +{ + int i; + SW_INS_PTR ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + if (ptr->fixed_h > 0) continue; + while (ptr != NULL) { + ptr->inside_h = ptr->best_inside_h; + ptr = ptr->next; + } + } +} + +void transfer_hyperparams(void) +{ + int i; + SW_INS_PTR ptr; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + if (ptr->fixed_h > 0) continue; + + while (ptr != NULL) { + ptr->smooth = ptr->inside_h - 1.0; + ptr = ptr->next; + } + } +} + +/*------------------------------------------------------------------------*/ + +void get_param_means(void) +{ + int i; + SW_INS_PTR ptr; + double sum; + + for (i = 0; i < occ_switch_tab_size; i++) { + ptr = occ_switches[i]; + if (ptr->fixed > 0) continue; + + sum = 0.0; + while (ptr != NULL) { + sum += ptr->inside_h; + ptr = ptr->next; + } + + ptr = occ_switches[i]; + while (ptr != NULL) { + ptr->inside = ptr->inside_h / sum; + ptr = ptr->next; + } + } +} + +/*------------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/up/em_aux_vb.h b/packages/prism/src/c/up/em_aux_vb.h new file mode 100644 index 000000000..6d2b5c84f --- /dev/null +++ b/packages/prism/src/c/up/em_aux_vb.h @@ -0,0 +1,25 @@ +#ifndef EM_AUX_VB_H +#define EM_AUX_VB_H + +int check_smooth_vb(void); +void initialize_hyperparams(void); +int compute_pi_scaling_none(void); +int compute_pi_scaling_log_exp(void); +int compute_inside_vb_scaling_none(void); +int compute_inside_vb_scaling_log_exp(void); +int compute_daem_inside_vb_scaling_none(void); +int compute_daem_inside_vb_scaling_log_exp(void); +double compute_free_energy_l0(void); +double compute_daem_free_energy_l0(void); +double compute_free_energy_l1_scaling_none(void); +double compute_free_energy_l1_scaling_log_exp(void); +double compute_daem_free_energy_l1_scaling_none(void); +double compute_daem_free_energy_l1_scaling_log_exp(void); +int update_hyperparams(void); +int update_daem_hyperparams(void); +void save_hyperparams(void); +void restore_hyperparams(void); +void transfer_hyperparams(void); +void get_param_means(void); + +#endif /* EM_AUX_VB_H */ diff --git a/packages/prism/src/c/up/em_ml.c b/packages/prism/src/c/up/em_ml.c new file mode 100644 index 000000000..ed52a5202 --- /dev/null +++ b/packages/prism/src/c/up/em_ml.c @@ -0,0 +1,162 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "up/up.h" +#include "up/graph_aux.h" +#include "up/em.h" +#include "up/em_aux.h" +#include "up/em_aux_ml.h" +#include "up/flags.h" +#include "up/util.h" + +/*------------------------------------------------------------------------*/ + +void config_em(EM_ENG_PTR em_ptr) +{ + if (log_scale) { + em_ptr->compute_inside = daem ? compute_daem_inside_scaling_log_exp : compute_inside_scaling_log_exp; + em_ptr->examine_inside = examine_inside_scaling_log_exp; + em_ptr->compute_expectation = compute_expectation_scaling_log_exp; + em_ptr->compute_likelihood = compute_likelihood_scaling_log_exp; + em_ptr->compute_log_prior = daem ? compute_daem_log_prior : compute_log_prior; + em_ptr->update_params = em_ptr->smooth ? update_params_smooth : update_params; + } + else { + em_ptr->compute_inside = daem ? compute_daem_inside_scaling_none : compute_inside_scaling_none; + em_ptr->examine_inside = examine_inside_scaling_none; + em_ptr->compute_expectation = compute_expectation_scaling_none; + em_ptr->compute_likelihood = compute_likelihood_scaling_none; + em_ptr->compute_log_prior = daem ? compute_daem_log_prior : compute_log_prior; + em_ptr->update_params = em_ptr->smooth ? update_params_smooth : update_params; + } +} + +/*------------------------------------------------------------------------*/ + +int run_em(EM_ENG_PTR em_ptr) +{ + int r, iterate, old_valid, converged, saved = 0; + double likelihood, log_prior; + double lambda, old_lambda = 0.0; + + config_em(em_ptr); + + for (r = 0; r < num_restart; r++) { + SHOW_PROGRESS_HEAD("#em-iters", r); + + initialize_params(); + itemp = daem ? itemp_init : 1.0; + iterate = 0; + + /* [21 Aug 2007, by yuizumi] + * while-loop for inversed temperature (DAEM). Note that this + * loop is evaluated only once for EM without annealing, since + * itemp initially set to 1.0 by the code above. + */ + while (1) { + if (daem) { + SHOW_PROGRESS_TEMP(itemp); + } + old_valid = 0; + + while (1) { + if (CTRLC_PRESSED) { + SHOW_PROGRESS_INTR(); + RET_ERR(err_ctrl_c_pressed); + } + + RET_ON_ERR(em_ptr->compute_inside()); + RET_ON_ERR(em_ptr->examine_inside()); + + likelihood = em_ptr->compute_likelihood(); + log_prior = em_ptr->smooth ? em_ptr->compute_log_prior() : 0.0; + lambda = likelihood + log_prior; + + if (verb_em) { + if (em_ptr->smooth) { + prism_printf("Iteration #%d:\tlog_likelihood=%.9f\tlog_prior=%.9f\tlog_post=%.9f\n", iterate, likelihood, log_prior, lambda); + } + else { + prism_printf("Iteration #%d:\tlog_likelihood=%.9f\n", iterate, likelihood); + } + } + + if (debug_level) { + prism_printf("After I-step[%d]:\n", iterate); + prism_printf("likelihood = %.9f\n", likelihood); + print_egraph(debug_level, PRINT_EM); + } + + if (!isfinite(lambda)) { + emit_internal_error("invalid log likelihood or log post: %s (at iteration #%d)", + isnan(lambda) ? "NaN" : "infinity", iterate); + RET_ERR(ierr_invalid_likelihood); + } + if (old_valid && old_lambda - lambda > prism_epsilon) { + emit_error("log likelihood or log post decreased [old: %.9f, new: %.9f] (at iteration #%d)", + old_lambda, lambda, iterate); + RET_ERR(err_invalid_likelihood); + } + if (itemp == 1.0 && likelihood > 0.0) { + emit_error("log likelihood greater than zero [value: %.9f] (at iteration #%d)", + likelihood, iterate); + RET_ERR(err_invalid_likelihood); + } + + converged = (old_valid && lambda - old_lambda <= prism_epsilon); + if (converged || REACHED_MAX_ITERATE(iterate)) { + break; + } + + old_lambda = lambda; + old_valid = 1; + + RET_ON_ERR(em_ptr->compute_expectation()); + + if (debug_level) { + prism_printf("After O-step[%d]:\n", iterate); + print_egraph(debug_level, PRINT_EM); + } + + SHOW_PROGRESS(iterate); + RET_ON_ERR(em_ptr->update_params()); + iterate++; + } + + /* [21 Aug 2007, by yuizumi] + * Note that 1.0 can be represented exactly in IEEE 754. + */ + if (itemp == 1.0) { + break; + } + itemp *= itemp_rate; + if (itemp >= 1.0) { + itemp = 1.0; + } + } + + SHOW_PROGRESS_TAIL(converged, iterate, lambda); + + if (r == 0 || lambda > em_ptr->lambda) { + em_ptr->lambda = lambda; + em_ptr->likelihood = likelihood; + em_ptr->iterate = iterate; + + saved = (r < num_restart - 1); + if (saved) { + save_params(); + } + } + } + + if (saved) { + restore_params(); + } + + em_ptr->bic = compute_bic(em_ptr->likelihood); + em_ptr->cs = em_ptr->smooth ? compute_cs(em_ptr->likelihood) : 0.0; + + return BP_TRUE; +} diff --git a/packages/prism/src/c/up/em_ml.h b/packages/prism/src/c/up/em_ml.h new file mode 100644 index 000000000..292e30a24 --- /dev/null +++ b/packages/prism/src/c/up/em_ml.h @@ -0,0 +1,8 @@ +#ifndef EM_ML_H +#define EM_ML_H + +void config_em(EM_ENG_PTR); +int run_em(EM_ENG_PTR); + +#endif /* EM_ML_H */ + diff --git a/packages/prism/src/c/up/em_preds.c b/packages/prism/src/c/up/em_preds.c new file mode 100644 index 000000000..6b837fc01 --- /dev/null +++ b/packages/prism/src/c/up/em_preds.c @@ -0,0 +1,181 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "up/up.h" +#include "up/graph.h" +#include "up/graph_aux.h" +#include "up/em.h" +#include "up/em_ml.h" +#include "up/em_vb.h" +#include "up/em_aux.h" +#include "up/em_aux_ml.h" +#include "up/em_aux_vb.h" +#include "up/viterbi.h" +#include "up/hindsight.h" +#include "up/flags.h" +#include "up/util.h" + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +NORET myquit(int, const char *); + +/*------------------------------------------------------------------------*/ + +int pc_prism_prepare_4(void) +{ + TERM p_fact_list; + int size; + + p_fact_list = bpx_get_call_arg(1,4); + size = bpx_get_integer(bpx_get_call_arg(2,4)); + num_goals = bpx_get_integer(bpx_get_call_arg(3,4)); + failure_root_index = bpx_get_integer(bpx_get_call_arg(4,4)); + + failure_observed = (failure_root_index != -1); + + if (failure_root_index != -1) { + failure_subgoal_id = prism_goal_id_get(failure_atom); + if (failure_subgoal_id == -1) { + emit_internal_error("no subgoal ID allocated to `failure'"); + RET_INTERNAL_ERR; + } + } + + initialize_egraph_index(); + alloc_sorted_egraph(size); + RET_ON_ERR(sort_egraphs(p_fact_list)); +#ifndef MPI + if (verb_graph) { + print_egraph(0, PRINT_NEUTRAL); + } +#endif /* !(MPI) */ + + alloc_occ_switches(); + if (fix_init_order) { + sort_occ_switches(); + } + alloc_num_sw_vals(); + + return BP_TRUE; +} + +int pc_prism_em_6(void) +{ + struct EM_Engine em_eng; + + RET_ON_ERR(check_smooth(&em_eng.smooth)); + RET_ON_ERR(run_em(&em_eng)); + release_num_sw_vals(); + + return + bpx_unify(bpx_get_call_arg(1,6), bpx_build_integer(em_eng.iterate )) && + bpx_unify(bpx_get_call_arg(2,6), bpx_build_float (em_eng.lambda )) && + bpx_unify(bpx_get_call_arg(3,6), bpx_build_float (em_eng.likelihood)) && + bpx_unify(bpx_get_call_arg(4,6), bpx_build_float (em_eng.bic )) && + bpx_unify(bpx_get_call_arg(5,6), bpx_build_float (em_eng.cs )) && + bpx_unify(bpx_get_call_arg(6,6), bpx_build_integer(em_eng.smooth )) ; +} + +int pc_prism_vbem_2(void) +{ + struct VBEM_Engine vb_eng; + + RET_ON_ERR(check_smooth_vb()); + RET_ON_ERR(run_vbem(&vb_eng)); + release_num_sw_vals(); + + return + bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) && + bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy)); +} + +int pc_prism_both_em_2(void) +{ + struct VBEM_Engine vb_eng; + + RET_ON_ERR(check_smooth_vb()); + RET_ON_ERR(run_vbem(&vb_eng)); + + get_param_means(); + + release_num_sw_vals(); + + return + bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) && + bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy)); +} + +int pc_compute_inside_2(void) +{ + int gid; + double prob; + EG_NODE_PTR eg_ptr; + + gid = bpx_get_integer(bpx_get_call_arg(1,2)); + + initialize_egraph_index(); + alloc_sorted_egraph(1); + RET_ON_ERR(sort_one_egraph(gid, 0, 1)); + + if (verb_graph) { + print_egraph(0, PRINT_NEUTRAL); + } + + eg_ptr = expl_graph[gid]; + + if (log_scale) { + RET_ON_ERR(compute_inside_scaling_log_exp()); + prob = eg_ptr->inside; + } + else { + RET_ON_ERR(compute_inside_scaling_none()); + prob = eg_ptr->inside; + } + + return bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(prob)); +} + +/*------------------------------------------------------------------------*/ + +int pc_compute_probf_1(void) +{ + EG_NODE_PTR eg_ptr; + int prmode; + + prmode = bpx_get_integer(bpx_get_call_arg(1,1)); + + if (prmode == 3) { + compute_max(); + return BP_TRUE; + } + + eg_ptr = expl_graph[roots[0]->id]; + failure_root_index = -1; + + /* [31 Mar 2008, by yuizumi] + * compute_outside_scaling_*() is needed to be called because + * eg_ptr->outside computed by compute_expectation_scaling_*() + * is different from the outside probability. + */ + if (log_scale) { + RET_ON_ERR(compute_inside_scaling_log_exp()); + if (prmode != 1) { + RET_ON_ERR(compute_expectation_scaling_log_exp()); + RET_ON_ERR(compute_outside_scaling_log_exp()); + } + } + else { + RET_ON_ERR(compute_inside_scaling_none()); + if (prmode != 1) { + RET_ON_ERR(compute_expectation_scaling_none()); + RET_ON_ERR(compute_outside_scaling_none()); + } + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/up/em_preds.h b/packages/prism/src/c/up/em_preds.h new file mode 100644 index 000000000..d12d8d25d --- /dev/null +++ b/packages/prism/src/c/up/em_preds.h @@ -0,0 +1,11 @@ +#ifndef EM_PREDS_H +#define EM_PREDS_H + +int pc_prism_prepare_4(void); +int pc_prism_em_6(void); +int pc_prism_vbem_2(void); +int pc_prism_both_em_7(void); +int pc_compute_inside_2(void); +int pc_compute_probf_1(void); + +#endif /* EM_PREDS_H */ diff --git a/packages/prism/src/c/up/em_vb.c b/packages/prism/src/c/up/em_vb.c new file mode 100644 index 000000000..390f1e669 --- /dev/null +++ b/packages/prism/src/c/up/em_vb.c @@ -0,0 +1,170 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "core/random.h" +#include "up/up.h" +#include "up/graph_aux.h" +#include "up/em.h" +#include "up/em_aux.h" +#include "up/em_aux_ml.h" +#include "up/em_aux_vb.h" +#include "up/flags.h" +#include "up/util.h" + +/*------------------------------------------------------------------------*/ + +void config_vbem(VBEM_ENG_PTR vb_ptr) +{ + if (log_scale) { + vb_ptr->compute_pi = compute_pi_scaling_log_exp; + vb_ptr->compute_inside = daem ? compute_daem_inside_vb_scaling_log_exp : compute_inside_vb_scaling_log_exp; + vb_ptr->examine_inside = examine_inside_scaling_log_exp; + vb_ptr->compute_expectation = compute_expectation_scaling_log_exp; + vb_ptr->compute_free_energy_l0 = daem ? compute_daem_free_energy_l0 : compute_free_energy_l0; + vb_ptr->compute_free_energy_l1 = daem ? compute_daem_free_energy_l1_scaling_log_exp : compute_free_energy_l1_scaling_log_exp; + vb_ptr->compute_likelihood = compute_likelihood_scaling_log_exp; + vb_ptr->update_hyperparams = daem ? update_daem_hyperparams : update_hyperparams; + } + else { + vb_ptr->compute_pi = compute_pi_scaling_none; + vb_ptr->compute_inside = daem ? compute_daem_inside_vb_scaling_none : compute_inside_vb_scaling_none; + vb_ptr->examine_inside = examine_inside_scaling_none; + vb_ptr->compute_expectation = compute_expectation_scaling_none; + vb_ptr->compute_free_energy_l0 = daem ? compute_daem_free_energy_l0 : compute_free_energy_l0; + vb_ptr->compute_free_energy_l1 = daem ? compute_daem_free_energy_l1_scaling_none : compute_free_energy_l1_scaling_none; + vb_ptr->compute_likelihood = compute_likelihood_scaling_none; + vb_ptr->update_hyperparams = daem ? update_daem_hyperparams : update_hyperparams; + } +} + +/*------------------------------------------------------------------------*/ + +int run_vbem(VBEM_ENG_PTR vb_ptr) +{ + int r, iterate, old_valid, converged, saved = 0; + double free_energy, old_free_energy = 0.0; + double l0, l1, l2; + + config_vbem(vb_ptr); + + for (r = 0; r < num_restart; r++) { + SHOW_PROGRESS_HEAD("#vbem-iters", r); + + initialize_hyperparams(); + itemp = daem ? itemp_init : 1.0; + iterate = 0; + + /* [21 Aug 2007, by yuizumi] + * while-loop for inversed temperature (DAEM). Note that this + * loop is evaluated only once for EM without annealing, since + * itemp initially set to 1.0 by the code above. + */ + while (1) { + if (daem) { + SHOW_PROGRESS_TEMP(itemp); + } + old_valid = 0; + + while (1) { + if (CTRLC_PRESSED) { + SHOW_PROGRESS_INTR(); + RET_ERR(err_ctrl_c_pressed); + } + + RET_ON_ERR(vb_ptr->compute_pi()); + RET_ON_ERR(vb_ptr->compute_inside()); + RET_ON_ERR(vb_ptr->examine_inside()); + + /* compute free_energy */ + l0 = vb_ptr->compute_free_energy_l0(); + l1 = vb_ptr->compute_free_energy_l1(); + l2 = vb_ptr->compute_likelihood() / itemp; /* itemp == 1.0 for non-DAEM */ + free_energy = l0 - l1 + l2; + + if (verb_em) { + prism_printf("Iteration #%d:\tfree_energy=%.9f\n", iterate, free_energy); + } + + if (debug_level) { + prism_printf("After I-step[%d]:\n", iterate); + prism_printf("free_energy = %.9f\n", free_energy); + print_egraph(debug_level, PRINT_VBEM); + } + + if (!isfinite(free_energy)) { + emit_internal_error("invalid variational free energy: %s (at iteration #%d)", + isnan(free_energy) ? "NaN" : "infinity", iterate); + RET_ERR(err_invalid_free_energy); + } + if (old_valid && old_free_energy - free_energy > prism_epsilon) { + emit_error("variational free energy decreased [old: %.9f, new: %.9f] (at iteration #%d)", + old_free_energy, free_energy, iterate); + RET_ERR(err_invalid_free_energy); + } + if (itemp == 1.0 && free_energy > 0.0) { + emit_error("variational free energy exceeds zero [value: %.9f] (at iteration #%d)", + free_energy, iterate); + RET_ERR(err_invalid_free_energy); + } + + converged = (old_valid && free_energy - old_free_energy <= prism_epsilon); + if (converged || REACHED_MAX_ITERATE(iterate)) { + break; + } + + old_free_energy = free_energy; + old_valid = 1; + + RET_ON_ERR(vb_ptr->compute_expectation()); + + if (debug_level) { + prism_printf("After O-step[%d]:\n", iterate); + print_egraph(debug_level, PRINT_VBEM); + } + + SHOW_PROGRESS(iterate); + RET_ON_ERR(vb_ptr->update_hyperparams()); + + if (debug_level) { + prism_printf("After update[%d]:\n", iterate); + print_egraph(debug_level, PRINT_VBEM); + } + + iterate++; + } + + /* [21 Aug 2007, by yuizumi] + * Note that 1.0 can be represented exactly in IEEE 754. + */ + if (itemp == 1.0) { + break; + } + itemp *= itemp_rate; + if (itemp >= 1.0) { + itemp = 1.0; + } + } + + SHOW_PROGRESS_TAIL(converged, iterate, free_energy); + + if (r == 0 || free_energy > vb_ptr->free_energy) { + vb_ptr->free_energy = free_energy; + vb_ptr->iterate = iterate; + + saved = (r < num_restart - 1); + if (saved) { + save_hyperparams(); + } + } + } + + if (saved) { + restore_hyperparams(); + } + + transfer_hyperparams(); + + return BP_TRUE; +} diff --git a/packages/prism/src/c/up/em_vb.h b/packages/prism/src/c/up/em_vb.h new file mode 100644 index 000000000..196d9a724 --- /dev/null +++ b/packages/prism/src/c/up/em_vb.h @@ -0,0 +1,8 @@ +#ifndef EM_VB_H +#define EM_VB_H + +void config_vbem(VBEM_ENG_PTR); +int run_vbem(VBEM_ENG_PTR); + +#endif /* EM_VB_H */ + diff --git a/packages/prism/src/c/up/flags.c b/packages/prism/src/c/up/flags.c new file mode 100644 index 000000000..11236e95e --- /dev/null +++ b/packages/prism/src/c/up/flags.c @@ -0,0 +1,158 @@ +/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */ + +/*------------------------------------------------------------------------*/ + +#include "bprolog.h" +#include "up/up.h" + +/*------------------------------------------------------------------------*/ + +/* + * Since these variables are initialized on start-up by the predicate + * reset_prism_flags/0, the initial values below are not actually used. + * The values are just for reference. + * + * Also, don't forget to modify mp_flags.c when adding new flags. + */ +int daem = 0; +int em_message = 1; +int em_progress = 10; +int error_on_cycle = 1; +int explicit_empty_expls = 1; +int fix_init_order = 1; +int init_method = 1; +double itemp_init = 0.1; +double itemp_rate = 1.2; +int log_scale = 0; +int max_iterate = -1; /* == DEFAULT_MAX_ITERATE */ +int num_restart = 1; +double prism_epsilon = 0.0001; +int show_itemp = 0; +double std_ratio = 0.1; +int verb_em = 0; +int verb_graph = 0; +static int warn = 0; + +/* + * This variable does not correspond to any prism flags, and hence is + * not initialized by reset_prism_flags/0. + */ +int debug_level = 0; + +/*------------------------------------------------------------------------*/ + +int pc_set_daem_1(void) +{ + daem = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_em_message_1(void) +{ + em_message = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_em_progress_1(void) +{ + em_progress = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_error_on_cycle_1(void) +{ + error_on_cycle = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_explicit_empty_expls_1(void) +{ + explicit_empty_expls = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_fix_init_order_1(void) +{ + fix_init_order = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_init_method_1(void) +{ + init_method = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_itemp_init_1(void) +{ + itemp_init = bpx_get_float(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_itemp_rate_1(void) +{ + itemp_rate = bpx_get_float(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_log_scale_1(void) +{ + log_scale = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_max_iterate_1(void) +{ + max_iterate = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_num_restart_1(void) +{ + num_restart = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_prism_epsilon_1(void) +{ + prism_epsilon = bpx_get_float(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_show_itemp_1(void) +{ + show_itemp = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_std_ratio_1(void) +{ + std_ratio = bpx_get_float(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_verb_em_1(void) +{ + verb_em = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_verb_graph_1(void) +{ + verb_graph = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_warn_1(void) +{ + warn = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +int pc_set_debug_level_1(void) +{ + debug_level = bpx_get_integer(bpx_get_call_arg(1,1)); + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/up/flags.h b/packages/prism/src/c/up/flags.h new file mode 100644 index 000000000..9712c7db8 --- /dev/null +++ b/packages/prism/src/c/up/flags.h @@ -0,0 +1,48 @@ +#ifndef FLAGS_H +#define FLAGS_H + +/*========================================================================*/ + +int pc_set_daem_1(void); +int pc_set_em_message_1(void); +int pc_set_em_progress_1(void); +int pc_set_error_on_cycle_1(void); +int pc_set_explicit_empty_expls_1(void); +int pc_set_fix_init_order_1(void); +int pc_set_init_method_1(void); +int pc_set_itemp_init_1(void); +int pc_set_itemp_rate_1(void); +int pc_set_log_scale_1(void); +int pc_set_max_iterate_1(void); +int pc_set_num_restart_1(void); +int pc_set_prism_epsilon_1(void); +int pc_set_show_itemp_1(void); +int pc_set_std_ratio_1(void); +int pc_set_verb_em_1(void); +int pc_set_verb_graph_1(void); +int pc_set_warn_1(void); +int pc_set_debug_level_1(void); + +/*========================================================================*/ + +extern int daem; +extern int em_message; +extern int em_progress; +extern int error_on_cycle; +extern int explicit_empty_expls; +extern int fix_init_order; +extern int init_method; +extern double itemp_init; +extern double itemp_rate; +extern int log_scale; +extern int max_iterate; +extern int num_restart; +extern double prism_epsilon; +extern int show_itemp; +extern double std_ratio; +extern int verb_em; +extern int verb_graph; +extern int warn; +extern int debug_level; + +#endif /* FLAGS_H */ diff --git a/packages/prism/src/c/up/graph.c b/packages/prism/src/c/up/graph.c new file mode 100644 index 000000000..153316e5c --- /dev/null +++ b/packages/prism/src/c/up/graph.c @@ -0,0 +1,888 @@ +#include "up/up.h" +#include "up/flags.h" +#include "up/graph.h" +#include "up/util.h" + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +NORET quit(const char *); +NORET myquit(int, const char *); + +/* univ.c (B-Prolog) */ +int list_length(BPLONG, BPLONG); + +/*------------------------------------------------------------------------*/ + +static int max_egraph_size = INIT_MAX_EGRAPH_SIZE; +static int max_sorted_egraph_size = INIT_MAX_EGRAPH_SIZE; +static int egraph_size = 0; + +static int max_sw_tab_size = INIT_MAX_SW_TABLE_SIZE; +static int max_sw_ins_tab_size = INIT_MAX_SW_INS_TABLE_SIZE; + +static int index_to_sort = 0; +static int suppress_init_flags = 0; /* flag: suppress INIT_VISITED_FLAGS? */ + +int sorted_egraph_size = 0; +EG_NODE_PTR *expl_graph = NULL; +EG_NODE_PTR *sorted_expl_graph = NULL; +ROOT *roots = NULL; + +int num_roots; +int num_goals; + +int min_node_index; +int max_node_index; + +SW_INS_PTR *switches = NULL; +SW_INS_PTR *switch_instances = NULL; +SW_INS_PTR *occ_switches = NULL; /* subset of switches */ +int sw_tab_size = 0; +int sw_ins_tab_size = 0; +int occ_switch_tab_size = 0; + +int failure_subgoal_id; +int failure_root_index; + +/*------------------------------------------------------------------------*/ + +static void alloc_switch_table(void) +{ + int i; + sw_tab_size = 0; + switches = (SW_INS_PTR *)MALLOC(max_sw_tab_size * sizeof(SW_INS_PTR)); + + for (i = 0; i < max_sw_tab_size; i++) + switches[i] = NULL; +} + +static void expand_switch_table(int req_sw_tab_size) +{ + int old_size,i; + + if (req_sw_tab_size > max_sw_tab_size) { + old_size = max_sw_tab_size; + + while (req_sw_tab_size > max_sw_tab_size) + max_sw_tab_size *= 2; + + switches = (SW_INS_PTR *)REALLOC(switches, + max_sw_tab_size * sizeof(SW_INS_PTR)); + + for (i = old_size; i < max_sw_tab_size; i++) + switches[i] = NULL; + } +} + +static void clean_switch_table(void) +{ + if (switches != NULL) { + FREE(switches); + sw_tab_size = 0; + max_sw_tab_size = INIT_MAX_SW_TABLE_SIZE; + } +} + +/*------------------------------------------------------------------------*/ + +static SW_INS_PTR alloc_switch_instance(void) +{ + SW_INS_PTR sw_ptr = (SW_INS_PTR)MALLOC(sizeof(struct SwitchInstance)); + sw_ptr->inside = 0.5; + + return sw_ptr; +} + +static void alloc_switch_instance_table(void) +{ + int i; + sw_ins_tab_size = 0; + switch_instances = + (SW_INS_PTR *)MALLOC(max_sw_ins_tab_size * sizeof(SW_INS_PTR)); + + for (i = 0; i < max_sw_ins_tab_size; i++) + switch_instances[i] = NULL; +} + +static void expand_switch_instance_table(int req_sw_ins_tab_size) +{ + int old_size,i; + + if (req_sw_ins_tab_size > max_sw_ins_tab_size) { + old_size = max_sw_ins_tab_size; + + while (req_sw_ins_tab_size > max_sw_ins_tab_size) + max_sw_ins_tab_size *= 2; + + switch_instances = + (SW_INS_PTR *)REALLOC(switch_instances, + max_sw_ins_tab_size * sizeof(SW_INS_PTR)); + + for (i = old_size; i < max_sw_ins_tab_size; i++) + switch_instances[i] = NULL; + } +} + +static void clean_switch_instance_table(void) +{ + int i; + + if (switch_instances != NULL) { + for (i = 0; i < max_sw_ins_tab_size; i++) + FREE(switch_instances[i]); + FREE(switch_instances); + sw_ins_tab_size = 0; + max_sw_ins_tab_size = INIT_MAX_SW_INS_TABLE_SIZE; + } +} + +/*------------------------------------------------------------------------*/ + +static EG_NODE_PTR alloc_egraph_node(void) +{ + EG_NODE_PTR node_ptr = (EG_NODE_PTR)MALLOC(sizeof(struct ExplGraphNode)); + + node_ptr->inside = 1.0; + node_ptr->visited = 0; + node_ptr->path_ptr = NULL; + node_ptr->top_n = NULL; + node_ptr->top_n_len = 0; + node_ptr->shared = 0; + + return node_ptr; +} + +int pc_alloc_egraph_0(void) +{ + int i; + + alloc_switch_table(); + alloc_switch_instance_table(); + + egraph_size = 0; + expl_graph = (EG_NODE_PTR *)MALLOC(max_egraph_size * sizeof(EG_NODE_PTR)); + + for (i = 0; i < max_egraph_size; i++) { + expl_graph[i] = alloc_egraph_node(); + expl_graph[i]->id = i; + } + + return BP_TRUE; +} + +static void expand_egraph(int req_egraph_size) +{ + int old_size,i; + + if (req_egraph_size > max_egraph_size) { + old_size = max_egraph_size; + + while (req_egraph_size > max_egraph_size) { + if (max_egraph_size > MAX_EGRAPH_SIZE_EXPAND_LIMIT) { + max_egraph_size += MAX_EGRAPH_SIZE_EXPAND_LIMIT; + } + else { + max_egraph_size *= 2; + } + } + + expl_graph = + (EG_NODE_PTR *)REALLOC(expl_graph, + max_egraph_size * sizeof(EG_NODE_PTR)); + + for (i = old_size; i < max_egraph_size; i++) { + expl_graph[i] = alloc_egraph_node(); + expl_graph[i]->id = i; + } + } +} + +static void clean_sorted_egraph(void) +{ + FREE(sorted_expl_graph); +} + +/* Clean-up the base support graphs and switches */ +static void clean_base_egraph(void) +{ + int i,j; + EG_PATH_PTR path_ptr,next_path_ptr; + + clean_switch_table(); + clean_switch_instance_table(); + + if (expl_graph != NULL) { + for (i = 0; i < max_egraph_size; i++) { + if (expl_graph[i] == NULL) continue; + path_ptr = expl_graph[i]->path_ptr; + while (path_ptr != NULL) { + FREE(path_ptr->children); + FREE(path_ptr->sws); + next_path_ptr = path_ptr->next; + FREE(path_ptr); + path_ptr = next_path_ptr; + } + if (expl_graph[i]->top_n != NULL) { + for (j = 0; j < expl_graph[i]->top_n_len; j++) { + FREE(expl_graph[i]->top_n[j]->top_n_index); + FREE(expl_graph[i]->top_n[j]); + } + FREE(expl_graph[i]->top_n); + } + FREE(expl_graph[i]); + } + FREE(expl_graph); + egraph_size = 0; + max_egraph_size = INIT_MAX_EGRAPH_SIZE; + INIT_MIN_MAX_NODE_NOS; + } +} + +int pc_clean_base_egraph_0(void) +{ + clean_base_egraph(); + return BP_TRUE; +} + +int pc_clean_egraph_0(void) +{ + clean_sorted_egraph(); + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +int pc_export_switch_2(void) +{ + BPLONG sw,sw_ins_ids,sw_ins_id; + SW_INS_PTR *curr_ins_ptr; + + sw = bpx_get_integer(bpx_get_call_arg(1,2)); + sw_ins_ids = bpx_get_call_arg(2,2); + + if (sw >= max_sw_tab_size) expand_switch_table(sw + 1); + if (sw >= sw_tab_size) sw_tab_size = sw + 1; + + curr_ins_ptr = &switches[sw]; + while (bpx_is_list(sw_ins_ids)) { + sw_ins_id = bpx_get_integer(bpx_get_car(sw_ins_ids)); + sw_ins_ids = bpx_get_cdr(sw_ins_ids); + + if (sw_ins_id >= max_sw_ins_tab_size) + expand_switch_instance_table(sw_ins_id + 1); + if (sw_ins_id >= sw_ins_tab_size) sw_ins_tab_size = sw_ins_id + 1; + + switch_instances[sw_ins_id] = alloc_switch_instance(); + switch_instances[sw_ins_id]->id = sw_ins_id; + + *curr_ins_ptr = switch_instances[sw_ins_id]; + curr_ins_ptr = &switch_instances[sw_ins_id]->next; + } + *curr_ins_ptr = NULL; + + return BP_TRUE; +} + +static int add_egraph_path(int node_id, TERM children_prolog, TERM sws_prolog) +{ + EG_PATH_PTR path_ptr; + EG_NODE_PTR *children; + SW_INS_PTR *sws; + int len,k; + int child,sw; + TERM p_child,p_sw; + int list_length(BPLONG, BPLONG); + + if (node_id >= max_egraph_size) expand_egraph(node_id + 1); + if (node_id >= egraph_size) egraph_size = node_id + 1; + + path_ptr = (EG_PATH_PTR)MALLOC(sizeof(struct ExplGraphPath)); + + len = list_length(children_prolog, children_prolog); + if (len > 0) { + path_ptr->children_len = len; + children = (EG_NODE_PTR *)MALLOC(sizeof(EG_NODE_PTR) * len); + k = 0; + while (bpx_is_list(children_prolog)) { + p_child = bpx_get_car(children_prolog); + if (!bpx_is_integer(p_child)) + RET_ERR(err_invalid_goal_id); + child = bpx_get_integer(p_child); + children[k] = expl_graph[child]; + k++; + children_prolog = bpx_get_cdr(children_prolog); + } + path_ptr->children = children; + } + else { + path_ptr->children_len = 0; + path_ptr->children = NULL; + } + + len = list_length(sws_prolog, sws_prolog); + if (len > 0) { + path_ptr->sws_len = len; + sws = (SW_INS_PTR *)MALLOC(sizeof(SW_INS_PTR) * len); + k = 0; + while (bpx_is_list(sws_prolog)) { + p_sw = bpx_get_car(sws_prolog); + if (!bpx_is_integer(p_sw)) + RET_ERR(err_invalid_switch_instance_id); + sw = bpx_get_integer(p_sw); + sws[k] = switch_instances[sw]; + k++; + sws_prolog = bpx_get_cdr(sws_prolog); + } + path_ptr->sws = sws; + } + else { + path_ptr->sws_len = 0; + path_ptr->sws = NULL; + } + + path_ptr->next = expl_graph[node_id]->path_ptr; + expl_graph[node_id]->path_ptr = path_ptr; + + return BP_TRUE; +} + +int pc_add_egraph_path_3(void) +{ + TERM p_node_id,p_children,p_sws; + int node_id; + + /* children_prolog and sws_prolog must be in the table area */ + p_node_id = bpx_get_call_arg(1,3); + p_children = bpx_get_call_arg(2,3); + p_sws = bpx_get_call_arg(3,3); + + if (!bpx_is_integer(p_node_id)) RET_ERR(err_invalid_goal_id); + node_id = bpx_get_integer(p_node_id); + + XDEREF(p_children); + XDEREF(p_sws); + + RET_ON_ERR(add_egraph_path(node_id,p_children,p_sws)); + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +void alloc_sorted_egraph(int n) +{ + int i; + + max_sorted_egraph_size = INIT_MAX_EGRAPH_SIZE; + sorted_expl_graph = + (EG_NODE_PTR *)MALLOC(sizeof(EG_NODE_PTR) * max_sorted_egraph_size); + roots = (ROOT *)MALLOC(sizeof(ROOT *) * n); + + for (i = 0; i < n; i++) + roots[i] = NULL; + + num_roots = n; +} + +static void expand_sorted_egraph(int req_sorted_egraph_size) +{ + if (req_sorted_egraph_size > max_sorted_egraph_size) { + while (req_sorted_egraph_size > max_sorted_egraph_size) { + if (max_sorted_egraph_size > MAX_EGRAPH_SIZE_EXPAND_LIMIT) + max_sorted_egraph_size += MAX_EGRAPH_SIZE_EXPAND_LIMIT; + else + max_sorted_egraph_size *= 2; + } + sorted_expl_graph = + (EG_NODE_PTR *) + REALLOC(sorted_expl_graph, + max_sorted_egraph_size * sizeof(EG_NODE_PTR)); + } +} + +/*------------------------------------------------------------------------*/ + +void initialize_egraph_index(void) +{ + index_to_sort = 0; +} + +static int topological_sort(int node_id) +{ + EG_PATH_PTR path_ptr; + EG_NODE_PTR *children; + int k,len; + EG_NODE_PTR child_ptr; + + expl_graph[node_id]->visited = 2; + UPDATE_MIN_MAX_NODE_NOS(node_id); + + path_ptr = expl_graph[node_id]->path_ptr; + while (path_ptr != NULL) { + children = path_ptr->children; + len = path_ptr->children_len; + for (k = 0; k < len; k++) { + child_ptr = children[k]; + + if (child_ptr->visited == 2 && error_on_cycle) + RET_ERR(err_cycle_detected); + + if (child_ptr->visited == 0) { + RET_ON_ERR(topological_sort(child_ptr->id)); + expand_sorted_egraph(index_to_sort + 1); + sorted_expl_graph[index_to_sort++] = child_ptr; + } + child_ptr->shared += 1; + } + + path_ptr = path_ptr->next; + } + expl_graph[node_id]->visited = 1; + return BP_TRUE; +} + +int sort_one_egraph(int root_id, int root_index, int count) +{ + roots[root_index] = (ROOT)MALLOC(sizeof(struct ObservedFactNode)); + roots[root_index]->id = root_id; + roots[root_index]->count = count; + + if (expl_graph[root_id]->visited == 1) { + /* + * This top-goal is also a sub-goal of another top-goal. This + * should occur only when INIT_VISITED_FLAGS is suppressed + * (i.e. we have more than one observed goal in learning). + */ + if (suppress_init_flags) return BP_TRUE; + } + + if (expl_graph[root_id]->visited != 0) RET_INTERNAL_ERR; + + RET_ON_ERR(topological_sort(root_id)); + + expand_sorted_egraph(index_to_sort + 1); + sorted_expl_graph[index_to_sort] = expl_graph[root_id]; + + index_to_sort++; + sorted_egraph_size = index_to_sort; + + /* initialize flags after use */ + if (!suppress_init_flags) INIT_VISITED_FLAGS; + + return BP_TRUE; +} + +int sort_egraphs(TERM p_fact_list) /* assumed to be dereferenced in advance */ +{ + TERM pair; + int root_index = 0, goal_id, count; + + sorted_egraph_size = 0; + suppress_init_flags = 1; + + while (bpx_is_list(p_fact_list)) { + pair = bpx_get_car(p_fact_list); + p_fact_list = bpx_get_cdr(p_fact_list); + + goal_id = bpx_get_integer(bpx_get_arg(1,pair)); + count = bpx_get_integer(bpx_get_arg(2,pair)); + + if (sort_one_egraph(goal_id,root_index,count) == BP_ERROR) { + INIT_VISITED_FLAGS; + return BP_ERROR; + } + root_index++; + } + + suppress_init_flags = 0; + + INIT_VISITED_FLAGS; + return BP_TRUE; +} + +/* + * Sort the explanation graph such that no node sorted_expl_graph[i] calls + * node sorted_expl_graph[j] if i < j. + * + * This function is used only for probf/1-2, so we don't have to consider + * about scaling here. + */ +int pc_alloc_sort_egraph_1(void) +{ + int root_id; + + root_id = bpx_get_integer(bpx_get_call_arg(1,1)); + + index_to_sort = 0; + alloc_sorted_egraph(1); + RET_ON_ERR(sort_one_egraph(root_id,0,1)); + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +static void clean_root_tables(void) +{ + int i; + if (roots != NULL) { + for (i = 0; i < num_roots; i++) + FREE(roots[i]); + FREE(roots); + } +} + +int pc_clean_external_tables_0(void) +{ + clean_root_tables(); + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +/* + * Export probabilities of switches from Prolog to C. Switches is + * a list of switches, each of which takes the form: + * + * sw(Id,InstanceIds,Probs,SmoothCs,Fixed,FixedH), + * + * where + * Id: identifier of the switch + * InstanceIds: list of ids of the instances of the switch + * Probs: current probabilities assigned to the instance switches + * SmoothCs: current pseudo counts assigned to the instance switches + * Fixed: probabilities fixed? + * FixedH: pseudo counts fixed? + * + * The structures for switch instances have been allocated. This + * function only fills out the initial probabilities. + */ +int pc_export_sw_info_1(void) +{ + int sw_id,instance_id,fixed,fixed_h; + double prob,smooth; + TERM p_switches, p_switch; + TERM p_instance_list,p_prob_list,p_smooth_list; + TERM p_prob,p_smooth; + + p_switches = bpx_get_call_arg(1,1); + + while (bpx_is_list(p_switches)) { + /* p_switch: sw(Id,InstList,ProbList,SmoothCList,FixedP,FixedH) */ + p_switch = bpx_get_car(p_switches); + + sw_id = bpx_get_integer(bpx_get_arg(1,p_switch)); + p_instance_list = bpx_get_arg(2,p_switch); + p_prob_list = bpx_get_arg(3,p_switch); + p_smooth_list = bpx_get_arg(4,p_switch); + fixed = bpx_get_integer(bpx_get_arg(5,p_switch)); + fixed_h = bpx_get_integer(bpx_get_arg(6,p_switch)); + + while (bpx_is_list(p_instance_list)) { + instance_id = bpx_get_integer(bpx_get_car(p_instance_list)); + p_prob = bpx_get_car(p_prob_list); + p_smooth = bpx_get_car(p_smooth_list); + + if (bpx_is_integer(p_prob)) { + prob = (double)bpx_get_integer(p_prob); + } + else if (bpx_is_float(p_prob)) { + prob = bpx_get_float(p_prob); + } + else { + RET_ERR(illegal_arguments); + } + + if (bpx_is_integer(p_smooth)) { + smooth = (double)bpx_get_integer(p_smooth); + } + else if (bpx_is_float(p_smooth)) { + smooth = bpx_get_float(p_smooth); + } + else { + RET_ERR(illegal_arguments); + } + + switch_instances[instance_id]->inside = prob; + switch_instances[instance_id]->fixed = fixed; + switch_instances[instance_id]->fixed_h = fixed_h; + switch_instances[instance_id]->smooth_prolog = smooth; + + p_instance_list = bpx_get_cdr(p_instance_list); + p_prob_list = bpx_get_cdr(p_prob_list); + p_smooth_list = bpx_get_cdr(p_smooth_list); + } + p_switches = bpx_get_cdr(p_switches); + } + + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +/* the following functions are needed by probf */ + +int pc_import_sorted_graph_size_1(void) +{ + return bpx_unify(bpx_get_call_arg(1,1), + bpx_build_integer(sorted_egraph_size)); +} + +int pc_import_sorted_graph_gid_2(void) +{ + int idx = bpx_get_integer(bpx_get_call_arg(1,2)); + return bpx_unify(bpx_get_call_arg(2,2), + bpx_build_integer(sorted_expl_graph[idx]->id)); +} + +int pc_import_sorted_graph_paths_2(void) +{ + TERM paths0,paths1,glist,slist,t0,t1,p_tmp; + EG_PATH_PTR path_ptr; + EG_NODE_PTR *children; + SW_INS_PTR *sws; + int node_id,k,len; + + node_id = bpx_get_integer(bpx_get_call_arg(1,2)); + + path_ptr = sorted_expl_graph[node_id]->path_ptr; + + if (path_ptr == NULL) { + if (explicit_empty_expls) { + t0 = bpx_build_list(); + t1 = bpx_build_list(); + bpx_unify(bpx_get_car(t0),bpx_build_nil()); + bpx_unify(bpx_get_cdr(t0),t1); + bpx_unify(bpx_get_car(t1),bpx_build_nil()); + bpx_unify(bpx_get_cdr(t1),bpx_build_nil()); + + paths0 = bpx_build_list(); + bpx_unify(bpx_get_car(paths0),t0); + bpx_unify(bpx_get_cdr(paths0),bpx_build_nil()); + } + else paths0 = bpx_build_nil(); + } + else { + paths0 = bpx_build_nil(); + while (path_ptr != NULL) { + + len = path_ptr->children_len; + children = path_ptr->children; + + if (len > 0) { + glist = bpx_build_list(); + p_tmp = glist; + for (k = 0; k < len; k++) { + bpx_unify(bpx_get_car(p_tmp), + bpx_build_integer(children[k]->id)); + if (k == len - 1) { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_list()); + p_tmp = bpx_get_cdr(p_tmp); + } + } + } + else glist = bpx_build_nil(); + + len = path_ptr->sws_len; + sws = path_ptr->sws; + + if (len > 0) { + slist = bpx_build_list(); + p_tmp = slist; + for (k = 0; k < len; k++) { + bpx_unify(bpx_get_car(p_tmp),bpx_build_integer(sws[k]->id)); + if (k == len - 1) { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_list()); + p_tmp = bpx_get_cdr(p_tmp); + } + } + } + else slist = bpx_build_nil(); + + if (explicit_empty_expls || + !bpx_is_nil(glist) || !bpx_is_nil(slist)) { + + t0 = bpx_build_list(); + t1 = bpx_build_list(); + bpx_unify(bpx_get_car(t0),glist); + bpx_unify(bpx_get_cdr(t0),t1); + bpx_unify(bpx_get_car(t1),slist); + bpx_unify(bpx_get_cdr(t1),bpx_build_nil()); + + paths1 = bpx_build_list(); + bpx_unify(bpx_get_car(paths1),t0); + bpx_unify(bpx_get_cdr(paths1),paths0); + + paths0 = paths1; + } + + path_ptr = path_ptr->next; + } + } + + return bpx_unify(bpx_get_call_arg(2,2),paths0); +} + +int pc_get_gnode_inside_2(void) +{ + int idx = bpx_get_integer(bpx_get_call_arg(1,2)); + return bpx_unify(bpx_get_call_arg(2,2), + bpx_build_float(expl_graph[idx]->inside)); +} + +int pc_get_gnode_outside_2(void) +{ + int idx = bpx_get_integer(bpx_get_call_arg(1,2)); + return bpx_unify(bpx_get_call_arg(2,2), + bpx_build_float(expl_graph[idx]->outside)); +} + +int pc_get_gnode_viterbi_2(void) +{ + int idx = bpx_get_integer(bpx_get_call_arg(1,2)); + return bpx_unify(bpx_get_call_arg(2,2), + bpx_build_float(expl_graph[idx]->max)); +} + +int pc_get_snode_inside_2(void) +{ + int idx = bpx_get_integer(bpx_get_call_arg(1,2)); + double val = switch_instances[idx]->inside; + + if (log_scale) val = log(val); + + return bpx_unify(bpx_get_call_arg(2,2),bpx_build_float(val)); +} + +int pc_get_snode_expectation_2(void) +{ + int idx = bpx_get_integer(bpx_get_call_arg(1,2)); + return bpx_unify(bpx_get_call_arg(2,2), + bpx_build_float(switch_instances[idx]->total_expect)); +} + +int pc_import_occ_switches_3(void) +{ + TERM p_sw_list,p_sw_list0,p_sw_list1; + TERM p_sw_ins_list0,p_sw_ins_list1,sw,sw_ins; + TERM p_num_sw, p_num_sw_ins; + int i; + int num_sw_ins; + void release_occ_switches(); + +#ifdef __YAP_PROLOG__ + TERM *hstart; + restart: + hstart = heap_top; +#endif + p_sw_list = bpx_get_call_arg(1,3); + p_num_sw = bpx_get_call_arg(2,3); + p_num_sw_ins = bpx_get_call_arg(3,3); + + p_sw_list0 = bpx_build_nil(); + num_sw_ins = 0; + for (i = 0; i < occ_switch_tab_size; i++) { + SW_INS_PTR ptr; + +#ifdef __YAP_PROLOG__ + if ( heap_top + 64*1024 >= local_top ) { + H = hstart; + /* running out of stack */ + extern int Yap_gcl(UInt gc_lim, Int predarity, CELL *current_env, yamop *nextop); + + Yap_gcl(4*64*1024, 3, ENV, P); + goto restart; + } +#endif + + sw = bpx_build_structure("sw",2); + bpx_unify(bpx_get_arg(1,sw), bpx_build_integer(i)); + + p_sw_ins_list0 = bpx_build_nil(); + ptr = occ_switches[i]; + while (ptr != NULL) { + num_sw_ins++; + + if (ptr->inside <= 0.0) ptr->inside = 0.0; /* FIXME: quick hack */ + + sw_ins = bpx_build_structure("sw_ins",4); + bpx_unify(bpx_get_arg(1,sw_ins),bpx_build_integer(ptr->id)); + bpx_unify(bpx_get_arg(2,sw_ins),bpx_build_float(ptr->inside)); + bpx_unify(bpx_get_arg(3,sw_ins),bpx_build_float(ptr->smooth)); + bpx_unify(bpx_get_arg(4,sw_ins),bpx_build_float(ptr->total_expect)); + + p_sw_ins_list1 = bpx_build_list(); + bpx_unify(bpx_get_car(p_sw_ins_list1),sw_ins); + bpx_unify(bpx_get_cdr(p_sw_ins_list1),p_sw_ins_list0); + p_sw_ins_list0 = p_sw_ins_list1; + + ptr = ptr->next; + } + + bpx_unify(bpx_get_arg(2,sw),p_sw_ins_list0); + + p_sw_list1 = bpx_build_list(); + bpx_unify(bpx_get_car(p_sw_list1),sw); + bpx_unify(bpx_get_cdr(p_sw_list1),p_sw_list0); + p_sw_list0 = p_sw_list1; + } + + release_occ_switches(); + + return + bpx_unify(p_sw_list, p_sw_list0) && + bpx_unify(p_num_sw, bpx_build_integer(occ_switch_tab_size)) && + bpx_unify(p_num_sw_ins, bpx_build_integer(num_sw_ins)); +} + +/*------------------------------------------------------------------------*/ + +void graph_stats(int stats[4]) +{ + int num_goal_nodes = 0; + int num_switch_nodes = 0; + int total_shared = 0; + int i; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + total_shared += eg_ptr->shared; + + path_ptr = eg_ptr->path_ptr; + + while (path_ptr != NULL) { + num_goal_nodes += path_ptr->children_len; + num_switch_nodes += path_ptr->sws_len; + path_ptr = path_ptr->next; + } + } + + stats[0] = sorted_egraph_size; + stats[1] = num_goal_nodes; + stats[2] = num_switch_nodes; + stats[3] = total_shared; +} + +int pc_import_graph_stats_4(void) +{ + int stats[4]; + double avg_shared; + + graph_stats(stats); + avg_shared = (double)(stats[3]) / stats[0]; + + return + bpx_unify(bpx_get_call_arg(1,4), bpx_build_integer(stats[0])) && + bpx_unify(bpx_get_call_arg(2,4), bpx_build_integer(stats[1])) && + bpx_unify(bpx_get_call_arg(3,4), bpx_build_integer(stats[2])) && + bpx_unify(bpx_get_call_arg(4,4), bpx_build_float(avg_shared)); +} diff --git a/packages/prism/src/c/up/graph.h b/packages/prism/src/c/up/graph.h new file mode 100644 index 000000000..ab362c971 --- /dev/null +++ b/packages/prism/src/c/up/graph.h @@ -0,0 +1,82 @@ +#ifndef GRAPH_H +#define GRAPH_H + +/*====================================================================*/ + +#define INIT_MAX_SW_TABLE_SIZE 16 +#define INIT_MAX_SW_INS_TABLE_SIZE 64 +#define INIT_MAX_EGRAPH_SIZE (1 << 8) +#define MAX_EGRAPH_SIZE_EXPAND_LIMIT (128 << 10) + +/* node_id should be non-negative */ +#define UPDATE_MIN_MAX_NODE_NOS(node_id) do { \ + if (min_node_index < 0 || node_id < min_node_index) \ + min_node_index = node_id; \ + if (node_id > max_node_index) \ + max_node_index = node_id; \ + } while (0) +#define INIT_MIN_MAX_NODE_NOS do { \ + min_node_index = -1; \ + max_node_index = -1; \ + } while (0) +#define INIT_VISITED_FLAGS do { \ + int i; \ + for (i = min_node_index; i <= max_node_index; i++) \ + expl_graph[i]->visited = 0; \ + } while (0) + +/*====================================================================*/ + +int pc_alloc_egraph_0(void); +int pc_clean_base_egraph_0(void); +int pc_clean_egraph_0(void); +int pc_export_switch_2(void); +int pc_add_egraph_path_3(void); +int pc_alloc_sort_egraph_1(void); +int pc_clean_external_tables_0(void); +int pc_export_sw_info_1(void); +int pc_import_sorted_graph_size_1(void); +int pc_import_sorted_graph_gid_2(void); +int pc_import_sorted_graph_paths_2(void); +int pc_get_gnode_inside_2(void); +int pc_get_gnode_outside_2(void); +int pc_get_gnode_viterbi_2(void); +int pc_get_snode_inside_2(void); +int pc_get_snode_expectation_2(void); +int pc_import_occ_switches_3(void); +void graph_stats(int[4]); + +/*--------------------------------------------------------------------*/ + +void alloc_sorted_egraph(int); +void initialize_egraph_index(void); +int sort_one_egraph(int, int, int); +int sort_egraphs(TERM); + +/*====================================================================*/ + +extern int sorted_egraph_size; +extern EG_NODE_PTR *expl_graph; +extern EG_NODE_PTR *sorted_expl_graph; +extern int num_roots; +extern int num_goals; + +extern ROOT *roots; + +extern int min_node_index; +extern int max_node_index; + +extern int sw_tab_size; +extern int sw_ins_tab_size; +extern int occ_switch_tab_size; + +extern SW_INS_PTR *switches; +extern SW_INS_PTR *switch_instances; +extern SW_INS_PTR *occ_switches; + +extern int failure_subgoal_id; +extern int failure_root_index; + +/*====================================================================*/ + +#endif /* GRAPH_H */ diff --git a/packages/prism/src/c/up/graph_aux.c b/packages/prism/src/c/up/graph_aux.c new file mode 100644 index 000000000..70fce917c --- /dev/null +++ b/packages/prism/src/c/up/graph_aux.c @@ -0,0 +1,299 @@ +#include +#include "bprolog.h" +#include "up/up.h" +#include "up/graph.h" +#include "up/graph_aux.h" +#include "up/flags.h" + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +void quit(const char *); + +/*------------------------------------------------------------------------*/ + +static EG_NODE_PTR *subgraph; +static int subgraph_size; +static int max_subgraph_size; + +/*------------------------------------------------------------------------*/ + +static void alloc_subgraph(void) +{ + max_subgraph_size = INIT_MAX_EGRAPH_SIZE; + subgraph = (EG_NODE_PTR *)MALLOC(sizeof(EG_NODE_PTR) * max_subgraph_size); +} + +static void expand_subgraph(int req_subgraph_size) +{ + if (req_subgraph_size > max_subgraph_size) { + while (req_subgraph_size > max_subgraph_size) { + if (max_subgraph_size > MAX_EGRAPH_SIZE_EXPAND_LIMIT) + max_subgraph_size += MAX_EGRAPH_SIZE_EXPAND_LIMIT; + else + max_subgraph_size *= 2; + } + + subgraph = REALLOC(subgraph, sizeof(EG_NODE_PTR) * max_subgraph_size); + } +} + +static void release_subgraph(void) +{ + free(subgraph); + subgraph = NULL; +} + +static void traverse_egraph(EG_NODE_PTR node_ptr) +{ + int i; + EG_NODE_PTR c_node_ptr; + EG_PATH_PTR path_ptr; + + node_ptr->visited = 1; + path_ptr = node_ptr->path_ptr; + + while (path_ptr != NULL) { + for (i = 0; i < path_ptr->children_len; i++) { + c_node_ptr = path_ptr->children[i]; + if (c_node_ptr->visited != 1) { + if (c_node_ptr->visited == 0) { + traverse_egraph(c_node_ptr); + } + expand_subgraph(subgraph_size + 1); + subgraph[subgraph_size] = c_node_ptr; + subgraph_size++; + } + } + path_ptr = path_ptr->next; + } +} + +/*------------------------------------------------------------------------*/ + +/* `mode' is a macro prefixed by `PRINT_' */ +void print_egraph(int level, int mode) +{ + ROOT root_ptr; + EG_NODE_PTR eg_ptr, node_ptr; + EG_PATH_PTR path_ptr; + SW_INS_PTR sw_ptr; + int log_scale1; + int r,u,e,i,k,len; + + /* disable scaling for non-learning */ + log_scale1 = (mode > 0) ? log_scale : 0; + + alloc_subgraph(); + + for (r = 0; r < num_roots; r++) { + root_ptr = roots[r]; + + if (level >= 1) { + fprintf(curr_out," <>\n", + r,prism_goal_string(root_ptr->id), + root_ptr->id,root_ptr->count); + } + else { + fprintf(curr_out," <>\n",r,root_ptr->count); + } + + subgraph_size = 0; + + traverse_egraph(expl_graph[root_ptr->id]); + expand_subgraph(subgraph_size + 1); + subgraph[subgraph_size] = expl_graph[root_ptr->id]; + + for (i = subgraph_size; i >= 0; i--) { + eg_ptr = subgraph[i]; + + if (eg_ptr->visited == 2) { + fprintf(curr_out," g[%d]:%s\n", + eg_ptr->id,prism_goal_string(eg_ptr->id)); + fprintf(curr_out," **** already shown ****\n"); + continue; + } + + eg_ptr->visited = 2; + + if (level == 0) { + fprintf(curr_out," g[%d]:%s\n", + eg_ptr->id,prism_goal_string(eg_ptr->id)); + } + if (level >= 3) { + fprintf(curr_out," g[%d]:%s.addr = <%p>\n", + eg_ptr->id,prism_goal_string(eg_ptr->id),eg_ptr); + } + if (level >= 1) { + if (log_scale1) { + fprintf(curr_out," g[%d]:%s.inside = %.9e (%.9e)\n", + eg_ptr->id,prism_goal_string(eg_ptr->id), + eg_ptr->inside,exp(eg_ptr->inside)); + fprintf(curr_out," g[%d]:%s.outside = %.9e (%.9e)\n", + eg_ptr->id,prism_goal_string(eg_ptr->id), + eg_ptr->outside,exp(eg_ptr->outside)); + fprintf(curr_out," g[%d]:%s.first_outside = %.9e (%.9e)\n", + eg_ptr->id,prism_goal_string(eg_ptr->id), + eg_ptr->first_outside,exp(eg_ptr->first_outside)); + } + else { + fprintf(curr_out," g[%d]:%s.inside = %.9e\n", + eg_ptr->id,prism_goal_string(eg_ptr->id), + eg_ptr->inside); + fprintf(curr_out," g[%d]:%s.outside = %.9e\n", + eg_ptr->id,prism_goal_string(eg_ptr->id), + eg_ptr->outside); + } + if (mode == PRINT_VITERBI) { + fprintf(curr_out," g[%d]:%s.max = %.9e\n", + eg_ptr->id,prism_goal_string(eg_ptr->id), + eg_ptr->max); + fprintf(curr_out," g[%d]:%s.top_n_len = %d\n", + eg_ptr->id,prism_goal_string(eg_ptr->id), + eg_ptr->top_n_len); + if (eg_ptr->top_n != NULL) { + for (e = 0; e < eg_ptr->top_n_len; e++) { + if (eg_ptr->top_n[e] == NULL) continue; + fprintf(curr_out," top_n[%d]->goal_id = %d\n", + e,eg_ptr->top_n[e]->goal_id); + fprintf(curr_out," top_n[%d]->path_ptr = %p\n", + e,eg_ptr->top_n[e]->path_ptr); + len = eg_ptr->top_n[e]->children_len; + for (k = 0; k < len; k++) { + fprintf(curr_out, + " top_n[%d]->goal[%d] = %s (%d)\n", + e,k,prism_goal_string(eg_ptr->top_n[e]->path_ptr->children[k]->id),eg_ptr->top_n[e]->path_ptr->children[k]->id); + fprintf(curr_out," top_n[%d]->top_n_index[%d] = %d\n", + e,k,eg_ptr->top_n[e]->top_n_index[k]); + } + fprintf(curr_out," top_n[%d]->max = %.9e\n", + e,eg_ptr->top_n[e]->max); + } + } + } + } + + path_ptr = eg_ptr->path_ptr; + u = 0; + while (path_ptr != NULL) { + if (level == 0) { + fprintf(curr_out," path[%d]:\n",u); + } + if (level >= 3) { + fprintf(curr_out," path[%d].chilren_len = %d\n", + u,path_ptr->children_len); + fprintf(curr_out," path[%d].sws_len = %d\n", + u,path_ptr->sws_len); + } + if (level >= 1) { + if (log_scale1) { + fprintf(curr_out," path[%d].inside = %.9e (%.9e)\n", + u,path_ptr->inside,exp(path_ptr->inside)); + } + else { + fprintf(curr_out," path[%d].inside = %.9e\n", + u,path_ptr->inside); + } + } + for (k = 0; k < path_ptr->children_len; k++) { + node_ptr = path_ptr->children[k]; + if (level == 0) { + fprintf(curr_out," g[%d]:%s\n", + node_ptr->id,prism_goal_string(node_ptr->id)); + } + if (level >= 3) { + fprintf(curr_out," g[%d]:%s.addr = <%p>\n", + node_ptr->id,prism_goal_string(node_ptr->id), + node_ptr); + } + if (level >= 1) { + if (log_scale1) { + fprintf(curr_out, + " g[%d]:%s.inside = %.9e (%.9e)\n", + node_ptr->id, + prism_goal_string(node_ptr->id), + node_ptr->inside,exp(node_ptr->inside)); + fprintf(curr_out, + " g[%d]:%s.outside = %.9e (%.9e)\n", + node_ptr->id, + prism_goal_string(node_ptr->id), + node_ptr->outside,exp(node_ptr->outside)); + fprintf(curr_out, + " g[%d]:%s.first_outside = %.9e (%.9e)\n", + node_ptr->id, + prism_goal_string(node_ptr->id), + node_ptr->first_outside, + exp(node_ptr->first_outside)); + } + else { + fprintf(curr_out," g[%d]:%s.inside = %.9e\n", + node_ptr->id, + prism_goal_string(node_ptr->id), + node_ptr->inside); + fprintf(curr_out," g[%d]:%s.outside = %.9e\n", + node_ptr->id, + prism_goal_string(node_ptr->id), + node_ptr->outside); + } + } + } + for (k = 0; k < path_ptr->sws_len; k++) { + sw_ptr = path_ptr->sws[k]; + if (level == 0) { + fprintf(curr_out," sw[%d]:%s\n", + sw_ptr->id,prism_sw_ins_string(sw_ptr->id)); + } + if (level >= 1) { + if (mode == PRINT_EM) { + fprintf(curr_out," sw[%d]:%s.inside = %.9e\n", + sw_ptr->id, + prism_sw_ins_string(sw_ptr->id), + sw_ptr->inside); + fprintf(curr_out," sw[%d]:%s.total_e = %.9e\n", + sw_ptr->id, + prism_sw_ins_string(sw_ptr->id), + sw_ptr->total_expect); + } + if (mode == PRINT_VBEM) { + fprintf(curr_out," sw[%d]:%s.pi = %.9e\n", + sw_ptr->id, + prism_sw_ins_string(sw_ptr->id), + sw_ptr->pi); + fprintf(curr_out," sw[%d]:%s.smooth = %.9e\n", + sw_ptr->id, + prism_sw_ins_string(sw_ptr->id), + sw_ptr->smooth); + fprintf(curr_out," sw[%d]:%s.inside = %.9e\n", + sw_ptr->id, + prism_sw_ins_string(sw_ptr->id), + sw_ptr->inside); + fprintf(curr_out, + " sw[%d]:%s.inside_h = %.9e\n", + sw_ptr->id, + prism_sw_ins_string(sw_ptr->id), + sw_ptr->inside_h); + fprintf(curr_out," sw[%d]:%s.total_e = %.9e\n", + sw_ptr->id, + prism_sw_ins_string(sw_ptr->id), + sw_ptr->total_expect); + } + if (mode == PRINT_VITERBI) { + fprintf(curr_out," sw[%d]:%s.inside = %.9e\n", + sw_ptr->id, + prism_sw_ins_string(sw_ptr->id), + sw_ptr->inside); + } + } + } + + path_ptr = path_ptr->next; + u++; + } + } + } + + INIT_VISITED_FLAGS; + release_subgraph(); +} + +/*------------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/up/graph_aux.h b/packages/prism/src/c/up/graph_aux.h new file mode 100644 index 000000000..1a7fa1f59 --- /dev/null +++ b/packages/prism/src/c/up/graph_aux.h @@ -0,0 +1,15 @@ +#ifndef GRAPH_AUX_H +#define GRAPH_AUX_H + +/* + * mode for print_egraph + * (positive for EM learning; negative for other inferences) + */ +#define PRINT_NEUTRAL 0 +#define PRINT_EM 1 +#define PRINT_VBEM 2 +#define PRINT_VITERBI -1 + +void print_egraph(int, int); + +#endif /* GRAPH_AUX_H */ diff --git a/packages/prism/src/c/up/hindsight.c b/packages/prism/src/c/up/hindsight.c new file mode 100644 index 000000000..2a7d23941 --- /dev/null +++ b/packages/prism/src/c/up/hindsight.c @@ -0,0 +1,300 @@ +#include "up/up.h" +#include "up/graph.h" +#include "up/graph_aux.h" +#include "up/em_aux.h" +#include "up/em_aux_ml.h" +#include "up/flags.h" +#include "up/util.h" + +/*------------------------------------------------------------------------*/ + +#define INIT_MAX_HINDSIGHT_GOAL_SIZE 100 + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +NORET quit(const char *); + +/*------------------------------------------------------------------------*/ + +static int * hindsight_goals = NULL; +static double * hindsight_probs = NULL; +static int max_hindsight_goal_size; +static int hindsight_goal_size; + +/*------------------------------------------------------------------------*/ + +static void alloc_hindsight_goals(void) +{ + int i; + + hindsight_goal_size = 0; + max_hindsight_goal_size = INIT_MAX_HINDSIGHT_GOAL_SIZE; + hindsight_goals = (int *)MALLOC(max_hindsight_goal_size * sizeof(TERM)); + hindsight_probs = + (double *)MALLOC(max_hindsight_goal_size * sizeof(double)); + + for (i = 0; i < max_hindsight_goal_size; i++) { + hindsight_goals[i] = -1; + hindsight_probs[i] = 0.0; + } +} + +static void expand_hindsight_goals(int req_hindsight_goal_size) +{ + int old_size,i; + + if (req_hindsight_goal_size > max_hindsight_goal_size) { + old_size = max_hindsight_goal_size; + + while (req_hindsight_goal_size > max_hindsight_goal_size) { + max_hindsight_goal_size *= 2; + } + + hindsight_goals = + (int *)REALLOC(hindsight_goals, + max_hindsight_goal_size * sizeof(TERM)); + hindsight_probs = + (double *)REALLOC(hindsight_probs, + max_hindsight_goal_size * sizeof(double)); + + for (i = old_size; i < max_hindsight_goal_size; i++) { + hindsight_goals[i] = -1; + hindsight_probs[i] = 0.0; + } + } +} + +/* + * Be warned that eg_ptr->outside will have a value different from that + * in the compute_expectation-family functions. + */ +int compute_outside_scaling_none(void) +{ + int i,k; + EG_PATH_PTR path_ptr; + EG_NODE_PTR eg_ptr,node_ptr; + double q; + + if (num_roots != 1) { + emit_internal_error("illegal call to compute_outside"); + RET_ERR(build_internal_error("no_observed_data")); + } + + for (i = 0; i < sorted_egraph_size; i++) { + sorted_expl_graph[i]->outside = 0.0; + } + + eg_ptr = expl_graph[roots[0]->id]; + eg_ptr->outside = roots[0]->count; + + for (i = (sorted_egraph_size - 1); i >= 0; i--) { + eg_ptr = sorted_expl_graph[i]; + path_ptr = eg_ptr->path_ptr; + while (path_ptr != NULL) { + q = eg_ptr->outside * path_ptr->inside; + if (q > 0.0) { + for (k = 0; k < path_ptr->children_len; k++) { + node_ptr = path_ptr->children[k]; + node_ptr->outside += q / node_ptr->inside; + } + } + path_ptr = path_ptr->next; + } + } + + return BP_TRUE; +} + +int compute_outside_scaling_log_exp(void) +{ + int i,k; + EG_PATH_PTR path_ptr; + EG_NODE_PTR eg_ptr,node_ptr; + double q,r; + + if (num_roots != 1) { + emit_internal_error("illegal call to compute_outside"); + RET_ERR(build_internal_error("no_observed_data")); + } + + for (i = 0; i < sorted_egraph_size; i++) { + sorted_expl_graph[i]->outside = 0.0; + sorted_expl_graph[i]->has_first_outside = 0; + sorted_expl_graph[i]->first_outside = 0.0; + } + + eg_ptr = expl_graph[roots[0]->id]; + eg_ptr->outside = 1.0; + eg_ptr->has_first_outside = 1; + eg_ptr->first_outside = log((double)(roots[0]->count)); + + /* sorted_expl_graph[to] must be a root node */ + for (i = sorted_egraph_size - 1; i >= 0; i--) { + eg_ptr = sorted_expl_graph[i]; + + /* First accumulate log-scale outside probabilities: */ + if (!eg_ptr->has_first_outside) { + emit_internal_error("unexpected has_first_outside[%s]",prism_goal_string(eg_ptr->id)); + RET_INTERNAL_ERR; + } + else if (!(eg_ptr->outside > 0.0)) { + emit_internal_error("unexpected outside[%s]", + prism_goal_string(eg_ptr->id)); + RET_INTERNAL_ERR; + } + else { + eg_ptr->outside = eg_ptr->first_outside + log(eg_ptr->outside); + } + + path_ptr = sorted_expl_graph[i]->path_ptr; + while (path_ptr != NULL) { + q = sorted_expl_graph[i]->outside + path_ptr->inside; + for (k = 0; k < path_ptr->children_len; k++) { + node_ptr = path_ptr->children[k]; + r = q - node_ptr->inside; + if (!node_ptr->has_first_outside) { + node_ptr->first_outside = r; + node_ptr->outside += 1.0; + node_ptr->has_first_outside = 1; + } + else if (r - node_ptr->first_outside >= log(HUGE_PROB)) { + node_ptr->outside *= exp(node_ptr->first_outside - r); + node_ptr->first_outside = r; + node_ptr->outside += 1.0; + } + else { + node_ptr->outside += exp(r - node_ptr->first_outside); + } + } + path_ptr = path_ptr->next; + } + } + + return BP_TRUE; +} + +static int get_hindsight_goals_scaling_none(TERM p_subgoal, int is_cond) +{ + int i,j; + EG_NODE_PTR eg_ptr; + TERM t; + double denom; + + if (is_cond) { + denom = expl_graph[roots[0]->id]->inside; + } + else { + denom = 1.0; + } + + j = 0; + for (i = 0; i < sorted_egraph_size - 1; i++) { + eg_ptr = sorted_expl_graph[i]; + t = prism_goal_term((IDNUM)(eg_ptr->id)); + if (bpx_is_unifiable(p_subgoal, t)) { + if (j >= max_hindsight_goal_size) expand_hindsight_goals(j + 1); + if (j >= hindsight_goal_size) hindsight_goal_size = j + 1; + hindsight_goals[j] = eg_ptr->id; + hindsight_probs[j] = eg_ptr->inside * eg_ptr->outside / denom; + j++; + } + } + + return BP_TRUE; +} + +static int get_hindsight_goals_scaling_log_exp(TERM p_subgoal, int is_cond) +{ + int i,j; + EG_NODE_PTR eg_ptr; + TERM t; + double denom; + + if (is_cond) { + denom = expl_graph[roots[0]->id]->inside; + } + else { + denom = 0.0; + } + + j = 0; + for (i = 0; i < sorted_egraph_size - 1; i++) { + eg_ptr = sorted_expl_graph[i]; + t = prism_goal_term(eg_ptr->id); + if (bpx_is_unifiable(p_subgoal, t)) { + if (j >= max_hindsight_goal_size) expand_hindsight_goals(j + 1); + if (j >= hindsight_goal_size) hindsight_goal_size = j + 1; + hindsight_goals[j] = eg_ptr->id; + hindsight_probs[j] = eg_ptr->inside + eg_ptr->outside - denom; + j++; + } + } + + return BP_TRUE; +} + +int pc_compute_hindsight_4(void) +{ + TERM p_subgoal,p_hindsight_pairs,t,t1,p_pair; + int goal_id,is_cond,j; + + goal_id = bpx_get_integer(bpx_get_call_arg(1,4)); + p_subgoal = bpx_get_call_arg(2,4); + is_cond = bpx_get_integer(bpx_get_call_arg(3,4)); + + initialize_egraph_index(); + alloc_sorted_egraph(1); + RET_ON_ERR(sort_one_egraph(goal_id,0,1)); + if (verb_graph) print_egraph(0,PRINT_NEUTRAL); + + alloc_hindsight_goals(); + + if (log_scale) { + RET_ON_ERR(compute_inside_scaling_log_exp()); + RET_ON_ERR(compute_outside_scaling_log_exp()); + RET_ON_ERR(get_hindsight_goals_scaling_log_exp(p_subgoal,is_cond)); + } + else { + RET_ON_ERR(compute_inside_scaling_none()); + RET_ON_ERR(compute_outside_scaling_none()); + RET_ON_ERR(get_hindsight_goals_scaling_none(p_subgoal,is_cond)); + } + + if (hindsight_goal_size > 0) { + /* Build the list of pairs of a subgoal and its hindsight probability */ + p_hindsight_pairs = bpx_build_list(); + t = p_hindsight_pairs; + + for (j = 0; j < hindsight_goal_size; j++) { + p_pair = bpx_build_list(); + + t1 = p_pair; + bpx_unify(bpx_get_car(t1), + bpx_build_integer(hindsight_goals[j])); + bpx_unify(bpx_get_cdr(t1),bpx_build_list()); + + t1 = bpx_get_cdr(t1); + bpx_unify(bpx_get_car(t1),bpx_build_float(hindsight_probs[j])); + bpx_unify(bpx_get_cdr(t1),bpx_build_nil()); + + bpx_unify(bpx_get_car(t),p_pair); + + if (j == hindsight_goal_size - 1) { + bpx_unify(bpx_get_cdr(t),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(t),bpx_build_list()); + t = bpx_get_cdr(t); + } + } + } + else { + p_hindsight_pairs = bpx_build_nil(); + } + + FREE(hindsight_goals); + FREE(hindsight_probs); + + return bpx_unify(bpx_get_call_arg(4,4),p_hindsight_pairs); +} diff --git a/packages/prism/src/c/up/hindsight.h b/packages/prism/src/c/up/hindsight.h new file mode 100644 index 000000000..53024dffc --- /dev/null +++ b/packages/prism/src/c/up/hindsight.h @@ -0,0 +1,15 @@ +#ifndef HINDSIGHT_H +#define HINDSIGHT_H + +/*============================================================================*/ + +int pc_compute_hindsight_4(void); + +/*----------------------------------------------------------------------------*/ + +int compute_outside_scaling_none(void); +int compute_outside_scaling_log_exp(void); + +/*============================================================================*/ + +#endif /* HINDSIGHT_H */ diff --git a/packages/prism/src/c/up/up.h b/packages/prism/src/c/up/up.h new file mode 100644 index 000000000..19a4659d2 --- /dev/null +++ b/packages/prism/src/c/up/up.h @@ -0,0 +1,118 @@ +#ifndef UP_H +#define UP_H + +#include "core/bpx.h" +#include "core/xmalloc.h" +#include "core/stuff.h" +#include "core/idtable.h" +#include "core/idtable_preds.h" +#include "core/error.h" + +#ifndef _MSC_VER +#include +#endif +#ifdef MALLOC_TRACE +#include +#endif + +/* core binary version */ +#define BINARY_VERSION "20070529" + +#define INIT_PROB_THRESHOLD 1e-9 +#define EPS 1e-12 + +#define NULL_TERM ((TERM)(0)) /* reference to null */ + +/* IEEE 64bit double: 4.94e-324 ... 1.797e+308 (for positive) */ +#define HUGE_PROB 1.0e+280 +#define TINY_PROB 1.0e-300 + +/* Data structures for support graphs */ +typedef struct ExplGraphPath *EG_PATH_PTR; +struct ExplGraphPath { + int children_len; + int sws_len; + struct ExplGraphNode **children; /* an array of pointers to children nodes */ + struct SwitchInstance **sws; /* an array of pointers to switches */ + double inside; /* Inside propability of this path */ + double max; /* Max propability of this path (for Viterbi) */ + struct ExplGraphPath *next; /* next path in a list */ +}; + +typedef struct ViterbiEntry *V_ENT_PTR; +struct ViterbiEntry { + int goal_id; + EG_PATH_PTR path_ptr; /* path for a node */ + int children_len; /* number of children in the path */ + int *top_n_index; /* indices of paths in the top-N lists for children */ + double max; /* max. prob of the path with the sub-paths indicated by top_n_index[] */ +}; + +typedef struct ExplGraphNode *EG_NODE_PTR; +struct ExplGraphNode { + int id; + double inside, outside; /* inside and outside propabilities */ + double max; /* max probabilities */ + EG_PATH_PTR max_path; /* pointer to the path with max prob. */ + V_ENT_PTR *top_n; /* top-N list (for top-N Viterbi) */ + int top_n_len; /* size of top-N list (for top-N Viterbi) */ + int shared; /* number of goals which call this subgoal */ + EG_PATH_PTR path_ptr; + double first_outside; + char has_first_outside; + char visited; /* flag: each node needs to occur at most once */ +}; + +typedef struct ViterbiList *V_LIST_PTR; +struct ViterbiList { + V_ENT_PTR entry; + V_LIST_PTR prev; + V_LIST_PTR next; +}; + +/* Data structures for switches (this data structure might have + a little bit redundancy due to `fixed' and `occ' flags) */ +typedef struct SwitchInstance *SW_INS_PTR; +struct SwitchInstance { + int id; + char fixed; /* parameter is fixed or not */ + char fixed_h; /* hyperparameter is fixed or not */ + char occ; /* occurring in the current expl graphs or not (temporarily used) */ + double inside; /* theta (parameter) in ML/MAP */ + double inside_h; /* alpha (hyperparameter) in VB */ + double smooth; /* pseudo count which equals alpha - 1.0 */ + double smooth_prolog; /* original pseudo count passed from the Prolog part */ + double pi; + double best_inside; /* best theta */ + double best_inside_h; /* best alpha */ + double first_expectation; + char has_first_expectation; + double total_expect; /* Sigma ru */ + double best_total_expect; /* best Sigma ru */ + int count; /* number of occurrences in complete data */ + SW_INS_PTR next; /* connect next instance of the same switch */ +}; + +typedef struct ObservedFactNode *ROOT; +struct ObservedFactNode { + int id; + int count; /* number of occurrences */ +}; + +#define CTRLC_PRESSED (toam_signal_vec & INTERRUPT) + +/* isfinite()/isnan() on non-C99-complient compilers */ + +#ifdef _MSC_VER +#include +#define isfinite _finite +#define isnan _isnan +#endif + +#ifdef LINUX +#ifndef isfinite +#define isfinite finite +#endif +#endif + +#endif /* UP_H */ diff --git a/packages/prism/src/c/up/util.c b/packages/prism/src/c/up/util.c new file mode 100644 index 000000000..93103afff --- /dev/null +++ b/packages/prism/src/c/up/util.c @@ -0,0 +1,147 @@ +#include +#include "bprolog.h" +#include "up/up.h" +#include "core/gamma.h" + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +int compare(TERM, TERM); + +/*------------------------------------------------------------------------*/ + +int prism_printf(const char *fmt, ...) +{ + va_list ap; + int rv; + + va_start(ap, fmt); + rv = vfprintf(curr_out, fmt, ap); + va_end(ap); + + fflush(curr_out); + + return rv; +} + +/*------------------------------------------------------------------------*/ + +int pc_mp_mode_0(void) +{ +#ifdef MPI + return BP_TRUE; +#else + return BP_FALSE; +#endif +} + +/*------------------------------------------------------------------------*/ + +int compare_sw_ins(const void *a, const void *b) +{ + SW_INS_PTR sw_ins_a, sw_ins_b; + TERM msw_a, msw_b; + + sw_ins_a = *(const SW_INS_PTR *)(a); + sw_ins_b = *(const SW_INS_PTR *)(b); + + msw_a = prism_sw_ins_term(sw_ins_a->id); + msw_b = prism_sw_ins_term(sw_ins_b->id); + + return compare(bpx_get_arg(1,msw_a), bpx_get_arg(1,msw_b)); +} + +/*------------------------------------------------------------------------*/ + +int get_term_depth(TERM t) +{ + SYM_REC_PTR sym; + int i, n, d, di; + + XDEREF(t); + + SWITCH_OP(t, l_term_depth, { return 0; }, { return 0; }, { + if (IsNumberedVar(t)) return 0; + + d = 0; + i = 0; + + while (bpx_is_list(t)) { + di = get_term_depth(bpx_get_car(t)) + (++i); + d = d > di ? d : di; + t = bpx_get_cdr(t); + } + + di = get_term_depth(t) + i; + d = d > di ? d : di; + + return d; + }, { + sym = GET_STR_SYM_REC(t); + + if (sym == float_psc) return 0; + + n = GET_ARITY_STR(sym); + d = 0; + + for (i = 1; i <= n; i++) { + di = get_term_depth(bpx_get_arg(i, t)); + d = d > di ? d : di; + } + + return d + 1; + }, { return 0; }); + + return 0; /* arbitrary */ +} + +int pc_get_term_depth_2(void) +{ + return bpx_unify(bpx_build_integer(get_term_depth(bpx_get_call_arg(1,2))), + bpx_get_call_arg(2,2)); +} + +/*------------------------------------------------------------------------*/ + +int pc_lngamma_2(void) +{ + double x = bpx_get_float(bpx_get_call_arg(1,2)); + TERM t = bpx_build_float(lngamma(x)); + + return bpx_unify(bpx_get_call_arg(2,2),t); +} + +/*------------------------------------------------------------------------*/ + +int pc_mtrace_0(void) +{ +#ifdef MALLOC_TRACE + mtrace(); +#endif + return BP_TRUE; +} + +int pc_muntrace_0(void) +{ +#ifdef MALLOC_TRACE + muntrace(); +#endif + return BP_TRUE; +} + +/*------------------------------------------------------------------------*/ + +/* effective only for Linux and Mac OS X */ +void xsleep(unsigned int milliseconds) +{ +#ifndef _MSC_VER + usleep(milliseconds * 1000); +#endif +} + +int pc_sleep_1(void) +{ + xsleep(bpx_get_integer(bpx_get_call_arg(1,1))); + + return BP_TRUE; +} diff --git a/packages/prism/src/c/up/util.h b/packages/prism/src/c/up/util.h new file mode 100644 index 000000000..7827744a6 --- /dev/null +++ b/packages/prism/src/c/up/util.h @@ -0,0 +1,23 @@ +#ifndef UTIL_H +#define UTIL_H + +/*====================================================================*/ + +int pc_mp_mode_0(void); +int pc_get_term_depth_2(void); + +int prism_printf(const char *, ...); +int compare_sw_ins(const void *, const void *); +int get_term_depth(TERM); + +int pc_lngamma_2(void); + +int pc_mtrace_0(void); +int pc_muntrace_0(void); + +void xsleep(unsigned int); +int pc_sleep_1(void); + +/*====================================================================*/ + +#endif /* UTIL_H */ diff --git a/packages/prism/src/c/up/viterbi.c b/packages/prism/src/c/up/viterbi.c new file mode 100644 index 000000000..e9c66e86d --- /dev/null +++ b/packages/prism/src/c/up/viterbi.c @@ -0,0 +1,1121 @@ +#include "core/gamma.h" +#include "up/up.h" +#include "up/graph.h" +#include "up/graph_aux.h" +#include "up/em_aux.h" +#include "up/em_aux_vb.h" +#include "up/flags.h" + +/*------------------------------------------------------------------------*/ + +typedef struct ViterbiRankEntry *V_RANK_PTR; +struct ViterbiRankEntry { + int size; + V_ENT_PTR *expl; + double score; +}; + +/*------------------------------------------------------------------------*/ + +/* mic.c (B-Prolog) */ +NORET quit(const char *); +NORET myquit(int, const char *); + +/*------------------------------------------------------------------------*/ + +static EG_NODE_PTR * viterbi_egraphs = NULL; +static int max_viterbi_egraph_size; +static int viterbi_egraph_size; + +static V_LIST_PTR queue_first; +static V_LIST_PTR queue_last; +static int queue_len; +static V_LIST_PTR top_n_first; +static V_LIST_PTR top_n_last; +static int top_n_len; +static V_ENT_PTR * n_viterbi_egraphs = NULL; +static int max_n_viterbi_egraph_size; +static int n_viterbi_egraph_size; + +static V_RANK_PTR viterbi_rank = NULL; + +/* Viterbi works on only one explanation graph */ +void compute_max(void) +{ + int i,k; + double max_p,this_path_max; + EG_PATH_PTR max_path = NULL; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + + if (log_scale) { + for (i = 0; i < sorted_egraph_size; i++) { + max_p = 1.0; /* any positive value is possible */ + eg_ptr = sorted_expl_graph[i]; + path_ptr = eg_ptr->path_ptr; + + /* path_ptr should not be NULL; but it happens */ + if (path_ptr == NULL) { + max_p = 0.0; /* log-scale */ + max_path = NULL; + } + + /* [Note] we perform probability computations in log-scale */ + while (path_ptr != NULL) { + this_path_max = 0.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_max += path_ptr->children[k]->max; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_max += log(path_ptr->sws[k]->inside); + } + path_ptr->max = this_path_max; + + if (max_p > 0 || max_p <= this_path_max) { + max_p = this_path_max; + max_path = path_ptr; + } + + path_ptr = path_ptr->next; + } + + sorted_expl_graph[i]->max = max_p; + sorted_expl_graph[i]->max_path = max_path; + } + } + else { + for (i = 0; i < sorted_egraph_size; i++) { + max_p = 0.0; + eg_ptr = sorted_expl_graph[i]; + path_ptr = eg_ptr->path_ptr; + + /* path_ptr should not be NULL; but it happens */ + if (path_ptr == NULL) { + max_p = 1.0; + max_path = NULL; + } + + while (path_ptr != NULL) { + this_path_max = 1.0; + for (k = 0; k < path_ptr->children_len; k++) { + this_path_max *= path_ptr->children[k]->max; + } + for (k = 0; k < path_ptr->sws_len; k++) { + this_path_max *= path_ptr->sws[k]->inside; + } + path_ptr->max = this_path_max; + + if (this_path_max > max_p) { + max_p = this_path_max; + max_path = path_ptr; + } + + path_ptr = path_ptr->next; + } + + sorted_expl_graph[i]->max = max_p; + sorted_expl_graph[i]->max_path = max_path; + } + } + +} + +static void clean_queue(void) +{ + V_LIST_PTR ptr,next_ptr; + + ptr = queue_first; + while (ptr != NULL) { + next_ptr = ptr->next; + free(ptr); + ptr = next_ptr; + } + queue_first = queue_last = NULL; + queue_len = 0; +} + +static void clean_top_n(void) +{ + V_LIST_PTR ptr,next_ptr; + + ptr = top_n_first; + while (ptr != NULL) { + next_ptr = ptr->next; + free(ptr); + ptr = next_ptr; + } + top_n_first = top_n_last = NULL; + top_n_len = 0; +} + +void compute_n_max(int n) +{ + int i,k,j,m; + EG_NODE_PTR eg_ptr; + EG_PATH_PTR path_ptr; + V_LIST_PTR queue_ptr; + V_LIST_PTR top_n_ptr,next_top_n_ptr,new_top_n_ptr,old_top_n_last; + V_ENT_PTR v_ent; + V_ENT_PTR v_ent_next; + double p; + int inserted; + int old_mth_index,new_mth_index; + EG_NODE_PTR mth_child; + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + eg_ptr->inside = -1.0; + eg_ptr->outside = -1.0; + + if (eg_ptr->path_ptr != NULL) { + eg_ptr->top_n = (V_ENT_PTR *)MALLOC(sizeof(V_ENT_PTR) * n); + for (j = 0; j < n; j++) + eg_ptr->top_n[j] = NULL; + } + else { + eg_ptr->top_n = NULL; + } + eg_ptr->top_n_len = 0; + } + + for (i = 0; i < sorted_egraph_size; i++) { + eg_ptr = sorted_expl_graph[i]; + + queue_len = 0; + queue_first = queue_last = NULL; + + path_ptr = eg_ptr->path_ptr; + + if (path_ptr == NULL) continue; + + /* Constructing the initial queue: */ + while (path_ptr != NULL) { + + /* Create an entry which is the most probable for the path */ + v_ent = (V_ENT_PTR)MALLOC(sizeof(struct ViterbiEntry)); + v_ent->goal_id = eg_ptr->id; + v_ent->path_ptr = path_ptr; + v_ent->children_len = path_ptr->children_len; + v_ent->top_n_index = + (int *)MALLOC(sizeof(int) * path_ptr->children_len); + + for (k = 0; k < path_ptr->children_len; k++) { + v_ent->top_n_index[k] = 0; + } + if (log_scale) { + p = 0.0; + for (k = 0; k < path_ptr->children_len; k++) { + if (path_ptr->children[k]->top_n != NULL) + p += path_ptr->children[k]->top_n[0]->max; + } + for (k = 0; k < path_ptr->sws_len; k++) { + p += log(path_ptr->sws[k]->inside); + } + } + else { + p = 1.0; + for (k = 0; k < path_ptr->children_len; k++) { + if (path_ptr->children[k]->top_n != NULL) + p *= path_ptr->children[k]->top_n[0]->max; + } + for (k = 0; k < path_ptr->sws_len; k++) { + p *= path_ptr->sws[k]->inside; + } + } + v_ent->max = p; + + /* Enqueue the entry */ + queue_ptr = (V_LIST_PTR)MALLOC(sizeof(struct ViterbiList)); + queue_ptr->entry = v_ent; + queue_ptr->prev = NULL; /* Never use for the queue */ + queue_ptr->next = NULL; + if (queue_first == NULL) { + queue_first = queue_last = queue_ptr; + queue_len = 1; + } + else { + queue_last->next = queue_ptr; + queue_last = queue_ptr; + queue_len++; + } + + path_ptr = path_ptr->next; + } + + /* Create the header of top-N list */ + top_n_first = (V_LIST_PTR)MALLOC(sizeof(struct ViterbiList)); + top_n_first->entry = NULL; /* null entry */ + top_n_first->prev = NULL; + top_n_first->next = NULL; + top_n_last = top_n_first; + top_n_len = 0; + + while (queue_len > 0) { + /* Dequeue */ + v_ent = queue_first->entry; + queue_ptr = queue_first; + queue_first = queue_ptr->next; + free(queue_ptr); + queue_len--; + + /** Add the element to the top-N list **/ + top_n_ptr = top_n_first; + next_top_n_ptr = top_n_first->next; + inserted = 0; + while (next_top_n_ptr != NULL) { /* compare the current entry with the ones in the top-N list */ + if (v_ent->max > next_top_n_ptr->entry->max) { + new_top_n_ptr = + (V_LIST_PTR)MALLOC(sizeof(struct ViterbiList)); + + new_top_n_ptr->entry = v_ent; + new_top_n_ptr->prev = top_n_ptr; + new_top_n_ptr->next = next_top_n_ptr; + + next_top_n_ptr->prev = new_top_n_ptr; + top_n_ptr->next = new_top_n_ptr; + top_n_len++; + inserted = 1; + break; + } + top_n_ptr = next_top_n_ptr; + next_top_n_ptr = next_top_n_ptr->next; + } + + if (top_n_len < n) { + if (!inserted) { + new_top_n_ptr = + (V_LIST_PTR)MALLOC(sizeof(struct ViterbiList)); + new_top_n_ptr->entry = v_ent; + new_top_n_ptr->prev = top_n_ptr; + new_top_n_ptr->next = NULL; + + top_n_ptr->next = new_top_n_ptr; + top_n_last = new_top_n_ptr; + top_n_len++; + inserted = 1; + } + } + else if (top_n_len == n) { + if (!inserted) { + /* Erase the current entry */ + free(v_ent->top_n_index); + free(v_ent); + v_ent = NULL; + } + } + else { /* top_n_len > n */ + if (!inserted) { + /* Erase the current entry */ + free(v_ent->top_n_index); + free(v_ent); + v_ent = NULL; + } + else { + /* Erase the last entry */ + old_top_n_last = top_n_last; + top_n_last = top_n_last->prev; + top_n_last->next = NULL; + free(old_top_n_last->entry->top_n_index); + free(old_top_n_last->entry); + free(old_top_n_last); + top_n_len--; + } + } + + /* If the current entry is not added to the top-N list, there is no + * need to pursue the entries that have lower probabilities than + * the current entry's probability. + */ + if (!inserted) continue; + + /* Otherwise, propose the futher entries based on the current entry */ + for (m = 0; m < v_ent->children_len; m++) { + + old_mth_index = v_ent->top_n_index[m]; + new_mth_index = v_ent->top_n_index[m] + 1; + mth_child = v_ent->path_ptr->children[m]; + + if (new_mth_index >= mth_child->top_n_len) + continue; + + v_ent_next = (V_ENT_PTR)MALLOC(sizeof(struct ViterbiEntry)); + v_ent_next->goal_id = v_ent->goal_id; + v_ent_next->path_ptr = v_ent->path_ptr; + v_ent_next->children_len = v_ent->children_len; + v_ent_next->top_n_index = + (int *)MALLOC(sizeof(int) * v_ent_next->children_len); + + for (k = 0; k < v_ent_next->children_len; k++) { + v_ent_next->top_n_index[k] = + (k == m) ? + (v_ent->top_n_index[k] + 1) : v_ent->top_n_index[k]; + } + + if (log_scale) { + v_ent_next->max = + v_ent->max + - mth_child->top_n[old_mth_index]->max + + mth_child->top_n[new_mth_index]->max; + } + else { + v_ent_next->max = + v_ent->max + * mth_child->top_n[new_mth_index]->max + / mth_child->top_n[old_mth_index]->max; + } + + /* Enqueue the derived entries */ + queue_ptr = (V_LIST_PTR)MALLOC(sizeof(struct ViterbiList)); + queue_ptr->entry = v_ent_next; + queue_ptr->prev = NULL; /* Never use for the queue */ + queue_ptr->next = NULL; + if (queue_first == NULL) { + queue_first = queue_last = queue_ptr; + queue_len = 1; + } + else { + queue_last->next = queue_ptr; + queue_last = queue_ptr; + queue_len++; + } + } + } + + j = 0; + top_n_ptr = top_n_first->next; + while (top_n_ptr != NULL) { + if (eg_ptr->top_n != NULL) + eg_ptr->top_n[j] = top_n_ptr->entry; /* shallow copy */ + j++; + top_n_ptr = top_n_ptr->next; + } + eg_ptr->top_n_len = j; + + clean_queue(); + clean_top_n(); + } +} + +static void alloc_viterbi_egraphs(void) +{ + int i; + + /* [Note] The size of viterbi path can exceed the number of subgoals in the + * explanation graph. we will expand the array size on demand. + */ + viterbi_egraph_size = 0; + max_viterbi_egraph_size = sorted_egraph_size; + viterbi_egraphs = + (EG_NODE_PTR *)MALLOC(max_viterbi_egraph_size * sizeof(EG_NODE_PTR)); + + /* Initialize to extra Ids */ + for (i = 0; i < max_viterbi_egraph_size; i++) + viterbi_egraphs[i] = NULL; +} + +static void expand_viterbi_egraphs(int req_viterbi_egraph_size) +{ + int old_size,i; + + if (req_viterbi_egraph_size > max_viterbi_egraph_size) { + old_size = max_viterbi_egraph_size; + + while (req_viterbi_egraph_size > max_viterbi_egraph_size) { + max_viterbi_egraph_size *= 2; + } + + viterbi_egraphs = + (EG_NODE_PTR *) + REALLOC(viterbi_egraphs, + max_viterbi_egraph_size * sizeof(EG_NODE_PTR)); + + for (i = old_size; i < max_viterbi_egraph_size; i++) { + viterbi_egraphs[i] = NULL; + } + } +} + +static void alloc_n_viterbi_egraphs(void) +{ + int i; + + n_viterbi_egraph_size = 0; + max_n_viterbi_egraph_size = sorted_egraph_size; + n_viterbi_egraphs = + (V_ENT_PTR *)MALLOC(max_n_viterbi_egraph_size * sizeof(V_ENT_PTR)); + + for (i = 0; i < max_n_viterbi_egraph_size; i++) { + n_viterbi_egraphs[i] = NULL; + } +} + +static void expand_n_viterbi_egraphs(int req_n_viterbi_egraph_size) +{ + int old_size,i; + + if (req_n_viterbi_egraph_size > max_n_viterbi_egraph_size) { + old_size = max_n_viterbi_egraph_size; + + while (req_n_viterbi_egraph_size > max_n_viterbi_egraph_size) { + max_n_viterbi_egraph_size *= 2; + } + + n_viterbi_egraphs = + (V_ENT_PTR *)REALLOC(n_viterbi_egraphs, + max_n_viterbi_egraph_size * sizeof(V_ENT_PTR)); + + for (i = old_size; i < max_n_viterbi_egraph_size; i++) { + n_viterbi_egraphs[i] = NULL; + } + } +} + +/* This function returns the last index of the current path */ +static int visit_most_likely_path(EG_NODE_PTR eg_ptr, + int start_vindex) +{ + int k; + int curr_vindex; + EG_PATH_PTR max_path; + + curr_vindex = start_vindex; + + if (curr_vindex >= max_viterbi_egraph_size) + expand_viterbi_egraphs(curr_vindex + 1); + + if (curr_vindex >= viterbi_egraph_size) + viterbi_egraph_size = curr_vindex + 1; + + viterbi_egraphs[curr_vindex] = eg_ptr; + curr_vindex++; + + max_path = eg_ptr->max_path; + + if (max_path == NULL) return curr_vindex; + + for (k = 0; k < max_path->children_len; k++) { + if (max_path->children == NULL) quit("Internal error: visit_most_likely_path\n"); + curr_vindex = + visit_most_likely_path(max_path->children[k],curr_vindex); + } + + return curr_vindex; +} + +static void get_most_likely_path(int goal_id, + TERM *p_goal_path_ptr, + TERM *p_subpath_goal_ptr, + TERM *p_subpath_sw_ptr, + double *viterbi_prob_ptr) +{ + TERM p_goal_path; + TERM p_subpath_goal, p_subpath_sw; + TERM p_tmp, p_tmp_g, p_tmp_g0, p_tmp_g1, p_tmp_sw, p_tmp_sw0, p_tmp_sw1; + int m,k; + EG_NODE_PTR eg_ptr = NULL; + EG_PATH_PTR path_ptr = NULL; + int viterbi_egraph_size; + int c_len, sw_len; + + alloc_viterbi_egraphs(); + + viterbi_egraph_size = visit_most_likely_path(expl_graph[goal_id],0); + + /* Build the Viterbi path as a Prolog list: */ + p_goal_path = bpx_build_list(); + p_tmp = p_goal_path; + for (m = 0; m < viterbi_egraph_size; m++) { + bpx_unify(bpx_get_car(p_tmp),bpx_build_integer(viterbi_egraphs[m]->id)); + if (m == viterbi_egraph_size - 1) { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_list()); + p_tmp = bpx_get_cdr(p_tmp); + } + } + + p_subpath_goal = bpx_build_list(); + p_subpath_sw = bpx_build_list(); + + p_tmp_g = p_subpath_goal; + p_tmp_sw = p_subpath_sw; + + for (m = 0; m < viterbi_egraph_size; m++) { + eg_ptr = viterbi_egraphs[m]; + + if (eg_ptr->max_path == NULL) { + p_tmp_g0 = bpx_build_nil(); + p_tmp_sw0 = bpx_build_nil(); + } + else { + path_ptr = eg_ptr->max_path; + c_len = path_ptr->children_len; + sw_len = path_ptr->sws_len; + + if (c_len == 0) { + p_tmp_g0 = bpx_build_nil(); + } + else { + p_tmp_g0 = bpx_build_list(); + p_tmp_g1 = p_tmp_g0; + for (k = 0; k < c_len; k++) { + bpx_unify(bpx_get_car(p_tmp_g1), + bpx_build_integer(path_ptr->children[k]->id)); + if (k == c_len - 1) { + bpx_unify(bpx_get_cdr(p_tmp_g1),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_g1),bpx_build_list()); + p_tmp_g1 = bpx_get_cdr(p_tmp_g1); + } + } + } + + if (sw_len == 0) { + p_tmp_sw0 = bpx_build_nil(); + } + else { + p_tmp_sw0 = bpx_build_list(); + p_tmp_sw1 = p_tmp_sw0; + for (k = 0; k < sw_len; k++) { + bpx_unify(bpx_get_car(p_tmp_sw1),bpx_build_integer(path_ptr->sws[k]->id)); + if (k == sw_len - 1) { + bpx_unify(bpx_get_cdr(p_tmp_sw1),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_sw1),bpx_build_list()); + p_tmp_sw1 = bpx_get_cdr(p_tmp_sw1); + } + } + } + } + + bpx_unify(bpx_get_car(p_tmp_g),p_tmp_g0); + bpx_unify(bpx_get_car(p_tmp_sw),p_tmp_sw0); + + if (m == viterbi_egraph_size - 1) { + bpx_unify(bpx_get_cdr(p_tmp_g),bpx_build_nil()); + bpx_unify(bpx_get_cdr(p_tmp_sw),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_g),bpx_build_list()); + bpx_unify(bpx_get_cdr(p_tmp_sw),bpx_build_list()); + p_tmp_g = bpx_get_cdr(p_tmp_g); + p_tmp_sw = bpx_get_cdr(p_tmp_sw); + } + } + + free(viterbi_egraphs); + viterbi_egraphs = NULL; + + *p_goal_path_ptr = p_goal_path; + *p_subpath_goal_ptr = p_subpath_goal; + *p_subpath_sw_ptr = p_subpath_sw; + *viterbi_prob_ptr = expl_graph[goal_id]->max; /* top goal's max prob */ +} + +/* This function returns the last index of the current path */ +static int visit_n_most_likely_path(V_ENT_PTR v_ent, int start_vindex) +{ + int k,j; + int curr_vindex; + V_ENT_PTR new_v_ent = NULL; + + curr_vindex = start_vindex; + + if (curr_vindex >= max_n_viterbi_egraph_size) + expand_n_viterbi_egraphs(curr_vindex + 1); + + if (curr_vindex >= n_viterbi_egraph_size) + n_viterbi_egraph_size = curr_vindex + 1; + + n_viterbi_egraphs[curr_vindex] = v_ent; + curr_vindex++; + + for (k = 0; k < v_ent->children_len; k++) { + if (v_ent->path_ptr->children == NULL) + quit("Internal error: visit_n_most_likely_path\n"); + + if (v_ent->path_ptr->children[k]->top_n == NULL) { + new_v_ent = (V_ENT_PTR)MALLOC(sizeof(struct ViterbiEntry)); + new_v_ent->goal_id = v_ent->path_ptr->children[k]->id; + new_v_ent->path_ptr = NULL; + + if (curr_vindex >= max_n_viterbi_egraph_size) + expand_n_viterbi_egraphs(curr_vindex + 1); + + if (curr_vindex >= n_viterbi_egraph_size) + n_viterbi_egraph_size = curr_vindex + 1; + + n_viterbi_egraphs[curr_vindex] = new_v_ent; + curr_vindex++; + } + else { + j = v_ent->top_n_index[k]; + curr_vindex = + visit_n_most_likely_path(v_ent->path_ptr->children[k]->top_n[j], + curr_vindex); + } + } + + return curr_vindex; +} + +static void get_n_most_likely_path(int n, int goal_id, + TERM *p_n_viterbi_list_ptr) +{ + TERM p_goal_path; + TERM p_subpath_goal, p_subpath_sw; + TERM p_tmp, p_tmp_g, p_tmp_g0, p_tmp_g1, p_tmp_sw, p_tmp_sw0, p_tmp_sw1; + TERM p_n_viterbi, p_n_viterbi_list, p_tmp_list; + TERM p_viterbi_prob; + int j,m,k; + EG_PATH_PTR path_ptr = NULL; + int c_len, sw_len; + V_ENT_PTR v_ent; + + p_n_viterbi_list = bpx_build_list(); + p_tmp_list = p_n_viterbi_list; + + for (j = 0; j < n; j++) { + + if (expl_graph[goal_id]->top_n[j] == NULL) continue; + + alloc_n_viterbi_egraphs(); + + n_viterbi_egraph_size = + visit_n_most_likely_path(expl_graph[goal_id]->top_n[j],0); + + /* Build the Viterbi path as a Prolog list: */ + p_goal_path = bpx_build_list(); + p_tmp = p_goal_path; + for (m = 0; m < n_viterbi_egraph_size; m++) { + bpx_unify(bpx_get_car(p_tmp),bpx_build_integer(n_viterbi_egraphs[m]->goal_id)); + if (m == n_viterbi_egraph_size - 1) { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_list()); + p_tmp = bpx_get_cdr(p_tmp); + } + } + + p_subpath_goal = bpx_build_list(); + p_subpath_sw = bpx_build_list(); + + p_tmp_g = p_subpath_goal; + p_tmp_sw = p_subpath_sw; + + for (m = 0; m < n_viterbi_egraph_size; m++) { + v_ent = n_viterbi_egraphs[m]; + + if (v_ent->path_ptr == NULL) { + p_tmp_g0 = bpx_build_nil(); + p_tmp_sw0 = bpx_build_nil(); + } + else { + path_ptr = v_ent->path_ptr; + c_len = path_ptr->children_len; + sw_len = path_ptr->sws_len; + + if (c_len == 0) { + p_tmp_g0 = bpx_build_nil(); + } + else { + p_tmp_g0 = bpx_build_list(); + p_tmp_g1 = p_tmp_g0; + for (k = 0; k < c_len; k++) { + bpx_unify(bpx_get_car(p_tmp_g1),bpx_build_integer(path_ptr->children[k]->id)); + if (k == c_len - 1) { + bpx_unify(bpx_get_cdr(p_tmp_g1),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_g1),bpx_build_list()); + p_tmp_g1 = bpx_get_cdr(p_tmp_g1); + } + } + } + + if (sw_len == 0) { + p_tmp_sw0 = bpx_build_nil(); + } + else { + p_tmp_sw0 = bpx_build_list(); + p_tmp_sw1 = p_tmp_sw0; + for (k = 0; k < sw_len; k++) { + bpx_unify(bpx_get_car(p_tmp_sw1),bpx_build_integer(path_ptr->sws[k]->id)); + if (k == sw_len - 1) { + bpx_unify(bpx_get_cdr(p_tmp_sw1),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_sw1),bpx_build_list()); + p_tmp_sw1 = bpx_get_cdr(p_tmp_sw1); + } + } + } + } + + bpx_unify(bpx_get_car(p_tmp_g),p_tmp_g0); + bpx_unify(bpx_get_car(p_tmp_sw),p_tmp_sw0); + + if (m == n_viterbi_egraph_size - 1) { + bpx_unify(bpx_get_cdr(p_tmp_g),bpx_build_nil()); + bpx_unify(bpx_get_cdr(p_tmp_sw),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_g),bpx_build_list()); + bpx_unify(bpx_get_cdr(p_tmp_sw),bpx_build_list()); + p_tmp_g = bpx_get_cdr(p_tmp_g); + p_tmp_sw = bpx_get_cdr(p_tmp_sw); + } + } + + p_viterbi_prob = bpx_build_float(expl_graph[goal_id]->top_n[j]->max); + + p_n_viterbi = bpx_build_structure("v_expl",5); + bpx_unify(bpx_get_arg(1,p_n_viterbi),bpx_build_integer(j)); + bpx_unify(bpx_get_arg(2,p_n_viterbi),p_goal_path); + bpx_unify(bpx_get_arg(3,p_n_viterbi),p_subpath_goal); + bpx_unify(bpx_get_arg(4,p_n_viterbi),p_subpath_sw); + bpx_unify(bpx_get_arg(5,p_n_viterbi),p_viterbi_prob); + + bpx_unify(bpx_get_car(p_tmp_list),p_n_viterbi); + + if (j == n - 1 || + (j < n - 1 && expl_graph[goal_id]->top_n[j + 1] == NULL)) { + bpx_unify(bpx_get_cdr(p_tmp_list),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_list),bpx_build_list()); + p_tmp_list = bpx_get_cdr(p_tmp_list); + } + + for (m = 0; m < n_viterbi_egraph_size; m++) { + /* Release the entries newly added in visit_n_most_likely_path() */ + if (n_viterbi_egraphs[m]->path_ptr == NULL) { + free(n_viterbi_egraphs[m]); + } + } + + free(n_viterbi_egraphs); + n_viterbi_egraphs = NULL; + } + + *p_n_viterbi_list_ptr = p_n_viterbi_list; +} + +static double compute_rerank_score(void) +{ + int i,s; + V_ENT_PTR v_ent; + EG_PATH_PTR path_ptr = NULL; + int k; + SW_INS_PTR sw_ins_ptr; + double score = 0.0; + double alpha_sum0,alpha_sum1; + + for (i = 0; i < occ_switch_tab_size; i++) { + sw_ins_ptr = occ_switches[i]; + while (sw_ins_ptr != NULL) { + sw_ins_ptr->count = 0; + sw_ins_ptr = sw_ins_ptr->next; + } + } + + for (s = 0; s < n_viterbi_egraph_size; s++) { + v_ent = n_viterbi_egraphs[s]; + path_ptr = v_ent->path_ptr; + + if (path_ptr == NULL) continue; + + for (k = 0; k < path_ptr->sws_len; k++) { + path_ptr->sws[k]->count++; + } + } + + score = 0.0; + for (i = 0; i < occ_switch_tab_size; i++) { + + alpha_sum0 = 0.0; + alpha_sum1 = 0.0; + sw_ins_ptr = occ_switches[i]; + while (sw_ins_ptr != NULL) { + alpha_sum0 += sw_ins_ptr->inside_h; + alpha_sum1 += sw_ins_ptr->count + sw_ins_ptr->inside_h; + sw_ins_ptr = sw_ins_ptr->next; + } + score += lngamma(alpha_sum0) - lngamma(alpha_sum1); + + sw_ins_ptr = occ_switches[i]; + while (sw_ins_ptr != NULL) { + score += lngamma(sw_ins_ptr->count + sw_ins_ptr->inside_h); + score -= lngamma(sw_ins_ptr->inside_h); + sw_ins_ptr = sw_ins_ptr->next; + } + } + + return score; +} + +static int compare_viterbi_rank(const void *a, const void *b) +{ + double score_a = ((V_RANK_PTR)a)->score; + double score_b = ((V_RANK_PTR)b)->score; + + if (score_a > score_b) return -1; + if (score_a < score_b) return 1; + + return 0; +} + +static void get_n_most_likely_path_rerank(int n, int l, int goal_id, + TERM *p_n_viterbi_list_ptr) +{ + TERM p_goal_path; + TERM p_subpath_goal, p_subpath_sw; + TERM p_tmp, p_tmp_g, p_tmp_g0, p_tmp_g1, p_tmp_sw, p_tmp_sw0, p_tmp_sw1; + TERM p_n_viterbi, p_n_viterbi_list, p_tmp_list; + TERM p_viterbi_prob; + int j,m,k; + EG_PATH_PTR path_ptr = NULL; + int c_len, sw_len; + V_ENT_PTR v_ent; + int l_used; + double n_viterbi_egraph_score; + + p_n_viterbi_list = bpx_build_list(); + p_tmp_list = p_n_viterbi_list; + + l_used = 0; + for (j = 0; j < l; j++) { + if (expl_graph[goal_id]->top_n[j] != NULL) l_used++; + } + + viterbi_rank = + (V_RANK_PTR)MALLOC(sizeof(struct ViterbiRankEntry) * l_used); + + for (j = 0; j < l_used; j++) { + alloc_n_viterbi_egraphs(); + + n_viterbi_egraph_size = + visit_n_most_likely_path(expl_graph[goal_id]->top_n[j],0); + + viterbi_rank[j].size = n_viterbi_egraph_size; + viterbi_rank[j].expl = n_viterbi_egraphs; + viterbi_rank[j].score = compute_rerank_score(); + } + + qsort(viterbi_rank, l_used, sizeof(struct ViterbiRankEntry), + compare_viterbi_rank); + + for (j = 0; j < l_used && j < n; j++) { + n_viterbi_egraph_size = viterbi_rank[j].size; + n_viterbi_egraphs = viterbi_rank[j].expl; + n_viterbi_egraph_score = viterbi_rank[j].score; + + /* Build the Viterbi path as a Prolog list: */ + p_goal_path = bpx_build_list(); + p_tmp = p_goal_path; + for (m = 0; m < n_viterbi_egraph_size; m++) { + bpx_unify(bpx_get_car(p_tmp), + bpx_build_integer(n_viterbi_egraphs[m]->goal_id)); + + if (m == n_viterbi_egraph_size - 1) { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp),bpx_build_list()); + p_tmp = bpx_get_cdr(p_tmp); + } + } + + p_subpath_goal = bpx_build_list(); + p_subpath_sw = bpx_build_list(); + + p_tmp_g = p_subpath_goal; + p_tmp_sw = p_subpath_sw; + + for (m = 0; m < n_viterbi_egraph_size; m++) { + v_ent = n_viterbi_egraphs[m]; + + if (v_ent->path_ptr == NULL) { + p_tmp_g0 = bpx_build_nil(); + p_tmp_sw0 = bpx_build_nil(); + } + else { + path_ptr = v_ent->path_ptr; + c_len = path_ptr->children_len; + sw_len = path_ptr->sws_len; + + if (c_len == 0) { + p_tmp_g0 = bpx_build_nil(); + } + else { + p_tmp_g0 = bpx_build_list(); + p_tmp_g1 = p_tmp_g0; + for (k = 0; k < c_len; k++) { + bpx_unify(bpx_get_car(p_tmp_g1), + bpx_build_integer(path_ptr->children[k]->id)); + if (k == c_len - 1) { + bpx_unify(bpx_get_cdr(p_tmp_g1),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_g1),bpx_build_list()); + p_tmp_g1 = bpx_get_cdr(p_tmp_g1); + } + } + } + + if (sw_len == 0) { + p_tmp_sw0 = bpx_build_nil(); + } + else { + p_tmp_sw0 = bpx_build_list(); + p_tmp_sw1 = p_tmp_sw0; + for (k = 0; k < sw_len; k++) { + bpx_unify(bpx_get_car(p_tmp_sw1),bpx_build_integer(path_ptr->sws[k]->id)); + if (k == sw_len - 1) { + bpx_unify(bpx_get_cdr(p_tmp_sw1),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_sw1),bpx_build_list()); + p_tmp_sw1 = bpx_get_cdr(p_tmp_sw1); + } + } + } + } + + bpx_unify(bpx_get_car(p_tmp_g),p_tmp_g0); + bpx_unify(bpx_get_car(p_tmp_sw),p_tmp_sw0); + + if (m == n_viterbi_egraph_size - 1) { + bpx_unify(bpx_get_cdr(p_tmp_g),bpx_build_nil()); + bpx_unify(bpx_get_cdr(p_tmp_sw),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_g),bpx_build_list()); + bpx_unify(bpx_get_cdr(p_tmp_sw),bpx_build_list()); + p_tmp_g = bpx_get_cdr(p_tmp_g); + p_tmp_sw = bpx_get_cdr(p_tmp_sw); + } + } + + p_viterbi_prob = bpx_build_float(n_viterbi_egraph_score); + + p_n_viterbi = bpx_build_structure("v_expl",5); + bpx_unify(bpx_get_arg(1,p_n_viterbi),bpx_build_integer(j)); + bpx_unify(bpx_get_arg(2,p_n_viterbi),p_goal_path); + bpx_unify(bpx_get_arg(3,p_n_viterbi),p_subpath_goal); + bpx_unify(bpx_get_arg(4,p_n_viterbi),p_subpath_sw); + bpx_unify(bpx_get_arg(5,p_n_viterbi),p_viterbi_prob); + + bpx_unify(bpx_get_car(p_tmp_list),p_n_viterbi); + + if (j == (l_used - 1) || j == (n - 1)) { + bpx_unify(bpx_get_cdr(p_tmp_list),bpx_build_nil()); + } + else { + bpx_unify(bpx_get_cdr(p_tmp_list),bpx_build_list()); + p_tmp_list = bpx_get_cdr(p_tmp_list); + } + } + + for (j = 0; j < l_used; j++) { + free(viterbi_rank[j].expl); + } + free(viterbi_rank); + viterbi_rank = NULL; + + *p_n_viterbi_list_ptr = p_n_viterbi_list; +} + +/*------------------------------------------------------------------------*/ + +/* [Note] node copying is not required here even in computation without + * inter-goal sharing, but we need to declare it explicitly. + */ +int pc_compute_viterbi_5(void) +{ + TERM p_goal_path,p_subpath_goal,p_subpath_sw; + int goal_id; + double viterbi_prob; + + goal_id = bpx_get_integer(bpx_get_call_arg(1,5)); + + initialize_egraph_index(); + alloc_sorted_egraph(1); + /* INIT_MIN_MAX_NODE_NOS; */ + RET_ON_ERR(sort_one_egraph(goal_id,0,1)); + if (verb_graph) print_egraph(0,PRINT_NEUTRAL); + + compute_max(); + + if (debug_level) print_egraph(1,PRINT_VITERBI); + + get_most_likely_path(goal_id,&p_goal_path,&p_subpath_goal, + &p_subpath_sw,&viterbi_prob); + + return + bpx_unify(bpx_get_call_arg(2,5), p_goal_path) && + bpx_unify(bpx_get_call_arg(3,5), p_subpath_goal) && + bpx_unify(bpx_get_call_arg(4,5), p_subpath_sw) && + bpx_unify(bpx_get_call_arg(5,5), bpx_build_float(viterbi_prob)); +} + +int pc_compute_n_viterbi_3(void) +{ + TERM p_n_viterbi_list; + int n,goal_id; + + n = bpx_get_integer(bpx_get_call_arg(1,3)); + goal_id = bpx_get_integer(bpx_get_call_arg(2,3)); + + initialize_egraph_index(); + alloc_sorted_egraph(1); + /* INIT_MIN_MAX_NODE_NOS; */ + RET_ON_ERR(sort_one_egraph(goal_id,0,1)); + if (verb_graph) print_egraph(0,PRINT_NEUTRAL); + + compute_n_max(n); + + if (debug_level) print_egraph(1,PRINT_VITERBI); + + get_n_most_likely_path(n,goal_id,&p_n_viterbi_list); + + return bpx_unify(bpx_get_call_arg(3,3),p_n_viterbi_list); +} + +/* + * Note: parameters are always refreshed in advance by $pc_export_sw_info/1, + * so it causes no problem to overwrite them temporarily + */ +int pc_compute_n_viterbi_rerank_4(void) +{ + TERM p_n_viterbi_list; + int n,l,goal_id; + + n = bpx_get_integer(bpx_get_call_arg(1,4)); + l = bpx_get_integer(bpx_get_call_arg(2,4)); + goal_id = bpx_get_integer(bpx_get_call_arg(3,4)); + + initialize_egraph_index(); + alloc_sorted_egraph(1); + /* INIT_MIN_MAX_NODE_NOS; */ + RET_ON_ERR(sort_one_egraph(goal_id,0,1)); + if (verb_graph) print_egraph(0,PRINT_NEUTRAL); + + alloc_occ_switches(); + transfer_hyperparams_prolog(); + get_param_means(); + + compute_n_max(l); + + get_n_most_likely_path_rerank(n,l,goal_id,&p_n_viterbi_list); + + release_occ_switches(); + + return bpx_unify(bpx_get_call_arg(4,4),p_n_viterbi_list); +} + +/*------------------------------------------------------------------------*/ diff --git a/packages/prism/src/c/up/viterbi.h b/packages/prism/src/c/up/viterbi.h new file mode 100644 index 000000000..845ffca4e --- /dev/null +++ b/packages/prism/src/c/up/viterbi.h @@ -0,0 +1,13 @@ +#ifndef VITERBI_H +#define VITERBI_H + +int pc_compute_viterbi_5(void); +int pc_compute_n_viterbi_3(void); +int pc_compute_n_viterbi_rerank_4(void); + +void compute_max(void); +void compute_n_max(int); + +#endif /* VITERBI_H */ + + diff --git a/packages/prism/src/prolog/Makefile.in b/packages/prism/src/prolog/Makefile.in new file mode 100644 index 000000000..6e7a88f90 --- /dev/null +++ b/packages/prism/src/prolog/Makefile.in @@ -0,0 +1,108 @@ +# +# default base directory for YAP installation +# +ROOTDIR = @prefix@ +# +# where the binary should be +# +BINDIR = $(ROOTDIR)/bin +# +# where YAP should look for binary libraries +# +LIBDIR=@libdir@/Yap +# +# where YAP should look for architecture-independent Prolog libraries +# +SHAREDIR=$(ROOTDIR)/share +# +# +# You shouldn't need to change what follows. +# +INSTALL=@INSTALL@ +INSTALL_DATA=@INSTALL_DATA@ +INSTALL_PROGRAM=@INSTALL_PROGRAM@ +srcdir=@srcdir@ +YAP_EXTRAS=@YAP_EXTRAS@ + +# -*- Makefile -*- + +##---------------------------------------------------------------------- + +TARGETS = prism.pl + +PRISM_VERSION = 2.0 + +PL_CORE = $(srcdir)/core/message.pl \ + $(srcdir)/core/error.pl \ + $(srcdir)/core/random.pl \ + $(srcdir)/core/format.pl + +PL_TRANS = $(srcdir)/trans/trans.pl \ + $(srcdir)/trans/dump.pl \ + $(srcdir)/trans/verify.pl \ + $(srcdir)/trans/bpif.pl + +PL_UP = $(srcdir)/up/dynamic.pl \ + $(srcdir)/up/main.pl \ + $(srcdir)/up/switch.pl \ + $(srcdir)/up/learn.pl \ + $(srcdir)/up/prob.pl \ + $(srcdir)/up/viterbi.pl \ + $(srcdir)/up/hindsight.pl \ + $(srcdir)/up/expl.pl \ + $(srcdir)/up/sample.pl \ + $(srcdir)/up/dist.pl \ + $(srcdir)/up/list.pl \ + $(srcdir)/up/hash.pl \ + $(srcdir)/up/flags.pl \ + $(srcdir)/up/util.pl \ + $(srcdir)/up/bigarray.pl + +PL_BAT = $(srcdir)/up/batch.pl + +PL_MP = $(srcdir)/mp/mp_main.pl \ + $(srcdir)/mp/mp_learn.pl + +PL_BP = $(srcdir)/bp/eval.pl + +PL_UP_ALL = $(PL_CORE) $(PL_UP) $(PL_TRANS) $(PL_BP) $(srcdir)/prism.yap +PL_MP_ALL = $(PL_MP) +PL_BAT_ALL = $(PL_BAT) + +##---------------------------------------------------------------------- + +all: $(TARGETS) + +install: $(TARGETS) + mkdir -p $(DESTDIR)$(SHAREDIR)/Yap + mkdir -p $(DESTDIR)$(SHAREDIR)/Yap/prism + mkdir -p $(DESTDIR)$(SHAREDIR)/Yap/prism/bp + mkdir -p $(DESTDIR)$(SHAREDIR)/Yap/prism/core + mkdir -p $(DESTDIR)$(SHAREDIR)/Yap/prism/mp + mkdir -p $(DESTDIR)$(SHAREDIR)/Yap/prism/trans + mkdir -p $(DESTDIR)$(SHAREDIR)/Yap/prism/up + $(INSTALL_DATA) $(srcdir)/prism.yap $(DESTDIR)$(SHAREDIR)/Yap + for p in $(PL_BAT); do $(INSTALL_DATA) $$p $(DESTDIR)$(SHAREDIR)/Yap/prism/up; done + for p in $(PL_BP); do $(INSTALL_DATA) $$p $(DESTDIR)$(SHAREDIR)/Yap/prism/bp; done + for p in $(PL_CORE); do $(INSTALL_DATA) $$p $(DESTDIR)$(SHAREDIR)/Yap/prism/core; done + for p in $(PL_MP); do $(INSTALL_DATA) $$p $(DESTDIR)$(SHAREDIR)/Yap/prism/mp; done + for p in $(PL_TRANS); do $(INSTALL_DATA) $$p $(DESTDIR)$(SHAREDIR)/Yap/prism/trans; done + for p in $(PL_UP); do $(INSTALL_DATA) $$p $(DESTDIR)$(SHAREDIR)/Yap/prism/up; done + +clean: + rm -f prism.pl mpprism.pl batch.pl + +prism.pl: $(PL_UP_ALL) + cat $^ > $@ + +mpprism.pl: $(PL_MP_ALL) + cat $^ > $@ + +batch.pl: $(PL_BAT_ALL) + cat $^ > $@ + +%.out: %.pl $(COMPILER) + sh $(COMPILER) $< + +.PHONY: all install clean + diff --git a/packages/prism/src/prolog/README b/packages/prism/src/prolog/README new file mode 100644 index 000000000..625e18db8 --- /dev/null +++ b/packages/prism/src/prolog/README @@ -0,0 +1,40 @@ +======================= README (src/prolog) ====================== + +This directory contains the Prolog source files of the PRISM part, +along with a minimal set of source files from B-Prolog, required +to build the PRISM system. It is assumed that the compilation is +done on Linux, Mac OS X or Cygwin and that GNU make is used. + + Makefile ... Makefile + Compile.sh ... auxiliary shell script called by Makefile + core/ ... base components of the PRISM's Prolog part + trans/ ... translator from PRISM to Prolog + up/ ... probabilistic inferences + mp/ ... parallel EM learning + bp/ ... source file(s) from B-Prolog + +`up' and `mp' stands for uni-processor and multi-processor, +respectively. The source code of the First Order Compiler is +not available, and currently we have no plan for releasing it +to public. + + +[How to compile the Prolog part] + + Since the compiled code of the C part is used for the compilation + of the Prolog part, please make compilation and installation at + $(TOP)/src/c/ (for instruction, please see README in the directory) + in advance. + + Then, just type at the shell: + + make + + The compiled byte code will be installed (copied) into $(TOP)/bin + by typing: + + make install + + You can clean up the compiled byte codes by: + + make clean diff --git a/packages/prism/src/prolog/bp/eval.pl b/packages/prism/src/prolog/bp/eval.pl new file mode 100644 index 000000000..58121d694 --- /dev/null +++ b/packages/prism/src/prolog/bp/eval.pl @@ -0,0 +1,388 @@ +/* tracer and debugger of B-Prolog, + Neng-Fa Zhou +*/ +/*********************** eval_call(Call) no trace ******************/ +eval_call(Goal,_CP), var(Goal) => + handle_exception(illegal_predicate, Goal). +/* +eval_call((A : B),CP) => + eval_call(A,CP), + '_$cutto'(CP), + eval_call(B,CP). +eval_call((A ? B),CP) => + eval_call(A,CP), + eval_call(B,CP). +*/ +eval_call(true,_CP) => true. +eval_call((A,B),CP) => + eval_call(A,CP), + eval_call(B,CP). +eval_call((A -> B ; C),CP) => + eval_if_then_else(C,CP,A,B). +eval_call((A;B),CP) => + eval_or(A,B,CP). +eval_call((A -> B),CP) => + eval_if_then(A,B,CP). +eval_call(not(A),_CP) => + '_$savecp'(CP1), + eval_not(A,CP1). +eval_call(\+(A),_CP) => + '_$savecp'(CP1), + eval_not(A,CP1). +eval_call('!',CP) => + '_$cutto'(CP). +eval_call(call(X),_CP) => + '_$savecp'(CP1), + eval_call(X,CP1). +eval_call(Xs,_CP), [_|_]<=Xs => + consult_list(Xs). +eval_call(Goal,_CP), b_IS_CONSULTED_c(Goal) => + '_$savecp'(CP1), + clause(Goal,Body), + eval_call(Body,CP1). +eval_call(Goal,_CP) => + call(Goal). + +%% Prism-specific part +eval_call('_$initialize_var'(_Vars),_CP) => true. +eval_call('_$if_then_else'(C,A,B),CP) => eval_call((C->A;B),CP). + +eval_if_then_else(_C,CP,A,B) ?=> + '_$savecp'(CP1), + eval_call(A,CP1),!, + eval_call(B,CP). +eval_if_then_else(C,CP,_A,_B) => + eval_call(C,CP). + +eval_or(A,_B,CP) ?=> + eval_call(A,CP). +eval_or(_A,B,CP) => + eval_call(B,CP). + +eval_if_then(A,B,CP) => + '_$savecp'(CP1), + eval_call(A,CP1),!, + eval_call(B,CP). + +eval_not(A,CP) ?=> + eval_call(A,CP),!, + fail. +eval_not(_A,_CP) => true. + +/*********************** eval_call(Call) ******************/ +$trace_call(Call), b_IS_DEBUG_MODE => + '_$savecp'(CP), + eval_debug_call(Call,0,CP). +$trace_call(Call) => + '_$savecp'(CP), + eval_call(Call,CP). + +eval_debug_call(Goal,_Depth,_CP), var(Goal) => + handle_exception(illegal_predicate, Goal). +/* +eval_debug_call((A : B),Depth,CP) => + eval_debug_call(A,Depth,CP), + '_$cutto'(CP), + eval_debug_call(B,Depth,CP). +eval_debug_call((A ? B),Depth,CP) => + eval_debug_call(A,Depth,CP), + eval_debug_call(B,Depth,CP). +*/ +eval_debug_call((A,B),Depth,CP) => + eval_debug_call(A,Depth,CP), + eval_debug_call(B,Depth,CP). +eval_debug_call((A -> B ; C),Depth,CP) => + eval_debug_if_then_else(C,Depth,CP,A,B). +eval_debug_call((A;B),Depth,CP) => + eval_debug_or(A,B,Depth,CP). +eval_debug_call((A -> B),Depth,CP) => + eval_debug_if_then(A,B,Depth,CP). +eval_debug_call(not(A),Depth,_CP) => + '_$savecp'(CP1), + eval_debug_not(A,Depth,CP1). +eval_debug_call(\+(A),Depth,_CP) => + '_$savecp'(CP1), + eval_debug_not(A,Depth,CP1). +eval_debug_call('!',_Depth,CP) => + '_$cutto'(CP). +eval_debug_call('_$cutto'(X),_Depth,_CP) => + '_$cutto'(X). +eval_debug_call($trace_call(X),_Depth,_CP) => + $trace_call(X). +eval_debug_call(call(X),Depth,_CP) => + '_$savecp'(CP1), + eval_debug_call(X,Depth,CP1). +eval_debug_call($query(X),Depth,CP) => + eval_debug_call(X,Depth,CP). +eval_debug_call(true,_Depth,_CP) => true. +eval_debug_call($internal_match(X,Y),_Depth,_CP) => + nonvar(Y),X=Y. +eval_debug_call(trace,_Depth,_CP) => trace. +eval_debug_call(op(Prec,Fix,Op),_Depth,_CP) => + op(Prec,Fix,Op). +eval_debug_call(dynamic(Calls),_Depth,_CP) => + dynamic(Calls). +eval_debug_call(nospy,_Depth,_CP) => + nospy. +eval_debug_call(nospy(X),_Depth,_CP) => + nospy(X). +eval_debug_call(notrace,_Depth,_CP) => + notrace. +eval_debug_call(spy(S),_Depth,_CP) => + spy(S). +eval_debug_call(nospy(S),_Depth,_CP) => + nospy(S). +eval_debug_call(Xs,_Depth,_CP), [_|_]<=Xs => + consult_list(Xs). +eval_debug_call(Goal,Depth,_CP) => + c_SAVE_AR(AR), + c_next_global_call_number(CallNo), + $eval_and_monitor_call(Goal,Depth,CallNo,AR). + +%% Prism-specific part +eval_debug_call(Goal,_Depth,_CP), var(Goal) => + handle_exception(illegal_predicate, Goal). +eval_debug_call('_$initialize_var'(_Vars),_Depth,_CP) => true. +eval_debug_call('_$if_then_else'(C,A,B),Depth,CP) => + eval_debug_call((C->A;B),Depth,CP). +eval_debug_call(msw(Sw,V),Depth,CP) => + $pp_require_ground(Sw,$msg(0101),msw/2), + c_SAVE_AR(AR), + c_next_global_call_number(CallNo), + $prism_sample_msw(Sw,V,Depth,CP,CallNo,AR). + +eval_debug_if_then_else(_C,Depth,CP,A,B) ?=> + '_$savecp'(NewCP), + eval_debug_call(A,Depth,NewCP),!, + eval_debug_call(B,Depth,CP). +eval_debug_if_then_else(C,Depth,CP,_A,_B) => + eval_debug_call(C,Depth,CP). + +eval_debug_or(A,_B,Depth,CP) ?=> + eval_debug_call(A,Depth,CP). +eval_debug_or(_A,B,Depth,CP) => + eval_debug_call(B,Depth,CP). + +eval_debug_if_then(A,B,Depth,CP) => + '_$savecp'(NewCP), + eval_debug_call(A,Depth,NewCP),!, + eval_debug_call(B,Depth,CP). + +eval_debug_not(A,Depth,CP) ?=> + eval_debug_call(A,Depth,CP),!, + fail. +eval_debug_not(_A,_Depth,_CP) => true. + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +$eval_and_monitor_call(Call,Depth,CallNo,AR) ?=> + c_get_dg_flag(Flag), + $print_call(Flag,' Call: ',Call,Depth,CallNo,AR), + Depth1 is 1+Depth, + $eval_single_call(Call,Depth1), + $switch_skip_off(AR), + $eval_call_exit(Call,Depth,CallNo,AR). +$eval_and_monitor_call(Call,Depth,CallNo,AR) => + c_get_dg_flag(Flag), + $print_call(Flag,' Fail: ',Call,Depth,CallNo,AR), + fail. + +$eval_call_exit(Call,Depth,CallNo,AR) ?=> + c_get_dg_flag(Flag), + $print_call(Flag,' Exit: ',Call,Depth,CallNo,AR). +$eval_call_exit(Call,Depth,CallNo,AR) => + c_get_dg_flag(Flag), + $print_call(Flag,' Redo: ',Call,Depth,CallNo,AR), + fail. + +$eval_single_call(Call,Depth), b_IS_CONSULTED_c(Call) => + '_$savecp'(CP), + clause(Call,Body), + eval_debug_call(Body,Depth,CP). +$eval_single_call(Call,_Depth) => + call(Call). + +/* + --------------------------------------------- + |repeat | skip | leap | creep | spy | debug | + --------------------------------------------- +#define DG_FLAG_DEBUG 0x1 +#define DG_FLAG_SPY 0x2 +#define DG_FLAG_C 0x4 +#define DG_FLAG_L 0x8 +#define DG_FLAG_S 0x10 +#define DG_FLAG_R 0x20 +*/ + +%% Prism-specific part +$print_call(_F,_T,$pu_values(_,_), _D,_CNo,_AR) => true. +$print_call(_F,_T,$pd_is_prob_pred(_,_), _D,_CNo,_AR) => true. +$print_call(_F,_T,$pd_is_tabled_pred(_,_), _D,_CNo,_AR) => true. +$print_call(_F,_T,$pd_parameters(_,_,_), _D,_CNo,_AR) => true. +$print_call(_F,_T,$pd_hyperparameters(_,_,_,_),_D,_CNo,_AR) => true. +$print_call(_F,_T,$pd_expectations(_,_,_), _D,_CNo,_AR) => true. +$print_call(_F,_T,$pd_hyperexpectations(_,_,_),_D,_CNo,_AR) => true. +$print_call(_F,_T,$pd_fixed_parameters(_), _D,_CNo,_AR) => true. +$print_call(_F,_T,$pd_fixed_hyperparameters(_),_D,_CNo,_AR) => true. +$print_call(_Flag,_Type,write_call(_), _Depth,_CallNo,_AR) => true. +$print_call(_Flag,_Type,write_call(_,_),_Depth,_CallNo,_AR) => true. +$print_call(_Flag,_Type,(?? _), _Depth,_CallNo,_AR) => true. +$print_call(_Flag,_Type,(??* _), _Depth,_CallNo,_AR) => true. +$print_call(_Flag,_Type,(??> _), _Depth,_CallNo,_AR) => true. +$print_call(_Flag,_Type,(??< _), _Depth,_CallNo,_AR) => true. +$print_call(_Flag,_Type,(??+ _), _Depth,_CallNo,_AR) => true. +$print_call(_Flag,_Type,(??- _), _Depth,_CallNo,_AR) => true. +$print_call(Flag,Type,$prism_expl_msw(I,V,_SwId),Depth,CallNo,AR) => + $print_call(Flag,Type,msw(I,V),Depth,CallNo,AR). + +$print_call(Flag,Type,Call,Depth,CallNo,_AR), + Flag /\ 2'100000 =:= 2'100000 => %repeat + '$readl_userio'(I,O), + tab(2*Depth),write(Type),write('('),write(CallNo),write(') '), + print(Call),nl, + '$readl_resetio'(I,O). +$print_call(Flag,Type,Call,Depth,CallNo,AR), + Flag /\ 2'1000 =:= 2'1000 ?=> %leap + c_is_spy_point(Call),!, + $real_print_call(Type,Call,Depth,CallNo), + $next_monitor_instruction(Type,Call,Depth,CallNo,AR). +$print_call(Flag,Type,Call,Depth,CallNo,AR), + Flag /\ 2'100 =:= 2'100 => %creap + $real_print_call(Type,Call,Depth,CallNo), + $next_monitor_instruction(Type,Call,Depth,CallNo,AR). +$print_call(Flag,Type,Call,Depth,CallNo,AR), + Flag /\ 2'10000 =:= 2'10000 ?=> %skip + c_is_skip_ar(AR),!, + $real_print_call(Type,Call,Depth,CallNo), + $next_monitor_instruction(Type,Call,Depth,CallNo,AR). +$print_call(_Flag,_Type,_Call,_Depth,_AR,_CallNo) => true. + +$real_print_call(Type,Call,Depth,CallNo):- + '$readl_userio'(I,O), + tab(2*Depth),write(Type),write('('),write(CallNo),write(') '), + print(Call),writename(' ?'), + '$readl_resetio'(I,O). + +$next_monitor_instruction(Type,Call,Depth,CallNo,AR):- + $get_monitor_instruction(Inst), + $process_monitor_instruction(Type,Call,Depth,CallNo,AR,Inst). + +/* +#define DG_FLAG_DEBUG 0x1 +#define DG_FLAG_SPY 0x2 +#define DG_FLAG_C 0x4 +#define DG_FLAG_L 0x8 +#define DG_FLAG_S 0x10 +#define DG_FLAG_R 0x20 +*/ +$process_monitor_instruction(_Type,_Call,_Depth,_CallNo,_AR,0'a) => + abort. % abort +$process_monitor_instruction(_Type,_Call,_Depth,_CallNo,_AR,0'r) => + c_set_dg_flag(2'100000). % repeat +$process_monitor_instruction(_Type,_Call,_Depth,_CallNo,_AR,0'c) => + c_set_dg_flag(2'100). % creep +$process_monitor_instruction(_Type,_Call,_Depth,_CallNo,_AR,10) => + c_set_dg_flag(2'100). % return +$process_monitor_instruction(_Type,_Call,_Depth,_CallNo,_AR,0'n) => + c_get_dg_flag(Flag), + NewFlag is Flag/\2'11, + c_init_dg_flag(NewFlag). % no trace +$process_monitor_instruction(_Type,_Call,_Depth,_CallNo,_AR,0'l) => + c_set_dg_flag(2'1000). % leap +$process_monitor_instruction(Type,Call,Depth,CallNo,AR,0's) => + ((Type==' Fail: ';Type==' Exit: ')-> + write(user_output,'Option not applicable'),nl(user_output), + $real_print_call(Type,Call,Depth,CallNo), + $next_monitor_instruction(Type,Call,Depth,CallNo,AR); + c_set_dg_flag(2'10000), + c_set_skip_ar(AR)). % skip +$process_monitor_instruction(Type,Call,Depth,CallNo,AR,_) => % other ? + $print_help(Type), + $real_print_call(Type,Call,Depth,CallNo), + $next_monitor_instruction(Type,Call,Depth,CallNo,AR). + +$print_help(_Type):- + write(user,' a abort'),nl(user), + write(user,' ? help'),nl(user), + write(user,' h help'),nl(user), + write(user,' creep'),nl(user), + write(user,' c creep'),nl(user), + write(user,' h help'),nl(user), + write(user,' l leap'),nl(user), + write(user,' n nodebug'),nl(user), + write(user,' r repeat creep'),nl(user), + write(user,' s skip'),nl(user),nl(user). + +$get_monitor_instruction(Command):- + '$readl_userio'(I,O), + get0(Command), + $get_until_return(Command), + '$readl_resetio'(I,O). + +$get_until_return(10) => true. +$get_until_return(_Command) => + get0(X), + $get_until_return(X). + +$switch_skip_off(AR):- + c_is_skip_ar(AR),!, + c_set_skip_ar(0), + c_set_dg_flag(2'100). % creep +$switch_skip_off(_) => true. + + +/**************trace/1 spy/1******************/ +/* vsc: not supported in YAP yet +trace => + c_init_dg_flag(1). + +spy(S), var(S) => + c_get_spy_points(S). +spy([X|Xs]) => + spy(X), + spy(Xs). +spy([]) => true. +spy(Pred), F/N<=Pred, atom(F),integer(N) => + (c_CURRENT_PREDICATE(F,N)-> + '$readl_userio'(I,O), + write('Spy point '), write(Pred), write(' has been set.'),nl, + '$readl_resetio'(I,O), + c_add_spy_point(F,N); + handle_exception(predicate_not_exist, Pred)). +spy(F), atom(F) => + $search_preds(F,25,[],X), + (X\==[]->spy(X); handle_exception(predicate_not_exist, F)). +spy(F):- + handle_exception(illegal_argument, spy(F)). + +$search_preds(_X,N,P0,P), N<0 => + P=P0. +$search_preds(X,N,P0,P):- + c_CURRENT_PREDICATE(X,N),!, + N1 is N-1, + $search_preds(X,N1,[X/N|P0],P). +$search_preds(X,N,P0,P) => + N1 is N-1, + $search_preds(X,N1,P0,P). + +notrace => + c_init_dg_flag(0), + nospy. + +nospy([X|Xs]) => + nospy(X), + nospy(Xs). +nospy([]) => true. +nospy(F/N), atom(F), integer(N) => + c_remove_spy_point(F,N). +nospy(F), atom(F) => + $search_preds(F,25,[],X), + nospy(X). +nospy(F) => + handle_exception(illegal_predicate, nospy(F)). + +nospy:- + c_remove_spy_points. + +trace(Call) => + $trace_call(Call). +*/ \ No newline at end of file diff --git a/packages/prism/src/prolog/core/error.pl b/packages/prism/src/prolog/core/error.pl new file mode 100644 index 000000000..079cc8fd6 --- /dev/null +++ b/packages/prism/src/prolog/core/error.pl @@ -0,0 +1,909 @@ +%%---------------------------------------- + +$pp_emit_message(MsgID) :- + $pp_emit_message(MsgID,[]). + +$pp_emit_message(MsgID,Args) :- + $pp_assert($pp_message(MsgID,Type,Format)), + $pp_message_head(Type,Head), + format("*** ~w: ",[Head]), + $pp_format_message(Format,Args), + format("~n",[]). + +$pp_message_head(fatal,'PRISM FATAL ERROR'). +$pp_message_head(inter,'PRISM INTERNAL ERROR'). +$pp_message_head(error,'PRISM ERROR'). +$pp_message_head(fail ,'PRISM WARNING'). +$pp_message_head(warn ,'PRISM WARNING'). +$pp_message_head(obosl,'PRISM WARNING'). +$pp_message_head(info ,'PRISM INFO'). + +%%---------------------------------------- + +$pp_assert(Cond) :- + ( call(Cond) -> + true + ; $pp_emit_message($msg(9900),[Cond]), halt + ). + +%%---------------------------------------- + +% instanciation errors +$pp_raise_instanciation_error(MsgID,Source) :- + $pp_raise_instanciation_error(MsgID,[],Source). +$pp_raise_instanciation_error(MsgID,Args,Source) :- + $pp_emit_message(MsgID,Args), + throw(error(instanciation_error,Source)). + +% type errors +$pp_raise_type_error(MsgID,[Type,Culprit],Source) :- + $pp_raise_type_error(MsgID,[],[Type,Culprit],Source). +$pp_raise_type_error(MsgID,Args,[Type,Culprit],Source) :- + $pp_emit_message(MsgID,Args), + throw(error(type_error(Type,Culprit),Source)). + +% domain errors +$pp_raise_domain_error(MsgID,[Domain,Culprit],Source) :- + $pp_raise_domain_error(MsgID,[],[Domain,Culprit],Source). +$pp_raise_domain_error(MsgID,Args,[Domain,Culprit],Source) :- + $pp_emit_message(MsgID,Args), + throw(error(domain_error(Domain,Culprit),Source)). + +% existence errors +$pp_raise_existence_error(MsgID,[ObjType,Culprit],Source) :- + $pp_raise_existence_error(MsgID,[],[ObjType,Culprit],Source). +$pp_raise_existence_error(MsgID,Args,[ObjType,Culprit],Source) :- + $pp_emit_message(MsgID,Args), + throw(error(existence_error(ObjType,Culprit),Source)). + +% permission errors +$pp_raise_permission_error(MsgID,[Operation,PermissionType,Culprit],Source) :- + $pp_raise_permission_error(MsgID,[], + [Operation,PermissionType,Culprit], + Source). +$pp_raise_permission_error(MsgID,Args, + [Operation,PermissionType,Culprit], + Source) :- + $pp_emit_message(MsgID,Args), + throw(error(permission_error(Operation,PermissionType,Culprit),Source)). + +% evaluation errors +$pp_raise_evaluation_error(MsgID,Error,Source) :- + $pp_raise_evaluation_error(MsgID,[],Error,Source). +$pp_raise_evaluation_error(MsgID,Args,Error,Source) :- + $pp_emit_message(MsgID,Args), + throw(error(evaluation_error(Error),Source)). + +% runtime errors +$pp_raise_runtime_error(MsgID,Error,Source) :- + $pp_raise_runtime_error(MsgID,[],Error,Source). +$pp_raise_runtime_error(MsgID,Args,Error,Source) :- + $pp_emit_message(MsgID,Args), + throw(error(prism_runtime_error(Error),Source)). + +% translation errors +$pp_raise_trans_error(MsgID,Error,Source) :- + $pp_raise_trans_error(MsgID,[],Error,Source). +$pp_raise_trans_error(MsgID,Args,Error,Source) :- + $pp_emit_message(MsgID,Args), + throw(error(prism_translation_error(Error),Source)). + +% internal errors +$pp_raise_internal_error(MsgID,Error,Source) :- + $pp_raise_internal_error(MsgID,[],Error,Source). +$pp_raise_internal_error(MsgID,Args,Error,Source) :- + $pp_emit_message(MsgID,Args), + throw(error(prism_internal_error(Error),Source)). + +% warnings +$pp_raise_warning(MsgID) :- $pp_raise_warning(MsgID,[]). +$pp_raise_warning(MsgID,Args) :- + ( get_prism_flag(warn,on) -> $pp_emit_message(MsgID,Args) + ; true + ). + +%%---------------------------------------- +%% typical internal errors + +$pp_raise_unmatched_branches(Source) :- + $pp_raise_internal_error($msg(9803),unmatched_branches,Source). +$pp_raise_unmatched_branches(Source,Position) :- + $pp_raise_internal_error($msg(9803),unmatched_branches(Position),Source). + +$pp_raise_unexpected_failure(Source) :- + $pp_raise_internal_error($msg(9804),unexpected_failure,Source). + +$pp_raise_unexpected_failure(Source,Position) :- + $pp_raise_internal_error($msg(9804),unexpected_failure(Position),Source). + +%%---------------------------------------- + +$pp_raise_on_require(Xs,MsgID,Source,Pred) :- + $pp_emit_message(MsgID,Xs), + append(Xs,[Error],Args), + G =.. [Pred|Args], + ( call(G) -> + true + ; $pp_emit_message($msg(9800)), + Error = prism_internal_error(error_term_not_found) + ), + throw(error(Error,Source)). + +%%---------------------------------------- + +$pp_require_atom(X,MsgID,Source) :- + ( atom(X) -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_atom) + ). + +$pp_error_atom(X,instantiation_error) :- + \+ ground(X), !. +$pp_error_atom(X,type_error(atom,X)) :- + \+ atom(X), !. + +%%---------------------------------------- + +$pp_require_nonvar(X,MsgID,Source) :- + ( nonvar(X) -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_nonvar) + ). + +$pp_error_nonvar(X,instantiation_error) :- + var(X), !. + +%%---------------------------------------- + +$pp_require_nonvars(Xs,MsgID,Source) :- + ( $pp_test_nonvars(Xs) -> true + ; $pp_raise_on_require([Xs],MsgID,Source,$pp_error_nonvars) + ). + +$pp_test_nonvars(Xs) :- + Xs = [_|_], + $pp_test_nonvars1(Xs). + +$pp_test_nonvars1([]). +$pp_test_nonvars1([X|Xs]) :- + nonvar(X),!, + $pp_test_nonvars1(Xs). + +$pp_error_nonvars(Xs,Error) :- + $pp_error_ground(Xs,Error), !. +$pp_error_nonvars(Xs,Error) :- + $pp_error_list(Xs,Error), !. +$pp_error_nonvars(Xs,domain_error(non_variables,Xs)) :- + member(X,Xs), + var(X), !. + +%%---------------------------------------- + +$pp_require_ground(X,MsgID,Source) :- + ( ground(X) -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_ground) + ). + +$pp_error_ground(X,instantiation_error) :- + \+ ground(X), !. + +%%---------------------------------------- + +$pp_require_callable(X,MsgID,Source) :- + ( callable(X) -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_callable) + ). + +$pp_error_callable(X,type_error(callable,X)) :- + \+ callable(X), !. + +%%---------------------------------------- + +$pp_require_integer(X,MsgID,Source) :- + ( integer(X) -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_integer) + ). + +$pp_error_integer(X,instantiation_error) :- + var(X), !. +$pp_error_integer(X,type_error(integer,X)) :- + \+ integer(X), !. + +%%---------------------------------------- + +$pp_require_positive_integer(X,MsgID,Source) :- + ( integer(X), X > 0 -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_positive_integer) + ). + +$pp_error_positive_integer(X,Error) :- + $pp_error_integer(X,Error), !. +$pp_error_positive_integer(X,domain_error(greater_than_zero,X)) :- + X =< 0, !. + +%%---------------------------------------- + +$pp_require_non_negative_integer(X,MsgID,Source) :- + ( integer(X), X >= 0 -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_non_negative_integer) + ). + +$pp_error_non_negative_integer(X,Error) :- + $pp_error_integer(X,Error), !. +$pp_error_non_negative_integer(X,domain_error(not_less_than_zero,X)) :- + X < 0, !. + +%%---------------------------------------- + +$pp_require_number(X,MsgID,Source) :- + ( number(X) -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_number) + ). + +$pp_error_number(X,instantiation_error) :- + var(X), !. +$pp_error_number(X,type_error(number,X)) :- + \+ number(X), !. + +%%---------------------------------------- + +$pp_require_numbers(Xs,MsgID,Source) :- + ( $pp_test_numbers(Xs) -> true + ; $pp_raise_on_require([Xs],MsgID,Source, + $pp_error_numbers) + ). + +$pp_test_numbers(Xs) :- + Xs = [_|_], + $pp_test_numbers1(Xs). + +$pp_test_numbers1([]). +$pp_test_numbers1([X|Xs]) :- + number(X),!, + $pp_test_numbers1(Xs). + +$pp_error_numbers(Xs,Error) :- + $pp_error_ground(Xs,Error), !. +$pp_error_numbers(Xs,Error) :- + $pp_error_list(Xs,Error), !. +$pp_error_numbers(Xs,domain_error(numbers,Xs)) :- + member(X,Xs), + \+ number(X), !. + +%%---------------------------------------- + +$pp_require_positive_number(X,MsgID,Source) :- + ( number(X), X > 0 -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_positive_number) + ). + +$pp_error_positive_number(X,Error) :- + $pp_error_number(X,Error), !. +$pp_error_positive_number(X,domain_error(greater_than_zero,X)) :- + X =< 0, !. + +%%---------------------------------------- + +$pp_require_positive_numbers(Xs,MsgID,Source) :- + ( $pp_test_positive_numbers(Xs) -> true + ; $pp_raise_on_require([Xs],MsgID,Source, + $pp_error_positive_numbers) + ). + +$pp_test_positive_numbers(Xs) :- + Xs = [_|_], + $pp_test_positive_numbers1(Xs). + +$pp_test_positive_numbers1([]). +$pp_test_positive_numbers1([X|Xs]) :- + number(X), + X > 0,!, + $pp_test_positive_numbers1(Xs). + +$pp_error_positive_numbers(Xs,Error) :- + $pp_error_ground(Xs,Error), !. +$pp_error_positive_numbers(Xs,Error) :- + $pp_error_list(Xs,Error), !. +$pp_error_positive_numbers(Xs,domain_error(positive_numbers,Xs)) :- + member(X,Xs), + (\+ number(X) ; X =< 0), !. + +%%---------------------------------------- + +$pp_require_number_not_less_than(X,Min,MsgID,Source) :- + $pp_assert(number(Min)), + ( number(X), X >= Min -> + true + ; $pp_raise_on_require([X,Min],MsgID,Source,$pp_error_number_not_less_than) + ). + +$pp_error_number_not_less_than(X,_,Error) :- + $pp_error_number(X,Error), !. +$pp_error_number_not_less_than(X,Min,domain_error(not_less_than(Min),X)) :- + X < Min, !. + +%%---------------------------------------- + +$pp_require_numbers_not_less_than(Xs,Min,MsgID,Source) :- + $pp_assert(number(Min)), + ( $pp_test_numbers_not_less_than(Min,Xs) -> true + ; $pp_raise_on_require([Xs,Min],MsgID,Source, + $pp_error_numbers_not_less_than) + ). + +$pp_test_numbers_not_less_than(Min,Xs) :- + Xs = [_|_], + $pp_test_numbers_not_less_than1(Min,Xs). + +$pp_test_numbers_not_less_than1(_,[]). +$pp_test_numbers_not_less_than1(Min,[X|Xs]) :- + number(X), + X >= Min,!, + $pp_test_numbers_not_less_than1(Min,Xs). + +$pp_error_numbers_not_less_than(Xs,_,Error) :- + $pp_error_ground(Xs,Error), !. +$pp_error_numbers_not_less_than(Xs,_,Error) :- + $pp_error_list(Xs,Error), !. +$pp_error_numbers_not_less_than(Xs,Min, + domain_error(numbers_not_less_than(Min),Xs)) :- + member(X,Xs), + (\+ number(X) ; X < Min ), !. + +%%---------------------------------------- + +$pp_require_non_negative_number(X,MsgID,Source) :- + ( number(X), X >= 0 -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_non_negative_number) + ). + +$pp_error_non_negative_number(X,Error) :- + $pp_error_number(X,Error), !. +$pp_error_non_negative_number(X,domain_error(not_less_than_zero,X)) :- + X < 0, !. + +%%---------------------------------------- + +$pp_require_non_negative_numbers(Xs,MsgID,Source) :- + ( $pp_test_non_negative_numbers(Xs) -> true + ; $pp_raise_on_require([Xs],MsgID,Source,$pp_error_non_negative_numbers) + ). + +$pp_test_non_negative_numbers(Xs) :- + Xs = [_|_], + $pp_test_non_negative_numbers1(Xs). + +$pp_test_non_negative_numbers1([]). +$pp_test_non_negative_numbers1([X|Xs]) :- + number(X), + X >= 0.0,!, + $pp_test_non_negative_numbers1(Xs). + +$pp_error_non_negative_numbers(Xs,Error) :- + $pp_error_ground(Xs,Error), !. +$pp_error_non_negative_numbers(Xs,Error) :- + $pp_error_list(Xs,Error), !. +$pp_error_non_negative_numbers(Xs,domain_error(non_negative_numbers,Xs)) :- + member(X,Xs), + (\+ number(X) ; X < 0 ), !. + +%%---------------------------------------- + +$pp_require_list(X,MsgID,Source) :- + ( nonvar(X), X = [_|_] -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_list) + ). + +$pp_error_list(X,instanciation_error) :- + var(X), !. +$pp_error_list(X,type_error(list,X)) :- + X \= [_|_], !. + +%%---------------------------------------- + +$pp_require_list_or_nil(X,MsgID,Source) :- + ( nonvar(X), (X = [_|_] ; X = []) -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_list_or_nil) + ). + +$pp_error_list_or_nil(X,instanciation_error) :- + var(X), !. +$pp_error_list_or_nil(X,type_error(list_or_nil,X)) :- + X \= [_|_], X \= [], !. + +%%---------------------------------------- + +$pp_require_list_not_shorter_than(X,Min,MsgID,Source) :- + $pp_assert(integer(Min)), + $pp_assert(Min >= 0), + ( $pp_test_list_not_shorter_than(X,Min) -> true + ; $pp_raise_on_require([X,Min],MsgID,Source,$pp_error_list_not_shorter_than) + ). + +$pp_test_list_not_shorter_than(X,Min) :- + nonvar(X), + ( X = [_|_] ; X = [] ), + length(X,L), L >= Min. + +$pp_error_list_not_shorter_than(X,_Min,instanciation_error) :- + var(X), !. +$pp_error_list_not_shorter_than(X,_Min,type_error(list,X)) :- + X \= [_|_], X \= [], !. +$pp_error_list_not_shorter_than(X,Min,type_error(list_not_shorter_than(Min),X)) :- + length(X,L), L < Min, !. + +%%---------------------------------------- + +$pp_require_compound(X,MsgID,Source) :- + ( compound(X) -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_compound) + ). + +$pp_error_compound(X,instantiation_error) :- + var(X), !. +$pp_error_compound(X,type_error(compound,X)) :- + \+ compound(X), !. + +%%---------------------------------------- + +$pp_require_integer_range(Min,Max,MsgID,Source) :- + $pp_assert(integer(Min)), + $pp_assert(integer(Max)), + ( Min < Max -> + true + ; $pp_raise_on_require([Min,Max],MsgID,Source,$pp_error_integer_range) + ). + +$pp_error_integer_range(Min,Max,Error) :- + Min >= Max, + Error = domain_error(integer_range,[Min,Max]), !. + +%%---------------------------------------- + +$pp_require_integer_range_incl(Min,Max,MsgID,Source) :- + $pp_assert(integer(Min)), + $pp_assert(integer(Max)), + ( Min =< Max -> + true + ; $pp_raise_on_require([Min,Max],MsgID,Source,$pp_error_integer_range_incl) + ). + +$pp_error_integer_range_incl(Min,Max,Error) :- + Min > Max, + Error = domain_error(integer_range_inclusive,[Min,Max]), !. + +%%---------------------------------------- + +$pp_require_integer_range_excl(Min,Max,MsgID,Source) :- + $pp_assert(integer(Min)), + $pp_assert(integer(Max)), + ( Min + 1 > Min, Min + 1 < Max -> % (Min + 1 =< Min) -> overflow + true + ; $pp_raise_on_require([Min,Max],MsgID,Source,$pp_error_integer_range_excl) + ). + +$pp_error_integer_range_excl(Min,Max,Error) :- + ( Min + 1 =< Min ; Min + 1 >= Max ), + Error = domain_error(integer_range_exclusive,[Min,Max]), !. + +%%---------------------------------------- + +$pp_require_number_range_incl(Min,Max,MsgID,Source) :- + $pp_assert(number(Min)), + $pp_assert(number(Max)), + ( Min =< Max -> + true + ; $pp_raise_on_require([Min,Max],MsgID,Source,$pp_error_number_range_incl) + ). + +$pp_error_number_range_incl(Min,Max,Error) :- + Min > Max, + Error = domain_error(number_range_inclusive,[Min,Max]), !. + +%%---------------------------------------- + +$pp_require_number_range_excl(Min,Max,MsgID,Source) :- + $pp_assert(number(Min)), + $pp_assert(number(Max)), + ( Min < Max -> + true + ; $pp_raise_on_require([Min,Max],MsgID,Source,$pp_error_number_range_excl) + ). + +$pp_error_number_range_excl(Min,Max,Error) :- + Min >= Max, + Error = domain_error(number_range_exclusive,[Min,Max]), !. + +%%---------------------------------------- + +$pp_require_membership(X,Xs,MsgID,Source) :- + $pp_assert(Xs = [_|_]), + ( nonvar(X),membchk(X,Xs) -> true + ; $pp_raise_on_require([X,Xs],MsgID,Source,$pp_error_membership) + ). + +$pp_error_membership(X,_Xs,Error) :- + $pp_error_nonvar(X,Error), !. +$pp_error_membership(X,Xs,domain_error(Xs,X)) :- + \+ membchk(X,Xs), !. + +%%---------------------------------------- + +$pp_require_predicate_indicator(X,MsgID,Source) :- + ( $pp_test_predicate_indicator(X) -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_predicate_indicator) + ). + +$pp_test_predicate_indicator(X) :- + X = F/N, atom(F), integer(N), N >= 0. + +$pp_error_predicate_indicator(X,Error) :- + $pp_error_ground(X,Error), !. +$pp_error_predicate_indicator(X,type_error(predicate_indicator,X)) :- + \+ $pp_test_predicate_indicator(X), !. + +%%---------------------------------------- + +$pp_require_user_probabilistic_atom(X,MsgID,Source) :- + ( $pp_is_user_probabilistic_atom(X) -> true + ; $pp_raise_on_require([X],MsgID,Source, + $pp_error_user_probabilistic_atom) + ). + +$pp_error_user_probabilistic_atom(X,Error) :- + $pp_error_nonvar(X,Error), !. +$pp_error_user_probabilistic_atom(X,Error) :- + $pp_error_callable(X,Error), !. +$pp_error_user_probabilistic_atom(X,type_error(user_probabilistic_atom,X)) :- + \+ $pp_is_user_probabilistic_atom(X), !. + +%%---------------------------------------- + +$pp_require_probabilistic_atom(X,MsgID,Source) :- + ( $pp_is_probabilistic_atom(X) -> true + ; $pp_raise_on_require([X],MsgID,Source, + $pp_error_probabilistic_atom) + ). + +$pp_error_probabilistic_atom(X,Error) :- + $pp_error_nonvar(X,Error), !. +$pp_error_probabilistic_atom(X,Error) :- + $pp_error_callable(X,Error), !. +$pp_error_probabilistic_atom(X,type_error(probabilistic_atom,X)) :- + \+ $pp_is_probabilistic_atom(X), !. + +%%---------------------------------------- + +$pp_require_extended_probabilistic_atom(X,MsgID,Source) :- + ( $pp_is_extended_probabilistic_atom(X) -> true + ; $pp_raise_on_require([X],MsgID,Source, + $pp_error_extended_probabilistic_atom) + ). + +$pp_error_extended_probabilistic_atom(X,Error) :- + $pp_error_probabilistic_atom(X,Error), !. +$pp_error_extended_probabilistic_atom(X,type_error(extended_probabilistic_atom,X)) :- + \+ $pp_is_extended_probabilistic_atom(X), !. + +%%---------------------------------------- + +$pp_require_probabilistic_callable(X,MsgID,Source) :- + ( $pp_is_probabilistic_callable(X) -> true + ; $pp_raise_on_require([X],MsgID,Source, + $pp_error_probabilistic_callable) + ). + +$pp_error_probabilistic_callable(X,Error) :- + $pp_error_probabilistic_atom(X,Error), !. +$pp_error_probabilistic_callable(X,type_error(probabilistic_callable,X)) :- + \+ $pp_is_probabilistic_callable(X), !. + +%%---------------------------------------- + +$pp_require_tabled_probabilistic_atom(X,MsgID,Source) :- + ( $pp_is_tabled_probabilistic_atom(X) -> true + ; $pp_raise_on_require([X],MsgID,Source, + $pp_error_tabled_probabilistic_atom) + ). + +$pp_error_tabled_probabilistic_atom(X,Error) :- + $pp_error_probabilistic_atom(X,Error), !. +$pp_error_tabled_probabilistic_atom(X,type_error(tabled_probabilistic_atom,X)) :- + \+ $pp_is_tabled_probabilistic_atom(X), !. + +%%---------------------------------------- + +$pp_require_msw_declaration(MsgID,Source) :- + ( current_predicate($pu_values/2) -> true + ; $pp_raise_on_require([],MsgID,Source,$pp_error_msw_declaration) + ). + +$pp_error_msw_declaration(msw_declaration_not_found) :- + \+ current_predicate($pu_values/2), !. + +%%---------------------------------------- + +$pp_require_switch_outcomes(X,MsgID,Source) :- + $pp_assert(ground(X)), + ( current_predicate($pu_values/2), + $pu_values(X,_) + -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_switch_outcomes) + ). + +$pp_error_switch_outcomes(_X,Error) :- + $pp_error_msw_declaration(Error), !. +$pp_error_switch_outcomes(X,existence_error(outcome,X)) :- + \+ $pu_values(X,_), !. + +%%---------------------------------------- + +$pp_require_prism_flag(Flag,MsgID,Source) :- + ( $pp_test_prism_flag(Flag) -> true + ; $pp_raise_on_require([Flag],MsgID,Source,$pp_error_prism_flag) + ). + +$pp_test_prism_flag(Flag) :- + atom(Flag), + ( $pp_prism_flag(Flag,_,_,_) + ; $pp_prism_flag_renamed(Flag,Flag1), + $pp_prism_flag(Flag1,_,_,_) + ). + +$pp_error_prism_flag(Flag,Error) :- + $pp_error_atom(Flag,Error), !. +$pp_error_prism_flag(Flag,domain_error(prism_flag,Flag)) :- + \+ $pp_prism_flag(Flag,_,_,_), !. + +%%---------------------------------------- + +$pp_require_prism_flag_value(Flag,Value,MsgID,Source) :- + $pp_assert($pp_test_prism_flag(Flag)), + ( $pp_test_prism_flag_value(Flag,Value) -> true + ; $pp_raise_on_require([Flag,Value],MsgID,Source,$pp_error_prism_flag_value) + ). + +$pp_test_prism_flag_value(Flag,Value) :- + ground(Value), + ( $pp_prism_flag(Flag,Type,_,_), + $pp_check_prism_flag(Type,Value,_,_) + ; $pp_prism_flag_renamed(Flag,Flag1), + $pp_prism_flag(Flag1,Type,_,_), + $pp_check_prism_flag(Type,Value,_,_) + ). + +$pp_error_prism_flag_value(_Flag,Value,Error) :- + $pp_error_ground(Value,Error), !. +$pp_error_prism_flag_value(Flag,Value, + domain_error(prism_flag_value(Flag),Value)) :- + \+ $pp_test_prism_flag_value(Flag,Value), !. + +%%---------------------------------------- + +$pp_require_distribution(X,MsgID,Source) :- + ( $pp_test_distribution(X) -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_distribution) + ). + +% we do not check each element at this moment +$pp_test_distribution(X) :- + ( $pp_test_fixed_size_distribution(X) + ; $pp_test_variable_size_distribution(X) + ). + +$pp_test_variable_size_distribution(X) :- + ground(X), + ( X = uniform + ; X = f_geometric + ; X = f_geometric(Base) -> number(Base), Base > 1 + ; X = f_geometric(Base,Type) -> number(Base), Base > 1, membchk(Type,[asc,desc]) + ; X = random + ; X = default + ). + +$pp_error_distribution(X,Error) :- + $pp_error_ground(X,Error), !. +$pp_error_distribution(X,domain_error(distribution,X)) :- + \+ $pp_test_distribution(X), !. + +%%---------------------------------------- + +$pp_require_fixed_size_distribution(X,MsgID,Source) :- + ( $pp_test_fixed_size_distribution(X) -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_fixed_size_distribution) + ). + +% we do not check each element at this moment +$pp_test_fixed_size_distribution(X) :- + ground(X), + ( $pp_test_probabilities(X) + ; $pp_test_probabilities_plus(X) + ; $pp_test_ratio(X) + ). + +$pp_test_probabilities_plus(X) :- + $pp_expr_to_list('+',X,Ps), + length(Ps,L), + L > 1,!, + $pp_test_probabilities(Ps). + +$pp_test_ratio(X) :- + $pp_expr_to_list(':',X,Rs), + length(Rs,L), + L > 1,!, + $pp_test_non_negative_numbers(Rs), + \+ $pp_test_zeros(Rs). + +$pp_test_zeros([]). +$pp_test_zeros([Z|Zs]):- + -1.0e-15 < Z, + 1.0e-15 > Z,!, + $pp_test_zeros(Zs). + +$pp_error_fixed_size_distribution(X,Error) :- + $pp_error_ground(X,Error), !. +$pp_error_fixed_size_distribution(X,domain_error(fixed_size_distribution,X)) :- + \+ $pp_test_fixed_size_distribution(X), !. + +%%---------------------------------------- + +$pp_require_probability(X,MsgID,Source) :- + ( $pp_test_probability(X) -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_probability) + ). + +$pp_test_probability(X) :- + number(X), + X >= 0.0, + X =< 1.0. + +$pp_error_probability(X,Error) :- + $pp_error_number(X,Error), !. +$pp_error_probability(X,domain_error(probability,X)) :- + ( X < 0.0 ; X > 1.0 ), !. + +%%---------------------------------------- + +$pp_require_probabilities(Ps,MsgID,Source) :- + ( $pp_test_probabilities(Ps) -> true + ; $pp_raise_on_require([Ps],MsgID,Source,$pp_error_probabilities) + ). + +$pp_test_probabilities(Ps) :- + Ps = [_|_], + $pp_test_probabilities1(Ps), + sumlist(Ps,Sum), + abs(Sum - 1.0) =< 1.0e-12. + +$pp_test_probabilities1([]). +$pp_test_probabilities1([P|Ps]) :- + $pp_test_probability(P),!, + $pp_test_probabilities1(Ps). + +$pp_error_probabilities(Ps,Error) :- + $pp_error_list(Ps,Error), !. +$pp_error_probabilities(Ps,Error) :- + member(P,Ps), + $pp_error_probability(P,Error), !. +$pp_error_probabilities(Ps,domain_error(probabilities,Ps)) :- + sumlist(Ps,Sum), + abs(Sum - 1.0) > 1.0e-12, !. + +%%---------------------------------------- + +$pp_require_hyperparameters(X,MsgID,Source) :- + ( $pp_test_hyperparameters(X) -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_hyperparameters) + ). + +$pp_test_hyperparameters(X) :- + ( $pp_test_fixed_size_hyperparameters(X) + ; $pp_test_variable_size_hyperparameters(X) + ). + +$pp_test_variable_size_hyperparameters(X) :- + ground(X), + ( number(X) -> X >= 0.0 + ; X = uniform + ; X = uniform(U) -> number(U), U >= 0 + ; X = f_geometric + ; X = f_geometric(Base) -> + number(Base), Base > 1 + ; X = f_geometric(Init,Base) -> + number(Init), Init >= 0, + number(Base), Base > 1 + ; X = f_geometric(Init,Base,Type) -> + number(Init), Init >= 0, + number(Base), Base > 1, + membchk(Type,[asc,desc]) + ; X = default + ). + +$pp_error_hyperparameters(X,Error) :- + $pp_error_ground(X,Error), !. +$pp_error_hyperparameters(X,domain_error(hyperparameters,X)) :- + \+ $pp_test_hyperparameters(X), !. + +%%---------------------------------------- + +$pp_require_fixed_size_hyperparameters(X,MsgID,Source) :- + ( $pp_test_fixed_size_hyperparameters(X) -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_fixed_size_hyperparameters) + ). + +$pp_test_fixed_size_hyperparameters(X) :- + ground(X), + $pp_test_non_negative_numbers(X). + +$pp_error_fixed_size_hyperparameters(X,Error) :- + $pp_error_ground(X,Error), !. +$pp_error_fixed_size_hyperparameters(X,domain_error(fixed_size_hyperparameters,X)) :- + \+ $pp_test_fixed_size_hyperparameters(X), !. + +%%---------------------------------------- + +$pp_require_prism_option(X,MsgID,Source) :- + ( $pp_test_prism_option(X) -> true + ; $pp_raise_on_require([X],MsgID,Source,$pp_prism_option) + ). + +$pp_test_prism_option(X) :- + ground(X), + ( X = dump + ; X = consult + ; X = compile + ; X = load + ; X = v + ; X = verb + ; X = nv + ; X = noverb + ; X = consult + ; X = (_=_) + ). + +$pp_error_prism_option(X,Error) :- + $pp_error_ground(X,Error), !. +$pp_error_prism_option(X,domain_error(prism_option,X)) :- + \+ $pp_test_prism_opton(X), !. + +%%---------------------------------------- + +% aggregate pattern is so flexible that we can only check if +% X is callable or not + +$pp_require_hindsight_aggregate_pattern(X,MsgID,Source) :- + ( $pp_test_hindsight_aggregate_pattern(X) -> true + ; $pp_raise_on_require([X],MsgID,Source, + $pp_error_hindsight_aggregate_pattern) + ). + +$pp_test_hindsight_aggregate_pattern(X) :- + callable(X). + +$pp_error_hindsight_aggregate_pattern(X,Error) :- + $pp_error_nonvar(X,Error), !. +$pp_error_hindsight_aggregate_pattern(X,Error) :- + $pp_error_callable(X,Error), !. + +%%---------------------------------------- + +$pp_require_write_callable(G,MsgID,Source) :- + ( $pp_is_write_callable(G) -> true + ; $pp_raise_on_require([G],MsgID,Source,$pp_write_callable) + ). + +$pp_error_write_callable(G,Error) :- + $pp_error_nonvar(G,Error), !. +$pp_error_write_callable(G,Error) :- + $pp_error_callable(G,Error), !. +$pp_error_write_callable(G,domain_error(write_callable,G)) :- + \+ $pp_is_write_callable(G), !. diff --git a/packages/prism/src/prolog/core/format.pl b/packages/prism/src/prolog/core/format.pl new file mode 100644 index 000000000..789d46517 --- /dev/null +++ b/packages/prism/src/prolog/core/format.pl @@ -0,0 +1,55 @@ +%%-------------------------------- +%% Entry Point + +$pp_format_message(Format,Args) :- + $pp_format_message_loop(Format,Args). + + +%%-------------------------------- +%% Main Loop + +$pp_format_message_loop([],_) :- !. +$pp_format_message_loop(Format,Args) :- + Format = [Next|Format0], + ( Next == 0'{ -> % ' + $pp_format_message_loop(Format0,Format1,Args) + ; Next == 0'~ -> % ' + Format0 = [Code|Format1], $pp_format_message_char(Code) + ; %% else + Format0 = Format1, $pp_format_message_char(Next) + ), !, + $pp_format_message_loop(Format1,Args). + +$pp_format_message_loop(Format0,Format1,Args) :- + $pp_format_message_spec(Format0,Format1,N), + nth(N,Args,Arg), !, + $pp_format_message_term(Arg). +$pp_format_message_loop(Format0,Format0,_Args) :- + $pp_format_message_char(0'{). % ' + + +%%-------------------------------- +%% Format Spec Extraction + +$pp_format_message_spec(Format0,Format1,N) :- + $pp_format_message_spec(Format0,Format1,[],Codes), + number_codes(N,Codes). + +$pp_format_message_spec(Xs0,Xs1,Ys,Ys) :- + Xs0 = [0'}|Xs1], !. % ' +$pp_format_message_spec(Xs0,Xs1,Zs0,Ys) :- + Xs0 = [X|Xs2], + integer(X), + X >= 48, + X =< 57, + Zs1 = [X|Zs0], !, + $pp_format_message_spec(Xs2,Xs1,Zs1,Ys). + + +%%-------------------------------- +%% Output + +$pp_format_message_char(Code) :- + format("~c",[Code]). +$pp_format_message_term(Term) :- + format("~w",[Term]). diff --git a/packages/prism/src/prolog/core/message.pl b/packages/prism/src/prolog/core/message.pl new file mode 100644 index 000000000..34bbb7392 --- /dev/null +++ b/packages/prism/src/prolog/core/message.pl @@ -0,0 +1,194 @@ +%% -*- Prolog -*- + +%%======================================================================== +%% +%% This module contains a set of error and warning messages displayed in +%% the Prism system. Each message entry has the following form: +%% +%% $pp_message(ID,Type,Message) +%% +%% is a positive integer that identifies the message. +%% +%% denotes the message type, which is one of the following: +%% +%% * fatal +%% * inter(nal error) +%% * error +%% * fail +%% * warn +%% * obsol(ete) +%% * info +%% +%% is (to be written). +%% +%%======================================================================== + +%% +%% Errors related to probabilistic models +%% + +% Errors related to probabilities +$pp_message($msg(0000),error,"invalid probability -- {1}"). +$pp_message($msg(0001),error,"invalid probability list -- {1}"). +$pp_message($msg(0002),error,"invalid ratio list -- {1}"). +$pp_message($msg(0003),error,"invalid probabilistic atomic formula -- {1}"). +$pp_message($msg(0004),error,"invalid user-defined probabilistic atomic formula -- {1}"). +$pp_message($msg(0005),error,"invalid extended probabilistic atomic formula -- {1}"). +$pp_message($msg(0006),error,"invalid tabled probabilistic atomic formula -- {1}"). +$pp_message($msg(0007),error,"invalid probabilistic callable -- {1}"). + +% Errors related to random switches +$pp_message($msg(0100),error,"no multi-valued switch declarations given"). +$pp_message($msg(0101),error,"non-ground switch name -- {1}"). +$pp_message($msg(0102),error,"outcome space not given -- {1}"). +$pp_message($msg(0103),error,"probability distribution not given -- {1}"). +$pp_message($msg(0104),error,"hyperparameters not given -- {1}"). +%$pp_message($msg(0105),error,""). +$pp_message($msg(0106),error,"modified outcome space; probabilities expected to be unfixed -- {1}"). +$pp_message($msg(0107),error,"modified outcome space; obsolete expectations -- {1}"). +$pp_message($msg(0108),error,"modified outcome space; hyperparameters expected to be unfixed -- {1}"). +$pp_message($msg(0109),warn, "distribution fixed -- {1}"). +$pp_message($msg(0110),warn, "hyperparameters fixed -- {1}"). + +% Errors related to distribution +$pp_message($msg(0200),error,"invalid distribution -- {1}"). +$pp_message($msg(0201),error,"invalid hyperparameters -- {1}"). +$pp_message($msg(0202),error,"default distribution unavailable"). +$pp_message($msg(0203),error,"default hyperparameters unavailable"). +$pp_message($msg(0204),error,"invalid number of outcomes -- {1}"). +$pp_message($msg(0205),error,"invalid switch configuration -- {1}"). +%$pp_message($msg(0206),error,""). +%$pp_message($msg(0207),error,""). +$pp_message($msg(0208),error,"invalid alpha values -- {1}"). +$pp_message($msg(0209),error,"invalid delta values -- {1}"). +$pp_message($msg(0210),error,"distribution does not match -- ({1},{2})"). +$pp_message($msg(0211),error,"size unmatched -- ({1},{2})"). + +%% +%% Errors related to built-ins for probabilistic inferences +%% + +% Errors in loading +$pp_message($msg(1000),error,"invalid filename -- {1}"). +$pp_message($msg(1001),error,"invalid PRISM option -- {1}"). +$pp_message($msg(1002),warn, "tabling disabled in the consultation mode"). +$pp_message($msg(1003),error,"batch file not specified"). +$pp_message($msg(1004),error,"prism_main/0-1 undefined -- {1}"). +$pp_message($msg(1005),error,"invalid module for prism or upprism"). + +% Errors in translation +$pp_message($msg(1100),fail ,"bad or duplicate predicate -- {1}"). +$pp_message($msg(1101),error,"co-existing p_table and p_not_table declarations"). +$pp_message($msg(1102),error,"invalid predicate indicator -- {1}"). +$pp_message($msg(1103),error,"invalid call in write_call"). +$pp_message($msg(1104),warn, "parameters left unset/unfixed; ground terms expected -- values_x({1},_,{2})"). +$pp_message($msg(1105),error,"invalid outcome space; ground list expected"). + +% Errors in sampling +$pp_message($msg(1201),error,"invalid goal; probabilistic goal expected -- {1}"). +$pp_message($msg(1202),error,"invalid constraint; callable term expected -- {1}"). +$pp_message($msg(1203),error,"invalid number of samples; positive integer expected -- {1}"). +$pp_message($msg(1204),error,"invalid number of trials; `inf' or positive integer expected -- {1}"). + +% Errors in EM learning +$pp_message($msg(1300),error,"no observed data; the data_source flag set to `none'"). +$pp_message($msg(1301),error,"no observed data; data/1 undefined"). +$pp_message($msg(1302),error,"invalid observed data -- {1}"). +$pp_message($msg(1303),error,"invalid observed goal; tabled probabilistic atomic formula expected -- {1}"). +$pp_message($msg(1304),error,"no explanations -- {1}"). +$pp_message($msg(1305),error,"DAEM not applicable to models with failure"). +$pp_message($msg(1306),error,"invalied goal count; positive integer expected -- {1}"). + +% Errors in other probabilistic inferences +$pp_message($msg(1400),error,"invalid number of top-ranked expls; positive integer expected -- {1}"). +$pp_message($msg(1401),error,"invalid number of intermediate candidate expls; positive integer expected -- {1}"). +$pp_message($msg(1402),error,"invalid subgoal aggregation pattern -- {1}"). +$pp_message($msg(1403),error,"invalid subgoal pattern -- {1}"). +$pp_message($msg(1404),warn, "subgoals unmatched"). +$pp_message($msg(1405),error,"invalid subgoal argument; integer expected -- {1}"). +$pp_message($msg(1406),error,"invalid subgoal argument; atom expected -- {1}"). +$pp_message($msg(1407),error,"invalid subgoal argument; ground compound expected -- {1}"). +$pp_message($msg(1408),error,"invalid subgoal argument; list expected -- {1}"). +$pp_message($msg(1409),error,"invalid subgoal argument; d-list expected -- {1}"). + +%% +%% Errors related to built-ins for auxiliary operations +%% + +% Errors in random operations +$pp_message($msg(2000),error,"invalid random seed -- {1}"). +$pp_message($msg(2001),error,"invalid random state -- {1}"). +$pp_message($msg(2002),error,"invalid max value; positive integer expected -- {1}"). +$pp_message($msg(2003),error,"invalid min value; integer expected -- {1}"). +$pp_message($msg(2004),error,"invalid max value; integer expected -- {1}"). +$pp_message($msg(2005),error,"invalid max value; positive number expected -- {1}"). +$pp_message($msg(2006),error,"invalid min value; number expected -- {1}"). +$pp_message($msg(2007),error,"invalid max value; number expected -- {1}"). +$pp_message($msg(2008),error,"invalid min/max pair -- ({1},{2})"). +$pp_message($msg(2009),error,"invalid mu; number expected -- {1}"). +$pp_message($msg(2010),error,"invalid sigma; positive number expected -- {1}"). +$pp_message($msg(2011),error,"invalid elements; list expected -- {1}"). +$pp_message($msg(2012),error,"invalid number of selections; integer expected -- {1}"). +$pp_message($msg(2013),error,"number of selections out of range -- {1}"). +$pp_message($msg(2014),error,"invalid number of groups; positive integer expected -- {1}"). + +% Errors in list handling +$pp_message($msg(2100),error,"invalid predicate name -- {1}"). +$pp_message($msg(2101),error,"invalid unary operator -- {1}"). +$pp_message($msg(2102),error,"invalid binary operator -- {1}"). +$pp_message($msg(2103),error,"invalid argument; list not shorter than {2} expected -- {1}"). +$pp_message($msg(2104),error,"invalid argument; list expected -- {1}"). +$pp_message($msg(2105),error,"invalid argument; non-negative integer expected -- {1}"). +$pp_message($msg(2106),error,"invalid argument; positive integer expected -- {1}"). +$pp_message($msg(2107),error,"invalid agglist operation -- {1}"). +$pp_message($msg(2108),error,"invalid argument; list of numbers expected -- {1}"). +$pp_message($msg(2109),error,"invalid argument; list or nil expected -- {1}"). +$pp_message($msg(2110),error,"invalid argument; list of non-variables expected -- {1}"). + +%% +%% Miscellaneous errors +%% + +% File I/Os +$pp_message($msg(3000),error,"invalid file specification -- {1}"). +$pp_message($msg(3001),error,"file not found -- {1}"). +$pp_message($msg(3002),error,"unknown or illegal option -- {1}"). +$pp_message($msg(3003),error,"duplicate option -- {1}"). +$pp_message($msg(3004),error,"no information on the last observation"). +$pp_message($msg(3005),error,"too few rows"). +$pp_message($msg(3006),error,"too few columns"). +$pp_message($msg(3007),error,"parsing failure in CSV format"). +$pp_message($msg(3008),warn, "too few rows compared to the specification"). + +% Execution flags +$pp_message($msg(3100),error,"invalid prism flag -- {1}"). +$pp_message($msg(3101),error,"invalid value for {1} -- {2}"). +$pp_message($msg(3102),warn, "prism flag replaced by {2} -- {1}"). +$pp_message($msg(3103),error,"prism flag deleted in version {2} -- {1}"). +$pp_message($msg(3104),error,"prism flag value deleted in version {2} -- {1}"). +$pp_message($msg(3105),warn, "prism flag value replaced by {2} -- {1}"). + +% Write calls +$pp_message($msg(3200),error,"control constructs (other than conjunction) disallowed -- {1}"). + +% Deprecated predicates +$pp_message($msg(3300),warn, "predicate replaced by {2} -- {1}"). +$pp_message($msg(3301),warn, "predicate deprecated -- {1}"). + +% Math predicates +$pp_message($msg(3400),error,"invalid argument -- {1}"). + +%% +%% System-related errors +%% + +% Internal errors +$pp_message($msg(9800),inter,"error term not found"). +$pp_message($msg(9801),inter,"error message not found"). +$pp_message($msg(9802),inter,"invalid internal representation"). +$pp_message($msg(9803),inter,"unmatched branches"). +$pp_message($msg(9804),inter,"unexpected failure"). +$pp_message($msg(9805),inter,"failure in hash-id registration -- {1}"). + +% Fatal errors +$pp_message($msg(9900),fatal,"assertion failure -- {1}"). diff --git a/packages/prism/src/prolog/core/random.pl b/packages/prism/src/prolog/core/random.pl new file mode 100644 index 000000000..7be947e90 --- /dev/null +++ b/packages/prism/src/prolog/core/random.pl @@ -0,0 +1,286 @@ +%%---------------------------------------- + +% +% vsc: done in prism.yap +% :- random_set_seed. + +%%---------------------------------------- + +random_get_seed(Seed) :- + global_get($pg_random_seed,Seed),!. + +random_set_seed :- + $pc_random_auto_seed(Seed), + random_set_seed(Seed),!. + +random_set_seed(Seed) :- + $pp_require_random_seed(Seed,$msg(2000),random_set_seed/1), + ( integer(Seed) -> $pc_random_init_by_seed(Seed) + ; Seed ?= [_|_] -> $pc_random_init_by_list(Seed) + ; %% else + $pp_assert(fail) + ), !, + global_set($pg_random_seed,Seed),!. + +random_get_state(State) :- + $pc_random_get_state(State),!. + +random_set_state(State) :- + $pp_require_random_state(State,$msg(2001),random_set_state/1), + $pc_random_set_state(State),!. + +% deprecated predicates: + +set_seed(Seed) :- + $pp_raise_warning($msg(3300),[set_seed/1,random_set_seed/1]), + random_set_seed(Seed). + +set_seed_time :- + $pp_raise_warning($msg(3300),[set_seed_time/0,random_set_seed/0]), + random_set_seed. + +set_seed_time(Seed) :- + $pp_raise_warning($msg(3301),[set_seed_time/1]), + random_set_seed, + random_get_seed(Seed). + +%%---------------------------------------- + +random_int(Max,Value) :- + $pp_require_positive_integer(Max,$msg(2002),random_int/2), + $pc_random_int(Max,Value). + +random_int(Min,Max,Value) :- + $pp_require_integer(Min,$msg(2003),random_int/3), + $pp_require_integer(Max,$msg(2004),random_int/3), + $pp_require_integer_range(Min,Max,$msg(2008),random_int/3), + Max1 is Max - 1, + $pc_random_int(Min,Max1,Value). + +random_int_incl(Min,Max,Value) :- + $pp_require_integer(Min,$msg(2003),random_int_incl/3), + $pp_require_integer(Max,$msg(2004),random_int_incl/3), + $pp_require_integer_range_incl(Min,Max,$msg(2008),random_int/3), + $pc_random_int(Min,Max,Value). + + +random_int_excl(Min,Max,Value) :- + $pp_require_integer(Min,$msg(2003),random_int_excl/3), + $pp_require_integer(Max,$msg(2004),random_int_excl/3), + $pp_require_integer_range_excl(Min,Max,$msg(2008),random_int/3), + Min1 is Min + 1, + Max1 is Max - 1, + $pc_random_int(Min1,Max1,Value). + +%%---------------------------------------- + +random_uniform(Value) :- + $pc_random_float(Value). + +random_uniform(Max,Value) :- + $pp_require_positive_number(Max,$msg(2005),random_uniform/2), + $pc_random_float(Value0), + Value is Value0 * Max. + +random_uniform(Min,Max,Value) :- + $pp_require_number(Min,$msg(2006),random_uniform/3), + $pp_require_number(Max,$msg(2007),random_uniform/3), + $pp_require_number_range_excl(Min,Max,$msg(2008),random_uniform/3), + $pc_random_float(Value0), + Value is Value0 * (Max - Min) + Min. + +random_gaussian(Value) :- + $pc_random_gaussian(Value). + +random_gaussian(Mu,Sigma,Value) :- + $pp_require_number(Mu,$msg(2009),random_gaussian/3), + $pp_require_positive_number(Sigma,$msg(2010),random_gaussian/3), + $pc_random_gaussian(Value0), + Value is Value0 * Sigma + Mu. + +%%---------------------------------------- + +random_select(List,Value) :- + random_select(List,uniform,Value). + +random_select(List,Dist,Value) :- + $pp_require_list(List,$msg(2011),random_select/3), + $pp_require_distribution(Dist,$msg(0200),random_select/3), + expand_values(List,List1), + length(List1,L1), + $pp_spec_to_ratio(Dist,L1,Ratio,random_select/3), + length(Ratio,L2), + ( L1 is L2 -> true + ; $pp_raise_runtime_error($msg(0210),[List,Dist],unmatched_distribution, + random_select/3) + ), + sumlist(Ratio,Sum), + random_uniform(Sum,Rand),!, + $pp_random_select(Ratio,List1,Rand,Value). + +$pp_random_select([X|Xs],[Y|Ys],R,Value) :- + ( R >= X, Xs ?= [_|_] -> + R1 is R - X, !, $pp_random_select(Xs,Ys,R1,Value) + ; Y = Value + ),!. + +% deprecated predicates: + +dice(List,Value) :- + $pp_raise_warning($msg(3300),[dice/2,random_select/2]), + random_select(List,Value). + +dice(List,Dist,Value) :- + $pp_raise_warning($msg(3300),[dice/3,random_select/3]), + random_select(List,Dist,Value). + +%%---------------------------------------- + +random_multiselect(List,N,Result) :- + $pp_require_list(List,$msg(2011),random_multiselect/3), + $pp_require_integer(N,$msg(2012),random_multiselect/3), + length(List,L), + ( \+ ( 1 =< N, N =< L ) -> + $pp_raise_runtime_error($msg(2013),[N], + invalid_argument,random_multiselect/3) + ; true + ), !, + new_bigarray(Elems,L), + new_bigarray(Flags,L), + $pp_random_multiselect1(1,L,Elems,Flags), + M is L - N, + ( N =< M -> + $pp_random_multiselect2(1,N,L,Elems,Flags), + $pp_random_multiselect3(1,1,Flags,List,Result) + ; $pp_random_multiselect2(1,M,L,Elems,Flags), + $pp_random_multiselect3(1,0,Flags,List,Result) + ). + +$pp_random_multiselect1(K,L,_,_), K > L => + true. +$pp_random_multiselect1(K,L,Elems,Flags), K =< L => + bigarray_put(Elems,K,K), + bigarray_put(Flags,K,0), + K1 is K + 1, !, + $pp_random_multiselect1(K1,L,Elems,Flags). + +$pp_random_multiselect2(K,N,_,_,_), K > N => + true. +$pp_random_multiselect2(K,N,L,Elems,Flags), K =< N => + random_int_incl(K,L,J), + bigarray_get(Elems,K,VK), + bigarray_get(Elems,J,VJ), + bigarray_put(Elems,J,VK), + bigarray_put(Elems,K,VJ), + bigarray_put(Flags,VJ,1), + K1 is K + 1, !, + $pp_random_multiselect2(K1,N,L,Elems,Flags). + +$pp_random_multiselect3(_,_,_,Xs,Ys), Xs = [] => + Ys = []. +$pp_random_multiselect3(K,Query,Flags,Xs,Ys), Xs = [X|Xs1] => + ( bigarray_get(Flags,K,Query) -> Ys = [X|Ys1] ; Ys = Ys1 ), + K1 is K + 1, !, + $pp_random_multiselect3(K1,Query,Flags,Xs1,Ys1). + +%%---------------------------------------- + +random_group(List,N,Result) :- + $pp_require_list(List,$msg(2011),random_group/3), + $pp_require_positive_integer(N,$msg(2014),random_group/3), + List = List1, + new_bigarray(Array,N), + $pp_random_group1(1,N,Array), + $pp_random_group2(List1,N,Array), + $pp_random_group3(1,N,Array,Result). + +$pp_random_group1(K,N,_), K > N => + true. +$pp_random_group1(K,N,Array), K =< N => + bigarray_put(Array,K,Xs-Xs), + K1 is K + 1, !, + $pp_random_group1(K1,N,Array). + +$pp_random_group2(Xs,_,_), Xs = [] => + true. +$pp_random_group2(Xs,N,Array), Xs = [X|Xs1] => + $pc_random_int(N,Z0), + Z is Z0 + 1, + bigarray_get(Array,Z,Ys0-Ys1), + Ys1 = [X|Ys2], + bigarray_put(Array,Z,Ys0-Ys2), !, + $pp_random_group2(Xs1,N,Array). + +$pp_random_group3(K,N,_,Xs), K > N => + Xs = []. +$pp_random_group3(K,N,Array,Xs), K =< N => + bigarray_get(Array,K,X-[]), + Xs = [X|Xs1], + K1 is K + 1, !, + $pp_random_group3(K1,N,Array,Xs1). + +%%---------------------------------------- + +random_shuffle(List0,List) :- + $pp_require_list(List0,$msg(2011),random_shuffle/3), + list_to_bigarray(List0,Array), + bigarray_length(Array,Size), + $pp_random_shuffle(1,Size,Array), + bigarray_to_list(Array,List). + +$pp_random_shuffle(K,N,_), K > N => + true. +$pp_random_shuffle(K,N,Array), K =< N => + random_int_incl(K,N,J), + bigarray_get(Array,K,VK), + bigarray_get(Array,J,VJ), + bigarray_put(Array,J,VK), + bigarray_put(Array,K,VJ), + K1 is K + 1, !, + $pp_random_shuffle(K1,N,Array). + +%%---------------------------------------- + +$pp_require_random_seed(X,ID,Source) :- + ( $pp_test_random_seed(X) -> true + ; $pp_raise_on_require([X],ID,Source,$pp_error_random_seed) + ). + +$pp_test_random_seed(X), integer(X) => true. +$pp_test_random_seed(X), X = [Y], integer(Y) => true. +$pp_test_random_seed(X), X = [Y|Z], integer(Y) => + Z ?= [_|_], + $pp_test_random_seed(Z). + +$pp_error_random_seed(X,instantiation_error) :- + \+ ground(X), !. +$pp_error_random_seed(X,domain_error(random_seed,X)) :- + \+ $pp_test_random_seed(X), !. + +%%---------------------------------------- + +$pp_require_random_state(X,ID,Source) :- + ( $pp_test_random_state(X) -> + true + ; $pp_raise_on_require([X],ID,Source,$pp_error_random_state) + ). + +$pp_test_random_state(X) :- + functor(X,$randstate,833), + $pp_test_random_state(X,1). + +$pp_test_random_state(_,N), N > 833 => true. +$pp_test_random_state(X,N), N =< 833 => + arg(N,X,Arg), + integer(Arg), + N1 is N + 1, + $pp_test_random_state(X,N1). + +$pp_error_random_state(X,instantiation_error) :- + \+ ground(X), !. +$pp_error_random_state(X,type_error(compound,X)) :- + \+ compound(X), !. +$pp_error_random_state(X,domain_error(random_state,X)) :- + \+ $pp_test_random_state(X), !. + +%%---------------------------------------- diff --git a/packages/prism/src/prolog/mp/mp_learn.pl b/packages/prism/src/prolog/mp/mp_learn.pl new file mode 100644 index 000000000..a89110ef9 --- /dev/null +++ b/packages/prism/src/prolog/mp/mp_learn.pl @@ -0,0 +1,151 @@ +:- $pp_require_mp_mode. + +%%---------------------------------------- + +$pp_learn_core(Mode) :- + ( $pc_mp_master -> $pp_mpm_learn_main(Mode) ; true ). +$pp_learn_core(Mode,Goals) :- + ( $pc_mp_master -> $pp_mpm_learn_main(Mode,Goals) ; true ). + +$pp_mpm_learn_main(Mode) :- + learn_data_file(FileName), + load_clauses(FileName,Goals,[]), + $pc_mpm_bcast_command($pp_mps_learn_core(Mode)),!, + $pp_mpm_learn_core(Mode,Goals). + +$pp_mpm_learn_main(Mode,Goals) :- + $pp_learn_check_goals(Goals), + $pc_mpm_bcast_command($pp_mps_learn_core(Mode)),!, + $pp_mpm_learn_core(Mode,Goals). + +%%---------------------------------------- + +% Master +$pp_mpm_learn_core(Mode,Goals) :- + $pc_mp_sync(2,1), + $pc_mp_wtime(Start), + $pp_learn_clean_info, + $pp_learn_reset_hparams(Mode), + $pp_build_count_pairs(Goals,GoalEqCountPairs), + $pp_learn_message(MsgS,MsgE,MsgT,MsgM), + $pc_set_em_message(MsgE), + $pc_mp_wtime(StartExpl), + global_set($pg_num_goals,0), + $pc_mpm_share_prism_flags, + $pp_mpm_find_explanations(GoalEqCountPairs,GoalCountPairs),!, + global_set($pg_observed_facts,GoalCountPairs), + $pp_print_num_goals(MsgS), + $pc_mp_wtime(EndExpl), + TableSpace = 'N/A', + ( MsgM == 0 -> true + ; format("Gathering and exporting switch information ...~n",[]) + ), + $pc_mp_recv_switches, + $pp_mpm_export_switches, + $pc_mpm_alloc_occ_switches, + $pc_mp_send_swlayout, + $pp_collect_init_switches(Sws), + $pc_export_sw_info(Sws), + $pc_mp_wtime(StartEM), + $pp_mpm_em(Mode,Output), + $pc_mp_wtime(EndEM), + $pc_import_occ_switches(NewSws,NumSwitches,NumSwVals), + $pp_decode_update_switches(Mode,NewSws), + $pc_mpm_import_graph_stats(NumSubgraphs,NumGoalNodes,NumSwNodes,AvgShared), + $pc_mp_wtime(End), + $pp_assert_graph_stats(NumSubgraphs,NumGoalNodes,NumSwNodes,AvgShared), + $pp_assert_learn_stats(Mode,Output,NumSwitches,NumSwVals,TableSpace, + Start,End,StartExpl,EndExpl,StartEM,EndEM,1), + ( MsgT == 0 -> true ; $pp_print_learn_stats_message ), + ( MsgM == 0 -> true ; $pp_print_learn_end_message(Mode) ),!. + +% Slave +$pp_mps_learn_core(Mode) :- + $pc_mp_sync(2,1), + $pp_learn_clean_info, + $pc_mps_share_prism_flags, + $pp_mps_find_explanations(GoalCountPairs), + global_set($pg_observed_facts,GoalCountPairs), + $pp_collect_init_switches(_Sws), + $pp_observed_facts(GoalCountPairs,GoalIdCountPairs,0,Len,0,NumOfGoals,-1,FailRootIndex), + $pc_prism_prepare(GoalIdCountPairs,Len,NumOfGoals,FailRootIndex), + $pc_mp_send_switches, + $pc_mp_recv_swlayout, + $pp_mps_em(Mode), + $pc_mps_import_graph_stats,!. + +%%---------------------------------------- + +$pp_mpm_em(params,Output) :- + $pc_mpm_prism_em(Iterate,LogPost,LogLike,BIC,CS,ModeSmooth), + Output = [Iterate,LogPost,LogLike,BIC,CS,ModeSmooth]. +$pp_mpm_em(hparams,Output) :- + $pc_mpm_prism_vbem(IterateVB,FreeEnergy), + Output = [IterateVB,FreeEnergy]. +$pp_mpm_em(both,Output) :- + $pc_mpm_prism_both_em(IterateVB,FreeEnergy), + Output = [IterateVB,FreeEnergy]. + +$pp_mps_em(params) :- + $pc_mps_prism_em. +$pp_mps_em(hparams) :- + $pc_mps_prism_vbem. +$pp_mps_em(both) :- + $pc_mps_prism_both_em. + +%%---------------------------------------- + +$pp_mpm_find_explanations(GoalEqCountPairs,GoalCountPairs) :- + $pp_learn_message(MsgS,_,_,_), + $pp_mpm_expl_goals(MsgS,GoalEqCountPairs,GoalCountPairs), + $pc_mp_size(N), + $pp_mpm_expl_complete(N). + +$pp_mpm_expl_goals(_,[],[]). +$pp_mpm_expl_goals(MsgS, + [Goal=Count|GoalEqCountPairs], + [goal(Goal,Count)|GoalCountPairs]) :- + $pc_mp_send_goal(Goal=Count), + $pp_print_goal_message(MsgS),!, + $pp_mpm_expl_goals(MsgS,GoalEqCountPairs,GoalCountPairs). + +$pp_mpm_expl_complete(N) :- + N =< 1,!. +$pp_mpm_expl_complete(N) :- + $pc_mp_send_goal($done), + N1 is N - 1,!, + $pp_mpm_expl_complete(N1). + +%%---------------------------------------- + +$pp_mps_find_explanations(GoalCountPairs) :- + $pp_mps_expl_goals([],GoalCountPairs). + +$pp_mps_expl_goals(GoalCountPairs0,GoalCountPairs) :- + once($pc_mp_recv_goal(GoalEqCountPair)), + GoalEqCountPair \== $done,!, + GoalEqCountPair = (Goal=Count), + $pp_build_dummy_goal(Goal,DummyGoal), + ( $pp_expl_one_goal(DummyGoal) -> true + ; mps_err_msg("Failed to find solutions for ~w.",[Goal]) + ), + GoalCountPairs1 = [goal(DummyGoal,Count)|GoalCountPairs0], + $pc_sleep(1), % enable this for the stability in small-scale learning + !, + $pp_mps_expl_goals(GoalCountPairs1,GoalCountPairs). +$pp_mps_expl_goals(GoalCountPairs,GoalCountPairs). + +%%---------------------------------------- + +$pp_mpm_export_switches :- + $pc_prism_sw_count(N), + $pp_mpm_export_switches(0,N). + +$pp_mpm_export_switches(Sid,N) :- + Sid >= N,!. +$pp_mpm_export_switches(Sid,N) :- + $pc_prism_sw_term(Sid,Sw), + $pp_get_values(Sw,Values), + $pp_export_switch(Sid,Sw,Values), + Sid1 is Sid + 1,!, + $pp_mpm_export_switches(Sid1,N). diff --git a/packages/prism/src/prolog/mp/mp_main.pl b/packages/prism/src/prolog/mp/mp_main.pl new file mode 100644 index 000000000..ef1d57064 --- /dev/null +++ b/packages/prism/src/prolog/mp/mp_main.pl @@ -0,0 +1,112 @@ +:- $pp_require_mp_mode. +:- $pc_mp_master -> print_copyright ; true. + +%%------------------------------------------------------------------------ +%% [[ Tags for $pc_mp_sync/2 ]] +%%------------------------------------------------------------------------ +%% 01 : $pp_batch_call +%% 02 : $pp_mp_call_s_core +%% 03 : $pp_compile_load +%% 04 : $pp_foc +%%------------------------------------------------------------------------ + +%%---------------------------------------- +%% batch routines + +main :- $pp_batch. + +%$pp_batch_call(Goal) :- +% $pc_mp_master -> $pp_mpm_batch_call(Goal) ; $pp_mps_batch_call. + +$pp_batch_call(Goal) :- + ( $pc_mp_master -> $pp_mpm_batch_call(Goal) + ; $pp_mps_batch_call + ). + +$pp_mpm_batch_call(Goal) :- + ( call(Goal) -> Sync = 1 ; Sync = -1 ), + $pc_mpm_bcast_command($stop),!, + ( $pc_mp_sync(1,Sync) -> Res = yes ; Res = no ), + format("~n~w~n",[Res]). + +$pp_mps_batch_call :- + ( $pp_slave_loop -> Sync = 1 ; Sync = -1 ),!, + ( $pc_mp_sync(1,Sync) ; true ). + +$pp_slave_loop :- + $pc_mps_bcast_command(Cmd), + ( Cmd \== $stop -> call(Cmd),!,$pp_slave_loop + ; true + ). + +%%---------------------------------------- +%% system predicates + +abort :- $pc_mp_abort. + +$pp_mps_err_msg(Msg) :- + $pc_mps_revert_stdout, err_msg(Msg). +$pp_mps_err_msg(Fmt,Args) :- + $pc_mps_revert_stdout, err_msg(Fmt,Args). + +$pp_load(File) :- + $pp_mp_call_s_core(\+ \+ $myload(File)), + $pp_init_tables. + +$pp_compile_load(File) :- + $pp_add_out_extension(File,OutFile), + ( $pc_mp_master -> $pp_compile(File,_DmpFile,OutFile) ; true ),!, + $pc_mp_sync(3,1), + $pp_load(OutFile). +$pp_compile_load(_File) :- + $pc_mp_sync(3,-1). + +$pp_foc(File1,File2) :- + ( $pc_mp_master -> + fo(File1,File2), format("Compilation done by FOC~n~n",[]) + ; true + ),!, + $pc_mp_sync(4,1). +$pp_foc(_,_) :- + $pc_mp_sync(4,-1). + +%%---------------------------------------- +%% user predicates + +mp_call(Goal) :- + $pc_mpm_bcast_command(Goal),call(Goal). +mp_call_s(Goal) :- + $pc_mpm_bcast_command($pp_mp_call_s_core(Goal)),$pp_mp_call_s_core(Goal). + +$pp_mp_call_s_core(Goal) :- + $pc_mp_rank(R), + $pc_mp_size(N), + $pp_mp_call_s_core(Goal,R,N,0). + +$pp_mp_call_s_core(_,_,N,K) :- + K >= N,!. +$pp_mp_call_s_core(Goal,MyID,N,K) :- + ( K =:= MyID -> + ( call(Goal) -> Sync = K ; Sync = -1 ) + ; % else + Sync = K + ), + $pc_mp_sync(2,Sync), + K1 is K + 1,!, + $pp_mp_call_s_core(Goal,MyID,N,K1). + +%%---------------------------------------- +%% debug predicates + +$pp_mp_debug(Format,Args) :- + current_output(Stream), + $pp_mp_debug(Stream,Format,Args). + +$pp_mp_debug(Stream,Format,Args) :- + $pc_mp_rank(R), + append("[RANK:~w] ",Format,NewFormat), + NewArgs = [R|Args], + format(Stream,NewFormat,NewArgs),!. + +%%---------------------------------------- + diff --git a/packages/prism/src/prolog/prism.yap b/packages/prism/src/prolog/prism.yap new file mode 100644 index 000000000..0bf337821 --- /dev/null +++ b/packages/prism/src/prolog/prism.yap @@ -0,0 +1,50 @@ + +% interface to prism from YAP + +:- ensure_loaded(library(dialect/bprolog)). +% :- set_prolog_flag(tabling_mode, local). +:- load_foreign_files([prism], [], bp4p_register_preds). /* load prism stuff */ +:- style_check(-discontiguous). /* load prism stuff */ + +:- include('prism/core/message.pl'). +:- include('prism/core/error.pl'). +:- include('prism/core/random.pl'). +:- include('prism/core/format.pl'). + +:- include('prism/up/dynamic.pl'). +:- include('prism/up/main.pl'). +:- include('prism/up/switch.pl'). +:- include('prism/up/learn.pl'). +:- include('prism/up/prob.pl'). +:- include('prism/up/viterbi.pl'). +:- include('prism/up/hindsight.pl'). +:- include('prism/up/expl.pl'). +:- include('prism/up/sample.pl'). +:- include('prism/up/dist.pl'). +:- include('prism/up/list.pl'). +:- include('prism/up/hash.pl'). +:- include('prism/up/flags.pl'). +:- include('prism/up/util.pl'). +:- include('prism/up/bigarray.pl'). + +:- include('prism/trans/trans.pl'). +:- include('prism/trans/dump.pl'). +:- include('prism/trans/verify.pl'). +:- include('prism/trans/bpif.pl'). + +%PL_BAT = up/batch.pl + +%PL_MP = mp/mp_main.pl \ +% mp/mp_learn.pl + +:- include('prism/bp/eval.pl'). + +:- initialization(init). + +init :- + ( $pc_mp_mode -> true ; print_copyright ), + random_set_seed, + reset_prism_flags. + + + diff --git a/packages/prism/src/prolog/trans/bpif.pl b/packages/prism/src/prolog/trans/bpif.pl new file mode 100644 index 000000000..e8c34b3e5 --- /dev/null +++ b/packages/prism/src/prolog/trans/bpif.pl @@ -0,0 +1,53 @@ +%% -*- Prolog -*- + +/* +======================================================================== + +This module provides a simple interface to the B-Prolog compiler. +In the following description, denotes a program represented in +the B-Prolog internal form (i.e. a list of pred/6). + +$pp_bpif_read_program(-Prog,+File) :- + Loads from . + +$pp_bpif_compile_program(+Prog,+File) :- + Compiles and saves the resultant byte-code into . + +======================================================================== +*/ + +%%-------------------------------- +%% Entry Point + +$pp_bpif_read_program(Prog,File) :- + getclauses1(File,Prog,0). + +$pp_bpif_compile_program(Prog0,File) :- + $pp_preproc_program(Prog0,Prog1), + phase_1_process(Prog1,Prog2), + compileProgToFile(_,File,Prog2). + + +%%-------------------------------- +%% Preprocessing + +$pp_preproc_program(Prog0,Prog1) :- + new_hashtable(AuxTable), + $pp_preproc_program(Prog0,Prog1,AuxTable,0). + +$pp_preproc_program(Prog0,Prog1,AuxTable,K), + Prog0 = [pred(F,N,M,D,T,Cls0)|Prog0R] => + Prog1 = [pred(F,N,M,D,T,Cls1)|Prog1R], + $pp_preproc_clauses(Cls0,Cls1,AuxTable,K,NewK), + $pp_preproc_program(Prog0R,Prog1R,AuxTable,NewK). +$pp_preproc_program(Prog0,Prog1,AuxTable,_), + Prog0 = [] => + hashtable_values_to_list(AuxTable,Prog1). + +$pp_preproc_clauses(Cls0,Cls1,AuxTable,K,NewK), Cls0 = [Cl0|Cls0R] => + Cls1 = [Cl1|Cls1R], + preprocess_cl(Cl0,Cl1,AuxTable,K,TmpK,1), + $pp_preproc_clauses(Cls0R,Cls1R,AuxTable,TmpK,NewK). +$pp_preproc_clauses(Cls0,Cls1,_,K,NewK), Cls0 = [] => + Cls1 = [], + K = NewK. diff --git a/packages/prism/src/prolog/trans/dump.pl b/packages/prism/src/prolog/trans/dump.pl new file mode 100644 index 000000000..4c82f083e --- /dev/null +++ b/packages/prism/src/prolog/trans/dump.pl @@ -0,0 +1,150 @@ +%% -*- Prolog -*- + +%%====================================================================== +%% +%% This module provides a pretty-printer for programs. In the following +%% preidcates, should be a valid program in the B-Prolog internal +%% form; otherwise they would behave in an unexpected way. +%% +%% $pp_dump_program(Prog) :- +%% Writes into the current output stream. +%% +%% $pp_dump_program(S,Prog) :- +%% Writes into the stream . +%% +%% $pp_save_program(Prog,File) :- +%% Writes into . +%% +%%====================================================================== + +%%-------------------------------- +%% Entry Point + +$pp_dump_program(Prog) :- + current_output(S), $pp_dump_program(S,Prog). + +$pp_save_program(Prog,File) :- + open(File,write,S), $pp_dump_program(S,Prog), close(S). + +$pp_dump_program(S,Prog) :- + $pp_dump_split(Prog,Damon,Preds), + $pp_dump_damon(S,Damon), + $pp_dump_decls(S,Preds), + $pp_dump_preds(S,Preds). + + +%%-------------------------------- +%% Separator + +$pp_dump_nl(S,L) :- + var(L), !, + nl(S), + L = 1. +$pp_dump_nl(_,L) :- + nonvar(L), !. + + +%%-------------------------------- +%% Split $damon_load/0 + +$pp_dump_split(Prog,Damon,Preds) :- + Q = pred($damon_load,0,_,_,_,[($damon_load :- Damon)|_]), + select(Q,Prog,Preds), !. + + +%%-------------------------------- +%% Start-up Queries + +$pp_dump_damon(S,Damon) :- + $pp_dump_damon(S,Damon,_). + +$pp_dump_damon(S,Damon,L) :- + Damon = (A,B), !, + $pp_dump_damon(S,A,L), + $pp_dump_damon(S,B,L). +$pp_dump_damon(_,Damon,_) :- + Damon = true, !. +$pp_dump_damon(S,Damon,L) :- + Damon = $query(Query), !, + $pp_dump_nl(S,L), + \+ \+ $pp_dump_query(S,Query). + +$pp_dump_query(S,Query) :- + prettyvars(Query), + format(S,":- ~k.~n",[Query]). + + +%%-------------------------------- +%% Declarations + +$pp_dump_decls(S,Preds) :- + $pp_dump_m_decls(S,Preds,_), + $pp_dump_t_decls(S,Preds,_). + + +%%-------------------------------- +%% Mode Declarations + +$pp_dump_m_decls(_,Preds,_) :- Preds == [], !. +$pp_dump_m_decls(S,Preds,L) :- Preds = [Pred|Preds1], !, + Pred = pred(F,N,M,_,_,_), + $pp_dump_m_decl(S,F,N,M,L), + $pp_dump_m_decls(S,Preds1,L). + +$pp_dump_m_decl(_,_,_,M,_) :- var(M), !. +$pp_dump_m_decl(S,F,N,M,L) :- M = [_|_], !, + $pp_dump_nl(S,L), + format(S,":- mode ~q(",[F]), + $pp_dump_m_spec(S,N,M), + format(S,").~n",[]). + +$pp_dump_m_spec(S,N,Mode) :- N == 1, !, + Mode = [M], + $pp_mode_symbol(M,Sym), !, % M can be an unbound variable + write(S,Sym). +$pp_dump_m_spec(S,N,Mode) :- N >= 2, !, + Mode = [M|Mode1], + $pp_mode_symbol(M,Sym), !, % M can be an unbound variable + write(S,Sym), + write(S,','), + N1 is N - 1, + $pp_dump_m_spec(S,N1,Mode1). + +$pp_mode_symbol(d ,? ). +$pp_mode_symbol(? ,? ). +$pp_mode_symbol(c ,+ ). +$pp_mode_symbol(+ ,+ ). +$pp_mode_symbol(f ,- ). +$pp_mode_symbol(- ,- ). +$pp_mode_symbol(nv,nv). + + +%%-------------------------------- +%% Table Decalrations + +$pp_dump_t_decls(_,Preds,_) :- Preds == [], !. +$pp_dump_t_decls(S,Preds,L) :- Preds = [Pred|Preds1], !, + Pred = pred(F,N,_,_,T,_), + $pp_dump_t_decl(S,F,N,T,L), + $pp_dump_t_decls(S,Preds1,L). + +$pp_dump_t_decl(_,_,_,T,_) :- var(T), !. +$pp_dump_t_decl(S,F,N,T,L) :- nonvar(T), !, + $pp_dump_nl(S,L), + format(S,":- table ~q/~d.~n",[F,N]). + + +%%-------------------------------- +%% Clauses + +$pp_dump_preds(_,Preds) :- Preds == [], !. +$pp_dump_preds(S,Preds) :- Preds = [Pred|Preds1], !, + Pred = pred(_,_,_,_,_,Cls), + $pp_dump_clauses(S,Cls,_), + $pp_dump_preds(S,Preds1). + +$pp_dump_clauses(_,Cls,_) :- Cls == [], !. +$pp_dump_clauses(S,Cls,L) :- Cls = [Cl|Cls1], !, + $pp_dump_nl(S,L), + portray_clause(S,Cl), + $pp_dump_clauses(S,Cls1,L). diff --git a/packages/prism/src/prolog/trans/trans.pl b/packages/prism/src/prolog/trans/trans.pl new file mode 100644 index 000000000..f3e28813e --- /dev/null +++ b/packages/prism/src/prolog/trans/trans.pl @@ -0,0 +1,735 @@ +%% -*- Prolog -*- + +%%====================================================================== +%% +%% [Notes on translation information] +%% +%% This translator uses a term containing the global information shared +%% by the translation processes. It takes the form: +%% +%% $trans_info(DoTable,TPredTab,NoDebug,PPredTab) +%% +%% DoTable denotes whether probabilistic predicates should be tabled +%% by default (i.e. unless declared in the source program); it takes +%% 1 if the predicates should be tabled; 0 otherwise. In case of an +%% unbound variable, DoTable should be considered to be 1. +%% +%% TPredTab is a hashtable that contains tabled/non-tabled predicates +%% which are compatible with the default (i.e. DoTable). The key of +%% each entry has the form P/N; the value is ignored. In consultation +%% mode where all probabilistic predicates are not tabled, TPredTab is +%% just a free variable. +%% +%% NoDebug indicates whether "write_call" should be disabled; any non- +%% variable disables the feature. +%% +%% PPredTab is a hashtable that contains probabilistic predicates found +%% in the source program. Each entry has the form P/N={0 or 1}, where +%% the value is 1 if the predicate is tabled and 0 otherwise. +%% +%%====================================================================== + +%%---------------------------------------------------------------------- +%% Entry Point +%%---------------------------------------------------------------------- + +$pp_compile(PsmFile,DmpFile,OutFile) :- + $pp_bpif_read_program(Prog0,PsmFile), + new_hashtable(TPredTab), + new_hashtable(PPredTab), + Info = $trans_info(_DoTable,TPredTab,_NoDebug,PPredTab), + $pp_trans_phase1(Prog0,Prog1,Info), + $pp_trans_phase2(Prog1,Prog2,Info), + $pp_trans_phase3(Prog2,Prog3,Info), + $pp_trans_phase4(Prog3,Prog4,Info), + $pp_trans_phase5(Prog4,Prog5,Info), + Prog = Prog5, + % $pp_dump_program(Prog), % for debugging + ( $pp_valid_program(Prog) + ; $pp_raise_internal_error($msg(9802),invalid_compilation,$pp_compile/3) + ), + ( var(DmpFile) -> true ; $pp_save_program(Prog,DmpFile) ), + $pp_bpif_compile_program(Prog,OutFile),!. + + +%%---------------------------------------------------------------------- +%% Phase #1: Scan the queries. +%%---------------------------------------------------------------------- + +$pp_trans_phase1(Prog0,Prog,Info) :- + $pp_extract_decls(Prog0,Info), + Prog = Prog0. + +$pp_extract_decls([],_) => true. +$pp_extract_decls([Pred|Preds],Info), + Pred = pred($damon_load,0,_,_,_,[($damon_load:-Demon0)|_]) => + $pp_extract_decls_from_demons(Demon0,Info),!, + $pp_extract_decls(Preds,Info). +$pp_extract_decls([_Pred|Preds],Info) => + $pp_extract_decls(Preds,Info). + +$pp_extract_decls_from_demons((D1,D2),Info) => + $pp_extract_decls_from_demons(D1,Info),!, + $pp_extract_decls_from_demons(D2,Info). +$pp_extract_decls_from_demons($query((p_table Preds)),Info) => + Info = $trans_info(DoTable,TPredTab,_,_), + ( var(TPredTab) -> true % consult mode + ; DoTable == 1 -> + $pp_add_preds_to_hashtable(Preds,TPredTab) + ; var(DoTable) -> + $pp_add_preds_to_hashtable(Preds,TPredTab), + DoTable = 1 + ; DoTable == 0 -> + $pp_raise_trans_error($msg(1101),mixed_table_declarations,$pp_trans_phase1/3) + ; $pp_raise_unmatched_branches($pp_extract_decls_from_demons/2, + query) + ). +$pp_extract_decls_from_demons($query((p_not_table Preds)),Info) => + Info = $trans_info(DoTable,TPredTab,_,_), + ( var(TPredTab) -> true % consult mode + ; DoTable == 0 -> + $pp_add_preds_to_hashtable(Preds,TPredTab) + ; var(DoTable) -> + $pp_add_preds_to_hashtable(Preds,TPredTab), + DoTable = 0 + ; DoTable == 1 -> + $pp_raise_trans_error($msg(1101),mixed_table_declarations,$pp_trans_phase1/3) + ; $pp_raise_unmatched_branches($pp_extract_decls_from_demons/2, + p_not_table) + ). +$pp_extract_decls_from_demons($query(disable_write_call),Info) => + Info = $trans_info(_,_,NoDebug,_), + ( NoDebug == 1 -> true + ; var(NoDebug) -> NoDebug = 1 + ; $pp_raise_unmatched_branches($pp_extract_decls_from_demons/2, + disable_write_call) + ). +$pp_extract_decls_from_demons(_,_Info) => true. + +$pp_add_preds_to_hashtable((Pred,Preds),TPredTab) :- !, + $pp_add_one_pred_to_hashtable(Pred,TPredTab),!, + $pp_add_preds_to_hashtable(Preds,TPredTab). +$pp_add_preds_to_hashtable(Pred,TPredTab) :- + $pp_add_one_pred_to_hashtable(Pred,TPredTab),!. + +$pp_add_one_pred_to_hashtable(Pred,TPredTab) :- + $pp_require_predicate_indicator(Pred,$msg(1102),$pp_trans_phase1/3), + Pred = F/N, + ( hashtable_get(TPredTab,F/N,_) -> true + ; hashtable_register(TPredTab,F/N,1) + ). + +%%---------------------------------------------------------------------- +%% Phase #2: Process values/2-3. +%%---------------------------------------------------------------------- + +% We do not refer to the information objects here. +$pp_trans_phase2(Prog0,Prog,_Info) :- + $pp_trans_values(Prog0,Prog1), + $pp_replace_values(Prog1,Prog). + +% translate the "values" declarations +$pp_trans_values(Preds0,Preds) :- + $pp_trans_values(Preds0,Preds1,ValCls,Demon,DemonAux), + Preds2 = [pred($pu_values,2,_Mode,_Delay,_Tabled,ValCls)|Preds1], + DemonCl1 = ($damon_load:-Demon,DemonAux), + DemonCl2 = ($damon_load:-true), + Preds = [pred($damon_load,0,_,_,_,[DemonCl1,DemonCl2])|Preds2]. + +$pp_trans_values([],[],[],true,true). +$pp_trans_values([pred(F,2,_,_,_,Cls0)|Preds0], + Preds,ValCls,Demon,DemonAux) :- + (F = values ; F = values_x),!, + $pp_trans_values_clauses(Cls0,Cls1), + append(Cls1,ValCls1,ValCls),!, + $pp_trans_values(Preds0,Preds,ValCls1,Demon,DemonAux). +$pp_trans_values([pred(F,3,_,_,_,Cls0)|Preds0], + Preds,ValCls,Demon,DemonAux) :- + (F = values ; F = values_x),!, + $pp_trans_values_demon_clauses(Cls0,Cls1,DemonAux), + append(Cls1,ValCls1,ValCls),!, + $pp_trans_values(Preds0,Preds,ValCls1,Demon,_). +$pp_trans_values([pred($damon_load,0,_,_,_,[($damon_load:-Demon)|_])|Preds0], + Preds,ValCls,Demon,DemonAux) :- !, + $pp_trans_values(Preds0,Preds,ValCls,_,DemonAux). +$pp_trans_values([P|Preds0],[P|Preds],ValCls,Demon,DemonAux) :- !, + $pp_trans_values(Preds0,Preds,ValCls,Demon,DemonAux). + +$pp_trans_values_clauses([],[]). +$pp_trans_values_clauses([Cl0|Cls0],[Cl|Cls]) :- + $pp_trans_values_one_clause(Cl0,Cl),!, + $pp_trans_values_clauses(Cls0,Cls). + +$pp_trans_values_one_clause(Cl0,Cl) :- + ( Cl0 = (values(Sw,Vals0):-Body) -> true + ; Cl0 = (values_x(Sw,Vals0):-Body) -> true + ; Cl0 = values(Sw,Vals0) -> Body = true + ; Cl0 = values_x(Sw,Vals0) -> Body = true + ), + $pp_build_expand_values(Vals0,Vals,Expand), + Cl = ($pu_values(Sw,Vals):-Body,Expand). + +$pp_trans_values_demon_clauses([],[],true). +$pp_trans_values_demon_clauses([Cl0|Cls0],[Cl|Cls],Demon) :- + ( Cl0 = (values(Sw,Vals0,Demons):-Body) -> true + ; Cl0 = (values_x(Sw,Vals0,Demons):-Body) -> true + ; Cl0 = values(Sw,Vals0,Demons) -> Body = true + ; Cl0 = values_x(Sw,Vals0,Demons) -> Body = true + ), + $pp_build_expand_values(Vals0,Vals,Expand), + Cl = ($pu_values(Sw,Vals):-Body,Expand), + ( ground(Sw),ground(Demons) + -> $pp_trans_values_demons(Sw,Demons,Demon1), Demon = (Demon1,Demon2) + ; $pp_raise_warning($msg(1104),[Sw,Demons]), Demon = Demon2 + ),!, + $pp_trans_values_demon_clauses(Cls0,Cls,Demon2). + +$pp_trans_values_demons(_Sw,true,true) :- !. +$pp_trans_values_demons(Sw,(Demon0,Demons),(Demon2,Demon1)) :- !, + $pp_trans_values_demons(Sw,Demon0,Demon2),!, + $pp_trans_values_demons(Sw,Demons,Demon1). +$pp_trans_values_demons(Sw,Demon0,Demon) :- + ( Demon0 = set@Params -> Demon = $query(set_sw(Sw,Params)) + ; Demon0 = fix@Params -> Demon = $query(fix_sw(Sw,Params)) + ; Demon0 = a@HParams -> Demon = $query(set_sw_a(Sw,HParams)) + ; Demon0 = d@HParams -> Demon = $query(set_sw_d(Sw,HParams)) + ; Demon0 = h@HParams -> Demon = $query(set_sw_d(Sw,HParams)) + ; Demon0 = set_a@HParams -> Demon = $query(set_sw_a(Sw,HParams)) + ; Demon0 = set_d@HParams -> Demon = $query(set_sw_d(Sw,HParams)) + ; Demon0 = set_h@HParams -> Demon = $query(set_sw_d(Sw,HParams)) + ; Demon0 = fix_a@HParams -> Demon = $query(fix_sw_a(Sw,HParams)) + ; Demon0 = fix_d@HParams -> Demon = $query(fix_sw_d(Sw,HParams)) + ; Demon0 = fix_h@HParams -> Demon = $query(fix_sw_d(Sw,HParams)) + ; Demon0 = Params -> Demon = $query(set_sw(Sw,Params)) + ). + +$pp_build_expand_values(Vals0,Vals,Expand) :- + ( $pp_unexpandable_values(Vals0) -> Expand = true, Vals = Vals0 + ; Expand = expand_values1(Vals0,Vals) % use the no-exception version + ). + +% Checks if Vals only contains ground values that cannot be expanded by +% expand_values{,1}/2: +$pp_unexpandable_values(Vals) :- + is_list(Vals), + ground(Vals), + $pp_unexpandable_values1(Vals). + +$pp_unexpandable_values1([]). +$pp_unexpandable_values1([V|Vals]) :- + ( V \= _Start-_End@_Step ; V \= _Start-_End ),!, + $pp_unexpandable_values1(Vals). + + +% replace all appearances of values/2 in the clause bodies with get_values/2 +$pp_replace_values([],[]). +$pp_replace_values([Pred0|Preds0],[Pred|Preds]) :- + Pred0 = pred(F,N,Mode,Delay,Tabled,Cls0), + Pred = pred(F,N,Mode,Delay,Tabled,Cls), + $pp_replace_values_clauses(Cls0,Cls),!, + $pp_replace_values(Preds0,Preds). + +$pp_replace_values_clauses([],[]). +$pp_replace_values_clauses([Cl0|Cls0],[Cl|Cls]) :- + $pp_replace_values_one_clause(Cl0,Cl),!, + $pp_replace_values_clauses(Cls0,Cls). + +$pp_replace_values_one_clause(Cl0,Cl) :- + ( Cl0 = (Head:-Body0) -> + $pp_replace_values_body(Body0,Body), Cl = (Head:-Body) + ; Cl = Cl0 + ). + +$pp_replace_values_body((G1,G2),(NG1,NG2)) :- !, + $pp_replace_values_body(G1,NG1), + $pp_replace_values_body(G2,NG2). +$pp_replace_values_body((G1;G2),(NG1;NG2)) :- !, + $pp_replace_values_body(G1,NG1), + $pp_replace_values_body(G2,NG2). +$pp_replace_values_body(not(G),not(NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body((\+ G),(\+ NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body((C->G),(NC->NG)) :- !, + $pp_replace_values_body(C,NC), + $pp_replace_values_body(G,NG). +$pp_replace_values_body(write_call(G),write_call(NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body(write_call(Opts,G),write_call(Opts,NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body((?? G),(?? NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body((??* G),(??* NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body((??> G),(??> NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body((??< G),(??< NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body((??+ G),(??+ NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body((??- G),(??- NG)) :- !, + $pp_replace_values_body(G,NG). +$pp_replace_values_body(values(Sw,Vals),get_values(Sw,Vals)) :- !. +$pp_replace_values_body(G,G). + + +%%---------------------------------------------------------------------- +%% Phase #3: Find probabilistic predicates. +%%---------------------------------------------------------------------- + +$pp_trans_phase3(Prog0,Prog,Info) :- + $pp_analyze(Prog0,Info), + Prog = Prog0. + +$pp_analyze(Prog,Info) :- + Info = $trans_info(_,_,_,PPredTab), + $pp_collect_preds(Prog,PPredTab), + $pp_infer_prob_preds_fixpoint(Prog,Info), + $pp_complete_prob_preds(Info), + $pp_assert_prob_preds(Prog,Info). + +% collect the predicates appearing in the program +$pp_collect_preds([],_). +$pp_collect_preds([pred($damon_load,0,_,_,_,_)|Preds],PPredTab) :- !, + hashtable_register(PPredTab,$damon_load/0,_),!, + $pp_collect_preds(Preds,PPredTab). +$pp_collect_preds([pred(values,2,_,_,_,_)|Preds],PPredTab) :- !, + hashtable_register(PPredTab,values/2,_),!, + $pp_collect_preds(Preds,PPredTab). +$pp_collect_preds([pred(F,N,_Mode,_Delay,_Tabled,_Cls)|Preds],PPredTab) :- + hashtable_register(PPredTab,F/N,_),!, + $pp_collect_preds(Preds,PPredTab). + +$pp_infer_prob_preds_fixpoint(Prog,Info) :- + Info = $trans_info(_,_,_,PPredTab), + global_set($pg_prob_tab_updated,0,0), + $pp_infer_prob_preds(Prog,PPredTab), + % if some probabilistic predicate have been newly found, try again: + ( global_get($pg_prob_tab_updated,0,1) + -> $pp_infer_prob_preds_fixpoint(Prog,Info) + ; true + ). + +$pp_infer_prob_preds([],_PPredTab) => true. +$pp_infer_prob_preds([pred(values,2,_,_,_,_)|Preds],PPredTab) => + $pp_infer_prob_preds(Preds,PPredTab). +$pp_infer_prob_preds([pred(F,N,_Mode,_Delay,_Tab,Cls)|Preds],PPredTab) => + hashtable_get(PPredTab,F/N,IsProb), + ( var(IsProb) -> $pp_infer_prob_cls(Cls,IsProb,PPredTab), + ( nonvar(IsProb) -> global_set($pg_prob_tab_updated,0,1) + ; true + ) + ; true + ),!, + $pp_infer_prob_preds(Preds,PPredTab). + +$pp_infer_prob_cls([],_IsProb,_PPredTab) => true. +$pp_infer_prob_cls([Cl|Cls],IsProb,PPredTab) => + $pp_infer_prob_cl(Cl,IsProb,PPredTab), + ( var(IsProb) -> $pp_infer_prob_cls(Cls,IsProb,PPredTab) + ; true + ). + +$pp_infer_prob_cl((_H:-B),IsProb,PPredTab) => + $pp_infer_prob_body(B,IsProb,PPredTab). +$pp_infer_prob_cl(_H,_IsProb,_PPredTab) => true. + +$pp_infer_prob_body((G1,G2),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab), + ( var(IsProb) -> $pp_infer_prob_body(G2,IsProb,PPredTab) + ; true + ). +$pp_infer_prob_body((C->G1;G2),IsProb,PPredTab) => + $pp_infer_prob_body(C,IsProb,PPredTab), + ( var(IsProb) -> + $pp_infer_prob_body(G1,IsProb,PPredTab), + ( var(IsProb) -> $pp_infer_prob_body(G2,IsProb,PPredTab) + ; true + ) + ; true + ). +$pp_infer_prob_body((G1;G2),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab), + ( var(IsProb) -> $pp_infer_prob_body(G2,IsProb,PPredTab) + ; true + ). +$pp_infer_prob_body(not(G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body((\+ G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body((C->G1),IsProb,PPredTab) => + $pp_infer_prob_body(C,IsProb,PPredTab), + ( var(IsProb) -> $pp_infer_prob_body(G1,IsProb,PPredTab) + ; true + ). +$pp_infer_prob_body(write_call(G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body(write_call(_,G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body((?? G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body((??* G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body((??> G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body((??< G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body((??+ G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body((??- G1),IsProb,PPredTab) => + $pp_infer_prob_body(G1,IsProb,PPredTab). +$pp_infer_prob_body(msw(_,_,_),IsProb,_PPredTab) => IsProb = 1. +$pp_infer_prob_body(msw(_,_),IsProb,_PPredTab) => IsProb = 1. +$pp_infer_prob_body(G,IsProb,PPredTab) :- + functor(G,F,N), + hashtable_get(PPredTab,F/N,IsProb1),!, + ( nonvar(IsProb1) -> IsProb = 1 + ; true + ). +$pp_infer_prob_body(_G,_IsProb,_PPredTab). /* G: undefined predicates */ + +$pp_complete_prob_preds(Info) :- + Info = $trans_info(_,_,_,PPredTab), + hashtable_keys_to_list(PPredTab,Preds), + $pp_complete_prob_preds(Preds,PPredTab). + +$pp_complete_prob_preds([],_). +$pp_complete_prob_preds([F/N|Preds],PPredTab) :- + hashtable_get(PPredTab,F/N,IsProb),!, + ( var(IsProb) -> IsProb = 0 + ; true + ),!, + $pp_complete_prob_preds(Preds,PPredTab). + +$pp_assert_prob_preds([],_). +$pp_assert_prob_preds([pred(F,N,_,_,_,_)|Preds],Info) :- + Info = $trans_info(DoTable,TPredTab,_,PPredTab), + hashtable_get(PPredTab,F/N,IsProb),!, + ( IsProb = 1 -> + $pp_abolish_compiled_pred(F,N), + ( $pd_is_prob_pred(F,N) -> true + ; assert($pd_is_prob_pred(F,N)) + ), + ( $pp_is_tabled_prob_pred(F/N,DoTable,TPredTab) + -> ( $pd_is_tabled_pred(F,N) -> true + ; assert($pd_is_tabled_pred(F,N)) + ) + ; true + ) + ; true + ),!, + $pp_assert_prob_preds(Preds,Info). + +$pp_abolish_compiled_pred(F,N) :- + $pp_trans_prob_pred_name(F,NewF), + global_del(NewF,N),!. + + +%%---------------------------------------------------------------------- +%% Phase #4: Translate the probabilistic predicates. +%%---------------------------------------------------------------------- + +% [Note] Mode indicators in B-Prolog: +% c (or +) : closed term +% f (or -) : free variable +% nv : non-variable term +% d (or ?) : dont-know term + +$pp_trans_phase4(Prog0,Prog,Info) :- + $pp_trans_prob_preds(Prog0,Prog,Info). + +$pp_trans_prob_preds([],Prog,_Info) => Prog = []. +$pp_trans_prob_preds([Pred|Preds],Prog,Info), + Pred = pred(F,N,Mode,Delay,Tabled,Cls) => + Info = $trans_info(_,_,NoDebug,_), + ( $pd_is_prob_pred(F,N) -> + Prog = [pred(F,N,Mode,Delay,_,Cls1),NewPred|Prog1], + ( $pd_is_tabled_pred(F,N) -> + NewTabled = tabled(_,_,_,_), + ( nonvar(Mode) -> NewMode = [f|Mode] ; true), + NewArity is N + 1 + ; % \+ $is_tabled_pred(F,N) + ( nonvar(Mode) -> NewMode = [d,d,d,d|Mode] + ; true + ), + NewArity is N + 4 + ), + NewPred = pred(NewF,NewArity,NewMode,_,NewTabled,NewCls), + $pp_trans_prob_pred_name(F,NewF), + copy_term(Cls,ClsCp), % Pred and NewPred do not share variables + $pp_trans_prob_cls(ClsCp,NewCls,NewF,NewTabled,Info) + ; % \+ $pd_is_prob_pred(F,N) + Prog = [pred(F,N,Mode,Delay,Tabled,Cls1)|Prog1] + ), + ( var(NoDebug) -> Cls1 = Cls + ; $pp_strip_write_call_cls(Cls,Cls1) % just strip the write_call predicates + ),!, + $pp_trans_prob_preds(Preds,Prog1,Info). + +$pp_trans_prob_cls([],Cls,_F,_Tabled,_Info) => Cls = []. +$pp_trans_prob_cls([(Head0:-Body0)|Cls0],Cls,F,Tabled,Info) => + Cls = [(Head:-Body)|Cls1], + Head0 =.. [_|Args], + ((nonvar(Tabled),Tabled = tabled(_,_,_,_)) -> + Head =.. [F,Gid0|Args], + $pp_trans_prob_body(Body0,Body1,Gids,[],Sids,[],Info), + ( Gids == [], Sids == [] -> RegistPath = true + ; RegistPath = + catch($prism_eg_path(Gid0,Gids,Sids), + Exception, + ($pp_emit_message($msg(9805),[Head0]),throw(Exception))) + % FIXME: this translation may lead to some overhead + ), + Body = (Body1, + $pc_prism_goal_id_register(Head0,Gid0), + RegistPath) + ; % Non-tabled + Head =.. [F,Gids,GidsR,Sids,SidsR|Args], + $pp_trans_prob_body(Body0,Body1,Gids,GidsR,Sids,SidsR,Info), + Body = Body1 + ),!, + $pp_trans_prob_cls(Cls0,Cls1,F,Tabled,Info). +$pp_trans_prob_cls([Head|Cls0],Cls,F,Tabled,Info) => + $pp_trans_prob_cls([(Head:-true)|Cls0],Cls,F,Tabled,Info). + +$pp_trans_prob_body((G1,G2),NewGoal,Gids,GidsR,Sids,SidsR,Info) => + NewGoal = (NG1,NG2), + $pp_trans_prob_body(G1,NG1,Gids,Gids1,Sids,Sids1,Info), + $pp_trans_prob_body(G2,NG2,Gids1,GidsR,Sids1,SidsR,Info). +$pp_trans_prob_body((C->A;B),NewGoal,Gids,GidsR,Sids,SidsR,Info) => + NewGoal = (InitVars, + (NC-> + (NA,Gids=GidsCp1,Sids=SidsCp1,GidsR=GidsRCp1,SidsR=SidsRCp1) + ;(NB,Gids=GidsCp2,Sids=SidsCp2,GidsR=GidsRCp2,SidsR=SidsRCp2))), + $pp_trans_prob_body(C,NC,GidsCp1,GidsCp3,SidsCp1,SidsCp3,Info), + $pp_trans_prob_body(A,NA,GidsCp3,GidsRCp1,SidsCp3,SidsRCp1,Info), + $pp_trans_prob_body(B,NB,GidsCp2,GidsRCp2,SidsCp2,SidsRCp2,Info), + vars_set((NA;NB),Vars), + $pp_gen_initialize_var([Vars,Gids,Sids,GidsR,SidsR, + GidsCp1,SidsCp1,GidsRCp1,SidsRCp1, + GidsCp2,SidsCp2,GidsRCp2,SidsRCp2, + GidsCp3,SidsCp3],InitVars). +$pp_trans_prob_body((A;B),NewGoal,Gids,GidsR,Sids,SidsR,Info) => + NewGoal = (InitVars, + ((NA,Gids=GidsCp1,Sids=SidsCp1,GidsR=GidsRCp1,SidsR=SidsRCp1) + ;(NB,Gids=GidsCp2,Sids=SidsCp2,GidsR=GidsRCp2,SidsR=SidsRCp2))), + $pp_trans_prob_body(A,NA,GidsCp1,GidsRCp1,SidsCp1,SidsRCp1,Info), + $pp_trans_prob_body(B,NB,GidsCp2,GidsRCp2,SidsCp2,SidsRCp2,Info), + vars_set((NA;NB),Vars), + $pp_gen_initialize_var([Vars,Gids,Sids,GidsR,SidsR, + GidsCp1,SidsCp1,GidsRCp1,SidsRCp1, + GidsCp2,SidsCp2,GidsRCp2,SidsRCp2],InitVars). +$pp_trans_prob_body(not(G),NewGoal,Gids,GidsR,Sids,SidsR,Info) => + NewGoal = not(NG), + Gids = GidsR, + Sids = SidsR, + $pp_trans_prob_body(G,NG,Gids,_,Sids,_,Info). +$pp_trans_prob_body(\+(G),NewGoal,Gids,GidsR,Sids,SidsR,Info) => + NewGoal = \+(NG), + Gids = GidsR, + Sids = SidsR, + $pp_trans_prob_body(G,NG,Gids,_,Sids,_,Info). +$pp_trans_prob_body((C->A),NewGoal,Gids,GidsR,Sids,SidsR,Info) => + NewGoal = (NC->NA), + $pp_trans_prob_body(C,NC,Gids,Gids1,Sids,Sids1,Info), + $pp_trans_prob_body(A,NA,Gids1,GidsR,Sids1,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info), + Goal = write_call(Goal1) => + $pp_trans_prob_body(write_call([],Goal1), + NewGoal,Gids,GidsR,Sids,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info), + Goal = write_call(Opts,Goal1) => + Info = $trans_info(_,_,NoDebug,_), + ( $pp_is_write_callable(Goal1) -> true + ; $pp_raise_trans_error($msg(1103),not_write_callable,$pp_trans_phase4/3) + ), + ( var(NoDebug) -> $pp_write_call_build(Opts,Goal1,NewGoal1,NewGoal) + ; NewGoal1 = NewGoal + ),!, + $pp_trans_prob_body(Goal1,NewGoal1,Gids,GidsR,Sids,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info), + Goal = (?? Goal1) => + $pp_trans_prob_body(write_call([],Goal1), + NewGoal,Gids,GidsR,Sids,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info), + Goal = (??* Goal1) => + $pp_trans_prob_body(write_call([all],Goal1), + NewGoal,Gids,GidsR,Sids,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info), + Goal = (??> Goal1) => + $pp_trans_prob_body(write_call([call],Goal1), + NewGoal,Gids,GidsR,Sids,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info), + Goal = (??< Goal1) => + $pp_trans_prob_body(write_call([exit+fail],Goal1), + NewGoal,Gids,GidsR,Sids,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info), + Goal = (??+ Goal1) => + $pp_trans_prob_body(write_call([exit],Goal1), + NewGoal,Gids,GidsR,Sids,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info), + Goal = (??- Goal1) => + $pp_trans_prob_body(write_call([fail],Goal1), + NewGoal,Gids,GidsR,Sids,SidsR,Info). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,_Info), + Goal = msw(I,V) => + Gids = GidsR, + Sids = [Sid|SidsR], + NewGoal = $prism_expl_msw(I,V,Sid). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,Info) :- + Info = $trans_info(DoTable,TPredTab,_,_), + functor(Goal,F,N), + $pd_is_prob_pred(F,N),!, + Goal =.. [_|Args], + $pp_trans_prob_pred_name(F,NewF), + ( $pp_is_tabled_prob_pred(F/N,DoTable,TPredTab) -> + NewGoal =.. [NewF,Gid|Args], + Gids = [Gid|GidsR], + Sids = SidsR + ; NewGoal =.. [NewF,Gids,GidsR,Sids,SidsR|Args] + ). +$pp_trans_prob_body(Goal,NewGoal,Gids,GidsR,Sids,SidsR,_Info) :- + Sids = SidsR, + Gids = GidsR, + Goal = NewGoal. + +$pp_strip_write_call_cls([],Cls)=> Cls = []. +$pp_strip_write_call_cls([(Head:-Body0)|Cls0],Cls) => + Cls = [(Head:-Body)|Cls1], + $pp_strip_write_call_body(Body0,Body),!, + $pp_strip_write_call_cls(Cls0,Cls1). +$pp_strip_write_call_cls([Head|Cls0],Cls) => + Cls = [Head|Cls1],!, + $pp_strip_write_call_cls(Cls0,Cls1). + +$pp_strip_write_call_body((A0,B0),Goal) => + Goal = (A1,B1), + $pp_strip_write_call_body(A0,A1), + $pp_strip_write_call_body(B0,B1). +$pp_strip_write_call_body((A0->B0;C0),Goal) => + Goal = (A1->B1;C1), + $pp_strip_write_call_body(A0,A1), + $pp_strip_write_call_body(B0,B1), + $pp_strip_write_call_body(C0,C1). +$pp_strip_write_call_body((A0;B0),Goal) => + Goal = (A1;B1), + $pp_strip_write_call_body(A0,A1), + $pp_strip_write_call_body(B0,B1). +$pp_strip_write_call_body(not(A0),Goal) => + Goal = not(A1), + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body(\+(A0),Goal) => + Goal = \+(A1), + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body((A0->B0),Goal) => + Goal = (A1->B1), + $pp_strip_write_call_body(A0,A1), + $pp_strip_write_call_body(B0,B1). +$pp_strip_write_call_body(write_call(A0),Goal) => Goal = A1, + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body(write_call(_,A0),Goal) => Goal = A1, + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body((?? A0),Goal) => Goal = A1, + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body((??* A0),Goal) => Goal = A1, + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body((??> A0),Goal) => Goal = A1, + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body((??< A0),Goal) => Goal = A1, + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body((??+ A0),Goal) => Goal = A1, + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body((??- A0),Goal) => Goal = A1, + $pp_strip_write_call_body(A0,A1). +$pp_strip_write_call_body(Goal0,Goal) => Goal = Goal0. + +$pp_gen_initialize_var(VarsL,InitVars):- + flatten(VarsL,Vars0), + sort(Vars0,Vars), + $pp_gen_initialize_var_aux(Vars,InitVarsL), + list_to_and(InitVarsL,InitVars). + +$pp_gen_initialize_var_aux([],[]). +$pp_gen_initialize_var_aux([Var|Vars],InitVars):- + ( var(Var) -> InitVars = ['_$initialize_var'(Var)|InitVars1] + ; InitVars = InitVars1 + ),!, + $pp_gen_initialize_var_aux(Vars,InitVars1). + +%%---------------------------------------------------------------------- +%% Phase #5: Add assert calls to the first demon call. +%%---------------------------------------------------------------------- + +$pp_trans_phase5(Prog0,Prog,Info) :- + $pp_add_assert_calls(Prog0,Prog,Info). + +$pp_add_assert_calls([],[],_). +$pp_add_assert_calls([Pred|Preds],[Pred1|Preds1],Info) :- + Pred = pred($damon_load,0,_,_,_,[($damon_load:-Demon)|DemonCls]), + $pp_build_assert_calls(Info,AssertCalls), + Demon1 = ($query(retractall($pd_is_prob_pred(_,_))), + $query(retractall($pd_is_tabled_pred(_,_))), + $query(call(AssertCalls)), + Demon), + Pred1 = pred($damon_load,0,_,_,_,[($damon_load:-Demon1)|DemonCls]),!, + $pp_add_assert_calls(Preds,Preds1,Info). +$pp_add_assert_calls([Pred|Preds],[Pred|Preds1],Info) :- !, + $pp_add_assert_calls(Preds,Preds1,Info). + +$pp_build_assert_calls(Info,AssertCalls) :- + Info = $trans_info(_,_,_,PPredTab), + hashtable_to_list(PPredTab,Pairs), + $pp_build_assert_calls1(Pairs,Info,AssertGs), + list_to_and(AssertGs,AssertCalls). + +$pp_build_assert_calls1([],_,[]). +$pp_build_assert_calls1([Pair|Pairs],Info,AssertGs) :- + Info = $trans_info(DoTable,TPredTab,_,_), + ( Pair = (F/N=V) -> + ( V == 1 -> + AssertGs = [assert($pd_is_prob_pred(F,N))|AssertGs2], + ( $pp_is_tabled_prob_pred(F/N,DoTable,TPredTab) -> + AssertGs2 = [assert($pd_is_tabled_pred(F,N))|AssertGs1] + ; AssertGs2 = AssertGs1 + ) + ; V == 0 -> AssertGs = AssertGs1 + ; $pp_raise_unmatched_branches($pp_build_assert_calls1/3,value) + ) + ; $pp_raise_unmatched_branches($pp_build_assert_calls1/3,pair) + ),!, + $pp_build_assert_calls1(Pairs,Info,AssertGs1). + + +%%---------------------------------------- +%% Auxiliary predicates for translation + +'_$initialize_var'(_). +'_$if_then_else'(C,A,B) :- (C->A;B). + +%%---------------------------------------- +%% Miscellaneous routines + +$pp_trans_prob_pred_name(F,NewF) :- + name(F,FString), + append("$pu_expl_",FString,NewFString), + name(NewF,NewFString). + + +$pp_is_tabled_prob_pred(F/N,DoTable,TPredTab) :- + ( var(TPredTab) -> fail % consult mode + ; true + ),!, + ( DoTable == 1 -> hashtable_get(TPredTab,F/N,_) + ; DoTable == 0 -> + ( hashtable_get(TPredTab,F/N,_) -> fail + ; true + ) + ; var(DoTable) -> true + ),!. + + +$pp_add_conj_to_list((A,B),List) => + $pp_add_conj_to_list(A,List),!, + $pp_add_conj_to_list(B,List). +$pp_add_conj_to_list(A,List) => + $member1(A,List). diff --git a/packages/prism/src/prolog/trans/verify.pl b/packages/prism/src/prolog/trans/verify.pl new file mode 100644 index 000000000..e06f0ed70 --- /dev/null +++ b/packages/prism/src/prolog/trans/verify.pl @@ -0,0 +1,130 @@ +%% -*- Prolog -*- + +%%====================================================================== +%% +%% This module provides a quick validator for programs represented in the +%% B-Prolog internal form. +%% +%% $pp_valid_program(Prog) :- +%% Succeeds if and only if is a valid program. +%% +%%====================================================================== + +%%-------------------------------- +%% Entry Point + +$pp_valid_program(Prog) :- + new_hashtable(Done), + $pp_valid_program_aux(Prog,Done). + +$pp_valid_program_aux(Prog,_), Prog == [] => + true. +$pp_valid_program_aux(Prog,Done), Prog = [Pred|Prog1] => + ( $pp_valid_prog_elem(Pred,Done) -> + true + ; $pp_emit_message($msg(1100),[Pred]), fail + ), + arg(1,Pred,F), + arg(2,Pred,N), + hashtable_register(Done,F/N,1), + $pp_valid_program_aux(Prog1,Done). + + +%%-------------------------------- +%% Predicate + +$pp_illegal_pred(':-',2). + +$pp_valid_prog_elem(Pred,Done) :- + Pred = pred(F,N,_,_,_,_), + atom(F), integer(N), N >= 0, + \+ ( $pp_illegal_pred(F,N) ; hashtable_get(Done,F/N,_) ), + $pp_valid_prog_pred(Pred). + +$pp_valid_prog_pred(Pred), + Pred = pred(F,N,M,D,T,Cls), + F == $damon_load, + N == 0 => + var(M), + var(D), + var(T), + Cls = [Cl0,Cl1], + Cl0 = ($damon_load :- Body), + Cl1 = ($damon_load :- true), + $pp_valid_damon(Body). +$pp_valid_prog_pred(Pred), + Pred = pred(F,N,M,D,T,Cls) => + $pp_valid_mspec(N,M), + $pp_valid_delay(D), + $pp_valid_table(T), + $pp_valid_clauses(F,N,D,Cls). + + +%%-------------------------------- +%% $damon_load/0 + +$pp_valid_damon(G) :- G = (A,B), !, + $pp_valid_damon(A), + $pp_valid_damon(B). +$pp_valid_damon(G) :- G == true, !, + true. +$pp_valid_damon(G) :- G = $query(_), !, + true. +$pp_valid_damon(G) :- callable(G), !, + true. + +%%-------------------------------- +%% Mode Spec + +$pp_valid_mspec(_,M), var(M) => true. +$pp_valid_mspec(N,M), nonvar(M) => + $pp_valid_mspec_loop(N,M). + +$pp_valid_mspec_loop(N,ModeL), N == 0 => ModeL == []. +$pp_valid_mspec_loop(N,ModeL), N >= 1 => + ModeL = [Mode|ModeL1], + $pp_valid_mode(Mode), + N1 is N - 1, + $pp_valid_mspec_loop(N1,ModeL1). + +$pp_valid_mode(M), M == c => true. +$pp_valid_mode(M), M == f => true. +$pp_valid_mode(M), M == nv => true. +$pp_valid_mode(M), M == d => true. + +%%-------------------------------- +%% Delay + +$pp_valid_delay(D), var(D) => true. +$pp_valid_delay(D), D == 1 => true. + + +%%-------------------------------- +%% Table + +$pp_valid_table(T), var(T) => true. +$pp_valid_table(T), + T = tabled(U1,U2,U3,U4), + var(U1), + var(U2), + var(U3), + var(U4) => true. + + +%%-------------------------------- +%% Clauses + +$pp_valid_clauses(_,_,_,Cls), Cls == [] => true. +$pp_valid_clauses(F,N,D,Cls), Cls = [Cl|Cls1] => + $pp_valid_clause(F,N,D,Cl), + $pp_valid_clauses(F,N,D,Cls1). + +$pp_valid_clause(F,N,_,Cl), Cl = (H :- _) => + nonvar(H), + functor(H,F,N). +$pp_valid_clause(F,N,D,Cl), Cl = delay(Cl1) => + D == 1, + $pp_valid_clause(F,N,_,Cl1). +$pp_valid_clause(F,N,_,Cl) => + nonvar(Cl), + functor(Cl,F,N). diff --git a/packages/prism/src/prolog/up/batch.pl b/packages/prism/src/prolog/up/batch.pl new file mode 100644 index 000000000..9cff8ac9d --- /dev/null +++ b/packages/prism/src/prolog/up/batch.pl @@ -0,0 +1,5 @@ +main :- $pp_batch. + +$pp_batch_call(Goal) :- + ( call(Goal) -> Res = yes ; Res = no ), + format("~n~w~n",[Res]). diff --git a/packages/prism/src/prolog/up/bigarray.pl b/packages/prism/src/prolog/up/bigarray.pl new file mode 100644 index 000000000..443f34f22 --- /dev/null +++ b/packages/prism/src/prolog/up/bigarray.pl @@ -0,0 +1,154 @@ +%%%% +%%%% bigarray.pl -- A large one-dimensional array for B-Prolog +%%%% + +%%---------------------------------------- + +$pp_bigarray_unit(65535). % max_arity + +%%---------------------------------------- + +new_bigarray(Array,N), var(Array), integer(N), N > 0 => + $pp_bigarray_unit(M), + Array = $bigarray(N,Body), + $pp_new_bigarray(Body,N,M). + +new_bigarray(Array,N) => + $pp_new_bigarray_throw(Array,N). + +$pp_new_bigarray_throw(Array,N) :- + ( var(Array) -> true + ; throw(error(type_error(variable,Array),new_bigarray/2)) + ), + ( nonvar(N) -> true + ; throw(error(instantiation_error,new_bigarray/2)) + ), + ( integer(N) -> true + ; throw(error(type_error(integer,N),new_bigarray/2)) + ), + ( N > 0 -> true + ; throw(error(domain_error(greater_than_zero,N),new_bigarray/2)) + ), !, + fail. % should not reach here + +$pp_new_bigarray(Body,N,M), N =< M => + functor(Body,array,N). + +$pp_new_bigarray(Body,N,M), N > M => + L is (N - 1) // M + 1, + functor(Body,outer,L), + $pp_new_bigarray(Body,1,N,M). + +$pp_new_bigarray(Body,K,N,M), N =< M => + arg(K,Body,SubBody), + functor(SubBody,array,N). + +$pp_new_bigarray(Body,K,N,M), N > M => + arg(K,Body,SubBody), + functor(SubBody,array,M), + K1 is K + 1, + N1 is N - M, !, + $pp_new_bigarray(Body,K1,N1,M). + +%%---------------------------------------- + +is_bigarray(Array), Array = $bigarray(_,_) => true. + +bigarray_length(Array,L), Array = $bigarray(N,_) => L = N. +bigarray_length(Array,_) => + $pp_bigarray_length_throw(Array). + +$pp_bigarray_length_throw(Array) :- + ( nonvar(Array) -> true + ; throw(error(instantiation_error,bigarray_length/2)) + ), + ( Array ?= $bigarray(_,_) -> true + ; throw(error(domain_error(bigarray,Array),bigarray_length/2)) + ), !, + fail. % should not reach here + +%%---------------------------------------- + +bigarray_get(Array,I,Value), + Array = $bigarray(N,Body), + integer(I), + I >= 1, + I =< N => + $pp_bigarray_get(Body,I,Value). + +bigarray_get(Array,I,_Value) => + $pp_bigarray_access_throw(Array,I,bigarray_get/3). + +bigarray_put(Array,I,Value), + Array = $bigarray(N,Body), + integer(I), + I >= 1, + I =< N => + $pp_bigarray_put(Body,I,Value). + +bigarray_put(Array,I,_Value) => + $pp_bigarray_access_throw(Array,I,bigarray_put/3). + +$pp_bigarray_access_throw(Array,I,Source) :- + ( nonvar(Array) -> true + ; throw(error(instantiation_error,Source)) + ), + ( Array = $bigarray(N,_) -> true + ; throw(error(domain_error(bigarray,Array),Source)) + ), + ( nonvar(I) -> true + ; throw(error(instantiation_error,Source)) + ), + ( integer(I) -> true + ; throw(error(type_error(integer,I),Source)) + ), + ( I >= 1, I =< N -> true + ; throw(error(domain_error(bigarray_index,I),Source)) + ), !, + fail. % should not reach here + +$pp_bigarray_get(Body,I,Elem), functor(Body,array,_) => + arg(I,Body,Elem). +$pp_bigarray_get(Body,I,Elem), functor(Body,outer,_) => + $pp_bigarray_unit(M), + Q is (I - 1) // M + 1, + R is (I - 1) mod M + 1, + arg(Q,Body,SubBody), + arg(R,SubBody,Elem). + +$pp_bigarray_put(Body,I,Elem), functor(Body,array,_) => + setarg(I,Body,Elem). +$pp_bigarray_put(Body,I,Elem), functor(Body,outer,_) => + $pp_bigarray_unit(M), + Q is (I - 1) // M + 1, + R is (I - 1) mod M + 1, + arg(Q,Body,SubBody), + setarg(R,SubBody,Elem). + +%%---------------------------------------- + +list_to_bigarray(List,Array) :- + $pp_bigarray_unit(M), + length(List,N), + Array = $bigarray(N,Body), + $pp_new_bigarray(Body,N,M), + $pp_list_to_bigarray(List,1,Body). + +$pp_list_to_bigarray(Xs,_,_), Xs = [] => true. +$pp_list_to_bigarray(Xs,K,Body), Xs = [X|Xs1] => + $pp_bigarray_put(Body,K,X), + K1 is K + 1, !, + $pp_list_to_bigarray(Xs1,K1,Body). + +bigarray_to_list(Array,List), Array = $bigarray(N,Body) => + $pp_bigarray_to_list(Body,1,N,List). + +$pp_bigarray_to_list(_,K,N,Xs), K > N => + Xs = []. +$pp_bigarray_to_list(Body,K,N,Xs), K =< N => + $pp_bigarray_get(Body,K,X), + Xs = [X|Xs1], + K1 is K + 1, !, + $pp_bigarray_to_list(Body,K1,N,Xs1). + +%%---------------------------------------- diff --git a/packages/prism/src/prolog/up/dist.pl b/packages/prism/src/prolog/up/dist.pl new file mode 100644 index 000000000..d93c8de45 --- /dev/null +++ b/packages/prism/src/prolog/up/dist.pl @@ -0,0 +1,193 @@ +%%---------------------------------------- + +expand_probs(Dist,Probs) :- + $pp_expand_probs(Dist,Probs,expand_probs/2). + +expand_probs(Dist,N,Probs) :- + $pp_expand_probs(Dist,N,Probs,expand_probs/3). + +$pp_expand_probs(Dist,Probs,Source) :- + $pp_require_fixed_size_distribution(Dist,$msg(0200),Source), + $pp_spec_to_ratio(Dist,_,Ratio,Source), + $pp_normalize_ratio(Ratio,Probs). + +$pp_expand_probs(Dist,N,Probs,Source) :- + $pp_require_distribution(Dist,$msg(0200),Source), + $pp_require_positive_integer(N,$msg(0204),Source), + $pp_spec_to_ratio(Dist,N,Ratio,Source), + $pp_check_expanded_prob_size(Ratio,N,Source), + $pp_normalize_ratio(Ratio,Probs). + +$pp_normalize_ratio(Ratio,Probs) :- + sumlist(Ratio,Denom), + $pp_ratio_to_probs(Ratio,Denom,Probs). + +$pp_ratio_to_probs([],_,[]) :- !. +$pp_ratio_to_probs([X|Xs],Denom,[Y|Ys]) :- + Y is X / Denom,!, + $pp_ratio_to_probs(Xs,Denom,Ys). + +$pp_check_expanded_prob_size(List,N,Source) :- + length(List,N1), + ( N = N1 -> true + ; $pp_raise_runtime_error($msg(0211),[List,N],unmatched_distribution, + Source) + ),!. + +%%---------------------------------------- + +$pp_spec_to_ratio(Dist,N,Ratio,Source) :- + ( Dist = default, + get_prism_flag(default_sw,none) + -> $pp_raise_runtime_error($msg(0202), + default_distribution_unavailable, + Source) + ; true + ), + $pp_spec_to_ratio1(Dist,N,Ratio,Source). + +$pp_spec_to_ratio1(Dist,_N,Ps,_Source), Dist = [_|_] => Ps = Dist. + +$pp_spec_to_ratio1(Dist,_N,Ps,_Source), Dist = (_+_) => + $pp_expr_to_list('+',Dist,Ps). + +$pp_spec_to_ratio1(Dist,_N,Ratio,_Source), Dist = (_:_) => + $pp_expr_to_list(':',Dist,Ratio). + +$pp_spec_to_ratio1(uniform,N,Ratio,_Source) => + $pp_gen_geom_list(N,1,1,Ratio). + +$pp_spec_to_ratio1(f_geometric,N,Ratio,_Source) => + $pp_spec_to_ratio_fgeom(2,desc,N,Ratio). + +$pp_spec_to_ratio1(f_geometric(Base),N,Ratio,_Source) => + $pp_spec_to_ratio_fgeom(Base,desc,N,Ratio). + +$pp_spec_to_ratio1(f_geometric(Base,Type),N,Ratio,_Source) => + $pp_spec_to_ratio_fgeom(Base,Type,N,Ratio). + +$pp_spec_to_ratio1(random,N,Ratio,_Source) => + $pp_gen_rand_list(N,Ratio). + +$pp_spec_to_ratio1(default,N,Ratio,Source) => + get_prism_flag(default_sw,Flag), + $pp_require_distribution(Flag,$msg(0200),Source),!, + $pp_spec_to_ratio1(Flag,N,Ratio,Source). + +%%---------------------------------------- + +expand_pseudo_counts(Spec,Cs) :- + $pp_require_fixed_size_hyperparameters(Spec,$msg(0201), + expand_pseudo_counts/2), + $pp_expand_pseudo_counts(Spec,_,Cs,expand_pseudo_counts/2). + +expand_pseudo_counts(Spec,N,Cs) :- + Source = expand_pseudo_counts/3, + $pp_require_hyperparameters(Spec,$msg(0201),Source), + $pp_require_positive_integer(N,$msg(0204),Source), + $pp_expand_pseudo_counts(Spec,N,Cs,Source), + $pp_check_expanded_pseudo_count_size(Cs,N,Source). + +$pp_expand_pseudo_counts(Spec,N,Cs,Source) :- + ( Spec = default, + $pp_get_default_pseudo_counts(none) + -> $pp_raise_runtime_error($msg(0202), + default_hyperparameters_unavailable, + Source) + ; true + ), + $pp_spec_to_pseudo_counts(Spec,N,Cs,Source). + +$pp_spec_to_pseudo_counts(Spec,_N,Cs,_Source), Spec = [_|_] => Cs = Spec. + +$pp_spec_to_pseudo_counts(Spec,N,Cs,_Source), number(Spec) => + C = Spec, + $pp_gen_dup_list(N,C,Cs). + +$pp_spec_to_pseudo_counts(uniform,N,Cs,Source) => + $pp_spec_to_pseudo_counts(uniform(1.0),N,Cs,Source). + +$pp_spec_to_pseudo_counts(uniform(U),N,Cs,_Source) => + C is U / N, + $pp_gen_dup_list(N,C,Cs). + +$pp_spec_to_pseudo_counts(f_geometric,N,Cs,Source) => + $pp_spec_to_pseudo_counts(f_geometric(1.0,2.0,desc),N,Cs,Source). + +$pp_spec_to_pseudo_counts(f_geometric(Base),N,Cs,Source) => + $pp_spec_to_pseudo_counts(f_geometric(1.0,Base,desc),N,Cs,Source). + +$pp_spec_to_pseudo_counts(f_geometric(Init,Base),N,Cs,Source) => + $pp_spec_to_pseudo_counts(f_geometric(Init,Base,desc),N,Cs,Source). + +$pp_spec_to_pseudo_counts(f_geometric(Init,Base,Type),N,Cs,_Source) => + $pp_spec_to_ratio_fgeom(Init,Base,Type,N,Cs). + +$pp_spec_to_pseudo_counts(default,N,Cs,Source) => + $pp_get_default_pseudo_counts(Spec), % get hyperparameters anyway + $pp_require_hyperparameters(Spec,$msg(0201),Source),!, + $pp_spec_to_pseudo_counts(Spec,N,Cs,Source). + +$pp_get_default_pseudo_counts(Spec) :- + ( get_prism_flag(default_sw_a,$disabled) -> + get_prism_flag(default_sw_d,Spec) + ; get_prism_flag(default_sw_a,Spec) + ). + +$pp_check_expanded_pseudo_count_size(List,N,Source) :- + length(List,N1), + ( N = N1 -> true + ; $pp_raise_runtime_error($msg(0211),[List,N],unmatched_pseudo_counts, + Source) + ),!. + +%%---------------------------------------- + +$pp_spec_to_ratio_fgeom(Base,Type,N,Ratio) :- + $pp_spec_to_ratio_fgeom(1.0,Base,Type,N,Ratio). + +$pp_spec_to_ratio_fgeom(Init,Base,Type,N,Ratio) :- + $pp_gen_geom_list(N,Init,Base,Ratio0), + ( Type == asc -> Ratio0 = Ratio ; reverse(Ratio0,Ratio) ). + +%%---------------------------------------- + +$pp_expr_to_list(Op,Expr,List) :- + current_op(_,yfx,Op),!, + $pp_expr_to_list_yfx(Op,Expr,List,[]). +$pp_expr_to_list(Op,Expr,List) :- + current_op(_,xfy,Op),!, + $pp_expr_to_list_xfy(Op,Expr,List,[]). + +$pp_expr_to_list_yfx(Op,Expr,L0,L1), functor(Expr,Op,2) => + Expr =.. [Op,Expr1,X], + L2 = [X|L1], !, + $pp_expr_to_list_yfx(Op,Expr1,L0,L2). +$pp_expr_to_list_yfx(_ ,Expr,L0,L1) => + L0 = [Expr|L1]. + +$pp_expr_to_list_xfy(Op,Expr,L0,L1), functor(Expr,Op,2) => + Expr =.. [Op,X,Expr1], + L0 = [X|L2], !, + $pp_expr_to_list_xfy(Op,Expr1,L2,L1). +$pp_expr_to_list_xfy(_ ,Expr,L0,L1) => + L0 = [Expr|L1]. + +%%---------------------------------------- + +$pp_gen_geom_list(0,_,_,[]) :- !. +$pp_gen_geom_list(N,X,Base,[X|Xs1]) :- + X1 is X * Base, + N1 is N - 1,!, + $pp_gen_geom_list(N1,X1,Base,Xs1). + +$pp_gen_rand_list(0,[]) :- !. +$pp_gen_rand_list(N,[X|Xs1]) :- + random_uniform(X), + N1 is N - 1,!, + $pp_gen_rand_list(N1,Xs1). + +$pp_gen_dup_list(0,_,[]) :- !. +$pp_gen_dup_list(N,C,[C|Cs]) :- + N1 is N - 1,!, + $pp_gen_dup_list(N1,C,Cs). diff --git a/packages/prism/src/prolog/up/dynamic.pl b/packages/prism/src/prolog/up/dynamic.pl new file mode 100644 index 000000000..78e56c8f9 --- /dev/null +++ b/packages/prism/src/prolog/up/dynamic.pl @@ -0,0 +1,41 @@ +% predicate_info +:- dynamic $pd_is_prob_pred/2. +:- dynamic $pd_is_tabled_pred/2. + +% switch_info +:- dynamic $pd_parameters/3. +:- dynamic $pd_hyperparameters/4. +:- dynamic $pd_expectations/3. +:- dynamic $pd_hyperexpectations/3. +:- dynamic $pd_fixed_parameters/1. +:- dynamic $pd_fixed_hyperparameters/1. + +% dummy_goal_table +:- dynamic $pd_dummy_goal_table/2. + +% learn_stats +:- dynamic $ps_log_likelihood/1. +:- dynamic $ps_log_post/1. +:- dynamic $ps_num_switches/1. +:- dynamic $ps_num_switch_values/1. +:- dynamic $ps_num_iterations/1. +:- dynamic $ps_num_iterations_vb/1. +:- dynamic $ps_bic_score/1. +:- dynamic $ps_cs_score/1. +:- dynamic $ps_free_energy/1. +:- dynamic $ps_learn_time/1. +:- dynamic $ps_learn_search_time/1. +:- dynamic $ps_em_time/1. +:- dynamic $ps_learn_table_space/1. + +% graph_stats +:- dynamic $ps_num_subgraphs/1. +:- dynamic $ps_num_nodes/1. +:- dynamic $ps_num_goal_nodes/1. +:- dynamic $ps_num_switch_nodes/1. +:- dynamic $ps_avg_shared/1. + +% infer_stats +:- dynamic $ps_infer_time/1. +:- dynamic $ps_infer_search_time/1. +:- dynamic $ps_infer_calc_time/1. diff --git a/packages/prism/src/prolog/up/expl.pl b/packages/prism/src/prolog/up/expl.pl new file mode 100644 index 000000000..7c054c1dc --- /dev/null +++ b/packages/prism/src/prolog/up/expl.pl @@ -0,0 +1,410 @@ +%% +%% expl.pl: routines for explanation search +%% +%% $pp_find_explanations(Goals) constructs the explanation graphs for Goals. +%% An explanation graph is a directed hype-graph where each hype-arc takes +%% the form of: +%% +%% $prism_eg_path(GoalId,Children,SWs) +%% +%% where +%% GoalId: +%% the id of the source node (all variant subgoals have the same ID) +%% Children: +%% the list of nodes that are connected by the hype-arc with GoalID +%% SWs: +%% the list of switches associated with the arc. +%% +%% consider the following PRISM program: +%% +%% values(init,[s0,s1]). +%% values(out(_),[a,b]). +%% values(tr(_),[s0,s1]). +%% +%% hmm(L) :- +%% msw(init,Si), +%% hmm(1,Si,L). +%% +%% hmm(T,S,[]) :- T>3. +%% hmm(T,S,[C|L]) :- +%% T=<3, +%% msw(out(S),C), +%% msw(tr(S),NextS), +%% T1 is T + 1, +%% hmm(T1,NextS,L). +%% +%% +%% The relations for the goal hmm([a,b,a]) are as follows (where goals +%% rather than their ids are shown for description purpose): +%% +%% goal_id(hmm([a,b,a]),0), +%% goal_id(hmm(1,s0,[a,b,a]),1) +%% goal_id(hmm(2,s0,[b,a]),4)] +%% goal_id(hmm(2,s1,[b,a]),11) +%% goal_id(hmm(3,s0,[a]),7) +%% goal_id(hmm(3,s1,[a]),9) +%% goal_id(hmm(3,s2,[a]),14) +%% goal_id(observe(1,s0,a),2) +%% goal_id(observe(2,s0,b),5) +%% goal_id(observe(2,s1,b),12) +%% goal_id(observe(3,s0,a),8) +%% goal_id(observe(3,s1,a),10) +%% goal_id(observe(3,s2,a),15) +%% goal_id(trans(1,s0,_5b0400),3) +%% goal_id(trans(2,s0,_5b0480),6) +%% goal_id(trans(2,s1,_5b04f0),13) +%% +%% $prism_eg_path(3,[],[msw(trans(s0),1,s0)]), +%% $prism_eg_path(6,[],[msw(trans(s0),2,s0)]), +%% $prism_eg_path(12,[],[msw(obs(s1),2,b)]), +%% $prism_eg_path(3,[],[msw(trans(s0),1,s1)]), +%% $prism_eg_path(6,[],[msw(trans(s0),2,s1)]), +%% $prism_eg_path(13,[],[msw(trans(s1),2,s1)]), +%% $prism_eg_path(0,[1],[]), +%% $prism_eg_path(7,[8],[]), +%% $prism_eg_path(1,[4,3,2],[]), +%% $prism_eg_path(13,[],[msw(trans(s1),2,s2)]), +%% $prism_eg_path(4,[7,6,5],[]), +%% $prism_eg_path(2,[],[msw(obs(s0),1,a)]), +%% $prism_eg_path(8,[],[msw(obs(s0),3,a)]), +%% $prism_eg_path(5,[],[msw(obs(s0),2,b)])] +%% +%% One of the explanations for hmm([a,b,a]) is: +%% +%% [msw(init,once,s0),msw(out(s0),1,a),msw(tr(s0),1,s0),msw(out(s0),2,b),...] +%% + +$pp_find_explanations(Goals) :- + $pp_expl_goals_all(Goals). + +$pp_expl_failure :- + $pp_trans_one_goal(failure,CompGoal),!, + call(CompGoal). +$pp_expl_failure :- + savecp(CP), + Depth = 0, + $pp_expl_interp_goal(failure,Depth,CP,[],_,[],_,[],_,[],_). + +$pp_expl_goals_all(Goals) :- + $pp_expl_goals(Goals). + +$pp_expl_goals([]) => true. +$pp_expl_goals([Goal|Goals]) => + $pp_learn_message(MsgS,_,_,_), + $pp_print_goal_message(MsgS), + ( $pp_expl_one_goal(Goal) -> true + ; $pp_raise_runtime_error($msg(1304),[Goal],explanation_not_found, + $pp_find_explanations/1) + ),!, + $pp_expl_goals(Goals). +$pp_expl_goals(Goal) => + $pp_expl_one_goal(Goal). + +$pp_expl_one_goal(msw(Sw,V)) :- !, + $prism_expl_msw(Sw,V,_Id). +$pp_expl_one_goal(failure) :- !, + $pp_expl_failure. +$pp_expl_one_goal(Goal) :- + $pp_is_dummy_goal(Goal),!, + call(Goal). +$pp_expl_one_goal(Goal) :- + % FIXME: handling non-tabled probabilistic predicate is future work + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),$pp_expl_one_goal/1), + ( ground(Goal) -> GoalCp = Goal + ; copy_term(Goal,GoalCp) + ), + ( $pp_trans_one_goal(GoalCp,CompGoal) -> +( % vsc: make this give all solutions!! + call(CompGoal) , fail ; true) +% old code was just: call(CompGoal) + ; savecp(CP), + Depth = 0, + $pp_expl_interp_goal(GoalCp,Depth,CP,[],_,[],_,[],_,[],_) + ). + +% [Note] this predicate fails if Goal is not probabilistic +$pp_trans_one_goal(Goal,CompGoal) :- + functor(Goal,F,N), + name(F,FString), + append("$pu_expl_",FString,NewFString), + name(NewF,NewFString), + N1 is N + 1, + current_predicate(NewF/N1),!, + Goal =.. [_|Args], + CompGoal =.. [NewF,_|Args]. + +%%---------------------------------------------------------------------------- + +$pp_expl_interp_goal('!',_Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + cutto(CP), + CIDs = CIDs0, + SWs = SWs0, + SimCIDs = SimCIDs0, + SimSWs = SimSWs0. +$pp_expl_interp_goal('_$savecp'(X),_Depth,_CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + savecp(X), + CIDs = CIDs0, + SWs = SWs0, + SimCIDs = SimCIDs0, + SimSWs = SimSWs0. +$pp_expl_interp_goal('_$savepcp'(X),_Depth,_CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + '_$savepcp'(X), + CIDs = CIDs0, + SWs = SWs0, + SimCIDs = SimCIDs0, + SimSWs = SimSWs0. +$pp_expl_interp_goal('_$cutto'(X),_Depth,_CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + cutto(X), + CIDs = CIDs0, + SWs = SWs0, + SimCIDs = SimCIDs0, + SimSWs = SimSWs0. +$pp_expl_interp_goal('_$initialize_var'(_Vars),_Depth,_CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + CIDs = CIDs0, + SWs = SWs0, + SimCIDs = SimCIDs0, + SimSWs = SimSWs0. +$pp_expl_interp_goal(Goal,Depth,_CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs), Goal = msw(I,V) => + CIDs = CIDs0, + SWs = [SwId|SWs0], + SimCIDs = SimCIDs0, + SimSWs = [Goal|SimSWs0], + c_SAVE_AR(AR), + c_next_global_call_number(CallNo), + $eval_and_monitor_call($prism_expl_msw(I,V,SwId),Depth,CallNo,AR). +$pp_expl_interp_goal((G1,G2),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + $pp_expl_interp_goal(G1,Depth,CP, + CIDs0,CIDs1,SWs0,SWs1, + SimCIDs0,SimCIDs1,SimSWs0,SimSWs1), + $pp_expl_interp_goal(G2,Depth,CP, + CIDs1,CIDs,SWs1,SWs, + SimCIDs1,SimCIDs,SimSWs1,SimSWs). +$pp_expl_interp_goal((C->A;B),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + '_$savecp'(NewCP), + ( eval_debug_call(C,Depth,NewCP) -> + $pp_expl_interp_goal(A,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) + ; $pp_expl_interp_goal(B,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) + ). +$pp_expl_interp_goal((C->A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + '_$savecp'(NewCP), + ( eval_debug_call(C,Depth,NewCP) -> + $pp_expl_interp_goal(A,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) + ). +$pp_expl_interp_goal((A;B),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + ( $pp_expl_interp_goal(A,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) + ; $pp_expl_interp_goal(B,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) + ). +$pp_expl_interp_goal(not(A),Depth,_CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + '_$savecp'(NewCP), + ( $pp_expl_interp_goal(A,Depth,NewCP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) -> fail + ; CIDs = CIDs0, + SWs = SWs0, + SimCIDs = SimCIDs0, + SimSWs = SimSWs0 + ). +$pp_expl_interp_goal((\+ A),Depth,_CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + '_$savecp'(NewCP), + ( $pp_expl_interp_goal(A,Depth,NewCP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) -> fail + ; CIDs = CIDs0, + SWs = SWs0, + SimCIDs = SimCIDs0, + SimSWs = SimSWs0 + ). +$pp_expl_interp_goal('_$if_then_else'(C,A,B),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + '_$savecp'(NewCP), + ( eval_debug_call(C,Depth,NewCP) -> + $pp_expl_interp_goal(A,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) + ; $pp_expl_interp_goal(B,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) + ). +$pp_expl_interp_goal(write_call(A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + $pp_expl_interp_goal(write_call([],A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs). +$pp_expl_interp_goal(write_call(Opts,A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + B = $pp_expl_interp_goal(A,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs), + $pp_write_call_core(Opts,A,B). +$pp_expl_interp_goal((?? A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + $pp_expl_interp_goal(write_call([],A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs). +$pp_expl_interp_goal((??* A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + $pp_expl_interp_goal(write_call([all],A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs). +$pp_expl_interp_goal((??> A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + $pp_expl_interp_goal(write_call([call],A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs). +$pp_expl_interp_goal((??< A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + $pp_expl_interp_goal(write_call([exit+fail],A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs). +$pp_expl_interp_goal((??+ A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + $pp_expl_interp_goal(write_call([exit],A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs). +$pp_expl_interp_goal((??- A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) => + $pp_expl_interp_goal(write_call([fail],A),Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs). +$pp_expl_interp_goal(Goal,Depth,_CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) :- + functor(Goal,F,N), + $pd_is_prob_pred(F,N),!, + CIDs = [Gid|CIDs0], + SWs = SWs0, + SimCIDs = [Goal|SimCIDs0], + SimSWs = SimSWs0, + c_SAVE_AR(AR), + c_next_global_call_number(CallNo), + $expl_interp_and_monitor_prob_goal(Goal,Depth,Gid,CallNo,AR). +$pp_expl_interp_goal(Goal,Depth,CP, + CIDs0,CIDs,SWs0,SWs, + SimCIDs0,SimCIDs,SimSWs0,SimSWs) :- + CIDs = CIDs0, + SWs = SWs0, + SimCIDs = SimCIDs0, + SimSWs = SimSWs0, + ( c_is_debug_mode -> + eval_debug_call(Goal,Depth,CP) + ; eval_call(Goal,CP) + ). + +%%---------------------------------------------------------------------------- + +$expl_interp_and_monitor_prob_goal(Call,Depth,Gid,CallNo,AR) ?=> + c_get_dg_flag(Flag), + $print_call(Flag,' Call: ',Call,Depth,CallNo,AR), + Depth1 is Depth + 1, + $expl_interp_single_call(Call,Depth1,Gid), + $switch_skip_off(AR), + $eval_call_exit(Call,Depth,CallNo,AR). +$expl_interp_and_monitor_prob_goal(Call,Depth,_Gid,CallNo,AR) => + c_get_dg_flag(Flag), + $print_call(Flag,' Fail: ',Call,Depth,CallNo,AR), + fail. + +$expl_interp_single_call(Goal,Depth,Gid) :- % suppress re-computation + savecp(CP1), + clause(Goal,Body), + $pp_expl_interp_goal(Body,Depth,CP1, + [],BodyCIDs,[],BodySWs, + [],SimCIDs,[],SimSWs), + % BodyCIDs is a list of children in Body + % BodySWs is a list of switches in Body + $pc_prism_goal_id_register(Goal,Gid), + ( (BodyCIDs == [], BodySWs == []) -> true + ; c_get_dg_flag(Flag), + c_next_global_call_number(CallNo), + $print_call(Flag,' Add: ',path(Goal,SimCIDs,SimSWs),Depth,CallNo,0), + $prism_eg_path(Gid,BodyCIDs,BodySWs) + ). + +%%---------------------------------------------------------------------------- + +$prism_eg_path(Pid,CIDs,SWs) :- $pc_add_egraph_path(Pid,CIDs,SWs). + +$prism_expl_msw(Sw,V,SwInsId) :- + get_values1(Sw,Values), + ( $pc_prism_sw_id_get(Sw,SwId) -> true + ; $pc_prism_sw_id_register(Sw,SwId), + $pp_export_switch(SwId,Sw,Values) + ), % vsc !, + member(V,Values), + $pc_prism_sw_ins_id_get(msw(Sw,V),SwInsId). + +%%---------------------------------------------------------------------------- + +$pp_export_switch(SwId,Sw,Values) :- + $pp_encode_switch_instances(Sw,Values,SwInsIds), + $pc_export_switch(SwId,SwInsIds). + +$pp_encode_switch_instances(_Sw,[],[]). +$pp_encode_switch_instances(Sw,[V|Vs],[Id|Ids]) :- + $pc_prism_sw_ins_id_register(msw(Sw,V),Id),!, + $pp_encode_switch_instances(Sw,Vs,Ids). + +%%---------------------------------------------------------------------------- + +$pp_print_goal_message(MsgS) :- + MsgS > 0, !, + get_prism_flag(search_progress,Ival), + Ival > 0, !, + global_get($pg_num_goals,N), + ( N =:= 0 -> + format("#goals: 0",[]),flush_output, + N1 is N + 1, + global_set($pg_num_goals,N1) + ; N > 0 -> + ( N mod (Ival * 10) =:= 0 -> format("~w",[N]),flush_output + ; N mod Ival =:= 0 -> format(".",[]),flush_output + ; true + ), + N1 is N + 1, + global_set($pg_num_goals,N1) + ; true + ). +$pp_print_goal_message(_). diff --git a/packages/prism/src/prolog/up/flags.pl b/packages/prism/src/prolog/up/flags.pl new file mode 100644 index 000000000..43099c22f --- /dev/null +++ b/packages/prism/src/prolog/up/flags.pl @@ -0,0 +1,291 @@ +%% -*- Prolog -*- + +%% prism_flag(Name,Type,Init,Pred) defines a new Prism flag where each +%% argument indicates: +%% +%% Name : the flag name +%% Type : the domain of possible values +%% Init : the default value +%% Pred : the auxiliary predicate (see below) or `$none'. +%% +%% Type should be one of the followings: +%% +%% bool: +%% boolean value taking either `on' or `off' +%% +%% enum(Cands): +%% atom occuring in Cands +%% +%% term(Cands): +%% term matching one of patterns in Cands +%% +%% integer(Min,Max): +%% integral value from Min to Max (Min/Max can be -inf/+inf) +%% +%% float(Min,Max): +%% floating value from Min to Max (Min/Max can be -inf/+inf) +%% +%% +%% Declaring Auxiliary Predicates +%% ------------------------------ +%% +%% An auxiliary predicate is called just after a new value is set to +%% the corresponding flag. A typical purpose of auxiliary predicates +%% is to have the new value notified to the C routines. +%% +%% Auxiliary predicates must be of the arity one, and are called with +%% the argument indicating the new value set to the flag as described +%% below (depending on Type): +%% +%% bool: +%% an integer 1 (on) or 0 (off). +%% +%% enum(Cands): +%% an integer representing the index (starting at 0) at which the +%% specified atom exists in Cands +%% +%% term(Cands): +%% the specified term +%% +%% integer(Min,Max): +%% the specified integral value +%% +%% float(Min,Max): +%% the specified floating value +%% +%% [TODO: describe open/half-open ranges of floating values] +%% [TODO: describe special(PredName)] +%% +%% [Note] Make sure to declare flags in alphabetical order. + +$pp_prism_flag(clean_table,bool,on,$none). +$pp_prism_flag(daem,bool,off,$pc_set_daem). +$pp_prism_flag(data_source,term([none,data/1,file(_)]),data/1,$none). +$pp_prism_flag(default_sw,special($pp_check_default_sw),uniform,$none). +$pp_prism_flag(default_sw_a,special($pp_check_default_sw_a),1,$none). +$pp_prism_flag(default_sw_d,special($pp_check_default_sw_d),0,$none). +$pp_prism_flag(em_message,bool,on,$none). +$pp_prism_flag(em_progress,integer(1,+inf),10,$pc_set_em_progress). +$pp_prism_flag(epsilon,float(@0,+inf),0.0001,$pc_set_prism_epsilon). +$pp_prism_flag(error_on_cycle,bool,on,$pc_set_error_on_cycle). +$pp_prism_flag(explicit_empty_expls,bool,on,$pc_set_explicit_empty_expls). +$pp_prism_flag(fix_init_order,bool,on,$pc_set_fix_init_order). +$pp_prism_flag(init,enum([none,noisy_u,random]),random,$pc_set_init_method). +$pp_prism_flag(itemp_init,float(@0,1),0.1,$pc_set_itemp_init). +$pp_prism_flag(itemp_rate,float(@1,+inf),1.5,$pc_set_itemp_rate). +$pp_prism_flag(learn_message,special($pp_check_learn_message),all,$none). +$pp_prism_flag(learn_mode,enum([params,hparams,both]),params,$none). +$pp_prism_flag(log_scale,bool,off,$pc_set_log_scale). +$pp_prism_flag(max_iterate,special($pp_check_max_iterate), + default,$pc_set_max_iterate). +$pp_prism_flag(reset_hparams,bool,on,$none). +$pp_prism_flag(restart,integer(1,+inf),1,$pc_set_num_restart). +$pp_prism_flag(rerank,integer(1,+inf),10,$none). +$pp_prism_flag(search_progress,integer(1,+inf),10,$none). +$pp_prism_flag(show_itemp,bool,off,$pc_set_show_itemp). +$pp_prism_flag(sort_hindsight,enum([by_goal,by_prob]),by_goal,$none). +$pp_prism_flag(std_ratio,float(@0,+inf),0.2,$pc_set_std_ratio). +$pp_prism_flag(verb,special($pp_check_verb),none,$pp_set_verb). +$pp_prism_flag(viterbi_mode,enum([params,hparams]),params,$none). +$pp_prism_flag(warn,bool,off,$pc_set_warn). +$pp_prism_flag(write_call_events,special($pp_check_write_call_events),all,$none). + +% first flag is enabled by default +$pp_prism_flag_exclusive([default_sw_d,default_sw_a]). + +$pp_prism_flag_renamed(default_sw_h,default_sw_d). + +$pp_prism_flag_deleted(avg_branch,'1.11'). +$pp_prism_flag_deleted(layer_check,'1.11'). +$pp_prism_flag_deleted(log_viterbi,'2.0'). +$pp_prism_flag_deleted(dynamic_default_sw,'2.0'). +$pp_prism_flag_deleted(dynamic_default_sw_h,'2.0'). +$pp_prism_flag_deleted(params_after_vbem,'2.0'). +$pp_prism_flag_deleted(reduce_copy,'2.0'). +$pp_prism_flag_deleted(scaling,'2.0'). +$pp_prism_flag_deleted(scaling_factor,'2.0'). +$pp_prism_flag_deleted(smooth,'2.0'). + +%%---------------------------------------- + +get_prism_flag(Name,Value) :- + $pp_prism_flag(Name,_,_,_), + $pp_variable_prism_flag(Name,VarName), + global_get(VarName,Value). +get_prism_flag(Name,Value) :- + $pp_prism_flag_renamed(Name0,Name1), + Name == Name0,!, + $pp_raise_warning($msg(3102),[Name,Name1]), + $pp_variable_prism_flag(Name1,VarName), + global_get(VarName,Value). + +%%---------------------------------------- + +set_prism_flag(Name,Value) :- + $pp_require_prism_flag(Name,$msg(3100),set_prism_flag/2), + $pp_require_prism_flag_value(Name,Value,$msg(3101),set_prism_flag/2), + ( current_predicate($pp_prism_flag_deleted/2), + $pp_prism_flag_deleted(Name,Version) + -> $pp_raise_domain_error($msg(3103),[Name,Version],[prism_flag,Name], + set_prism_flag/2) + ; current_predicate($pp_prism_flag_deleted/3), + $pp_prism_flag_deleted(Name,Value,Version) + -> $pp_raise_domain_error($msg(3104),[Name,Value,Version], + [prism_flag_value(Name),Value],set_prism_flag/2) + ; true ), + ( $pp_prism_flag(Name,Type,_,Pred) -> + Name1 = Name + ; $pp_prism_flag_renamed(Name,Name1),$pp_prism_flag(Name1,Type,_,Pred) -> + $pp_raise_warning($msg(3102),[Name,Name1]) + ), + $pp_check_prism_flag(Type,Value,SValue,IValue), + $pp_disable_prism_flag(Name1), + $pp_variable_prism_flag(Name1,VarName), + global_set(VarName,SValue), + ( Pred == $none -> true + ; Term =.. [Pred,IValue], call(Term) + ),!. + +%%---------------------------------------- + +reset_prism_flags :- + set_default_prism_flags, + disable_exclusive_prism_flags. + +set_default_prism_flags :- + $pp_prism_flag(Name,_,Value,_), + set_prism_flag(Name,Value), + fail. +set_default_prism_flags. + +disable_exclusive_prism_flags :- + ( current_predicate($pp_prism_flag_exclusive/1), + $pp_prism_flag_exclusive([_|Names]), + $pp_disable_prism_flag1(Names), + fail + ; true + ). + +%%---------------------------------------- + +show_prism_flags :- + get_prism_flag(Name,Value), + ( Value = $disabled -> Value1 = '(disabled)' + ; Value1 = Value + ), + format("~w~22|: ~w~n",[Name,Value1]), + fail. +show_prism_flags. + +%%---------------------------------------- +%% aliases + +current_prism_flag(Name,Value) :- get_prism_flag(Name,Value). + +show_prism_flag :- show_prism_flags. +show_flags :- show_prism_flags. +show_flag :- show_prism_flags. + +$pp_variable_prism_flag(Name,VarName) :- + atom_chars(Name,Name1), + VarName1 = [$,p,g,'_',f,l,a,g,'_'|Name1], + atom_chars(VarName,VarName1). + +%%---------------------------------------- + +$pp_check_prism_flag(Type,Value,SValue,IValue), Type = bool => + nth0(IValue,[off,on],Value),!, + SValue = Value. +$pp_check_prism_flag(Type,Value,SValue,IValue), Type = enum(Cands) => + nth0(IValue,Cands,Value),!, + SValue = Value. +$pp_check_prism_flag(Type,Value,SValue,IValue), Type = term(Patts) => + member(Value,Patts),!, + SValue = Value, + IValue = Value. +$pp_check_prism_flag(Type,Value,SValue,IValue), Type = integer(Min,Max) => + integer(Value), + $pp_check_min_max(Value,Min,Max),!, + SValue = Value, + IValue = Value. +$pp_check_prism_flag(Type,Value,SValue,IValue), Type = float(Min,Max) => + number(Value), + $pp_check_min_max(Value,Min,Max),!, + SValue = Value, + IValue is float(Value). +$pp_check_prism_flag(Type,Value,SValue,IValue), Type = special(PredName) => + call(PredName,Value,SValue,IValue). % B-Prolog extension + +$pp_check_min_max(Value,Min,Max) :- + ( Min = -inf -> true + ; Min = @Min0 -> Min0 < Value + ; Min =< Value + ),!, + ( Max = +inf -> true + ; Max = @Max0 -> Max0 > Value + ; Max >= Value + ),!. + +$pp_check_max_iterate(0,inf,0) :- $pp_raise_warning($msg(3105),[0,inf]). +$pp_check_max_iterate(inf,inf,0). +$pp_check_max_iterate(default,default,-1). +$pp_check_max_iterate(N,N,N) :- integer(N), N > 0. + +$pp_check_default_sw(V,V,V) :- + ( V = f_geometric(B), number(B), B > 1.0 + ; V = f_geometric(B,T), number(B), B > 1.0, member(T,[asc,desc]) + ; member(V,[none,uniform,f_geometric]) + ). + +$pp_check_default_sw_a(V,V,V) :- + ( number(V), V > 0.0 + ; V = uniform(U), number(U), U > 0.0 + ; member(V,[none,uniform]) + ). + +$pp_check_default_sw_d(V,V,V) :- + ( number(V), V >= 0.0 + ; V = uniform(U), number(U), U >= 0.0 + ; member(V,[none,uniform]) + ). + +$pp_check_verb(none,none,[0,0]). +$pp_check_verb(em,em,[1,0]). +$pp_check_verb(graph,graph,[0,1]). +$pp_check_verb(full,full,[1,1]). +$pp_check_verb(off,none,[0,0]) :- $pp_raise_warning($msg(3105),[off,none]). +$pp_check_verb(on,full,[1,1]) :- $pp_raise_warning($msg(3105),[on,full]). + +$pp_check_write_call_events(X,Y,Y) :- $pp_write_call_events(X,Y),!. +$pp_check_write_call_events(off,off,off) :- !. + +$pp_check_learn_message(X,Y,Y) :- $pp_learn_message_events(X,Y),!. +$pp_check_learn_message(off,off,off) :- !. + +%% disable competitors + +$pp_disable_prism_flag(Name) :- + ( current_predicate($pp_prism_flag_exclusive/1), + $pp_prism_flag_exclusive(Competitors), + select(Name,Competitors,Names), % B-Prolog's built-in + $pp_disable_prism_flag1(Names), + fail + ; true + ). + +$pp_disable_prism_flag1([]). +$pp_disable_prism_flag1([Name|Names]) :- + $pp_variable_prism_flag(Name,VarName), + global_set(VarName,$disabled),!, + $pp_disable_prism_flag1(Names). + +%% check the availability of the flag (Note: Name must be ground) +$pp_is_enabled_flag(Name) :- + \+ get_prism_flag(Name,$disabled). + +%%---------------------------------------- +%% auxiliary predicates + +$pp_set_verb([EM,Graph]) :- + $pc_set_verb_em(EM), + $pc_set_verb_graph(Graph). diff --git a/packages/prism/src/prolog/up/hash.pl b/packages/prism/src/prolog/up/hash.pl new file mode 100644 index 000000000..997451f66 --- /dev/null +++ b/packages/prism/src/prolog/up/hash.pl @@ -0,0 +1,42 @@ +%% Assumption: +%% h(F) = h(G) where F and G are variants and h is the hash function + +% In YAP use the standard routines: + +$pp_hashtable_get(T,K,V) :- hashtable_get(T,K,V). +$pp_hashtable_put(T,K,V) :- hashtable_put(T,K,V). + +/****** vsc: commented out for YAP + +$pp_hashtable_get(T,K,V), T = $hshtb(_,_) => hashtable_get(T,K,V). +$pp_hashtable_get(T,_,_) => $pp_hashtable_throw(T,$pp_hashtable_get/3). + +$pp_hashtable_put(T,K,V), T = $hshtb(N0,A) => + hash_code(K,H), + functor(A,_,M), + I is (H mod M) + 1, + arg(I,A,L), + member(KV,L), + ( var(KV) -> + KV = (K = V), + N1 is N0 + 1, + setarg(1,T,N1), + ( N1 > 2 * M + 1, M < 32700 -> $hashtable_expand_buckets(T) + ; true % #buckets should not exceed 65536 + ) + ; KV = (Key = _), + variant(Key,K) -> setarg(2,KV,V) + ), !. +$pp_hashtable_put(T,_,_) => + $pp_hashtable_throw(T,$pp_hashtable_put/3). + +*/ + +$pp_hashtable_throw(T,Source) :- + ( nonvar(T) -> true + ; throw(error(instantiation_error,Source)) + ), + ( T ?= $hshtb(_,_) -> true + ; throw(error(type_error(hashtable,T),Source)) + ), !, + fail. % should not reach here diff --git a/packages/prism/src/prolog/up/hindsight.pl b/packages/prism/src/prolog/up/hindsight.pl new file mode 100644 index 000000000..da6348585 --- /dev/null +++ b/packages/prism/src/prolog/up/hindsight.pl @@ -0,0 +1,497 @@ +%%%% +%%%% Hindsight routine with C interface +%%%% + +%% +%% hindsight(G,SubG,HProbs) :- +%% output hindsight probs of subgoals that matches with SubG given G +%% +%% hindsight(G,SubG) :- print hindsight probs of SubG given G +%% + +hindsight(G) :- hindsight(G,_). + +hindsight(G,SubG) :- + hindsight(G,SubG,HProbs), + ( HProbs == [] -> $pp_raise_warning($msg(1404)) + ; format("hindsight probabilities:~n",[]), + $pp_print_hindsight_probs(HProbs) + ). + +hindsight(G,SubG,HProbs) :- + $pp_require_tabled_probabilistic_atom(G,$msg(0006),hindsight/3), + ( nonvar(SubG) -> $pp_require_callable(SubG,$msg(1403),hindsight/3) + ; true + ), + $pp_clean_infer_stats, + cputime(T0), + $pp_hindsight_core(G,SubG,HProbs0), + $pp_sort_hindsight_probs(HProbs0,HProbs), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_hindsight_stats1(InfTime),!. + +hindsight_agg(G,Agg) :- + hindsight_agg(G,Agg,HProbs), + ( HProbs == [] -> $pp_raise_warning($msg(1404)) + ; format("hindsight probabilities:~n",[]), + $pp_print_hindsight_probs_agg(HProbs) + ). + +hindsight_agg(G,Agg,HProbs) :- + $pp_require_tabled_probabilistic_atom(G,$msg(0006),hindsight_agg/3), + $pp_require_hindsight_aggregate_pattern(Agg,$msg(1402),hindsight_agg/3), + $pp_clean_infer_stats, + cputime(T0), + $pp_get_subgoal_from_agg(Agg,SubG),!, + $pp_hindsight_core(G,SubG,HProbs0), + $pp_aggregate_hindsight_probs(Agg,HProbs0,HProbs1), + $pp_sort_hindsight_probs_agg(HProbs1,HProbs), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_hindsight_stats1(InfTime),!. + +$pp_hindsight_core(G,SubG,HProbs) :- + ground(G),!, + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T0), + $pp_find_explanations(G),!, + cputime(T1), + $pp_compute_hindsight(G,SubG,HProbs), + cputime(T2), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T1 - T0, + NumCompTime is T2 - T1, + $pp_assert_hindsight_stats2(SearchTime,NumCompTime),!. + +$pp_hindsight_core(G,SubG,HProbs) :- + copy_term(G,GoalCp), + ( $pp_trans_one_goal(GoalCp,CompGoal) -> BodyGoal = CompGoal + ; BodyGoal = (savecp(CP),Depth=0, + $pp_expl_interp_goal(GoalCp,Depth,CP,[],_,[],_,[],_,[],_)) + ), + $pp_create_dummy_goal(DummyGoal), + Clause = (DummyGoal:-BodyGoal, + $pc_prism_goal_id_register(GoalCp,GId), + $pc_prism_goal_id_register(DummyGoal,HId), + $prism_eg_path(HId,[GId],[])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_),[Clause]), + pred('$damon_load',0,_,_,_,[('$damon_load':-true)])], + consult_preds([],Prog), + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T0), + $pp_find_explanations(DummyGoal),!, + cputime(T1), + $pp_compute_hindsight(DummyGoal,SubG,HProbs), + cputime(T2), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T1 - T0, + NumCompTime is T2 - T1, + $pp_assert_hindsight_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + +% Sws = [sw(Id,Instances,Probs,PseudoCs,Fixed,FixedH),...] +$pp_compute_hindsight(Goal,SubG,HProbs) :- + $pp_collect_sw_info(Sws), + $pc_export_sw_info(Sws), + $pc_prism_goal_id_get(Goal,Gid), + garbage_collect, + $pc_compute_hindsight(Gid,SubG,0,HProbs0), % "0" indicates "unconditional" + $pp_decode_hindsight(HProbs0,HProbs),!. + +%% +%% Conditional version of hindsight computation: +%% + +chindsight(G) :- chindsight(G,_). + +chindsight(G,SubG) :- + chindsight(G,SubG,HProbs), + ( HProbs == [] -> $pp_raise_warning($msg(1404)) + ; format("conditional hindsight probabilities:~n",[]), + $pp_print_hindsight_probs(HProbs) + ). + +chindsight(G,SubG,HProbs) :- + $pp_require_tabled_probabilistic_atom(G,$msg(0006),chindsight/3), + ( nonvar(SubG) -> $pp_require_callable(SubG,$msg(1403),chindsight/3) + ; true + ), + $pp_clean_infer_stats, + cputime(T0), + $pp_chindsight_core(G,SubG,HProbs0), + $pp_sort_hindsight_probs(HProbs0,HProbs), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_hindsight_stats1(InfTime),!. + +chindsight_agg(G,Agg) :- + chindsight_agg(G,Agg,HProbs), + ( HProbs == [] -> $pp_raise_warning($msg(1404)) + ; format("conditional hindsight probabilities:~n",[]), + $pp_print_hindsight_probs_agg(HProbs) + ). + +chindsight_agg(G,Agg,HProbs) :- + $pp_require_tabled_probabilistic_atom(G,$msg(0006),chindsight_agg/3), + $pp_require_hindsight_aggregate_pattern(Agg,$msg(1402),chindsight_agg/3), + $pp_clean_infer_stats, + cputime(T0), + $pp_get_subgoal_from_agg(Agg,SubG),!, + $pp_chindsight_core(G,SubG,HProbs0), + $pp_aggregate_hindsight_probs(Agg,HProbs0,HProbs1), + $pp_sort_hindsight_probs_agg(HProbs1,HProbs), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_hindsight_stats1(InfTime),!. + +$pp_chindsight_core(G,SubG,HProbs) :- + ground(G),!, + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T0), + $pp_find_explanations(G),!, + cputime(T1), + $pp_compute_chindsight(G,SubG,HProbs), + cputime(T2), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T1 - T0, + NumCompTime is T2 - T1, + $pp_assert_hindsight_stats2(SearchTime,NumCompTime),!. + +$pp_chindsight_core(G,SubG,HProbs) :- + copy_term(G,GoalCp), + ( $pp_trans_one_goal(GoalCp,CompGoal) -> BodyGoal = CompGoal + ; BodyGoal = (savecp(CP),Depth=0, + $pp_expl_interp_goal(GoalCp,Depth,CP,[],_,[],_,[],_,[],_)) + ), + $pp_create_dummy_goal(DummyGoal), + Clause = (DummyGoal:-BodyGoal, + $pc_prism_goal_id_register(GoalCp,GId), + $pc_prism_goal_id_register(DummyGoal,HId), + $prism_eg_path(HId,[GId],[])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_),[Clause]), + pred('$damon_load',0,_,_,_,[('$damon_load':-true)])], + consult_preds([],Prog), % B-Prolog build-in + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T0), + $pp_find_explanations(DummyGoal),!, + cputime(T1), + $pp_compute_chindsight(DummyGoal,SubG,HProbs), + cputime(T2), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T1 - T0, + NumCompTime is T2 - T1, + $pp_assert_hindsight_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + +$pp_compute_chindsight(Goal,SubG,HProbs) :- + $pp_collect_sw_info(Sws), + $pc_export_sw_info(Sws), + $pc_prism_goal_id_get(Goal,Gid), + garbage_collect, + $pc_compute_hindsight(Gid,SubG,1,HProbs0), % "1" indicates "conditional" + $pp_decode_hindsight(HProbs0,HProbs),!. + +$pp_decode_hindsight([],[]). +$pp_decode_hindsight([[Gid,P]|HProbs0],[[G,P]|HProbs]) :- + $pc_prism_goal_term(Gid,G),!, + $pp_decode_hindsight(HProbs0,HProbs). + +$pp_get_subgoal_from_agg(Agg,SubG) :- + Agg =.. [F|Args0], + $pp_get_subgoal_from_agg1(Args0,Args1), + SubG =.. [F|Args1]. + +$pp_get_subgoal_from_agg1([],[]). +$pp_get_subgoal_from_agg1([A0|Args0],[A1|Args1]) :- + ( $pp_is_agg_patt(A0) -> A1 = _ + ; A1 = A0 + ),!, + $pp_get_subgoal_from_agg1(Args0,Args1). + +$pp_is_agg_patt(A) :- + ( var(A) -> true + ; member(A,[integer,atom,compound,length,d_length,depth,query,ignore]) + ). + +$pp_aggregate_hindsight_probs(Agg,HProbs0,HProbs) :- + $pp_group_hindsight_probs(Agg,HProbs0,HProbs1),!, + $pp_aggregate_hindsight_probs1(Agg,HProbs1,HProbs). + +$pp_group_hindsight_probs(Agg,HProbs0,HProbs) :- + $pp_insert_group_patt(Agg,HProbs0,HProbs1), + $pp_group_hindsight_probs1(HProbs1,HProbs2), + $delete_group_patt(HProbs2,HProbs). + +$pp_insert_group_patt(_,[],[]). +$pp_insert_group_patt(Agg,[[G,P]|HProbs0],[[GPatt,G,P]|HProbs]) :- + $pp_get_group_patt(Agg,G,GPatt),!, + $pp_insert_group_patt(Agg,HProbs0,HProbs). + +$delete_group_patt([],[]). +$delete_group_patt([Gr0|Groups0],[Gr|Groups]) :- + $delete_group_patt1(Gr0,Gr),!, + $delete_group_patt(Groups0,Groups). + +$delete_group_patt1([],[]). +$delete_group_patt1([[_GPatt,G,P]|HProbs0],[[G,P]|HProbs]) :- !, + $delete_group_patt1(HProbs0,HProbs). + +$pp_get_group_patt(Agg,G,GPatt) :- + Agg =.. [F|AggArgs], + G =.. [F|Args], + $pp_get_group_patt_args(AggArgs,Args,GPattArgs), + GPatt =.. [F|GPattArgs]. + +$pp_get_group_patt_args([],[],[]). +$pp_get_group_patt_args([AggA|AggArgs],[A|Args],[GPA|GPattArgs]) :- + ( nonvar(AggA) -> + ( AggA = integer -> + ( integer(A) -> GPA = A + ; $pp_raise_domain_error($msg(1405),[A],[integer,A], + $pp_group_hindsight_probs/3) + ) + ; AggA = atom -> + ( atom(A) -> GPA = A + ; $pp_raise_domain_error($msg(1406),[A],[atom,A], + $pp_group_hindsight_probs/3) + ) + ; AggA = compound -> + ( A = [] -> GPA = A + ; \+ ground(A) -> + $pp_raise_instanciation_error($msg(1407),[A], + $pp_group_hindsight_probs/3) + ; compound(A) -> GPA = A + ; $pp_raise_domain_error($msg(1407),[A],[compound,A], + $pp_group_hindsight_probs/3) + ) + ; AggA = length -> + ( (A = [] ; is_list(A)) -> length(A,L), GPA = length-L + ; $pp_raise_domain_error($msg(1408),[A],[list,A], + $pp_group_hindsight_probs/3) + ) + ; AggA = d_length -> + ( A = (D0-D1), is_list(D0), is_list(D1) + -> length(D0,L0), length(D1,L1), L is L0 - L1, GPA = d_length-L + ; $pp_raise_domain_error($msg(1409),[A],[d_list,A], + $pp_group_hindsight_probs/3) + ) + ; AggA = depth -> $pc_get_term_depth(A,D), GPA = depth-D + ; AggA = query -> GPA = * + ; AggA = ignore -> GPA = * + ; GPA = A + ) + ; GPA = * + ),!, + $pp_get_group_patt_args(AggArgs,Args,GPattArgs). + +$pp_group_hindsight_probs1(HProbs0,HProbs) :- + $pp_sort_remain_dup(HProbs0,HProbs1),!, + $pp_group_hindsight_probs2(HProbs1,HProbs). + +$pp_group_hindsight_probs2([],[]). +$pp_group_hindsight_probs2([U],[[U]]). +$pp_group_hindsight_probs2([U0|Us0],Us) :- !, + $pp_group_hindsight_probs2(U0,[U0],Us0,Us). + +$pp_group_hindsight_probs2(_,Us,[],[Us]). +$pp_group_hindsight_probs2(U0,Us0,[U1|Us1],Us) :- + ( U0 = [GPatt,_,_], U1 = [GPatt,_,_] -> + Us2 = [U1|Us0],!, + $pp_group_hindsight_probs2(U1,Us2,Us1,Us) + ; Us = [Us0|Us3],!, + $pp_group_hindsight_probs2(U1,[U1],Us1,Us3) + ). + +$pp_aggregate_hindsight_probs1(Agg,HProbs0,HProbs) :- + $pp_replace_agg_patt(Agg,HProbs0,HProbs1),!, + $pp_aggregate_hindsight_probs2(HProbs1,HProbs). + +$pp_replace_agg_patt(_,[],[]). +$pp_replace_agg_patt(Agg,[Gr0|Groups0],[Gr|Groups]) :- + $pp_replace_agg_patt1(Agg,Gr0,Gr),!, + $pp_replace_agg_patt(Agg,Groups0,Groups). + +$pp_replace_agg_patt1(_,[],[]). +$pp_replace_agg_patt1(Agg,[[G,P]|HProbs0],[[APatt,P]|HProbs]) :- + $pp_get_agg_patt(Agg,G,APatt),!, + $pp_replace_agg_patt1(Agg,HProbs0,HProbs). + +$pp_get_agg_patt(Agg,G,APatt) :- + Agg =.. [F|AggArgs], + G =.. [F|Args], + $pp_get_agg_patt_args(AggArgs,Args,APattArgs), + APatt =.. [F|APattArgs]. + +$pp_get_agg_patt_args([],[],[]). +$pp_get_agg_patt_args([AggA|AggArgs],[A|Args],[APA|APattArgs]) :- + ( nonvar(AggA) -> + ( AggA = integer -> + ( integer(A) -> APA = A + ; $pp_raise_domain_error($msg(1405),[A],[integer,A], + $pp_aggregate_hindsight_probs/3) + ) + ; AggA = atom -> + ( atom(A) -> APA = A + ; $pp_raise_domain_error($msg(1406),[A],[atom,A], + $pp_aggregate_hindsight_probs/3) + ) + ; AggA = compound -> + ( A = [] -> APA = A + ; \+ ground(A) -> + $pp_raise_instanciation_error($msg(1407),[A], + $pp_aggregate_hindsight_probs/3) + ; compound(A) -> APA = A + ; $pp_raise_domain_error($msg(1407),[A],[compound,A], + $pp_aggregate_hindsight_probs/3) + ) + ; AggA = length -> + ( (A = [] ; is_list(A)) -> length(A,L), APA = 'L'-L + ; $pp_raise_domain_error($msg(1408),[A],[list,A], + $pp_aggregate_hindsight_probs/3) + ) + ; AggA = d_length -> + ( A = (D0-D1), is_list(D0), is_list(D1) + -> length(D0,L0), length(D1,L1), L is L0 - L1, APA = 'DL'-L + ; $pp_raise_domain_error($msg(1409),[A],[d_list,A], + $pp_aggregate_hindsight_probs/3) + ) + ; AggA = depth -> $pc_get_term_depth(A,D), APA = 'D'-D + ; AggA = query -> APA = A + ; AggA = ignore -> APA = * + ; APA = A + ) + ; APA = * + ),!, + $pp_get_agg_patt_args(AggArgs,Args,APattArgs). + +$pp_aggregate_hindsight_probs2([],[]). +$pp_aggregate_hindsight_probs2([Gr0|Groups0],[Gr|Groups]) :- !, + $pp_aggregate_hindsight_probs3(Gr0,Gr),!, + $pp_aggregate_hindsight_probs2(Groups0,Groups). + +$pp_aggregate_hindsight_probs3(HProbs0,HProbs) :- + $pp_sort_remain_dup(HProbs0,HProbs1), + $pp_aggregate_hindsight_probs4(HProbs1,HProbs). + +$pp_aggregate_hindsight_probs4(HProbs0,HProbs) :- + ( get_prism_flag(log_scale,off) -> + $pp_aggregate_hindsight_probs5(HProbs0,HProbs) + ; $pp_aggregate_hindsight_probs5_log(HProbs0,HProbs) + ). + +$pp_aggregate_hindsight_probs5([],[]). +$pp_aggregate_hindsight_probs5([U],[U]). +$pp_aggregate_hindsight_probs5([[APatt,P]|Us0],Us) :- !, + $pp_aggregate_hindsight_probs5(APatt,P,Us0,Us). + +$pp_aggregate_hindsight_probs5(APatt,P,[],[[APatt,P]]). +$pp_aggregate_hindsight_probs5(APatt,P0,[[APatt1,P1]|Us1],Us) :- + ( APatt = APatt1 -> + P2 is P0 + P1,!, + $pp_aggregate_hindsight_probs5(APatt,P2,Us1,Us) + ; Us = [[APatt,P0]|Us2],!, + $pp_aggregate_hindsight_probs5(APatt1,P1,Us1,Us2) + ). + +% log-scale computation for tiny probabilities +$pp_aggregate_hindsight_probs5_log([],[]). +$pp_aggregate_hindsight_probs5_log([U],[U]). +$pp_aggregate_hindsight_probs5_log([[APatt,P]|Us0],Us) :- + $pp_aggregate_hindsight_probs5_log(APatt,P,1.0,Us0,Us). + +$pp_aggregate_hindsight_probs5_log(APatt,P0,Q,[],[[APatt,P]]) :- + P is P0 + log(Q),!. +$pp_aggregate_hindsight_probs5_log(APatt,P0,Q0,[[APatt1,P1]|Us1],Us) :- + ( APatt = APatt1 -> + ( P1 < -4096.0 -> % P1 == -Inf, i.e. exp(P1) == 0 + Q is Q0, % Note: exp(-4096) << Double.MIN_VALUE + P2 = P0 + ; P0 < -4096.0 -> % P0 == -Inf, i.e. exp(P0) == 0 + Q is 1.0, + P2 = P1 + ; P1 - P0 > log(1.0e+280) -> + Q is Q0 * exp(P0 - P1) + 1.0, + P2 = P1 + ; Q is Q0 + exp(P1 - P0), + P2 = P0 + ),!, + $pp_aggregate_hindsight_probs5_log(APatt,P2,Q,Us1,Us) + ; P is P0 + log(Q0), + Us = [[APatt,P]|Us2],!, + $pp_aggregate_hindsight_probs5_log(APatt1,P1,1.0,Us1,Us2) + ). + +$pp_sum_log_list([],0.0) :- !. +$pp_sum_log_list([LP],LP) :- !. +$pp_sum_log_list([LP|LPs],Sum) :- + $pp_sum_log_list(LPs,LP,1.0,SumRest),!, + Sum is LP + log(SumRest). + +$pp_sum_log_list([],_,SumRest,SumRest). +$pp_sum_log_list([LP|LPs],FirstLP,SumRest0,SumRest) :- + SumRest1 is SumRest0 + exp(LP - FirstLP),!, + $pp_sum_log_list(LPs,FirstLP,SumRest1,SumRest). + +%%%% +%%%% Sort hindsight proabilities +%%%% + +$pp_sort_hindsight_probs(HProbs0,HProbs) :- + ( get_prism_flag(sort_hindsight,by_goal) -> + $pp_sort_remain_dup(HProbs0,HProbs) + ; $pp_sort_hindsight_probs_by_prob(HProbs0,HProbs) + ). + +$pp_sort_hindsight_probs_by_prob(HProbs0,HProbs) :- + $pp_swap_hindsight_pair(HProbs0,HProbs1), + $pp_sort_remain_dup(HProbs1,HProbs2), + reverse(HProbs2,HProbs3), + $pp_swap_hindsight_pair(HProbs3,HProbs). + +$pp_swap_hindsight_pair([],[]) :- !. +$pp_swap_hindsight_pair([[X,Y]|XYs],[[Y,X]|YXs]) :- !, + $pp_swap_hindsight_pair(XYs,YXs). + +$pp_sort_hindsight_probs_agg([],[]) :- !. +$pp_sort_hindsight_probs_agg([Gr0|Groups0],[Gr|Groups]) :- + $pp_sort_hindsight_probs(Gr0,Gr),!, + $pp_sort_hindsight_probs_agg(Groups0,Groups). + +%%%% +%%%% Print hindsight probabilities +%%%% + +$pp_print_hindsight_probs([]). +$pp_print_hindsight_probs([[G,P]|HProbs]) :- + format(" ~w: ~15f~n",[G,P]),!, + $pp_print_hindsight_probs(HProbs). + +$pp_print_hindsight_probs_agg([]). +$pp_print_hindsight_probs_agg([Gr|Groups]) :- + $pp_print_hindsight_probs(Gr),!, + $pp_print_hindsight_probs_agg(Groups). + +%%%% Statistics + +$pp_assert_hindsight_stats1(InfTime0) :- + InfTime is InfTime0 / 1000.0, + assertz($ps_infer_time(InfTime)),!. + +$pp_assert_hindsight_stats2(SearchTime0,NumCompTime0) :- + SearchTime is SearchTime0 / 1000.0, + NumCompTime is NumCompTime0 / 1000.0, + assertz($ps_infer_search_time(SearchTime)), + assertz($ps_infer_calc_time(NumCompTime)),!. diff --git a/packages/prism/src/prolog/up/learn.pl b/packages/prism/src/prolog/up/learn.pl new file mode 100644 index 000000000..5145eda28 --- /dev/null +++ b/packages/prism/src/prolog/up/learn.pl @@ -0,0 +1,435 @@ +learn :- + get_prism_flag(learn_mode,Mode), + $pp_learn_main(Mode). +learn(Goals) :- + get_prism_flag(learn_mode,Mode), + $pp_learn_main(Mode,Goals). + +learn_p :- + $pp_learn_main(params). +learn_p(Goals) :- + $pp_learn_main(params,Goals). +learn_h :- + $pp_learn_main(hparams). +learn_h(Goals) :- + $pp_learn_main(hparams,Goals). +learn_b :- + $pp_learn_main(both). +learn_b(Goals) :- + $pp_learn_main(both,Goals). + +%% for the parallel version +$pp_learn_main(Mode) :- call($pp_learn_core(Mode)). +$pp_learn_main(Mode,Goals) :- call($pp_learn_core(Mode,Goals)). + +$pp_learn_data_file(FileName) :- + get_prism_flag(data_source,Source), + ( Source == none -> + $pp_raise_runtime_error($msg(1300),data_source_not_found, + $pp_learn_data_file/1) + ; Source == data/1 -> + ( current_predicate(data/1) -> data(FileName) + ; $pp_raise_runtime_error($msg(1301),data_source_not_found, + $pp_learn_data_file/1) + ) + ; Source = file(FileName) + ; $pp_raise_unmatched_branches($pp_learn_data_file/1) + ),!. + +$pp_learn_check_goals(Goals) :- + $pp_require_observed_data(Goals,$msg(1302),$pp_learn_core/1), + $pp_learn_check_goals1(Goals), + ( get_prism_flag(daem,on), + membchk(failure,Goals) + -> $pp_raise_runtime_error($msg(1305),daem_with_failure, + $pp_learn_core/1) + ; true + ). + +$pp_learn_check_goals1([]). +$pp_learn_check_goals1([G0|Gs]) :- + ( (G0 = goal(G,Count) ; G0 = count(G,Count) ; G0 = (Count times G) ) -> + $pp_require_positive_integer(Count,$msg(1306),$pp_learn_core/1) + ; G = G0 + ), + $pp_require_tabled_probabilistic_atom(G,$msg(1303),$pp_learn_core/1),!, + $pp_learn_check_goals1(Gs). + +$pp_learn_core(Mode) :- + $pp_learn_data_file(FileName), + load_clauses(FileName,Goals,[]),!, + $pp_learn_core(Mode,Goals). + +$pp_learn_core(Mode,Goals) :- + $pp_learn_check_goals(Goals), + $pp_learn_message(MsgS,MsgE,MsgT,MsgM), + $pc_set_em_message(MsgE), + cputime(Start), + $pp_learn_clean_info, + $pp_learn_reset_hparams(Mode), + $pp_trans_goals(Goals,GoalCountPairs,AllGoals),!, + global_set($pg_observed_facts,GoalCountPairs), + cputime(StartExpl), + global_set($pg_num_goals,0), + $pp_find_explanations(AllGoals),!, + $pp_print_num_goals(MsgS), + cputime(EndExpl), +% vsc statistics(table,[TableSpace,_]), +TableSpace = 0, % not supported in YAP (it should be). + ( MsgM == 0 -> true + ; format("Exporting switch information to the EM routine ... ",[]) + ), + flush_output, + $pp_collect_init_switches(Sws), + $pc_export_sw_info(Sws), + ( MsgM == 0 -> true ; format("done~n",[]) ), + $pp_observed_facts(GoalCountPairs,GidCountPairs, + 0,Len,0,NGoals,-1,FailRootIndex), + $pc_prism_prepare(GidCountPairs,Len,NGoals,FailRootIndex), + cputime(StartEM), + $pp_em(Mode,Output), + cputime(EndEM), + $pc_import_occ_switches(NewSws,NSwitches,NSwVals), + $pp_decode_update_switches(Mode,NewSws), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_delete_tmp_out, + cputime(End), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_learn_stats(Mode,Output,NSwitches,NSwVals,TableSpace, + Start,End,StartExpl,EndExpl,StartEM,EndEM,1000), + ( MsgT == 0 -> true ; $pp_print_learn_stats_message ), + ( MsgM == 0 -> true ; $pp_print_learn_end_message(Mode) ),!. + +$pp_learn_clean_info :- + $pp_clean_dummy_goal_table, + $pp_clean_graph_stats, + $pp_clean_learn_stats, + $pp_init_tables_aux, + $pp_init_tables_if_necessary,!. + +$pp_learn_reset_hparams(Mode) :- + ( Mode == params -> true + ; get_prism_flag(reset_hparams,on) -> set_sw_all_a(_) + ; true + ). + +$pp_print_num_goals(MsgS) :- + ( MsgS == 0 -> true + ; global_get($pg_num_goals,N),format("(~w)~n",[N]),flush_output + ). + +$pp_em(params,Output) :- + $pc_prism_em(Iterate,LogPost,LogLike,BIC,CS,ModeSmooth), + Output = [Iterate,LogPost,LogLike,BIC,CS,ModeSmooth]. +$pp_em(hparams,Output) :- + $pc_prism_vbem(IterateVB,FreeEnergy), + Output = [IterateVB,FreeEnergy]. +$pp_em(both,Output) :- + $pc_prism_both_em(IterateVB,FreeEnergy), + Output = [IterateVB,FreeEnergy]. + +$pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared) :- + NNodes is NGoalNodes + NSwNodes, + assertz($ps_num_subgraphs(NSubgraphs)), + assertz($ps_num_nodes(NNodes)), + assertz($ps_num_goal_nodes(NGoalNodes)), + assertz($ps_num_switch_nodes(NSwNodes)), + assertz($ps_avg_shared(AvgShared)),!. + +$pp_assert_learn_stats(Mode,Output,NSwitches,NSwVals,TableSpace, + Start,End,StartExpl,EndExpl,StartEM,EndEM,UnitsPerSec) :- + assertz($ps_num_switches(NSwitches)), + assertz($ps_num_switch_values(NSwVals)), + ( integer(TableSpace) -> assertz($ps_learn_table_space(TableSpace)) ; true ), + Time is (End - Start) / UnitsPerSec, + assertz($ps_learn_time(Time)), + TimeExpl is (EndExpl - StartExpl) / UnitsPerSec, + assertz($ps_learn_search_time(TimeExpl)), + TimeEM is (EndEM - StartEM) / UnitsPerSec, + assertz($ps_em_time(TimeEM)), + $pp_assert_learn_stats_sub(Mode,Output),!. + +$pp_assert_learn_stats_sub(params,Output) :- + Output = [Iterate,LogPost,LogLike,BIC,CS,ModeSmooth], + assertz($ps_num_iterations(Iterate)), + ( ModeSmooth > 0 -> assertz($ps_log_post(LogPost)) ; true ), + assertz($ps_log_likelihood(LogLike)), + assertz($ps_bic_score(BIC)), + ( ModeSmooth > 0 -> assertz($ps_cs_score(CS)) ; true ),!. + +$pp_assert_learn_stats_sub(hparams,Output) :- + Output = [IterateVB,FreeEnergy], + assertz($ps_num_iterations_vb(IterateVB)), + assertz($ps_free_energy(FreeEnergy)),!. + +$pp_assert_learn_stats_sub(both,Output) :- + Output = [IterateVB,FreeEnergy], + assertz($ps_num_iterations_vb(IterateVB)), + assertz($ps_free_energy(FreeEnergy)),!. + +$pp_print_learn_stats_message :- + format("Statistics on learning:~n",[]), + ( $pp_print_learn_stats_message_sub,fail ; true ),!. + +$pp_print_learn_stats_message_sub :- + ( $ps_num_nodes(L), + format("~tGraph size: ~w~n",[L]) + ; $ps_num_switches(L), + format("~tNumber of switches: ~w~n",[L]) + ; $ps_num_switch_values(L), + format("~tNumber of switch instances: ~w~n",[L]) + ; $ps_num_iterations_vb(L), + format("~tNumber of iterations: ~w~n",[L]) + ; $ps_num_iterations(L), + format("~tNumber of iterations: ~w~n",[L]) + ; $ps_free_energy(L), + format("~tFinal variational free energy: ~9f~n",[L]) + ; $ps_log_post(L), + format("~tFinal log of a posteriori prob: ~9f~n",[L]) + ; $ps_log_likelihood(L), \+ $ps_log_post(_), + format("~tFinal log likelihood: ~9f~n",[L]) + ; $ps_learn_time(L), + format("~tTotal learning time: ~3f seconds~n",[L]) + ; $ps_learn_search_time(L), + format("~tExplanation search time: ~3f seconds~n",[L]) + ; $ps_learn_table_space(L), + format("~tTotal table space used: ~w bytes~n",[L]) + ). + +$pp_print_learn_end_message(Mode) :- + ( Mode == params -> + format("Type show_sw to show the probability distributions.~n",[]) + ; Mode == hparams -> + format("Type show_sw_a/show_sw_d to show the probability distributions.~n",[]) + ; Mode == both -> + format("Type show_sw_pa/show_sw_pd to show the probability distributions.~n",[]) + ). + +$pp_clean_graph_stats :- + retractall($ps_num_subgraphs(_)), + retractall($ps_num_nodes(_)), + retractall($ps_num_goal_nodes(_)), + retractall($ps_num_switch_nodes(_)), + retractall($ps_avg_shared(_)),!. + +$pp_clean_learn_stats :- + retractall($ps_log_likelihood(_)), + retractall($ps_log_post(_)), + retractall($ps_num_switches(_)), + retractall($ps_num_switch_values(_)), + retractall($ps_num_iterations(_)), + retractall($ps_num_iterations_vb(_)), + retractall($ps_bic_score(_)), + retractall($ps_cs_score(_)), + retractall($ps_free_energy(_)), + retractall($ps_learn_time(_)), + retractall($ps_learn_search_time(_)), + retractall($ps_em_time(_)), + retractall($ps_learn_table_space(_)),!. + +$pp_collect_init_switches(Sws) :- + $pc_prism_sw_count(N), + $pp_collect_init_switches(0,N,Sws). + +$pp_collect_init_switches(Sid,N,SwInsList) :- Sid >= N,!, + SwInsList = []. +$pp_collect_init_switches(Sid,N,SwInsList) :- + $pc_prism_sw_term(Sid,Sw), + SwInsList = [sw(Sid,Instances,Pbs,Deltas,FixedP,FixedH)|SwInsList1], + $pp_get_parameters(Sw,Values,Pbs),!, + $pp_get_hyperparameters(Sw,Values,_,Deltas),!, + ( $pd_fixed_parameters(Sw) -> FixedP = 1 ; FixedP = 0 ), + ( $pd_fixed_hyperparameters(Sw) -> FixedH = 1 ; FixedH = 0 ), + $pp_collect_sw_ins_ids(Sw,Values,Instances), + Sid1 is Sid + 1,!, + $pp_collect_init_switches(Sid1,N,SwInsList1). + +$pp_collect_sw_ins_ids(_Sw,[],[]). +$pp_collect_sw_ins_ids(Sw,[V|Vs],[I|Is]) :- + $pc_prism_sw_ins_id_get(msw(Sw,V),I),!, + $pp_collect_sw_ins_ids(Sw,Vs,Is). + +$pp_decode_update_switches(params,Sws) :- + $pp_decode_update_switches_p(Sws). +$pp_decode_update_switches(hparams,Sws) :- + $pp_decode_update_switches_h(Sws). +$pp_decode_update_switches(both,Sws) :- + $pp_decode_update_switches_b(Sws). + +$pp_decode_update_switches_p([]). +$pp_decode_update_switches_p([sw(_,SwInstances)|Sws]) :- + $pp_decode_switch_name(SwInstances,Sw), + $pp_decode_switch_instances(SwInstances,Updates), + get_values1(Sw,Values), + $pp_separate_updates(Values,Probs,_Deltas,Es,Updates), + ( retract($pd_parameters(Sw,_,_)) -> true ; true ), + assert($pd_parameters(Sw,Values,Probs)), + ( retract($pd_expectations(Sw,_,_)) -> true ; true), + ( retract($pd_hyperexpectations(Sw,_,_)) -> true ; true), + assert($pd_expectations(Sw,Values,Es)),!, + $pp_decode_update_switches_p(Sws). + +$pp_decode_update_switches_h([]). +$pp_decode_update_switches_h([sw(_,SwInstances)|Sws]) :- + $pp_decode_switch_name(SwInstances,Sw), + $pp_decode_switch_instances(SwInstances,Updates), + get_values1(Sw,Values), + $pp_separate_updates(Values,_Probs,Deltas,Es,Updates), + ( retract($pd_hyperparameters(Sw,_,_,_)) -> true ; true ), + $pp_delta_to_alpha(Deltas,Alphas), + assert($pd_hyperparameters(Sw,Values,Alphas,Deltas)), + ( retract($pd_expectations(Sw,_,_)) -> true ; true), + ( retract($pd_hyperexpectations(Sw,_,_)) -> true ; true), + assert($pd_hyperexpectations(Sw,Values,Es)),!, + $pp_decode_update_switches_h(Sws). + +$pp_decode_update_switches_b([]). +$pp_decode_update_switches_b([sw(_,SwInstances)|Sws]) :- + $pp_decode_switch_name(SwInstances,Sw), + $pp_decode_switch_instances(SwInstances,Updates), + get_values1(Sw,Values), + $pp_separate_updates(Values,Probs,Deltas,Es,Updates), + ( retract($pd_parameters(Sw,_,_)) -> true ; true ), + assert($pd_parameters(Sw,Values,Probs)), + ( retract($pd_hyperparameters(Sw,_,_,_)) -> true ; true ), + $pp_delta_to_alpha(Deltas,Alphas), + assert($pd_hyperparameters(Sw,Values,Alphas,Deltas)), + ( retract($pd_hyperexpectations(Sw,_,_)) -> true ; true), + ( retract($pd_expectations(Sw,_,_)) -> true ; true), + assert($pd_hyperexpectations(Sw,Values,Es)),!, + $pp_decode_update_switches_b(Sws). + +$pp_decode_switch_name([sw_ins(Sid,_,_,_)|_SwInstances],Sw) :- + $pc_prism_sw_ins_term(Sid,msw(Sw,_)). % only uses the first element + +$pp_decode_switch_instances([],[]). +$pp_decode_switch_instances([sw_ins(Sid,Prob,Delta,Expect)|SwInstances], + [(V,Prob,Delta,Expect)|Updates]) :- + $pc_prism_sw_ins_term(Sid,msw(_,V)),!, + $pp_decode_switch_instances(SwInstances,Updates). + +$pp_separate_updates([],[],[],[],_Updates). +$pp_separate_updates([V|Vs],[Prob|Probs],[Delta|Deltas],[E|Es],Updates) :- + member((V,Prob,Delta,E),Updates),!, + $pp_separate_updates(Vs,Probs,Deltas,Es,Updates). + +%% [NOTE] Non-ground goals has already been replaced by dummy goals, so all +%% goals are ground here. + +$pp_observed_facts([],[],Len,Len,NGoals,NGoals,FailRootIndex,FailRootIndex). +$pp_observed_facts([goal(Goal,Count)|GoalCountPairs],GidCountPairs, + Len0,Len,NGoals0,NGoals,FailRootIndex0,FailRootIndex) :- + % fails if the goal is ground but has no proof + ( $pc_prism_goal_id_get(Goal,Gid) -> + ( Goal == failure -> + NGoals1 = NGoals0, + FailRootIndex1 = Len0 + ; NGoals1 is NGoals0 + Count, + FailRootIndex1 = FailRootIndex0 + ), + GidCountPairs = [goal(Gid,Count)|GidCountPairs1], + Len1 is Len0 + 1 + ; $pp_raise_unexpected_failure($pp_observed_facts/8) + ),!, + $pp_observed_facts(GoalCountPairs,GidCountPairs1, + Len1,Len,NGoals1,NGoals,FailRootIndex1,FailRootIndex). + +%% Assumption: for any pair of terms F and F' (F's variant), hash codes for +%% F and F' are equal. +%% +%% For convenience on implementation of parallel learning, $pp_trans_goals/3 +%% is (internally) split into two predicates $pp_build_count_pairs/2 and +%% $pp_trans_count_pairs/3. +%% +%% The order of goal-count pairs may differ at every run due to the way of +%% implemention of hashtables. + +$pp_trans_goals(Goals,GoalCountPairs,AllGoals) :- + $pp_build_count_pairs(Goals,Pairs), + $pp_trans_count_pairs(Pairs,GoalCountPairs,AllGoals). + +$pp_build_count_pairs(Goals,Pairs) :- + new_hashtable(Table), + $pp_count_goals(Goals,Table), + hashtable_to_list(Table,Pairs0), + sort(Pairs0,Pairs). + +$pp_count_goals([],_). +$pp_count_goals([G0|Goals],Table) :- + ( G0 = goal(Goal,Count) -> true + ; G0 = count(Goal,Count) -> true + ; G0 = (Count times Goal) -> true + ; Goal = G0, Count = 1 + ), + ( ground(Goal) -> GoalCp = Goal + ; copy_term(Goal,GoalCp) + ), + ( $pp_hashtable_get(Table,GoalCp,Count0) -> + Count1 is Count0 + Count, + $pp_hashtable_put(Table,GoalCp,Count1) + ; $pp_hashtable_put(Table,GoalCp,Count) + ),!, + $pp_count_goals(Goals,Table). + +$pp_trans_count_pairs([],[],[]). +$pp_trans_count_pairs([Goal=Count|Pairs],GoalCountPairs,AllGoals) :- + $pp_build_dummy_goal(Goal,DummyGoal), + GoalCountPairs = [goal(DummyGoal,Count)|GoalCountPairs1], + AllGoals = [DummyGoal|AllGoals1],!, + $pp_trans_count_pairs(Pairs,GoalCountPairs1,AllGoals1). + +$pp_build_dummy_goal(Goal,DummyGoal) :- + ( Goal = msw(I,V) -> + ( ground(I) -> I = ICp ; copy_term(I,ICp) ), + ( ground(V) -> V = VCp ; copy_term(V,VCp) ), + $pp_create_dummy_goal(DummyGoal), + $pp_assert_dummy_goal(DummyGoal,Goal), + Clause = (DummyGoal :- $prism_expl_msw(ICp,VCp,Sid), + $pc_prism_goal_id_register(DummyGoal,Hid), + $prism_eg_path(Hid,[],[Sid])), + Prog = [pred(DummyGoal,0,_,_,tabled(_,_,_,_),[Clause]), + pred($damon_load,0,_,_,_,[($damon_load:-true)])], + consult_preds([],Prog) + ; ground(Goal) -> + DummyGoal = Goal % don't create dummy goals (wrappers) for + ; % ground goals to save memory. + $pp_create_dummy_goal(DummyGoal), + $pp_assert_dummy_goal(DummyGoal,Goal), + ( $pp_trans_one_goal(Goal,CompGoal) -> BodyGoal = CompGoal + ; BodyGoal = (savecp(CP),Depth=0, + $pp_expl_interp_goal(Goal,Depth,CP,[],_,[],_,[],_,[],_)) + ), + Clause = (DummyGoal:-BodyGoal, + $pc_prism_goal_id_register(Goal,GId), + $pc_prism_goal_id_register(DummyGoal,HId), + $prism_eg_path(HId,[GId],[])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_),[Clause]), + pred($damon_load,0,_,_,_,[($damon_load:-true)])], + consult_preds([],Prog) + ),!. + +$pp_assert_dummy_goal(DummyGoal,OrigGoal) :- + assertz($pd_dummy_goal_table(DummyGoal,OrigGoal)),!. + +$pp_clean_dummy_goal_table :- + retractall($pd_dummy_goal_table(_,_)). + +%%---------------------------------------- + +% just make a simple check +$pp_require_observed_data(Gs,MsgID,Source) :- + ( $pp_test_observed_data(Gs) -> true + ; $pp_raise_on_require([Gs],MsgID,Source,$pp_error_observed_data) + ). + +$pp_test_observed_data(Gs) :- + nonvar(Gs), + ( Gs = [failure] -> fail + ; Gs = [_|_] + ). + +$pp_error_observed_data(Gs,Error) :- + $pp_error_nonvar(Gs,Error), !. +$pp_error_observed_data(Gs,domain_error(observed_data,Gs)) :- + ( Gs = [failure] ; Gs \= [_|_] ), !. + diff --git a/packages/prism/src/prolog/up/list.pl b/packages/prism/src/prolog/up/list.pl new file mode 100644 index 000000000..f7683335a --- /dev/null +++ b/packages/prism/src/prolog/up/list.pl @@ -0,0 +1,882 @@ +%%-------------------------------- +%% Temporary Clauses + +:- dynamic $pd_temp_clause/2. +:- dynamic $pd_temp_clause/3. +:- dynamic $pd_temp_clause/4. + +:- global_set($pg_temp_clause_num,0). + +$pp_create_temp_clause_1(ID,X,Body) :- + $pp_create_temp_clause_num(ID), + assert(($pd_temp_clause(ID,X) :- Body)), !. + +$pp_create_temp_clause_2(ID,X,Y,Body) :- + $pp_create_temp_clause_num(ID), + assert(($pd_temp_clause(ID,X,Y) :- Body)), !. + +$pp_create_temp_clause_3(ID,X,Y,Z,Body) :- + $pp_create_temp_clause_num(ID), + assert(($pd_temp_clause(ID,X,Y,Z) :- Body)), !. + +$pp_delete_temp_clause_1(ID) :- + retractall($pd_temp_clause(ID,_)), + $pp_delete_temp_clause_num(ID), !. + +$pp_delete_temp_clause_2(ID) :- + retractall($pd_temp_clause(ID,_,_)), + $pp_delete_temp_clause_num(ID), !. + +$pp_delete_temp_clause_3(ID) :- + retractall($pd_temp_clause(ID,_,_,_)), + $pp_delete_temp_clause_num(ID), !. + +$pp_create_temp_clause_num(N) :- + global_get($pg_temp_clause_num,M), + N is M + 1, + global_set($pg_temp_clause_num,N), !. + +$pp_delete_temp_clause_num(N) :- + global_get($pg_temp_clause_num,N), + M is N - 1, + global_set($pg_temp_clause_num,M), !. +$pp_delete_temp_clause_num(_). + + +%%-------------------------------- +%% Base Predicates + +$pp_length(Xs,N) :- + $pp_length(Xs,0,N). + +$pp_length(Xs0,N0,N), Xs0 = [] => + N0 = N. +$pp_length(Xs0,N0,N), Xs0 = [_|Xs1] => + N1 is N0 + 1, + $pp_length(Xs1,N1,N). + +$pp_match(Patt,X) :- + \+ \+ ( number_vars(X,0,_), Patt ?= X ). + +$pp_copy_term(X0,X) :- + ground(X0) -> X0 = X ; copy_term(X0,X). + +$pp_count(Table,Key,N) :- + ( $pp_hashtable_get(Table,Key,N0) -> N is N0 + 1 ; N is 1 ), + $pp_hashtable_put(Table,Key,N). + +%%-------------------------------- +%% Stat: Means + +avglist(List,Mean) :- + $pp_meanlist(List,_,Mean,avglist/2). + +meanlist(List,Mean) :- + $pp_meanlist(List,_,Mean,meanlist/2). + +gmeanlist(List,Mean) :- + $pp_gmeanlist(List,_,Mean,gmeanlist/2). + +hmeanlist(List,Mean) :- + $pp_hmeanlist(List,_,Mean,hmeanlist/2). + +$pp_meanlist(List,N,M,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_numbers(List,$msg(2108),Source), + ( $pp_meanlist(List,0,N0,0,M0) -> + N0 = N, + M0 = M + ; throw(error(type_error(list,List),Source)) + ). + +$pp_meanlist(Xs,N0,N,M0,M), Xs = [] => + N0 = N, + M0 = M. +$pp_meanlist(Xs,N0,N,M0,M), Xs = [X|Xs1] => + N1 is N0 + 1, + M1 is M0 + (X - M0) / N1, + $pp_meanlist(Xs1,N1,N,M1,M). + +$pp_gmeanlist(List,N,M,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_numbers(List,$msg(2108),Source), + ( $pp_gmeanlist(List,0,N0,0,M0) -> + N0 = N, + M0 = M + ; throw(error(type_error(list,List),Source)) + ). + +$pp_gmeanlist(Xs,N0,N,M0,M), Xs = [] => + N = N0, M is exp(M0). +$pp_gmeanlist(Xs,N0,N,M0,M), Xs = [X|Xs1] => + N1 is N0 + 1, + M1 is M0 + (log(X) - M0) / N1, + $pp_gmeanlist(Xs1,N1,N,M1,M). + +$pp_hmeanlist(List,N,M,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_numbers(List,$msg(2108),Source), + ( $pp_hmeanlist(List,0,N0,0,M0) -> + N0 = N, M0 = M + ; throw(error(type_error(list,List),Source)) + ). + +$pp_hmeanlist(Xs,N0,N,M0,M), Xs = [] => + N = N0, M is 1 / M0. +$pp_hmeanlist(Xs,N0,N,M0,M), Xs = [X|Xs1] => + N1 is N0 + 1, + M1 is M0 + (1 / X - M0) / N1, + $pp_hmeanlist(Xs1,N1,N,M1,M). + + +%%-------------------------------- +%% Stat: Variance etc. + +varlistp(List,Var) :- + $pp_moment2(List,1,N,_,M2,varlistp/2), + Var is M2 / N. + +varlist(List,Var) :- + $pp_moment2(List,2,N,_,M2,varlist/2), + Var is M2 / (N - 1). + +stdlistp(List,Std) :- + $pp_moment2(List,1,N,_,M2,stdlistp/2), + Std is sqrt(M2 / N). + +stdlist(List,Std) :- + $pp_moment2(List,2,N,_,M2,stdlist/2), + Std is sqrt(M2 / (N - 1)). + +semlistp(List,Sem) :- + $pp_moment2(List,1,N,_,M2,semlistp/2), + Sem is sqrt(M2) / N. + +semlist(List,Sem) :- + $pp_moment2(List,2,N,_,M2,semlist/2), + Sem is sqrt(M2 / (N - 1) / N). + +skewlistp(List,Skew) :- + $pp_moment3(List,1,N,_,M2,M3,skewlistp/2), + $pp_compute_skew0(Skew,N,M2,M3). + +skewlist(List,Skew) :- + $pp_moment3(List,3,N,_,M2,M3,skewlist/2), + $pp_compute_skew1(Skew,N,M2,M3). + +kurtlistp(List,Kurt) :- + $pp_moment4(List,1,N,_,M2,_,M4,kurtlistp/2), + $pp_compute_kurt0(Kurt,N,M2,M4). + +kurtlist(List,Kurt) :- + $pp_moment4(List,4,N,_,M2,_,M4,kurtlist/2), + $pp_compute_kurt1(Kurt,N,M2,M4). + +$pp_moment2(List,MinN,N,M,M2,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_numbers(List,$msg(2108),Source), + $pp_moment2(List,0,N0,0,TmpM,0,TmpM2), + ( N0 >= MinN -> true + ; $pp_require_list_not_shorter_than(List,MinN,$msg(2103),Source) + ), + N0 = N, TmpM = M, TmpM2 = M2. + +$pp_moment3(List,MinN,N,M,M2,M3,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_numbers(List,$msg(2108),Source), + $pp_moment3(List,0,N0,0,TmpM,0,TmpM2,0,TmpM3), + ( N0 >= MinN -> true + ; $pp_require_list_not_shorter_than(List,MinN,$msg(2103),Source) + ), + N0 = N, TmpM = M, TmpM2 = M2, TmpM3 = M3. + +$pp_moment4(List,MinN,N,M,M2,M3,M4,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_numbers(List,$msg(2108),Source), + $pp_moment4(List,0,N0,0,TmpM,0,TmpM2,0,TmpM3,0,TmpM4), + ( N0 >= MinN -> true + ; $pp_require_list_not_shorter_than(List,MinN,$msg(2103),Source) + ), + N0 = N, TmpM = M, TmpM2 = M2, TmpM3 = M3, TmpM4 = M4. + +$pp_moment2(Xs,TmpN,N,TmpM,M,TmpM2,M2), Xs = [] => + TmpN = N, + TmpM = M, + TmpM2 = M2. +$pp_moment2(Xs,OldN,N,OldM,M,OldM2,M2), Xs = [X|Xs1] => + NewN is OldN + 1, + D is X - OldM, + E is D / NewN, + F is D * E * OldN, % == (X - OldM) * (X - NewM) + NewM is OldM + E, + NewM2 is OldM2 + F, + $pp_moment2(Xs1,NewN,N,NewM,M,NewM2,M2). + +$pp_moment3(Xs,TmpN,N,TmpM,M,TmpM2,M2,TmpM3,M3), Xs = [] => + TmpN = N, + TmpM = M, + TmpM2 = M2, + TmpM3 = M3. +$pp_moment3(Xs,OldN,N,OldM,M,OldM2,M2,OldM3,M3), Xs = [X|Xs1] => + NewN is OldN + 1, + D is X - OldM, + E is D / NewN, + F is D * E * OldN, % == (X - OldM) * (X - OldN) + NewM is OldM + E, + NewM2 is OldM2 + F, + NewM3 is OldM3 + E * (F * (NewN - 2) - 3 * OldM2), + $pp_moment3(Xs1,NewN,N,NewM,M,NewM2,M2,NewM3,M3). + +$pp_moment4(Xs,TmpN,N,TmpM,M,TmpM2,M2,TmpM3,M3,TmpM4,M4), Xs = [] => + TmpN = N, + TmpM = M, + TmpM2 = M2, + TmpM3 = M3, + TmpM4 = M4. +$pp_moment4(Xs,OldN,N,OldM,M,OldM2,M2,OldM3,M3,OldM4,M4), Xs = [X|Xs1] => + NewN is OldN + 1, + D is X - OldM, + E is D / NewN, + F is D * E * OldN, % == (X - OldM) * (X - OldN) + NewM is OldM + E, + NewM2 is OldM2 + F, + NewM3 is OldM3 + E * (F * (NewN - 2) - 3 * OldM2), + NewM4 is OldM4 + E * (E * F * (NewN ** 2 - (NewN + 1)) - 2 * (OldM3 + NewM3)), + $pp_moment4(Xs1,NewN,N,NewM,M,NewM2,M2,NewM3,M3,NewM4,M4). + +$pp_compute_skew0(Skew,N,M2,M3) :- + Skew is M3 / M2 * sqrt(N / M2). + +$pp_compute_skew1(Skew,N,M2,M3) :- + Skew is M3 / M2 * sqrt((N - 1) / M2) * N / (N - 2). + +$pp_compute_kurt0(Kurt,N,M2,M4) :- + Kurt is M4 / (M2 * M2) * N - 3. + +$pp_compute_kurt1(Kurt,N,M2,M4) :- + F is M4 / (M2 * M2) * N * (N + 1), + G is 3 * (N - 1), + H is (N - 1) / (float(N - 2) * (N - 3)), % float(*) avoids overflow + Kurt is (F - G) * H. + + +%%-------------------------------- +%% Stat: Mode + +modelist(List,Mode) :- + $pp_modelist(List,Mode,modelist/2). + +amodelist(List,Modes) :- + $pp_amodelist(List,Modes,amodelist/2). + +rmodelist(List,Mode) :- + $pp_amodelist(List,Modes,rmodelist/2), + $pp_pmodelist(Modes,Mode). + +pmodelist(List,Mode) :- + $pp_pmodelist(List,Mode,pmodelist/2). + +$pp_modelist(List,Mode,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_nonvars(List,$msg(2110),Source), + new_hashtable(Table), + ( $pp_modelist(List,Table,_,0,Mode0) -> + $pp_copy_term(Mode0,Mode) + ; throw(error(type_error(list,List),Source)) + ). + +$pp_modelist(Xs,_,Y,_,Mode), Xs = [] => + Y = Mode. +$pp_modelist(Xs,Table,Y0,N0,Mode), Xs = [X|Xs1] => + $pp_count(Table,X,N), + ( $pp_modelist_cmp(N0,N,Y0,X) -> Y1 = X, N1 = N ; Y1 = Y0, N1 = N0 ), + $pp_modelist(Xs1,Table,Y1,N1,Mode). + +$pp_modelist_cmp(N0,N,_,_), N0 < N => true. +$pp_modelist_cmp(N0,N,_,_), N0 > N => fail. +$pp_modelist_cmp(_,_,X0,X) => + X0 @> X. + +$pp_amodelist(List,Modes,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_nonvars(List,$msg(2110),Source), + new_hashtable(Table), + ( $pp_amodelist(List,Table,_,0,Modes0) -> + $pp_copy_term(Modes0,Modes1), + sort(Modes1,Modes) + ; throw(error(type_error(list,List),Source)) + ). + +$pp_amodelist(Xs,_,Ys,_,Modes), Xs = [] => + Ys = Modes. +$pp_amodelist(Xs,Table,Ys0,N0,Modes), Xs = [X|Xs1] => + $pp_count(Table,X,N), + ( N0 < N -> + Ys1 = [X], N1 = N + ; N0 > N -> + Ys1 = Ys0, N1 = N0 + ; %% else + Ys1 = [X|Ys0], N1 = N0 + ), + $pp_amodelist(Xs1,Table,Ys1,N1,Modes). + +$pp_pmodelist(List,Mode,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_nonvars(List,$msg(2110),Source), + ( $pp_pmodelist(List,Mode0) -> + Mode0 = Mode + ; throw(error(type_error(list,List),Source)) + ). + +$pp_pmodelist(List,Mode) :- + $pp_length(List,L), $pc_random_int(L,I), nth0(I,List,Mode). + + +%%-------------------------------- +%% Stat: Median + +medianlist(List,Median) :- + $pp_medianlist(List,Median,medianlist/2). + +$pp_medianlist(List,Median,Source) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),Source), + $pp_require_numbers(List,$msg(2108),Source), + ( $pp_medianlist(List,Median0) -> + Median0 = Median + ; throw(error(type_error(list,List),Source)) + ). + +$pp_medianlist(List,Median) :- + $pp_length(List,L), + N is L // 2, + $pp_mergesort(0,L,List,_,Temp), + ( L mod 2 is 0 -> + nth1(N,Temp,A), + nth0(N,Temp,B), + Median is A + (B - A) / 2 % avoids overflow + ; nth0(N,Temp,Median) + ). + + +%%-------------------------------- +%% Stat: Min/Max + +minlist(List,Min) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),minlist/2), + $pp_require_numbers(List,$msg(2108),minlist/2), + Min is min(List). + +maxlist(List,Max) :- + $pp_require_list_not_shorter_than(List,1,$msg(2103),maxlist/2), + $pp_require_numbers(List,$msg(2108),maxlist/2), + Max is max(List). + + +%%-------------------------------- +%% Stat: agglist/2 + +agglist(List,Dest) :- + $pp_require_list_not_shorter_than(Dest,1,$msg(2103),agglist/2), + Flag = $aggop(0,0,0), + $pp_agglist_1(Dest,Flag), + $pp_agglist_2(List,Flag,N,M,M2,M3,M4,Modes), + $pp_agglist_3(List,Dest,N,M,M2,M3,M4,Modes). + +$pp_agglist_1(Dest,_), Dest = [] => true. +$pp_agglist_1(Dest,Flag), Dest = [Op=_|Dest1] => + $pp_require_agglist_operation(Op,$msg(2107),agglist/2), + %% X = none(0)/len(1)/mean(2)/var(3)/skew(4)/kurt(5) + %% Y = none(0)/mode(1)/amode(2) + ( Op == sum -> X = 0, Y = 0, N = 0 + ; Op == avg -> X = 2, Y = 0, N = 1 + ; Op == mean -> X = 2, Y = 0, N = 1 + ; Op == gmean -> X = 0, Y = 0, N = 1 + ; Op == hmean -> X = 0, Y = 0, N = 1 + ; Op == varp -> X = 3, Y = 0, N = 1 + ; Op == var -> X = 3, Y = 0, N = 2 + ; Op == stdp -> X = 3, Y = 0, N = 1 + ; Op == std -> X = 3, Y = 0, N = 2 + ; Op == semp -> X = 3, Y = 0, N = 1 + ; Op == sem -> X = 3, Y = 0, N = 2 + ; Op == skewp -> X = 4, Y = 0, N = 1 + ; Op == skew -> X = 4, Y = 0, N = 3 + ; Op == kurtp -> X = 5, Y = 0, N = 1 + ; Op == kurt -> X = 5, Y = 0, N = 4 + ; Op == mode -> X = 0, Y = 1, N = 1 + ; Op == amode -> X = 0, Y = 2, N = 1 + ; Op == rmode -> X = 0, Y = 2, N = 1 + ; Op == pmode -> X = 0, Y = 0, N = 1 + ; Op == median -> X = 0, Y = 0, N = 1 + ; Op == min -> X = 0, Y = 0, N = 1 + ; Op == max -> X = 0, Y = 0, N = 1 + ; Op == len -> X = 1, Y = 0, N = 0 + ), + Flag = $aggop(X0,Y0,N0), + ( X0 < X -> setarg(1,Flag,X) ; true ), + ( Y0 < Y -> setarg(2,Flag,Y) ; true ), + ( N0 < N -> setarg(3,Flag,N) ; true ), !, + $pp_agglist_1(Dest1,Flag). + +$pp_agglist_2(List,Flag,N,M,M2,M3,M4,Modes) :- + Flag = $aggop(X,Y,MinN), + ( X == 0 -> + true + ; X == 1 -> $pp_length(List,N) + ; X == 2 -> + $pp_meanlist(List,N,M,agglist/2) + ; X == 3 -> + $pp_moment2(List,MinN,N,M,M2,agglist/2) + ; X == 4 -> + $pp_moment3(List,MinN,N,M,M2,M3,agglist/2) + ; X == 5 -> + $pp_moment4(List,MinN,N,M,M2,M3,M4,agglist/2) + ; %% else + $pp_unmatched_branches($pp_agglist_2/8,first_arg) + ), + ( Y == 0 -> + true + ; Y == 1 -> + $pp_modelist(List,Mode,agglist/2), Modes = [Mode] + ; Y == 2 -> + $pp_amodelist(List,Modes,agglist/2) + ; %% else + $pp_unmatched_branches($pp_agglist_2/8,second_arg) + ). + +$pp_agglist_3(_,Dest,_,_,_,_,_,_), Dest = [] => true. +$pp_agglist_3(List,Dest,N,M,M2,M3,M4,Mode), Dest = [Op=Y|Dest1] => + ( Op == sum -> Y is sum(List) + ; Op == avg -> Y = M + ; Op == mean -> Y = M + ; Op == gmean -> $pp_gmeanlist(List,_,Y,agglist/2) + ; Op == hmean -> $pp_hmeanlist(List,_,Y,agglist/2) + ; Op == varp -> Y is M2 / N + ; Op == var -> Y is M2 / (N - 1) + ; Op == stdp -> Y is sqrt(M2 / N) + ; Op == std -> Y is sqrt(M2 / (N - 1)) + ; Op == semp -> Y is sqrt(M2) / N + ; Op == sem -> Y is sqrt(M2 / (N - 1) / N) + ; Op == skewp -> $pp_compute_skew0(Y,N,M2,M3) + ; Op == skew -> $pp_compute_skew1(Y,N,M2,M3) + ; Op == kurtp -> $pp_compute_kurt0(Y,N,M2,M4) + ; Op == kurt -> $pp_compute_kurt1(Y,N,M2,M4) + ; Op == mode -> [Y|_] = Mode + ; Op == amode -> Y = Mode + ; Op == rmode -> $pp_pmodelist(Mode,Y) + ; Op == pmode -> $pp_pmodelist(List,Y,agglist/2) + ; Op == median -> $pp_medianlist(List,Y,agglist/2) + ; Op == min -> Y is min(List) + ; Op == max -> Y is max(List) + ; Op == len -> Y = N + ; $pp_raise_unmatched_branches($pp_agglist_3/8,operation) + ), !, + $pp_agglist_3(List,Dest1,N,M,M2,M3,M4,Mode). +$pp_agglist_3(_,_,_,_,_,_) => + $pp_raise_unmatched_branches($pp_agglist_3/8,list). + + +%%-------------------------------- +%% Map + +maplist(X,Clause,Xs) :- + $pp_create_temp_clause_1(ID,X,Clause), + ( $pp_maplist(ID,Xs) -> R = true ; R = fail ), + $pp_delete_temp_clause_1(ID), R. + +maplist(X,Y,Clause,Xs,Ys) :- + $pp_create_temp_clause_2(ID,X,Y,Clause), + ( $pp_maplist(ID,Xs,Ys) -> R = true ; R = fail ), + $pp_delete_temp_clause_2(ID), R. + +maplist(X,Y,Z,Clause,Xs,Ys,Zs) :- + $pp_create_temp_clause_3(ID,X,Y,Z,Clause), + ( $pp_maplist(ID,Xs,Ys,Zs) -> R = true ; R = fail ), + $pp_delete_temp_clause_3(ID), R. + +$pp_maplist(_,[]). +$pp_maplist(ID,[X|Xs]) :- + $pd_temp_clause(ID,X), !, $pp_maplist(ID,Xs). + +$pp_maplist(_,[],[]). +$pp_maplist(ID,[X|Xs],[Y|Ys]) :- + $pd_temp_clause(ID,X,Y), !, $pp_maplist(ID,Xs,Ys). + +$pp_maplist(_,[],[],[]). +$pp_maplist(ID,[X|Xs],[Y|Ys],[Z|Zs]) :- + $pd_temp_clause(ID,X,Y,Z), !, $pp_maplist(ID,Xs,Ys,Zs). + +maplist_func(F,Xs) :- + $pp_require_atom(F,$msg(2100),maplist_func/2), + $pp_maplist_func(F,Xs). + +maplist_func(F,Xs,Ys) :- + $pp_require_atom(F,$msg(2100),maplist_func/3), + $pp_maplist_func(F,Xs,Ys). + +maplist_func(F,Xs,Ys,Zs) :- + $pp_require_atom(F,$msg(2100),maplist_func/4), + $pp_maplist_func(F,Xs,Ys,Zs). + +$pp_maplist_func(_,[]). +$pp_maplist_func(F,[X|Xs]) :- + call(F,X), !, $pp_maplist_func(F,Xs). + +$pp_maplist_func(_,[],[]). +$pp_maplist_func(F,[X|Xs],[Y|Ys]) :- + call(F,X,Y), !, $pp_maplist_func(F,Xs,Ys). + +$pp_maplist_func(_,[],[],[]). +$pp_maplist_func(F,[X|Xs],[Y|Ys],[Z|Zs]) :- + call(F,X,Y,Z), !, $pp_maplist_func(F,Xs,Ys,Zs). + +maplist_math(Op,Xs,Ys) :- + $pp_require_atom(Op,$msg(2101),maplist_math/3), + functor(Expr,Op,1), + $pp_maplist_math(Expr,Xs,Ys). + +maplist_math(Op,Xs,Ys,Zs) :- + $pp_require_atom(Op,$msg(2102),maplist_math/4), + functor(Expr,Op,2), + $pp_maplist_math(Expr,Xs,Ys,Zs). + +$pp_maplist_math(_,[],[]). +$pp_maplist_math(Expr,[X|Xs],[Y|Ys]) :- + setarg(1,Expr,X), + Y is Expr, + $pp_maplist_math(Expr,Xs,Ys). + +$pp_maplist_math(_,[],[],[]). +$pp_maplist_math(Expr,[X|Xs],[Y|Ys],[Z|Zs]) :- + setarg(1,Expr,X), + setarg(2,Expr,Y), + Z is Expr, + $pp_maplist_math(Expr,Xs,Ys,Zs). + + +%%-------------------------------- +%% Reduction + +reducelist(A,B,C,Body,Xs,Y0,Y) :- + $pp_create_temp_clause_3(ID,A,B,C,Body), + ( $pp_reducelist(ID,Xs,Y0,Y) -> R = true ; R = fail ), + $pp_delete_temp_clause_3(ID), R. + +$pp_reducelist(_,[],Y,Y). +$pp_reducelist(ID,[X|Xs],Y0,Y) :- + $pd_temp_clause(ID,Y0,X,Y1), !, $pp_reducelist(ID,Xs,Y1,Y). + +reducelist_func(F,Xs,Y0,Y) :- + $pp_require_atom(F,$msg(2100),reducelist_func/4), + $pp_reducelist_func(F,Xs,Y0,Y). + +$pp_reducelist_func(_,[],Y,Y). +$pp_reducelist_func(F,[X|Xs],Y0,Y) :- + call(F,Y0,X,Y1), !, $pp_reducelist_func(F,Xs,Y1,Y). + +reducelist_math(Op,Xs,Y0,Y) :- + $pp_require_atom(Op,$msg(2102),reducelist_math/4), + functor(Expr,Op,2), + $pp_reducelist_math(Expr,Xs,Y0,Y). + +$pp_reducelist_math(_,[],Y,Y). +$pp_reducelist_math(Expr,[X|Xs],Y0,Y) :- + setarg(1,Expr,Y0), + setarg(2,Expr, X), + Y1 is Expr, + $pp_reducelist_math(Expr,Xs,Y1,Y). + +%%-------------------------------- +%% Sublists + +/* vsc: not needed in YAP */ +% sublist(Sub,Lst) :- +% $pp_sublist1(I,_,Lst,Tmp), +% $pp_sublist2(I,_,Tmp,Sub). + +sublist(Sub,Lst,I,J) :- + $pp_require_non_negative_integer(I,$msg(2105),sublist/4), + $pp_require_non_negative_integer(J,$msg(2105),sublist/4), + $pp_sublist1(I,J,Lst,Tmp), + $pp_sublist2(I,J,Tmp,Sub). + +$pp_sublist1(I,J,Xs,Ys) :- var(I), !, + $pp_sublist1_var(0,I,J,Xs,Ys). +$pp_sublist1(I,J,Xs,Ys) :- var(J), !, + $pp_sublist1_det(I,Xs,Ys). +$pp_sublist1(I,J,Xs,Ys) :- I =< J, !, + $pp_sublist1_det(I,Xs,Ys). + +%% [03 Dec 2008, by yuizumi] +%% This predicate would cause infinite loops without (I0 < J) for queries +%% such as ( sublist(_,_,I,0), I > 0 ). + +$pp_sublist1_var(I0,I,_,Xs,Ys) :- + I0 = I, + Xs = Ys. +$pp_sublist1_var(I0,I,J,Xs,Ys) :- var(J),!, + I1 is I0 + 1, + Xs = [_|Xs1], + $pp_sublist1_var(I1,I,J,Xs1,Ys). +$pp_sublist1_var(I0,I,J,Xs,Ys) :- I0 < J, !, + I1 is I0 + 1, + Xs = [_|Xs1], + $pp_sublist1_var(I1,I,J,Xs1,Ys). + +$pp_sublist1_det(I,Xs,Ys) :- I =:= 0, !, + Xs = Ys. +$pp_sublist1_det(I,Xs,Ys) :- I > 0, !, + I1 is I - 1, + Xs = [_|Xs1], + $pp_sublist1_det(I1,Xs1,Ys). + +$pp_sublist2(I,J,Xs,Ys) :- var(J), !, + $pp_sublist2_var(I,J,Xs,Ys). +$pp_sublist2(I,J,Xs,Ys) :- nonvar(J), !, + N is J - I, + $pp_sublist2_det(N,Xs,Ys). + +$pp_sublist2_var(J0,J,_ ,Ys) :- + J0 = J, + Ys = []. +$pp_sublist2_var(J0,J,Xs,Ys) :- + J1 is J0 + 1, + Xs = [X|Xs1], + Ys = [X|Ys1], + $pp_sublist2_var(J1,J,Xs1,Ys1). + +$pp_sublist2_det(N,_ ,Ys) :- N =:= 0, !, + Ys = []. +$pp_sublist2_det(N,Xs,Ys) :- N > 0, !, + N1 is N - 1, + Xs = [X|Xs1], + Ys = [X|Ys1], + $pp_sublist2_det(N1,Xs1,Ys1). + + +%%-------------------------------- +%% Splitting + +splitlist(Prefix,Suffix,List,N) :- + $pp_splitlist(N,List,Prefix,Suffix,splitlist/4). + +grouplist(List,N,Sizes,Dest) :- + $pp_require_positive_integer(N,$msg(2106),grouplist/4), + $pp_grouplist(N,Sizes,List,Dest). + +egrouplist(List,N,Dest) :- + ( $pp_length(List,L) -> true + ; $pp_raise_type_error($msg(2104),[List],[list,List],egrouplist/4) + ), + $pp_require_positive_integer(N,$msg(2106),egrouplist/4),!, + $pp_egrouplist(N,L,List,Dest). + +$pp_splitlist(N,Xs,Ys,Zs,_), var(N) => + $pp_splitlist_var(0,N,Xs,Ys,Zs). +$pp_splitlist(N,Xs,Ys,Zs,Source) :- + $pp_require_non_negative_integer(N,$msg(2105),Source), + $pp_splitlist_det(0,N,Xs,Ys,Zs). + +$pp_splitlist_var(N0,N,Xs,Ys,Zs) ?=> + N0 = N, + Xs = Zs, + Ys = []. +$pp_splitlist_var(N0,N,Xs,Ys,Zs) => + N1 is N0 + 1, + Xs = [X|Xs1], + Ys = [X|Ys1], + $pp_splitlist_var(N1,N,Xs1,Ys1,Zs). + +$pp_splitlist_det(N0,N,Xs,Ys,Zs), N0 =:= N => + Xs = Zs, + Ys = []. +$pp_splitlist_det(N0,N,Xs,Ys,Zs), N0 < N => + N1 is N0 + 1, + Xs = [X|Xs1], + Ys = [X|Ys1], + $pp_splitlist_det(N1,N,Xs1,Ys1,Zs). + +$pp_grouplist(N,Ls,Xs,Ys), N =:= 0 => + Ls = [], + Xs = [], + Ys = []. +$pp_grouplist(N,Ls,Xs,Ys), N > 0 => + Ls = [L|Ls1], + Ys = [Y|Ys1], + $pp_splitlist(L,Xs,Y,Xs1,grouplist/4), + N1 is N - 1, + $pp_grouplist(N1,Ls1,Xs1,Ys1). + +$pp_egrouplist(N,_,_ ,Ys), N =:= 0 => + Ys = []. +$pp_egrouplist(N,L,Xs,Ys), N > 0 => + M is (L + N - 1) // N, + Ys = [Y|Ys1], + $pp_splitlist_det(0,M,Xs,Y,Xs1), + N1 is N - 1, + L1 is L - M, + $pp_egrouplist(N1,L1,Xs1,Ys1). + + +%%-------------------------------- +%% Filtering + +filter(Patt,Xs,Ys) :- + ( $pp_filter(Patt,Xs,Ys) -> true + ; $pp_raise_type_error($msg(2104),[Xs],[list,Xs],filter/3) + ). + +filter(Patt,Xs,Ys,Count) :- + ( $pp_filter(Patt,Xs,Ys) -> true + ; $pp_raise_type_error($msg(2104),[Xs],[list,Xs],filter/4) + ), + length(Ys,Count). + +$pp_filter(_,Xs,Ys), Xs = [] => + Ys = []. +$pp_filter(Patt,Xs,Ys), Xs = [X|Xs1] => + ( $pp_match(Patt,X) -> Ys = [X|Ys1] ; Ys = Ys1 ), + $pp_filter(Patt,Xs1,Ys1). + +filter_not(Patt,Xs,Ys) :- + ( $pp_filter_not(Patt,Xs,Ys) -> true + ; $pp_raise_type_error($msg(2104),[Xs],[list,Xs],filter/4) + ). + +filter_not(Patt,Xs,Ys,Count) :- + ( $pp_filter_not(Patt,Xs,Ys) -> true + ; $pp_raise_type_error($msg(2104),[Xs],[list,Xs],filter_not/4) + ), + length(Ys,Count). + +$pp_filter_not(_,Xs,Ys), Xs = [] => + Ys = []. +$pp_filter_not(Patt,Xs,Ys), Xs = [X|Xs1] => + ( $pp_match(Patt,X) -> Ys = Ys1 ; Ys = [X|Ys1] ), + $pp_filter_not(Patt,Xs1,Ys1). + + +%%-------------------------------- +%% Counting + +countlist(List,Counts) :- + new_hashtable(Table), + ( $pp_countlist(List,Table) -> true + ; $pp_raise_type_error($msg(2104),[List],[list,List],countlist/2) + ), + hashtable_to_list(Table,Counts1), + $pp_countlist_copy(Counts1,0,N), + $pp_mergesort($pp_compare_eqpair(_,_),N,Counts1,_,Counts). + +$pp_countlist(Xs,_), Xs = [] => true. +$pp_countlist(Xs,Table), Xs = [X|Xs1] => + $pp_count(Table,X,_), $pp_countlist(Xs1,Table). + +countlist(Patt,List,Count) :- + ( $pp_countlist(Patt,List,0,Count) -> true + ; $pp_raise_type_error($msg(2104),[List],[list,List],countlist/3) + ). + +$pp_countlist(_,Xs,N0,N), Xs = [] => N0 = N. +$pp_countlist(Patt,Xs,N0,N), Xs = [X|Xs1] => + ( variant(X,Patt) -> N1 is N0 + 1 ; N1 is N0 ), + $pp_countlist(Patt,Xs1,N1,N). + +$pp_countlist_copy(KVs,N0,N), KVs = [] => N0 = N. +$pp_countlist_copy(KVs,N0,N), KVs = [KV|KVs1] => + KV = (Key=_), + ( ground(Key) -> + true + ; copy_term(Key,KeyCp), setarg(1,KV,KeyCp) % overwrite + ), + N1 is N0 + 1, + $pp_countlist_copy(KVs1,N1,N). + +$pp_compare_eqpair((_=A2),(_=B2)), A2 > B2 => true. +$pp_compare_eqpair((A1=A2),(B1=B2)), A2 =:= B2 => A1 @< B1. + + +%%-------------------------------- +%% Sorting + +number_sort(Xs,Ys) :- + $pp_custom_sort(0,Xs,Ys,number_sort/2). + +custom_sort(Op,Xs,Ys), Op == '<' => $pp_custom_sort(0,Xs,Ys,custom_sort/3). +custom_sort(Op,Xs,Ys), Op == '@<' => $pp_custom_sort(1,Xs,Ys,custom_sort/3). +custom_sort(Op,Xs,Ys), atom(Op) => + functor(Term,Op,2), + $pp_custom_sort(Term,Xs,Ys,custom_sort/3). +custom_sort(Op,_,_) => + $pp_require_atom(Op,$msg(2102),custom_sort/3). + +custom_sort(A,B,Body,Xs,Ys) :- + $pp_custom_sort($cmp(A,B,Body),Xs,Ys,custom_sort/5). + +$pp_custom_sort(Cmp,Xs,Ys,Source) :- + ( $pp_length(Xs,L) -> true + ; $pp_raise_type_error($msg(2104),[Xs],[list,Xs],Source) + ), + $pp_mergesort(Cmp,L,Xs,_,Ys). + +$pp_mergesort(_,N,Xs0,Xs1,Ys), N == 0 => Xs0 = Xs1, Ys = []. +$pp_mergesort(_,N,Xs0,Xs1,Ys), N == 1 => Xs0 = [X|Xs1], Ys = [X]. +$pp_mergesort(Cmp,N,Xs0,Xs1,Ys) => + NL is N // 2, + NR is N - NL, + $pp_mergesort(Cmp,NL,Xs0,Xs2,Ys0), + $pp_mergesort(Cmp,NR,Xs2,Xs1,Ys1), + $pp_mergelist(Cmp,Ys0,Ys1,Ys). + +$pp_mergelist(_,Xs,Ys,Zs), Xs == [] => Ys = Zs. +$pp_mergelist(_,Xs,Ys,Zs), Ys == [] => Xs = Zs. +$pp_mergelist(Cmp,Xs0,Ys0,Zs0), Cmp == 0 => + Xs0 = [X|Xs1], + Ys0 = [Y|Ys1], + ( Y < X -> + Zs0 = [Y|Zs1], $pp_mergelist(Cmp,Xs0,Ys1,Zs1) + ; Zs0 = [X|Zs1], $pp_mergelist(Cmp,Xs1,Ys0,Zs1) + ). +$pp_mergelist(Cmp,Xs0,Ys0,Zs0), Cmp == 1 => + Xs0 = [X|Xs1], + Ys0 = [Y|Ys1], + ( Y @< X -> + Zs0 = [Y|Zs1], $pp_mergelist(Cmp,Xs0,Ys1,Zs1) + ; Zs0 = [X|Zs1], $pp_mergelist(Cmp,Xs1,Ys0,Zs1) + ). +$pp_mergelist(Cmp,Xs0,Ys0,Zs0), functor(Cmp,_,2) => + Xs0 = [X|Xs1], + Ys0 = [Y|Ys1], + setarg(1,Cmp,Y), + setarg(2,Cmp,X), + ( Cmp -> + Zs0 = [Y|Zs1], $pp_mergelist(Cmp,Xs0,Ys1,Zs1) + ; Zs0 = [X|Zs1], $pp_mergelist(Cmp,Xs1,Ys0,Zs1) + ). +$pp_mergelist(Cmp,Xs0,Ys0,Zs0) => + Xs0 = [X|Xs1], + Ys0 = [Y|Ys1], + ( \+ \+ ( Cmp = $cmp(Y,X,Body), Body ) -> + Zs0 = [Y|Zs1], $pp_mergelist(Cmp,Xs0,Ys1,Zs1) + ; Zs0 = [X|Zs1], $pp_mergelist(Cmp,Xs1,Ys0,Zs1) + ). + + +%%-------------------------------- + +$pp_require_agglist_operation(Op,MsgID,Source) :- + ( $pp_test_agglist_operation(Op) -> true + ; $pp_raise_on_require([Op],MsgID,Source,$pp_error_agglist_operation) + ). + +$pp_test_agglist_operation(Op) :- + atom(Op), + membchk(Op,[sum,avg,mean,gmean,hmean,varp,var, + stdp,std,semp,sem,skewp,skew,kurtp,kurt, + mode,amode,rmode,pmode,median,min,max,len]). + +$pp_error_agglist_operation(Op,instanciation_error) :- + var(Op), !. +$pp_error_agglist_operation(Op,Error) :- + \+ $pp_error_atom(Op,Error), !. +$pp_error_agglist_operation(Op,domain_error(agglist_operation,Op)) :- + \+ $pp_test_agglist_operation(Op), !. diff --git a/packages/prism/src/prolog/up/main.pl b/packages/prism/src/prolog/up/main.pl new file mode 100644 index 000000000..98f8c5cd5 --- /dev/null +++ b/packages/prism/src/prolog/up/main.pl @@ -0,0 +1,338 @@ +%% -*- Prolog -*- + +%%---------------------------------------- +%% Version and copyright statement + +$pp_version('2.0'). +$pp_copyright('PRISM 2.0, (C) Sato Lab, Tokyo Institute of Technology, July, 2010'). + +get_version(V) :- $pp_version(V). +print_version :- $pp_version(V), !, format("~w~n",[V]). +print_copyright :- $pp_copyright(Msg), !, format("~w~n",[Msg]). + +%%---------------------------------------- +%% Operators + +:- op(1160,xfx,times). + +:- op(1150,fx,sample). +:- op(1150,fx,prob). +:- op(1150,fx,probf). +:- op(1150,fx,probfi). +:- op(1150,fx,probfo). +:- op(1150,fx,probfv). +:- op(1150,fx,probfio). +:- op(1150,fx,viterbi). +:- op(1150,fx,viterbif). +:- op(1150,fx,viterbig). +:- op(1150,fx,hindsight). +:- op(1150,fx,chindsight). + +:- op(1150,fy,p_table). +:- op(1150,fy,p_not_table). + +:- op(600,xfx,@). + +:- op(950,fx,?? ). +:- op(950,fx,??*). +:- op(950,fx,??>). +:- op(950,fx,??<). +:- op(950,fx,??+). +:- op(950,fx,??-). + +%%---------------------------------------- +%% Declarations + +% only declarations. no effect when executed +p_table(_). +p_not_table(_). + +:- table $prism_eg_path/3. +:- table $prism_expl_msw/3. +:- table $expl_interp_single_call/3. + +%%---------------------------------------- +%% Initializations + +% +% vsc: delay until end in YAP +% +%:- ( $pc_mp_mode -> true ; print_copyright ). +%:- random_set_seed. +%:- reset_prism_flags. + +%%---------------------------------------- +%% Help messages + +$help_mess("~nType 'prism_help' for usage.~n"). % Hook for B-Prolog + +prism_help :- + format(" prism(File) -- compile and load a program~n",[]), + format(" prism(Opts,File) -- compile and load a program~n",[]), + nl, + format(" msw(I,V) -- the switch I randomly outputs the value V~n",[]), + nl, + format(" learn(Gs) -- learn the parameters~n",[]), + format(" learn -- learn the parameters from data_source~n",[]), + format(" sample(Goal) -- get a sampled instance of Goal~n",[]), + format(" prob(Goal,P) -- compute a probability~n",[]), + format(" probf(Goal,F) -- compute an explanation graph~n",[]), + format(" viterbi(Goal,P) -- compute a Viterbi probability~n",[]), + format(" viterbif(Goal,P,F) -- compute a Viterbi probability with its explanation~n",[]), + format(" hindsight(Goal,Patt,Ps) -- compute hindsight probabilities~n",[]), + nl, + format(" set_sw(Sw,Params) -- set parameters of a switch~n",[]), + format(" get_sw(Sw,SwInfo) -- get information of a switch~n",[]), + format(" set_prism_flag(Flg,Val) -- set a new value to a flag~n",[]), + format(" get_prism_flag(Flg,Val) -- get the current value of a flag~n",[]), + nl, + format(" please consult the user's manual for details.~n",[]). + +%%---------------------------------------- +%% Loading a program + +prism(File) :- + prism([],File). + +prism(Opts,File) :- + $pp_require_atom(File,$msg(3000),prism/2), + $pp_set_options(Opts), % also aiming at the error check of options + ( member(consult,Opts) -> + $pp_search_file(File,File1,[".psm",""]), + Pred = $pp_consult(File1) + ; member(load,Opts) -> + $pp_search_file(File,File1,[".psm.out",".out",""]), + Pred = $pp_load(File1) + ; ( member(dump,Opts) -> D = 1 ; D = 0 ), + global_set($pg_dump_compiled,D), + $pp_search_file(File,File1,[".psm",""]), + Pred = $pp_compile_load(File1) + ),!, + reset_prism_flags, + global_del(failure,0), + global_set($pg_dummy_goal_count,0), + call(Pred),!. +prism(_Opts,File) :- + $pp_raise_existence_error($msg(3001),[File], + [prism_file,File],existence_error). + +$pp_compile_load(File) :- + $pp_add_out_extension(File,OutFile), + $pp_clean_dynamic_preds, + $pp_compile(File,_DmpFile,OutFile), + $pp_load(OutFile). + +$pp_load(File) :- + not(not($myload(File))), + $pp_init_tables_aux, + $pp_init_tables,!. +% We do not perform translation +% -- the explanation search will be done by meta-interpreters +$pp_consult(File) :- + $pp_clean_dynamic_preds, + new_hashtable(PPredTab), + Info = $trans_info(_DoTable,_TPredTab,_NoDebug,PPredTab), + $pp_bpif_read_program(Prog,File), + $pp_extract_decls(Prog,Info), + $pp_trans_values(Prog,Prog1), + $pp_analyze(Prog1,Info), + $pp_tabled_to_nontabled(Prog1,Prog2), + assert($pd_is_tabled_pred($disabled_by_consult,0)), + $pp_separate_demon_load(Prog2,Prog3,Prog4), + % $damon_load/0 should be consulted after loading the entire program + consult_preds(Prog4,_ProgCompiled), + consult_preds(Prog3,_ProgCompiled), + $pp_init_tables_aux, + $pp_init_tables. + + +$pp_set_options([]) => true. +$pp_set_options([O|Options]) => + $pp_require_prism_option(O,$msg(1001),prism/2), + $pp_set_one_option(O),!, + $pp_set_options(Options). + +$pp_set_one_option(dump) => true. +$pp_set_one_option(consult) => true. +$pp_set_one_option(compile) => true. +$pp_set_one_option(load) => true. +$pp_set_one_option(v) :- set_prism_flag(verb,full). +$pp_set_one_option(verb) :- set_prism_flag(verb,full). +$pp_set_one_option(nv) :- set_prism_flag(verb,none). +$pp_set_one_option(noverb) :- set_prism_flag(verb,none). +$pp_set_one_option(Att=Val) :- set_prism_flag(Att,Val). + + +%%---------------------------------------- +%% Clean up databases + +$pp_clean_dynamic_preds :- + $pp_clean_predicate_info, + $pp_clean_switch_info, + $pp_clean_dummy_goal_table, + $pp_clean_graph_stats, + $pp_clean_learn_stats, + $pp_clean_infer_stats. + +$pp_clean_predicate_info :- + retractall($pd_is_prob_pred(_,_)), + retractall($pd_is_tabled_pred(_,_)),!. + +$pp_clean_switch_info :- + retractall($pd_parameters(_,_,_)), + retractall($pd_hyperparameters(_,_,_,_)), + retractall($pd_expectations(_,_,_)), + retractall($pd_hyperexpectations(_,_,_)), + retractall($pd_fixed_parameters(_)), + retractall($pd_fixed_hyperparameters(_)),!. + +$pp_init_tables :- + initialize_table, + $pc_prism_id_table_init, + $pc_clean_base_egraph, % base support graph and switches + $pc_alloc_egraph,!. % get ready for the following steps + +$pp_init_tables_if_necessary :- + ( get_prism_flag(clean_table, on) -> $pp_init_tables + ; true + ),!. + +$pp_init_tables_aux :- + $pc_clean_egraph, % derived support graphs + $pc_clean_external_tables,!. + + +%%---------------------------------------- +%% Show the program information + +show_values :- + format("Outcome spaces:~n",[]),!, + findall([Sw,Vals],($pp_registered_sw(Sw),get_values1(Sw,Vals)),SwVals0), + sort(SwVals0,SwVals1), + $pp_show_values_list(SwVals1),!. + +$pp_show_values_list([]). +$pp_show_values_list([[Sw,Vals]|SwVals]) :- + format(" ~q: ~q~n",[Sw,Vals]),!, + $pp_show_values_list(SwVals). + +%% (Note) $pd_is_{prob,tabled}_pred/2 are dynamic, so we don't have to call +%% current_predicate/1. We don't check the input rigorously either +%% for flexibility. + +is_prob_pred(F/N) :- is_prob_pred(F,N). +is_prob_pred(F,N) :- $pd_is_prob_pred(F,N). + +is_tabled_pred(F/N) :- is_tabled_pred(F,N). +is_tabled_pred(F,N) :- $pd_is_tabled_pred(F,N). + +show_prob_preds :- + format("Probabilistic predicates:~n",[]),!, + findall(F0/N0,is_prob_pred(F0,N0),Preds0), + sort(Preds0,Preds), + ( member(F/N,Preds), + format(" ~q/~w~n",[F,N]), + fail + ; true + ),!. + +show_tabled_preds :- + $pd_is_tabled_pred($disabled_by_consult,_),!, + $pp_raise_warning($msg(1002)). + +show_tabled_preds :- + format("Tabled probabilistic predicates:~n",[]),!, + findall(F0/N0,is_tabled_pred(F0,N0),Preds0), + sort(Preds0,Preds), + ( member(F/N,Preds), + format(" ~q/~w~n",[F,N]), + fail + ; true + ),!. + +%% aliases +show_prob_pred :- show_prob_preds. +show_table_pred :- show_tabled_preds. +show_table_preds :- show_tabled_preds. +show_tabled_pred :- show_tabled_preds. + +%%---------------------------------------- +%% Predicates for batch (non-interactive) execution + +$pp_batch :- + catch($pp_batch_core,Err,$pp_batch_error(Err)). + +$pp_batch_error(Err) :- + Err == abort,!. +$pp_batch_error(Err) :- + Err == interrupt,!, + format("Aborted by interruption~n",[]), + abort. +$pp_batch_error(Err) :- + format("Aborted by exception -- ~w~n",[Err]), + abort. + +$pp_batch_core :- + get_main_args([Arg|Args]),!, + $pp_batch_load(Arg,File), + $pp_batch_main(Args,File). +$pp_batch_core :- + $pp_raise_existence_error($msg(1003),[prism_file,unknown],$pp_batch/1). + +$pp_batch_load(Arg,File) :- + ( atom_chars(Arg,[p,r,i,s,m, ':'|FileChars]) -> + atom_chars(File,FileChars), FileChars \== [], prism(File) + ; atom_chars(Arg,[p,r,i,s,m,n,':'|FileChars]) -> + atom_chars(File,FileChars), FileChars \== [], prismn(File) + ; atom_chars(Arg,[l,o,a,d, ':'|FileChars]) -> + atom_chars(File,FileChars), FileChars \== [], prism([load],File) + ; prism(Arg), File = Arg + ),!. + +$pp_batch_main(Args,File) :- + ( current_predicate(prism_main/1) -> Goal = prism_main(Args) + ; current_predicate(prism_main/0) -> Goal = prism_main + ; $pp_raise_existence_error($msg(1004),[File],[batch_predicate,File], + $pp_batch_main/2) + ),!, + %% use of call/1 is for the parallel version + call($pp_batch_call(Goal)). + +%%---------------------------------------- +%% Miscellaneous routines + +$pp_tabled_to_nontabled([],Prog) => Prog = []. +$pp_tabled_to_nontabled([pred(F,N,M,Delay,_Tabled,Cls)|Preds],Prog) => + Prog = [pred(F,N,M,Delay,_,Cls)|Prog1], !, + $pp_tabled_to_nontabled(Preds,Prog1). + + +$pp_separate_demon_load([],[],[]). +$pp_separate_demon_load([pred($damon_load,0,X0,X1,X2,X3)|Prog0], + [pred($damon_load,0,X0,X1,X2,X3)|Prog1], + Prog2) :- !, + $pp_separate_demon_load(Prog0,Prog1,Prog2). +$pp_separate_demon_load([P|Prog0],Prog1,[P|Prog2]) :- !, + $pp_separate_demon_load(Prog0,Prog1,Prog2). + + +$pp_search_file(File,File1,Suffixes) :- + member(Suffix,Suffixes), + $pp_add_extension(File,File1,Suffix), + exists(File1),!. + + +$pp_add_psm_extension(File,PsmFile) :- + $pp_add_extension(File,PsmFile,".psm"). + +$pp_add_out_extension(File,OutFile) :- + $pp_add_extension(File,OutFile,".out"). + +$pp_add_extension(File,File1,Extension) :- + ( atom(File) -> name(File,FileString) + ; File ?= [_|_] -> File = FileString + ; $pp_raise_domain_error($msg(1000),[File],[filename,File], + $pp_add_extension/3) + ), + append(FileString,Extension,FileString1), + name(File1,FileString1). diff --git a/packages/prism/src/prolog/up/prob.pl b/packages/prism/src/prolog/up/prob.pl new file mode 100644 index 000000000..fb018d50d --- /dev/null +++ b/packages/prism/src/prolog/up/prob.pl @@ -0,0 +1,412 @@ +prob(Goal) :- + prob(Goal,P), + ( $pp_in_log_scale -> Text = 'Log-probability' ; Text = 'Probability' ), + format("~w of ~w is: ~15f~n",[Text,Goal,P]). + +prob(Goal,Prob) :- + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),prob/2), + $pp_prob(Goal,Prob). + +$pp_prob(msw(Sw,V),Prob) :- + $pp_require_ground(Sw,$msg(0101),prob/2), + $pp_require_switch_outcomes(Sw,$msg(0102),prob/2), + $pp_clean_infer_stats, + ( var(V) -> + cputime(T0), + ( $pp_in_log_scale -> Prob = 0.0 ; Prob = 1.0 ), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_prob_stats1(InfTime) + ; % else + cputime(T0), + $pp_get_value_prob(Sw,V,Prob0), + ( $pp_in_log_scale -> Prob is log(Prob0) ; Prob = Prob0 ), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_prob_stats1(InfTime) + ), + $pp_assert_prob_stats2(0.0,0.0),!. + +$pp_prob(Goal,Prob) :- + $pp_clean_infer_stats, + cputime(T0), + $pp_prob_core(Goal,Prob), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_prob_stats1(InfTime),!. + +log_prob(Goal) :- + log_prob(Goal,P),format("Log-probability of ~w is: ~15f~n",[Goal,P]). +log_prob(Goal,P) :- + $pp_prob(Goal,P0),( $pp_in_log_scale -> P = P0 ; P is log(P0) ). + +$pp_in_log_scale :- + get_prism_flag(log_scale,on). + +$pp_prob_core(Goal,Prob) :- + ground(Goal), + $pp_is_tabled_probabilistic_atom(Goal),!, + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(Goal), + cputime(T2), + $pp_compute_inside(Goal,Prob),!, + cputime(T3), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_prob_stats2(SearchTime,NumCompTime),!. + +$pp_prob_core(Goal,Prob) :- + copy_term(Goal,GoalCp), + ( $pp_trans_one_goal(GoalCp,CompGoal) -> BodyGoal = CompGoal + ; BodyGoal = (savecp(CP),Depth=0, + $pp_expl_interp_goal(GoalCp,Depth,CP,[],_,[],_,[],_,[],_)) + ), + $pp_create_dummy_goal(DummyGoal), + Clause = (DummyGoal:-BodyGoal, + $pc_prism_goal_id_register(GoalCp,GId), + $pc_prism_goal_id_register(DummyGoal,HId), + $prism_eg_path(HId,[GId],[])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_),[Clause])], + consult_preds([],Prog), + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(DummyGoal), + cputime(T2), + $pp_compute_inside(DummyGoal,Prob), + cputime(T3), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_prob_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + +% Sws = [sw(Id,Instances,Probs,Deltas,FixedP,FixedH),...] +$pp_compute_inside(Goal,Prob) :- + $pp_collect_sw_info(Sws), + $pc_export_sw_info(Sws), + $pc_prism_goal_id_get(Goal,Gid), + $pc_compute_inside(Gid,Prob),!. + +$pp_get_value_prob(Sw,V,Prob) :- + $pp_get_parameters(Sw,Values,Probs), + $pp_get_value_prob(Values,Probs,V,Prob). + +$pp_get_value_prob([V|_],[Prob0|_],V,Prob) :- !, Prob = Prob0. +$pp_get_value_prob([_|Vs],[_|Probs],V,Prob) :- !, + $pp_get_value_prob(Vs,Probs,V,Prob). + +$pp_collect_sw_info(Sws) :- + $pc_prism_sw_count(N), + $pp_collect_sw_info(0,N,Sws). + +$pp_collect_sw_info(Sid,N,[]) :- Sid >= N,!. +$pp_collect_sw_info(Sid,N,SwInsList) :- + $pc_prism_sw_term(Sid,Sw), + $pp_get_parameters(Sw,Values,Pbs), + $pp_get_hyperparameters(Sw,Values,_,Deltas), + ( $pd_fixed_parameters(Sw) -> FixedP = 1 ; FixedP = 0 ), + ( $pd_fixed_hyperparameters(Sw) -> FixedH = 1 ; FixedH = 0 ), + SwInsList = [sw(Sid,Iids,Pbs,Deltas,FixedP,FixedH)|SwInsList1],!, + $pp_collect_sw_ins_ids(Sw,Values,Iids), + Sid1 is Sid + 1,!, + $pp_collect_sw_info(Sid1,N,SwInsList1). + +get_subgoal_hashtable(GTab) :- + $pp_get_subgoal_hashtable(GTab). + +$pp_get_subgoal_hashtable(GTab) :- + $pc_prism_goal_count(GC), + new_hashtable(GTab,GC), + $pp_get_subgoal_hashtable(0,GC,GTab). + +$pp_get_subgoal_hashtable(Gid,N,_) :- Gid >= N,!. +$pp_get_subgoal_hashtable(Gid,N,GTab) :- + $pc_prism_goal_term(Gid,G), + hashtable_put(GTab,Gid,G), + Gid1 is Gid + 1,!, + $pp_get_subgoal_hashtable(Gid1,N,GTab). + +get_switch_hashtable(SwTab) :- + $pp_get_switch_hashtable(SwTab). + +$pp_get_switch_hashtable(SwTab) :- + $pc_prism_sw_ins_count(IC), + new_hashtable(SwTab,IC), + $pp_get_switch_hashtable(0,IC,SwTab). + +$pp_get_switch_hashtable(Sid,N,_) :- Sid >= N,!. +$pp_get_switch_hashtable(Sid,N,SwTab) :- + $pc_prism_sw_ins_term(Sid,S), + hashtable_put(SwTab,Sid,S), + Sid1 is Sid + 1,!, + $pp_get_switch_hashtable(Sid1,N,SwTab). + +probf(Goal) :- + $pp_probf(Goal,Expls,1,0), \+ \+ print_graph(Expls,[lr('<=>')]). +probfi(Goal) :- + $pp_probf(Goal,Expls,1,1), \+ \+ print_graph(Expls,[lr('<=>')]). +probfo(Goal) :- + $pp_probf(Goal,Expls,1,2), \+ \+ print_graph(Expls,[lr('<=>')]). +probfv(Goal) :- + $pp_probf(Goal,Expls,1,3), \+ \+ print_graph(Expls,[lr('<=>')]). +probfio(Goal) :- + $pp_probf(Goal,Expls,1,4), \+ \+ print_graph(Expls,[lr('<=>')]). + +probf(Goal,Expls) :- + $pp_probf(Goal,Expls,1,0). +probfi(Goal,Expls) :- + $pp_probf(Goal,Expls,1,1). +probfo(Goal,Expls) :- + $pp_probf(Goal,Expls,1,2). +probfv(Goal,Expls) :- + $pp_probf(Goal,Expls,1,3). +probfio(Goal,Expls) :- + $pp_probf(Goal,Expls,1,4). + +probef(Goal) :- + $pp_probf(Goal,Expls,0,0), \+ \+ print_graph(Expls,[lr('<=>')]). +probefi(Goal) :- + $pp_probf(Goal,Expls,0,1), \+ \+ print_graph(Expls,[lr('<=>')]). +probefo(Goal) :- + $pp_probf(Goal,Expls,0,2), \+ \+ print_graph(Expls,[lr('<=>')]). +probefv(Goal) :- + $pp_probf(Goal,Expls,0,3), \+ \+ print_graph(Expls,[lr('<=>')]). +probefio(Goal) :- + $pp_probf(Goal,Expls,0,4), \+ \+ print_graph(Expls,[lr('<=>')]). + +probef(Goal,Expls) :- + $pp_probf(Goal,Expls,0,0). +probefi(Goal,Expls) :- + $pp_probf(Goal,Expls,0,1). +probefo(Goal,Expls) :- + $pp_probf(Goal,Expls,0,2). +probefv(Goal,Expls) :- + $pp_probf(Goal,Expls,0,3). +probefio(Goal,Expls) :- + $pp_probf(Goal,Expls,0,4). + +probef(Goal,Expls,GoalHashTab,SwHashTab) :- + $pp_probf(Goal,Expls,0,0), + $pp_get_subgoal_hashtable(GoalHashTab), + $pp_get_switch_hashtable(SwHashTab). +probefi(Goal,Expls,GoalHashTab,SwHashTab) :- + $pp_probf(Goal,Expls,0,1), + $pp_get_subgoal_hashtable(GoalHashTab), + $pp_get_switch_hashtable(SwHashTab). +probefo(Goal,Expls,GoalHashTab,SwHashTab) :- + $pp_probf(Goal,Expls,0,2), + $pp_get_subgoal_hashtable(GoalHashTab), + $pp_get_switch_hashtable(SwHashTab). +probefv(Goal,Expls,GoalHashTab,SwHashTab) :- + $pp_probf(Goal,Expls,0,3), + $pp_get_subgoal_hashtable(GoalHashTab), + $pp_get_switch_hashtable(SwHashTab). +probefio(Goal,Expls,GoalHashTab,SwHashTab) :- + $pp_probf(Goal,Expls,0,4), + $pp_get_subgoal_hashtable(GoalHashTab), + $pp_get_switch_hashtable(SwHashTab). + +%% PrMode is one of 0 (none), 1 (inside), 2 (outside), 3 (viterbi) and +%% 4 (inside-outside) + +$pp_probf(Goal,Expls,Decode,PrMode) :- + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),$pp_probf/4), + $pp_compute_expls(Goal,Expls,Decode,PrMode). + +$pp_compute_expls(Goal,Expls,Decode,PrMode) :- + Goal = msw(I,V),!, + $pp_require_ground(I,$msg(0101),$pp_probf/4), + $pp_require_switch_outcomes(I,$msg(0102),$pp_probf/4), + $pp_clean_infer_stats, + ( ground(V) -> V = VCp ; copy_term(V,VCp) ), + $pp_create_dummy_goal(DummyGoal), + DummyBody = ($prism_expl_msw(I,VCp,Sid), + $pc_prism_goal_id_register(DummyGoal,Hid), + $prism_eg_path(Hid,[],[Sid])), + Prog = [pred(DummyGoal,0,_,_,tabled(_,_,_,_),[(DummyGoal:-DummyBody)])], + consult_preds([],Prog), + cputime(T0), + $pp_compute_expls(DummyGoal,Goal,Expls,Decode,PrMode,T0),!. + +$pp_compute_expls(Goal,Expls,Decode,PrMode) :- + $pp_is_tabled_probabilistic_atom(Goal), + ground(Goal),!, + $pp_clean_infer_stats, + cputime(T0), + $pp_compute_expls(Goal,_,Expls,Decode,PrMode,T0),!. + +$pp_compute_expls(Goal,Expls,Decode,PrMode) :- + $pp_clean_infer_stats, + copy_term(Goal,GoalCp), + ( $pp_trans_one_goal(GoalCp,CompGoal) -> + BodyGoal = CompGoal + ; BodyGoal = (savecp(CP),Depth=0, + $pp_expl_interp_goal(GoalCp,Depth,CP,[],_,[],_,[],_,[],_)) + ), + $pp_create_dummy_goal(DummyGoal), + DummyBody = (BodyGoal, + $pc_prism_goal_id_register(GoalCp,GId), + $pc_prism_goal_id_register(DummyGoal,HId), + $prism_eg_path(HId,[GId],[])), + Prog = [pred(DummyGoal,0,_,_,tabled(_,_,_,_),[(DummyGoal:-DummyBody)])], + consult_preds([],Prog), + cputime(T0), + $pp_compute_expls(DummyGoal,Goal,Expls,Decode,PrMode,T0),!. + +$pp_compute_expls(Goal,GLabel,Expls,Decode,PrMode,T0) :- + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + garbage_collect, + cputime(T1), + $pp_find_explanations(Goal), + cputime(T2), + $pc_prism_goal_id_get(Goal,Gid), + $pc_alloc_sort_egraph(Gid), + cputime(T3), + ( PrMode == 0 -> true + ; $pp_collect_sw_info(Sws), + $pc_export_sw_info(Sws), + $pc_compute_probf(PrMode) + ), + cputime(T4), + $pc_import_sorted_graph_size(Size), + $pp_build_expls(Size,Decode,PrMode,GLabel,Expls), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + cputime(T5), + SearchTime is T2 - T1, + NumCompTime is T4 - T3, + InfTime is T5 - T0, + ( PrMode == 0 -> $pp_assert_prob_stats2(SearchTime) + ; $pp_assert_prob_stats2(SearchTime,NumCompTime) + ), + $pp_assert_prob_stats1(InfTime), + $pp_delete_tmp_out,!. + +$pp_build_expls(I0,_,_,_,Expls), I0 =< 0 => + Expls = []. +$pp_build_expls(I0,Decode,PrMode,GLabel,Expls), I0 > 0 => + I is I0 - 1, + $pc_import_sorted_graph_gid(I,Gid), + $pc_import_sorted_graph_paths(I,Paths0), + ( Decode == 0 -> Label = Gid + ; nonvar(GLabel) -> Label = GLabel + ; $pc_prism_goal_term(Gid,Label) + ), + ( PrMode == 0 -> Node = node(Label,Paths) % probf + ; PrMode == 4 -> % probfio + $pp_get_gnode_probs(PrMode,Gid,Value), + Node = node(Label,Paths,Value), + Value = [_,Vo] + ; $pp_get_gnode_probs(PrMode,Gid,Value), + Node = node(Label,Paths,Value), + Value = Vo + ), + $pp_decode_paths(Paths0,Paths,Decode,PrMode,Vo), + Expls = [Node|Expls1],!, + $pp_build_expls(I,Decode,PrMode,_,Expls1). + + + +$pp_decode_paths([],[],_Decode,_PrMode,_Vo). +$pp_decode_paths([Pair|Pairs],[Path|Paths],Decode,PrMode,Vo) :- + Pair = [Gids,Sids], + $pp_decode_gnodes(Gids,GNodes,Decode,PrMode,Vg), + $pp_decode_snodes(Sids,SNodes,Decode,PrMode,Vs), + get_prism_flag(log_scale,LogScale), + ( PrMode == 0 -> + Path = path(GNodes,SNodes) + ; PrMode == 1 -> ( LogScale == on -> Vi is Vg + Vs ; Vi is Vg * Vs), + Path = path(GNodes,SNodes,Vi) + ; PrMode == 2 -> + Path = path(GNodes,SNodes,Vo) + ; PrMode == 3 -> ( LogScale == on -> Vi is Vg + Vs ; Vi is Vg * Vs), + Path = path(GNodes,SNodes,Vi) + ; PrMode == 4 -> ( LogScale == on -> Vi is Vg + Vs ; Vi is Vg * Vs), + Path = path(GNodes,SNodes,[Vi,Vo]) + ),!, + $pp_decode_paths(Pairs,Paths,Decode,PrMode,Vo). + +$pp_decode_gnodes(Gids,GNodes,Decode,PrMode,V) :- + get_prism_flag(log_scale,LogScale), + ( LogScale == on -> V0 = 0.0 ; V0 = 1.0 ), + $pp_decode_gnodes(Gids,GNodes,Decode,PrMode,LogScale,V0,V). + +$pp_decode_gnodes([],[],_Decode,_PrMode,_LogScale,V,V) :- !. +$pp_decode_gnodes([Gid|Gids],[GNode|GNodes],Decode,PrMode,LogScale,V0,V) :- + ( Decode == 0 -> Gid = Label + ; $pc_prism_goal_term(Gid,Label) + ), + ( PrMode == 0 -> GNode = Label + ; $pp_get_gnode_probs(PrMode,Gid,Value), + GNode = gnode(Label,Value), + ( LogScale == on -> + V1 is Value + V0 + ; V1 is Value * V0 + ) + ),!, + $pp_decode_gnodes(Gids,GNodes,Decode,PrMode,LogScale,V1,V). + +$pp_decode_snodes(Sids,SNodes,Decode,PrMode,V) :- + get_prism_flag(log_scale,LogScale), + ( LogScale == on -> V0 = 0.0 ; V0 = 1.0 ), + $pp_decode_snodes(Sids,SNodes,Decode,PrMode,LogScale,V0,V). + +$pp_decode_snodes([],[],_Decode,_PrMode,_LogScale,V,V) :- !. +$pp_decode_snodes([Sid|Sids],[SNode|SNodes],Decode,PrMode,LogScale,V0,V) :- + ( Decode == 0 -> Sid = Label + ; $pc_prism_sw_ins_term(Sid,Label) + ), + ( PrMode == 0 -> SNode = Label + ; $pp_get_snode_probs(PrMode,Sid,Value), + SNode = snode(Label,Value), + ( LogScale == on -> + V1 is Value + V0 + ; V1 is Value * V0 + ) + ),!, + $pp_decode_snodes(Sids,SNodes,Decode,PrMode,LogScale,V1,V). + +$pp_get_gnode_probs(1,Gid,Pi) :- $pc_get_gnode_inside(Gid,Pi),!. +$pp_get_gnode_probs(2,Gid,Po) :- $pc_get_gnode_outside(Gid,Po),!. +$pp_get_gnode_probs(3,Gid,Pv) :- $pc_get_gnode_viterbi(Gid,Pv),!. +$pp_get_gnode_probs(4,Gid,[Pi,Po]) :- + $pc_get_gnode_inside(Gid,Pi), + $pc_get_gnode_outside(Gid,Po),!. + +$pp_get_snode_probs(1,Sid,Pi) :- $pc_get_snode_inside(Sid,Pi),!. +$pp_get_snode_probs(2,Sid,E) :- $pc_get_snode_expectation(Sid,E),!. +$pp_get_snode_probs(3,Sid,Pi) :- $pc_get_snode_inside(Sid,Pi),!. +$pp_get_snode_probs(4,Sid,[Pi,Po]) :- + $pc_get_snode_inside(Sid,Pi), + $pc_get_snode_expectation(Sid,Po),!. + +%%%% Statistics + +$pp_assert_prob_stats1(InfTime0) :- + InfTime is InfTime0 / 1000.0, + assertz($ps_infer_time(InfTime)),!. + +$pp_assert_prob_stats2(SearchTime0) :- + SearchTime is SearchTime0 / 1000.0, + assertz($ps_infer_search_time(SearchTime)),!. + +$pp_assert_prob_stats2(SearchTime0,NumCompTime0) :- + SearchTime is SearchTime0 / 1000.0, + NumCompTime is NumCompTime0 / 1000.0, + assertz($ps_infer_search_time(SearchTime)), + assertz($ps_infer_calc_time(NumCompTime)),!. + +$pp_clean_infer_stats :- + retractall($ps_infer_time(_)), + retractall($ps_infer_search_time(_)), + retractall($ps_infer_calc_time(_)),!. diff --git a/packages/prism/src/prolog/up/sample.pl b/packages/prism/src/prolog/up/sample.pl new file mode 100644 index 000000000..0374dcff3 --- /dev/null +++ b/packages/prism/src/prolog/up/sample.pl @@ -0,0 +1,113 @@ +%% +%% sample.pl: routines for sampling execution +%% +%% +%% | ?- sample(bloodtype(X)). +%% +%% X = a ? +%% +%% Also available for Utility program. +%% +%% go(Loc,Dir) :- +%% ( is_wall(forward,Loc), +%% sample(coin(X)), +%% ( X = head,!,Dir = right +%% ; Dir = left +%% ) +%% ; Dir = forward +%% ). + +sample(Goal) :- + $pp_require_probabilistic_atom(Goal,$msg(1201),sample/1), + $trace_call(Goal). % just calls call(Goal) if not in debug mode + +%%---------------------------------------------------------------------------- + +msw(Sw,V) :- + $pp_require_ground(Sw,$msg(0101),msw/2), + $prism_sample_msw(Sw,V). + +% Sw is assumed to be ground in $prism_sample_msw/{2,5}. + +$prism_sample_msw(Sw,V) :- + $pp_get_parameters(Sw,Values,Pbs),!, + sumlist(Pbs,Sum), + random_uniform(Sum,R), + $pp_choose(Pbs,R,Values,V,_P). + +$prism_sample_msw(Sw,V,Depth,_CP,CallNo,AR) :- + $pp_get_parameters(Sw,Values,Pbs),!, + c_get_dg_flag(Flag), + $print_call(Flag,' Call: ',(msw(Sw,V):P),Depth,CallNo,AR), + sumlist(Pbs,Sum), + random_uniform(Sum,R), + ( $pp_choose(Pbs,R,Values,V,P) -> + $print_call(Flag,' Exit: ',(msw(Sw,V):P),Depth,CallNo,AR) + ; $print_call(Flag,' Fail: ',msw(Sw,V),Depth,CallNo,AR), + fail + ). + +$pp_choose(Pbs,R,Vs,X,P) :- $pp_choose(0,Pbs,R,Vs,X,P). +$pp_choose(CPb,[Pb|Pbs],R,[V|Vs],X,P) :- + CPb1 is CPb+Pb, + ( R < CPb1 -> X = V, P = Pb + ; Pbs = [] -> X = V, P = Pb + ; $pp_choose(CPb1,Pbs,R,Vs,X,P) + ). + +%%---------------------------------------- +%% sampling utils + +get_samples(N,G,Gs) :- % G assumed to never fail + $pp_require_positive_integer(N,$msg(1203),get_samples/3), + $pp_require_probabilistic_atom(G,$msg(1201),get_samples/3), + $pp_get_samples(0,N,G,Gs). + +$pp_get_samples(N,N,_,[]) :- !. +$pp_get_samples(N0,N,G,[G1|Gs]) :- + copy_term(G,G1),!, + sample(G1), + N1 is N0 + 1,!, + $pp_get_samples(N1,N,G,Gs). + +get_samples_c(N,G,Gs) :- get_samples_c(N,G,true,Gs). + +get_samples_c(N,G,C,Gs) :- + get_samples_c(N,G,C,Gs,[NS,NF]), + format("sampling -- #success = ~w~n",[NS]), + format("sampling -- #failure = ~w~n",[NF]). + +get_samples_c(PairN,PairG,C,Gs,[NS,NF]) :- + ( [N,M] = PairN -> true ; N = PairN, M = PairN ), + ( [S,G] = PairG -> true ; S = PairG, G = PairG ), + $pp_require_positive_integer_or_infinity(N,$msg(1204),get_samples_c/5), + $pp_require_positive_integer(M,$msg(1203),get_samples_c/5), + $pp_require_probabilistic_atom(S,$msg(1201),get_samples_c/5), + $pp_require_callable(C,$msg(1202),get_samples_c/5), + $pp_get_samples_c(0,N,M,S,G,C,Gs,0,NS,0,NF). + +$pp_get_samples_c(N,N,_ ,_,_,_,[],NS,NS,NF,NF) :- !. +$pp_get_samples_c(_,_,NS,_,_,_,[],NS,NS,NF,NF) :- !. + +$pp_get_samples_c(N0,N,M,S,G,C,Gs,NS0,NS,NF0,NF) :- + copy_term([S,G,C],[S1,G1,C1]),!, + ( sample(S1),!,call(C1) -> + Gs = [G1|Gs1], NS1 is NS0 + 1, NF1 is NF0 + ; Gs = Gs1, NS1 is NS0, NF1 is NF0 + 1 + ), + N1 is N0 + 1,!, + $pp_get_samples_c(N1,N,M,S,G,C,Gs1,NS1,NS,NF1,NF). + +%%---------------------------------------- + +$pp_require_positive_integer_or_infinity(X,MsgID,Source) :- + ( ( X == inf ; integer(X), X > 0 ) -> + true + ; $pp_raise_on_require([X],MsgID,Source,$pp_error_positive_integer_or_infinity) + ). + +$pp_error_positive_integer_or_infinity(X,Error) :- + X \== inf, + ( $pp_error_integer(X,Error) + ; X =< 0 -> Error = domain_error(infinity_or_greater_than_zero,X) + ). diff --git a/packages/prism/src/prolog/up/switch.pl b/packages/prism/src/prolog/up/switch.pl new file mode 100644 index 000000000..f79a6052b --- /dev/null +++ b/packages/prism/src/prolog/up/switch.pl @@ -0,0 +1,844 @@ +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%%%% set_sw/1,set_sw/2: initialize the prob. of MSW + +set_sw(Sw) :- set_sw(Sw,default). + +set_sw(Sw,Dist) :- + $pp_require_ground(Sw,$msg(0101),set_sw/2), + $pp_require_switch_outcomes(Sw,$msg(0102),set_sw/2), + $pp_require_distribution(Dist,$msg(0200),set_sw/2), + $pp_set_sw(Sw,Dist). + +$pp_set_sw(Sw,Dist) :- + ( $pd_fixed_parameters(Sw) -> $pp_raise_warning($msg(0109),[Sw]) + ; $pp_get_values(Sw,Values), + length(Values,N), + expand_probs(Dist,N,Probs), + ( retract($pd_parameters(Sw,_,_)) -> true ; true), + assert($pd_parameters(Sw,Values,Probs)) + ),!. + +%% set_sw_all(Sw): set parameters to all switches that matches with Sw. + +set_sw_all :- $pp_set_sw_all(_,default). +set_sw_all(Sw) :- $pp_set_sw_all(Sw,default). +set_sw_all(Sw,Dist) :- $pp_set_sw_all(Sw,Dist). + +$pp_set_sw_all(Sw,Dist) :- + findall(Sw,$pp_registered_sw(Sw),Sws), + $pp_set_sw_list(Sws,Dist),!. + +$pp_set_sw_list([],_). +$pp_set_sw_list([Sw|Sws],Dist) :- + set_sw(Sw,Dist),!, + $pp_set_sw_list(Sws,Dist). + +% fix switches +fix_sw(Sw,Dist) :- + $pp_require_ground(Sw,$msg(0101),fix_sw/2), + $pp_require_switch_outcomes(Sw,$msg(0102),fix_sw/2), + $pp_require_distribution(Dist,$msg(0200),fix_sw/2), + $pp_unfix_sw(Sw), + $pp_set_sw(Sw,Dist), + $pp_fix_sw(Sw),!. + +fix_sw(Sw) :- var(Sw),!, + ( get_sw(switch(Sw1,_,_,_)), + fix_sw(Sw1), + fail + ; true + ). +fix_sw(Sw) :- Sw = [_|_],!, + $pp_fix_sw_list(Sw). +fix_sw(Sw) :- + ( $pd_parameters(Sw,_,_), + $pp_fix_sw(Sw), + fail + ; true + ). + +$pp_fix_sw_list([]). +$pp_fix_sw_list([Sw|Sws]) :- + fix_sw(Sw),!, + $pp_fix_sw_list(Sws). + +$pp_fix_sw(Sw) :- + ( $pd_fixed_parameters(Sw) -> true + ; assert($pd_fixed_parameters(Sw)) + ). + +unfix_sw(Sw) :- var(Sw),!, + ( get_sw(switch(Sw1,_,_,_)), + unfix_sw(Sw1), + fail + ; true + ). +unfix_sw(SwList) :- SwList = [_|_],!,$pp_unfix_sw_list(SwList). +unfix_sw(Sw) :- + ( $pd_parameters(Sw,_,_), + $pp_unfix_sw(Sw), + fail + ; true + ). + +$pp_unfix_sw_list([]). +$pp_unfix_sw_list([Sw|Sws]) :- + $pp_unfix_sw(Sw),!, + $pp_unfix_sw_list(Sws). + +$pp_unfix_sw(Sw) :- + ( retract($pd_fixed_parameters(Sw)) -> true ; true). + +% show msw +show_sw :- show_sw(_). + +show_sw(Sw) :- + findall(Sw,$pp_registered_sw(Sw),Sws0), + sort(Sws0,Sws), + $pp_show_sw_list(Sws). + +$pp_show_sw_list([]) :- !. +$pp_show_sw_list([Sw|Sws]) :-!, + $pp_show_sw1(Sw),!, + $pp_show_sw_list(Sws). + +% We can assume Sw is ground +$pp_show_sw1(Sw) :- + $pp_get_parameters(Sw,Values,Probs), + format("Switch ~w: ",[Sw]), + ( $pd_fixed_parameters(Sw) -> write('fixed_p:') ; write('unfixed_p:') ), + $pp_show_sw_values(Values,Probs), + nl. + +$pp_show_sw_values([],_Ps). +$pp_show_sw_values([V|Vs],[P|Ps]) :- + format(" ~w (p: ~9f)",[V,P]),!, + $pp_show_sw_values(Vs,Ps). + +get_sw(Sw) :- + get_sw(SwName,Status,Values,Probs), + Sw = switch(SwName,Status,Values,Probs). + +get_sw(Sw,[Status,Values,Probs]) :- + get_sw(Sw,Status,Values,Probs). + +% - Inconsitency of outcome spaces are checked in advance in +% $pp_get_parameters/3 and $pp_get_expectations/3. + +get_sw(Sw,Status,Values,Probs) :- + $pp_get_parameters(Sw,Values,Probs), + ( $pd_fixed_parameters(Sw) -> Status = fixed ; Status = unfixed ). + +get_sw(Sw,Status,Values,Probs,Es) :- + $pp_get_parameters(Sw,Values,Probs), + $pp_get_expectations(Sw,_,Es), + ( $pd_fixed_parameters(Sw) -> Status = fixed ; Status = unfixed ). + +%% save/restore switch information + +save_sw :- save_sw('Saved_SW'). + +save_sw(File) :- + open(File,write,OutStream), + ( get_sw(SwName,Status,Values,Probs), + format(OutStream,"switch(~q,~q,~q,",[SwName,Status,Values]), + $pp_write_distribution(OutStream,Probs,'['), + format(OutStream,"]).~n",[]), + fail + ; true + ), + close(OutStream),!. + +$pp_write_distribution(_,[],_). +$pp_write_distribution(OutStream,[Prob|Probs],C) :- + format(OutStream,"~w~15e",[C,Prob]),!, + $pp_write_distribution(OutStream,Probs,','). + +restore_sw :- restore_sw('Saved_SW'). + +restore_sw(File) :- + open(File,read,InStream), + repeat, + read(InStream,Switch), + ( Switch == end_of_file + ; Switch = switch(ID,_,_,Params), + set_sw(ID,Params), + fail + ), + close(InStream),!. + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%%% set_sw_{a,d}/1-2: initialize the hyperparameters of MSW + +set_sw_a(Sw) :- set_sw_a(Sw,default). + +set_sw_a(Sw,Spec) :- + $pp_require_ground(Sw,$msg(0101),set_sw_a/2), + $pp_require_switch_outcomes(Sw,$msg(0102),set_sw_a/2), + $pp_require_hyperparameters(Spec,$msg(0208),set_sw_a/2), + $pp_set_sw_a(Sw,Spec). + +$pp_set_sw_a(Sw,Spec) :- + ( $pd_fixed_hyperparameters(Sw) -> $pp_raise_warning($msg(0110),[Sw]) + ; $pp_get_values(Sw,Values), + length(Values,N), + $pp_expand_pseudo_counts(set_sw_a/2,Spec,N,Alphas,Deltas), + ( retract($pd_hyperparameters(Sw,_,_,_)) -> true ; true), + assert($pd_hyperparameters(Sw,Values,Alphas,Deltas)) + ),!. + +set_sw_d(Sw) :- set_sw_d(Sw,default). + +set_sw_d(Sw,Spec) :- + $pp_require_ground(Sw,$msg(0101),set_sw_d/2), + $pp_require_switch_outcomes(Sw,$msg(0102),set_sw_d/2), + $pp_require_hyperparameters(Spec,$msg(0209),set_sw_d/2), + $pp_set_sw_d(Sw,Spec). + +$pp_set_sw_d(Sw,Spec) :- + ( $pd_fixed_hyperparameters(Sw) -> $pp_raise_warning($msg(0110),[Sw]) + ; $pp_get_values(Sw,Values), + length(Values,N), + $pp_expand_pseudo_counts(set_sw_d/2,Spec,N,Alphas,Deltas), + ( retract($pd_hyperparameters(Sw,_,_,_)) -> true ; true), + assert($pd_hyperparameters(Sw,Values,Alphas,Deltas)) + ),!. + +% wrapper for getting alphas and deltas +$pp_expand_pseudo_counts(Caller,Spec,N,Alphas,Deltas) :- + expand_pseudo_counts(Spec,N,Hs), + ( Spec = default -> + ( get_prism_flag(default_sw_a,$disabled) -> Mode = delta + ; Mode = alpha + ) + ; Caller = set_sw_a/2 -> Mode = alpha + ; Mode = delta + ), + ( Mode = alpha -> + Alphas = Hs, + ( $pp_test_positive_numbers(Alphas) -> true + ; $pp_raise_domain_error($msg(0208),[Spec],[alphas,Spec],Caller) + ), % a bit dirty + $pp_alpha_to_delta(Alphas,Deltas) + ; % Mode = delta + Deltas = Hs, + $pp_delta_to_alpha(Deltas,Alphas) + ). + +%% aliases for backward compatibility +set_sw_h(Sw) :- set_sw_d(Sw). +set_sw_h(Sw,Spec) :- set_sw_d(Sw,Spec). + +%%% set_sw_all_{a,d}(Sw): +%%% set hyperparameters to all switches that matches with Sw. + +set_sw_all_a :- set_sw_all_a(_). + +set_sw_all_a(Sw) :- set_sw_all_a(Sw,default). + +set_sw_all_a(Sw,Spec) :- + findall(Sw,$pp_registered_sw(Sw),Sws), + $pp_set_sw_a_list(Sws,Spec),!. + +$pp_set_sw_a_list([],_). +$pp_set_sw_a_list([Sw|Sws],Spec) :- + set_sw_a(Sw,Spec),!, + $pp_set_sw_a_list(Sws,Spec). + + +set_sw_all_d :- set_sw_all_d(_). + +set_sw_all_d(Sw) :- set_sw_all_d(Sw,default). + +set_sw_all_d(Sw,Spec) :- + findall(Sw,$pp_registered_sw(Sw),Sws), + $pp_set_sw_d_list(Sws,Spec),!. + +$pp_set_sw_d_list([],_). +$pp_set_sw_d_list([Sw|Sws],Spec) :- + set_sw_d(Sw,Spec),!, + $pp_set_sw_d_list(Sws,Spec). + +%% aliases for backward compatibility + +set_sw_all_h :- set_sw_all_d. +set_sw_all_h(Sw) :- set_sw_all_d(Sw). +set_sw_all_h(Sw,Spec) :- set_sw_all_d(Sw,Spec). + +set_sw_a_all :- set_sw_all_a. +set_sw_a_all(Sw) :- set_sw_all_a(Sw). +set_sw_a_all(Sw,Spec) :- set_sw_all_a(Sw,Spec). + +set_sw_d_all :- set_sw_all_d. +set_sw_d_all(Sw) :- set_sw_all_d(Sw). +set_sw_d_all(Sw,Spec) :- set_sw_all_d(Sw,Spec). + +set_sw_h_all :- set_sw_all_h. +set_sw_h_all(Sw) :- set_sw_all_h(Sw). +set_sw_h_all(Sw,Spec) :- set_sw_all_h(Sw,Spec). + +%%% fix_sw_h(Sw,Spec) :- fix the hyperparameters of Sw at Spec + +fix_sw_a(Sw,Spec) :- + $pp_require_ground(Sw,$msg(0101),fix_sw_a/2), + $pp_require_switch_outcomes(Sw,$msg(0102),fix_sw_a/2), + $pp_require_hyperparameters(Spec,$msg(0208),fix_sw_a/2), + $pp_unfix_sw_h(Sw), + $pp_set_sw_a(Sw,Spec), + $pp_fix_sw_h(Sw),!. + +fix_sw_a(Sw) :- var(Sw),!, + ( get_sw_a(switch(Sw1,_,_,_)), + fix_sw_a(Sw1), + fail + ; true + ). +fix_sw_a(Sw) :- Sw = [_|_],!, + $pp_fix_sw_a_list(Sw). +fix_sw_a(Sw) :- + ( $pd_hyperparameters(Sw,_,_,_), + $pp_fix_sw_h(Sw), + fail + ; true + ),!. + +$pp_fix_sw_a_list([]). +$pp_fix_sw_a_list([Sw|Sws]) :- + fix_sw_a(Sw),!, + $pp_fix_sw_a_list(Sws). + +fix_sw_d(Sw,Spec) :- + $pp_require_ground(Sw,$msg(0101),fix_sw_d/2), + $pp_require_switch_outcomes(Sw,$msg(0102),fix_sw_d/2), + $pp_require_hyperparameters(Spec,$msg(0209),fix_sw_d/2), + $pp_unfix_sw_h(Sw), + $pp_set_sw_d(Sw,Spec), + $pp_fix_sw_h(Sw),!. + +fix_sw_d(Sw) :- var(Sw),!, + ( get_sw_d(switch(Sw1,_,_,_)), + fix_sw_d(Sw1), + fail + ; true + ). +fix_sw_d(Sw) :- Sw = [_|_],!, + $pp_fix_sw_d_list(Sw). +fix_sw_d(Sw) :- + ( $pd_hyperparameters(Sw,_,_,_), + $pp_fix_sw_h(Sw), + fail + ; true + ),!. + +$pp_fix_sw_d_list([]). +$pp_fix_sw_d_list([Sw|Sws]) :- + fix_sw_d(Sw),!, + $pp_fix_sw_d_list(Sws). + +$pp_fix_sw_h(Sw) :- + ( clause($pd_fixed_hyperparameters(Sw),_) -> true + ; assert($pd_fixed_hyperparameters(Sw)) + ). + +%% aliases for backward compatibility + +fix_sw_h(Sw,Spec) :- fix_sw_d(Sw,Spec). +fix_sw_h(Sw) :- fix_sw_d(Sw). + +%%% unfix_sw_{a,d}(Sw) :- unfix the hyperparameters of Sw + +unfix_sw_d(Sw) :- var(Sw),!, + ( get_sw_d(switch(Sw1,_,_,_)), + unfix_sw_d(Sw1), + fail + ; true + ). +unfix_sw_d(SwList) :- SwList = [_|_],!, + $pp_unfix_sw_d_list(SwList). +unfix_sw_d(Sw) :- + ( $pd_hyperparameters(Sw,_,_,_), + $pp_unfix_sw_h(Sw), + fail + ; true + ),!. + +$pp_unfix_sw_d_list([]). +$pp_unfix_sw_d_list([Sw|Sws]) :- + unfix_sw_d(Sw),!, + $pp_unfix_sw_d_list(Sws). + +$pp_unfix_sw_h(Sw) :- + ( retract($pd_fixed_hyperparameters(Sw)) -> true + ; true + ). + +%% aliases + +unfix_sw_a(Sw) :- unfix_sw_d(Sw). +unfix_sw_h(Sw) :- unfix_sw_d(Sw). + +%%% show hyperparameters + +show_sw_a :- show_sw_a(_). + +show_sw_a(Sw) :- + findall(Sw,$pp_registered_sw(Sw),Sws0), + sort(Sws0,Sws), + $pp_show_sw_list_a(Sws). + +$pp_show_sw_list_a([]) :- !. +$pp_show_sw_list_a([Sw|Sws]) :- !, + $pp_show_sw1_a(Sw),!, + $pp_show_sw_list_a(Sws). + +$pp_show_sw1_a(Sw) :- + $pp_get_hyperparameters(Sw,Values,Alphas,_), + format("Switch ~w: ",[Sw]), + ( $pd_fixed_hyperparameters(Sw) -> write('fixed_h:') ; write('unfixed_h:') ), + $pp_show_sw_a_values(Values,Alphas), + nl. + +$pp_show_sw_a_values([],_). +$pp_show_sw_a_values([V|Vs],[A|As]) :- + format(" ~w (a: ~9f)",[V,A]),!, + $pp_show_sw_a_values(Vs,As). + +show_sw_d :- show_sw_d(_). + +show_sw_d(Sw) :- + findall(Sw,$pp_registered_sw(Sw),Sws0), + sort(Sws0,Sws), + $pp_show_sw_list_d(Sws). + +$pp_show_sw_list_d([]) :- !. +$pp_show_sw_list_d([Sw|Sws]) :- !, + $pp_show_sw1_d(Sw),!, + $pp_show_sw_list_d(Sws). + +$pp_show_sw1_d(Sw) :- + $pp_get_hyperparameters(Sw,Values,_,Deltas), + format("Switch ~w: ",[Sw]), + ( $pd_fixed_hyperparameters(Sw) -> write('fixed_h:') ; write('unfixed_h:') ), + $pp_show_sw_d_values(Values,Deltas), + nl. + +$pp_show_sw_d_values([],_). +$pp_show_sw_d_values([V|Vs],[D|Ds]) :- + format(" ~w (d: ~9f)",[V,D]),!, + $pp_show_sw_d_values(Vs,Ds). + +%% aliases + +show_sw_h :- show_sw_d. +show_sw_h(Sw) :- show_sw_d(Sw). + +%%% show both parameters and hyperparameters + +show_sw_pa :- show_sw_pa(_). + +show_sw_pa(Sw) :- + findall(Sw,$pp_registered_sw(Sw),Sws0), + sort(Sws0,Sws), + $pp_show_sw_list_pa(Sws). + +$pp_show_sw_list_pa([]) :- !. +$pp_show_sw_list_pa([Sw|Sws]) :- !, + $pp_show_sw1_pa(Sw),!, + $pp_show_sw_list_pa(Sws). + +$pp_show_sw1_pa(Sw) :- + $pp_get_parameters(Sw,Values,Probs), + $pp_get_hyperparameters(Sw,_,Alphas,_), + format("Switch ~w: ",[Sw]), + ( $pd_fixed_parameters(Sw) -> write('fixed_p,') ; write('unfixed_p,') ), + ( $pd_fixed_hyperparameters(Sw) -> write('fixed_h:') ; write('unfixed_h:') ), + $pp_show_sw_pa_values(Values,Probs,Alphas), + nl,!. + +$pp_show_sw_pa_values([],_,_). + +$pp_show_sw_pa_values([V|Vs],[P|Ps],[A|As]) :- + format(" ~w (p: ~9f, a: ~9f)",[V,P,A]),!, + $pp_show_sw_pa_values(Vs,Ps,As). + +$pp_show_sw_pa_values([V|Vs],[P|Ps],$not_assigned) :- + format(" ~w (p: ~9f, a: n/a)",[V,P]),!, + $pp_show_sw_pa_values(Vs,Ps,$not_assigned). + +$pp_show_sw_pa_values([V|Vs],$not_assigned,[A|As]) :- + format(" ~w (p: n/a, a: ~9f)",[V,A]),!, + $pp_show_sw_pa_values(Vs,$not_assigned,As). + +show_sw_pd :- show_sw_pd(_). + +show_sw_pd(Sw) :- + findall(Sw,$pp_registered_sw(Sw),Sws0), + sort(Sws0,Sws), + $pp_show_sw_list_pd(Sws). + +$pp_show_sw_list_pd([]) :- !. +$pp_show_sw_list_pd([Sw|Sws]) :- !, + $pp_show_sw1_pd(Sw),!, + $pp_show_sw_list_pd(Sws). + +$pp_show_sw1_pd(Sw) :- + $pp_get_parameters(Sw,Values,Probs), + $pp_get_hyperparameters(Sw,_,_,Deltas), + format("Switch ~w: ",[Sw]), + ( $pd_fixed_parameters(Sw) -> write('fixed_p,') ; write('unfixed_p,') ), + ( $pd_fixed_hyperparameters(Sw) -> write('fixed_h:') ; write('unfixed_h:') ), + $pp_show_sw_pd_values(Values,Probs,Deltas), + nl,!. + +$pp_show_sw_pd_values([],_,_). + +$pp_show_sw_pd_values([V|Vs],[P|Ps],[D|Ds]) :- + format(" ~w (p: ~9f, d: ~9f)",[V,P,D]),!, + $pp_show_sw_pd_values(Vs,Ps,Ds). + +$pp_show_sw_pd_values([V|Vs],[P|Ps],$not_assigned) :- + format(" ~w (p: ~9f, d: n/a)",[V,P]),!, + $pp_show_sw_pd_values(Vs,Ps,$not_assigned). + +$pp_show_sw_pd_values([V|Vs],$not_assigned,[D|Ds]) :- + format(" ~w (p: n/a, d: ~9f)",[V,D]),!, + $pp_show_sw_pd_values(Vs,$not_assigned,Ds). + +%% aliases + +show_sw_b :- show_sw_pd. +show_sw_b(Sw) :- show_sw_pd(Sw). + +%%% get switch information including hyperparameters + +get_sw_a(Sw) :- + get_sw_a(SwName,Status,Values,Alphas), + Sw = switch(SwName,Status,Values,Alphas). + +get_sw_a(Sw,[Status,Values,Alphas]) :- get_sw_a(Sw,Status,Values,Alphas). + +get_sw_a(Sw,Status,Values,Alphas) :- + $pp_get_hyperparameters(Sw,Values,Alphas,_), + ( $pd_fixed_hyperparameters(Sw) -> Status = fixed ; Status = unfixed ). + +get_sw_a(Sw,Status,Values,Alphas,Es) :- + $pp_get_hyperparameters(Sw,Values,Alphas,_), + $pp_get_hyperexpectations(Sw,_,Es), + ( $pd_fixed_hyperparameters(Sw) -> Status = fixed ; Status = unfixed ). + +get_sw_d(Sw) :- + get_sw_d(SwName,Status,Values,Deltas), + Sw = switch(SwName,Status,Values,Deltas). + +get_sw_d(Sw,[Status,Values,Deltas]) :- get_sw_d(Sw,Status,Values,Deltas). + +get_sw_d(Sw,Status,Values,Deltas) :- + $pp_get_hyperparameters(Sw,Values,_,Deltas), + ( $pd_fixed_hyperparameters(Sw) -> Status = fixed ; Status = unfixed ). + +get_sw_d(Sw,Status,Values,Deltas,Es) :- + $pp_get_hyperparameters(Sw,Values,_,Deltas), + $pp_get_expectations(Sw,_,Es), + ( $pd_fixed_hyperparameters(Sw) -> Status = fixed ; Status = unfixed ). + +%% aliases + +get_sw_h(Sw) :- get_sw_d(Sw). +get_sw_h(Sw,Info) :- get_sw_d(Sw,Info). +get_sw_h(Sw,Status,Vs,Ds) :- get_sw_d(Sw,Status,Vs,Ds). + +%%% get switch information including both parameters and hyperparameters + +get_sw_pa(Sw) :- + get_sw_pa(SwName,StatusPair,Values,Probs,Alphas), + Sw = switch(SwName,StatusPair,Values,Probs,Alphas). + +get_sw_pa(Sw,[StatusPair,Values,Probs,Alphas]) :- + get_sw_pa(Sw,StatusPair,Values,Probs,Alphas). + +get_sw_pa(Sw,[StatusP,StatusH],Values,Probs,Alphas) :- + $pp_get_parameters(Sw,Values,Probs), + $pp_get_hyperparameters(Sw,_,Alphas,_), + ( $pd_fixed_parameters(Sw) -> StatusP = fixed ; StatusP = unfixed ), + ( $pd_fixed_hyperparameters(Sw) -> StatusH = fixed ; StatusH = unfixed ). + +get_sw_pa(Sw,[StatusP,StatusH],Values,Probs,Alphas,Es) :- + $pp_get_parameters(Sw,Values,Probs), + $pp_get_hyperparameters(Sw,_,Alphas,_), + $pp_get_hyperexpectations(Sw,_,Es), + ( $pd_fixed_parameters(Sw) -> StatusP = fixed ; StatusP = unfixed ), + ( $pd_fixed_hyperparameters(Sw) -> StatusH = fixed ; StatusH = unfixed ). + +get_sw_pd(Sw) :- + get_sw_pd(SwName,StatusPair,Values,Probs,Deltas), + Sw = switch(SwName,StatusPair,Values,Probs,Deltas). + +get_sw_pd(Sw,[StatusPair,Values,Probs,Deltas]) :- + get_sw_pd(Sw,StatusPair,Values,Probs,Deltas). + +get_sw_pd(Sw,[StatusP,StatusH],Values,Probs,Deltas) :- + $pp_get_parameters(Sw,Values,Probs), + $pp_get_hyperparameters(Sw,_,_,Deltas), + ( $pd_fixed_parameters(Sw) -> StatusP = fixed ; StatusP = unfixed ), + ( $pd_fixed_hyperparameters(Sw) -> StatusH = fixed ; StatusH = unfixed ). + +get_sw_pd(Sw,[StatusP,StatusH],Values,Probs,Deltas,Es) :- + $pp_get_parameters(Sw,Values,Probs), + $pp_get_hyperparameters(Sw,_,_,Deltas), + $pp_get_expectations(Sw,_,Es), + ( $pd_fixed_parameters(Sw) -> StatusP = fixed ; StatusP = unfixed ), + ( $pd_fixed_hyperparameters(Sw) -> StatusH = fixed ; StatusH = unfixed ). + +%% aliases + +get_sw_b(Sw) :- get_sw_pd(Sw). +get_sw_b(Sw,Info) :- get_sw_pd(Sw,Info). +get_sw_b(Sw,StatusPH,Vs,Ps,Ds) :- get_sw_pd(Sw,StatusPH,Vs,Ps,Ds). + +%%%% save hyperparameters + +save_sw_a :- save_sw_a('Saved_SW_A'). + +save_sw_a(File) :- + open(File,write,OutStream), + ( get_sw_a(SwName,Status,Values,Alphas), + format(OutStream,"switch(~q,~q,~q,",[SwName,Status,Values]), + $pp_write_hyperparameters(OutStream,Alphas,'['), + format(OutStream,"]).~n",[]), + fail + ; true + ), + close(OutStream),!. + +save_sw_d :- save_sw_d('Saved_SW_D'). + +save_sw_d(File) :- + open(File,write,OutStream), + ( get_sw_d(SwName,Status,Values,Deltas), + format(OutStream,"switch(~q,~q,~q,",[SwName,Status,Values]), + $pp_write_hyperparameters(OutStream,Deltas,'['), + format(OutStream,"]).~n",[]), + fail + ; true + ), + close(OutStream),!. + +$pp_write_hyperparameters(_,[],_). +$pp_write_hyperparameters(OutStream,[H|Hs],C) :- + format(OutStream,"~w~15e",[C,H]),!, + $pp_write_hyperparameters(OutStream,Hs,','). + +%% aliases + +save_sw_h :- save_sw_d. +save_sw_h(File) :- save_sw_d(File). + +%%%% restore hyperparameters + +restore_sw_a :- restore_sw_a('Saved_SW_A'). + +restore_sw_a(File) :- + open(File,read,InStream), + repeat, + read(InStream,Switch), + ( Switch == end_of_file + ; Switch = switch(ID,_,_,Alphas), + set_sw_a(ID,Alphas), + fail + ), + close(InStream),!. + +restore_sw_d :- restore_sw_d('Saved_SW_D'). + +restore_sw_d(File) :- + open(File,read,InStream), + repeat, + read(InStream,Switch), + ( Switch == end_of_file + ; Switch = switch(ID,_,_,Deltas), + set_sw_d(ID,Deltas), + fail + ), + close(InStream),!. + +%% aliases + +restore_sw_h :- restore_sw_d. +restore_sw_h(File) :- restore_sw_d(File). + +%%%% save both parameters and hyperparameters + +save_sw_pa :- save_sw, save_sw_a. + +save_sw_pa(FileP,FileA) :- + save_sw(FileP), + save_sw_a(FileA),!. + +save_sw_pd :- save_sw, save_sw_d. + +save_sw_pd(FileP,FileD) :- + save_sw(FileP), + save_sw_d(FileD),!. + +%% aliases + +save_sw_b :- save_sw_pd. +save_sw_b(FileP,FileD) :- save_sw_pd(FileP,FileD). + +%%%% restore both parameters and hyperparameters + +restore_sw_pa :- restore_sw, restore_sw_a. + +restore_sw_pa(FileP,FileA) :- + restore_sw(FileP), + restore_sw_a(FileA),!. + +restore_sw_pd :- restore_sw, restore_sw_d. + +restore_sw_pd(FileP,FileD) :- + restore_sw(FileP), + restore_sw_d(FileD),!. + +%% aliases + +restore_sw_b :- restore_sw_pd. +restore_sw_b(FileP,FileD) :- restore_sw_pd(FileP,FileD). + +%%---------------------------------------- +%% [Note] +%% $pp_get_{values,parameters,expectations,hyperparameters}/2 do not check +%% the groundness of switch names. + +% raises a exception when there are no msw declarations +% (and can be a replacement of values/2 called in the clause bodies) +get_values(Sw,Values) :- + $pp_require_msw_declaration($msg(0100),get_values0/2), + $pp_get_values(Sw,Values). + +% provides a simple access to value declarations +get_values0(Sw,Values) :- + current_predicate($pu_values/2), + $pp_get_values(Sw,Values). + +% deterministically behaves and raises a exception also when there is no msw +% declaration that matches with Sw +get_values1(Sw,Values) :- + $pp_require_ground(Sw,$msg(0101),get_values1/2), + $pp_require_switch_outcomes(Sw,$msg(0102),get_values1/2), + $pp_get_values(Sw,Values),!. + +% $pu_values/2 = translated values declarations +$pp_get_values(Sw,Values) :- $pu_values(Sw,Values). + +%%---------------------------------------- +%% Wrappers to the switch database + +$pp_get_parameters(Sw,Values,Probs) :- + ( ground(Sw) -> + get_values1(Sw,Values), + ( $pd_parameters(Sw,Values0,Probs0) -> + ( Values0 = Values -> Probs = Probs0 + ; $pd_fixed_parameters(Sw) -> + $pp_raise_runtime_error($msg(0106),[Sw], + modified_switch_outcomes, + $pp_get_parameters/3) + ; set_sw(Sw,default),!, + $pd_parameters(Sw,Values,Probs) + ) + ; set_sw(Sw,default),!, + $pd_parameters(Sw,Values,Probs) + ) + ; $pd_parameters(Sw,Values,Probs) + % if Sw is not ground, we do not assign the default distribution + ). + +% [Note] set_sw_a(Sw,default) and set_sw_d(Sw,default) behaves in the same way +$pp_get_hyperparameters(Sw,Values,Alphas,Deltas) :- + ( ground(Sw) -> + get_values1(Sw,Values), + ( $pd_hyperparameters(Sw,Values0,Alphas0,Deltas0) -> + ( Values0 = Values -> + Alphas = Alphas0, + Deltas = Deltas0 + ; $pd_fixed_hyperparameters(Sw) -> + $pp_raise_runtime_error($msg(0108),[Sw], + modified_switch_outcomes, + $pp_get_hyperparameters/4) + ; set_sw_a(Sw,default),!, + $pd_hyperparameters(Sw,Values,Alphas,Deltas) + ) + ; set_sw_a(Sw,default),!, + $pd_hyperparameters(Sw,Values,Alphas,Deltas) + ) + ; $pd_hyperparameters(Sw,Values,Alphas,Deltas) + ). + +$pp_get_expectations(Sw,Values,Es) :- + ( ground(Sw) -> + get_values1(Sw,Values), + $pd_expectations(Sw,Values0,Es0), + ( Values0 = Values -> Es = Es0 + ; $pp_raise_runtime_error($msg(0107),[Sw],modified_switch_outcomes, + $pp_get_expectations/3) + ) + ; $pd_expectations(Sw,Values,Es) + ). + +$pp_get_hyperexpectations(Sw,Values,Es) :- + ( ground(Sw) -> + get_values1(Sw,Values), + $pd_hyperexpectations(Sw,Values0,Es0), + ( Values0 = Values -> Es = Es0 + ; $pp_raise_runtime_error($msg(0107),[Sw],modified_switch_outcomes, + $pp_get_hyperexpectations/3) + ) + ; $pd_hyperexpectations(Sw,Values,Es) + ). + +%%---------------------------------------- + +$pp_registered_sw(Sw) :- % ground switch name will be returned + ( $pd_parameters(Sw,_,_) + ; $pd_hyperparameters(Sw,_,_,_) + ). + +show_reg_sw :- + get_reg_sw_list(Sws), + $pp_show_reg_sw(Sws). + +$pp_show_reg_sw(Sws) :- + format("Registered random switches:~n",[]), + $pp_show_reg_sw1(Sws). + +$pp_show_reg_sw1([]). +$pp_show_reg_sw1([Sw|Sws]) :- + format(" ~w~n",[Sw]),!, + $pp_show_reg_sw1(Sws). + +get_reg_sw(Sw) :- + get_reg_sw_list(Sws),!, + member(Sw,Sws). + +get_reg_sw_list(Sws) :- + findall(Sw,$pp_registered_sw(Sw),Sws0), + sort(Sws0,Sws). + +%%---------------------------------------- + +alpha_to_delta(Alphas,Deltas) :- + $pp_require_non_negative_numbers(Alphas,$msg(0208),alpha_to_delta/2), + $pp_alpha_to_delta(Alphas,Deltas). + +$pp_alpha_to_delta([],[]). +$pp_alpha_to_delta([Alpha|Alphas],[Delta|Deltas]) :- + Delta is Alpha - 1,!, + $pp_alpha_to_delta(Alphas,Deltas). + +delta_to_alpha(Deltas,Alphas) :- + $pp_require_non_negative_numbers(Deltas,$msg(0209),delta_to_alpha/2), + $pp_delta_to_alpha(Deltas,Alphas). + +$pp_delta_to_alpha([],[]). +$pp_delta_to_alpha([Delta|Deltas],[Alpha|Alphas]) :- + Alpha is Delta + 1,!, + $pp_delta_to_alpha(Deltas,Alphas). diff --git a/packages/prism/src/prolog/up/util.pl b/packages/prism/src/prolog/up/util.pl new file mode 100644 index 000000000..256e779bb --- /dev/null +++ b/packages/prism/src/prolog/up/util.pl @@ -0,0 +1,923 @@ +%%---------------------------------------- +%% error/warning (obsolete) + +err_msg(Msg) :- + format("{PRISM ERROR: ",[]),write(Msg),format("}~n",[]),!, + abort. +err_msg(Msg,Vars) :- + format("{PRISM ERROR: ",[]),format(Msg,Vars),format("}~n",[]),!, + abort. + +warn_msg(Msg) :- + ( get_prism_flag(warn,on) -> + format("{PRISM WARNING: ",[]),write(Msg),format("}~n",[]) + ; true + ). +warn_msg(Msg,Vars) :- + ( get_prism_flag(warn,on) -> + format("{PRISM WARNING: ",[]),format(Msg,Vars),format("}~n",[]) + ; true + ). + + +%%---------------------------------------- +%% internal utils + +%% probabilistic formulas + +$pp_is_user_probabilistic_atom(Goal) :- + callable(Goal), + functor(Goal,F,N), + $pd_is_prob_pred(F,N),!. + +$pp_is_probabilistic_atom(Goal) :- + ( nonvar(Goal), Goal ?= msw(_,_) + ; $pp_is_user_probabilistic_atom(Goal) + ),!. + +$pp_is_extended_probabilistic_atom(Goal) :- + ( $pp_is_probabilistic_atom(Goal) + ; $pp_is_dummy_goal(Goal) + ),!. + +$pp_is_probabilistic_callable(Goal) :- + callable(Goal), + $pp_is_probabilistic_callable_aux(Goal),!. + +$pp_is_probabilistic_callable_aux((G1,G2)) => + ( $pp_is_probabilistic_callable_aux(G1),callable(G2) + ; callable(G1),$pp_is_probabilistic_callable_aux(G2) + ). +$pp_is_probabilistic_callable_aux((G1;G2)) => + ( $pp_is_probabilistic_callable_aux(G1),callable(G2) + ; callable(G1),$pp_is_probabilistic_callable_aux(G2) + ). +$pp_is_probabilistic_callable_aux((C->A;B)) => + ( $pp_is_probabilistic_callable_aux(C),callable(A),callable(B) + ; callable(C),$pp_is_probabilistic_callable_aux(A),callable(B) + ; callable(C),callable(A),$pp_is_probabilistic_callable_aux(B) + ). +$pp_is_probabilistic_callable_aux(not(G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux(\+(G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux((C->A)) => + ( $pp_is_probabilistic_callable_aux(C),callable(A) + ; callable(C),$pp_is_probabilistic_callable_aux(A) + ). +$pp_is_probabilistic_callable_aux(write_call(G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux(write_call(_Opts,G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux((?? G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux((??* G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux((??> G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux((??< G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux((??+ G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux((??- G)) => + $pp_is_probabilistic_callable_aux(G). +$pp_is_probabilistic_callable_aux(G) :- + $pp_is_extended_probabilistic_atom(G). + +%% tabled probabilistic formulas + +$pp_is_tabled_probabilistic_atom(Goal) :- + callable(Goal), + functor(Goal,F,N), + $pd_is_tabled_pred(F,N),!. + +%% goals that can be handled with the write_call predicates + +$pp_is_write_callable(Goal) :- + ( Goal = '!' -> fail + ; Goal = (A,B) -> $pp_is_write_callable(A), $pp_is_write_callable(B) + ; Goal = (_;_) -> fail + ; Goal = \+(_) -> fail + ; Goal = not(_) -> fail + ; Goal = (_->_) -> fail + ; true + ). + +%% dummy goals + +$pp_create_dummy_goal(DummyGoal) :- + global_get($pg_dummy_goal_count,N0), + N1 is N0 + 1, + global_set($pg_dummy_goal_count,N1),!, + $pp_create_dummy_goal(N0,DummyGoal),!. + +$pp_create_dummy_goal(N,DummyGoal) :- + number_chars(N,NChars), + append(['$',p,d,'_',d,u,m,m,y],NChars,DummyGoalChars), + atom_chars(DummyGoal,DummyGoalChars). + +$pp_is_dummy_goal(G) :- + atom(G), + atom_chars(G,GChars), + GChars = ['$',p,d,'_',d,u,m,m,y|_]. + +%% option analyzer + +$pp_proc_opts(Opts,Pred,Vars,Defaults,Source) :- + $pp_require_list_or_nil(Opts,$msg(2109),Source), + $pp_proc_opts_core(Opts,Pred,Vars,Defaults,Source). + +$pp_proc_opts_core([],_,[],[],_Source) :- !. +$pp_proc_opts_core([],Pred,[Var|Vars],[Default|Defaults],Source) :- + ( Var = Default ; true ),!, + $pp_proc_opts_core([],Pred,Vars,Defaults,Source). +$pp_proc_opts_core([Opt|Opts],Pred,Vars,Defaults,Source) :- + nonvar(Opt), + Clause =.. [Pred,Opt,Pos,Val], + call(Clause), + nth1(Pos,Vars,Var), + ( var(Var) -> Var = Val + ; $pp_raise_runtime_error($msg(3003),[Opt],duplicate_option, + Source) + ),!, + $pp_proc_opts_core(Opts,Pred,Vars,Defaults,Source). +$pp_proc_opts_core([Opt|_],_,_,_,Source) :- + $pp_raise_runtime_error($msg(3002),[Opt],unknown_option,Source). + +%% sorting with duplicate elements remained + +$pp_sort_remain_dup(L0,L) :- sort('=<',L0,L). + + +%%---------------------------------------- +%% statistics + +show_goals :- + global_get($pg_observed_facts,GoalCountPairs0),!, + sort(GoalCountPairs0,GoalCountPairs), + $pp_find_total_count(GoalCountPairs,0,Total), + $pp_show_goals(GoalCountPairs,Total). +show_goals :- + $pp_raise_runtime_error($msg(3004),observation_not_found,show_goals/0). + +$pp_find_total_count([],Total,Total). +$pp_find_total_count([goal(_Goal,Count)|GoalCountPairs],Total0,Total) :- + Total1 is Total0 + Count,!, + $pp_find_total_count(GoalCountPairs,Total1,Total). + +$pp_show_goals([],Total) :- format("Total_count=~w~n",[Total]). +$pp_show_goals([goal(DummyGoal,Count)|GoalCountPairs],Total) :- + P is Count / Total * 100, + ( current_predicate($pd_dummy_goal_table/2), + $pd_dummy_goal_table(DummyGoal,Goal) + -> true + ; Goal = DummyGoal + ), + format("Goal ~w (count=~w, freq=~3f%)~n",[Goal,Count,P]), + $pp_show_goals(GoalCountPairs,Total). + +get_goals(Gs) :- + findall(Goal,$pp_get_one_goal(Goal),Gs0), + sort(Gs0,Gs). + +$pp_get_one_goal(Goal) :- + ( global_get($pg_observed_facts,GoalCountPairs) -> + $pp_get_one_goal(Goal,GoalCountPairs) + ; $pp_raise_runtime_error($msg(3004),observation_not_found,show_goals/0) + ). + +$pp_get_one_goal(Goal,[goal(DummyGoal,_Count)|_]) :- + current_predicate($pd_dummy_goal_table/2), + $pd_dummy_goal_table(DummyGoal,Goal). +$pp_get_one_goal(Goal,[goal(Goal,_Count)|_]). +$pp_get_one_goal(Goal,[_|Pairs]) :- $pp_get_one_goal(Goal,Pairs). + +get_goal_counts(GCounts) :- + findall([Goal,Count,Freq],$pp_get_one_goal_count(Goal,Count,Freq),GCounts0), + sort(GCounts0,GCounts). + +$pp_get_one_goal_count(Goal,Count,Freq) :- + ( global_get($pg_observed_facts,GoalCountPairs) -> + $pp_find_total_count(GoalCountPairs,0,Total), + $pp_get_one_goal_count(Goal,Count,Freq,GoalCountPairs,Total) + ; $pp_raise_runtime_error($msg(3004),observation_not_found,show_goals/0) + ). + +$pp_get_one_goal_count(Goal,Count,Freq,[goal(DummyGoal,Count)|_],Total) :- + current_predicate($pd_dummy_goal_table/2), + $pd_dummy_goal_table(DummyGoal,Goal), + Freq is Count / Total * 100. +$pp_get_one_goal_count(Goal,Count,Freq,[goal(Goal,Count)|_],Total) :- + Freq is Count / Total * 100. +$pp_get_one_goal_count(Goal,Count,Freq,[_|Pairs],Total) :- + $pp_get_one_goal_count(Goal,Count,Freq,Pairs,Total). + +prism_statistics(Name,L) :- + ( graph_statistics(Name,L) + ; learn_statistics(Name,L) + ; infer_statistics(Name,L) + ). + +graph_statistics(Name,L) :- + ( \+ $ps_num_subgraphs(_) -> fail + ; Name = num_subgraphs, + ( $ps_num_subgraphs(L) -> true ) + ; Name = num_nodes, + ( $ps_num_nodes(L) -> true ) + ; Name = num_goal_nodes, + ( $ps_num_goal_nodes(L) -> true ) + ; Name = num_switch_nodes, + ( $ps_num_switch_nodes(L) -> true ) + ; Name = avg_shared, + ( $ps_avg_shared(L) -> true ) + ). + +learn_statistics(Name,L) :- + ( \+ $ps_learn_time(_) -> fail + ; Name = log_likelihood, + ( $ps_log_likelihood(L) -> true ) + ; Name = log_post, + ( $ps_log_post(L) -> true ) + ; Name = log_prior, + ( $ps_log_post(LPost), $ps_log_likelihood(LogLike) -> L is LPost - LogLike ) + ; Name = lambda, + ( ( $ps_log_post(L) ; $ps_log_likelihood(L) ) -> true ) + ; Name = num_switches, + ( $ps_num_switches(L) -> true ) + ; Name = num_switch_values, + ( $ps_num_switch_values(L) -> true ) + ; Name = num_parameters, + ( $ps_num_switches(N0), $ps_num_switch_values(N1) -> L is N1 - N0 ) + ; Name = num_iterations, + ( $ps_num_iterations(L) -> true ) + ; Name = num_iterations_vb, + ( $ps_num_iterations_vb(L) -> true ) + ; Name = goals, + ( is_global($pg_observed_facts) -> get_goals(L) ) + ; Name = goal_counts, + ( is_global($pg_observed_facts) -> get_goal_counts(L) ) + ; Name = bic, + ( $ps_bic_score(L) -> true ) + ; Name = cs, + ( $ps_cs_score(L) -> true ) + ; Name = free_energy, + ( $ps_free_energy(L) -> true ) + ; Name = learn_time, + ( $ps_learn_time(L) -> true ) + ; Name = learn_search_time, + ( $ps_learn_search_time(L) -> true ) + ; Name = em_time, + ( $ps_em_time(L) -> true ) + ). + +infer_statistics(Name,L) :- + ( \+ $ps_infer_time(_) -> fail + ; Name = infer_time, + ( $ps_infer_time(L) -> true ) + ; Name = infer_search_time, + ( $ps_infer_search_time(L) -> true ) + ; Name = infer_calc_time, + ( $ps_infer_calc_time(L) -> true ) + ). + +prism_statistics :- + format("Statistics in PRISM:~n",[]),!, + ( prism_statistics(Name,L), + $pp_print_one_statistic(Name,L), + fail + ; true + ),!. + +learn_statistics :- + format("Statistics on learning:~n",[]),!, + ( learn_statistics(Name,L), + $pp_print_one_statistic(Name,L), + fail + ; true + ),!. + +graph_statistics :- + format("Statistics on the size of the explanation graphs:~n",[]),!, + ( graph_statistics(Name,L), + $pp_print_one_statistic(Name,L), + fail + ; true + ),!. + +infer_statistics :- + format("Statistics on inference:~n",[]),!, + ( infer_statistics(Name,L), + $pp_print_one_statistic(Name,L), + fail + ; true + ),!. + +$pp_print_one_statistic(Name,L) :- + ( Name = goals -> format(" ~w~24|: (run show_goals/0)~n",[Name]) + ; Name = goal_counts -> format(" ~w~24|: (run show_goals/0)~n",[Name]) + ; float(L) -> format(" ~w~24|: ~9g~n",[Name,L]) + ; format(" ~w~24|: ~w~n",[Name,L]) + ). + +%%---------------------------------------- +%% clause list reader/writer + +load_clauses(FileName,Clauses) :- + load_clauses(FileName,Clauses,[]). + +load_clauses(FileName,Clauses,From,Size) :- + $pp_raise_warning($msg(3300),[load_clauses/4,load_clauses/3]), + load_clauses(FileName,Clauses,[from(From),size(Size)]). + +load_clauses(FileName,Clauses,Opts) :- + $pp_require_atom(FileName,$msg(3000),load_clauses/3), + $pp_proc_opts(Opts,$load_clauses_option, + [From,Size], + [0 ,max ], + load_clauses/3), + open(FileName,read,Stream), + $pp_load_clauses_core(Stream,Clauses,From,Size), + close(Stream),!. + +$load_clauses_option(from(N),1,N) :- + integer(N),N >= 0. +$load_clauses_option(skip(N),1,N) :- + integer(N),N >= 0. +$load_clauses_option(size(N),2,N) :- + integer(N),N >= 0 ; N == max. + +$pp_load_clauses_core(_,[],_,0). +$pp_load_clauses_core(S,Xs,K,N) :- + $pp_load_clauses_read(S,X),!, + ( K > 0 -> Xs = Xs1, K1 is K - 1, N1 = N + ; N == max -> Xs = [X|Xs1], K1 = K, N1 = N + ; Xs = [X|Xs1], K1 = K, N1 is N - 1 + ),!, + $pp_load_clauses_core(S,Xs1,K1,N1). +$pp_load_clauses_core(_,[],K,N) :- + ( K =< 0, N == max -> true + ; $pp_raise_warning($msg(3008)) + ). + +$pp_load_clauses_read(S,X) :- + read(S,X),!,X \== end_of_file. + +save_clauses(FileName,Clauses) :- + save_clauses(FileName,Clauses,[]). + +save_clauses(FileName,Clauses,From,Size) :- + $pp_raise_warning($msg(3300),[save_clauses/4,save_clauses/3]), + save_clauses(FileName,Clauses,[from(From),size(Size)]). + +save_clauses(FileName,Clauses,Opts) :- + $pp_require_atom(FileName,$msg(3000),save_clauses/3), + $pp_require_list_or_nil(Clauses,$msg(2109),save_clauses/3), + $pp_proc_opts(Opts,$load_clauses_option, + [From,Size], + [0 ,max ], + save_clauses/3), + open(FileName,write,Stream), + $pp_save_clauses_core(Stream,Clauses,From,Size), + close(Stream),!. + +$pp_save_clauses_core(_,_,_,0) :- !. +$pp_save_clauses_core(S,[X|Xs1],K,N) :- + ( K > 0 -> K1 is K-1, N1 = N + ; N == max -> format(S,"~q.~n",[X]), K1 = K, N1 = N + ; format(S,"~q.~n",[X]), K1 = K, N1 is N-1 + ),!, + $pp_save_clauses_core(S,Xs1,K1,N1). +$pp_save_clauses_core(_,[],K,N) :- + ( K =< 0, N == max -> true + ; $pp_raise_warning($msg(3008)) + ),!. + +%%---------------------------------------- +%% csv loader [RFC 4180] + +load_csv(FileName,Rows) :- + load_csv(FileName,Rows,[]). + +load_csv(FileName,Rows,Opts) :- + $pp_require_atom(FileName,$msg(3000),load_csv/3), + $pp_proc_opts(Opts,$pp_load_csv_option, + [RFrom,RSize,CFrom,CSize,Pred,Conv,Quot,Cmnt,Miss], + [0,max,0,max,csvrow/1,1,34,none,_], + load_csv/3), + open(FileName,read,Stream), + $pp_load_csv_core(Stream,Rows,RFrom,RSize,CFrom,CSize,Pred,Conv,Quot,Cmnt,Miss), + close(Stream),!. + +$pp_load_csv_option(row_from(N),1,N) :- + integer(N),N >= 0. +$pp_load_csv_option(row_skip(N),1,N) :- + integer(N),N >= 0. +$pp_load_csv_option(row_size(N),2,N) :- + integer(N),N >= 0 ; N == max. +$pp_load_csv_option(col_from(N),3,N) :- + integer(N),N >= 0. +$pp_load_csv_option(col_skip(N),3,N) :- + integer(N),N >= 0. +$pp_load_csv_option(col_size(N),4,N) :- + integer(N),N >= 0 ; N == max. + +$pp_load_csv_option(pred(X),5,Pred) :- + ( X == [] -> Pred = []/0 + ; atom(X) -> Pred = X/1 + ; X = P/N -> atom(P),(N == 1;N == n),Pred = P/N + ). + +$pp_load_csv_option(parse_number(X),6,Flag) :- + ( X == yes -> Flag = 1 ; X == no -> Flag = 0 ). + +$pp_load_csv_option(double_quote(X),7,Code) :- + ( X == yes -> Code = 34 ; X == no -> Code = none ). + +$pp_load_csv_option(comment(X),8,Code) :- + atom(X),atom_length(X,1),char_code(X,Code). +$pp_load_csv_option(comment,8,35). + +$pp_load_csv_option(missing(X),9,Codes) :- + atom(X),atom_codes(X,Codes). +$pp_load_csv_option(missing,9,''). + +$pp_load_csv_core(_,[],_,0,_,_,_,_,_,_,_). +$pp_load_csv_core(S,Xs,K,N,J,M,Pred,Conv,Quot,Cmnt,Miss) :- + $pp_load_csv_read(S,Row0,Conv,Quot,Cmnt,Miss),!, + $pp_load_csv_extract(Row0,Row,J,M), + Pred = Name/Style, + ( Style == 0 -> X = Row + ; Style == 1 -> X =.. [Name,Row] + ; Style == n -> X =.. [Name|Row] + ), + ( K > 0 -> Xs = Xs1, K1 is K - 1, N1 = N + ; N == max -> Xs = [X|Xs1], K1 = K, N1 = N + ; Xs = [X|Xs1], K1 = K, N1 is N-1 + ),!, + $pp_load_csv_core(S,Xs1,K1,N1,J,M,Pred,Conv,Quot,Cmnt,Miss). +$pp_load_csv_core(_,[],K,N,_,_,_,_,_,_,_) :- + ( K =< 0, N == max -> true + ; $pp_raise_runtime_error($msg(3005),invalid_csv_format,load_csv/3) + ). + +$pp_load_csv_extract(Row0,Row1,J,M), M == max => + $pp_load_csv_extract_step1(Row0,Row1,J). +$pp_load_csv_extract(Row0,Row2,J,M), M \== max => + $pp_load_csv_extract_step1(Row0,Row1,J), + $pp_load_csv_extract_step2(Row1,Row2,M). + +$pp_load_csv_extract_step1(Xs,Xs,0). +$pp_load_csv_extract_step1([_|Xs],Ys,J) :- + J1 is J-1,!,$pp_load_csv_extract_step1(Xs,Ys,J1). +$pp_load_csv_extract_step1(_,_,_) :- + $pp_raise_runtime_error($msg(3006),invalid_csv_format,load_csv/3). + +$pp_load_csv_extract_step2(_,[],0). +$pp_load_csv_extract_step2([Z|Xs],[Z|Ys],M) :- + M1 is M-1,!,$pp_load_csv_extract_step2(Xs,Ys,M1). +$pp_load_csv_extract_step2(_,_,_) :- + $pp_raise_runtime_error($msg(3006),invalid_csv_format,load_csv/3). + +$pp_load_csv_read(S,Row,Conv,Quot,Cmnt,Miss) :- + $pp_load_csv_skip(S,Cmnt),!,$pp_load_csv_q0(S,Conv,Miss,Quot,Row-[],Any-Any). + +$pp_load_csv_skip(S,Cm) :- + peek_code(S,Code), + ( Code == -1 -> fail + ; Code == Cm -> $pp_load_csv_skip(S),!,$pp_load_csv_skip(S,Cm) + ; true + ). + +$pp_load_csv_skip(S) :- + get_code(S,Code), + ( Code =:= -1 -> fail + ; Code =:= 10 -> true + ; Code =:= 13 -> $pp_load_csv_crlf(S) + ; $pp_load_csv_skip(S) + ). + +$pp_load_csv_crlf(S) :- + ( peek_code(S,10) -> get_code(S,10) ; true ). + +%% 3rd arg. = parse numeric values? +%% 4th arg. = missing value + +$pp_load_csv_done(_,Codes-[],_,M) :- + nonvar(M),Codes = M,!. +$pp_load_csv_done(Value,Codes-[],1,_) :- + forall(member(Code,Codes),(32= % EOF + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],! + ; Code == 10 -> % LF + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],! + ; Code == 13 -> % CR + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],!,$pp_load_csv_crlf(S) + ; Code == 44 -> % , + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs1],!, + $pp_load_csv_q0(S,Cv,Ms,Dq,Xs1-Xs0,Any-Any) + ; Code == Dq -> % " + !,$pp_load_csv_q2(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys0) + ; % ELSE + Ys0 = [Code|Ys1],!,$pp_load_csv_q1(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys1) + ). + +$pp_load_csv_q1(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys0) :- + get_code(S,Code), + ( Code == -1 -> % EOF + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],! + ; Code == 10 -> % LF + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],! + ; Code == 13 -> % CR + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],!,$pp_load_csv_crlf(S) + ; Code == 44 -> % , + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs1],!, + $pp_load_csv_q0(S,Cv,Ms,Dq,Xs1-Xs0,Any-Any) + ; Code == Dq -> % " + close(S),!, + $pp_raise_runtime_error($msg(3007),invalid_csv_format,load_csv/3) + ; % ELSE + Ys0 = [Code|Ys1],!,$pp_load_csv_q1(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys1) + ). + +$pp_load_csv_q2(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys0) :- + get_code(S,Code), + ( Code == -1 -> % EOF + close(S),!, + $pp_raise_runtime_error($msg(3007),invalid_csv_format,load_csv/3) + ; Code == Dq -> % " + !,$pp_load_csv_q3(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys0) + ; % ELSE + Ys0 = [Code|Ys1],!,$pp_load_csv_q2(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys1) + ). + +$pp_load_csv_q3(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys0) :- + get_code(S,Code), + ( Code == -1 -> % EOF + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],! + ; Code == 10 -> % LF + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],! + ; Code == 13 -> % CR + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs0],!,$pp_load_csv_crlf(S) + ; Code == 44 -> % , + $pp_load_csv_done(X,Ys-Ys0,Cv,Ms),Xs = [X|Xs1],!, + $pp_load_csv_q0(S,Cv,Ms,Dq,Xs1-Xs0,Any-Any) + ; Code == Dq -> % " + Ys0 = [Code|Ys1],!,$pp_load_csv_q2(S,Cv,Ms,Dq,Xs-Xs0,Ys-Ys1) + ; % ELSE + close(S),!, + $pp_raise_runtime_error($msg(3007),invalid_csv_format,load_csv/3) + ). + + +%%---------------------------------------- +%% pretty e-graph printer + +print_graph(G) :- + current_output(S),print_graph(S,G, [] ). + +print_graph(G,Opts) :- + current_output(S),print_graph(S,G,Opts). + +print_graph(S,G,Opts) :- + $pp_require_list(G,$msg(2104),print_graph/3), + $pp_proc_opts(Opts,$pp_print_graph_option, + [Lr0,And,Or0], + ["" ,"&","v"], + pring_graph/3),!, + ( Lr0 == "" -> Colon = ":" ; Colon = "" ), + length(Lr0,LenLr), + length(Or0,LenOr), + PadLr is LenOr-LenLr,$pp_print_graph_pad(Lr0,Lr,PadLr), + PadOr is LenLr-LenOr,$pp_print_graph_pad(Or0,Or,PadOr),!, + $pp_print_graph_roots(S,G,Colon,Lr,And,Or). + +$pp_print_graph_option(lr(T) ,1,S) :- $pp_print_graph_optarg(T,S). +$pp_print_graph_option(and(T),2,S) :- $pp_print_graph_optarg(T,S). +$pp_print_graph_option(or(T) ,3,S) :- $pp_print_graph_optarg(T,S). + +$pp_print_graph_optarg(T,S) :- + ( atom(T) -> atom_codes(T,S) + ; length(T,_),forall(member(X,T),(integer(X),0= T = S + ). + +$pp_print_graph_pad(Xs,Ys,N), N =< 0 => Xs = Ys. +$pp_print_graph_pad(Xs,Ys,N), N > 0 => Ys = [32|Ys1], N1 is N-1, !, $pp_print_graph_pad(Xs,Ys1,N1). + +$pp_print_graph_roots(_,[],_,_,_,_). +$pp_print_graph_roots(S,[node(L,[])|Nodes],Colon,Lr,And,Or) :- + format(S,"~w~n",[L]),!, + $pp_print_graph_roots(S,Nodes,Colon,Lr,And,Or). +$pp_print_graph_roots(S,[node(L,Paths)|Nodes],Colon,Lr,And,Or) :- + format(S,"~w~s~n",[L,Colon]), + $pp_print_graph_paths(S,Paths,Lr,And,Or),!, + $pp_print_graph_roots(S,Nodes,Colon,Lr,And,Or). +$pp_print_graph_roots(S,[node(L,[],V)|Nodes],Colon,Lr,And,Or) :- + ( V = [V1,V2] -> + format(S,"~w [~6g,~6g]~n",[L,V1,V2]) + ; format(S,"~w [~6g]~n",[L,V]) + ),!, + $pp_print_graph_roots(S,Nodes,Colon,Lr,And,Or). +$pp_print_graph_roots(S,[node(L,Paths,V)|Nodes],Colon,Lr,And,Or) :- + ( V = [V1,V2] -> + format(S,"~w [~6g,~6g]~s~n",[L,V1,V2,Colon]) + ; format(S,"~w [~6g]~s~n",[L,V,Colon]) + ),!, + $pp_print_graph_paths_aux(S,Paths,Lr,And,Or),!, + $pp_print_graph_roots(S,Nodes,Colon,Lr,And,Or). + +$pp_print_graph_paths(_,[],_,_,_). +$pp_print_graph_paths(_,[path([],[])],_,_,_) :- !. +$pp_print_graph_paths(S,[path(TNodes,SNodes)|Paths],Conn,And,Or) :- + write(S,' '), + append(TNodes,SNodes,Nodes), + $pp_print_graph_nodes(S,Nodes,Conn,And), + nl(S),!, + $pp_print_graph_paths(S,Paths,Or,And,Or). + +$pp_print_graph_nodes(_,[],_,_). +$pp_print_graph_nodes(S,[Node|Nodes],Conn,And) :- + format(S," ~s ~w",[Conn,Node]),!, + $pp_print_graph_nodes(S,Nodes,And,And). + +$pp_print_graph_paths_aux(_,[],_,_,_). +$pp_print_graph_paths_aux(_,[path([],[],_)],_,_,_) :- !. +$pp_print_graph_paths_aux(S,[path(TNodes,SNodes,V)|Paths],Conn,And,Or) :- + write(S,' '), + append(TNodes,SNodes,Nodes), + $pp_print_graph_nodes_aux(S,Nodes,Conn,And), + write(S,' '), + ( V = [V1,V2] -> + format(S,"{~6g,~6g}",[V1,V2]) + ; format(S,"{~6g}",[V]) + ), + nl(S),!, + $pp_print_graph_paths_aux(S,Paths,Or,And,Or). + +$pp_print_graph_nodes_aux(_,[],_,_). +$pp_print_graph_nodes_aux(S,[Node|Nodes],Conn,And) :- + ( Node = gnode(Label,Value) ; Node = snode(Label,Value) ), + ( Value = [Value1,Value2] -> + format(S," ~s ~w [~6g,~6g]",[Conn,Label,Value1,Value2]) + ; format(S," ~s ~w [~6g]",[Conn,Label,Value]) + ),!, + $pp_print_graph_nodes_aux(S,Nodes,And,And). + + +%%---------------------------------------- +%% pretty tree printer + +print_tree(T) :- + current_output(S),print_tree(S,T,[]). + +print_tree(T,Opts) :- + current_output(S),print_tree(S,T,Opts). + +print_tree(S,T,Opts) :- + $pp_require_list(T,$msg(2104),print_tree/3), + $pp_proc_opts(Opts,$pp_opts_print_tree,[Indent],[3],print_tree/3), + number_codes(Indent,Format0), + append("~",Format0,Format1), + append(Format1,"s",Format2), + $pp_print_tree_root(S,T,0,Format2). + +$pp_opts_print_tree(indent(N),1,N) :- + integer(N), N >= 1. + +$pp_print_tree_root(S,[L|Sibs],K,Format) :- + $pp_print_tree_node(S,L,K,Format), + K1 is K + 1, !, + $pp_print_tree_sibs(S,Sibs,K1,Format). + +$pp_print_tree_sibs(_,Xs,_,_), Xs = [] => + true. +$pp_print_tree_sibs(S,Xs,K,Format), Xs = [X|Xs1] => + ( X ?= [_|_] -> + $pp_print_tree_root(S,X,K,Format) + ; $pp_print_tree_node(S,X,K,Format) + ), !, + $pp_print_tree_sibs(S,Xs1,K,Format). + +$pp_print_tree_node(S,L,K,_), K == 0 => + write(S,L), nl(S). +$pp_print_tree_node(S,L,K,Format), K > 0 => + format(S,Format,["|"]), + K1 is K - 1, !, + $pp_print_tree_node(S,L,K1,Format). + + +%%---------------------------------------- +%% e-graph manipulator + +strip_switches(G0,G1) :- + $pp_require_list(G0,$msg(2104),strip_switches/2), + $pp_strip_switches(G0,G1). + +$pp_strip_switches([],[]). +$pp_strip_switches([node(L,Ps0)|Ns0],[node(L,Ps1)|Ns1]) :- + $pp_strip_switches_sub(Ps0,Ps1),!, + $pp_strip_switches(Ns0,Ns1). + +$pp_strip_switches_sub([],[]). +$pp_strip_switches_sub([path(Gs,_)|Ps0],[Gs|Ps1]) :- !, + $pp_strip_switches_sub(Ps0,Ps1). + +%%---------------------------------------- +%% debugging aid + +write_call(Goal) :- + write_call([],Goal). + +write_call(Opts,Goal) :- + $pp_write_call_core(Opts,Goal,Goal). + +??(Goal) :- write_call([],Goal). +??*(Goal) :- write_call([all],Goal). +??>(Goal) :- write_call([call],Goal). +??<(Goal) :- write_call([exit+fail],Goal). +??+(Goal) :- write_call([exit],Goal). +??-(Goal) :- write_call([fail],Goal). + +disable_write_call :- + set_prism_flag(write_call_events,off). + +$pp_write_call_core(Opts,Source,Goal) :- + $pp_require_write_callable(Goal,$msg(3200),write_call/2), + $pp_write_call_proc_opts(Opts,Call,Exit,Redo,Fail,Indent,Marker), + $pp_write_call_print(Call,'Call',Indent,Marker,Source), + ( Goal, ( $pp_write_call_print(Exit,'Exit',Indent,Marker,Source) + ; $pp_write_call_print(Redo,'Redo',Indent,Marker,Source), fail + ) + ; $pp_write_call_print(Fail,'Fail',Indent,Marker,Source), fail + ). + +$pp_write_call_build(Opts,Source,Goal,Body) :- + Body = ( $pp_write_call_proc_opts(Opts,Call,Exit,Redo,Fail,Indent,Marker), + $pp_write_call_print(Call,'Call',Indent,Marker,Source), + ( Goal,( $pp_write_call_print(Exit,'Exit',Indent,Marker,Source) + ; $pp_write_call_print(Redo,'Redo',Indent,Marker,Source), fail + ) + ; $pp_write_call_print(Fail,'Fail',Indent,Marker,Source), fail + ) + ),!. + +$pp_write_call_proc_opts(Opts,Call,Exit,Redo,Fail,Indent,Marker) :- + get_prism_flag(write_call_events,FlagValue), + $pp_proc_opts(Opts,$pp_write_call_option, + [Events,Indent,Marker],[FlagValue,0,_], + write_call/2), + ( FlagValue == off -> + Call = 0, Exit = 0, Redo = 0, Fail = 0 + ; $pp_write_call_decomp(Events,Call,Exit,Redo,Fail) + ), !. + +$pp_write_call_option(X,1,Y) :- + $pp_write_call_events(X,Y), !, Y \== none. +$pp_write_call_option(indent(X),2,X) :- !, integer(X). +$pp_write_call_option(marker(X),3,X) :- !. + +$pp_write_call_events(all,all) :- !. +$pp_write_call_events(none,none) :- !. +$pp_write_call_events(X,Y) :- + $pp_expr_to_list('+',X,Xs), + $pp_write_call_events(Xs,Y,0,0,0,0),!. + +$pp_write_call_events(Xs0,Y,C,E,R,F), Xs0 == [] => + $pp_write_call_decomp(Y,C,E,R,F), Y \== none. +$pp_write_call_events(Xs0,Y,C,E,R,F), Xs0 = [X|Xs1] => + ( X == call, C == 0 -> $pp_write_call_events(Xs1,Y,1,E,R,F) + ; X == exit, E == 0 -> $pp_write_call_events(Xs1,Y,C,1,R,F) + ; X == redo, R == 0 -> $pp_write_call_events(Xs1,Y,C,E,1,F) + ; X == fail, F == 0 -> $pp_write_call_events(Xs1,Y,C,E,R,1) + ). + +$pp_write_call_decomp(none,0,0,0,0). +$pp_write_call_decomp(call,1,0,0,0). +$pp_write_call_decomp(exit,0,1,0,0). +$pp_write_call_decomp(call+exit,1,1,0,0). +$pp_write_call_decomp(redo,0,0,1,0). +$pp_write_call_decomp(call+redo,1,0,1,0). +$pp_write_call_decomp(exit+redo,0,1,1,0). +$pp_write_call_decomp(call+exit+redo,1,1,1,0). +$pp_write_call_decomp(fail,0,0,0,1). +$pp_write_call_decomp(call+fail,1,0,0,1). +$pp_write_call_decomp(exit+fail,0,1,0,1). +$pp_write_call_decomp(call+exit+fail,1,1,0,1). +$pp_write_call_decomp(redo+fail,0,0,1,1). +$pp_write_call_decomp(call+redo+fail,1,0,1,1). +$pp_write_call_decomp(exit+redo+fail,0,1,1,1). +$pp_write_call_decomp(all,1,1,1,1). + +$pp_write_call_print(1,Head,Indent,Marker,Goal), var(Marker) => + tab(Indent), format("[~w] ~q~n",[Head,Goal]). +$pp_write_call_print(1,Head,Indent,Marker,Goal), nonvar(Marker) => + tab(Indent), format("[~w:~w] ~q~n",[Head,Marker,Goal]). +$pp_write_call_print(0,_,_,_,_). + +%%---------------------------------------- + +$pp_learn_message(S,E,T,M) :- + get_prism_flag(learn_message,LM), + $pp_learn_message_decomp(LM,S,E,T,M),!. + +%%---------------------------------------- + +$pp_learn_message_events(all,all) :- !. +$pp_learn_message_events(none,none) :- !. +$pp_learn_message_events(X,Y) :- + $pp_expr_to_list('+',X,Xs), + $pp_learn_message_events(Xs,Y,0,0,0,0). + +$pp_learn_message_events(Xs0,Y,S,E,T,M), Xs0 == [] => + $pp_learn_message_decomp(Y,S,E,T,M), Y \== none. +$pp_learn_message_events(Xs0,Y,S,E,T,M), Xs0 = [X|Xs1] => + ( X == search, S = 0 -> $pp_learn_message_events(Xs1,Y,1,E,T,M) + ; X == em, E = 0 -> $pp_learn_message_events(Xs1,Y,S,1,T,M) + ; X == stats, T = 0 -> $pp_learn_message_events(Xs1,Y,S,E,1,M) + ; X == misc, M = 0 -> $pp_learn_message_events(Xs1,Y,S,E,T,1) + ). + +$pp_learn_message_decomp(none, 0,0,0,0). +$pp_learn_message_decomp(search, 1,0,0,0). +$pp_learn_message_decomp(em, 0,1,0,0). +$pp_learn_message_decomp(search+em, 1,1,0,0). +$pp_learn_message_decomp(stats, 0,0,1,0). +$pp_learn_message_decomp(search+stats, 1,0,1,0). +$pp_learn_message_decomp(em+stats, 0,1,1,0). +$pp_learn_message_decomp(search+em+stats, 1,1,1,0). +$pp_learn_message_decomp(misc, 0,0,0,1). +$pp_learn_message_decomp(search+misc, 1,0,0,1). +$pp_learn_message_decomp(em+misc, 0,1,0,1). +$pp_learn_message_decomp(search+em+misc, 1,1,0,1). +$pp_learn_message_decomp(stats+misc, 0,0,1,1). +$pp_learn_message_decomp(search+stats+misc,1,0,1,1). +$pp_learn_message_decomp(em+stats+misc, 0,1,1,1). +$pp_learn_message_decomp(all, 1,1,1,1). + +%%---------------------------------------- +%% for parallel mode + +$pp_require_mp_mode :- + ( $pc_mp_mode -> true + ; $pp_raise_internal_error($msg(1005),invalid_module,$damon_load/0) + ). + +%%---------------------------------------- +%% expand the outcome space + +% ?- expand_values([3,2-5@2,1-3,t],X). +% X = [3,2,4,1,2,3,t] + +expand_values(Ns,ExpandedNs) :- + $pp_require_list_or_nil(Ns,$msg(2109),expland_values/2), + $pp_require_ground(Ns,$msg(1105),expand_values/2), + $pp_expand_values1(Ns,ExpandedNs). + +% just fails for errorneous inputs +expand_values1(Ns,ExpandedNs) :- + is_list(Ns), + ground(Ns), + $pp_expand_values1(Ns,ExpandedNs). + +$pp_expand_values1([],[]). +$pp_expand_values1([N|Ns],ENs) :- + ( N = Start-End@Step, + integer(Start),integer(End),integer(Step),Step>0 -> + $pp_require_integer_range_incl(Start,End,$msg(2008),expand_values/2), + $pp_expand_values2(Start,End,Step,ENs0), + append(ENs0,ENs1,ENs) + ; N = Start-End,integer(Start),integer(End) -> + $pp_require_integer_range_incl(Start,End,$msg(2008),expand_values/2), + $pp_expand_values2(Start,End,1,ENs0), + append(ENs0,ENs1,ENs) + ; ENs = [N|ENs1] + ),!, + $pp_expand_values1(Ns,ENs1). + +$pp_expand_values2(Start,End,_,[]) :- Start > End. +$pp_expand_values2(Start,End,Step,[Start|Ns]) :- + Start1 is Start + Step,!, + $pp_expand_values2(Start1,End,Step,Ns). + + +%%---------------------------------------- +%% delete temporary file + +$pp_delete_tmp_out :- + Tmp = '__tmp.out', + ( file_exists(Tmp) -> delete_file(Tmp) + ; true + ),!. + + +%%---------------------------------------- +%% log-gamma function + +lngamma(X,G) :- + $pp_require_positive_number(X,$msg(3400),lngamma/2), + $pc_lngamma(X,G). diff --git a/packages/prism/src/prolog/up/viterbi.pl b/packages/prism/src/prolog/up/viterbi.pl new file mode 100644 index 000000000..bf42e2833 --- /dev/null +++ b/packages/prism/src/prolog/up/viterbi.pl @@ -0,0 +1,785 @@ +%%%% Viterbi wrappers + +viterbi(G) :- + $pp_viterbi_wrapper(viterbi(G)). +viterbi(G,P) :- + $pp_viterbi_wrapper(viterbi(G,P)). +viterbif(G) :- + $pp_viterbi_wrapper(viterbif(G)). +viterbif(G,P,V) :- + $pp_viterbi_wrapper(viterbif(G,P,V)). +viterbit(G) :- + $pp_viterbi_wrapper(viterbit(G)). +viterbit(G,P,T) :- + $pp_viterbi_wrapper(viterbit(G,P,T)). +n_viterbi(N,G) :- + $pp_viterbi_wrapper(n_viterbi(N,G)). +n_viterbi(N,G,P) :- + $pp_viterbi_wrapper(n_viterbi(N,G,P)). +n_viterbif(N,G) :- + $pp_viterbi_wrapper(n_viterbif(N,G)). +n_viterbif(N,G,V) :- + $pp_viterbi_wrapper(n_viterbif(N,G,V)). +n_viterbit(N,G) :- + $pp_viterbi_wrapper(n_viterbit(N,G)). +n_viterbit(N,G,T) :- + $pp_viterbi_wrapper(n_viterbit(N,G,T)). +viterbig(G) :- + $pp_viterbi_wrapper(viterbig(G)). +viterbig(G,P) :- + $pp_viterbi_wrapper(viterbig(G,P)). +viterbig(G,P,V) :- + $pp_viterbi_wrapper(viterbig(G,P,V)). +n_viterbig(N,G) :- + $pp_viterbi_wrapper(n_viterbig(N,G)). +n_viterbig(N,G,P) :- + $pp_viterbi_wrapper(n_viterbig(N,G,P)). +n_viterbig(N,G,P,V) :- + $pp_viterbi_wrapper(n_viterbig(N,G,P,V)). + +$pp_viterbi_wrapper(Pred0) :- + get_prism_flag(viterbi_mode,Mode), + ( Mode == params -> Suffix = '_p' ; Mode == hparams -> Suffix = '_h' ),!, + Pred0 =.. [Name0|Args], + atom_concat(Name0,Suffix,Name1), + Pred1 =.. [Name1|Args],!, + call(Pred1). % do not add cut here (n_viterbig is non-deterministic) + +%%%% Viterbi routine with C interface +%% +%% viterbi_p(G) :- print the Viterbi prob +%% viterbi_p(G,P) :- output the Viterbi prob +%% viterbif_p(G) :- print the Viterbi path and the Viterbi prob +%% viterbif_p(G,P,VPath) :- output the Viterbi path and the Viterbi prob +%% +%% VPath is a list of node(G,Paths), where Paths is a list of +%% path(Gs,Sws), where Gs are subgoals of G and Sws are switches. +%% +%% Usually in VPath, node(msw(Sw,V),[]) is omitted, but optionally +%% it can be included in VPath. + +% Main routine: + +% viterbi family: + +viterbi_p(Goal) :- + viterbif_p(Goal,Pmax,_), + $pp_print_viterbi_prob(Pmax). + +viterbi_p(Goal,Pmax) :- + viterbif_p(Goal,Pmax,_). + +% viterbif family: + +viterbif_p(Goal) :- + viterbif_p(Goal,Pmax,VNodeL), + format("~n",[]), + print_graph(VNodeL,[lr('<=')]), + $pp_print_viterbi_prob(Pmax). + +viterbif_p(Goal,Pmax,VNodeL) :- + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),viterbif_p/3), + ( Goal = msw(I,_) -> + $pp_require_ground(I,$msg(0101),viterbif_p/3), + $pp_require_switch_outcomes(I,$msg(0102),viterbif_p/3) + ; true + ), + $pp_viterbif_p(Goal,Pmax,VNodeL). + +$pp_viterbif_p(Goal,Pmax,VNodeL) :- + $pp_clean_infer_stats, + cputime(T0), + $pp_viterbi_core(Goal,Pmax,VNodeL), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_viterbi_stats1(InfTime),!. + +% viterbit family: + +viterbit_p(Goal) :- + viterbit_p(Goal,Pmax,VTreeL), + format("~n",[]), + print_tree(VTreeL), + $pp_print_viterbi_prob(Pmax). + +viterbit_p(Goal,Pmax,VTreeL) :- + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),viterbit_p/3), + $pp_viterbif_p(Goal,Pmax,VNodeL), + viterbi_tree(VNodeL,VTreeL). + +% viterbig family: + +viterbig_p(Goal) :- + ( ground(Goal) -> viterbi_p(Goal) + ; viterbig_p(Goal,_,_) + ). + +viterbig_p(Goal,Pmax) :- + ( ground(Goal) -> viterbi_p(Goal,Pmax) + ; viterbig_p(Goal,Pmax,_) + ). + +viterbig_p(Goal,Pmax,VNodeL) :- + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),viterbig_p/3), + ( Goal = msw(I,_) -> + $pp_require_ground(I,$msg(0101),viterbif_p/3), + $pp_require_switch_outcomes(I,$msg(0102),viterbig_p/3) + ; true + ), + $pp_viterbig_p(Goal,Pmax,VNodeL). + +$pp_viterbig_p(Goal,Pmax,VNodeL) :- + $pp_clean_infer_stats, + cputime(T0), + $pp_viterbi_core(Goal,Pmax,VNodeL), + ( ground(Goal) -> true + ; VNodeL = [node(_,[path([Goal1],[])])|_] -> Goal = Goal1 + ; VNodeL = [node(_,[path([],[SwIns])])|_] -> Goal = SwIns + ), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_viterbi_stats1(InfTime),!. + +%% Common routine: + +$pp_print_viterbi_prob(Pmax) :- + ( get_prism_flag(log_scale,off) -> format("~nViterbi_P = ~15f~n",[Pmax]) + ; format("~nlog(Viterbi_P) = ~15f~n",[Pmax]) + ). + +$pp_viterbi_core(Goal,Pmax,VNodeL) :- + Goal = msw(I,V),!, + $pp_require_ground(I,$msg(0101),$pp_viterbi_core/3), + $pp_require_switch_outcomes(I,$msg(0102),$pp_viterbi_core/3), + ( ground(V) -> V = VCp ; copy_term(V,VCp) ), + $pp_create_dummy_goal(DummyGoal), + DummyBody = ($prism_expl_msw(I,VCp,Sid), + $pc_prism_goal_id_register(DummyGoal,Hid), + $prism_eg_path(Hid,[],[Sid])), + Prog = [pred(DummyGoal,0,_,_,tabled(_,_,_,_),[(DummyGoal:-DummyBody)])], + consult_preds([],Prog), + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(DummyGoal), + cputime(T2), + $pp_compute_viterbi_p(DummyGoal,Pmax,[node(DummyGoal,Paths)|VNodeL0]),!, + cputime(T3), + VNodeL = [node(msw(I,V),Paths)|VNodeL0], + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + +$pp_viterbi_core(Goal,Pmax,VNodeL) :- + ground(Goal),!, + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(Goal), + cputime(T2), + $pp_compute_viterbi_p(Goal,Pmax,VNodeL),!, + cputime(T3), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime),!. + +$pp_viterbi_core(Goal,Pmax,VNodeL) :- + copy_term(Goal,GoalCp), + ( $pp_trans_one_goal(GoalCp,CompGoal) -> BodyGoal = CompGoal + ; BodyGoal = (savecp(CP),Depth=0, + $pp_expl_interp_goal(GoalCp,Depth,CP,[],_,[],_,[],_,[],_)) + ), + $pp_create_dummy_goal(DummyGoal), + DummyBody = (BodyGoal, + $pc_prism_goal_id_register(GoalCp,GId), + $pc_prism_goal_id_register(DummyGoal,HId), + $prism_eg_path(HId,[GId],[])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_), + [(DummyGoal:-DummyBody)])], + consult_preds([],Prog), + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(DummyGoal), + cputime(T2), + $pp_compute_viterbi_p(DummyGoal,Pmax,[node(DummyGoal,Paths)|VNodeL0]),!, + cputime(T3), + VNodeL = [node(Goal,Paths)|VNodeL0], + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + +% Sws = [sw(Id,Instances,Probs,PseudoCs,Fixed,FixedH),...] +$pp_compute_viterbi_p(Goal,Pmax,VNodeL) :- + $pp_collect_sw_info(Sws), + $pc_export_sw_info(Sws), + $pc_prism_goal_id_get(Goal,Gid), + garbage_collect, + $pc_compute_viterbi(Gid,EGs,EGPaths,ESwPaths,Pmax), + $pp_decode_viterbi_path(EGs,EGPaths,ESwPaths,VNodeL),!. + +$pp_decode_viterbi_path([],[],[],[]) :- !. +$pp_decode_viterbi_path([Gid|Gids],[GPath|GPaths],[SPath|SPaths],[Node|Nodes]) :- + $pc_prism_goal_term(Gid,G), + ( GPath == [], SPath == [] -> + get_prism_flag(explicit_empty_expls,V), + ( V == off -> Node = node(G,[]) + ; Node = node(G,[path([],[])]) + ) + ; $pp_decode_gnodes(GPath,GPathDec,1,0,_Vg), + $pp_decode_snodes(SPath,SPathDec,1,0,_Vs), + Node = node(G,[path(GPathDec,SPathDec)]) + ),!, + $pp_decode_viterbi_path(Gids,GPaths,SPaths,Nodes). + + +%%%% +%%%% Top-N Viterbi +%%%% +%%%% n_viterbi_p(N,G) :- print the top-N Viterbi probs +%%%% n_viterbi_p(N,G,Ps) :- output the top-N Viterbi probs +%%%% n_viterbif_p(N,G) :- print the top-N Viterbi paths and the corresponding +%%%% Viterbi probs +%%%% n_viterbif_p(N,G,VPathL) :- output the list of top-N Viterbi paths and +%%%% the corresponding Viterbi probs +%%%% + +% n_viterbi family + +n_viterbi_p(N,Goal) :- + n_viterbif_p(N,Goal,VPathL), + ( member(v_expl(J,Pmax,_),VPathL), + $pp_print_n_viterbi(J,Pmax), + fail + ; true + ). + +n_viterbi_p(N,Goal,Ps) :- + n_viterbif_p(N,Goal,VPathL),!, + findall(Pmax,member(v_expl(_,Pmax,_),VPathL),Ps). + +% n_viterbif family + +n_viterbif_p(N,Goal) :- + n_viterbif_p(N,Goal,VPathL),!, + $pp_print_n_viterbif(VPathL). + +n_viterbif_p(N,Goal,VPathL) :- + $pp_require_positive_integer(N,$msg(1400),n_viterbif_p/3), + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),n_viterbif_p/3), + $pp_n_viterbif_p(N,Goal,VPathL). + +$pp_n_viterbif_p(N,Goal,VPathL) :- + $pp_clean_infer_stats, + cputime(T0), + $pp_n_viterbi_p_core(N,Goal,VPathL), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_viterbi_stats1(InfTime),!. + +% n_viterbit family + +n_viterbit_p(N,Goal) :- + n_viterbif_p(N,Goal,VPathL),!, + $pp_print_n_viterbit(VPathL). + +n_viterbit_p(N,Goal,VPathL) :- + n_viterbif_p(N,Goal,VPathL0),!, + $pp_build_n_viterbit(VPathL0,VPathL). + +%%%% +%%%% $pp_n_viterbig_p(N,Goal) :- the same as $pp_n_viterbig_p(N,Goal,_,_) +%%%% $pp_n_viterbig_p(N,Goal,Pmax) :- the same as $pp_n_viterbig_p(N,Goal,Pmax,_) +%%%% $pp_n_viterbig_p(N,Goal,Pmax,VNodeL) :- +%%%% if Goal is not ground, unify Goal with the first element in the K-th +%%%% Viterbi path VNodeL (K=0,1,2,...,(N-1) on backtracking). Pmax is the +%%%% probability of VNodeL. +%%%% + +n_viterbig_p(N,Goal) :- + ( ground(Goal) -> n_viterbi_p(N,Goal) + ; n_viterbig_p(N,Goal,_,_) + ). + +n_viterbig_p(N,Goal,Pmax) :- + ( ground(Goal) -> n_viterbi_p(N,Goal,Ps),!,member(Pmax,Ps) + ; n_viterbig_p(N,Goal,Pmax,_) + ). + +n_viterbig_p(N,Goal,Pmax,VNodeL) :- + $pp_require_positive_integer(N,$msg(1400),n_viterbi_p/3), + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),n_viterbi_p/3), + $pp_n_viterbig_p(N,Goal,Pmax,VNodeL). + +$pp_n_viterbig_p(N,Goal,Pmax,VNodeL) :- + $pp_clean_infer_stats, + cputime(T0), + $pp_n_viterbi_p_core(N,Goal,VPathL),!, + cputime(T1), + InfTime is T1 - T0, + $pp_assert_viterbi_stats1(InfTime),!, + ( ground(Goal) -> member(v_expl(J,Pmax,VNodeL),VPathL) + ; Goal = msw(_,_) -> + member(v_expl(J,Pmax,VNodeL),VPathL), + VNodeL = [node(_,[path([],[SwIns])])|_], + Goal = SwIns + ; % else + member(v_expl(J,Pmax,VNodeL),VPathL), + VNodeL = [node(_,[path([Goal1],[])])|_], + Goal = Goal1 + ). + +%% Common routines: + +$pp_print_n_viterbi(J,Pmax) :- + ( get_prism_flag(log_scale,off) -> + format("#~w: Viterbi_P = ~15f~n",[J,Pmax]) + ; format("#~w: log(Viterbi_P) = ~15f~n",[J,Pmax]) + ). + +$pp_print_n_viterbif([]). +$pp_print_n_viterbif([v_expl(J,Pmax,VNodeL)|VPathL]) :- + format("~n#~w~n",[J]), + print_graph(VNodeL,[lr('<=')]), + ( get_prism_flag(log_scale,off) -> format("~nViterbi_P = ~15f~n",[Pmax]) + ; format("~nlog(Viterbi_P) = ~15f~n",[Pmax]) + ),!, + $pp_print_n_viterbif(VPathL). + +$pp_print_n_viterbit([]). +$pp_print_n_viterbit([v_expl(J,Pmax,VNodeL)|VPathL]) :- + format("~n#~w~n",[J]), + viterbi_tree(VNodeL,VTreeL), + print_tree(VTreeL), + $pp_print_viterbi_prob(Pmax),!, + $pp_print_n_viterbit(VPathL). + +$pp_build_n_viterbit([],[]). +$pp_build_n_viterbit([v_expl(J,Pmax,VNodeL)|VPathL0], + [v_tree(J,Pmax,VTreeL)|VPathL1]) :- + viterbi_tree(VNodeL,VTreeL),!, + $pp_build_n_viterbit(VPathL0,VPathL1). + +$pp_n_viterbi_p_core(N,Goal,VPathL) :- + Goal = msw(I,V),!, + $pp_require_ground(I,$msg(0101),$pp_viterbi_core/3), + $pp_require_switch_outcomes(I,$msg(0102),$pp_viterbi_core/3), + ( ground(V) -> V = VCp ; copy_term(V,VCp) ), + $pp_create_dummy_goal(DummyGoal), + DummyBody = ($prism_expl_msw(I,VCp,Sid), + $pc_prism_goal_id_register(DummyGoal,Hid), + $prism_eg_path(Hid,[],[Sid])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_), + [(DummyGoal:-DummyBody)])], + consult_preds([],Prog), + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(DummyGoal), + cputime(T2), + $pp_compute_n_viterbi_p(N,DummyGoal,VPathL0),!, + cputime(T3), + $pp_replace_dummy_goal(Goal,DummyGoal,VPathL0,VPathL), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + +$pp_n_viterbi_p_core(N,Goal,VPathL) :- + ground(Goal),!, + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(Goal), + cputime(T2), + $pp_compute_n_viterbi_p(N,Goal,VPathL),!, + cputime(T3), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime),!. + +$pp_n_viterbi_p_core(N,Goal,VPathL) :- + copy_term(Goal,GoalCp), + ( $pp_trans_one_goal(GoalCp,CompGoal) -> BodyGoal = CompGoal + ; BodyGoal = (savecp(CP),Depth=0, + $pp_expl_interp_goal(GoalCp,Depth,CP,[],_,[],_,[],_,[],_)) + ), + $pp_create_dummy_goal(DummyGoal), + DummyBody = (BodyGoal, + $pc_prism_goal_id_register(GoalCp,GId), + $pc_prism_goal_id_register(DummyGoal,HId), + $prism_eg_path(HId,[GId],[])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_), + [(DummyGoal:-DummyBody)])], + consult_preds([],Prog), + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(DummyGoal), + cputime(T2), + $pp_compute_n_viterbi_p(N,DummyGoal,VPathL0),!, + cputime(T3), + $pp_replace_dummy_goal(Goal,DummyGoal,VPathL0,VPathL), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + + +$pp_compute_n_viterbi_p(N,Goal,VPathL) :- + $pp_collect_sw_info(Sws), + $pc_export_sw_info(Sws), + $pc_prism_goal_id_get(Goal,Gid), + garbage_collect, + $pc_compute_n_viterbi(N,Gid,VPathL0), + $pp_build_n_viterbi_path(VPathL0,VPathL),!. + +$pp_replace_dummy_goal(_,_,[],[]). +$pp_replace_dummy_goal(Goal,DummyGoal, + [v_expl(J,Pmax,VNodeL0)|VPathL0], + [v_expl(J,Pmax,VNodeL)|VPathL]) :- + VNodeL0 = [node(DummyGoal,Paths)|VNodeL1], + VNodeL = [node(Goal,Paths)|VNodeL1],!, + $pp_replace_dummy_goal(Goal,DummyGoal,VPathL0,VPathL). + +$pp_build_n_viterbi_path([],[]). +$pp_build_n_viterbi_path([v_expl(J,EGs,EGPaths,ESwPaths,Pmax)|VPathL0], + [v_expl(J,Pmax,VNodeL)|VPathL]) :- + $pp_decode_viterbi_path(EGs,EGPaths,ESwPaths,VNodeL), + $pp_build_n_viterbi_path(VPathL0,VPathL). + +%% Viterbi with reranking based on VB +%% +%% viterbi_h(G) :- the same as n_viterbi_h([1,default],G) +%% viterbi_h(G,P) :- the same as n_viterbi_h([1,default],G,P) +%% viterbif_h(G) :- the same as n_viterbif_h([1,default],G) +%% viterbif_h(G,P,VPath) :- the same as +%% n_viterbif_h([1,default],[v_expl(0,P,VPath)]) +%% +%% n_viterbi_h(N,G) :- the same as n_viterbi_h([N,default],G) +%% n_viterbi_h(N,G,Ps) :- the same as n_viterbi_h([N,default],G,Ps) +%% n_viterbi_h([N,M],G) :- print top-N Viterbi probs selected from top-M +%% Viterbi probs based on ML/MAP (M > N) +%% n_viterbi_h([N,M],G,Ps) :- output top-N Viterbi probs selected from top-M +%% Viterbi probs based on ML/MAP (M > N) +%% n_viterbif_h(N,G) :- the same as n_viterbif_h([N,default],G) +%% n_viterbif_h(N,G,VPathL) :- the same as n_viterbif_h([N,default],G,VPathL) +%% n_viterbif_h([N,M],G) :- print the top-N Viterbi paths and the corresponding +%% Viterbi probs selected from the top-N Viterbi paths +%% based on ML/MAP (M > N) +%% n_viterbif_h([N,M],G,VPathL) :- +%% output the list of the top-N Viterbi paths and the corresponding +%% Viterbi probs selected from top-N Viterbi paths based on ML/MAP +%% (M =< N) +%% +%% viterbig_h(Goal) :- the same as n_viterbig_h(1,Goal) +%% viterbig_h(Goal,Pmax) :- the same as n_viterbig_h(1,Goal,Pmax) +%% viterbig_h(Goal,Pmax,VNodeL) :- the same as n_viterbig_h(1,Goal,Pmax,VNodeL) +%% +%% n_viterbig_h(N,Goal) :- the same as n_viterbig_h(N,Goal,_,_) +%% n_viterbig_h([N,M],Goal) :- the same as n_viterbig_h([N,M],Goal,_,_) +%% n_viterbig_h(N,Goal,Pmax) :- the same as n_viterbig_h(N,Goal,Pmax,_) +%% n_viterbig_h([N,M],Goal,Pmax) :- the same as n_viterbig_h([N,M],Goal,Pmax,_) +%% n_viterbig_h(N,Goal,Pmax) :- +%% the same as n_viterbig_h([N,default],Goal,Pmax,_) +%% n_viterbig_h(N,Goal,Pmax,VNodeL) :- +%% the same as n_viterbig_h([N,default],Goal,Pmax,VNodeL) +%% n_viterbig_h([N,M],Goal,Pmax,VNodeL) :- +%% If Goal is not ground, unify Goal with the first element in the K-th +%% Viterbi path VNodeL (K=1,2,... on backtracking). Pmax is the +%% probability of VNodeL. + +viterbi_h(G) :- n_viterbi_h([1,default],G). +viterbi_h(G,P) :- n_viterbi_h([1,default],G,[P]). +viterbif_h(G) :- n_viterbif_h([1,default],G). +viterbif_h(G,P,VPath) :- n_viterbif_h([1,default],G,[v_expl(0,P,VPath)]). +viterbit_h(G) :- n_viterbit_h([1,default],G). +viterbit_h(G,P,VTree) :- + n_viterbif_h([1,default],G,[v_expl(0,P,VPath)]),!, + viterbi_tree(VPath,VTree). + +n_viterbi_h([N,M],G) :- !, + n_viterbif_h([N,M],G,VPathL),!, + ( member(v_expl(J,Pmax,_),VPathL), + $pp_print_n_viterbi(J,Pmax), + fail + ; true + ). +n_viterbi_h(N,G) :- n_viterbi_h([N,default],G). + +n_viterbi_h([N,M],G,Ps) :- !, + n_viterbif_h([N,M],G,VPathL),!, + findall(Pmax,member(v_expl(_,Pmax,_),VPathL),Ps). +n_viterbi_h(N,G,Ps) :- n_viterbi_h([N,default],G,Ps). + +n_viterbif_h([N,M],G) :- !, + n_viterbif_h([N,M],G,VPathL),!, + $pp_print_n_viterbif(VPathL). +n_viterbif_h(N,G) :- + n_viterbif_h([N,default],G). + +n_viterbif_h([N,M],Goal,VPathL) :- !, + ( M == default -> + get_prism_flag(rerank,M1),!, + n_viterbif_h([N,M1],Goal,VPathL) + ; % M \== default + $pp_require_positive_integer(N,$msg(1400),n_viterbif_h/3), + $pp_require_positive_integer(M,$msg(1401),n_viterbif_h/3), + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),n_viterbif_h/3), + ( N > M -> N1 = M ; N1 = N ),!, + $pp_n_viterbif_h([N1,M],Goal,VPathL) + ). + +n_viterbif_h(N,G,VPathL) :- + n_viterbif_h([N,default],G,VPathL). + +$pp_n_viterbif_h([N,M],Goal,VPathL) :- + $pp_clean_infer_stats, + cputime(T0), + $pp_n_viterbi_h_core(N,M,Goal,VPathL), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_viterbi_stats1(InfTime),!. + +n_viterbit_h([N,M],G) :- !, + n_viterbif_h([N,M],G,VPathL),!, + $pp_print_n_viterbit(VPathL). +n_viterbit_h(N,G) :- + n_viterbit_h([N,default],G). + +n_viterbit_h([N,M],G,VPathL) :- !, + n_viterbif_h([N,M],G,VPathL0),!, + $pp_build_n_viterbit(VPathL0,VPathL). +n_viterbit_h(N,G,VPathL) :- + n_viterbit_h([N,default],G,VPathL). + +viterbig_h(Goal) :- n_viterbig_h(1,Goal). +viterbig_h(Goal,Pmax) :- n_viterbig_h(1,Goal,Pmax). +viterbig_h(Goal,Pmax,VNodeL) :- n_viterbig_h(1,Goal,Pmax,VNodeL). + +n_viterbig_h([N,M],Goal) :- !, + ( ground(Goal) -> n_viterbi_h([N,M],Goal) + ; n_viterbig_h([N,M],Goal,_,_) + ). +n_viterbig_h(N,Goal) :- + ( ground(Goal) -> n_viterbi_h(N,Goal) + ; n_viterbig_h(N,Goal,_,_) + ). + +n_viterbig_h([N,M],Goal,Pmax) :- !, + ( ground(Goal) -> n_viterbi_h([N,M],Goal,Ps),!,member(Pmax,Ps) + ; n_viterbig_h([N,M],Goal,Pmax,_) + ). +n_viterbig_h(N,Goal,Pmax) :- + ( ground(Goal) -> n_viterbi_h(N,Goal,Ps),!,member(Pmax,Ps) + ; n_viterbig_h(N,Goal,Pmax,_) + ). + +n_viterbig_h([N,default],Goal,Pmax,VNodeL) :- !, + get_prism_flag(rerank,M),!, + n_viterbig_h([N,M],Goal,Pmax,VNodeL). +n_viterbig_h([N,M],Goal,Pmax,VNodeL) :- !, + $pp_require_positive_integer(N,$msg(1400),n_viterbig_h/3), + $pp_require_positive_integer(M,$msg(1401),n_viterbig_h/3), + $pp_require_tabled_probabilistic_atom(Goal,$msg(0006),n_viterbig_h/3), + ( N > M -> N1 = M ; N1 = N ),!, + $pp_n_viterbig_h([N1,M],Goal,Pmax,VNodeL). +n_viterbig_h(N,Goal,Pmax,VNodeL) :- + n_viterbig_h([N,default],Goal,Pmax,VNodeL). + +$pp_n_viterbig_h([N,M],Goal,Pmax,VNodeL) :- !, + $pp_clean_infer_stats, + cputime(T0), + $pp_n_viterbi_h_core(N,M,Goal,VPathL), + cputime(T1), + InfTime is T1 - T0, + $pp_assert_viterbi_stats1(InfTime),!, + ( ground(Goal) -> member(v_expl(J,Pmax,VNodeL),VPathL) + ; Goal = msw(_,_) -> + member(v_expl(J,Pmax,VNodeL),VPathL), + VNodeL = [node(_,[path([],[SwIns])])|_], + Goal = SwIns + ; % else + member(v_expl(J,Pmax,VNodeL),VPathL), + VNodeL = [node(_,[path([Goal1],[])])|_], + Goal = Goal1 + ). + +%% Common routines: + +$pp_n_viterbi_h_core(N,M,Goal,VPathL) :- + Goal = msw(I,V),!, + $pp_require_ground(I,$msg(0101),$pp_viterbi_core/3), + $pp_require_switch_outcomes(I,$msg(0102),$pp_viterbi_core/3), + ( ground(V) -> V = VCp ; copy_term(V,VCp) ), + $pp_create_dummy_goal(DummyGoal), + DummyBody = ($prism_expl_msw(I,VCp,Sid), + $pc_prism_goal_id_register(DummyGoal,Hid), + $prism_eg_path(Hid,[],[Sid])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_), + [(DummyGoal:-DummyBody)])], + consult_preds([],Prog), + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(DummyGoal), + cputime(T2), + $pp_compute_n_viterbi_h(N,M,DummyGoal,VPathL0),!, + cputime(T3), + $pp_replace_dummy_goal(Goal,DummyGoal,VPathL0,VPathL), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + +$pp_n_viterbi_h_core(N,M,Goal,VPathL) :- + ground(Goal),!, + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(Goal), + cputime(T2), + $pp_compute_n_viterbi_h(N,M,Goal,VPathL),!, + cputime(T3), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime),!. + +$pp_n_viterbi_h_core(N,M,Goal,VPathL) :- + copy_term(Goal,GoalCp), + ( $pp_trans_one_goal(GoalCp,CompGoal) -> BodyGoal = CompGoal + ; BodyGoal = (savecp(CP),Depth=0, + $pp_expl_interp_goal(GoalCp,Depth,CP,[],_,[],_,[],_,[],_)) + ), + $pp_create_dummy_goal(DummyGoal), + DummyBody = (BodyGoal, + $pc_prism_goal_id_register(GoalCp,GId), + $pc_prism_goal_id_register(DummyGoal,HId), + $prism_eg_path(HId,[GId],[])), + Prog = [pred(DummyGoal,0,_Mode,_Delay,tabled(_,_,_,_), + [(DummyGoal:-DummyBody)])], + consult_preds([],Prog), + $pp_init_tables_aux, + $pp_clean_graph_stats, + $pp_init_tables_if_necessary,!, + cputime(T1), + $pp_find_explanations(DummyGoal), + cputime(T2), + $pp_compute_n_viterbi_h(N,M,DummyGoal,VPathL0),!, + cputime(T3), + $pp_replace_dummy_goal(Goal,DummyGoal,VPathL0,VPathL), + $pc_import_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + $pp_assert_graph_stats(NSubgraphs,NGoalNodes,NSwNodes,AvgShared), + SearchTime is T2 - T1, + NumCompTime is T3 - T2, + $pp_assert_viterbi_stats2(SearchTime,NumCompTime), + $pp_delete_tmp_out,!. + +$pp_compute_n_viterbi_h(N,M,Goal,VPathL) :- + $pp_collect_sw_info(Sws), + $pc_export_sw_info(Sws), + $pc_prism_goal_id_get(Goal,Gid), + garbage_collect, + $pc_compute_n_viterbi_rerank(N,M,Gid,VPathL0), + $pp_build_n_viterbi_path(VPathL0,VPathL),!. + +%% Statistics + +$pp_assert_viterbi_stats1(InfTime0) :- + InfTime is InfTime0 / 1000.0, + assertz($ps_infer_time(InfTime)),!. + +$pp_assert_viterbi_stats2(SearchTime0,NumCompTime0) :- + SearchTime is SearchTime0 / 1000.0, + NumCompTime is NumCompTime0 / 1000.0, + assertz($ps_infer_search_time(SearchTime)), + assertz($ps_infer_calc_time(NumCompTime)),!. + +%%---------------------------------------- +%% e-graph -> tree + +viterbi_tree(EG,Tree) :- + $pp_require_list(EG,$msg(2104),viterbi_tree/2), + new_hashtable(HT), + $pp_viterbi_tree(EG,Tree,HT). + +$pp_viterbi_tree([],[],_). +$pp_viterbi_tree([Node|Nodes],Tree,HT), Node = node(Name,[]) => + Tree = Name, + $pp_viterbi_tree_register(Name,Tree,HT),!, + $pp_viterbi_tree(Nodes,_,HT). +$pp_viterbi_tree([Node|Nodes],Tree,HT), Node = node(Name,[path(Gs,Ss)]) => + Tree = [Name|L0], + $pp_viterbi_tree_goals(Gs,L0,L1,HT), + $pp_viterbi_tree_swits(Ss,L1,[],HT), + $pp_viterbi_tree_register(Name,Tree,HT),!, + $pp_viterbi_tree(Nodes,_,HT). + +$pp_viterbi_tree_goals([],L,L,_). +$pp_viterbi_tree_goals([G|Gs],[Node|L0],L1,HT) :- + $pp_viterbi_tree_register(G,Node,HT),!, % Node = free var. + $pp_viterbi_tree_goals(Gs,L0,L1,HT). + +$pp_viterbi_tree_swits([],L,L,_). +$pp_viterbi_tree_swits([S|Ss],[Node|L0],L1,HT) :- + Node = S,!, + $pp_viterbi_tree_swits(Ss,L0,L1,HT). + +$pp_viterbi_tree_register(Name,Node,HT) :- + hashtable_get(HT,Name,V),!, + ( V = Node -> true + ; $pp_raise_unmatched_branches($pp_viterbi_tree_register/3) + ). +$pp_viterbi_tree_register(Name,Node,HT) :- + hashtable_put(HT,Name,Node). + +%%---------------------------------------- +%% e-graph -> list of subgoals, list of switches + +viterbi_subgoals(VNodes,Goals) :- + $pp_require_list(VNodes,$msg(2104),viterbi_subgoals/2), + $pp_viterbi_subgoals(VNodes,Goals). + +$pp_viterbi_subgoals([],[]). +$pp_viterbi_subgoals([node(_,[])|Nodes],Ys) :- !, + $pp_viterbi_subgoals(Nodes,Ys). +$pp_viterbi_subgoals([node(_,[path(Xs,_)])|Nodes],Ys) :- + append(Xs,Ys1,Ys),!, + $pp_viterbi_subgoals(Nodes,Ys1). + +viterbi_switches(VNodes,Goals) :- + $pp_require_list(VNodes,$msg(2104),viterbi_switches/2), + $pp_viterbi_switches(VNodes,Goals). + +$pp_viterbi_switches([],[]). +$pp_viterbi_switches([node(_,[])|Nodes],Ys) :- !, + $pp_viterbi_switches(Nodes,Ys). +$pp_viterbi_switches([node(_,[path(_,Xs)])|Nodes],Ys) :- + append(Xs,Ys1,Ys),!, + $pp_viterbi_switches(Nodes,Ys1).