Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support output in batch level methods of callaback #330

Closed
innat opened this issue Dec 15, 2022 · 37 comments
Closed

Support output in batch level methods of callaback #330

innat opened this issue Dec 15, 2022 · 37 comments
Assignees

Comments

@innat
Copy link

innat commented Dec 15, 2022

System information.

TensorFlow version (you are using): 2.9
Are you willing to contribute it (Yes/No) : No (It's not clear from where to start; it may also required API desgin level discussion.)

Describe the feature and the current behavior/state

Currently, if we want to use a data set in callback, we usually pass it to the corresponding callback's constructor. That dataset can be validation set or any special dataset that we want to .evaluate or .predict, etc. The feature of current callback methods are listed here.

Now, I've encountered a case, where I need to have the following from validation set.

  • image_path,
  • image_tensor,
  • target,
  • predicted
  • '<>`

I've a .fit method as follows, and I need to use a callback for validation (on_test_batch_end) data to further save the image-tensor as image-file with predicted value (here, which is mask) and various type of metrics calculation at instance level (samplw-wise).

model.fit(
    training_dataset, 
    validation_data=valid_dataset, 
    epochs=EPOCHS,
)
  • I am not sure if epoch-level or global-level need such modificaiton.
  • We can of course, run the model on the validation dataset inside the callback and get true and predicted results; but IMHO, that might be a wrong argument. By this, we're running the validation twice, one for validation_data in .fit and another time in callback, looks unncessary. Also, there can be multiple callback which may require the same.
  • There can be some hacks to achive this but it might be more natural to have those attributes (image_paths, tensor, target and prediction, [and so on] ) from callback-level method.
  • PyTorch-Lightning reference.

Will this change the current api? How?

Currently

on_(train|test|predict)_batch_end(self, batch, logs=None)

Maybe,

on_(train|test|predict)_batch_end(self, output, batch_id, logs=None)

outputs["image_path"],
outputs["image"],
outputs["gt_label"],
outputs["prediction"],

Here the term output ( also in the post title) inspired from pytorch-lightning and just for demonstration purpose).

Who will benefit from this feature?

keras-user.

Contributing

  • Do you want to contribute a PR? (yes/no): no
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing):
@haifeng-jin
Copy link
Contributor

Summary: Want to access model output in callbacks without running the forward pass of the model again.

@innat Would it be better just to use the submodel API for this with Model.test_step()?
Link

@innat
Copy link
Author

innat commented Dec 22, 2022

@innat Would it be better just to use the submodel API for this with Model.test_step()?
Link

IMO, not really. Please note that, I raised this ticket in general. Some may have a fuctional model (written for classification, object detection etc), and rewriting subclass just for these, looks inconvenient and not intuitive. Having output from callback API would be much more reasonable and cleaner.

@rchao
Copy link
Contributor

rchao commented Dec 22, 2022

Hello @innat, it would require a larger scope change on Keras training/testing APIs before such support is available, and in the meantime overriding of training or testing step is still the best solution. Whether or not we'll make it a default support still needs to be assessed, based on how general such usage is.

@innat
Copy link
Author

innat commented Dec 24, 2022

Agree that it would require larger change in the API. Please do the assessment as needed.

@rchao
Copy link
Contributor

rchao commented Jan 12, 2023

Circling back here as I heard back from the team - currently we don't have a plan to have the predicted output in the batch level callbacks since that may further complicate the training flow as the output may be on the remote workers and not available on chief (and syncing may have performance penalty). For now, test_step override is still the best solution for the original request.

@innat
Copy link
Author

innat commented Jan 13, 2023

@rchao
Thanks for the update.

since that may further complicate the training flow as the output may be on the remote workers and not available on chief (and syncing may have performance penalty).

I understand that this may require major change in the API. But rather than complication, I think it would be a very useful feature to have, for example, when we need to unpack the dataloader (generator or tf.data API) to compute some tf or non-tf computation over the model's output. In case of overriding the test_step, I tired once lightly, and faced some issue for eager mode and graph mode training approach (will revisit with details.).

