ningshuxia
2022-11-17 112f6bf8aa6c76b055d19627ccef21fb59515436
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
// <copyright file="HybridMCGeneric.cs" company="Math.NET">
// Math.NET Numerics, part of the Math.NET Project
// http://numerics.mathdotnet.com
// http://github.com/mathnet/mathnet-numerics
//
// Copyright (c) 2009-2010 Math.NET
//
// Permission is hereby granted, free of charge, to any person
// obtaining a copy of this software and associated documentation
// files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use,
// copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
// </copyright>
 
namespace IStation.Numerics.Statistics.Mcmc
{
    using System;
    using Distributions;
 
    /// <summary>
    /// The Hybrid (also called Hamiltonian) Monte Carlo produces samples from distribution P using a set
    /// of Hamiltonian equations to guide the sampling process. It uses the negative of the log density as
    /// a potential energy, and a randomly generated momentum to set up a Hamiltonian system, which is then used
    /// to sample the distribution. This can result in a faster convergence than the random walk Metropolis sampler
    /// (<seealso cref="MetropolisSampler{T}"/>).
    /// </summary>
    /// <typeparam name="T">The type of samples this sampler produces.</typeparam>
    public abstract class HybridMCGeneric<T> : McmcSampler<T>
    {
        /// <summary>
        /// The delegate type that defines a derivative evaluated at a certain point.
        /// </summary>
        /// <param name="f">Function to be differentiated.</param>
        /// <param name="x">Value where the derivative is computed.</param>
        public delegate T DiffMethod(DensityLn<T> f, T x);
 
        /// <summary>
        /// Evaluates the energy function of the target distribution.
        /// </summary>
        readonly DensityLn<T> _energy;
 
        /// <summary>
        /// The current location of the sampler.
        /// </summary>
        protected T Current;
 
        /// <summary>
        /// The number of burn iterations between two samples.
        /// </summary>
        int _burnInterval;
 
        /// <summary>
        /// The size of each step in the Hamiltonian equation.
        /// </summary>
        double _stepSize;
 
        /// <summary>
        /// The number of iterations in the Hamiltonian equation.
        /// </summary>
        int _frogLeapSteps;
 
        /// <summary>
        /// The algorithm used for differentiation.
        /// </summary>
        readonly DiffMethod _diff;
 
        /// <summary>
        /// Gets or sets the number of iterations in between returning samples.
        /// </summary>
        /// <exception cref="ArgumentOutOfRangeException">When burn interval is negative.</exception>
        public int BurnInterval
        {
            get => _burnInterval;
            set => _burnInterval = SetNonNegative(value);
        }
 
        /// <summary>
        /// Gets or sets the number of iterations in the Hamiltonian equation.
        /// </summary>
        /// <exception cref="ArgumentOutOfRangeException">When frog leap steps is negative or zero.</exception>
        public int FrogLeapSteps
        {
            get => _frogLeapSteps;
            set => _frogLeapSteps = SetPositive(value);
        }
 
        /// <summary>
        /// Gets or sets the size of each step in the Hamiltonian equation.
        /// </summary>
        /// <exception cref="ArgumentOutOfRangeException">When step size is negative or zero.</exception>
        public double StepSize
        {
            get => _stepSize;
            set => _stepSize = SetPositive(value);
        }
 
        /// <summary>
        /// Constructs a new Hybrid Monte Carlo sampler.
        /// </summary>
        /// <param name="x0">The initial sample.</param>
        /// <param name="pdfLnP">The log density of the distribution we want to sample from.</param>
        /// <param name="frogLeapSteps">Number frog leap simulation steps.</param>
        /// <param name="stepSize">Size of the frog leap simulation steps.</param>
        /// <param name="burnInterval">The number of iterations in between returning samples.</param>
        /// <param name="randomSource">Random number generator used for sampling the momentum.</param>
        /// <param name="diff">The method used for differentiation.</param>
        /// <exception cref="ArgumentOutOfRangeException">When the number of burnInterval iteration is negative.</exception>
        /// <exception cref="ArgumentNullException">When either x0, pdfLnP or diff is null.</exception>
        protected HybridMCGeneric(T x0, DensityLn<T> pdfLnP, int frogLeapSteps, double stepSize, int burnInterval, Random randomSource, DiffMethod diff)
        {
            _energy = x => -pdfLnP(x);
            FrogLeapSteps = frogLeapSteps;
            StepSize = stepSize;
            BurnInterval = burnInterval;
            Current = x0;
            _diff = diff;
            RandomSource = randomSource;
        }
 
        /// <summary>
        /// Returns a sample from the distribution P.
        /// </summary>
        public override T Sample()
        {
            Burn(_burnInterval + 1);
 
            return Current;
        }
 
        /// <summary>
        /// This method runs the sampler for a number of iterations without returning a sample
        /// </summary>
        protected void Burn(int n)
        {
            T p = Create();
            double e = _energy(Current);
            T gradient = _diff(_energy, Current);
            for (int i = 0; i < n; i++)
            {
                RandomizeMomentum(ref p);
                double h = Hamiltonian(p, e);
 
 
                T mNew = Copy(Current);
                T gNew = Copy(gradient);
 
                for (int j = 0; j < _frogLeapSteps; j++)
                {
                    HamiltonianEquations(ref gNew, ref mNew, ref p);
                }
 
                double enew = _energy(mNew);
                double hnew = Hamiltonian(p, enew);
 
                double dh = hnew - h;
 
                Update(ref e, ref gradient, mNew, gNew, enew, dh);
                Samples++;
            }
        }
 
