Using the Fourier/spectral method to solve PDEs
I’ve just watched a number of videos regarding the discrete Fourier transform. One key point I had never fully understood is how the Fourier transform allows one to numerically compute derivatives. And, by this ability, the Fourier transform is a wonderful tool for computing solutions to partial differential equations (PDEs). To showcase this (and solidify my understanding), I’ll do two things in this blog post:
- compare finite difference derivatives with Fourier-computed derivatives
- solve the wave equation using the Fourier transform
Most of what I’ll describe is based on the videos from Steven Brunton’s playlist on Fourier methods: https://www.youtube.com/playlist?list=PLMrJAkhIeNNT_Xh3Oy0Y4LTj0Oxo8GqsC
Finite difference derivative vs the spectral derivative¶
Suppose we’re investigating the function $f(x) = \cos(x) e^{-x^2}$. It’s derivative is $f'(x) = e^{-x^2}(-\sin(x) -2 x \cos(x))$.
If we have samples of this function on a grid, we can compute its derivative values at the same points using a finite difference approximation as in $f'(n \Delta x) \approx \frac{f_{n+1} - f_{n-1}}{2 \Delta x}$.
On the other hand, we can also compute a Fourier transform $\hat{f} = \mathcal{F}(f) = \int f(x) e^{-i k x} dx$. Then, the Fourier transform of the derivative of $f$, $\mathcal{F}(f')$, is equal to $i k \hat{f}$.
Let’s see how these methods compare.
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-5, 5, num=25)
dx = np.diff(x)[0]
f = np.cos(x) * np.exp(-x**2)
df_exact = np.exp(-x**2) * (-np.sin(x) - 2*x*np.cos(x))
df_findiff = (f[2:] - f[:-2]) / (2 * dx)
F = np.fft.fft(f)
k = 2 * np.pi * np.fft.fftfreq(f.size, d=dx)
dF = 1j * k * F
df_spectral = np.fft.ifft(dF).real
# Plots
fig, ax = plt.subplots()
ax.plot(x, df_exact, '-x', label='exact')
ax.plot(x[1:-1], df_findiff, label='finite diff.')
ax.plot(x, df_spectral, label='spectral')
ax.legend()
<matplotlib.legend.Legend at 0x294c3ffdf10>
The above plot shows that for the given number of data points, the spectral method is much more accurate than the finite difference method.
A side note with respect to the above: I have done as if there was no difference between a continuous Fourier transform, i.e. $\hat{f} = \mathcal{F}(f) = \int f(x) e^{-i k x} dx$ and a Discrete Fourier transform, i.e. $X_k = \sum_{n=0}^{N-1} x_n \cdot e^{-i2\pi \tfrac{k}{N}n}$. This is obviously not so easy, but it turns out it is a good first approximation.
Another note with respect to the above: I have used the fftfreq
method above, which gives the list of frequencies that correspond to the Fourier bins that are computed. This is a very useful helper, because it returns the natural frequency bins computed by the fft
method. Interestingly, I had to multiply the values obtained by $2 \pi$ to get the correct "wavenumbers" for the computation.
Now, what happens if we have more data points? We should expect the error to go down for both methods. But are the rates identical? We can define an error function and plot the result on a logscale.
def calc_error(x, df_approx):
df_exact = np.exp(-x**2) * (-np.sin(x) - 2*x*np.cos(x))
return np.sum(np.square((df_approx - df_exact)))
errors = []
for n in np.logspace(1, 2, num=40):
x = np.linspace(-5, 5, num=int(n))
dx = np.diff(x)[0]
f = np.cos(x) * np.exp(-x**2)
df_findiff = (f[2:] - f[:-2]) / (2 * dx)
F = np.fft.fft(f)
k = 2 * np.pi * np.fft.fftfreq(f.size, d=dx)
dF = 1j * k * F
df_spectral = np.fft.ifft(dF).real
errors.append((int(n), calc_error(x[1:-1], df_findiff), calc_error(x, df_spectral)))
import pandas as pd
df = pd.DataFrame(errors, columns=['Npoints', 'error_fin_diff', 'error_spectral']).set_index('Npoints')
df.plot(logx=True, logy=True)
<Axes: xlabel='Npoints'>
The above graph clearly shows how much better the spectral derivative is: it quickly goes down to machine precision, while the error decrease for finite differences is much much slower. Of course, this comes with a cost, that of computing forward and inverse transforms.
Note: this more or less follows what is done in this video https://www.youtube.com/watch?v=y8SqkjoKV4k.
Solving PDEs with the Fourier transform¶
The above has shown how easily one can compute derivatives with a Fourier transform, in a very precise way. This kind of derivative often shows up in partial differential equations. Here, I want to solve the wave equation:
$$ u_{tt} = c^2 u_{xx} $$
If we Fourier transform both sides in space, we are left with an ordinary differential equation in time for $\hat{u}(k, t)$, the Fourier transform of $u(x, t)$.
$$ \hat{u}_{tt} = \frac{\mathrm{d}^2 \hat{u}}{\mathrm{d} t^2} = -c^2 k^2 \hat{u} $$
We can easily solve this ODE numerically using scipy.integrate.odeint
, which is designed for this. Here, a minor adjustment is that we have to transform the second derivative in time to a first order derivative, but this is a classic thing. So we want to solve :
$$ \left \lbrace \begin{aligned} \frac{\mathrm{d} \hat{u}}{\mathrm{d} t} = \hat{v}\\ \frac{\mathrm{d} \hat{v}}{\mathrm{d} t} = -c^2 k^2 \hat{u} \\ \end{aligned} \right . $$
Our unknown is then the vector $[\hat{u}, \hat{v}]$.
from scipy.integrate import solve_ivp
N = 50
x = np.linspace(-5, 5, num=N)
dx = np.diff(x)[0]
c = 1.
f0 = np.cos(x) * np.exp(-x**2)
f0_hat = np.fft.fft(f0)
k = 2 * np.pi * np.fft.fftfreq(f0.size, d=dx)
w0 = np.concatenate([f0_hat, np.zeros_like(f0_hat)])
def make_rhs_fun(t, state):
u_hat, v_hat = state[:N], state[N:]
du_hat = v_hat
dv_hat = - c**2 * k**2 * u_hat
return np.concatenate([du_hat, dv_hat])
sol = solve_ivp(make_rhs_fun, [0, 20], w0, t_eval=np.linspace(0, 20, num=100))
snapshots = []
for ind, t in enumerate(sol.t):
snapshot = np.fft.ifft(sol.y[:N, ind]).real
snapshots.append(snapshot)
fig, ax = plt.subplots()
for ind, snapshot in enumerate(snapshots):
ax.plot(x, snapshot + 0.1 * ind, color='k', lw=1)
This diagram looks like what we would expect: we can see two waves bouncing back and forth and reflecting on the boundaries.
Let’s transform this into an animation.
from IPython.display import HTML
import matplotlib.animation as manim
fig, ax = plt.subplots()
l, = ax.plot(x, snapshots[0])
def frame(i):
l.set_ydata(snapshots[i])
anim = manim.FuncAnimation(fig, frame, frames=range(1, len(snapshots)))
plt.close(fig)
HTML(anim.to_jshtml(fps=25))
Great, we have done it: two waves propagating along the mesh!
This section was again, following a video by Steven Brunton: https://www.youtube.com/watch?v=mMdIxa5qC9Y
Bonus: 2D wave equation¶
Can we solve the wave equation in 2D?
$$ u_{tt} = c^2 (u_{xx} + u_{yy}) $$
We need to do a 2D Fourier transform. Let’s call the 2D transform of $u(x, y, t) = \hat{u}(k_x, k_y, t)$. We arrive at
$$ \frac{\mathrm{d}^2}{\mathrm{d} t^2} \hat{u}(k_x, k_y, t) = -c^2 (k_x ^2 + k_y^2) \hat{u}(k_x, k_y, t) $$
f0_hat.shape
(50,)
50 * 50 * 2
5000
w0.shape
(100,)
N = 100
x = np.linspace(-5, 5, num=N)
y = np.linspace(-5, 5, num=N)
X, Y = np.meshgrid(x, y)
dx = np.diff(x)[0]
dy = np.diff(y)[0]
c = 1.
f0 = np.cos(2.5 * X) * np.sin(Y) * np.exp(-(X*0.8)**2 -Y**2)
f0_hat = np.fft.fft2(f0)
M = f0_hat.size
kx = 2 * np.pi * np.fft.fftfreq(f0_hat.shape[1], d=dx)
ky = 2 * np.pi * np.fft.fftfreq(f0_hat.shape[0], d=dy)
w0 = np.concatenate([f0_hat.ravel(), np.zeros_like(f0_hat).ravel()])
def make_rhs_fun(t, state):
u_hat, v_hat = state[:M], state[M:]
u_hat, v_hat = u_hat.reshape(f0_hat.shape), v_hat.reshape(f0_hat.shape)
du_hat = v_hat
dv_hat = - c**2 * (kx**2 + ky[:, None]**2) * u_hat
return np.concatenate([du_hat.ravel(), dv_hat.ravel()])
sol = solve_ivp(make_rhs_fun, [0, 50], w0, t_eval=np.linspace(0, 50, num=200))
snapshots = []
for ind, t in enumerate(sol.t):
snapshot = np.fft.ifft2(sol.y[:M, ind].reshape(f0_hat.shape)).real
snapshots.append(snapshot)
The above solves the wave equation spectrally. One tricky point is that the array broadcasting that appears in the computation of dv_hat
.
Nrows = 4
fig, axes = plt.subplots(nrows=Nrows, ncols=Nrows, layout='tight', figsize=(6, 6))
for i in range(Nrows * Nrows):
ind = np.linspace(0, 50, num=Nrows*Nrows)[i].astype(int)
ax = axes.ravel()[i]
ax.imshow(snapshots[ind])
ax.set_title(f"t={ind}")
ax.axis(False)
Finally, let’s make a movie out of this.
from IPython.display import HTML
import matplotlib.animation as manim
fig, ax = plt.subplots()
vmax = 0.2
img = ax.imshow(snapshots[0], vmin=-vmax, vmax=vmax)
ax.axis(False)
def frame(i):
img.set_data(snapshots[i])
anim = manim.FuncAnimation(fig, frame, frames=range(1, len(snapshots)))
plt.close(fig)
HTML(anim.to_jshtml(fps=10))
That’s it. I hope this posts shows how useful Fourier methods can be.
This post was entirely written using the Jupyter Notebook. Its content is BSD-licensed. You can see a static view or download this notebook with the help of nbviewer at 20250915_spectral_method_wave_equation1d.ipynb.