@rchao
Copy link
Contributor

rchao commented Jan 19, 2023

@innat thanks for the info. If you're interested, can you show me a possible way of achieving this while not adding much complexity to the library, and not breaking any existing tests?

@innat
Copy link
Author

innat commented Jan 21, 2023

@rchao
Sorry, from your last reply, I think I couldn't explain properly. I actually didn't mean any possibility to achieve that without revisiting the API design logic.


Anyway, as it is suggested (by @haifeng-jin ) that the best option here could be overriding test_step. Below is something that I tried, but faced some issues.

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.val_gt = []
        self.val_pred = []
        
    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)
 
        self.val_gt.append(y)
        self.val_pred.append(y_pred) 
        return {m.name: m.result() for m in self.metrics}

class MyCallback(tf.keras.callbacks.Callback):
    
    def on_epoch_begin(self, epoch, logs=None):
        print('called on begin')
        # reset 
        self.model.val_gt = []
        self.model.val_pred = []
        
    def on_epoch_end(self, epoch, logs=None):
        print('called on end')
        print(self.model.val_gt)
        print(self.model.val_pred)
        print()
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(
    x, y, epochs=3, validation_split=0.1, verbose=2, 
    callbacks=MyCallback()
)
Epoch 1/3
called on begin
29/29 - 1s - loss: 0.3536 - mae: 0.4871 - val_loss: 0.1599 - val_mae: 0.3253
called on end
ListWrapper([<tf.Tensor 'IteratorGetNext:1' shape=(None, 1) dtype=float32>])
ListWrapper([<tf.Tensor 'custom_model_4/dense_4/BiasAdd:0' shape=(None, 1) dtype=float32>])

Epoch 2/3
called on begin
29/29 - 0s - loss: 0.2058 - mae: 0.3681 - val_loss: 0.1534 - val_mae: 0.3148
called on end
ListWrapper([])
ListWrapper([])

Epoch 3/3
called on begin
29/29 - 0s - loss: 0.1969 - mae: 0.3620 - val_loss: 0.1482 - val_mae: 0.3104
called on end
ListWrapper([])
ListWrapper([])

Running this code in graph mode (not eager mode), I've faced issue to get value (IteratorGetNext). Also, using callback, it's returned empty list after epoch end. Any catch?

@innat
Copy link
Author

innat commented Jan 23, 2023

With eager mode (works but not in graph mode), cc. @rchao @haifeng-jin

class MyCallback(tf.keras.callbacks.Callback):
    
    def on_epoch_begin(self, epoch, logs=None):
        print('called on begin ', epoch)
        self.model.val_gt = []
        self.model.val_pred = []
        
    def on_epoch_end(self, epoch, logs=None):
        print('called on end ', epoch)
        print(len(self.model.val_gt), len(self.model.val_pred))
        
        gt = [i.numpy()[-1] for a in self.model.val_gt for i in a]
        pred = [i.numpy()[-1] for a in self.model.val_pred for i in a]
        
        print(gt)
        print(pred)
        print()

inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(
    optimizer="adam", loss="mse", metrics=["mae"], run_eagerly=True
)

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
x_test = np.random.random((10, 32))
y_test = np.random.random((10, 1))

model.fit(
    x, y, 
    epochs=2, 
    validation_data=(x_test, y_test),
    verbose=2, 
    batch_size=4,
    callbacks=MyCallback()
)
Epoch 1/2
called on begin  0
250/250 - 2s - loss: 0.3321 - mae: 0.4534 - val_loss: 0.1830 - val_mae: 0.3323
called on end  0
3 3
[0.0333783, 0.1717369, 0.7665437, 0.8238915, 0.7495797, 0.2540941, 0.8106554, 0.62855566, 0.9265095, 0.6541081]
[0.3376781, 0.19622403, 0.41489872, 0.8612992, 0.27498952, 0.0058316663, 0.5520814, -0.05125352, 0.98022455, -0.2361697]

