cfd — Plotting#

This module contains the main public function for creating cohort flow diagrams.

Main Function#

pycohortflow.cfd.plot_cfd(data, ax=None, save_dir=None, img_name=None, save_format='png', figure_title=None, style='white', style_config_path=None, transparent=False, verbose=False, **kwargs)[source]#

Draw a vertical cohort flow diagram.

Each element of data represents one step (node) in the cohort pipeline. The function automatically calculates exclusion counts, lays out boxes and arrows, and applies a colour gradient.

Parameters:
  • data (list[dict]) –

    Ordered list of cohort nodes. Every dictionary must contain an "N" key (int) with the remaining participant count. Optional keys:

    • "heading" (str) – Title shown inside the box (defaults to "Step <i>").

    • "description" (str) – Body text below the title.

    • "exclusion_description" (str) – Label for the side-exclusion box (defaults to "Excluded").

    • "color" (str) – Override colour for this node’s main box (hex string or Matplotlib colour name).

    • "exclusion_color" (str) – Override colour for the exclusion box.

    • "heading_fontweight" (str) – Per-node override for the box heading weight, e.g. "bold" or "normal". Defaults to the style’s [text] heading_fontweight.

  • ax (matplotlib.axes.Axes | None) – An existing Matplotlib axes object to draw on. When provided the function does not create a new figure; instead it renders the diagram into the given axes and returns (ax.figure, ax). This is useful for embedding the flow chart in a subplot layout. When None (default) a new figure and axes are created automatically.

  • save_dir (str | os.PathLike | None) – Directory for saved images. Only used when img_name is also provided.

  • img_name (str | None) – Base file name (no extension) to save the figure. When None the figure is not written to disk.

  • save_format (str | list[str]) – Image format(s). Defaults to "png". Pass a list for multiple formats, e.g. ["png", "svg"].

  • figure_title (str | None) – Optional title rendered above the diagram. Applied as the axes title via set_title().

  • style (str) – Name of the built-in style to use. "white" (default) produces boxes with no background colour; "colorful" applies pastel gradients; "minimal" renders white boxes with normal-weight headings and italic side text instead of exclusion boxes. See Customise for details.

  • style_config_path (str | os.PathLike | None) – Path to a custom TOML file that selectively overrides the chosen built-in style. See Customise for details.

  • transparent (bool) – If True, the figure and axes backgrounds are set to transparent. Useful for embedding the diagram in slides or posters. Defaults to False.

  • verbose (bool) – If True, print a Saved: <path> line to stdout for every file written when img_name is provided. Defaults to False (silent).

  • **kwargs

    Ad-hoc overrides. Currently recognised keys:

    • dpi (int) – Figure resolution (ignored when ax is provided).

    • figsize (tuple[float, float]) – (width, height) in inches (ignored when ax is provided).

    • main_palette (list[str]) – Explicit list of hex colours for main boxes.

    • exclusion_palette (list[str]) – Explicit list of hex colours for exclusion boxes.

Returns:

The Matplotlib figure and axes objects so that callers can further customise the plot. When ax is provided the returned figure is ax.figure.

Return type:

tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]

Raises:

ValueError – If data is empty, a node has a higher N than the preceding node, or style is not recognised.

Example

>>> from pycohortflow import plot_cfd
>>> data = [
...     {"heading": "Registered", "N": 350},
...     {"heading": "Screened", "N": 150,
...      "exclusion_description": "Not eligible"},
...     {"heading": "Analysed", "N": 120,
...      "exclusion_description": "Lost to follow-up"},
... ]
>>> fig, ax = plot_cfd(data, figure_title="Study")

Drawing into an existing axes (e.g. a subplot):

>>> import matplotlib.pyplot as plt
>>> fig, axes = plt.subplots(1, 2, figsize=(20, 8))
>>> plot_cfd(data, ax=axes[0], figure_title="Left")
>>> plot_cfd(data, ax=axes[1], style="colorful")