Aravind Ramalingam
- Python
- JavaScript
- CNN
- Jupyter Notebook
- Deep Learning
Digit Emoji Predictor: Build UI for Deep Learning model in Notebook
Build a digit emoji predictor in Jupyter notebook by passing data between JavaScript & Python. Typically, JavaScripts are used for data visualization in notebooks, but it can also be used for prototyping front-end/UI for deep learning models.
Why the hell build UI in Notebook?
Short Answer:
Share the Deep Learning model notebook with colleagues from different teams like Business, Data Science, Front End developer, DevOps for their opinion before the real software development takes place.
Jupyter notebook is meant for quick prototyping and everyone knows this but what many misses is we can also do a quick prototype of UI as well. Many data scientists are quick to start building the big deep learning model with whatever little data they have without even thinking about what they are developing and for whom 🤷♂️. Trust me, the first model you are satisfied with (based on the metrics like accuracy) won't be deployed into production for various reasons like slow inference time, no on-device support, memory hogging blah blah.
Even experienced project/product managers sometimes won't know what they really want so first try to get the workflow or pipeline agreement from all the players involved, even with this some things will slip through the crack. " I am data scientist, My job is only to build ML models is not the right mindset, at least in my opinion." A tad bit of knowledge in Business, Web development, DevOps can help you save a ton of time in trying to debug, or explain why the model/workflow you built is the best for business. Better to be jack of all trades here. BTW, this does not mean that you should take over the responsibilities of your colleagues. Just try to get a sense of what is happening outside your circle. In this post, Let me share an example of how creating a simple UI for input & output for deep learning model can help you see the cracks in your workflow and increase productivity.
Application
If you are not familiar with MNIST dataset then just know that you can build a Digit Emoji Predictor as the dataset contains greyscale images of hand-drawn digits, from zero through nine. So our UI is going to let the user draw a digit & display the corresponding emoji.
Firstly, we will decide what we are going to build? Instead of creating
from scratch let us try to recreate Digit
Recognizer in notebook without using Django. The use case we
have chosen is great because we can't simply use
pywidgets
, this forces us to dig deep
into the world of HTML & JavaScript. In the longer run this route has
more benefits because of unlimited capability, free resource and more
importantly closer to the actual front-end. Of course if all you want is
a slider then just roll with pywidgets
.
We will first set up the back-end and then get back to the UI.
Digit Emoji Recognizer
I have trained the CNN Deep learning model using the
pytorch
, we will just download the
model & weights from
MNIST Model. This is a simple model with 3 Conv blocks and 2
Fully connected layers. The Transformations used for this model are
Resize
& ToTensor
Upon calling the predict
function on the image, we will get the predicted digit.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
# Digit Recognizer model definition
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
self.conv3 = nn.Conv2d(32,64, kernel_size=5)
self.fc1 = nn.Linear(3*3*64, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(F.max_pool2d(self.conv3(x),2))
x = F.dropout(x, p=0.5, training=self.training)
x = x.view(-1,3*3*64 )
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return x
# Predict the digit given the tensor
def predict(image, modelName = 'mnist_model.pth'):
# Resize before converting the image to Tensor
Trfms = transforms.Compose([transforms.Resize(28), transforms.ToTensor()])
# Apply transformations and increase dim for mini-batch dimension
imageTensor = Trfms(image).unsqueeze(0)
# Download model from URL and load it
model = torch.load(modelName)
# Activate Eval mode
model.eval()
# Pass the image tensor by the model & return max index of output
with torch.no_grad():
out = model(imageTensor)
return int(torch.max(out, 1)[1])
We need a helper function to convert BASE64 Image (canvas area in which the user is going to draw the digit) to PIL Image.
import re, base64
from PIL import Image
from io import BytesIO
# Decode the image drawn by the user
def decodeImage(codec):
# remove the front part of codec
base64_data = re.sub('^data:image/.+;base64,', '', codec)
# base64 decode
byte_data = base64.b64decode(base64_data)
# Convert to bytes
image_data = BytesIO(byte_data)
# Convert to image & convert to grayscale
img = Image.open(image_data).convert('L')
return img
Now let us combine the two functions namely decodeImage
and predict
as a
separate function. This way is easier because the output of the function
(emoji html code) can be directly passed to the HTML tag via JavaScript.
Hang on for a bit if you are confused, things will get much easier to
follow in the next section.
# Decode the image and predict the value
def decode_predict(imgStr):
# Decode the image
image = decodeImage(imgStr)
# Declare html codes for 0-9 emoji
emojis = [
"0️⃣",
"1️⃣",
"2️⃣",
"3️⃣",
"4️⃣",
"5️⃣",
"6️⃣",
"7️⃣",
"8️⃣",
"9️⃣"
]
# Call the predict function
digit = predict(image)
# get corresponding emoji
return emojis[digit]
UI
The HTML & CSS part are quite straight forward, the JavaScript lets the
user draw inside the canvas and call functions via the buttons
Predict
& Clear
The Clear
button
cleans up the canvas and sets the result to 🤔 emoji. We will set up the
Predict
button JS function in the next
section.
from IPython.display import HTML
html = """
<div class="outer">
<! -- HEADER SECTION -->
<div>
<h3 style="margin-left: 30px;"> Digit Emoji Predictor 🚀 </h3>
<br>
<h7 style="margin-left: 40px;"> Draw a digit from 0️⃣ - 9️⃣</h7>
</div>
<div>
<! -- CANVAS TO DRAW THE DIGIT -->
<canvas id="canvas" width="250" height="250" style="border:2px solid; float: left; border-radius: 5px; cursor: crosshair;">
</canvas>
<! -- SHOW PREDICTED DIGIT EMOJI-->
<div class="wrapper1"> <p id="result">🤔</p></div>
<! -- BUTTONS TO CALL DL MODEL & CLEAR THE CANVAS -->
<div class="wrapper2">
<button type="button" id="predictButton" style="color: #4CAF50;margin:10px;"> Predict </button>
<button type="button" id="clearButton" style="color: #f44336;margin:10px;"> Clear </button>
</div>
</div>
</div>
"""
css = """
<style>
.wrapper1
.wrapper2
.outer
</style>
"""
javascript = """
<script type="text/javascript">
(function() ());
</script>
"""
HTML(html + css + javascript)
Great! We are able to draw inside the canvas and upon clicking the
Clear
button, canvas is cleaned up. The
tricky part starts when the user clicks the Predict
button as we have to do the following:
- JS to Python 👉 Grab the drawing inside the canvas and pass it to python
/* PASS CANVAS BASE64 IMAGE TO PYTHON VARIABLE imgStr*/
var imgData = canvasObj.toDataURL();
var imgVar = 'imgStr';
var passImgCode = imgVar + " = '" + imgData + "'";
var kernel = IPython.notebook.kernel;
kernel.execute(passImgCode);
- Python to JS
👉 Set the predicted emoji as the value of HTML
element (
#result
)
/* CALL PYTHON FUNCTION "decode_predict" WITH "imgStr" AS ARGUMENT */
function handle_output(response)
var callbacks = ;
var getPredictionCode = "decode_predict(imgStr)";
kernel.execute(getPredictionCode, callbacks, );
Linking all together.
predictJS = """
<script type="text/javascript">
/* PREDICTION BUTTON */
$("#predictButton").click(function() );
</script>
"""
HTML(predictJS)
We have both good & bad news, the good thing is data connection between Python and JavaScript works, but the prediction values are all wrong (mapped to same value --- 8). Let us compare the canvas image passed to python and MNIST dataset image.
It is very clear now that we need to invert
the image because the model was trained on the dataset
with black background and white strokes.
import PIL.ImageOps
# Decode the image and predict the value
def decode_predict(imgStr):
# Decode the image
image = decodeImage(imgStr)
# Invert the image as model expects black background & white strokes
image = PIL.ImageOps.invert(image)
# Declare html codes for 0-9 emoji
emojis = [
"0️⃣",
"1️⃣",
"2️⃣",
"3️⃣",
"4️⃣",
"5️⃣",
"6️⃣",
"7️⃣",
"8️⃣",
"9️⃣"
]
# Call the predict function
digit = predict(image)
# get corresponding emoji
return emojis[digit]
Awesome! It works.
Conclusion
Now you can quickly share this notebook with everyone and ask for their opinion. Don't forget to mention the caveat that this is an early prototype. Without much effort, we sort of replicated the functionalities of a web app. The biggest plus is we never left the comfort of notebook and potentially bridged the gap between Training & Test data.
Hopefully I managed to convince you that basic web developer skills are really helpful for a data scientist. The notebook can be accessed from here. Feel free to reach out via Contact.