The Plan

In the first section, you'll learn how to generate impressively cool images like this one, commonly named the Mandelbrot set, both in NumPy and in PyTorch.

In the second section, you'll speed up your image generation by over a factor of 10, learning about broadcasting and the speedups that come from maxing out the GPU.

In the third section, you'll generate mesmerizing videos exploring the very fabric of Julia fractal-space.

mandelbrot

Above: The Mandelbrot set, which you'll be able to generate yourself by the end of Part 1.

Hang on...what is a Julia fractal?

Steven Wittens does an incredible job of showing what a Julia set is in his post How to Fold a Julia Fractal. If you need a refresher on Julia sets, or just want to marvel at Witten's impressive ability to build intuition through evocative visualizations, I suggest you take a few minutes to check it out now. Then return when you're ready to start coding.

Part 1: From Math to Code

A Julia fractal is computed by iteratively applying a function over a set of points in the complex plane. Mathematically, this looks like:

$$ z_{n+1} = f_n(z_n, c) $$

Where $ f $ is a complex, non-linear function (i.e. map) and $ z_0 $ and $ c $ are both complex numbers.

By applying these maps iteratively over the points in the complex plane, the points that diverge to infinity are excluded from the Julia set. Those that don't diverge are part of the set.

Defining the complex plane

Let's get started defining the region of the complex plane we want to plot (xrange and yrange), as well as the resolution of our image.

xrange = yrange = (-2, 2)
resolution = (5, 5) # (x-resolution, y-resolution)

Each point corresponds to one pixel, so in this case, we'll end up with a 5x5 pixel image with 25 points evenly spaced in the grid defined by the real x-axis ranging from -2 to 2, and the imaginary y-axis ranging from -2i to 2i.

This fully specifies our complex plane. Let's create this with NumPy:

import numpy as np

def np_complex_plane(xrange=(-2,2), yrange=None, res=1024):
    '''Return a grid of points on the complex plane.'''
    
    # conveniences to simplify parameters
    if yrange == None: yrange = xrange    # default: y-range = x-range
    if type(res) == int: res = (res,res) # default: x-resolution = y-resolution
    
    # define the x and y axis values on the (not yet) complex plane
    x = np.linspace(xrange[0], xrange[1], res[0])
    # indexed backwards because images are indexed in the opposite direction on the y axis vs. the coordinate plane
    y = np.linspace(yrange[1], yrange[0], res[1])
    
    # get a grid of points corresponding to the x (real) and y (complex) axis values
    real_plane, imag_plane = np.meshgrid(x, y)

    # convert points to complex numbers
    cplane = real_plane + 1j*imag_plane
    return cplane

Testing this complex plane generating function, we see that it matches our expectations.

np_complex_plane(xrange, yrange, resolution)
array([[-2.+2.j, -1.+2.j,  0.+2.j,  1.+2.j,  2.+2.j],
       [-2.+1.j, -1.+1.j,  0.+1.j,  1.+1.j,  2.+1.j],
       [-2.+0.j, -1.+0.j,  0.+0.j,  1.+0.j,  2.+0.j],
       [-2.-1.j, -1.-1.j,  0.-1.j,  1.-1.j,  2.-1.j],
       [-2.-2.j, -1.-2.j,  0.-2.j,  1.-2.j,  2.-2.j]])

Converting this from NumPy to PyTorch is trivial. Because CUDA doesn't support complex numbers yet, we'll have to keep track of the real and imaginary values of the coordinates separately.

import torch

def torch_complex_plane(xrange=(-2,2), yrange=None, res=1000):
    '''Return a 2-tuple of grids corresponding to the real and imaginary points 
    on the complex plane, respectively.'''
    if yrange == None: yrange = xrange
    if type(res) == int: res = (res,res)
    
    # np.linspace(...) --> torch.linspace(...).cuda()
    x = torch.linspace(xrange[0], xrange[1], res[0]).cuda()
    y = torch.linspace(yrange[1], yrange[0], res[1]).cuda()
    
    # np.meshgrid --> torch.meshgrid
    real_plane, imag_plane = torch.meshgrid(x,y)
    
    cplane = tuple([real_plane.transpose(0,1), imag_plane.transpose(0,1)])
    
    return cplane

