Picture of Aravind Ramalingam

Aravind Ramalingam

  • Python
  • Functions
  • Type Dispatch

Write Better Python Functions (Using Type Dispatch)

Writing a function in python is easy but writing a good function is not that easy. Most python programmers are often not aware of how to write clear and concise functions, let alone using Type Dispatch. Let us see how we can write better functions in python using Type Dispatch.

Getting Started

Before we get into what is Type Dispatch? or How Type dispatch can make you a better Python programmer, let us get some basics out of the way like What is a good python function? To be honest, I would rather let Jeff Knupp explain it.

I truly believe that the only way to really understand howsomething works is by breaking it into granular pieces and putting it back together or building it from scratch. Let us take a very well-defined use case and write a function for it from scratch. By scratch, I mean starting with the one-to-one translation of the pseudo-code to python code without thinking about any best practices.

Then by analyzing what works & what didn't, we can build on top of this. We are going to repeat this process until we run out of ideas to improve the code. So this post is going to be not one, but many versions of a function for the same use case.

If you are wondering Why? Because we are going to unlearning everything we know about writing a function and question each line of code as we write it. In this way, we can clearly see how each piece fits. Rather than blindly following a checklist of items about writing a function, we will create our own checklist. In case this post missed something that is of value, then please let me know I will gladly update it.

We will try to write the best or at least my best python function for a particular use case in 4 iterations. Each iteration is going to be an improvement over the last one, with a clear objective declared upfront. Getting better at anything is an iterative process, depending on where you are in the spectrum of python proficiency, one of the 4 iterations is going to resonate more with you than others. The next step of this post is understanding the use case. Buckle up, it's going to be a very interesting journey.


Use Case

Briefing

Convert the given image input into PyTorch tensor. Yes, that’s it. Quite simple right?

Folks with a data science background can skip this part and jump to the next section, as I am going to give a brief introduction about python libraries needed for this use case.

  1. Pillow is a that adds support for opening, manipulating, and saving many different image file formats.
  2. Numpy stands for Numerical Python. Numpy is the core library for scientific computing in Python. It provides a high-performance object and tools for working with these arrays.
  3. PyTorch is open-source, it provides two major functions namely computing (like NumPy) with strong acceleration via & Deep neural networks built on a tape-based automatic differentiation system. Pytorch can run the code using CPU if there is no CUDA supported GPU available but this will be way slower for training an ML model.

If you wondering why or to whom this use case is useful? Most Machine Learning applications (Face detection, OCR) with image data as input are using these libraries. This use case is considered data preprocessing in the ML world.

How do we go about this?

The steps involved are very straightforward. Open the image using Pillow library, and convert it to NumPy array & transpose it before converting to Tensor. BTW, Gentle reminder: This post is about writing better python functions and not a data pre-processing tutorial for deep learning so if it's intimidating, don't worry about it. You will feel more comfortable, once we start coding.

