2.2.0
Keras stateful LSTM/GRU support


ST Edge AI Core

Keras stateful LSTM/GRU support


ST Edge AI Core Technology 2.2.0



r2.0

Introduction

ST Edge AI Core provides an initial support for the Keras stateful LSTM/GRU layers (‘legacy’ C-API). Only a batch-size of 1 is supported. To manage completely the internal state of a layer, dedicated weak functions can be implemented by the client application. Note that the return-state attribute is not supported.

Example of Keras model with LSTM layer

inputs = tf.keras.layers.Input(shape=(1, 64), batch_size=1)
lstm = tf.keras.layers.LSTM(20, activation='relu', stateful=True,
                                return_sequences=True)(inputs)
dense = tf.keras.layers.Dense(5)(lstm)
model = tf.keras.Model([inputs], [dense])

Limitations

State buffer allocation

State of a recurrent layer is currently stored in a memory buffer which must be allocated outside the [activations][API_data_placement] buffer because the contents should be conserved between two inferences. By default, dedicated weak functions (_allocate_[lstm,gru]_states()) are implemented in the network runtime library to allocate the buffer in the system heap (call of malloc() function). This function is called during the execution of the ai_<network>_init() function. Allocated buffer will be used by the forward LSTM/GRU function during the inference (ai_<network>_run() function) to restore and to save the global state.

AI_DECLARE_WEAK
void _allocate_lstm_states(ai_float **states, ai_u32 size_in_bytes)
{
  ai_handle src = AI_HANDLE_PTR(*states);
  *states = (ai_float*)realloc(src, size_in_bytes);
  // Clear lstm initial state
  if (*states) {
    memset(*states, 0, size_in_bytes);
  }
}

AI_DECLARE_WEAK
void _allocate_gru_states(ai_float **states, ai_u32 size_in_bytes)
{
  AI_ASSERT(states && size_in_bytes>0)
  ai_handle src = AI_HANDLE_PTR(*states);
  *states = (ai_float*)realloc(src, size_in_bytes);
  // Clear lstm initial state
  if (*states) {
    memset(*states, 0, size_in_bytes);
  }
}

As illustrated in the following snippet code, to be able to set the state at a specific value or to reset the state regularly, this function should be implemented by the client application to have a reference on the allocated buffers.

static ai_float* _lstm_states[NUM_INSTANCES] = {};
static ai_u32 _lstm_states_size[NUM_INSTANCES] = {};
static int _lstm_instance_idx = 0;

void _allocate_lstm_states(ai_float **states, ai_u32 size_in_bytes)
{
  if ((_lstm_instance_idx > NUM_INSTANCES) || (!states) || (*states) || (!size_in_bytes)) {
      /* error - invalid call */
      return;
  }
  ai_handle src = AI_HANDLE_PTR(*states);
  _lstm_states[_lstm_instance_idx] = (ai_float *)malloc(size_in_bytes);
  _lstm_states_size[_lstm_instance_idx] = size_in_bytes;
  *states = _lstm_states[_lstm_instance_idx];
  /* 
     Clear lstm initial state or
     set the state with a user-defined value.
  */
  if (*states) {
    memset(*states, 0, size_in_bytes);
  }
  _lstm_instance_idx++;
}
...
void clear_lstm_states(void) {
    for (int i=0; i<NUM_INSTANCES; i++) {
            memset(_lstm_states[i], 0, _lstm_states_size[i]);
    }
}
...
aiInit();

/* main processing loop */
int sample_n = 0;
while (1) {
    /* 1 - Acquire, pre-process and fill the input buffers */
    acquire_and_process_data(in_data);

    /* 2 - Call inference engine */
    if (sample_n++ > 20) {
        clear_lstm_states();
        sample_n = 0;
    }
    aiRun(in_data, out_data);

    /* 3 - Post-process the predictions */
    post_process(out_data);
}

To free the allocated resources, second weak function is also defined and called when the instance of the model is destroyed (ai_<network>_destroy() function).

AI_DECLARE_WEAK
void _deallocate_lstm_states(ai_float **states)
{
  if (*states) {
    free(*states);
    *states = NULL;
  }
}

AI_DECLARE_WEAK
void _deallocate_gru_states(ai_float **states)
{
  if (*states) {
    free(*states);
    *states = NULL;
  }
}

Validate command limitation

With the current implementation of the aiValidation firmware, the state buffers are created and reset only during the initialization of the firmware. If multiple validate command are used, the board should be re-start between two executions.

# board should be reset.
$ stm32ai validate model.h5 --mode stm32 -b 20

There is no limitation for the validation on desktop. Run-time is re-instantiated between two executions.

$ stm32ai validate model.h5 -b 20