Note the slight difference in the way we represent the complex plane in PyTorch compared to NumPy, above. The first element in the tuple is the real value of the point in the complex plane, the second is the imaginary value.

torch_complex_plane(xrange, yrange, resolution)
(tensor([[-2., -1.,  0.,  1.,  2.],
         [-2., -1.,  0.,  1.,  2.],
         [-2., -1.,  0.,  1.,  2.],
         [-2., -1.,  0.,  1.,  2.],
         [-2., -1.,  0.,  1.,  2.]], device='cuda:0'),
 tensor([[ 2.,  2.,  2.,  2.,  2.],
         [ 1.,  1.,  1.,  1.,  1.],
         [ 0.,  0.,  0.,  0.,  0.],
         [-1., -1., -1., -1., -1.],
         [-2., -2., -2., -2., -2.]], device='cuda:0'))

Defining divergence

Now that we have our complex plane, we need to know how to detect whether a point has diverged. This will determine what color we plot it.

Of course, we don't have an infinitely sized numeric data type to determine whether a point has 'reached infinity' yet. On top of this, we don't have an infinite amount of time to find out whether each point exceeds this bound.

One way to approximate this is to count how many iterations it takes before a point exceeds a specified divergence_value, up to a fixed number of iterations, n_iterations.

divergence_value = 2
n_iterations = 50

n_iterations is an especially convenient parameter because it directly manages our tradeoff between time and accuracy. More iterations take longer to compute, but result in higher quality maps.

Choosing a mapping function

Now that we know how to count convergence, we just need to choose a map. Let's recall our generating function:

$$ z_{n+1} = f_n(z_n, c) $$

We'll spend most of our time here working with the most popular function for generating Julia fractals, the quadratic map:

$$ f(z) = z^2 + c $$

Two other popular choices are the sine map (implemented at the very end) and cosine map (left as an exercise for the reader):

$$ f(z) = c * sin(z) $$ $$ f(z) = c * cos(z) $$

Let's start by implementing the quadratic map in NumPy:

def np_quadratic_method(c, z, n_iterations, divergence_value):
    '''Iteratively apply the quadratic map `z = z^2 + c` for `n_iterations` times
        c: constant complex value
        z: initial complex values
    '''
    # initialize matrix of counters
    stable_steps = np.zeros_like(z, dtype=np.int32)
    
    for i in range(n_iterations):
        # mask keeps track of the points that have not diverged yet
        mask = np.less(np.abs(z), divergence_value)
        
        # increment counter for all points that haven't diverged yet
        stable_steps += mask
        
        # do one iteration of the function
        # (for performance, only update the points that haven't diverged yet)
        try:
            # update z when c is a matrix of complex numbers
            z[mask] = z[mask]**2 + c[mask]
        except:
            # update z when c is a single complex number
            z[mask] = z[mask]**2 + c
            
    # normalize values to the range [0,1]
    return stable_steps / n_iterations

Finally, to visually verify our model, let's choose a value for c for which the image is already known. Paul Bourke lists of some particularly interesting on his website. I like $ c = 0.54 + 0.54i $, which, according to Bourke's website, looks like this:

dragon-fractal

c = -0.54 + 0.54j

z_init = np_complex_plane(xrange, yrange, resolution)

img = np_quadratic_method(c, z_init, n_iterations, divergence_value)
img
array([[0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.04, 0.06, 0.02, 0.  ],
       [0.  , 0.08, 1.  , 0.08, 0.  ],
       [0.  , 0.02, 0.06, 0.04, 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  ]])

Plotting a Julia set

This is difficult to inspect by eye. Let's write a function to plot these values. By experimentation, I've found that plotting the negative log of the returned values usually gives a nice contrast of values. eps controls the sensitivity to values very close to zero.

import matplotlib.pyplot as plt

