Jekyll2024-01-23T08:45:06+00:00https://wessel.ai/feed.xmlwesselb.github.ioThoughts on machine learning and other topicsA Short Note on Uniform Integrability2021-08-05T00:00:00+00:002021-08-05T00:00:00+00:00https://wessel.ai/2021/08/05/uniform-integrability<h2 id="introduction">Introduction</h2>
<p>A sequence of random variables $(X_n)_{n \ge 1} \sub L^1$ is called $L^1$-convergent if there exists some limit $X \in L^1$ such that $\E|X_n - X| \to 0$ as $n \to \infty$.
In this post, we briefly discuss a necessary and sufficient condition for $L^1$-convergence called <em>uniform integrability</em>.</p>
<h2 id="uniform-integrability">Uniform Integrability</h2>
<p><strong>Definition.</strong> A collection of random variables $\mathcal{F}$ is called <em>uniformly integrable</em> if</p>
<p>\begin{equation}
\lim_{K \to \infty} \sup\,\{ \E[\ind_{|X| \ge K} |X|] : X \in \mathcal{F} \} = 0.
\end{equation}</p>
<p>Noting that $\E[\ind_{\abs{X} \ge K} \abs{X}] = \E\abs{X - \ind_{\abs{X} < K} X}$, this condition can also be written as</p>
<p>\begin{equation}
\lim_{K \to \infty} \sup\,\{ \E\abs{X - \ind_{\abs{X} < K} X} : X \in \mathcal{F} \} = 0.
\end{equation}</p>
<p>In other words, if $\mathcal{F}$ is uniformly integrable, then you can choose a single value of $K > 0$ such that, uniformly over $X \in \mathcal{F}$, the random variable $\ind_{\abs{X} < K} X$ is a good approximation of $X$ in terms of the $L^1$-norm.
Crucially, every $\ind_{\abs{X} < K} X$ is a bounded random variable, which is often a desirable property.
Therefore, you could aptly call a family which is uniform integrable a family which allows a <em>uniform bounded approximation</em>.</p>
<p>But what about the name <em>uniform integrability</em>?
For a single variable $X$, it is true that
\begin{equation}
\E\abs{X} < \infty
\iff
\lim_{K \to \infty} \E[\ind_{|X| \ge K} |X|] = 0.
\end{equation}
Hence, you could call a family of random variables <em>uniformly</em> integrable if the limit on the RHS, which is equivalent to integrability, converges uniformly over the family.</p>
<p>The bounded approximation given by uniform integrability can be made a bit nicer.
Instead of bounding $X$ by applying the function $f_K(x) = \ind_{\abs{x} < K} x$, which exhibits a discontinuity at $\abs{x} = K$, uniform integrability allows us to bound $X$ by applying the nicer function $g_K(x) = \max(\min(x, K), -K)$, which is a fully continuous function:
\begin{equation}
\E\abs{g_K(X) - X}
= \E[\ind_{\abs{X} \ge K}\abs{\abs{X} - K}]
\le \E[\ind_{\abs{X} \ge K}\abs{X}] + \E[\ind_{\abs{X} \ge K} K]
\le 2 \E[\ind_{\abs{X} \ge K} \abs{X}],
\end{equation}
which uniformly converges to zero as $K \to \infty$.
Henceforth, for any random variable $X$, denote by $X^K =\max(\min(X, K), -K)$ the <em>trunction of $X$ at level $K$</em>.
Since $g_K$ is continuous, trunctions in this way preserves limits.</p>
<p>Finally, to check that a family of random variables is uniformly integrable, the following two facts are very useful:</p>
<ol>
<li>
<p>If $\sup\,\{ \E[\abs{X}^{p}] : X \in \mathcal{F}\} < \infty$ for some $p > 1$, then $\mathcal{F}$ is uniformly integrable.</p>
</li>
<li>
<p>Every family $\{ \E[X \cond \mathcal{G}] : \mathcal{G} \text{ is a sub-}\sigma\text{-algebra}\}$ is uniformly integrable.</p>
</li>
</ol>
<h2 id="a-necessary-and-sufficient-condition-for-l1-convergence">A Necessary and Sufficient Condition for $L^1$-Convergence</h2>
<p>A standard way to prove that a sequence of random variables $(X_n)_{n \ge 1}$ is $L^1$-convergent to some limit is to use <em>bounded convergence</em>, an instance of the dominated convergence theorem.
Recall that a sequence of random variables $(X_n)_{n \ge 1}$ is called <em>convergent in probability</em> if there exists a limit $X$ such that $\P(\abs{X - X_n} \ge \e) \to 0$ for every $\e > 0$.</p>
<p><strong>Theorem (bounded convergence).</strong>
If $(X_n)_{n \ge 1}$ and $X$ are bounded by some $K > 0$ and $X_n \to X$ in probability, then $X_n \to X$ in $L^1$.</p>
<p><strong>Proof.</strong>
Without loss of generality, assume that $X = 0$, so it remains to demonstrate that $\E\abs{X_n} \to 0$.
Let $\e > 0$.
Using the assumption that $\abs{X_n} \le K$, the idea is to consider the cases $\abs{X_n} \in [0, \e]$ and $\abs{X_n} \in (\e, K]$:</p>
<p>\begin{equation}
\E\abs{X_n}
= \E[\abs{X_n} \ind_{\abs{X_n} \in [0, \e]}] + \E[\abs{X_n} \ind_{\abs{X_n} \in (\e, K]}]
\le \e + K\, \E[\ind_{\abs{X_n} \in (\e, K]}]
\le \e + K\, \P(\abs{X_n} \ge \e).
\end{equation}</p>
<p>Using that assumpion that $\P(\abs{X_n} \ge \e) \to 0$ as $n \to \infty$, hence $\limsup_{n \to \infty} \E\abs{X_n} \le \e$.
Since $\e > 0$ was arbitrary, this proves that $\lim_{n \to \infty} \E\abs{X_n}=0$. <span style="float:right">\(\blacksquare\)</span></p>
<p>Bounded convergence is an incredibly useful tool, but the assumption that $(X_n)_{n \ge 1}$ and $X$ are bounded can be too strong.
A looser assumption is that $(X_n)_{n \ge 1}$ and $X$ uniformly allow a <em>bounded approximation</em>, <em>i.e.</em> that $(X_n)_{n \ge 1}$ (and therefore the union of $(X_n)_{n \ge 1}$ and $X$) are <em>uniformly integrable</em>.
This looser condition turns out to not just be sufficient but also necessary.</p>
<p><strong>Theorem (Vitali’s).</strong>
Let $(X_n)_{n \ge 1}$ be a sequence of random variables and let $X$ be a random variable.
Then (a) $(X_n)_{n \ge 1} \sub L^1$, $X \in L^1$, and $X_n \to X$ in $L^1$ if and only if (b) $(X_n)_{n \ge 1} \sub L^1$ is uniformly integrable and $X_n \to X$ in probability.</p>
<p><strong>Proof.</strong>
We only show the hard direction, which is that (b) implies (a).
Assume that $(X_n)_{n \ge 1} \sub L^1$ is uniformly integrable and $X_n \to X$ in probability.
To begin with, it is true<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> that $X \in L^1$.
Since $X \in L^1$, $(X_n - X)_{n \ge 1}$ is uniformly integrable and $X_n - X \to 0$ in probability in any case, so without loss of generality assume that $X = 0$.</p>
<p>Uniform integrability gives a uniform bounded approximation of the sequence:</p>
<p>\begin{equation} \label{eq:uniform-approx}
\lim_{K \to \infty} \sup_{n \ge 1}\, \E\abs{X_n - \ind_{\abs{X_n} < K} X_n} = 0.
\end{equation}</p>
<p>For every $K>0$, the sequence $(\ind_{\abs{X_n} < K} X_n)_{n \ge 1}$ is bounded and $\ind_{\abs{X_n} < K} X_n \to 0$ in probability, so $\ind_{\abs{X_n} < K} X_n \to 0$ in $L^1$ by bounded convergence.
The idea is to then take $K \to \infty$ to show that also $X_n \to 0$ in $L^1$.
To wit, by the triangle inequality,</p>
<p>\begin{equation}
\limsup_{n \to \infty} \E\abs{X_n}
\le \sup_{n \ge 1}\, \E\abs{X_n - \ind_{\abs{X_n} < K} X_n} + \limsup_{n \to \infty} \E\abs{\ind_{\abs{X_n} < K} X_n}
\overset{\text{(i)}}{=} \sup_{n \ge 1}\, \E\abs{X_n - \ind_{\abs{X_n} < K} X_n}
\end{equation}</p>
<p>where (i) follows from that $\ind_{\abs{X_n} < K} X_n \to 0$ in $L^1$ by bounded convergence.
Taking $K \to \infty$ and using \eqref{eq:uniform-approx} then shows the result.
<span style="float:right">\(\blacksquare\)</span></p>
<h2 id="application-strengthening-of-convergence-in-distribution">Application: Strengthening of Convergence in Distribution</h2>
<p>A sequence of random variables $(X_n)_{n \ge 1}$ is called <em>weakly convergent</em> if there exists a limit $X$ such that, for every $f \colon \R \to \R$ continuous and bounded, it is true that $\E[f(X_n)] \to \E[f(X)]$.
A limitation of weak convergence is that it only handles <em>bounded</em> $f$;
for example, weak convergence does not imply that $\E[X_n] \to \E[X]$.
As we illustrate now, the assumption of uniform integrability can be used to strengthen the conclusion of weak convergence to include $\E[X_n] \to \E[X]$.</p>
<p>The key observation is as follows: if $(X_n)_{n \ge 1}$ and $X$ were bounded by some $K > 0$, then we can apply the truncation function $g_K$, which is a continuous and bounded function, to conclude that
\begin{equation}
\E[X_n] = \E[g_K(X_n)] \to \E[g_K(X)] = \E[X].
\end{equation}
Instead of assuming boundedness, now assume that $(X_n)_{n \ge 1}$ is only uniformly integrable.
For all $K > 0$, consider the uniform bounded approximations $(X^K_n)_{n \ge 1}$ and $X^K$.
Because the trunction operation is continuous, every $(X_n^K)_{n \ge 1}$ is still weakly convergent to $X^K$.
Morever, $(X_n^K)_{n \ge 1}$ and $X^K$ are bounded by $K > 0$.
The foregoing argument then shows that
$
\lim_{n \to \infty} \E[X_n^K] = \E[X^K].
$
Therefore,
\begin{equation}
\lim_{n \to \infty} \E[X_n]
= \lim_{n \to \infty} \lim_{K \to \infty} \E[X_n^K]
= \lim_{K \to \infty} \lim_{n \to \infty} \E[X_n^K]
= \lim_{K \to \infty} \E[X^K]
= \E[X],
\end{equation}
where the interchange of limits is allowed by uniformity of the bounded approximation.</p>
<h2 id="summary">Summary</h2>
<p>A family of random variables is called <em>uniformly integrable</em> if it allows a <em>uniform bounded approximation</em>.
Allowing a uniform bounded approximation turns out to be the right characterisation of $L^1$-convergence:
a sequence is $L^1$-convergent if and only if it is uniformly integrable.
Uniformly integrability is generally useful tool:
if you can prove a result for bounded random variables, then you might be able to prove the result for the greater class of uniformly integrable random variables by considering a uniform bounded approximation.</p>
<p>Thanks to <a href="https://sites.google.com/view/jirihron">Jiri Hron</a> for helpful comments on a draft of this post.</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>Since $X_n \to X$ in probability, $X_{n_k} \to X$ almost surely along some subsequence $(X_{n_k})_{k \ge 0}$.
Therefore, using Fatou’s lemma,</p>
<p>\begin{equation}
\E\abs{X}
= \E[\lim_{k \to \infty} \abs{X_{n_k}}]
\le \liminf_{k \to \infty} \E[\abs{X_{n_k}}]
< \infty,
\end{equation}</p>
<p>where the right hand side is bounded because any uniformly integrable family is uniformly bounded in $L^1$. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>IntroductionWhat Keeps a Bayesian Awake at Night2021-04-07T00:00:00+00:002021-04-07T00:00:00+00:00https://wessel.ai/2021/04/07/what-keeps-a-bayesian-awake-at-night<p>The <a href="http://mlg.eng.cam.ac.uk/">Cambridge Machine Learning Group</a> is launching <a href="https://mlg.eng.cam.ac.uk/blog">a blog</a>, featuring a first <a href="https://mlg.eng.cam.ac.uk/blog/2021/03/31/what-keeps-a-bayesian-awake-at-night-part-1.html">two</a>-<a href="https://mlg.eng.cam.ac.uk/blog/2021/03/31/what-keeps-a-bayesian-awake-at-night-part-2.html">part</a> post about what keeps a Bayesian awake at night.
In the <a href="https://mlg.eng.cam.ac.uk/blog/2021/03/31/what-keeps-a-bayesian-awake-at-night-part-1.html">first part</a>, during day time, we lay out the standard arguments that many use to support Bayesian inference, ranging from more fundamental theorems, like Cox’s theorem, to unit tests, like Wald’s theorem.
In the <a href="https://mlg.eng.cam.ac.uk/blog/2021/03/31/what-keeps-a-bayesian-awake-at-night-part-2.html">second part</a>, at night time, we take a closer look at these standard arguments and identify the weaknesses which cause Bayesians to lose sleep at night: the standard justifications have problems, modelling is hard and sensitive to innocolous details, and—worst of all—one typically must resort to approximate inference.
Check it out!</p>The Cambridge Machine Learning Group is launching a blog, featuring a first two-part post about what keeps a Bayesian awake at night. In the first part, during day time, we lay out the standard arguments that many use to support Bayesian inference, ranging from more fundamental theorems, like Cox’s theorem, to unit tests, like Wald’s theorem. In the second part, at night time, we take a closer look at these standard arguments and identify the weaknesses which cause Bayesians to lose sleep at night: the standard justifications have problems, modelling is hard and sensitive to innocolous details, and—worst of all—one typically must resort to approximate inference. Check it out!Linear Models from a Gaussian Process Point of View with Stheno and JAX2021-01-19T00:00:00+00:002021-01-19T00:00:00+00:00https://wessel.ai/2021/01/19/linear-models-with-stheno-and-jax<p>By Wessel Bruinsma, <a href="https://jamesr.info/">James Requeima</a>, and <a href="https://scholar.google.com/citations?user=EIpfkw4AAAAJ">Eric Perim Martins</a></p>
<p class="pretitle">Cross-posted on the <a href="https://invenia.github.io/blog/2021/01/19/linear-models-with-stheno-and-jax/">Invenia blog</a>.</p>
<h2 id="introduction">Introduction</h2>
<p>A linear model prescribes a linear relationship between inputs and outputs.
Linear models are amongst the simplest of models, but they are ubiquitous across science.
A linear model with Gaussian distributions on the coefficients forms one of the simplest instances of a <em><a href="https://en.wikipedia.org/wiki/Gaussian_process">Gaussian process</a></em>.
In this post, we will give a brief introduction to linear models from a Gaussian process point of view.
We will see how a linear model can be implemented with <em>Gaussian process probabilistic programming</em> using <a href="https://github.com/wesselb/stheno">Stheno</a>, and how this model can be used to denoise noisy observations.
(Disclosure: <a href="https://willtebbutt.github.io/">Will Tebbutt</a> and Wessel are the authors of Stheno;
Will maintains a <a href="https://github.com/willtebbutt/Stheno.jl">Julia version</a>.)
In short, <a href="https://en.wikipedia.org/wiki/Probabilistic_programming">probabilistic programming</a> is a programming paradigm that brings powerful probabilistic models to the comfort of your programming language, which often comes with tools to automatically perform inference (make predictions).
We will also use <a href="https://github.com/google/jax">JAX</a>’s just-in-time compiler to make our implementation extremely efficient.</p>
<h2 id="linear-models-from-a-gaussian-process-point-of-view">Linear Models from a Gaussian Process Point of View</h2>
<p>Consider a data set \((x_i, y_i)_{i=1}^n \subseteq \R \times \R\) consisting of \(n\) real-valued input–output pairs.
Suppose that we wish to estimate a linear relationship between the inputs and outputs:</p>
\[\label{eq:ax_b}
y_i = a \cdot x_i + b + \e_i,\]
<p>where \(a\) is an unknown slope, \(b\) is an unknown offset, and \(\e_i\) is some error/noise associated with the observation \(y_i\).
To implement this model with Gaussian process probabilistic programming, we need to cast the problem into a <em>functional form</em>.
This means that we will assume that there is some underlying, random function \(y \colon \R \to \R\) such that the observations are evaluations of this function: \(y_i = y(x_i)\).
The model for the random function \(y\) will embody the structure of the linear model \eqref{eq:ax_b}.
This may sound hard, but it is not difficult at all.
We let the random function \(y\) be of the following form:</p>
\[\label{eq:ax_b_functional}
y(x) = a(x) \cdot x + b(x) + \e(x)\]
<p>where \(a\colon \R \to \R\) is a <em>random constant function</em>.
An example of a <em>constant function</em> \(f\) is \(f(x) = 5\).
<em>Random</em> means that the value \(5\) is not fixed, but modelled with a random value drawn from some probability distribution, because we don’t know the true value.
We let \(b\colon \R \to \R\) also be a random <em>constant function</em>, and \(\e\colon \R \to \R\) a random <em>noise function</em>.
Do you see the similarities between \eqref{eq:ax_b} and \eqref{eq:ax_b_functional}?
If all that doesn’t fully make sense, don’t worry; things should become more clear as we implement the model.</p>
<p>To model random constant functions and random noise functions, we will use <a href="https://github.com/wesselb/stheno">Stheno</a>, which is a Python library for Gaussian process modelling.
We also have a <a href="https://github.com/willtebbutt/Stheno.jl">Julia version</a>, but in this post we’ll use the Python version.
To install Stheno, run the command</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>pip <span class="nb">install</span> <span class="nt">--upgrade</span> <span class="nt">--upgrade-strategy</span> eager stheno
</code></pre></div></div>
<p>In Stheno, a Gaussian process can be created with <code class="language-plaintext highlighter-rouge">GP(kernel)</code>, where <code class="language-plaintext highlighter-rouge">kernel</code> is the so-called <a href="https://en.wikipedia.org/wiki/Gaussian_process#Covariance_functions"><em>kernel</em> or <em>covariance function</em> of the Gaussian process</a>.
The kernel determines the properties of the function that the Gaussian process models.
For example, the kernel <code class="language-plaintext highlighter-rouge">EQ()</code> models smooth functions, and the kernel <code class="language-plaintext highlighter-rouge">Matern12()</code> models functions that look jagged.
See the <a href="https://www.cs.toronto.edu/~duvenaud/cookbook/">kernel cookbook</a> for an overview of commonly used kernels and the <a href="https://wesselb.github.io/stheno/docs/_build/html/readme.html#available-kernels">documentation of Stheno</a> for the corresponding classes.
For constant functions, you can set the kernel to simply a constant, for example <code class="language-plaintext highlighter-rouge">1</code>, which then models the constant function with a value drawn from \(\Normal(0, 1)\). (By default, in Stheno, all means are zero; but, if you like, <a href="https://wesselb.github.io/stheno/docs/_build/html/readme.html#available-means">you can also set a mean</a>.)</p>
<p>Let’s start out by creating a Gaussian process for the random constant function \(a(x)\) that models the slope.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="kn">from</span> <span class="nn">stheno</span> <span class="kn">import</span> <span class="n">GP</span>
<span class="o">>>></span> <span class="n">a</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">a</span>
<span class="n">GP</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>
<p>You can see how the Gaussian process looks by simply sampling from it.
To sample from the Gaussian process <code class="language-plaintext highlighter-rouge">a</code> at some inputs <code class="language-plaintext highlighter-rouge">x</code>, evaluate it at those inputs, <code class="language-plaintext highlighter-rouge">a(x)</code>, and call the method <code class="language-plaintext highlighter-rouge">sample</code>: <code class="language-plaintext highlighter-rouge">a(x).sample()</code>.
This shows that you can really think of a Gaussian process just like you think of a function:
pass it some inputs to get (the model for) the corresponding outputs.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">sample</span><span class="p">(</span><span class="mi">20</span><span class="p">));</span> <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="image-container">
<img src="/assets/images/posts/linear-models-constant-functions.png" alt="Samples of a Gaussian process that models a constant function" id="figure-constant-functions" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 1: Samples of a Gaussian process that models a constant function
</p>
</div>
<p>We’ve sampled a bunch of constant functions.
Sweet!
The next step in the model \eqref{eq:ax_b_functional} is to multiply the slope function \(a(x)\) by \(x\).
To multiply <code class="language-plaintext highlighter-rouge">a</code> by \(x\), we multiply <code class="language-plaintext highlighter-rouge">a</code> by the function <code class="language-plaintext highlighter-rouge">lambda x: x</code>, which casts also \(x\) as a function:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">f</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">f</span>
<span class="n">GP</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o"><</span><span class="k">lambda</span><span class="o">></span><span class="p">)</span>
</code></pre></div></div>
<p>This will give rise to functions like \(x \mapsto 0.1x\) and \(x \mapsto -0.4x\), depending on the value that \(a(x)\) takes.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">sample</span><span class="p">(</span><span class="mi">20</span><span class="p">));</span> <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="image-container">
<img src="/assets/images/posts/linear-models-slope-functions.png" alt="Samples of a Gaussian process that models functions with a random slope" id="figure-slope-functions" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 2: Samples of a Gaussian process that models functions with a random slope
</p>
</div>
<p>This is starting to look good!
The only ingredient that is missing is an offset.
We model the offset just like the slope, but here we set the kernel to <code class="language-plaintext highlighter-rouge">10</code> instead of <code class="language-plaintext highlighter-rouge">1</code>, which models the offset with a value drawn from \(\Normal(0, 10)\).</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">b</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">f</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">b</span>
<span class="nb">AssertionError</span><span class="p">:</span> <span class="n">Processes</span> <span class="n">GP</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o"><</span><span class="k">lambda</span><span class="o">></span><span class="p">)</span> <span class="ow">and</span> <span class="n">GP</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">10</span> <span class="o">*</span> <span class="mi">1</span><span class="p">)</span> <span class="n">are</span> <span class="n">associated</span> <span class="n">to</span> <span class="n">different</span> <span class="n">measures</span><span class="p">.</span>
</code></pre></div></div>
<p>Something went wrong.
Stheno has an abstraction called <em>measures</em>, where only <code class="language-plaintext highlighter-rouge">GP</code>s that are part of the same measure can be combined into new <code class="language-plaintext highlighter-rouge">GP</code>s;
the abstraction of measures is there to keep things safe and tidy.
What goes wrong here is that <code class="language-plaintext highlighter-rouge">a</code> and <code class="language-plaintext highlighter-rouge">b</code> are not part of the same measure.
Let’s explicitly create a new measure and attach <code class="language-plaintext highlighter-rouge">a</code> and <code class="language-plaintext highlighter-rouge">b</code> to it.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="kn">from</span> <span class="nn">stheno</span> <span class="kn">import</span> <span class="n">Measure</span>
<span class="o">>>></span> <span class="n">prior</span> <span class="o">=</span> <span class="n">Measure</span><span class="p">()</span>
<span class="o">>>></span> <span class="n">a</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">b</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">f</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">b</span>
<span class="o">>>></span> <span class="n">f</span>
<span class="n">GP</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o"><</span><span class="k">lambda</span><span class="o">></span> <span class="o">+</span> <span class="mi">10</span> <span class="o">*</span> <span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>
<p>Let’s see how samples from <code class="language-plaintext highlighter-rouge">f</code> look like.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">sample</span><span class="p">(</span><span class="mi">20</span><span class="p">));</span> <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="image-container">
<img src="/assets/images/posts/linear-models-linear-functions.png" alt="Samples of a Gaussian process that models linear functions" id="figure-linear-functions" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 3: Samples of a Gaussian process that models linear functions
</p>
</div>
<p>Perfect!
We will use <code class="language-plaintext highlighter-rouge">f</code> as our linear model.</p>
<p>In practice, observations are corrupted with noise.
We can add some noise to the lines in <a href="#figure-linear-functions">Figure 3</a> by adding a Gaussian process that models noise.
You can construct such a Gaussian process by using the kernel <code class="language-plaintext highlighter-rouge">Delta()</code>, which models the noise with independent \(\Normal(0, 1)\) variables.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="kn">from</span> <span class="nn">stheno</span> <span class="kn">import</span> <span class="n">Delta</span>
<span class="o">>>></span> <span class="n">noise</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="n">Delta</span><span class="p">(),</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">y</span> <span class="o">=</span> <span class="n">f</span> <span class="o">+</span> <span class="n">noise</span>
<span class="o">>>></span> <span class="n">y</span>
<span class="n">GP</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o"><</span><span class="k">lambda</span><span class="o">></span> <span class="o">+</span> <span class="mi">10</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">Delta</span><span class="p">())</span>
<span class="o">>>></span> <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">sample</span><span class="p">(</span><span class="mi">20</span><span class="p">));</span> <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="image-container">
<img src="/assets/images/posts/linear-models-noisy-linear-functions.png" alt="Samples of a Gaussian process that models noisy linear functions" id="figure-noisy-linear-functions" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 4: Samples of a Gaussian process that models noisy linear functions
</p>
</div>
<p>That looks more realistic, but perhaps that’s a bit too much noise.
We can tune down the amount of noise, for example, by scaling <code class="language-plaintext highlighter-rouge">noise</code> by <code class="language-plaintext highlighter-rouge">0.5</code>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">y</span> <span class="o">=</span> <span class="n">f</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">noise</span>
<span class="o">>>></span> <span class="n">y</span>
<span class="n">GP</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o"><</span><span class="k">lambda</span><span class="o">></span> <span class="o">+</span> <span class="mi">10</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">+</span> <span class="mf">0.25</span> <span class="o">*</span> <span class="n">Delta</span><span class="p">())</span>
<span class="o">>>></span> <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">sample</span><span class="p">(</span><span class="mi">20</span><span class="p">));</span> <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="image-container">
<img src="/assets/images/posts/linear-models-noisy-linear-functions-2.png" alt="Samples of a Gaussian process that models noisy linear functions" id="figure-noisy-linear-functions-2" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 5: Samples of a Gaussian process that models noisy linear functions
</p>
</div>
<p>Much better.</p>
<p>To summarise, our linear model is given by</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">prior</span> <span class="o">=</span> <span class="n">Measure</span><span class="p">()</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for slope
</span><span class="n">b</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for offset
</span><span class="n">f</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">b</span> <span class="c1"># Noiseless linear model
</span>
<span class="n">noise</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="n">Delta</span><span class="p">(),</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for noise
</span><span class="n">y</span> <span class="o">=</span> <span class="n">f</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">noise</span> <span class="c1"># Noisy linear model
</span></code></pre></div></div>
<p>We call a program like this a <em>Gaussian process probabilistic program</em> (GPPP).
Let’s generate some noisy synthetic data, <code class="language-plaintext highlighter-rouge">(x_obs, y_obs)</code>, that will make up an example data set \((x_i, y_i)_{i=1}^n\).
We also save the observations without noise added — <code class="language-plaintext highlighter-rouge">f_obs</code> — so we can later check how good our predictions really are.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">x_obs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">50_000</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">f_obs</span> <span class="o">=</span> <span class="mf">0.8</span> <span class="o">*</span> <span class="n">x_obs</span> <span class="o">-</span> <span class="mf">2.5</span>
<span class="o">>>></span> <span class="n">y_obs</span> <span class="o">=</span> <span class="n">f_obs</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">50_000</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">plt</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">x_obs</span><span class="p">,</span> <span class="n">y_obs</span><span class="p">);</span> <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="image-container">
<img src="/assets/images/posts/linear-models-observations.png" alt="Some observations" id="figure-observations" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 6: Some observations
</p>
</div>
<p>We will see next how we can fit our model to this data.</p>
<h2 id="inference-in-linear-models">Inference in Linear Models</h2>
<p>Suppose that we wish to remove the noise from the observations in <a href="#figure-observations">Figure 6</a>.
We carefully phrase this problem in terms of our GPPP:
the observations <code class="language-plaintext highlighter-rouge">y_obs</code> are realisations of the <em>noisy</em> linear model <code class="language-plaintext highlighter-rouge">y</code> at <code class="language-plaintext highlighter-rouge">x_obs</code> — realisations of <code class="language-plaintext highlighter-rouge">y(x_obs)</code> — and we wish to make predictions for the <em>noiseless</em> linear model <code class="language-plaintext highlighter-rouge">f</code> at <code class="language-plaintext highlighter-rouge">x_obs</code> — predictions for <code class="language-plaintext highlighter-rouge">f(x_obs)</code>.</p>
<p>In Stheno, we can make predictions based on observations by <em>conditioning</em> the measure of the model on the observations.
In our GPPP, the measure is given by <code class="language-plaintext highlighter-rouge">prior</code>, so we aim to condition <code class="language-plaintext highlighter-rouge">prior</code> on the observations <code class="language-plaintext highlighter-rouge">y_obs</code> for <code class="language-plaintext highlighter-rouge">y(x_obs)</code>.
Mathematically, this process of incorporating information by conditioning happens through <a href="https://en.wikipedia.org/wiki/Bayes%27_theorem">Bayes’ rule</a>.
Programmatically, we first make an <code class="language-plaintext highlighter-rouge">Observations</code> object, which represents the information — the observations — that we want to incorporate, and then condition <code class="language-plaintext highlighter-rouge">prior</code> on this object:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="kn">from</span> <span class="nn">stheno</span> <span class="kn">import</span> <span class="n">Observations</span>
<span class="o">>>></span> <span class="n">obs</span> <span class="o">=</span> <span class="n">Observations</span><span class="p">(</span><span class="n">y</span><span class="p">(</span><span class="n">x_obs</span><span class="p">),</span> <span class="n">y_obs</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">post</span> <span class="o">=</span> <span class="n">prior</span><span class="p">.</span><span class="n">condition</span><span class="p">(</span><span class="n">obs</span><span class="p">)</span>
</code></pre></div></div>
<p>You can also more concisely perform these two steps at once, as follows:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">post</span> <span class="o">=</span> <span class="n">prior</span> <span class="o">|</span> <span class="p">(</span><span class="n">y</span><span class="p">(</span><span class="n">x_obs</span><span class="p">),</span> <span class="n">y_obs</span><span class="p">)</span>
</code></pre></div></div>
<p>This mimics the mathematical notation used for conditioning.</p>
<p>With our updated measure <code class="language-plaintext highlighter-rouge">post</code>, which is often called the <em>posterior</em> measure, we can make a prediction for <code class="language-plaintext highlighter-rouge">f(x_obs)</code> by passing <code class="language-plaintext highlighter-rouge">f(x_obs)</code> to <code class="language-plaintext highlighter-rouge">post</code>:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">pred</span> <span class="o">=</span> <span class="n">post</span><span class="p">(</span><span class="n">f</span><span class="p">(</span><span class="n">x_obs</span><span class="p">))</span>
<span class="o">>>></span> <span class="n">pred</span><span class="p">.</span><span class="n">mean</span>
<span class="o"><</span><span class="n">dense</span> <span class="n">matrix</span><span class="p">:</span> <span class="n">shape</span><span class="o">=</span><span class="mi">50000</span><span class="n">x1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float64</span>
<span class="n">mat</span><span class="o">=</span><span class="p">[[</span><span class="o">-</span><span class="mf">2.498</span><span class="p">]</span>
<span class="p">[</span><span class="o">-</span><span class="mf">2.498</span><span class="p">]</span>
<span class="p">[</span><span class="o">-</span><span class="mf">2.498</span><span class="p">]</span>
<span class="p">...</span>
<span class="p">[</span> <span class="mf">5.501</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">5.502</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">5.502</span><span class="p">]]</span><span class="o">></span>
<span class="o">>>></span> <span class="n">pred</span><span class="p">.</span><span class="n">var</span>
<span class="o"><</span><span class="n">low</span><span class="o">-</span><span class="n">rank</span> <span class="n">matrix</span><span class="p">:</span> <span class="n">shape</span><span class="o">=</span><span class="mi">50000</span><span class="n">x50000</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float64</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="mi">2</span>
<span class="n">left</span><span class="o">=</span><span class="p">[[</span><span class="mf">1.e+00</span> <span class="mf">0.e+00</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">2.e-04</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">4.e-04</span><span class="p">]</span>
<span class="p">...</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]]</span>
<span class="n">middle</span><span class="o">=</span><span class="p">[[</span> <span class="mf">2.001e-05</span> <span class="o">-</span><span class="mf">2.995e-06</span><span class="p">]</span>
<span class="p">[</span><span class="o">-</span><span class="mf">2.997e-06</span> <span class="mf">6.011e-07</span><span class="p">]]</span>
<span class="n">right</span><span class="o">=</span><span class="p">[[</span><span class="mf">1.e+00</span> <span class="mf">0.e+00</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">2.e-04</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">4.e-04</span><span class="p">]</span>
<span class="p">...</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]]</span><span class="o">></span>
</code></pre></div></div>
<p>The prediction <code class="language-plaintext highlighter-rouge">pred</code> is a <a href="https://en.wikipedia.org/wiki/Multivariate_Gaussian_distribution">multivariate Gaussian distribution</a> with a particular mean and variance, which are displayed above.
You should view <code class="language-plaintext highlighter-rouge">post</code> as a function that assigns a probability distribution — the prediction — to every part of our GPPP, like <code class="language-plaintext highlighter-rouge">f(x_obs)</code>.
Note that the variance of the prediction is a <em>massive</em> matrix of size 50k \(\times\) 50k.
Under the hood, Stheno uses <a href="https://github.com/wesselb/matrix">structured representations for matrices</a> to compute and store matrices in an efficient way.</p>
<p>Let’s see how the prediction <code class="language-plaintext highlighter-rouge">pred</code> for <code class="language-plaintext highlighter-rouge">f(x_obs)</code> looks like.
The prediction <code class="language-plaintext highlighter-rouge">pred</code> exposes the method <code class="language-plaintext highlighter-rouge">marginal_credible_bounds()</code> that conveniently computes the mean and associated lower and upper error bounds for you.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">mean</span><span class="p">,</span> <span class="n">error_lower</span><span class="p">,</span> <span class="n">error_upper</span> <span class="o">=</span> <span class="n">pred</span><span class="p">.</span><span class="n">marginal_credible_bounds</span><span class="p">()</span>
<span class="o">>>></span> <span class="n">mean</span>
<span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">2.49818708</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.49802708</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.49786708</span><span class="p">,</span> <span class="p">...,</span> <span class="mf">5.50148996</span><span class="p">,</span>
<span class="mf">5.50164996</span><span class="p">,</span> <span class="mf">5.50180997</span><span class="p">])</span>
<span class="o">>>></span> <span class="n">error_upper</span> <span class="o">-</span> <span class="n">error_lower</span>
<span class="n">array</span><span class="p">([</span><span class="mf">0.01753381</span><span class="p">,</span> <span class="mf">0.01753329</span><span class="p">,</span> <span class="mf">0.01753276</span><span class="p">,</span> <span class="p">...,</span> <span class="mf">0.01761883</span><span class="p">,</span> <span class="mf">0.01761935</span><span class="p">,</span>
<span class="mf">0.01761988</span><span class="p">])</span>
</code></pre></div></div>
<p>The error is very small — on the order of \(10^{-2}\) — which means that Stheno predicted <code class="language-plaintext highlighter-rouge">f(x_obs)</code> with high confidence.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">plt</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">x_obs</span><span class="p">,</span> <span class="n">y_obs</span><span class="p">);</span> <span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x_obs</span><span class="p">,</span> <span class="n">mean</span><span class="p">);</span> <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="image-container">
<img src="/assets/images/posts/linear-models-denoised-observations.png" alt="Mean of the prediction (blue line) for the denoised observations" id="figure-denoised-observations" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 7: Mean of the prediction (blue line) for the denoised observations
</p>
</div>
<p>The blue line in <a href="#figure-denoised-observations">Figure 7</a> shows the mean of the predictions.
This line appears to nicely pass through the observations with the noise removed.
But let’s see how good the predictions really are by comparing to <code class="language-plaintext highlighter-rouge">f_obs</code>, which we previously saved.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">f_obs</span> <span class="o">-</span> <span class="n">mean</span>
<span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">0.00181292</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.00181292</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.00181292</span><span class="p">,</span> <span class="p">...,</span> <span class="o">-</span><span class="mf">0.00180997</span><span class="p">,</span>
<span class="o">-</span><span class="mf">0.00180997</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.00180997</span><span class="p">])</span>
<span class="o">>>></span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">f_obs</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="c1"># Compute the mean square error.
</span><span class="mf">3.281323087544209e-06</span>
</code></pre></div></div>
<p>That’s pretty close!
Not bad at all.</p>
<p>We wrap up this section by encapsulating everything that we’ve done so far in a function <code class="language-plaintext highlighter-rouge">linear_model_denoise</code>, which denoises noisy observations from a linear model:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">linear_model_denoise</span><span class="p">(</span><span class="n">x_obs</span><span class="p">,</span> <span class="n">y_obs</span><span class="p">):</span>
<span class="n">prior</span> <span class="o">=</span> <span class="n">Measure</span><span class="p">()</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for slope
</span> <span class="n">b</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for offset
</span> <span class="n">f</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">b</span> <span class="c1"># Noiseless linear model
</span> <span class="n">noise</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="n">Delta</span><span class="p">(),</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for noise
</span> <span class="n">y</span> <span class="o">=</span> <span class="n">f</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">noise</span> <span class="c1"># Noisy linear model
</span>
<span class="n">post</span> <span class="o">=</span> <span class="n">prior</span> <span class="o">|</span> <span class="p">(</span><span class="n">y</span><span class="p">(</span><span class="n">x_obs</span><span class="p">),</span> <span class="n">y_obs</span><span class="p">)</span> <span class="c1"># Condition on observations.
</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">post</span><span class="p">(</span><span class="n">f</span><span class="p">(</span><span class="n">x_obs</span><span class="p">))</span> <span class="c1"># Make predictions.
</span> <span class="k">return</span> <span class="n">pred</span><span class="p">.</span><span class="n">marginal_credible_bounds</span><span class="p">()</span> <span class="c1"># Return the mean and associated error bounds.
</span></code></pre></div></div>
<p></p>
<p><!-- Prevent tabs. --></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">linear_model_denoise</span><span class="p">(</span><span class="n">x_obs</span><span class="p">,</span> <span class="n">y_obs</span><span class="p">)</span>
<span class="p">(</span><span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">2.49818708</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.49802708</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.49786708</span><span class="p">,</span> <span class="p">...,</span> <span class="mf">5.50148996</span><span class="p">,</span>
<span class="mf">5.50164996</span><span class="p">,</span> <span class="mf">5.50180997</span><span class="p">]),</span> <span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">2.50695399</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.50679372</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.50663346</span><span class="p">,</span> <span class="p">...,</span> <span class="mf">5.49268055</span><span class="p">,</span>
<span class="mf">5.49284029</span><span class="p">,</span> <span class="mf">5.49300003</span><span class="p">]),</span> <span class="n">array</span><span class="p">([</span><span class="o">-</span><span class="mf">2.48942018</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.48926044</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.4891007</span> <span class="p">,</span> <span class="p">...,</span> <span class="mf">5.51029937</span><span class="p">,</span>
<span class="mf">5.51045964</span><span class="p">,</span> <span class="mf">5.51061991</span><span class="p">]))</span>
<span class="o">>>></span> <span class="o">%</span><span class="n">timeit</span> <span class="n">linear_model_denoise</span><span class="p">(</span><span class="n">x_obs</span><span class="p">,</span> <span class="n">y_obs</span><span class="p">)</span>
<span class="mi">233</span> <span class="n">ms</span> <span class="err">±</span> <span class="mf">12.6</span> <span class="n">ms</span> <span class="n">per</span> <span class="n">loop</span> <span class="p">(</span><span class="n">mean</span> <span class="err">±</span> <span class="n">std</span><span class="p">.</span> <span class="n">dev</span><span class="p">.</span> <span class="n">of</span> <span class="mi">7</span> <span class="n">runs</span><span class="p">,</span> <span class="mi">1</span> <span class="n">loop</span> <span class="n">each</span><span class="p">)</span>
</code></pre></div></div>
<p>To denoise 50k observations, <code class="language-plaintext highlighter-rouge">linear_model_denoise</code> takes about 250 ms.
Not terrible, but we can do much better, which is important if we want to scale to larger numbers of observations.
In the next section, we will make this function really fast.</p>
<h2 id="making-inference-fast">Making Inference Fast</h2>
<p>To make <code class="language-plaintext highlighter-rouge">linear_model_denoise</code> fast, firstly, the linear algebra that happens under the hood when <code class="language-plaintext highlighter-rouge">linear_model_denoise</code> is called should be simplified as much as possible.
Fortunately, this happens automatically, due to <a href="https://github.com/wesselb/matrix">the structured representation of matrices</a> that Stheno uses.
For example, when making predictions with Gaussian processes, the main computational bottleneck is usually the construction and inversion of <code class="language-plaintext highlighter-rouge">y(x_obs).var</code>, the variance associated with the observations:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">y</span><span class="p">(</span><span class="n">x_obs</span><span class="p">).</span><span class="n">var</span>
<span class="o"><</span><span class="n">Woodbury</span> <span class="n">matrix</span><span class="p">:</span> <span class="n">shape</span><span class="o">=</span><span class="mi">50000</span><span class="n">x50000</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float64</span>
<span class="n">diag</span><span class="o">=<</span><span class="n">diagonal</span> <span class="n">matrix</span><span class="p">:</span> <span class="n">shape</span><span class="o">=</span><span class="mi">50000</span><span class="n">x50000</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float64</span>
<span class="n">diag</span><span class="o">=</span><span class="p">[</span><span class="mf">0.25</span> <span class="mf">0.25</span> <span class="mf">0.25</span> <span class="p">...</span> <span class="mf">0.25</span> <span class="mf">0.25</span> <span class="mf">0.25</span><span class="p">]</span><span class="o">></span>
<span class="n">lr</span><span class="o">=<</span><span class="n">low</span><span class="o">-</span><span class="n">rank</span> <span class="n">matrix</span><span class="p">:</span> <span class="n">shape</span><span class="o">=</span><span class="mi">50000</span><span class="n">x50000</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float64</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="mi">2</span>
<span class="n">left</span><span class="o">=</span><span class="p">[[</span><span class="mf">1.e+00</span> <span class="mf">0.e+00</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">2.e-04</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">4.e-04</span><span class="p">]</span>
<span class="p">...</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]]</span>
<span class="n">middle</span><span class="o">=</span><span class="p">[[</span><span class="mf">10.</span> <span class="mf">0.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.</span> <span class="mf">1.</span><span class="p">]]</span>
<span class="n">right</span><span class="o">=</span><span class="p">[[</span><span class="mf">1.e+00</span> <span class="mf">0.e+00</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">2.e-04</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">4.e-04</span><span class="p">]</span>
<span class="p">...</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]</span>
<span class="p">[</span><span class="mf">1.e+00</span> <span class="mf">1.e+01</span><span class="p">]]</span><span class="o">>></span>
</code></pre></div></div>
<p>Indeed observe that this matrix has particular structure:
it is a sum of a diagonal and a low-rank matrix.
In Stheno, the sum of a diagonal and a low-rank matrix is called a <em>Woodbury</em> matrix, because the <a href="https://en.wikipedia.org/wiki/Woodbury_matrix_identity">Sherman–Morrison–Woodbury formula</a> can be used to efficiently invert it.
Let’s see how long it takes to construct <code class="language-plaintext highlighter-rouge">y(x_obs).var</code> and then invert it.
We invert <code class="language-plaintext highlighter-rouge">y(x_obs).var</code> using <a href="https://github.com/wesselb/lab">LAB</a>, which is automatically installed alongside Stheno and exposes the API to efficiently work with structured matrices.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="kn">import</span> <span class="nn">lab</span> <span class="k">as</span> <span class="n">B</span>
<span class="o">>>></span> <span class="o">%</span><span class="n">timeit</span> <span class="n">B</span><span class="p">.</span><span class="n">inv</span><span class="p">(</span><span class="n">y</span><span class="p">(</span><span class="n">x_obs</span><span class="p">).</span><span class="n">var</span><span class="p">)</span>
<span class="mf">28.5</span> <span class="n">ms</span> <span class="err">±</span> <span class="mf">1.69</span> <span class="n">ms</span> <span class="n">per</span> <span class="n">loop</span> <span class="p">(</span><span class="n">mean</span> <span class="err">±</span> <span class="n">std</span><span class="p">.</span> <span class="n">dev</span><span class="p">.</span> <span class="n">of</span> <span class="mi">7</span> <span class="n">runs</span><span class="p">,</span> <span class="mi">10</span> <span class="n">loops</span> <span class="n">each</span><span class="p">)</span>
</code></pre></div></div>
<p>That’s only 30 ms! Not bad, for such a big matrix. Without exploiting structure, a 50k \(\times\) 50k matrix takes 20 GB of memory to store and about an hour to invert.</p>
<p>Secondly, we would like the code implemented by <code class="language-plaintext highlighter-rouge">linear_model_denoise</code> to be as efficient as possible.
To achieve this, we will use <a href="https://github.com/google/jax">JAX</a> to compile <code class="language-plaintext highlighter-rouge">linear_model_denoise</code> with <a href="https://www.tensorflow.org/xla">XLA</a>, which generates blazingly fast code.
We start out by importing JAX and loading the JAX extension of Stheno.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="n">jnp</span>
<span class="o">>>></span> <span class="kn">import</span> <span class="nn">stheno.jax</span> <span class="c1"># JAX extension for Stheno
</span></code></pre></div></div>
<p>We use JAX’s just-in-time (JIT) compiler to compile <code class="language-plaintext highlighter-rouge">linear_model_denoise</code>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="kn">import</span> <span class="nn">lab</span> <span class="k">as</span> <span class="n">B</span>
<span class="o">>>></span> <span class="n">linear_model_denoise_jitted</span> <span class="o">=</span> <span class="n">B</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">linear_model_denoise</span><span class="p">)</span>
</code></pre></div></div>
<p>Let’s see what happens when we run <code class="language-plaintext highlighter-rouge">linear_model_denoise_jitted</code>.
We must pass <code class="language-plaintext highlighter-rouge">x_obs</code> and <code class="language-plaintext highlighter-rouge">y_obs</code> as JAX arrays to use the compiled version.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">linear_model_denoise_jitted</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">x_obs</span><span class="p">),</span> <span class="n">jnp</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">y_obs</span><span class="p">))</span>
<span class="p">(</span><span class="n">DeviceArray</span><span class="p">([</span><span class="o">-</span><span class="mf">2.4981871</span> <span class="p">,</span> <span class="o">-</span><span class="mf">2.4980271</span> <span class="p">,</span> <span class="o">-</span><span class="mf">2.49786709</span><span class="p">,</span> <span class="p">...,</span> <span class="mf">5.50149004</span><span class="p">,</span>
<span class="mf">5.50165005</span><span class="p">,</span> <span class="mf">5.50181005</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float64</span><span class="p">),</span> <span class="n">DeviceArray</span><span class="p">([</span><span class="o">-</span><span class="mf">2.5069514</span> <span class="p">,</span> <span class="o">-</span><span class="mf">2.50679114</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.50663087</span><span class="p">,</span> <span class="p">...,</span> <span class="mf">5.4927699</span> <span class="p">,</span>
<span class="mf">5.49292964</span><span class="p">,</span> <span class="mf">5.49308938</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float64</span><span class="p">),</span> <span class="n">DeviceArray</span><span class="p">([</span><span class="o">-</span><span class="mf">2.4894228</span> <span class="p">,</span> <span class="o">-</span><span class="mf">2.48926306</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.48910332</span><span class="p">,</span> <span class="p">...,</span> <span class="mf">5.51021019</span><span class="p">,</span>
<span class="mf">5.51037046</span><span class="p">,</span> <span class="mf">5.51053072</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float64</span><span class="p">))</span>
</code></pre></div></div>
<p>Nice!
Let’s see how much faster <code class="language-plaintext highlighter-rouge">linear_model_denoise_jitted</code> is:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="o">%</span><span class="n">timeit</span> <span class="n">linear_model_denoise</span><span class="p">(</span><span class="n">x_obs</span><span class="p">,</span> <span class="n">y_obs</span><span class="p">)</span>
<span class="mi">233</span> <span class="n">ms</span> <span class="err">±</span> <span class="mf">12.6</span> <span class="n">ms</span> <span class="n">per</span> <span class="n">loop</span> <span class="p">(</span><span class="n">mean</span> <span class="err">±</span> <span class="n">std</span><span class="p">.</span> <span class="n">dev</span><span class="p">.</span> <span class="n">of</span> <span class="mi">7</span> <span class="n">runs</span><span class="p">,</span> <span class="mi">1</span> <span class="n">loop</span> <span class="n">each</span><span class="p">)</span>
<span class="o">>>></span> <span class="o">%</span><span class="n">timeit</span> <span class="n">linear_model_denoise_jitted</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">x_obs</span><span class="p">),</span> <span class="n">jnp</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">y_obs</span><span class="p">))</span>
<span class="mf">1.63</span> <span class="n">ms</span> <span class="err">±</span> <span class="mf">16.5</span> <span class="n">µs</span> <span class="n">per</span> <span class="n">loop</span> <span class="p">(</span><span class="n">mean</span> <span class="err">±</span> <span class="n">std</span><span class="p">.</span> <span class="n">dev</span><span class="p">.</span> <span class="n">of</span> <span class="mi">7</span> <span class="n">runs</span><span class="p">,</span> <span class="mi">1000</span> <span class="n">loops</span> <span class="n">each</span><span class="p">)</span>
</code></pre></div></div>
<p>The compiled function <code class="language-plaintext highlighter-rouge">linear_model_denoise_jitted</code> only takes 2 ms to denoise 50k observations!
Compared to <code class="language-plaintext highlighter-rouge">linear_model_denoise</code>, that’s a speed-up of two orders of magnitude.</p>
<h2 id="conclusion">Conclusion</h2>
<p>We’ve seen how a linear model can be implemented with a Gaussian process probabilistic program (GPPP) using <a href="https://github.com/wesselb/stheno">Stheno</a>.
Stheno allows us to focus on model construction, and takes away the distraction of the technicalities that come with making predictions.
This flexibility, however, comes at the cost of some complicated machinery that happens in the background, such as structured representations of matrices.
Fortunately, we’ve seen that this overhead can be completely avoided by compiling your program using <a href="https://github.com/google/jax">JAX</a>, which can result in extremely efficient implementations.
To close this post and to warm you up for <a href="https://github.com/wesselb/stheno#examples">what’s further possible with Gaussian process probabilistic programming using Stheno</a>, the linear model that we’ve built can easily be extended to, for example, include a <em>quadratic</em> term:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">quadratic_model_denoise</span><span class="p">(</span><span class="n">x_obs</span><span class="p">,</span> <span class="n">y_obs</span><span class="p">):</span>
<span class="n">prior</span> <span class="o">=</span> <span class="n">Measure</span><span class="p">()</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for slope
</span> <span class="n">b</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for coefficient of quadratic term
</span> <span class="n">c</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for offset
</span> <span class="c1"># Noiseless quadratic model
</span> <span class="n">f</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">b</span> <span class="o">*</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">+</span> <span class="n">c</span>
<span class="n">noise</span> <span class="o">=</span> <span class="n">GP</span><span class="p">(</span><span class="n">Delta</span><span class="p">(),</span> <span class="n">measure</span><span class="o">=</span><span class="n">prior</span><span class="p">)</span> <span class="c1"># Model for noise
</span> <span class="n">y</span> <span class="o">=</span> <span class="n">f</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">noise</span> <span class="c1"># Noisy quadratic model
</span>
<span class="n">post</span> <span class="o">=</span> <span class="n">prior</span> <span class="o">|</span> <span class="p">(</span><span class="n">y</span><span class="p">(</span><span class="n">x_obs</span><span class="p">),</span> <span class="n">y_obs</span><span class="p">)</span> <span class="c1"># Condition on observations.
</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">post</span><span class="p">(</span><span class="n">f</span><span class="p">(</span><span class="n">x_obs</span><span class="p">))</span> <span class="c1"># Make predictions.
</span> <span class="k">return</span> <span class="n">pred</span><span class="p">.</span><span class="n">marginal_credible_bounds</span><span class="p">()</span> <span class="c1"># Return the mean and associated error bounds.
</span></code></pre></div></div>
<p>To use Gaussian process probabilistic programming for your specific problem, the main challenge is to figure out which model you need to use.
Do you need a quadratic term?
Maybe you need an exponential term!
But, using Stheno, implementing the model and making predictions should then be simple.</p>By Wessel Bruinsma, James Requeima, and Eric Perim MartinsJulia Learning Circle: Generated Functions2020-12-13T00:00:00+00:002020-12-13T00:00:00+00:00https://wessel.ai/2020/12/13/julia-learning-circle-meeting-3<p>A normal function outputs the result of the computation by the function.
In contrast, a <a href="https://docs.julialang.org/en/v1.6-dev/manual/metaprogramming/#Generated-functions">generated function</a> outputs <em>the code that implements the function</em>.
While generating this code, the generated function can only make use of the <em>types</em> of the arguments, not their <em>values</em>.
In a sense, generated functions offer <a href="https://discourse.julialang.org/t/understanding-generated-functions/10092/4">“on-demand code generation”</a>.
This mechanism is quite powerful and can be used when normal functions in combination with multiple dispatch cannot give you what you need.</p>
<p>To illustrate generated functions, we will build on the example of <a href="/2020/11/23/julia-learning-circle-meeting-2.html#case-study-stack-allocated-vectors-aka-a-very-brief-introduction-to-staticarraysjl">stack-allocated vectors from the previous post</a>.
We will extend our stack-allocated vector to a stack-allocated <em>matrix</em>, and we will use a generated function to implement matrix multiplication.
Let’s start out by defining a stack-allocated vector and matrix.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">struct</span><span class="nc"> StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">M</span><span class="x">,</span> <span class="n">N</span><span class="x">,</span> <span class="n">L</span><span class="x">}</span>
<span class="n">data</span><span class="o">::</span><span class="kt">NTuple</span><span class="x">{</span><span class="n">L</span><span class="x">,</span> <span class="n">T</span><span class="x">}</span>
<span class="k">end</span>
<span class="k">function</span><span class="nf"> StackVector</span><span class="x">(</span><span class="n">data</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="n">T</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span>
<span class="k">return</span> <span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">length</span><span class="x">(</span><span class="n">data</span><span class="x">),</span> <span class="mi">1</span><span class="x">,</span> <span class="n">length</span><span class="x">(</span><span class="n">data</span><span class="x">)}(</span><span class="kt">Tuple</span><span class="x">(</span><span class="n">data</span><span class="x">))</span>
<span class="k">end</span>
<span class="k">function</span><span class="nf"> StackMatrix</span><span class="x">(</span><span class="n">data</span><span class="o">::</span><span class="kt">Matrix</span><span class="x">{</span><span class="n">T</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span>
<span class="n">M</span><span class="x">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">data</span><span class="x">)</span>
<span class="k">return</span> <span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">M</span><span class="x">,</span> <span class="n">N</span><span class="x">,</span> <span class="n">length</span><span class="x">(</span><span class="n">data</span><span class="x">)}(</span><span class="kt">Tuple</span><span class="x">(</span><span class="n">data</span><span class="x">[</span><span class="o">:</span><span class="x">]))</span>
<span class="k">end</span>
</code></pre></div></div>
<p>The type signature is <code class="language-plaintext highlighter-rouge">StackMatrix{T, M, N, L}</code> where <code class="language-plaintext highlighter-rouge">T</code> is the type of the elements of the matrix, <code class="language-plaintext highlighter-rouge">M</code> is the number of rows of the matrix, <code class="language-plaintext highlighter-rouge">N</code> is the number of columns of the matrix, and <code class="language-plaintext highlighter-rouge">L = M * N</code> is the total number of elements in the matrix;
even though <code class="language-plaintext highlighter-rouge">L</code> can always be computed from <code class="language-plaintext highlighter-rouge">M</code> and <code class="language-plaintext highlighter-rouge">N</code>, we need <code class="language-plaintext highlighter-rouge">L</code> in the type signature, because it specifies the length of the <code class="language-plaintext highlighter-rouge">NTuple</code>.</p>
<p>Before we implement multiplication of general <code class="language-plaintext highlighter-rouge">StackMatrix{T, M, N, L}</code>s, we first consider the case of <code class="language-plaintext highlighter-rouge">StackMatrix{T, 2, 2, 4}</code>s.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">import</span> <span class="n">Base</span><span class="o">:</span> <span class="o">*</span>
<span class="k">function</span><span class="nf"> </span><span class="o">*(</span><span class="n">x</span><span class="o">::</span><span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">},</span> <span class="n">y</span><span class="o">::</span><span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span>
<span class="n">x11</span><span class="x">,</span> <span class="n">x21</span><span class="x">,</span> <span class="n">x12</span><span class="x">,</span> <span class="n">x22</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">data</span>
<span class="n">y11</span><span class="x">,</span> <span class="n">y21</span><span class="x">,</span> <span class="n">y12</span><span class="x">,</span> <span class="n">y22</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">data</span>
<span class="n">z11</span> <span class="o">=</span> <span class="n">x11</span> <span class="o">*</span> <span class="n">y11</span> <span class="o">+</span> <span class="n">x12</span> <span class="o">*</span> <span class="n">y21</span>
<span class="n">z21</span> <span class="o">=</span> <span class="n">x21</span> <span class="o">*</span> <span class="n">y11</span> <span class="o">+</span> <span class="n">x22</span> <span class="o">*</span> <span class="n">y21</span>
<span class="n">z12</span> <span class="o">=</span> <span class="n">x11</span> <span class="o">*</span> <span class="n">y12</span> <span class="o">+</span> <span class="n">x12</span> <span class="o">*</span> <span class="n">y22</span>
<span class="n">z22</span> <span class="o">=</span> <span class="n">x21</span> <span class="o">*</span> <span class="n">y12</span> <span class="o">+</span> <span class="n">x22</span> <span class="o">*</span> <span class="n">y22</span>
<span class="k">return</span> <span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">}((</span><span class="n">z11</span><span class="x">,</span> <span class="n">z21</span><span class="x">,</span> <span class="n">z12</span><span class="x">,</span> <span class="n">z22</span><span class="x">))</span>
<span class="k">end</span>
</code></pre></div></div>
<p>Let’s check that the implementation is correct.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">2</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">y</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">2</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">x_stack</span> <span class="o">=</span> <span class="n">StackMatrix</span><span class="x">(</span><span class="n">x</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">y_stack</span> <span class="o">=</span> <span class="n">StackMatrix</span><span class="x">(</span><span class="n">y</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span>
<span class="mi">2</span><span class="n">×2</span> <span class="kt">Matrix</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}</span><span class="o">:</span>
<span class="o">-</span><span class="mf">1.16361</span> <span class="mf">0.848159</span>
<span class="mf">0.355827</span> <span class="o">-</span><span class="mf">0.441428</span>
<span class="n">julia</span><span class="o">></span> <span class="n">reshape</span><span class="x">(</span><span class="n">collect</span><span class="x">((</span><span class="n">x_stack</span> <span class="o">*</span> <span class="n">y_stack</span><span class="x">)</span><span class="o">.</span><span class="n">data</span><span class="x">),</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span>
<span class="mi">2</span><span class="n">×2</span> <span class="kt">Matrix</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}</span><span class="o">:</span>
<span class="o">-</span><span class="mf">1.16361</span> <span class="mf">0.848159</span>
<span class="mf">0.355827</span> <span class="o">-</span><span class="mf">0.441428</span>
</code></pre></div></div>
<p>Nice!
And it is quite a bit faster, too.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="nd">@benchmark</span> <span class="o">$</span><span class="n">x</span> <span class="o">*</span> <span class="o">$</span><span class="n">y</span>
<span class="n">BenchmarkTools</span><span class="o">.</span><span class="n">Trial</span><span class="o">:</span>
<span class="n">memory</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">112</span> <span class="n">bytes</span>
<span class="n">allocs</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">1</span>
<span class="o">--------------</span>
<span class="n">minimum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">54.112</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">median</span> <span class="n">time</span><span class="o">:</span> <span class="mf">59.993</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">mean</span> <span class="n">time</span><span class="o">:</span> <span class="mf">62.529</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">1.27</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">maximum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">486.884</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">83.35</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="o">--------------</span>
<span class="n">samples</span><span class="o">:</span> <span class="mi">10000</span>
<span class="n">evals</span><span class="o">/</span><span class="n">sample</span><span class="o">:</span> <span class="mi">973</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@benchmark</span> <span class="o">$</span><span class="x">(</span><span class="kt">Ref</span><span class="x">(</span><span class="n">x_stack</span><span class="x">))[]</span> <span class="o">*</span> <span class="o">$</span><span class="x">(</span><span class="kt">Ref</span><span class="x">(</span><span class="n">y_stack</span><span class="x">))[]</span>
<span class="n">BenchmarkTools</span><span class="o">.</span><span class="n">Trial</span><span class="o">:</span>
<span class="n">memory</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span> <span class="n">bytes</span>
<span class="n">allocs</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span>
<span class="o">--------------</span>
<span class="n">minimum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">3.015</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">median</span> <span class="n">time</span><span class="o">:</span> <span class="mf">3.033</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">mean</span> <span class="n">time</span><span class="o">:</span> <span class="mf">3.077</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">maximum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">16.869</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="o">--------------</span>
<span class="n">samples</span><span class="o">:</span> <span class="mi">10000</span>
<span class="n">evals</span><span class="o">/</span><span class="n">sample</span><span class="o">:</span> <span class="mi">1000</span>
</code></pre></div></div>
<p>The problem with multiplication of general <code class="language-plaintext highlighter-rouge">StackMatrix{T, M, N, L}</code>s is that the implementation depends on the particular values of <code class="language-plaintext highlighter-rouge">M</code> and <code class="language-plaintext highlighter-rouge">N</code> — for example, the variables <code class="language-plaintext highlighter-rouge">z11</code>, <code class="language-plaintext highlighter-rouge">z21</code>, <em>et cetera</em>.
We will use a generated function to <em>automatically generate the implementation of the corresponding matrix multiplication</em>.
This code-generation procedure depends on the values of <code class="language-plaintext highlighter-rouge">M</code> and <code class="language-plaintext highlighter-rouge">N</code> and will adapt the implementation accordingly.
Generated functions are defined with the macro <code class="language-plaintext highlighter-rouge">@generated</code>.
The implementation of multiplication of general <code class="language-plaintext highlighter-rouge">StackMatrix{T, M, N, L}</code>s as follows:</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">import</span> <span class="n">Base</span><span class="o">:</span> <span class="o">*</span>
<span class="nd">@generated</span> <span class="k">function</span><span class="nf"> </span><span class="o">*(</span>
<span class="n">x</span><span class="o">::</span><span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">M</span><span class="x">,</span> <span class="n">L₁</span><span class="x">},</span>
<span class="n">y</span><span class="o">::</span><span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">M</span><span class="x">,</span> <span class="n">N</span><span class="x">,</span> <span class="n">L₂</span><span class="x">}</span>
<span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">M</span><span class="x">,</span> <span class="n">N</span><span class="x">,</span> <span class="n">L₁</span><span class="x">,</span> <span class="n">L₂</span><span class="x">}</span>
<span class="c"># Unpack `x`.</span>
<span class="n">tuple_x</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="x">[</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"x_</span><span class="si">$(k)</span><span class="s">_</span><span class="si">$(m)</span><span class="s">"</span><span class="x">)</span> <span class="k">for</span> <span class="n">m</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">M</span> <span class="k">for</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">K</span><span class="x">]</span><span class="o">...</span><span class="x">)</span>
<span class="n">unpack_x</span> <span class="o">=</span> <span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">tuple_x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
<span class="c"># Unpack `y`.</span>
<span class="n">tuple_y</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="x">[</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"y_</span><span class="si">$(m)</span><span class="s">_</span><span class="si">$(n)</span><span class="s">"</span><span class="x">)</span> <span class="k">for</span> <span class="n">n</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">N</span> <span class="k">for</span> <span class="n">m</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">M</span><span class="x">]</span><span class="o">...</span><span class="x">)</span>
<span class="n">unpack_y</span> <span class="o">=</span> <span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">tuple_y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
<span class="c"># Perform multiplication.</span>
<span class="n">mults</span> <span class="o">=</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Expr</span><span class="x">}()</span>
<span class="k">for</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">K</span><span class="x">,</span> <span class="n">n</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">N</span>
<span class="n">expr</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span>
<span class="o">:</span><span class="n">call</span><span class="x">,</span>
<span class="o">:+</span><span class="x">,</span>
<span class="x">[</span><span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="x">(</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"x_</span><span class="si">$(k)</span><span class="s">_</span><span class="si">$(m)</span><span class="s">"</span><span class="x">))</span> <span class="o">*</span> <span class="o">$</span><span class="x">(</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"y_</span><span class="si">$(m)</span><span class="s">_</span><span class="si">$(n)</span><span class="s">"</span><span class="x">)))</span> <span class="k">for</span> <span class="n">m</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">M</span><span class="x">]</span><span class="o">...</span>
<span class="x">)</span>
<span class="n">push!</span><span class="x">(</span><span class="n">mults</span><span class="x">,</span> <span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="x">(</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"z_</span><span class="si">$(k)</span><span class="s">_</span><span class="si">$(n)</span><span class="s">"</span><span class="x">))</span> <span class="o">=</span> <span class="o">$</span><span class="n">expr</span><span class="x">))</span>
<span class="k">end</span>
<span class="c"># Pack `z`.</span>
<span class="n">tuple_z</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="x">[</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"z_</span><span class="si">$(k)</span><span class="s">_</span><span class="si">$(n)</span><span class="s">"</span><span class="x">)</span> <span class="k">for</span> <span class="n">n</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">N</span> <span class="k">for</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">K</span><span class="x">]</span><span class="o">...</span><span class="x">)</span>
<span class="n">pack_z</span> <span class="o">=</span> <span class="o">:</span><span class="x">(</span><span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">N</span><span class="x">,</span> <span class="n">L₃</span><span class="x">}(</span><span class="o">$</span><span class="n">tuple_z</span><span class="x">))</span>
<span class="k">return</span> <span class="kt">Expr</span><span class="x">(</span>
<span class="o">:</span><span class="n">block</span><span class="x">,</span>
<span class="n">unpack_x</span><span class="x">,</span>
<span class="n">unpack_y</span><span class="x">,</span>
<span class="n">mults</span><span class="o">...</span><span class="x">,</span>
<span class="o">:</span><span class="x">(</span><span class="n">L₃</span> <span class="o">=</span> <span class="n">K</span> <span class="o">*</span> <span class="n">N</span><span class="x">),</span>
<span class="o">:</span><span class="x">(</span><span class="k">return</span> <span class="o">$</span><span class="n">pack_z</span><span class="x">)</span>
<span class="x">)</span>
<span class="k">end</span>
</code></pre></div></div>
<p>If we omit the macro <code class="language-plaintext highlighter-rouge">@generated</code>, we can call the implementation to inspect the generated code:</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">x_stack</span> <span class="o">*</span> <span class="n">y_stack</span>
<span class="k">quote</span>
<span class="x">(</span><span class="n">x_1_1</span><span class="x">,</span> <span class="n">x_2_1</span><span class="x">,</span> <span class="n">x_1_2</span><span class="x">,</span> <span class="n">x_2_2</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">data</span>
<span class="x">(</span><span class="n">y_1_1</span><span class="x">,</span> <span class="n">y_2_1</span><span class="x">,</span> <span class="n">y_1_2</span><span class="x">,</span> <span class="n">y_2_2</span><span class="x">)</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">data</span>
<span class="n">z_1_1</span> <span class="o">=</span> <span class="n">x_1_1</span> <span class="o">*</span> <span class="n">y_1_1</span> <span class="o">+</span> <span class="n">x_1_2</span> <span class="o">*</span> <span class="n">y_2_1</span>
<span class="n">z_1_2</span> <span class="o">=</span> <span class="n">x_1_1</span> <span class="o">*</span> <span class="n">y_1_2</span> <span class="o">+</span> <span class="n">x_1_2</span> <span class="o">*</span> <span class="n">y_2_2</span>
<span class="n">z_2_1</span> <span class="o">=</span> <span class="n">x_2_1</span> <span class="o">*</span> <span class="n">y_1_1</span> <span class="o">+</span> <span class="n">x_2_2</span> <span class="o">*</span> <span class="n">y_2_1</span>
<span class="n">z_2_2</span> <span class="o">=</span> <span class="n">x_2_1</span> <span class="o">*</span> <span class="n">y_1_2</span> <span class="o">+</span> <span class="n">x_2_2</span> <span class="o">*</span> <span class="n">y_2_2</span>
<span class="n">L₃</span> <span class="o">=</span> <span class="n">K</span> <span class="o">*</span> <span class="n">N</span>
<span class="k">return</span> <span class="n">StackMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">N</span><span class="x">,</span> <span class="n">L₃</span><span class="x">}((</span><span class="n">z_1_1</span><span class="x">,</span> <span class="n">z_2_1</span><span class="x">,</span> <span class="n">z_1_2</span><span class="x">,</span> <span class="n">z_2_2</span><span class="x">))</span>
<span class="k">end</span>
</code></pre></div></div>
<p>Sweet!
This looks very much like our earlier implementation of the two-by-two case.
Let’s again check that the implementation is correct.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">4</span><span class="x">,</span> <span class="mi">2</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">y</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">3</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">x_stack</span> <span class="o">=</span> <span class="n">StackMatrix</span><span class="x">(</span><span class="n">x</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">y_stack</span> <span class="o">=</span> <span class="n">StackMatrix</span><span class="x">(</span><span class="n">y</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span>
<span class="mi">4</span><span class="n">×3</span> <span class="kt">Matrix</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}</span><span class="o">:</span>
<span class="mf">0.125514</span> <span class="o">-</span><span class="mf">0.0135978</span> <span class="o">-</span><span class="mf">0.0283178</span>
<span class="o">-</span><span class="mf">1.93756</span> <span class="mf">0.450559</span> <span class="mf">1.17303</span>
<span class="mf">2.56769</span> <span class="o">-</span><span class="mf">0.365378</span> <span class="o">-</span><span class="mf">0.845966</span>
<span class="mf">3.22549</span> <span class="o">-</span><span class="mf">0.602203</span> <span class="o">-</span><span class="mf">1.50065</span>
<span class="n">julia</span><span class="o">></span> <span class="n">reshape</span><span class="x">(</span><span class="n">collect</span><span class="x">((</span><span class="n">x_stack</span> <span class="o">*</span> <span class="n">y_stack</span><span class="x">)</span><span class="o">.</span><span class="n">data</span><span class="x">),</span> <span class="mi">4</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span>
<span class="mi">4</span><span class="n">×3</span> <span class="kt">Matrix</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}</span><span class="o">:</span>
<span class="mf">0.125514</span> <span class="o">-</span><span class="mf">0.0135978</span> <span class="o">-</span><span class="mf">0.0283178</span>
<span class="o">-</span><span class="mf">1.93756</span> <span class="mf">0.450559</span> <span class="mf">1.17303</span>
<span class="mf">2.56769</span> <span class="o">-</span><span class="mf">0.365378</span> <span class="o">-</span><span class="mf">0.845966</span>
<span class="mf">3.22549</span> <span class="o">-</span><span class="mf">0.602203</span> <span class="o">-</span><span class="mf">1.50065</span>
</code></pre></div></div>
<p>Like the two-by-two case, this implementation is quite a bit faster, too.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="nd">@benchmark</span> <span class="o">$</span><span class="n">x</span> <span class="o">*</span> <span class="o">$</span><span class="n">y</span>
<span class="n">BenchmarkTools</span><span class="o">.</span><span class="n">Trial</span><span class="o">:</span>
<span class="n">memory</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">176</span> <span class="n">bytes</span>
<span class="n">allocs</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">1</span>
<span class="o">--------------</span>
<span class="n">minimum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">205.100</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">median</span> <span class="n">time</span><span class="o">:</span> <span class="mf">219.162</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">mean</span> <span class="n">time</span><span class="o">:</span> <span class="mf">229.711</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.74</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">maximum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">1.679</span> <span class="n">μs</span> <span class="x">(</span><span class="mf">75.82</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="o">--------------</span>
<span class="n">samples</span><span class="o">:</span> <span class="mi">10000</span>
<span class="n">evals</span><span class="o">/</span><span class="n">sample</span><span class="o">:</span> <span class="mi">530</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@benchmark</span> <span class="o">$</span><span class="x">(</span><span class="kt">Ref</span><span class="x">(</span><span class="n">x_stack</span><span class="x">))[]</span> <span class="o">*</span> <span class="o">$</span><span class="x">(</span><span class="kt">Ref</span><span class="x">(</span><span class="n">y_stack</span><span class="x">))[]</span>
<span class="n">BenchmarkTools</span><span class="o">.</span><span class="n">Trial</span><span class="o">:</span>
<span class="n">memory</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span> <span class="n">bytes</span>
<span class="n">allocs</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span>
<span class="o">--------------</span>
<span class="n">minimum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">15.097</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">median</span> <span class="n">time</span><span class="o">:</span> <span class="mf">15.665</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">mean</span> <span class="n">time</span><span class="o">:</span> <span class="mf">16.987</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">maximum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">103.605</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="o">--------------</span>
<span class="n">samples</span><span class="o">:</span> <span class="mi">10000</span>
<span class="n">evals</span><span class="o">/</span><span class="n">sample</span><span class="o">:</span> <span class="mi">997</span>
</code></pre></div></div>A normal function outputs the result of the computation by the function. In contrast, a generated function outputs the code that implements the function. While generating this code, the generated function can only make use of the types of the arguments, not their values. In a sense, generated functions offer “on-demand code generation”. This mechanism is quite powerful and can be used when normal functions in combination with multiple dispatch cannot give you what you need.Julia Learning Circle: Memory Allocations and Garbage Collection2020-11-23T00:00:00+00:002020-11-23T00:00:00+00:00https://wessel.ai/2020/11/23/julia-learning-circle-meeting-2<h2 id="immutable-and-mutable-types">Immutable and Mutable Types</h2>
<p>Concrete types in Julia are either immutable or mutable.
Immutable types are created with <code class="language-plaintext highlighter-rouge">struct ImmutableType</code> and mutable types are created with <code class="language-plaintext highlighter-rouge">mutable struct MutableType</code>.
The advantage of immutable types is that they can be allocated on the <em>stack</em> as opposed to on the <em>heap</em>.
Allocating objects on the stack is typically more performant due to cache locality and the stack’s simple, but more rigid memory structure.</p>
<p>An interesting situation occurs when an <em>immutable</em> type references a <em>mutable</em> type.
<a href="https://github.com/JuliaLang/julia/blob/release-1.5/NEWS.md#compilerruntime-improvements">Since Julia 1.5, such immutable types can be allocated on the stack.</a></p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="k">struct</span><span class="nc"> A</span>
<span class="n">data</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}</span>
<span class="k">end</span>
<span class="n">julia</span><span class="o">></span> <span class="n">a</span> <span class="o">=</span> <span class="n">A</span><span class="x">(</span><span class="n">randn</span><span class="x">(</span><span class="mi">3</span><span class="x">))</span>
<span class="n">A</span><span class="x">([</span><span class="mf">0.9462871255469765</span><span class="x">,</span> <span class="mf">1.1995018446247545</span><span class="x">,</span> <span class="mf">0.7153882414691778</span><span class="x">])</span>
</code></pre></div></div>
<p>Here <code class="language-plaintext highlighter-rouge">A</code> is immutable, but references a <code class="language-plaintext highlighter-rouge">Vector{Float64}</code>, which is mutable.
This means that <code class="language-plaintext highlighter-rouge">a.data</code> cannot be changed, but, since <code class="language-plaintext highlighter-rouge">a.data</code> is mutable, e.g. <code class="language-plaintext highlighter-rouge">a.data[1]</code> <em>can</em> be changed.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">a</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">3</span><span class="x">)</span>
<span class="n">ERROR</span><span class="o">:</span> <span class="n">setfield!</span> <span class="n">immutable</span> <span class="k">struct</span><span class="nc"> of</span> <span class="n">type</span> <span class="n">A</span> <span class="n">cannot</span> <span class="n">be</span> <span class="n">changed</span>
<span class="n">Stacktrace</span><span class="o">:</span>
<span class="x">[</span><span class="mi">1</span><span class="x">]</span> <span class="n">setproperty!</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="n">A</span><span class="x">,</span> <span class="n">f</span><span class="o">::</span><span class="kt">Symbol</span><span class="x">,</span> <span class="n">v</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span>
<span class="err">@</span> <span class="n">Base</span> <span class="o">./</span><span class="n">Base</span><span class="o">.</span><span class="n">jl</span><span class="o">:</span><span class="mi">34</span>
<span class="x">[</span><span class="mi">2</span><span class="x">]</span> <span class="n">top</span><span class="o">-</span><span class="n">level</span> <span class="n">scope</span>
<span class="err">@</span> <span class="n">REPL</span><span class="x">[</span><span class="mi">5</span><span class="x">]</span><span class="o">:</span><span class="mi">1</span>
<span class="n">julia</span><span class="o">></span> <span class="n">a</span><span class="o">.</span><span class="n">data</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="mf">1.0</span>
</code></pre></div></div>
<p>Types <code class="language-plaintext highlighter-rouge">T</code> that satisfy <code class="language-plaintext highlighter-rouge">isbitstype(T) == true</code> are a subset of immutable types.
They are immutable types that reference only other <code class="language-plaintext highlighter-rouge">isbitstype</code> types or <em>primitive types</em>.
Primitive types are types whose data are a simple collection of bits.
A collection of primitive types <a href="https://docs.julialang.org/en/v1/manual/types/#Primitive-Types">is defined by base</a>.
The purpose of primitive types is to facilitate interoperability with LLVM.</p>
<h2 id="case-study-stack-allocated-vectors-aka-a-very-brief-introduction-to-staticarraysjl">Case Study: Stack-Allocated Vectors (A.K.A. a Very Brief Introduction to StaticArrays.jl)</h2>
<p>The usual <code class="language-plaintext highlighter-rouge">Vector{Float64}</code> is mutable, which means that it is heap allocated.
Let’s see if we can create a more performant vector by creating a vector type that is allocated on the <em>stack</em>.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">struct</span><span class="nc"> StackVector</span><span class="x">{</span><span class="n">N</span><span class="x">}</span>
<span class="n">data</span><span class="o">::</span><span class="kt">NTuple</span><span class="x">{</span><span class="n">N</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="k">end</span>
<span class="n">StackVector</span><span class="x">(</span><span class="n">data</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="o">=</span> <span class="n">StackVector</span><span class="x">(</span><span class="kt">Tuple</span><span class="x">(</span><span class="n">data</span><span class="x">))</span>
</code></pre></div></div>
<p>Define <code class="language-plaintext highlighter-rouge">+</code> for our newly defined <code class="language-plaintext highlighter-rouge">StackVector</code>.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">import</span> <span class="n">Base</span><span class="o">:</span> <span class="o">+</span>
<span class="o">+</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="n">StackVector</span><span class="x">{</span><span class="n">N</span><span class="x">},</span> <span class="n">y</span><span class="o">::</span><span class="n">StackVector</span><span class="x">{</span><span class="n">N</span><span class="x">})</span> <span class="k">where</span> <span class="n">N</span> <span class="o">=</span> <span class="n">StackVector</span><span class="x">{</span><span class="n">N</span><span class="x">}(</span><span class="n">x</span><span class="o">.</span><span class="n">data</span> <span class="o">.+</span> <span class="n">y</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
</code></pre></div></div>
<p>Let’s check that this works as intended.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">10</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">y</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">10</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">stack_x</span> <span class="o">=</span> <span class="n">StackVector</span><span class="x">(</span><span class="n">x</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">stack_y</span> <span class="o">=</span> <span class="n">StackVector</span><span class="x">(</span><span class="n">y</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="mi">10</span><span class="o">-</span><span class="n">element</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}</span><span class="o">:</span>
<span class="o">-</span><span class="mf">0.5453143850886275</span>
<span class="mf">2.120385168072067</span>
<span class="mf">1.1278328263047377</span>
<span class="mf">1.6358682579762607</span>
<span class="o">-</span><span class="mf">0.22486252827622277</span>
<span class="o">-</span><span class="mf">2.1333012655133836</span>
<span class="mf">2.6754332229859767</span>
<span class="o">-</span><span class="mf">0.7701873679976846</span>
<span class="mf">0.26775849165909</span>
<span class="o">-</span><span class="mf">2.7389288669831786</span>
<span class="n">julia</span><span class="o">></span> <span class="n">collect</span><span class="x">((</span><span class="n">stack_x</span> <span class="o">+</span> <span class="n">stack_y</span><span class="x">)</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
<span class="mi">10</span><span class="o">-</span><span class="n">element</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}</span><span class="o">:</span>
<span class="o">-</span><span class="mf">0.5453143850886275</span>
<span class="mf">2.120385168072067</span>
<span class="mf">1.1278328263047377</span>
<span class="mf">1.6358682579762607</span>
<span class="o">-</span><span class="mf">0.22486252827622277</span>
<span class="o">-</span><span class="mf">2.1333012655133836</span>
<span class="mf">2.6754332229859767</span>
<span class="o">-</span><span class="mf">0.7701873679976846</span>
<span class="mf">0.26775849165909</span>
<span class="o">-</span><span class="mf">2.7389288669831786</span>
</code></pre></div></div>
<p>That looks good.
Now let’s see what avoiding allocations on the heap gets us.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="k">using</span> <span class="n">BenchmarkTools</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@benchmark</span> <span class="o">$</span><span class="n">x</span> <span class="o">+</span> <span class="o">$</span><span class="n">y</span>
<span class="n">BenchmarkTools</span><span class="o">.</span><span class="n">Trial</span><span class="o">:</span>
<span class="n">memory</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">160</span> <span class="n">bytes</span>
<span class="n">allocs</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">1</span>
<span class="o">--------------</span>
<span class="n">minimum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">53.664</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">median</span> <span class="n">time</span><span class="o">:</span> <span class="mf">56.126</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">mean</span> <span class="n">time</span><span class="o">:</span> <span class="mf">59.544</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">1.83</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">maximum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">572.958</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">87.42</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="o">--------------</span>
<span class="n">samples</span><span class="o">:</span> <span class="mi">10000</span>
<span class="n">evals</span><span class="o">/</span><span class="n">sample</span><span class="o">:</span> <span class="mi">987</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@benchmark</span> <span class="o">$</span><span class="n">stack_x</span> <span class="o">+</span> <span class="o">$</span><span class="n">stack_y</span>
<span class="n">BenchmarkTools</span><span class="o">.</span><span class="n">Trial</span><span class="o">:</span>
<span class="n">memory</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span> <span class="n">bytes</span>
<span class="n">allocs</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span>
<span class="o">--------------</span>
<span class="n">minimum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">0.052</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">median</span> <span class="n">time</span><span class="o">:</span> <span class="mf">0.055</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">mean</span> <span class="n">time</span><span class="o">:</span> <span class="mf">0.055</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">maximum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">0.099</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="o">--------------</span>
<span class="n">samples</span><span class="o">:</span> <span class="mi">10000</span>
<span class="n">evals</span><span class="o">/</span><span class="n">sample</span><span class="o">:</span> <span class="mi">1000</span>
</code></pre></div></div>
<p>Whoa!
What happened here is that the compiler is a little too clever:
it managed to figure out the answer at compile time and essentially hardcoded the answer.
Compare this with</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="nd">@benchmark</span> <span class="o">$</span><span class="x">(</span><span class="n">stack_x</span> <span class="o">+</span> <span class="n">stack_y</span><span class="x">)</span>
<span class="n">BenchmarkTools</span><span class="o">.</span><span class="n">Trial</span><span class="o">:</span>
<span class="n">memory</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span> <span class="n">bytes</span>
<span class="n">allocs</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span>
<span class="o">--------------</span>
<span class="n">minimum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">0.052</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">median</span> <span class="n">time</span><span class="o">:</span> <span class="mf">0.055</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">mean</span> <span class="n">time</span><span class="o">:</span> <span class="mf">0.056</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">maximum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">8.968</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="o">--------------</span>
<span class="n">samples</span><span class="o">:</span> <span class="mi">10000</span>
<span class="n">evals</span><span class="o">/</span><span class="n">sample</span><span class="o">:</span> <span class="mi">1000</span>
</code></pre></div></div>
<p>To stop the compiler from being too clever, <a href="https://github.com/JuliaCI/BenchmarkTools.jl#quick-start">BenchmarkTools.jl</a> advises the following trick:</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="nd">@benchmark</span> <span class="o">$</span><span class="x">(</span><span class="kt">Ref</span><span class="x">(</span><span class="n">stack_x</span><span class="x">))[]</span> <span class="o">+</span> <span class="o">$</span><span class="x">(</span><span class="kt">Ref</span><span class="x">(</span><span class="n">stack_y</span><span class="x">))[]</span>
<span class="n">BenchmarkTools</span><span class="o">.</span><span class="n">Trial</span><span class="o">:</span>
<span class="n">memory</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span> <span class="n">bytes</span>
<span class="n">allocs</span> <span class="n">estimate</span><span class="o">:</span> <span class="mi">0</span>
<span class="o">--------------</span>
<span class="n">minimum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">2.276</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">median</span> <span class="n">time</span><span class="o">:</span> <span class="mf">2.293</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">mean</span> <span class="n">time</span><span class="o">:</span> <span class="mf">2.401</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="n">maximum</span> <span class="n">time</span><span class="o">:</span> <span class="mf">30.049</span> <span class="n">ns</span> <span class="x">(</span><span class="mf">0.00</span><span class="o">%</span> <span class="n">GC</span><span class="x">)</span>
<span class="o">--------------</span>
<span class="n">samples</span><span class="o">:</span> <span class="mi">10000</span>
<span class="n">evals</span><span class="o">/</span><span class="n">sample</span><span class="o">:</span> <span class="mi">1000</span>
</code></pre></div></div>
<p>That looks more reasonable.
For this small array, compared to the allocating on the heap, that’s an 25x improvement in runtime!
This example demonstrates that allocations on the heap can substantially contribute to the total runtime of a program.</p>
<p>The idea of allocating vectors on the stack is certainly not mine.
Check out the fantastic <a href="https://github.com/JuliaArrays/StaticArrays.jl">StaticArrays.jl</a>, which provides a generic implementation of stack-allocated arrays.
If the size of the array is small, <a href="https://github.com/JuliaArrays/StaticArrays.jl#speed">these stack-allocated arrays can be significantly more performant than their heap-allocated counterparts</a>.
StaticArrays.jl works by automagically generating implementations of linear algebra operations that are optimised for specific sizes of vectors or matrices by using <a href="https://docs.julialang.org/en/v1/manual/metaprogramming/#Generated-functions">generated functions</a>.</p>
<h2 id="garbage-collection">Garbage Collection</h2>
<p>As more and more objects are allocated on the heap, eventually the heap fills up.
The purpose of the <em>garbage collector</em> is to clean up the heap every once in a while.
The underlying principle of garbage collection is that objects are considered <em>garbage</em>, hence can be cleaned, if it can be proven that they cannot be <em>reached</em> (used) anymore in future code.</p>
<p>Julia’s garbage collector algorithm is called <em>mark and sweep</em>.
This algorithm consists of two phases:
the <em>mark phase</em>, where all objects that are <em>not</em> garbage are found and marked so;
and the <em>sweep phase</em>, where all <em>unmarked</em> objects are cleaned.
The mark phase first establishes a set of objects that are definitely <em>not</em> garbage.
This set is called the <em>root set</em>, and <a href="https://stackoverflow.com/questions/30080745/how-does-the-mark-in-mark-and-sweep-function-trace-out-the-set-of-objects-acce">essentially consists of all global variables and everything on the stack</a>.
The garbage collector then follows everything that the root set references, and everything that those references reference, and marks those objects along the way.</p>
<p>During the sweep phase, the unmarked objects are <em>freed</em>, which simply means that it is internally recorded that their memory can be freely overwritten and used for something else.
These unmarked objects are found by walking through the whole heap.
Marked objects, on the other hand, remain untouched.
They are also not moved around:
you can imagine that the memory used by marked objects can sometimes be rearranged into a more compact arrangement.
This, however, takes time.
That Julia’s garbage collector does not move marked objects around is referred to by saying that Julia’s mark-and-sweep algorithm is <em>non-moving</em> or <em>non-compacting</em>.</p>
<p>There is more fancy stuff going on.
For example, Julia’s garbage collector is <a href="https://en.wikipedia.org/wiki/Tracing_garbage_collection#Generational_GC_(ephemeral_GC)"><em>generational</em></a>.
You can check out the docstrings of <a href="https://github.com/JuliaLang/julia/blob/master/src/gc.c">gc.c</a> for more details.</p>Immutable and Mutable TypesJulia Learning Circle: JIT and Method Invalidations2020-11-07T00:00:00+00:002020-11-07T00:00:00+00:00https://wessel.ai/2020/11/07/julia-learning-circle-meeting-1<p>I am participating in a learning circle with the goal of gaining a better understanding of the <a href="https://julialang.org/">Julia language</a>.
To better retain what we learn, I will be turning my notes into small blog posts.
The posts should be simple, quick, but hopefully enjoyable reads.</p>
<p>The code snippets in this post are run on Julia 1.6.0-DEV.1440.</p>
<h2 id="just-in-time-compilation">Just-in-Time Compilation</h2>
<p>The first time a method is run, it will <a href="https://en.wikipedia.org/wiki/Just-in-time_compilation">just-in-time</a> (JIT) be compiled.
The compilation time can be measured with <code class="language-plaintext highlighter-rouge">@time</code>.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">A</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float64</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">3</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@time</span> <span class="n">inv</span><span class="x">(</span><span class="n">A</span><span class="x">);</span>
<span class="mf">0.244590</span> <span class="n">seconds</span> <span class="x">(</span><span class="mf">559.50</span> <span class="n">k</span> <span class="n">allocations</span><span class="o">:</span> <span class="mf">31.983</span> <span class="n">MiB</span><span class="x">,</span> <span class="mf">2.82</span><span class="o">%</span> <span class="n">gc</span> <span class="n">time</span><span class="x">,</span> <span class="mf">99.94</span><span class="o">%</span> <span class="n">compilation</span> <span class="n">time</span><span class="x">)</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@time</span> <span class="n">inv</span><span class="x">(</span><span class="n">A</span><span class="x">);</span>
<span class="mf">0.000015</span> <span class="n">seconds</span> <span class="x">(</span><span class="mi">4</span> <span class="n">allocations</span><span class="o">:</span> <span class="mf">1.953</span> <span class="n">KiB</span><span class="x">)</span>
</code></pre></div></div>
<p>The method <code class="language-plaintext highlighter-rouge">inv(::Vector{Float64})</code> is now compiled and fast to call.
However, for example <code class="language-plaintext highlighter-rouge">inv(::Vector{Float32})</code> is not yet compiled, and will consequently incur compilation time.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">A</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">3</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@time</span> <span class="n">inv</span><span class="x">(</span><span class="n">A</span><span class="x">);</span>
<span class="mf">0.188690</span> <span class="n">seconds</span> <span class="x">(</span><span class="mf">449.85</span> <span class="n">k</span> <span class="n">allocations</span><span class="o">:</span> <span class="mf">25.852</span> <span class="n">MiB</span><span class="x">,</span> <span class="mf">96.79</span><span class="o">%</span> <span class="n">compilation</span> <span class="n">time</span><span class="x">)</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@time</span> <span class="n">inv</span><span class="x">(</span><span class="n">A</span><span class="x">);</span>
<span class="mf">0.000017</span> <span class="n">seconds</span> <span class="x">(</span><span class="mi">4</span> <span class="n">allocations</span><span class="o">:</span> <span class="mf">1.125</span> <span class="n">KiB</span><span class="x">)</span>
</code></pre></div></div>
<p>The Julia JIT is simple:
it compiles a method once the method is required.
This, however, comes at the cost of start-up time and delays during runtime.
Other approaches, like <a href="https://www.pypy.org/">PyPy</a>, first run the code on an interpreter, profile the code, and then compile bits of the code based on the profiling results;
this is called <a href="https://en.wikipedia.org/wiki/Profile-guided_optimization">profile-guided optimisation</a> (POGO).</p>
<h2 id="method-invalidation">Method Invalidation</h2>
<p>Once a method is compiled, it can happen that it needs to be recompiled.
Namely, a method is compiled under certain assumptions, and these assumptions may not hold anymore as more code is loaded.</p>
<p>For example, suppose that a compiled method <code class="language-plaintext highlighter-rouge">m</code> uses the instance <code class="language-plaintext highlighter-rouge">my_add(x::Float64, y::Float64)</code> obtained from the implementation for <code class="language-plaintext highlighter-rouge">my_add(x::Real, y::Real)</code>.
If a direct implementation of <code class="language-plaintext highlighter-rouge">my_add(x::Float64, y::Float64)</code> is then added, the compiled method <code class="language-plaintext highlighter-rouge">m</code> needs to be recompiled to make use of this direct implementation: <code class="language-plaintext highlighter-rouge">m</code> gets <em>invalidated</em>.</p>
<p>Here’s that example:</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">my_add</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Real</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Real</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">my_add</span> <span class="x">(</span><span class="n">generic</span> <span class="k">function</span><span class="nf"> with</span> <span class="mi">1</span> <span class="n">method</span><span class="x">)</span>
<span class="n">julia</span><span class="o">></span> <span class="n">my_sum</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="n">T</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span><span class="o"><:</span><span class="kt">Real</span> <span class="o">=</span> <span class="n">reduce</span><span class="x">(</span><span class="n">my_add</span><span class="x">,</span> <span class="n">x</span><span class="x">;</span> <span class="n">init</span><span class="o">=</span><span class="n">one</span><span class="x">(</span><span class="n">T</span><span class="x">))</span>
<span class="n">my_sum</span> <span class="x">(</span><span class="n">generic</span> <span class="k">function</span><span class="nf"> with</span> <span class="mi">1</span> <span class="n">method</span><span class="x">)</span>
<span class="n">julia</span><span class="o">></span> <span class="n">my_sum</span><span class="x">(</span><span class="n">randn</span><span class="x">(</span><span class="mi">10</span><span class="x">))</span>
<span class="mf">0.65443378603631</span>
</code></pre></div></div>
<p>We then add a direct implementation for <code class="language-plaintext highlighter-rouge">my_add(x::Float64, y::Float64)</code>.
To detect the method invalidation, we use <a href="https://github.com/timholy/SnoopCompile.jl">SnoopCompile.jl</a>.</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">trees</span> <span class="o">=</span> <span class="n">invalidation_trees</span><span class="x">(</span><span class="nd">@snoopr</span> <span class="k">begin</span>
<span class="n">my_add</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Float64</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="k">end</span><span class="x">)</span>
<span class="mi">1</span><span class="o">-</span><span class="n">element</span> <span class="kt">Vector</span><span class="x">{</span><span class="n">SnoopCompile</span><span class="o">.</span><span class="n">MethodInvalidations</span><span class="x">}</span><span class="o">:</span>
<span class="n">inserting</span> <span class="n">my_add</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Float64</span><span class="x">)</span> <span class="k">in</span> <span class="n">Main</span> <span class="n">at</span> <span class="n">REPL</span><span class="x">[</span><span class="mi">12</span><span class="x">]</span><span class="o">:</span><span class="mi">2</span> <span class="n">invalidated</span><span class="o">:</span>
<span class="n">backedges</span><span class="o">:</span> <span class="mi">1</span><span class="o">:</span> <span class="n">superseding</span> <span class="n">my_add</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Real</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Real</span><span class="x">)</span> <span class="k">in</span> <span class="n">Main</span> <span class="n">at</span> <span class="n">REPL</span><span class="x">[</span><span class="mi">8</span><span class="x">]</span><span class="o">:</span><span class="mi">1</span> <span class="n">with</span> <span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">my_add</span><span class="x">(</span><span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">)</span> <span class="x">(</span><span class="mi">10</span> <span class="n">children</span><span class="x">)</span>
<span class="mi">1</span> <span class="n">mt_cache</span>
<span class="n">julia</span><span class="o">></span> <span class="n">trees</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">backedges</span><span class="x">[</span><span class="k">end</span><span class="x">]</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">my_add</span><span class="x">(</span><span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">)</span> <span class="n">at</span> <span class="n">depth</span> <span class="mi">0</span> <span class="n">with</span> <span class="mi">10</span> <span class="n">children</span>
<span class="n">julia</span><span class="o">></span> <span class="n">show</span><span class="x">(</span><span class="n">trees</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">backedges</span><span class="x">[</span><span class="k">end</span><span class="x">];</span> <span class="n">minchildren</span><span class="o">=</span><span class="mi">0</span><span class="x">,</span> <span class="n">maxdepth</span><span class="o">=</span><span class="mi">100</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">my_add</span><span class="x">(</span><span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">)</span> <span class="x">(</span><span class="mi">10</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="x">(</span><span class="o">::</span><span class="n">Base</span><span class="o">.</span><span class="n">BottomRF</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">)})(</span><span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">)</span> <span class="x">(</span><span class="mi">9</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">_foldl_impl</span><span class="x">(</span><span class="o">::</span><span class="n">Base</span><span class="o">.</span><span class="n">BottomRF</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">)},</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="x">(</span><span class="mi">8</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">foldl_impl</span><span class="x">(</span><span class="o">::</span><span class="n">Base</span><span class="o">.</span><span class="n">BottomRF</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">)},</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="x">(</span><span class="mi">7</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">mapfoldl_impl</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">identity</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">),</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="x">(</span><span class="mi">6</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">_mapreduce_dim</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">identity</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">),</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">},</span> <span class="o">::</span><span class="kt">Colon</span><span class="x">)</span> <span class="x">(</span><span class="mi">5</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">var</span><span class="s">"#mapreduce#665"</span><span class="x">(</span><span class="o">::</span><span class="kt">Colon</span><span class="x">,</span> <span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">mapreduce</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">identity</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">),</span> <span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="x">(</span><span class="mi">4</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="x">(</span><span class="o">::</span><span class="n">Base</span><span class="o">.</span><span class="n">var</span><span class="s">"#mapreduce##kw"</span><span class="x">)(</span><span class="o">::</span><span class="kt">NamedTuple</span><span class="x">{(</span><span class="o">:</span><span class="n">init</span><span class="x">,),</span> <span class="kt">Tuple</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}},</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">mapreduce</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">identity</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">),</span> <span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="x">(</span><span class="mi">3</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">var</span><span class="s">"#reduce#667"</span><span class="x">(</span><span class="o">::</span><span class="n">Base</span><span class="o">.</span><span class="n">Iterators</span><span class="o">.</span><span class="n">Pairs</span><span class="x">{</span><span class="kt">Symbol</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Tuple</span><span class="x">{</span><span class="kt">Symbol</span><span class="x">},</span> <span class="kt">NamedTuple</span><span class="x">{(</span><span class="o">:</span><span class="n">init</span><span class="x">,),</span> <span class="kt">Tuple</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}}},</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">reduce</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">),</span> <span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="x">(</span><span class="mi">2</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="x">(</span><span class="o">::</span><span class="n">Base</span><span class="o">.</span><span class="n">var</span><span class="s">"#reduce##kw"</span><span class="x">)(</span><span class="o">::</span><span class="kt">NamedTuple</span><span class="x">{(</span><span class="o">:</span><span class="n">init</span><span class="x">,),</span> <span class="kt">Tuple</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}},</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">reduce</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">my_add</span><span class="x">),</span> <span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="x">(</span><span class="mi">1</span> <span class="n">children</span><span class="x">)</span>
<span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">my_sum</span><span class="x">(</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span> <span class="x">(</span><span class="mi">0</span> <span class="n">children</span><span class="x">)</span>
</code></pre></div></div>
<p>This shows the whole call stack.
You can interactively navigate the stack with <code class="language-plaintext highlighter-rouge">ascend(trees[1].backedges[end])</code>, which uses <a href="https://github.com/JuliaDebug/Cthulhu.jl">Cthulhu.jl</a>.</p>
<p>Let’s perform some timings to see whether we can detect delays due to method invalidations.
Start up a fresh Julia REPL.</p>
<div title="Invalidation" class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="k">using</span> <span class="n">SnoopCompile</span>
<span class="n">julia</span><span class="o">></span> <span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">10</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">my_add</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Real</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Real</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="x">;</span>
<span class="n">julia</span><span class="o">></span> <span class="n">my_sum</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="n">T</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span><span class="o"><:</span><span class="kt">Real</span> <span class="o">=</span> <span class="n">reduce</span><span class="x">(</span><span class="n">my_add</span><span class="x">,</span> <span class="n">x</span><span class="x">;</span> <span class="n">init</span><span class="o">=</span><span class="n">one</span><span class="x">(</span><span class="n">T</span><span class="x">));</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@time</span> <span class="n">my_sum</span><span class="x">(</span><span class="n">x</span><span class="x">);</span>
<span class="mf">0.023856</span> <span class="n">seconds</span> <span class="x">(</span><span class="mf">79.31</span> <span class="n">k</span> <span class="n">allocations</span><span class="o">:</span> <span class="mf">4.761</span> <span class="n">MiB</span><span class="x">,</span> <span class="mf">99.88</span><span class="o">%</span> <span class="n">compilation</span> <span class="n">time</span><span class="x">)</span>
<span class="n">julia</span><span class="o">></span> <span class="n">my_add</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Float64</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Float64</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="x">;</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@time</span> <span class="n">my_sum</span><span class="x">(</span><span class="n">x</span><span class="x">);</span>
<span class="mf">0.016896</span> <span class="n">seconds</span> <span class="x">(</span><span class="mf">53.17</span> <span class="n">k</span> <span class="n">allocations</span><span class="o">:</span> <span class="mf">2.952</span> <span class="n">MiB</span><span class="x">,</span> <span class="mf">99.94</span><span class="o">%</span> <span class="n">compilation</span> <span class="n">time</span><span class="x">)</span>
</code></pre></div></div>
<div title="No Invalidation" class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="k">using</span> <span class="n">SnoopCompile</span>
<span class="n">julia</span><span class="o">></span> <span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="mi">10</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">my_add</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Real</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Real</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="x">;</span>
<span class="n">julia</span><span class="o">></span> <span class="n">my_sum</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="n">T</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span><span class="o"><:</span><span class="kt">Real</span> <span class="o">=</span> <span class="n">reduce</span><span class="x">(</span><span class="n">my_add</span><span class="x">,</span> <span class="n">x</span><span class="x">;</span> <span class="n">init</span><span class="o">=</span><span class="n">one</span><span class="x">(</span><span class="n">T</span><span class="x">));</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@time</span> <span class="n">my_sum</span><span class="x">(</span><span class="n">x</span><span class="x">);</span>
<span class="mf">0.023979</span> <span class="n">seconds</span> <span class="x">(</span><span class="mf">79.31</span> <span class="n">k</span> <span class="n">allocations</span><span class="o">:</span> <span class="mf">4.761</span> <span class="n">MiB</span><span class="x">,</span> <span class="mf">99.89</span><span class="o">%</span> <span class="n">compilation</span> <span class="n">time</span><span class="x">)</span>
<span class="n">julia</span><span class="o">></span> <span class="n">my_add</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Float32</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Float32</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="x">;</span>
<span class="n">julia</span><span class="o">></span> <span class="nd">@time</span> <span class="n">my_sum</span><span class="x">(</span><span class="n">x</span><span class="x">);</span>
<span class="mf">0.000004</span> <span class="n">seconds</span> <span class="x">(</span><span class="mi">1</span> <span class="n">allocation</span><span class="o">:</span> <span class="mi">16</span> <span class="n">bytes</span><span class="x">)</span>
</code></pre></div></div>
<p>In the first case, where <code class="language-plaintext highlighter-rouge">my_add(::Float64, ::Float64)</code> gets invalidated, the second call of <code class="language-plaintext highlighter-rouge">my_sum(x)</code> again incurs compilation time.
This does not happen in the second case.</p>
<p>Lastly, we discuss one more common scenario in which method invalidations happen.
Consider</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">f</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span> <span class="o">=</span> <span class="mi">1</span><span class="x">;</span>
<span class="n">julia</span><span class="o">></span> <span class="n">g</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">f</span><span class="x">(</span><span class="n">x</span><span class="x">);</span>
<span class="n">julia</span><span class="o">></span> <span class="n">g</span><span class="x">(</span><span class="s">"1"</span><span class="x">)</span>
<span class="n">ERROR</span><span class="o">:</span> <span class="kt">MethodError</span><span class="o">:</span> <span class="n">no</span> <span class="n">method</span> <span class="n">matching</span> <span class="n">f</span><span class="x">(</span><span class="o">::</span><span class="kt">String</span><span class="x">)</span>
<span class="n">Closest</span> <span class="n">candidates</span> <span class="n">are</span><span class="o">:</span>
<span class="n">f</span><span class="x">(</span><span class="o">::</span><span class="kt">Int64</span><span class="x">)</span> <span class="n">at</span> <span class="n">REPL</span><span class="x">[</span><span class="mi">8</span><span class="x">]</span><span class="o">:</span><span class="mi">1</span>
<span class="n">Stacktrace</span><span class="o">:</span>
<span class="x">[</span><span class="mi">1</span><span class="x">]</span> <span class="n">g</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">String</span><span class="x">)</span>
<span class="err">@</span> <span class="n">Main</span> <span class="o">./</span><span class="n">REPL</span><span class="x">[</span><span class="mi">9</span><span class="x">]</span><span class="o">:</span><span class="mi">1</span>
<span class="x">[</span><span class="mi">2</span><span class="x">]</span> <span class="n">top</span><span class="o">-</span><span class="n">level</span> <span class="n">scope</span>
<span class="err">@</span> <span class="n">REPL</span><span class="x">[</span><span class="mi">10</span><span class="x">]</span><span class="o">:</span><span class="mi">1</span>
</code></pre></div></div>
<p>The compiled method instance <code class="language-plaintext highlighter-rouge">g(::String)</code> gives back a <code class="language-plaintext highlighter-rouge">MethodError</code>.
In particular, it assumes that there is no implementation for <code class="language-plaintext highlighter-rouge">f(::String)</code>.
If we add that implementation, then <code class="language-plaintext highlighter-rouge">g(::String)</code> needs to be recompiled to make use of the then-available <code class="language-plaintext highlighter-rouge">f(::String)</code>.
Invalidations of this kind link back to the method table.
They show up in the property <code class="language-plaintext highlighter-rouge">mt_backedges</code> of <code class="language-plaintext highlighter-rouge">MethodInvalidations</code>:</p>
<div class="language-julia highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">julia</span><span class="o">></span> <span class="n">invalidation_trees</span><span class="x">(</span><span class="nd">@snoopr</span> <span class="k">begin</span> <span class="n">f</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">String</span><span class="x">)</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">end</span><span class="x">)</span>
<span class="mi">1</span><span class="o">-</span><span class="n">element</span> <span class="kt">Vector</span><span class="x">{</span><span class="n">SnoopCompile</span><span class="o">.</span><span class="n">MethodInvalidations</span><span class="x">}</span><span class="o">:</span>
<span class="n">inserting</span> <span class="n">f</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">String</span><span class="x">)</span> <span class="k">in</span> <span class="n">Main</span> <span class="n">at</span> <span class="n">REPL</span><span class="x">[</span><span class="mi">11</span><span class="x">]</span><span class="o">:</span><span class="mi">1</span> <span class="n">invalidated</span><span class="o">:</span>
<span class="n">mt_backedges</span><span class="o">:</span> <span class="mi">1</span><span class="o">:</span> <span class="n">signature</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">String</span><span class="x">}</span> <span class="n">triggered</span> <span class="n">MethodInstance</span> <span class="k">for</span> <span class="n">g</span><span class="x">(</span><span class="o">::</span><span class="kt">String</span><span class="x">)</span> <span class="x">(</span><span class="mi">0</span> <span class="n">children</span><span class="x">)</span>
</code></pre></div></div>I am participating in a learning circle with the goal of gaining a better understanding of the Julia language. To better retain what we learn, I will be turning my notes into small blog posts. The posts should be simple, quick, but hopefully enjoyable reads.Solutions for High-Dimensional Statistics2020-08-21T00:00:00+00:002020-08-21T00:00:00+00:00https://wessel.ai/2020/08/21/high-dimensional-statistics<p>A brief update:
<a href="https://scholar.google.com/citations?user=Jp7hKlAAAAAJ">Jiri</a> and I have been working through the new book <a href="https://www.cambridge.org/core/books/highdimensional-statistics/8A91ECEEC38F46DAB53E9FF8757C7A4E"><em>High-Dimensional Statistics: A Non-Asymptotic Viewpoint</em> by Martin E. Wainwright</a>, which has been really good so far.
In the process, we have produced solutions for a subset of the exercises.
Since some of the exercises are considerably challenging, we have decided to publicly post our worked solutions.
<a href="https://high-dimensional-statistics.github.io/">Check it out!</a></p>A brief update: Jiri and I have been working through the new book High-Dimensional Statistics: A Non-Asymptotic Viewpoint by Martin E. Wainwright, which has been really good so far. In the process, we have produced solutions for a subset of the exercises. Since some of the exercises are considerably challenging, we have decided to publicly post our worked solutions. Check it out!A Short Note on The Y Combinator2018-08-16T00:00:00+00:002018-08-16T00:00:00+00:00https://wessel.ai/2018/08/16/y-combinator<p class="pretitle">Cross-posted at the <a href="https://invenia.github.io/blog/2018/08/20/ycombinator/">Invenia blog</a>.</p>
<h2 id="introduction">Introduction</h2>
<p>This post is a short note on the notorious <em>Y combinator</em>.
No, not <a href="https://ycombinator.com">that company</a>, but the computer sciency objects that looks like this:</p>
\[\label{eq:Y-combinator}
Y = \lambda\, f : (\lambda\, x : f\,(x\, x))\, (\lambda\, x : f\,(x\, x)).\]
<p>Don’t worry if that looks complicated; we’ll get down to some examples and the nitty gritty details in just a second.
But first, <em>what</em> even is this Y combinator thing?
Simply put, the Y combinator is a higher-order function \(Y\) that can be used to define recursive functions in languages that don’t support recursion.
Cool!</p>
<p>For readers unfamiliar with the above notation, the right-hand side of Equation \eqref{eq:Y-combinator} is a <em>lambda term</em>, which is a valid expression in <a href="https://en.wikipedia.org/wiki/Lambda_calculus"><em>lambda calculus</em></a>:</p>
<ol>
<li>\(x\), a variable, is a lambda term;</li>
<li>if \(t\) is a lambda term, then the anonymous function \(\lambda\, x : t\) is a lambda term;</li>
<li>if \(s\) and \(t\) are lambda terms, then \(s\, t\) is a lambda term, which should be interpreted as \(s\) applied with argument \(t\); and</li>
<li>nothing else is a lambda term.</li>
</ol>
<p>For example, if we apply \(\lambda\, x : y\,x\) to \(z\), we find</p>
\[\label{eq:example}
(\lambda\, x : y\,x)\, z = y\,z.\]
<p>Although the notation in Equation \eqref{eq:example} suggests multiplication, note that everything is function application, because really that’s all there is in lambda calculus.</p>
<p>Consider the factorial function \(\code{fact}\):</p>
\[\label{eq:fact-recursive}
\code{fact} =
\lambda\, n :
(\code{if}\,
(\code{iszero}\, n) \,
1 \,
(\code{multiply}\,
n\,
(\code{fact}\,
(\code{subtract}\, n\, 1)))).\]
<p>In words, if \(n\) is zero, return \(1\); otherwise, multiply \(n\) with \(\code{fact}(n-1)\).
Equation \eqref{eq:fact-recursive} would be a valid expression if lambda calculus would allow us to use \(\code{fact}\) in the definition of \(\code{fact}\).
Unfortunately, it doesn’t.
Tricky.
Let’s replace the inner \(\code{fact}\) by a variable \(f\):</p>
\[\code{fact}' =
\lambda\, f: \lambda\, n :
(\code{if}\,
(\code{iszero}\, n) \,
1 \,
(\code{multiply}\,
n\,
(f\,
(\code{subtract}\, n\, 1)))).\]
<p>Now, crucially, the Y combinator \(Y\) is precisely designed to construct \(\code{fact}\) from \(\code{fact}'\):</p>
\[Y\, \code{fact}' = \code{fact}.\]
<p>To see this, let’s denote \(\code{fact2}=Y\,\code{fact}'\) and verify that \(\code{fact2}\) indeed equals \(\code{fact}\):</p>
<p>\begin{align}
\code{fact2}
&= Y\, \code{fact}’ \newline
&= (\lambda\, f : (\lambda\, x : f\,(x\, x))\, (\lambda\, x : f\,(x\, x)))\, \code{fact}’ \newline
&= (\lambda\, x : \code{fact}’\,(x\, x) )\, (\lambda\, x : \code{fact}’\,(x\, x)) \label{eq:step-1} \newline
&= \code{fact}’\, ((\lambda\, x : \code{fact}’\, (x\, x))\,(\lambda\, x : \code{fact}’\, (x\, x))) \label{eq:step-2} \newline
&= \code{fact}’\, (Y\, \code{fact}’) \newline
&= \code{fact}’\, \code{fact2},
\end{align}</p>
<p>which is <em>exactly</em> what we’re looking for, because the first argument to \(\code{fact}'\) should be the actual factorial function, \(\code{fact2}\) in this case.
Neat!</p>
<p>We hence see that \(Y\) can indeed be used to define recursive functions in languages that don’t support recursion.
Where does this magic come from, you say?
Sit tight, because that’s up next!</p>
<h2 id="deriving-the-y-combinator">Deriving the Y Combinator</h2>
<p>This section introduces a simple trick that can be used to derive Equation \eqref{eq:Y-combinator}.
We also show how this trick can be used to derive analogues of the Y combinator that implement <em>mutual recursion</em> in languages that don’t even support simple recursion.</p>
<p>Again, let’s start out by considering a recursive function:</p>
\[f = \lambda\, x:g[f, x]\]
<p>where \(g\) is some lambda term that depends on \(f\) and \(x\).
As we discussed before, such a definition is not allowed.
However, pulling out \(f\),</p>
\[\label{eq:fixed-point}
f = \underbrace{(\lambda \, f' :\lambda\, x:g[f', x])}_{h}\,\, f = h\, f.\]
<p>we do find that \(f\) is a <em>fixed point</em> of \(h\): \(f\) is invariant under applications of \(h\).
Now—and this is the trick—suppose that \(f\) is the result of a function \(\hat{f}\) applied to itself: \(f=\hat{f}\,\hat{f}\).
Then Equation \eqref{eq:fixed-point} becomes</p>
\[{\color{red}\hat{f}} \,\hat{f}
= h\,(\hat{f}\, \hat{f})
= ({\color{red}\lambda\,x:h(x\,x)})\,\,\hat{f},\]
<p>from which we, by inspection, infer that</p>
\[\hat{f} = \lambda\,x:h(x\,x).\]
<p>Therefore,</p>
\[f
= \hat{f}\hat{f}
= (\lambda\,x:h(x\,x))\,(\lambda\,x:h(x\,x)).\]
<p>Pulling out \(h\),</p>
\[f
= (\lambda\, h': (\lambda\,x:h'\,(x\,x))\,(\lambda\,x:h'\,(x\,x)))\, h
= Y\, h,\]
<p>where suddenly a wild Y combinator has appeared.</p>
<p>The above derivation shows that \(Y\) is a <em>fixed-point combinator</em>.
Passed some function \(h\), \(Y\,h\) gives a fixed point of \(h\):
\(f = Y\,h\) satisfies \(f = h\,f\).</p>
<p>Pushing it even further, consider two functions that depend on each other:</p>
<p>\begin{align}
f &= \lambda\,x:k_f[x, f, g], &
g &= \lambda\,x:k_g[x, f, g]
\end{align}</p>
<p>where \(k_f\) and \(k_g\) are lambda terms that depend on \(x\), \(f\), and \(g\).
This is foul play, as we know.
We proceed as before and pull out \(f\) and \(g\):</p>
<p>\begin{align}
f
= \underbrace{
(\lambda\,f’:\lambda\,g’:\lambda\,x:k_f[x, f’, g’])
}_{h_f} \,\, f\, g
= h_f\, f\, g
\end{align}</p>
<p>\begin{align} <br />
g
= \underbrace{
(\lambda\,f’:\lambda\,g’:\lambda\,x:k_g[x, f’, g’])
}_{h_g} \,\, f\, g
= h_g\, f\, g.
\end{align}</p>
<p>Now—here’s that trick again—let \(f = \hat{f}\,\hat{f}\,\hat{g}\) and \(g = \hat{g}\,\hat{f}\,\hat{g}\).<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>
Then</p>
<p>\begin{align}
\hat{f}\,\hat{f}\,\hat{g}
&= h_f\,(\hat{f}\,\hat{f}\,\hat{g})\,(\hat{g}\,\hat{f}\,\hat{g})
= (\lambda\,x:\lambda\,y:h_f\,(x\,x\,y)\,(y\,x\,y))\,\,\hat{f}\,\hat{g},\newline
\hat{g}\,\hat{f}\,\hat{g}
&= h_g\,(\hat{f}\,\hat{f}\,\hat{g})\,(\hat{g}\,\hat{f}\,\hat{g})
= (\lambda\,x:\lambda\,y:h_g\,(x\,x\,y)\,(y\,x\,y))\,\,\hat{f}\,\hat{g},
\end{align}</p>
<p>which suggests that</p>
<p>\begin{align}
\hat{f} &= \lambda\,x:\lambda\,y:h_f\,(x\,x\,y)\,(y\,x\,y), \newline
\hat{g} &= \lambda\,x:\lambda\,y:h_g\,(x\,x\,y)\,(y\,x\,y).
\end{align}</p>
<p>Therefore</p>
<p>\begin{align}
f
&= \hat{f}\,\hat{f}\,\hat{g} \newline
&=
(\lambda\,x:\lambda\,y:h_f\,(x\,x\,y)\,(y\,x\,y))\,
(\lambda\,x:\lambda\,y:h_f\,(x\,x\,y)\,(y\,x\,y))\,
(\lambda\,x:\lambda\,y:h_g\,(x\,x\,y)\,(y\,x\,y)) \newline
&= Y_f\, h_f\, h_g
\end{align}</p>
<p>where</p>
\[Y_f = (\lambda\, h_f':
\lambda\, h_g':
(\lambda\,x:\lambda\,y:h_f'\,(x\,x\,y)\,(y\,x\,y))\,
(\lambda\,x:\lambda\,y:h_f'\,(x\,x\,y)\,(y\,x\,y))\,
(\lambda\,x:\lambda\,y:h_g'\,(x\,x\,y)\,(y\,x\,y))).\]
<p>Similarly,</p>
\[g = Y_g\, h_f\, h_g.\]
<p><em>Dang</em>, laborious, but that worked.
And thus we have derived two analogues \(Y_f\) and \(Y_g\) of the Y combinator that implement mutual recursion in languages that don’t even support simple recursion.</p>
<h2 id="implementing-the-y-combinator-in-python">Implementing the Y Combinator in Python</h2>
<p>Well, that’s cool and all, but let’s see whether this Y combinator thing actually works.
Consider the following nearly 1-to-1 translation of \(Y\) and \(\code{fact}'\) to Python:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">Y</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">f</span><span class="p">:</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">(</span><span class="n">x</span><span class="p">)))(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="n">fact</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">f</span><span class="p">:</span> <span class="k">lambda</span> <span class="n">n</span><span class="p">:</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">n</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">n</span> <span class="o">*</span> <span class="n">f</span><span class="p">(</span><span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>
<p>If we try to run this, we run into some weird recursion:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">Y</span><span class="p">(</span><span class="n">fact</span><span class="p">)(</span><span class="mi">4</span><span class="p">)</span>
<span class="nb">RecursionError</span><span class="p">:</span> <span class="n">maximum</span> <span class="n">recursion</span> <span class="n">depth</span> <span class="n">exceeded</span>
</code></pre></div></div>
<p>Eh?
What’s going?
Let’s, for closer inspection, once more write down \(Y\):</p>
\[Y = \lambda\, f: (\lambda\, x : f\,(x\, x))\, (\lambda\, x : f\,(x\, x)).\]
<p>After \(f\) is passed to \(Y\), \((\lambda\, x : f\,(x\, x))\) is passed to \((\lambda\, x : f\,(x\, x))\); which then evaluates \(x\, x\), which passes \((\lambda\, x : f\,(x\, x))\) to \((\lambda\, x : f\,(x\, x))\); which then again evaluates \(x\, x\), which again passes \((\lambda\, x : f\,(x\, x))\) to \((\lambda\, x : f\,(x\, x))\); <em>ad infinitum</em>.
Written down differently, evaluation of \(Y\, f\, x\) yields</p>
\[Y\, f\, x
= (Y\, f)\, x
= (Y\, (Y\, f))\, x
= (Y\, (Y\, (Y\, f)))\, x
= (Y\, (Y\, (Y\, (Y\, f))))\, x
= \ldots,\]
<p>which goes on indefinitely.
Consequently, \(Y\, f\) will not evaluate in finite time, and this is the cause of the <code class="language-plaintext highlighter-rouge">RecursionError</code>.
But we can fix this, and quite simply so: only allow the recursion—the \(x\,x\) bit—to happen when it’s passed an argument; in other words, replace</p>
\[\label{eq:strict-evaluation}
x\,x \to \lambda\,y:x\,x\,y.\]
<p>Subsituting Equation \eqref{eq:strict-evaluation} in Equation \eqref{eq:Y-combinator}, we find</p>
\[\label{eq:strict-Y-combinator}
Y = \lambda\, f : (\lambda\, x : f(\lambda\, y: x\, x\,y))\, (\lambda\, x : f(\lambda\, y:x\, x\, y)).\]
<p>Translating to Python,</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">Y</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">f</span><span class="p">:</span> <span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">f</span><span class="p">(</span><span class="k">lambda</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span><span class="p">(</span><span class="n">x</span><span class="p">)(</span><span class="n">y</span><span class="p">)))(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">f</span><span class="p">(</span><span class="k">lambda</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span><span class="p">(</span><span class="n">x</span><span class="p">)(</span><span class="n">y</span><span class="p">)))</span>
</code></pre></div></div>
<p>And then we try again:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">Y</span><span class="p">(</span><span class="n">fact</span><span class="p">)(</span><span class="mi">4</span><span class="p">)</span>
<span class="mi">24</span>
<span class="o">>>></span> <span class="n">Y</span><span class="p">(</span><span class="n">fact</span><span class="p">)(</span><span class="mi">3</span><span class="p">)</span>
<span class="mi">6</span>
<span class="o">>>></span> <span class="n">Y</span><span class="p">(</span><span class="n">fact</span><span class="p">)(</span><span class="mi">2</span><span class="p">)</span>
<span class="mi">2</span>
<span class="o">>>></span> <span class="n">Y</span><span class="p">(</span><span class="n">fact</span><span class="p">)(</span><span class="mi">1</span><span class="p">)</span>
<span class="mi">1</span>
</code></pre></div></div>
<p>Sweet success!</p>
<h2 id="summary">Summary</h2>
<p>To recapitulate, the Y combinator is a higher-order function that can be used to define recursion—and even mutual recursion—in languages that don’t support recursion.
One way of deriving \(Y\) is to assume that the recursive function under consideration \(f\) is the result of some other function \(\hat{f}\) applied to itself:
\(f = \hat{f}\,\hat{f}\);
after some simple manipulation, the result can then be determined by inspection.
Although \(Y\) can indeed be used to define recursive functions, it cannot be applied literally in a contemporary programming language; recursion errors might then occur.
Fortunately, this can be fixed simply by letting the recursion in \(Y\) happen when needed—that is, <em>lazily</em>.</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>Do you see why this is the appropriate generalisation of letting \(f=\hat{f}\,\hat{f}\)? <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Cross-posted at the Invenia blog.Hello, World2018-06-19T00:00:00+00:002018-06-19T00:00:00+00:00https://wessel.ai/2018/06/19/hello-world<p>Hello, world! Another blog has come into existence. Woo! Find out more about me <a href="/portfolio">here</a> and <a href="/about">here</a>.</p>
<p>Posts to follow soon. I promise.</p>Hello, world! Another blog has come into existence. Woo! Find out more about me here and here.