Epoch 2/2
called on begin  1
250/250 - 3s - loss: 0.2312 - mae: 0.3831 - val_loss: 0.1736 - val_mae: 0.3258
called on end  1
3 3
[0.0333783, 0.1717369, 0.7665437, 0.8238915, 0.7495797, 0.2540941, 0.8106554, 0.62855566, 0.9265095, 0.6541081]
[0.32530463, 0.20046362, 0.40798, 0.7969549, 0.24543351, 0.04316484, 0.49703956, -0.014348835, 0.88799894, -0.18776876]

<keras.callbacks.History at 0x7f1268071450>
gt = [i.numpy()[-1] for a in model.val_gt for i in a]
gt
[0.6062543,
 0.9596564,
 0.80562544,
 0.28538862,
 0.4508967,
 0.4312933,
 0.17705786,
 0.6797946,
 0.68246496,
 0.93645567]

# should match with the last epoch
pred = [i.numpy()[-1] for a in model.val_pred for i in a]
pred
[0.03851793,
 0.43239358,
 0.18904832,
 0.3528205,
 0.16835782,
 0.12055702,
 0.50357527,
 -0.12372895,
 0.05625192,
 0.43881953]

@innat
Copy link
Author

innat commented Jan 23, 2023

@bhack any thoughts on this?

@bhack
Copy link
Contributor

bhack commented Jan 23, 2023

I don't understand what is exactly the goal here. Are you just looking to access tf dataset in the callback like in https://stackoverflow.com/questions/64128947/how-to-access-tf-data-dataset-within-a-keras-custom-callback ?

@innat
Copy link
Author

innat commented Jan 23, 2023

Not just the input (x, y) but also the prediction (y_pred) from callback with given dataset. (The design is mostly inspired by the pytorch-lightning. )

on_(train|test|predict)_batch_end(self, output, batch_id, logs=None)

outputs["image_path"],
outputs["image"],
outputs["gt_label"],
outputs["prediction"],

cc. @haifeng-jin

Summary: Want to access model output in callbacks without running the forward pass of the model again.

@innat Would it be better just to use the submodel API for this with Model.test_step()? Link

@bhack
Copy link
Contributor

bhack commented Jan 23, 2023

Not just the input (x, y) but also the prediction (y_pred) from callback with given dataset. (

For the output are you re-proposing again the old keras-team/keras#3469?

@innat
Copy link
Author

innat commented Jan 23, 2023

For that issue, I think using self.model attribute, it ( layers output) can be achieved. But here, I'm expecting to get the model input (x: image_tensor, y: [target_label | target_mask | ..)] and model's output (y_pred: [predicted_label | predicted_mask | ..]).

on_(train|test|predict)_batch_end(self, output, batch_id, logs=None)

outputs["image_path"],
outputs["image_tensor"],
outputs["target_label"],
outputs["predicted_label"],
...

So, if we run model.evaluate or predict for example, the above information can be achived from the callbacak API (most likely lightning does). (Now, it is suggested that we can use test_step but faced issue for graph mode training.)

@bhack
Copy link
Contributor

bhack commented Jan 23, 2023

Do you have a minimal self-contained colab or gist to reproduce the error?

@innat
Copy link
Author

innat commented Jan 23, 2023

Already mentioned previous message, here.

model.compile(
    optimizer="adam", loss="mse", metrics=["mae"], 
    run_eagerly=True # False causes issue
)

@innat
Copy link
Author

innat commented Jan 23, 2023

Not really, in that callback forward pass is repeated.

model_train_step = model.train_step

def outer_train_step(data):
    # https://github.com/keras-team/keras/blob/v2.7.0/keras/engine/training.py
    x, y_true, w = keras.utils.unpack_x_y_sample_weight(data)

    self.x.assign(x)
    if w is not None:
        self.w.assign(w)
    self.y_true.assign(y_true)

    result = model_train_step(data)

    y_pred = model(x)
    self.y_pred.assign(y_pred)

    return result

@bhack
Copy link
Contributor

bhack commented Jan 23, 2023

Yes it was claimed (expensive). It is one of the many solutions accumulated over the time on that reply as it was a quite old and recurrent topic in SO.

immagine

All the other accumulated solutions in that SO reply have no Input, Target, Output access efficiently at the same time.

@innat
Copy link
Author

innat commented Jan 25, 2023

Instead of [], using tf.Variable works.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.experimental import numpy as tnp

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.val_x = tf.Variable((
            tnp.empty((0, 32), dtype=tf.float32)), shape=[None, 32]
        )
        self.val_gt = tf.Variable(
            tnp.empty((0, 1), dtype=tf.float32), shape=[None, 1]
        )
        self.val_pred = tf.Variable(
            tnp.empty((0, 1), dtype=tf.float32), shape=[None, 1]
        )

    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)

        self.val_x.assign(
            tf.concat([self.val_x, x], axis=0)
        )
        self.val_gt.assign(
            tf.concat([self.val_gt, y], axis=0)
        )
        self.val_pred.assign(
            tf.concat([self.val_pred, y_pred], axis=0)
        )
        return {m.name: m.result() for m in self.metrics}
class MyCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        print('called on begin ', epoch)
        self.model.val_x.assign(
            tf.Variable((
                tnp.empty((0, 32), dtype=tf.float32)), shape=[None, 32]
            )
        )
        self.model.val_gt.assign(
            tf.Variable(
                tnp.empty((0, 1), dtype=tf.float32), shape=[None, 1]
            )
        )
        self.model.val_pred.assign(
            tf.Variable(
                tnp.empty((0, 1), dtype=tf.float32), shape=[None, 1]
            )
        )

    def on_epoch_end(self, epoch, logs=None):
        print('called on end ', epoch)
        print(self.model.val_gt.numpy())
        print(self.model.val_pred.numpy())
        print(self.model.val_x.numpy().shape)
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(
    optimizer="adam", loss="mse", metrics=["mae"], run_eagerly=0
)

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
x_test = np.random.random((10, 32))
y_test = np.random.random((10, 1))

model.fit(
    x, 
    y, 
    epochs=5, 
    validation_data=(x_test, y_test),
    verbose=2, 
    batch_size=4,
    callbacks=[
        MyCallback(), 
    ]
)
Epoch 1/5
called on begin  0
250/250 - 1s - loss: 0.3816 - mae: 0.4723 - val_loss: 0.0360 - val_mae: 0.1434
called on end  0
[0.82510984 0.8162336  0.94785255 0.06877889 0.05006607 0.4200096
 0.8379941  0.23999517 0.32227454 0.12522219]
[ 0.81243783  0.96750426  0.9465156   0.5163584  -0.14448965  0.4210961
  1.0099922   0.39414424  0.4558667   0.2907895 ]
(10, 32)

Epoch 2/5
called on begin  1
250/250 - 0s - loss: 0.1617 - mae: 0.3254 - val_loss: 0.0380 - val_mae: 0.1485
called on end  1
[0.82510984 0.8162336  0.94785255 0.06877889 0.05006607 0.4200096
 0.8379941  0.23999517 0.32227454 0.12522219]
[ 0.81986165  0.94598633  0.9096545   0.5447165  -0.08035474  0.44689882
  0.9836677   0.41099367  0.47807842  0.33092263]
(10, 32)

Epoch 3/5
called on begin  2
250/250 - 0s - loss: 0.1455 - mae: 0.3091 - val_loss: 0.0398 - val_mae: 0.1494
called on end  2
[0.82510984 0.8162336  0.94785255 0.06877889 0.05006607 0.4200096
 0.8379941  0.23999517 0.32227454 0.12522219]
[ 0.8137208   0.91400844  0.8605845   0.564705   -0.01779287  0.46524122
  0.94323516  0.41769716  0.48927492  0.36420074]