def np_plot_julia(julia_img, sz=16, eps=.1):
    img = -np.log(julia_img + eps)
    plt.figure(figsize=(sz,sz))
    plt.imshow(img)
    plt.show()
np_plot_julia(img, sz=6)

Hmm...looks like we don't have enough resolution to get a clear picture.

resolution = 64

c = -0.54 + 0.54j
z_init = np_complex_plane(xrange, yrange, resolution)
img = np_quadratic_method(c, z_init, n_iterations, divergence_value)
np_plot_julia(img, sz=6)

This is looking promising! Let's increase the resolution even further, and make the image bigger for a better comparison.

resolution = 1024

c = -0.54 + 0.54j
z_init = np_complex_plane(xrange, yrange, resolution)
img = np_quadratic_method(c, z_init, n_iterations, divergence_value)
np_plot_julia(img, sz=12)

We're starting to get very close to the original, but we're lacking in detail. Let's increase how many iterations we do from 50 to 200 iterations.

n_iterations = 200

c = -0.54 + 0.54j
z_init = np_complex_plane(xrange, yrange, resolution)
img = np_quadratic_method(c, z_init, n_iterations, divergence_value)
np_plot_julia(img, sz=12)

Looks like the same fractal to me!

And along the way, we've seen how to improve the quality of our generated images by increasing the image resolution and/or increasing the number of iterations.

Accelerating the quadratic map

Okay, now that it looks like we've got the NumPy implementation correct, let's reimplement the quadratic map in PyTorch.

def torch_complex_magnitude(r,i):
    '''returns the magnitude of a complex tensor, given a real component, `r`, and an imaginary component, `i`'''
    return torch.sqrt(r**2+i**2)

def torch_quadratic_method(c, z, n_iterations, divergence_value):
    '''Iteratively apply the quadratic map `z = z^2 + c` for `n_iterations` times
        c: tuple of the real and imaginary components of the constant value
        z: tuple of the real and imaginary components of the initial z-value
    '''
    # np.zeros_like(...) --> torch.zeros_like(...).cuda()
    stable_steps = torch.zeros_like(z[0]).cuda()
    
    for i in range(n_iterations):
        # numpy handled squaring complex magnitudes, for PyTorch we implement this ourselves
        mask = torch.lt(torch_complex_magnitude(*z), divergence_value)
        stable_steps += mask.to(torch.float32)
        
        # likewise, we manually implement one iteration of the quadratic map
        z = (z[0]**2-z[1]**2 + c[0], # real
             2*z[0]*z[1] + c[1])     # imaginary
        
    # don't forget to put the array onto the cpu for plotting!
    return (stable_steps / n_iterations).cpu()

It's trivial to convert our plotting function from NumPy to PyTorch:

def torch_plot_julia(julia_img, sz=16, eps=.1):
    # np.log --> torch.log
    img = -torch.log(julia_img + eps)
    plt.figure(figsize=(sz,sz))
    plt.imshow(img)
    plt.show()

Now let's replicate the spiral fractal using our PyTorch code

resolution = 1024

# As before, note the change from using the built-in complex data type, which NumPy can handle
# to using the convention: [real, imaginary], which we manually handle using PyTorch
c = [-0.54, 0.54]

z_init = torch_complex_plane(xrange, yrange, resolution)
img = torch_quadratic_method(c, z_init, n_iterations, divergence_value)
torch_plot_julia(img, sz=12)

Nice! Looks identical to our NumPy version.

So we used a constant value for c to create this image, but the famous Mandelbrot set is generated by varying c, setting it equal the complex plane. That is, c varies for each point in the plane, so that c is equal to the coordinate at that point.

We can actually already do this, using the same function we use to initialize the complex plane! Let's test it out:

# Mandelbrot set
c = torch_complex_plane(xrange, yrange, resolution)

z_init = torch_complex_plane(xrange, yrange, resolution)
img = torch_quadratic_method(c, z_init, n_iterations, divergence_value)
torch_plot_julia(img, sz=6)

Although it is perhaps not as striking as the image of the Mandelbrot set on Wikipedia (see below), it looks like we've successfully replicated it!

mandelbrot-set-wikipedia

