Performance drawbacks of Tensorflow Datasets

Tensorflow, the popular framework for machine-learning, recommends its new dataset API for preprocessing and serving data. It supports useful tricks, such as caching data in memory, prefetching in parallel threads and others described in tutorials. Still, Tensorflow has issues with slow data slicing, so the dataset API may actually do harm in setups where computations are relatively fast; this includes linear models, word2vec and other shallow networks. In this note I illustrate this issue on training a classifier on MNIST dataset (handwritten images). As shown in the benchmark below, sometimes it makes sense to serve your data by custom generators not the recommended api.

The model is Logistic Regression. The loss function is implemented via logsumexp trick, for numerical stability and computational efficiency.

## model: Logistic Regression

w = tf.Variable(tf.random.normal(shape=(28*28,10),stddev=0.1),trainable=True)
optimizer = tf.optimizers.SGD(0.01)

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
    all_logits = tf.matmul(x,w) # (n_batch,n_class)
    y_logits = tf.gather(all_logits,y,batch_dims=1) # (n_batch,)
    logp = y_logits - tf.reduce_logsumexp(all_logits,axis=1)
    loss = -logp
  gradients = tape.gradient(loss,[w])
  optimizer.apply_gradients(zip(gradients,[w]))

Now we compare two ways of serving the data: by custom generators and by Dataset API (with caching and prefetching for better performance).

## serving data: a) via custom generator b) via TF Dataset API

def gen_batch(X,n_window):

  def gen():
    for i in range(0,len(X),n_window):
      yield X[i:i+n_window]
  
  return gen

def gen_data():
  n_window = 32
  return zip(gen_batch(x_train,n_window)(),gen_batch(y_train,n_window)())

tf_data = tf.data.Dataset.from_tensor_slices((x_train,y_train))
tf_data = tf_data.batch(32,drop_remainder=True)
tf_data = tf_data.cache()
tf_data = tf_data.prefetch(1)

Below running the comparison, the graph and dataset are warmed by one full-pass (then the graph gets built and the api pipeline is cached). Both approaches fit the same classifier, but the code with custom generator runs 50% faster. For the full code, see the jupyter notebook.

Dodaj komentarz

Twój adres e-mail nie zostanie opublikowany. Wymagane pola są oznaczone *