TensorFlow Keras: Train Model with a List of Tensors as Input instead of a Single Tensor
July 4, 2023•311 words
Sometimes it’s simpler for concatenating of multiple inputs or concatenating in the middle of ML model graph calculation; so feed a list of tensors to model instead of a single tensor.
Descriptions for the code below:
- Inp: Dataset as a whole
- Inp1, Inp2: Dataset split into feature 1, and 2
- Keras will add batch dimension to model based on
batch_size
- x: A batch
- x1, x2: The 2 parts of a batch with features split into 1, and 2
- x1a, x1b, x1c: Feature 1 of single sample
- x2a, x2b, x2c: Feature 2 of single sample
Note 1: Usage with single input per sample
class model(Model):
@tf.function
def call(self,x):
...
# Inp is a list of samples
model.evaluate(x=Inp, y=...)
model.train(x=Inp, y=...)
model.predict(x=Inp) # Slow
model.call(Inp)
model(Inp) # The same as model.call
Note 2: Usage with multiple inputs per sample
class model(Model):
@tf.function
def call(self,xs):
x1 = xs[0]
x2 = x2[1]
...
# Inp1 is a list of the same section in samples
# Inp2 is a list of another same section in samples
model.evaluate(x=[Inp1,Inp2], y=...)
model.train(x=[Inp1,Inp2], y=...)
model.predict(x=[Inp1,Inp2]) # Slow
model.call([Inp1,Inp2])
model([Inp1,Inp2]) # The same as model.call
# Use manual lists of samples instead of Inp1,Inp2
model.predict(x=[[x1a,x1b,x1c,...], [x2a,x2b,x2c,...]])
model.call([[x1a,x1b,x1c,...], [x2a,x2b,x2c,...]])
model([[x1a,x1b,x1c,...], [x2a,x2b,x2c,...]]) # The same as model.call
Keras model:
class model(Model):
@tf.function
def call(self,xs):
x1 = xs[0]
x2 = xs[1]
...
# ha calculated from x1
# hb calculated from x2
z = tf.concat([ha,hb], axis=1)
...
How to call Keras model with multiple inputs:
my_model = model(...)
my_model.compile(...)
my_model.fit(...)
# Wrong way to call Keras model with multiple inputs:
out = model([[inp1a,inp2a], [inp1b,inp2b]])
# Call the model the correct way
# The first bracket level makes a batch
# inp1a and inp2a are the x1 and x2 of sample A
# inp1b and inp2b are the x1 and x2 of sample B
out = model([[inp1a,inp1b], [inp2a,inp2b]])