Part 2: Broadcasting - From One to Many

This tiny change illustrates the power of broadcasting. Broadcasting is what makes this so amazingly efficient. If you've never heard of broadcasting, here's a good place to review it. Skip to the section on 'Broadcasting Rules' if just need a quick refresher. Take all the time you need, I'm not going anywhere.

Good? Good.

Updating the data model

So, what area is a good candidate for speeding up? We can't easily parallelize the calculation of an image any further due to the iterative nature of the calculation. However, we can compute multiple images in parallel.

Right now, we're generating 1 image, in 2 dimensions (x, y). This is defined by our variable, resolution, and is reflected in the shape of our constant, c, and $ z_0 $ , z_init.

resolution, c[0].shape, z_init[0].shape
(1024, torch.Size([1024, 1024]), torch.Size([1024, 1024]))

One way to do this is to vary our complex constant, c, generating multiple images. For example, we can add a new dimension in front, describing how many images to calculate at once, sampling c along the complex plane. I'll refer to this as the grid of images.

I'll demonstrate this directly in PyTorch to get the immediate speed advantages, although it is straightforward to transcribe into NumPy.

# number of images in x and y directions
grid_resolution = 5

# sample c evenly between these values
grid_xrange = grid_yrange = [-2, 2]

# number of pixels in each image
image_resolution = 64

# range of complex valued points in each image
image_xrange = image_yrange = [-2, 2]

# take note the dimensions here
c_real, c_imag = [ri.reshape(grid_resolution, grid_resolution, 1, 1) 
                  for ri in torch_complex_plane(grid_xrange, grid_yrange, grid_resolution)]

# take note of the dimensions here
z_real, z_imag = [ri.reshape(1, 1, image_resolution, image_resolution)
                  for ri in torch_complex_plane(image_xrange, image_yrange, image_resolution)]

c_real.shape, z_real.shape
(torch.Size([5, 5, 1, 1]), torch.Size([1, 1, 64, 64]))

Now look at the resulting shape when we add c to z:

(c_real + z_real).shape
torch.Size([5, 5, 64, 64])

The way to interpret this is that we're broadcasting the initialization of the complex plane, z, across every individual value of c, effectively representing a 5x5 grid of images, each with dimension 64x64. The image at [0,0,:,:] will represent the top-left-most value of c, which according to grid_xrange and grid_yrange is $ c = -2 + 2i $.

If you don't follow this yet, don't worry! Note your confusion, continue reading, and circle back after seeing what we're going to do with this. This is not easy to follow, but I ask that you please trust me, we're nearly there and it will all make sense very soon.

Let's generalize the creation of c and z.

def make_cz_grids(grid_res, img_res, grid_x_rng=(-2,2), grid_y_rng=(-2,2), img_x_rng=(-2,2), img_y_rng=(-2,2)):
    if type(grid_res) == int: grid_res = (grid_res, grid_res)
    if type(img_res ) == int: img_res  = (img_res , img_res )
    c = [x.reshape(grid_res[0],grid_res[1],1,1) for x in torch_complex_plane(grid_x_rng, grid_y_rng, grid_res)]
    z = [x.reshape(1,1,img_res[0],img_res[1]) for x in torch_complex_plane(img_x_rng, img_x_rng, img_res)]
    return c,z
c, z = make_cz_grids(grid_resolution, image_resolution)

(c[0] + z[0]).shape
torch.Size([5, 5, 64, 64])

Now that we've updated our variables, let's update our method. Note that, because of broadcasting, we only need to change one line of code.

Updating the iteration method

I suggest you take a moment to read through this function to convince yourself that this will work.

def torch_quadratic_method(c, z, n_iterations, divergence_value):
    '''Iteratively apply the quadratic map `z = z^2 + c` for `n_iterations` times
        c: tuple of the real and imaginary components of the constant value
        z: tuple of the real and imaginary components of the initial z-value
    '''
    # add c[0] to get the right shape
    stable_steps = torch.zeros_like(c[0] + z[0]).cuda()
    
    for i in range(n_iterations):
        mask = torch.lt(torch_complex_magnitude(*z), divergence_value)
        stable_steps += mask.to(torch.float32)
        z = (z[0]**2-z[1]**2 + c[0], # real
             2*z[0]*z[1] + c[1])     # imaginary
    return (stable_steps / n_iterations).cpu()