(10, 32)

Epoch 4/5
called on begin  3
250/250 - 0s - loss: 0.1296 - mae: 0.2932 - val_loss: 0.0290 - val_mae: 0.1234
called on end  3
[0.82510984 0.8162336  0.94785255 0.06877889 0.05006607 0.4200096
 0.8379941  0.23999517 0.32227454 0.12522219]
[ 0.71483564  0.8042053   0.73826283  0.48975974 -0.0262439   0.39364785
  0.8218028   0.34616536  0.39640903  0.30753985]
(10, 32)

Epoch 5/5
called on begin  4
250/250 - 0s - loss: 0.1177 - mae: 0.2815 - val_loss: 0.0360 - val_mae: 0.1352
called on end  4
[0.82510984 0.8162336  0.94785255 0.06877889 0.05006607 0.4200096
 0.8379941  0.23999517 0.32227454 0.12522219]
[0.72407675 0.78460646 0.70396656 0.5215519  0.04386391 0.4255894
 0.7968867  0.3694296  0.42922068 0.35909805]
(10, 32)
<keras.callbacks.History at 0x7f8ad85bad10>

@innat
Copy link
Author

innat commented Jan 29, 2023

I run the above code in CPU and it gives target and prediction as its coded. But in GPU mode, I set

import tensorflow as tf
tf.config.optimizer.set_jit(True)

And the same code gives the following error

Node: 'AssignVariableOp_2'
Shape of resource Resource-50-at-0x2602e640 cannot be 
changed after initialization: old shape was [0,1], new shape is [4,1] 
(defined @ <ipython-input-6-9bc5a65d609d>:10)

	 [[{{node AssignVariableOp_2}}]]
	 [[cluster_3_1/xla_compile]] [Op:__inference_test_function_7495]

In the above error message, the [0, 1], zero acts a placeholder for sample size and [4, 1], is the batch size. FYI, in the actual code that I'm working, the error doesn't show up but simply gets stucked as a deadlock at the end of the first epoch.

@bhack
Copy link
Contributor

bhack commented Jan 29, 2023

Yes the XLA bridge requires a constant shape:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/xla_resource.cc#L102:L109

@innat
Copy link
Author

innat commented Jan 31, 2023

Oh, error and error. 🤕
Now facing issue with TPU (works well in GPU [without jit compile).

InvalidArgumentError: Dst node should be assigned to an allowed device. Found an edge from node concat_variable_140584752799952_handle_inputs_0/shape/_4 assigned to /job:localhost/replica:0/task:0/device:COMPOSITE:0 to node TPUReplicate/_compile/_4530194686673722251/_7 assigned to /job:worker/replica:0/task:0/device:CPU:0

logs

023-01-31 17:16:49.309748: W tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc:76] Unable to destroy remote tensor handles. If you are running a tf.function, it usually indicates some op in the graph gets an error: Dst node should be assigned to an allowed device. Found an edge from node concat_variable_140584752799952_handle_inputs_0/shape/_4 assigned to /job:localhost/replica:0/task:0/device:COMPOSITE:0 to node TPUReplicate/_compile/_4530194686673722251/_7 assigned to /job:worker/replica:0/task:0/device:CPU:0
2023-01-31 17:16:49.309773: W ./tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h:57] Ignoring an error encountered when deleting remote tensors handles: Invalid argument: Unable to find the relevant tensor remote_handle: Op ID: 2613, Output num: 1
Additional GRPC error information from remote target /job:worker/replica:0/task:0:
:{"created":"@1675185409.307338558","description":"Error received from peer ipv4:10.0.0.2:8470","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Unable to find the relevant tensor remote_handle: Op ID: 2613, Output num: 1","grpc_status":3}

@innat
Copy link
Author

innat commented Feb 10, 2023

Yes the XLA bridge requires a constant shape:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/xla_resource.cc#L102:L109

@bhack is it a known issue or expected behaviour? Any workaround, for example using tf.TensorArray(.., dynamic_size)?