        /// <summary>
        /// Method used to update the sample location. Used in the end of the loop.
        /// </summary>
        /// <param name="e">The old energy.</param>
        /// <param name="gradient">The old gradient/derivative of the energy.</param>
        /// <param name="mNew">The new sample.</param>
        /// <param name="gNew">The new gradient/derivative of the energy.</param>
        /// <param name="enew">The new energy.</param>
        /// <param name="dh">The difference between the old Hamiltonian and new Hamiltonian. Use to determine
        /// if an update should take place. </param>
        protected void Update(ref double e, ref T gradient, T mNew, T gNew, double enew, double dh)
        {
            if (dh <= 0)
            {
                Current = mNew; gradient = gNew; e = enew; Accepts++;
            }
            else if (Bernoulli.Sample(RandomSource, Math.Exp(-dh)) == 1)
            {
                Current = mNew; gradient = gNew; e = enew; Accepts++;
            }
        }
 
        /// <summary>
        /// Use for creating temporary objects in the Burn method.
        /// </summary>
        /// <returns>An object of type T.</returns>
        protected abstract T Create();
 
        /// <summary>
        /// Use for copying objects in the Burn method.
        /// </summary>
        /// <param name="source">The source of copying.</param>
        /// <returns>A copy of the source object.</returns>
        protected abstract T Copy(T source);
 
        /// <summary>
        /// Method for doing dot product.
        /// </summary>
        /// <param name="first">First vector/scalar in the product.</param>
        /// <param name="second">Second vector/scalar in the product.</param>
        protected abstract double DoProduct(T first, T second);
 
        /// <summary>
        /// Method for adding, multiply the second vector/scalar by factor and then
        /// add it to the first vector/scalar.
        /// </summary>
        /// <param name="first">First vector/scalar.</param>
        /// <param name="factor">Scalar factor multiplying by the second vector/scalar.</param>
        /// <param name="second">Second vector/scalar.</param>
        protected abstract void DoAdd(ref T first, double factor, T second);
 
        /// <summary>
        /// Multiplying the second vector/scalar by factor and then subtract it from
        /// the first vector/scalar.
        /// </summary>
        /// <param name="first">First vector/scalar.</param>
        /// <param name="factor">Scalar factor to be multiplied to the second vector/scalar.</param>
        /// <param name="second">Second vector/scalar.</param>
        protected abstract void DoSubtract(ref T first, double factor, T second);
 
        /// <summary>
        /// Method for sampling a random momentum.
        /// </summary>
        /// <param name="p">Momentum to be randomized.</param>
        protected abstract void RandomizeMomentum(ref T p);
 
        /// <summary>
        /// The Hamiltonian equations that is used to produce the new sample.
        /// </summary>
        protected void HamiltonianEquations(ref T gNew, ref T mNew, ref T p)
        {
            DoSubtract(ref p, _stepSize / 2, gNew);
            DoAdd(ref mNew, _stepSize, p);
            gNew = _diff(_energy, mNew);
            DoSubtract(ref p, _stepSize / 2, gNew);
        }
 
        /// <summary>
        /// Method to compute the Hamiltonian used in the method.
        /// </summary>
        /// <param name="momentum">The momentum.</param>
        /// <param name="e">The energy.</param>
        /// <returns>Hamiltonian=E+p.p/2</returns>
        protected double Hamiltonian(T momentum, double e)
        {
            return e + DoProduct(momentum, momentum) / 2;
        }
 
        /// <summary>
        /// Method to check and set a quantity to a non-negative value.
        /// </summary>
        /// <param name="value">Proposed value to be checked.</param>
        /// <returns>Returns value if it is greater than or equal to zero.</returns>
        /// <exception cref="ArgumentOutOfRangeException">Throws when value is negative.</exception>
        protected int SetNonNegative(int value)
        {
            if (value < 0)
            {
                throw new ArgumentOutOfRangeException(nameof(value), "Value must not be negative (zero is ok).");
            }
            return value;
        }
 
        /// <summary>
        /// Method to check and set a quantity to a non-negative value.
        /// </summary>
        /// <param name="value">Proposed value to be checked.</param>
        /// <returns>Returns value if it is greater than to zero.</returns>
        /// <exception cref="ArgumentOutOfRangeException">Throws when value is negative or zero.</exception>
        protected int SetPositive(int value)
        {
            if (value <= 0)
            {
                throw new ArgumentOutOfRangeException(nameof(value), "Value must not be negative (zero is ok).");
            }
            return value;
        }
 
        /// <summary>
        /// Method to check and set a quantity to a non-negative value.
        /// </summary>
        /// <param name="value">Proposed value to be checked.</param>
        /// <returns>Returns value if it is greater than zero.</returns>
        /// <exception cref="ArgumentOutOfRangeException">Throws when value is negative or zero.</exception>
        protected double SetPositive(double value)
        {
            if (value <= 0)
            {
                throw new ArgumentOutOfRangeException(nameof(value), "Value must not be negative (zero is ok).");
            }
            return value;
        }
    }
}