This is but a small taste of the power and the beauty of broadcasting. Once we set up an appropriate data model, we get parallelization for free.

Time to test this out.

n_iterations = 50
c, z = make_cz_grids(grid_resolution, image_resolution)
imgs = torch_quadratic_method(c, z, n_iterations, divergence_value)
imgs.shape
torch.Size([5, 5, 64, 64])

This is the output shape we expected. Now let's plot these images!

Plotting a grid of Julia sets

def torch_plot_julia_grid(images, figsize=(12,12)):
    rows, cols = images.shape[:2]
    fig, axs = plt.subplots(rows, cols, figsize=figsize)
    for i in range(rows):
        for j in range(cols):
            axs[i,j].imshow(images[i,j])
            axs[i,j].get_xaxis().set_visible(False)
            axs[i,j].get_yaxis().set_visible(False)
    plt.show()
torch_plot_julia_grid(imgs)

This is a start, but it's not quite clear yet whether we've actually done it right. Let's try increasing the resolution of our grid and zoom in on the appropriate region.

# this is the main region where the Mandelbrot set exists
#   and also happens to be where the Julia sets are connected
grid_x_rng = (-1.5, 0.5)
grid_y_rng = (-1, 1)
grid_resolution = 21
image_resolution = 512

c, z = make_cz_grids(grid_resolution, image_resolution, grid_x_rng=grid_x_rng, grid_y_rng=grid_y_rng)
imgs = torch_quadratic_method(c, z, n_iterations, divergence_value)
torch_plot_julia_grid(imgs, figsize=(16,16))

This looks pretty good!

And what does one of these images look like?

x = 15
y = 2

print(f'c = {c[0][y][x][0][0]:0.3f} + {c[1][y][x][0][0]:0.3f}i')
torch_plot_julia(imgs[y, x], sz=8)
c = 0.000 + 0.800i

This particular Julia set ($ c = 0 + 0.8i $) is also on Bourke's website. Let's compare:

julia-c0plus8i

Looks like I probably mixed up my axes somewhere, causing a problem with symmetry. Maybe I'll get to fixing this someday. If you happen to spot it, let me know where I went wrong! (tweet me: @dcato98)

Now that you've seen what we can do with broadcasting, this a good time to go back to the beginning of the section and review the changes necessary to make this work.

Finally, I promised you we'd make a video to visualize the Julia fractal in another way. Let's get to work.

Part 3: Visualizing the Julia set as a video

To create the video we'll be using ffmpeg. If you want to follow along with this section, make sure you have it installed. I'd suggest a link, but installation varies across platforms so I'll leave you to search for instructions relevant to your situation.

Preparing the video frames

For this video, let's visualize varying c along the real x-axis, holding the imaginary y-axis constant at 0.25i.

First, we'll use the code we've already written to generate our frames. If your GPU runs out of memory at this point, try decreasing either the image resolution or the number of frames.

One way to get around this and create large videos at high quality would be to create and save the images in smaller batches before creating video (don't forget to name the frames in sequence!).

# vary the c-values along the x-axis
grid_x_rng = (-1.5, 0.5)

# hold the y-axis constant
grid_y_rng = (0.25, 0.25)

# define number of frames
grid_resolution = (120, 1)

# image quality parameters
image_resolution = 1024
n_iterations = 200

c, z = make_cz_grids(grid_resolution, image_resolution, grid_x_rng, grid_y_rng)
imgs = torch_quadratic_method(c, z, n_iterations, divergence_value)

# temporarily reshape our images into a square grid to easily visualize our frames
torch_plot_julia_grid(imgs.reshape(10,12,image_resolution, image_resolution), figsize=(16,16))

Looks good. Now let's smoosh together the grid, treating imgs as a list, not a grid, of images.

imgs = imgs.reshape(-1, image_resolution, image_resolution)
imgs.shape
torch.Size([120, 1024, 1024])

Next let's set up our file paths.

We need a working directory, path, a subfolder in this location for saving a the list of images, frame_folder, and a filename which we'll name the video, video_fn.

from pathlib import Path

# set this to your preferred working directory
path = Path('/home/dc')

frame_folder = path / f'julia_frames'

# save the video here
# video_fn = path / 'julia_visualization.mp4' # for some reason, I couldn't play .mp4 files in the notebook
video_fn = path / 'julia_visualization.webm' # MUCH slower to convert, also it looks like it might be resolution, try mp4 first

# create the frame folder
frame_folder.mkdir(exist_ok=True, parents=True)
# save images into the frame folder, this can take a while
for i, img in enumerate(imgs):
    fig = plt.figure(figsize=(16,16))
    plt.imshow(img)
    
    # turn off axis ticks
    plt.xticks([])
    plt.yticks([])
    
    # save image to `frame_folder` directory
    plt.savefig(frame_folder / f'julia_{i:08d}.png')
    # prevent plot from displaying in notebook, also reduces memory use
    plt.close(fig) 

Creating a video

# set the frame rate
fps = 30  # fps=frames per second

# change working directory to the frame folder
old_wd = Path.cwd()
frame_folder.cwd()

image_path_format = frame_folder / 'julia_%08d.png'

# create video
#! ffmpeg -r $fps -i $image_path_format -pix_fmt yuv420p -y $video_fn
! ffmpeg -r $fps -i $image_path_format -y $video_fn

# return to old working directory
old_wd.cwd()
ffmpeg version 4.0 Copyright (c) 2000-2018 the FFmpeg developers
  built with gcc 7.2.0 (crosstool-NG fa8859cb)
  configuration: --prefix=/home/dc/anaconda3/envs/fastai --cc=/opt/conda/conda-bld/ffmpeg_1531088893642/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --enable-shared --enable-static --enable-zlib --enable-pic --enable-gpl --enable-version3 --disable-nonfree --enable-hardcoded-tables --enable-avresample --enable-libfreetype --disable-openssl --disable-gnutls --enable-libvpx --enable-pthreads --enable-libopus --enable-postproc --disable-libx264
  libavutil      56. 14.100 / 56. 14.100
  libavcodec     58. 18.100 / 58. 18.100
  libavformat    58. 12.100 / 58. 12.100
  libavdevice    58.  3.100 / 58.  3.100
  libavfilter     7. 16.100 /  7. 16.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  1.100 /  5.  1.100
  libswresample   3.  1.100 /  3.  1.100
  libpostproc    55.  1.100 / 55.  1.100
Input #0, image2, from '/home/dc/julia_frames/julia_%08d.png':
  Duration: 00:00:04.80, start: 0.000000, bitrate: N/A
    Stream #0:0: Video: png, rgba(pc), 1152x1152 [SAR 2834:2834 DAR 1:1], 25 fps, 25 tbr, 25 tbn, 25 tbc
Stream mapping:
  Stream #0:0 -> #0:0 (png (native) -> vp9 (libvpx-vp9))
Press [q] to stop, [?] for help
[libvpx-vp9 @ 0x55bab7fa2fc0] v1.7.0
Output #0, webm, to '/home/dc/julia_visualization.webm':
  Metadata:
    encoder         : Lavf58.12.100
    Stream #0:0: Video: vp9 (libvpx-vp9), yuva420p, 1152x1152 [SAR 1:1 DAR 1:1], q=-1--1, 200 kb/s, 30 fps, 1k tbn, 30 tbc
    Metadata:
      encoder         : Lavc58.18.100 libvpx-vp9
    Side data:
      cpb: bitrate max/min/avg: 0/0/0 buffer size: 0 vbv_delay: -1
frame=  120 fps=1.2 q=0.0 Lsize=     111kB time=00:00:03.96 bitrate= 228.6kbits/s speed=0.0402x    
video:100kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: 10.341538%
PosixPath('/home/dc/git/blog/_notebooks')

Playing our video

This needs no explanation. Sit back and enjoy :)