@bhack
Copy link
Contributor

bhack commented Feb 10, 2023

Mhh.. I don't think so, you could try to check tensorflow/tensorflow#47170

@innat
Copy link
Author

innat commented Feb 10, 2023

Ah, I see. If I compile the model with eager mode, it works with tf.config.optimizer.set_jit(True) otherwise, that error.

@bhack
Copy link
Contributor

bhack commented Feb 10, 2023

@innat
Copy link
Author

innat commented Feb 10, 2023

Just tried as follows

import tensorflow as tf
tf.config.optimizer.set_jit(True)
tf.config.experimental.enable_mlir_bridge()

....
InternalError:  Invalid input index for variable write.
	 [[{{node cluster_0_1/xla_run}}]] [Op:__inference_train_function_2987]

Function call stack:
train_function

@bhack
Copy link
Contributor

bhack commented Feb 10, 2023

Is it the same inverting the order?

@innat
Copy link
Author

innat commented Feb 10, 2023

Tried both. But only with tf.config.experimental.enable_mlir_bridge(), it runs.

# using both (gives error)
tf.config.optimizer.set_jit(True)
tf.config.experimental.enable_mlir_bridge()

# only (works)
tf.config.experimental.enable_mlir_bridge()

added

gist.

@bhack
Copy link
Contributor

bhack commented Feb 10, 2023

Is it running with 2.11?

@innat
Copy link
Author

innat commented Feb 10, 2023

Colab crashed (no clue why). I'm using kaggle env, where it is 2.6. That's why I installed it in colab too.

@bhack
Copy link
Contributor

bhack commented Feb 10, 2023

With nightly the compiler crashed

---> 52 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, 53 inputs, attrs, num_outputs) 54 except core._NotOkStatusException as e: InternalError: Graph execution error: RET_CHECK failure (tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:626) dnn != nullptr 	 [[{{node cluster_1_1/xla_compile}}]] [Op:__inference_train_function_1394]

@innat
Copy link
Author

innat commented Feb 12, 2023

Probably the same reason this approach also doesn't work on TPU (doesn't support dynamic shape). tensorflow/tensorflow#59511 , kinda misleading error message too.

cc. @rchao @haifeng-jin

@innat
Copy link
Author

innat commented Feb 26, 2023

I run the above code in CPU and it gives target and prediction as its coded. But in GPU mode, I set

import tensorflow as tf
tf.config.optimizer.set_jit(True)

And the same code gives the following error

Node: 'AssignVariableOp_2'
Shape of resource Resource-50-at-0x2602e640 cannot be 
changed after initialization: old shape was [0,1], new shape is [4,1] 
(defined @ <ipython-input-6-9bc5a65d609d>:10)

	 [[{{node AssignVariableOp_2}}]]
	 [[cluster_3_1/xla_compile]] [Op:__inference_test_function_7495]

In the above error message, the [0, 1], zero acts a placeholder for sample size and [4, 1], is the batch size. FYI, in the actual code that I'm working, the error doesn't show up but simply gets stucked as a deadlock at the end of the first epoch.

Update

Creating variables within CPU and outside the strategy scope works.

with tf.device('/CPU:0'):
    val_gt = tf.Variable(
        tnp.empty((0, 1), dtype=tf.float32), shape=[None, 1], trainable=False
    )
    val_pred = tf.Variable(
        tnp.empty((0, 1), dtype=tf.float32), shape=[None, 1], trainable=False
    )

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        ...

@innat
Copy link
Author

innat commented Jul 13, 2023

@rchao Hi,
As the Keras-core is developing, I was wondering if there is any interest to reconsider this and also this requests as a feature. They are both available in PyTorch-Lightning. As the torch has become the backend of Keras-core, the practitioner might expect to have these as well.

@rchao
Copy link
Contributor

rchao commented Jul 13, 2023

Thanks @innat, I would defer to @fchollet to make a call on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants