Close
0%
0%

Train Neural Networks on STM32 and Arduino

Have you ever wondered how fast it's gonna be to train ANN on STM32 and Arduino, or is it even possible? Here's the benchmark for MCU.

Similar projects worth following
Train neural networks on STM32 and Arduino, now capable of training small ANN, and running CNN inference with MNIST dataset.

Project Homepage: 

ONNX (Open Neural Network Exchange): https://github.com/wuhanstudio/onnx-backend

ANN (RT-Thread): https://github.com/wuhanstudio/rt-libann

I'm just curious about how fast it is or is it possible to train neural networks on STM32 with limited FLASH and RAM, or even on 8-bit MCU Arduino. And it seems to be possible.

msh />onnx_mnist 1
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@        @@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@              @@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@                    @@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@          @@@@@@@@    @@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@        @@@@@@@@@@    @@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@  @@@@@@    @@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@  @@@@      @@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@          @@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@              @@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@            @@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@          @@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@      @@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@    @@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@    @@@@@@@@@@@@@@@@@@  @@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@    @@@@@@@@@@@@@@@@@@    @@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@  @@@@@@@@@@@@@@@@@@      @@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@    @@@@@@@@@@@@        @@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@                      @@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@                  @@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

Predictions:
0.007383 0.000000 0.057510 0.570970 0.000000 0.105505 0.000000 0.000039 0.257576 0.001016

The number is 3

Here's the video recorded with asciinema, you can replay it on your computer with:

asciinema play rt-thread.cast

Here's the recorded file rt-thread.cast

And test results on different MCUs:

MCUFrequencyFLASHRAMTrain TimePred TimeAccuracy
STM32F429IGT6180MHz1024KB256KB9s7ms96.0
STM32F401RET684MHz512KB96KB18s15ms96.0
STM32L475VET680MHz512KB128KB24s22ms96.0
STM32F103RCT672MHz256KB20KB32s26ms96.0
STM32F103C8T672MHz64KB20KB32s26ms96.0
Arduino M0 Pro48MHz256KB32KB135s97ms96.0
ATmega 256016MHz256KB8KB182s138ms96.0

rt-thread.cast

Demo video recorded with asciinema, you can replay it on your computer.

cast - 19.04 kB - 06/28/2019 at 08:41

Download

STM32F103C8T6.bin

Pre-compiled firmware for STM32F103C8T6, you can type command through UART1 at baudrate 115200.

octet-stream - 58.74 kB - 06/28/2019 at 09:15

Download

STM32F103RCT6.bin

Pre-compiled firmware for STM32f103RCT6, you can type command through UART1 at baudrate 115200.

octet-stream - 59.04 kB - 06/28/2019 at 09:20

Download

STM32F401RET6.bin

Pre-compiled firmware for STM32F401RET6, you can type command through UART1 at baudrate 115200.

octet-stream - 67.16 kB - 06/28/2019 at 09:19

Download

  • 1 × Atmega 2560 Arduino running at 16MHz
  • 1 × STM32F103C8T6 / STM32F103RCT6 STM32F10x Series running at 72MHz
  • 1 × STM32L475 STM32L475 running at 80MHz
  • 1 × STM32F401 STM32F401 running at 84MHz
  • 1 × STM32F429 STM32F429 running at 180MHz

  • CNN on STM32

    wuhanstudio08/12/2019 at 10:26 0 comments

    Now capable of running CNN model pre-trained with keras on STM32 with 16KB RAM.

    Model Description:

    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    conv2d_5 (Conv2D)            (None, 28, 28, 2)         20        
    _________________________________________________________________
    max_pooling2d_5 (MaxPooling2 (None, 14, 14, 2)         0         
    _________________________________________________________________
    dropout_5 (Dropout)          (None, 14, 14, 2)         0         
    _________________________________________________________________
    conv2d_6 (Conv2D)            (None, 14, 14, 2)         38        
    _________________________________________________________________
    max_pooling2d_6 (MaxPooling2 (None, 7, 7, 2)           0         
    _________________________________________________________________
    dropout_6 (Dropout)          (None, 7, 7, 2)           0         
    _________________________________________________________________
    flatten_3 (Flatten)          (None, 98)                0         
    _________________________________________________________________
    dense_5 (Dense)              (None, 4)                 396       
    _________________________________________________________________
    dense_6 (Dense)              (None, 10)                50        
    =================================================================
    Total params: 504
    Trainable params: 504
    Non-trainable params: 0
    _________________________________________________________________
    

    CNN inference on STM32:

    msh />onnx_mnist 1
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@        @@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@              @@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@                    @@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@          @@@@@@@@    @@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@        @@@@@@@@@@    @@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@  @@@@@@    @@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@  @@@@      @@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@          @@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@              @@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@            @@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@          @@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@      @@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@    @@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@    @@@@@@@@@@@@@@@@@@  @@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@    @@@@@@@@@@@@@@@@@@    @@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@  @@@@@@@@@@@@@@@@@@      @@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@    @@@@@@@@@@@@        @@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@                      @@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@                  @@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    
    Predictions:
    0.007383 0.000000 0.057510 0.570970 0.000000 0.105505 0.000000 0.000039 0.257576 0.001016
    
    The number is 3
    

  • Add test results for Arduino M0 Pro

    wuhanstudio07/13/2019 at 13:03 0 comments

    MCUFrequencyFLASHRAMTrain TimePred TimeAccuracy
    STM32F429IGT6180MHz1024KB256KB9s7ms96.0
    STM32F401RET684MHz512KB96KB18s15ms96.0
    STM32L475VET680MHz512KB128KB24s22ms96.0
    STM32F103RCT672MHz256KB20KB32s26ms96.0
    STM32F103C8T672MHz64KB20KB32s26ms96.0
    Arduino M0 Pro48MHz256KB32KB135s97ms96.0
    ATmega 256016MHz256KB8KB182s138ms96.0

  • Benchmark for STM32F1/STM32F4/STM32L4 and Atmega2560

    wuhanstudio06/28/2019 at 08:55 0 comments

    Benchmark results on following MCUs:

    MCUFrequencyFLASHRAMTrain TimePred TimeAccuracy
    STM32F429IGT6180MHz1024KB256KB9s7ms96.0%
    STM32F401RET684MHz512KB96KB18s15ms96.0%
    STM32L475VET680MHz512KB128KB24s22ms96.0%
    STM32F103RCT672MHz256KB20KB32s26ms96.0%
    STM32F103C8T672MHz64KB20KB32s26ms96.0%
    ATmega 256016MHz256KB8KB182s138ms96.0%

    The dataset I use is Iris dataset with 150 samples, and the ANN model has 4 inputs, 1 hidden layer with 4 nodes, and 3 outputs, training for 500 loops.

View all 3 project logs

  • 1
    Prerequisites

    The first step is to install prerequisites for RT-Thread. You may have noticed that there is a RTOS running. It is RT-Thread, a burgeoning RTOS in China.

    The command line interface is really convenient when debugging on STM32, so that I can train ANN interactively. Another reason I choose RT-Thread is the large quantity of MCUs it supports. Now it supports nearly 100 MCUs, which means I can run my program on all of these MCU without extra work. This is a real good news for benchmark purpose.

    As you can see, I can choose different types of MCUs and trim the Kernel as needed with KConfig. If you have used buildroot to compile Linux Kernel and build root file system before, you'll find the GUI really handy. My program has been released as  an online package on the platform of RT-Thread.

    You can download all the tools required here https://www.rt-thread.org/page/download.html

    If you are using Windows, please download RT-Thread Env tools first (built-in with all the tools you need) : https://pan.baidu.com/s/1cg28rk

    If you are using Linux, such as Ubuntu, make sure you have Scons and cross-compile tools installed.

    Finally for both Windows and Linux, clone the source code of RT-Thread, and we're ready to go.

    git clone https://github.com/RT-Thread/rt-thread
  • 2
    Install package libann

    The following instructions is based on STM32F103RCT6, but should work for all MCUs RT-Thread supports.

    If you are using Windows, simply dive into the directory rt-thread\bsp\stm32\stm32f103-atk-nano and right click ConEmu here, then type menuconfig, you should see configuration menu on the first step.

    If you are using Linux

    $ cd rt-thread\bsp\stm32\stm32f103-atk-nano
    rt-thread\bsp\stm32\stm32f103-atk-nano$ scons --menuconfig

     Now here's the configuration menu, you can use Arrow to navigate and Space to select/deselect

     Remember use Space to select, let's navigate to,:

    RT-Thread online Packages -->  
        miscellaneous packages  --->
            [*] libann: a light-weight ANN library, capable of training, saving and loading models.  --->
                [*]   Iris load model from flash and predict example
                [ ]   Iris train model and predict example
                [ ]   Iris load model and predict example

     You may try the last two examples if you have a SD card on your board, but as for now, I'll choose the first one only that loads dataset from FLASH. Now, it's time to compile.

    If you are using Windows:

    # Make sure you are in the directory rt-thread\bsp\stm32\stm32f103-atk-nano
    $ pkgs --update
    $ scons
    
    # If you prefer use Keil MDK5 to compile
    # pkgs --update
    # scons --target=mdk5 -s 
    # scons

    If you are using Linux:

    # Make sure you are in the directory rt-thread\bsp\stm32\stm32f103-atk-nano
    # You may need to modify rtconfig.py if Python failed to find your cross-compiler
    $ source ~/.env/env.sh
    $ pkgs --update
    $ scons
    
  • 3
    Benchmark

    Upload compiled firmware onto your board, open UART1 at baudrate 115200, you should see the command line interface.

    Simply type in:

    msh /> iris_train_and_predict_flash

     You should see STM32 training:

    Congratulations !!

View all 3 instructions

Enjoy this project?

Share

Discussions

Similar Projects

Does this project spark your interest?

Become a member to follow this project and never miss any updates