from ipywidgets import Video

Video.from_file(video_fn)

And there you have it! This is a visualization of the Julia set when $ c = r + 0.25i $ where $ r \epsilon[-1.5, 0.5] $.

What's next?

There are so many fun directions to go from here! Here's a few ideas:

  • Colorization: Try experimenting with different thresholds, multiple colors, etc... A quick google search turns up many amazing examples. Can you replicate these? Can you invent your own?
  • Map: There's nothing inherently special about the quadratic map. It's well studied, but try exploring fractals made from another map.
  • Video: Instead of updating c in a line, consider updating c to travel in a circle or other continuous curve. Or, continuously zoom in on a particularly interesting patch.
  • 3D: Our image grid and video are 2 ways to try to grasp the shape of a 4D object. What would this object look like in 3D (say, by holding the real or imaginary component of c constant). Then, create a video of the 3D object as you travel along the 4th dimension.
  • Batching: Occasionally when I wrote this, I experimented with large parameters which caused the GPU to run out of memory (e.g. when creating the video) and requiring the notebook to be restarted - batching what gets sent to the GPU could prevent this from happening.

For a bit of inspiration, here's an example of experimenting with colorization and maps to generate "electric christmas trees" Set is(see image below) generated using the Sine map!

First, I've refactored our code a bit to make our Julia generator and plotting functions more generic:

# implement the quadratic and Sine maps for complex numbers
def torch_quadratic_map(c, z): return (z[0]**2-z[1]**2 + c[0], 2*z[0]*z[1] + c[1])

# modify our quadratic julia method to generalize to any mapping function
def torch_julia_generator(c=None, z=None, n_iterations=50, divergence_value=50, map_func=torch_quadratic_map):
    '''A generic julia fractal generator (defaults produce the Mandelbrot set).'''
    if c is None: c = torch_complex_plane()
    if z is None: z = torch_complex_plane()
    stable_steps = torch.zeros_like(c[0]+z[0]).cuda()
    for i in range(n_iterations):
        mask = torch.lt(z[1].abs(), divergence_value)
        stable_steps += mask.to(torch.float32)
        z = map_func(c, z)
    return (stable_steps / n_iterations).cpu()

# add `transform_func` to our standard plotting function to accept arbitrary image transformations
def torch_plot_julia(julia_img, sz=16, eps=.1, transform_func=None):
    img = -torch.log(julia_img + eps) if transform_func is None else transform_func(julia_img)
    plt.figure(figsize=(sz,sz))
    plt.imshow(img)
    plt.show()

Testing the default Julia generator:

torch_plot_julia(torch_julia_generator())

Now, we implement the Sine map. Look how easy this is with our newly refactored code!

# implementing the Sine map in PyTorch
def torch_complex_mult(ar,ai,br,bi): return (ar*br-ai*bi, ar*bi+ai*br)
def torch_complex_sin(r,i): return (torch.sin(r)*torch.cosh(i), torch.sinh(i)*torch.cos(r))
def torch_sin_map(c, z): return torch_complex_mult(*torch_complex_sin(*z), *c)

After exploring the Sine map, this is one of my favorites!

# parameters for the Electric Christmas Trees fractal
map_func = torch_sin_map
c = [1.0, 0.1]
xrange = yrange = (-5,5)
print('Electric Christmas Trees! (Sine map: c = 1.0 + 0.1i)')

# image quality parameters
resolution = 1024
n_iterations = 256
figsize = 16
tranform_func = lambda x: x   # don't transform the output, default is do: -log(julia_img)

# putting it all together!
z = torch_complex_plane(xrange, yrange, resolution)
img = torch_julia_generator(c, z, n_iterations, divergence_value, map_func=map_func)
torch_plot_julia(img, sz=figsize, transform_func=tranform_func)
Electric Christmas Trees! (Sine map: c = 1.0 + 0.1i)

I'm excited to see where you take this. Tweet me at @dcato98 with a pic/video of your amazing creations!