TensorFlow Keras: Train Model with a List of Tensors as Input instead of a Single Tensor

Sometimes it’s simpler for concatenating of multiple inputs or concatenating in the middle of ML model graph calculation; so feed a list of tensors to model instead of a single tensor.

Descriptions for the code below:

  • Inp: Dataset as a whole
  • Inp1, Inp2: Dataset split into feature 1, and 2
  • Keras will add batch dimension to model based on batch_size
  • x: A batch
  • x1, x2: The 2 parts of a batch with features split into 1, and 2
  • x1a, x1b, x1c: Feature 1 of single sample
  • x2a, x2b, x2c: Feature 2 of single sample

Note 1: Usage with single input per sample

class model(Model):
    @tf.function
    def call(self,x):
        ...

# Inp is a list of samples
model.evaluate(x=Inp, y=...)
model.train(x=Inp, y=...)
model.predict(x=Inp) # Slow
model.call(Inp)
model(Inp) # The same as model.call

Note 2: Usage with multiple inputs per sample

class model(Model):
    @tf.function
    def call(self,xs):
        x1 = xs[0]
        x2 = x2[1]
        ...

# Inp1 is a list of the same section in samples
# Inp2 is a list of another same section in samples
model.evaluate(x=[Inp1,Inp2], y=...)
model.train(x=[Inp1,Inp2], y=...)
model.predict(x=[Inp1,Inp2]) # Slow
model.call([Inp1,Inp2])
model([Inp1,Inp2]) # The same as model.call

# Use manual lists of samples instead of Inp1,Inp2
model.predict(x=[[x1a,x1b,x1c,...], [x2a,x2b,x2c,...]])
model.call([[x1a,x1b,x1c,...], [x2a,x2b,x2c,...]])
model([[x1a,x1b,x1c,...], [x2a,x2b,x2c,...]]) # The same as model.call

Keras model:

class model(Model):
    @tf.function
    def call(self,xs):
        x1 = xs[0]
        x2 = xs[1]
        ...
        # ha calculated from x1
        # hb calculated from x2
        z = tf.concat([ha,hb], axis=1) 
        ...

How to call Keras model with multiple inputs:

my_model = model(...)
my_model.compile(...)
my_model.fit(...)

# Wrong way to call Keras model with multiple inputs:
out = model([[inp1a,inp2a], [inp1b,inp2b]])

# Call the model the correct way
# The first bracket level makes a batch
# inp1a and inp2a are the x1 and x2 of sample A
# inp1b and inp2b are the x1 and x2 of sample B
out = model([[inp1a,inp1b], [inp2a,inp2b]]) 

You'll only receive email when they publish something new.

More from 19411
All posts