Demo: From a script towards a workflow
In this episode we will explore code quality and good practices in Python using a hands-on approach. We will together build up a small project and improve it step by step.
We will start from a relatively simple image processing script which can read a telescope image of stars and our goal is to count the number of stars in the image. Later we will want to be able to process many such images.
The (fictional) telescope images look like the one below here (in this repository we can find more):
Rough plan for this demo
(15 min) Discuss how we would solve the problem, run example code, and make it work (as part of a Jupyter notebook)?
(15 min) Refactor the positioning code into a function and a module
(15 min) Now we wish to process many images - discuss how we would approach this
(15 min) Introduce CLI and discuss the benefits
(30 min) From a script to a workflow (using Snakemake)
Starting point (spoiler alert)
We can imagine that we pieced together the following code based on some examples we found online:
import matplotlib.pyplot as plt
from skimage import io, filters, color
from skimage.measure import label, regionprops
image = io.imread("stars.png")
sigma = 0.5
# if there is a fourth channel (alpha channel), ignore it
rgb_image = image[:, :, :3]
gray_image = color.rgb2gray(rgb_image)
# apply a gaussian filter to reduce noise
image_smooth = filters.gaussian(gray_image, sigma)
# threshold the image to create a binary image (bright stars will be white, background black)
thresh = filters.threshold_otsu(image_smooth)
binary_image = image_smooth > thresh
# label connected regions (stars) in the binary image
labeled_image = label(binary_image)
# get properties of labeled regions
regions = regionprops(labeled_image)
# extract star positions (centroids)
star_positions = [region.centroid for region in regions]
# plot the original image
plt.figure(figsize=(8, 8))
plt.imshow(image, cmap="gray")
# overlay star positions with crosses
for star in star_positions:
plt.plot(star[1], star[0], "rx", markersize=5, markeredgewidth=0.1)
plt.savefig("detected-stars.png", dpi=300)
print(f"number of stars detected: {len(star_positions)}")
Plan
Topics we wish to show and discuss:
Naming (and other) conventions, project organization, modularity
The value of pure functions and immutability
Refactoring (explained through examples)
Auto-formatting and linting with tools like black, vulture, ruff
Moving a project under Git
How to document dependencies
Structuring larger software projects in a modular way
Command-line interfaces
Workflows with Snakemake
We will work together on the code on the big screen, and participants will be encouraged to give suggestions and ask questions. We will end up with a Git repository which will be shared with workshop participants.
Possible solutions
Script after some work, with command-line interface (spoiler alert)
This is one possible solution (countstars.py
):
import click
import matplotlib.pyplot as plt
from skimage import io, filters, color
from skimage.measure import label, regionprops
def convert_to_gray(image):
# if there is a fourth channel (alpha channel), ignore it
rgb_image = image[:, :, :3]
return color.rgb2gray(rgb_image)
def locate_positions(image):
gray_image = convert_to_gray(image)
# apply a gaussian filter to reduce noise
image_smooth = filters.gaussian(gray_image, sigma=0.5)
# threshold the image to create a binary image (bright objects will be white, background black)
thresh = filters.threshold_otsu(image_smooth)
binary_image = image_smooth > thresh
# label connected regions in the binary image
labeled_image = label(binary_image)
# get properties of labeled regions
regions = regionprops(labeled_image)
# extract positions (centroids)
positions = [region.centroid for region in regions]
return positions
def plot_positions(image, positions, file_name):
# plot the original image
plt.figure(figsize=(8, 8))
plt.imshow(image, cmap="gray")
# overlay positions with crosses
for y, x in positions:
plt.plot(y, x, "rx", markersize=5, markeredgewidth=0.1)
plt.savefig(file_name, dpi=300)
@click.command()
@click.option(
"--image-file", type=click.Path(exists=True), help="Path to the input image"
)
@click.option("--output-file", type=click.Path(), help="Path to the output file")
@click.option("--generate-plot", is_flag=True, default=False)
def main(image_file, output_file, generate_plot):
image = io.imread(image_file)
star_positions = locate_positions(image)
if generate_plot:
plot_positions(image, star_positions, f"detected-{image_file}")
with open(output_file, "w") as f:
f.write(f"number of stars detected: {len(star_positions)}\n")
if __name__ == "__main__":
main()
Snakemake rules which define a workflow (spoiler alert)
This is one possible solution (snakefile
):
# the comma is there because glob_wildcards returns a named tuple
numbers, = glob_wildcards("input-images/stars-{number}.png")
# rule that collects the target files
rule all:
input:
expand("results/{number}.txt", number=numbers)
rule process_data:
input:
"input-images/stars-{number}.png"
output:
"results/{number}.txt"
log:
"logs/{number}.txt"
shell:
"""
python countstars.py --image-file {input} --output-file {output}
"""
We can process as many images as we like by running:
$ snakemake --cores 4 # adjust to the number of available cores