Quantization: PyTorch vs ESP32 S3

A project log for Running a PyTorch Model on the ESP32 S3

This log describes the steps I am taking to make a model trained in PyTorch work on the ESP32 S3 (M5Stack Core S3)

es-pronkE/S Pronk 05/13/2024 at 19:130 Comments

I'm working on a custom model, and I'm using pytorch to train it. Most of the layers are custom so I can't just export to some standard format and hope for the best. I'm going to duplicate the layers' logic in C on the ESP32, then use PyTorch to quantize my model weights.

I would like to try the ESP-DL library from Espressif, but unfortunately they use a different quantization scheme than PyTorch and claim you can't use your model with their API. This is not entirely true, it's just that there is no easy way to use your model with their quantization scheme, but you certainly can.

The key thing to understand is how both quantization schemes work. PyTorch uses a zero-point and a scale:

f32 = (i8 - zero_point) * scale

 while ESP-DL uses an exponent:

f32 = i8 * (2 ** exponent) 

which they claim is not compatible.

We can make this work though, if we force PyTorch to use a zero-point with value 0 and a scale that is always 2 to the power of a (signed) int.

Getting a zero-point of 0 is easy, we have to set the qconfig to use a symmetric quantization scheme. The scale is a little bit harder but no rocket science either: We can overload a suitable QuantizationObserver to produce qparams with a scale that is updated to

scale = 2 ** round( log2( scale ))

 Like so:

import as Q
class ESP32MovingAverageMinMaxObserver(Q.MovingAverageMinMaxObserver):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor):
        s,z = super()._calculate_qparams(min_val, max_val)
        assert (z == 0).all()
        s = 2 ** s.log2().round().clamp(-128,127)
        return s,z

Then when it is time to export the weights we also export the exponent we use in ESP-DL by simply getting the log2 of the scale of the weight tensor.