mandelbrot.py (Source)

"""
A little Mandelbrot plotter.
"""
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple
IntermediateResult = namedtuple('IntermediateResult', ['z','n'])
class Mandelbrot:
    def __init__(self, c, *, ax=None):
        self.ax = ax
        self.image = None
        self.scalar = np.isscalar(c)
        self.c = c
        self.z = np.zeros_like(c)
        self.count = np.zeros_like(c, dtype=int)
    def __del__(self):
        try:
            self.image.remove()
        except:
            pass
    @classmethod
    def from_axis(cls, ax=None, **kwargs):
        if ax is None:
            ax = plt.gca()
        xmin, xmax, ymin, ymax = ax.axis()
        return cls.from_limits(xmin, xmax, ymin, ymax, ax=ax, **kwargs)
    @classmethod
    def from_limits(cls, xmin, xmax, ymin, ymax, px=1024*1024, **kwargs):
        dx, dy = xmax-xmin, ymax-ymin
        ratio = dx/dy
        px_y = int(np.sqrt(px / ratio))
        px_x = px // px_y
        x, y = np.linspace(xmin, xmax, px_x), np.linspace(ymin, ymax, px_y)
        c = sum(np.meshgrid(x, y*1j))
        return cls(c, **kwargs)
    def advance(self, *, n=1, z=None, c=None):
        if z is None:
            z = self.z
        if c is None:
            c = self.c
        for _ in range(n):
            z = z**2 + c
        return z
    def divergence_count(self, *, n=1, z=None, cutoff=2):
        if z is None:
            z = self.z
            count = self.count
        else:
            count = np.zeros_like(self.c, dtype=int)
        if self.scalar:
            for _ in range(n):
                if np.abs(z) > cutoff:
                    break
                count += 1
                z = self.advance(z=z, c=self.c)
        else:
            for _ in range(n):
                escape = (np.abs(z) > cutoff)
                if escape.all():
                    break
                count[~escape] += 1
                z[~escape] = self.advance(z=z[~escape], c=self.c[~escape])
        return count
    def evolution(self, *, n=1, z=None, c=None):
        for _ in range(n):
            z = self.advance(z=z,c=c)
            yield z
    def imshow(self, data, **kwargs):
        if self.image:
            try:
                self.image.remove()
            except:
                self.image = None
        self.image = self.ax.imshow(
            data,
            extent=(
                self.c.real.min(), self.c.real.max(),
                self.c.imag.min(), self.c.imag.max(),
            ),
            origin='lower',
            **kwargs,
        )
        return self.image