A couple of points to take note about the process steps before we move on.

  1. PIL image when converted to Numpy array, it takes shape (Height, Width, # of Channels). In Pytorch the image sensor must be of the shape (# of Channels, Height, Width) so we going to transpose the array before converting it to tensor.
  2. For experienced Pytorch users who are wondering, why the hell are we not using torch functions like ToTensor or transform methods and skip Numpy? Yep, you are right, we can absolutely do that but to drive the point of this post home better, I have consciously decided to take the longer route.

Hiccup

Sorry, the use case is not that simple. There is a small hiccup, The function should support multiple input data types, namely str/Path, Pillow(Image) and Numpy(Ndarray).

The process step defined in the above image only supports the data types str/Path as input. Therefore we need to add support for Pillow(Image) and Numpy(Ndarray) types, but if we think about it, these data types are intermediate steps in converting the image file to torch tensor. So there is really no additional step from the above image, we just have to duplicate the steps defined for str/Path data types and alter few initial steps to support Pillow(Image) and Numpy(Ndarray) data types as input for our function.

Process steps for 2 additional Workfloware as follows.

After comparing the 3 images it is very clear that:

  • Image data type's process steps are a subset of the original process steps.
  • Array data type's process steps are a subset of the Image data type process steps.

Solution

The First 2 iterations of the function are going to be about the behavior implementation whereas the last two iterations are all about code refractoring. If you are one of those readers who like to read the ending of a book first (I am not judging), you can jump straight to the section with the title - Iteration 4.

Iteration 1

Objective:

Make it work. As simple as that, as we are starting from scratch it is best just to focus on getting the basic feature or functionality work. We don’t need all the bells and whistles. For this iteration, we will focus on only implementing Use Process Steps and not worry about Hiccup.

Code:


 # Import required libraries
import numpy as np  
import torch  
from PIL import Image as PILImage  
  
 # Set Torch device to GPU if CUDA supported GPU is available 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
  
 # Function to convert image to Torch tensor
def Tensor(inp):  
      
     # Read image from disk using PIL Image
    img = PILImage.open(inp).convert('RGB')  
      
     # Convert the image to numpy ndarray
    imgArr = np.asarray(img, np.uint8)  
      
     # Rearrange the shape of the array so that it is pytorch compatible_  
    imgArr = imgArr.transpose(2, 0, 1)  
      
    # Convert Numpy array to Torch Tensor and move it to device 
    return torch.from_numpy(imgArr).to(device)

What have we done?

We managed to translate the process steps into code. The number of lines in the code almost matches the number of boxes in the diagram.

Quick Check:

We need a sample image to test our function so we will download one.

# Download a sample image to disk
import urllib.request   
  
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c3/Python-logo-notext.svg/200px-Python-logo-notext.svg.png"  
filename = "sample.png"  
urllib.request.urlretrieve(url, filename)

Now, let us take our function for a test drive.

# Pass the sample image through our function
imgTensor = Tensor(filename)  
  
 # Check whether the output of the function is a Tensor
assert isinstance(imgTensor, torch.Tensor), "Not a Tensor"

Looks like we managed to nail the basic functionality. Awesome!

What can we improve?

  • We are yet to implement multiple data type input functionality i.e support for data types like PIL Image, Ndarray.
  • Also, we should definitely improve the name of the function, it doesn‚Äôt really say anything about the function.

Iteration 2

Objective:

When we add support for more than one data type, we need to be very careful because without proper validation, things can go south real quick and from end user point of view the error messages won’t make any sense. So in this iteration we will:

  1. Implement multiple data type support and validations for these types.
  2. Make the function name more useful.

Code:

# Import Libraries
import numpy as np  
import torch  
from PIL import Image as PILImage  
from pathlib import Path, PurePath

# Set Torch device to GPU if CUDA supported GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
  

# Function to convert image to Torch tensor 
def from_multiInput_toTensor(inp):  
      
   # Input type - str/ Path, then read from disk & convert to array
    if isinstance(inp, (str, PurePath)):  
        try:   
            image = PILImage.open(inp).convert('RGB')  
            imageArray = np.asarray(image.copy(), np.uint8)  
        except Exception as error:  
            raise error  
              
    # Input type - PIL Image, then we convert it to array 
    elif isinstance(inp, PILImage.Image):  
        imageArray = np.asarray(inp, np.uint8)  
          
    # Input type - ndarray, then assign it to imageArray variable  
    elif isinstance(inp, np.ndarray):  
        imageArray = inp  
          
    # Raise TypeError with input type is not supported
    else:   
        raise TypeError("Input must be of Type - String or Path or PIL Image or Numpy array")  
          
    # Rearrange shape of the array so that it is pytorch compatible 
    if imageArray.ndim == 3: imageArray = imageArray.transpose(2, 0, 1)  
          
    # Convert Numpy array to Torch Tensor and move it to device
    return torch.from_numpy(imageArray).to(device)

What have we done?

We have managed to implement all the functionalities required for this use case. Since ndarray is the last data type before the output data type - tensor, we convert the input first to ndarray for every supported data type and transpose it at the end, right before converting it to tensor. By adopting this style of coding, we are able to avoid having to convert the input to tensor & return the value for every supported data type. Instead we do it only once at the end.

With the help of isinstance function, we are able to identify the data type of the input and notify the users by raising TypeError with appropriate error message if any unsupported data type is passed. Often with io operations, something can go wrong easily so if the data type is str or Path, we read the image file inside try & except block and let the user know about the error (if any).

Quick Check:

Before we proceed, we will write a helper function to check whether two tensors are same or not.

# Test if two torch tensors are same or not
def is_same_tensor(tensor1, tensor2):  
    assert torch.eq(tensor1, tensor2).all(), "The Tensors are not equal!"  
    return True

When writing a test function, better to throw a proper error and not simple print function which can get buried in between other messages. We will also check a couple of things and ensure that the updated version is all good.

  1. Is the output of the functions from last two iterations the same?
 # Verify that the output of two versions are same or not 
is_same_tensor(Tensor(filename), from_multiInput_toTensor(filename))

2. Is the support for multiple data types working?

# Check the support for Path
path = Path(Path.cwd(), filename).resolve()  
is_same_tensor(Tensor(filename), from_multiInput_toTensor(path))

 # Check the support for PIL Image  
image = PILImage.open(filename).convert('RGB')  
is_same_tensor(Tensor(filename), from_multiInput_toTensor(image))

 # Check the support for Ndarray
imageArray = np.asarray(image, np.uint8)  
is_same_tensor(Tensor(filename), from_multiInput_toTensor(imageArray))

3. Are the validations working?

 # Validate whether an error is thrown when user passes wrong file 
from_multiInput_toTensor('test.png')

# Validate whether an error is thrown when input type is list
from_multiInput_toTensor([filename])

Great! We didn’t break anything. Now let us move on with refractoring as we are done with behavior implementation.

What can we improve?

  • With the goal of adding support for multiple data types, we sort of wrote a Spaghetti code. The life cycle of the variable imageArray is confusing at the least, this type of code design is not very intuitive and can make life hell for the person who is going to maintain this code base. Of course, for our simple use case this is not case but in the spirit of writing better function, let us avoid this type of coding style.
  • The function name is a lot better but still not ideal. Why? Because it still does not clearly state what the input type for the function is. Letting the user play trail & error with TypeError is definitely not the right way of coding.
  • As the name of the function suggest it takes many inputs which means we have lot of moving parts in one function. Long run this may lead to many unintended side effects.

Iteration 3

Objective:

After taking a closer look at the points from last section, it is very clear that we need to break the function into 3 smaller ones. i.e one function for each data type.

Code:

 # Import Libraries  
import numpy as np  
import torch  
from PIL import Image as PILImage  
from pathlib import Path, PurePath

# Change numpy array to torch tensor
def numpy_ToImageTensor(imageArray):  
      
     # if not type - ndarray then raise error 
    if not isinstance(imageArray, np.ndarray):  
        raise TypeError("Input must be of Type - Numpy array")  
      
     # Set Torch device to GPU if CUDA supported GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

     # Transpose the numpy array before converting it to Tensor
    if imageArray.ndim == 3: imageArray = imageArray.transpose(2, 0, 1)  
    return torch.from_numpy(imageArray).to(device)

 # Change image to torch tensor 
def pil_ToImageTensor(image):  
      
     # if not type - PIL Image then raise error 
    if not isinstance(image, PILImage.Image):  
        raise TypeError("Input must be of Type - PIL image")  
          
    # Convert the image to numpy   
    imageArray = np.array(image)  
      
     # Return output of numpy_ToImageTensor function
    return numpy_ToImageTensor(imageArray)

 # Change image file to torch tensor 
def file_ToImageTensor(file):  
      
     if not input - string or Path then raise error 
    if not isinstance(file, (str, PurePath)):  
        raise TypeError("Input must be of Type - String or Path")  
          
     # Read the image from disk and raise error (if any)
    try:   
        image = PILImage.open(file).convert('RGB')  
    except Exception as error:  
        raise error  
      
     # Return output of pil_ToImageTensor function  
    return pil_ToImageTensor(image)

What have we done?

This is good progress, we have removed the spaghetti code and added modularity. Now it is like we have a linear chain of three functions where two of them are calling another one, which makes it easier to understand and maintain.

file_ToImageTensor -> pil_ToImageTensor -> numpy_ToImageTensor

Modularity makes adding new feature quite straight forward and easy to test as we have to alter only one specific function. Some examples for functions to modify for hypothetical feature request:

  1. file_ToImageTensor - Ability to read black & white image as well.
  2. pil_ToImageTensor - Resize the image before converting it to Tensor.

Please do note that even though you didn’t touch the rest of code base for the above stated examples, it is best to test other functions as well to avoid any surprise.

Modular design also paved the way for better function names and parameter names as the input type have become more specific.

from_multiInput_toTensor -> file_ToImageTensor

inp -> file

These subtle changes can a go long way in making the code more user friendly. Honestly at this point, I feel like I am preaching, so I will try to wrap this quickly.


Quick Check:

This post is getting too long for own liking, so I am not going to test negative or different data type scenarios again and only check for the original use case. You really shouldn’t skip testing after refractoring.

is_same_tensor(from_multiInput_toTensor(filename), file_ToImageTensor(filename))

What can we improve?

  • Looks like we are following the same pattern of checking data type with isinstance function and raising TypeError for every function, which is a kind of code repetition as well.
  • Okay, we have to do something about the function names. I am pretty sure no one will remember these names, definitely not my future self.

Iteration 4

Objective:

Improve the documentation and avoid code repetition with the help of Type Dispatch.

Type Dispatch:

Type dispatch allows you to change the way a function behaves based upon the input types it receives. This is a prominent feature in some programming languages like Julia & Swift. All we have to do is add a decorator typedispatch before our function. Probably, it is easier to demonstrate than to explain.

Example for Type Dispatch:

Function definitions:

from fastcore.all import *  
from typing import List

 # Function to multiply two ndarrays
@typedispatch  
def multiple(x:np.ndarray, y:np.ndarray ):   
    return x * y

 # Function to multiply a List by an integer  
@typedispatch  
def multiple(lst:List, x:int):   
    return [ x*val for val in lst]

# Calling 1st function:

x = np.arange(1,3)  
print(f'x is ')

y = np.array(10)  
print(f'y is ')

print(f'Result of multiplying two numpy arrays: ')

Calling 2nd function:

x = [1, 2]  
print(f'x is ')

y = 10  
print(f'y is ')

print(f'Result of multiplying a List of integers by an integer: ')

Life of a programmer can be made so much better if they don’t have to come up with different function names with no change in purpose (for various data types). If this doesn’t encourage you to use Type Dispatch whenever possible then I don’t know what will 🤷

We will be using fastcore package for implementing Type Dispatch to our use case. For more details on fastcore and Type Dispatch, check this awesome blog by Hamel Husain. Also check out fastai, which inspired me to write this post.

Code:

 # Import Libraries_  
import numpy as np  
import torch  
from PIL import Image as PILImage  
from pathlib import Path, PurePath  
from fastcore.all import *  
  

@typedispatch  
def to_imageTensor(arr: np.ndarray) -> torch.Tensor:  
    """Change ndarray to torch tensor.  
   
 The ndarray would be of the shape (Height, Width, # of Channels)  
 but pytorch tensor expects the shape as   
 (# of Channels, Height, Width) before putting   
 the Tensor on GPU if it's available.  
   
 Args:
 arr[ndarray]: Ndarray which needs to be  
 converted to torch tensor
   
 Returns: 
 Torch tensor on GPU (if it's available)
 """
      
     # Set Torch device to GPU if CUDA supported GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
      
    # Transpose the array before converting to tensor
    imgArr = arr.transpose(2, 0, 1) if arr.ndim == 3 else arr  
    return torch.Tensor(imgArr).to(device)  
  

@typedispatch  
def to_imageTensor(image: PILImage.Image) -> torch.Tensor:  
    """Change image to torch tensor
   
 The PIL image cast as numpy array with dtype as uint8,
 and then passed to to_imageTensor(arr: np.ndarray) function
 for converting numpy array to torch tensor.
   
 Args:
 image[PILImage.Image]: PIL Image which
 needs to be converted to torch tensor
   
 Returns:
 Torch tensor on GPU (if it's available)
   
 """
    return to_imageTensor(np.asarray(image, np.uint8))  
  

@typedispatch  
def to_imageTensor(file: (str, PurePath)) -> torch.Tensor:  
    """Change image file to torch tensor.
   
 Read the image from disk as 3 channels (RGB) using PIL, 
 and passed on to to_imageTensor(image: PILImage.Image)
 function for converting Image to torch tensor.
   
 Args:  
 file[str, PurePath]: Image file name which needs to
 be converted to torch tensor
   
 Returns:
 Torch tensor on GPU (if it's available)
   
 Raises:  
 Any error thrown while reading the image file,
 Mostly FileNotFoundError will be raised.
   

    try:   
        img = PILImage.open(file).convert('RGB')  
    except Exception as error:  
        raise error  
    return to_imageTensor(img)  
  

@typedispatch  
def to_imageTensor(x:object) -> None:  
    """For unsupported data types, raise TypeError. """
    raise TypeError('Input must be of Type - String or Path or PIL Image or Numpy array')

What have we done?

By utilizing the Type dispatch functionality, we managed to use the same name for all 3 functions, and each one’s behavior is differentiated by their input type. The function name is also shortened with the removal of input type. This makes the function name easier to remember.

By calling the function name, we can see what are the different input types supported by the function. Fastcore by default expects two input parameters, since we have only it assigns second one as an object. The second parameter will not have any impact on the function behavior.

to_imageTensor

# OUTPUT  
(ndarray,object) -> to_imageTensor  
(Image,object) -> to_imageTensor  
(str,object) -> to_imageTensor  
(PurePath,object) -> to_imageTensor  
(object,object) -> to_imageTensor

With help of inspect module, we can access the input & output types of a particular function.

import inspect  
inspect.signature(to_imageTensor[np.ndarray])

The docstrings implemented in this iteration makes the code more readable. The docstring of a particular input type can be accessed by calling doc along with the input type.

print(to_imageTensor[np.ndarray].__doc__)

As discussed in the last section, we managed to move the TypeError message to a separate function with input type as object.

Quick Check:

Still working!

is_same_tensor(file_ToImageTensor(filename), to_imageTensor(filename))

# Validating error message when unsupported data type is passed to the function.

to_imageTensor([filename])

What can we improve?

This is it. I believe we are done here.


Closing Thoughts

Did we really write the best function for this use case? I think so, but don’t take my word. You can find the notebook version of this post here. If you like to add anything to this post you, feel free to reach out to me via Contact. I will gladly update this post based on your comments.

Even though Type Dispatch was the focal point when I started writing this post, soon I realized that the code evolution process is equally important too. So I have decided to include that as well. I hope you enjoyed this journey of writing a python function.

Credits

  1. Video Card icon by IconBros
  2. zwicon