In [None]:
import os
import pickle
import tarfile
import datetime
import numpy as np
import urllib.request
import sklearn.metrics
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
              tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
%load_ext tensorboard

In [None]:
HISTORY_DIR = './history'
os.makedirs(HISTORY_DIR, exist_ok=True)

In [None]:
def get_data():
    if not os.path.exists('cifar-10-batches-py/'):
        urllib.request.urlretrieve('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 'cifar-10-python.tar.gz')
        file = tarfile.open('cifar-10-python.tar.gz', 'r:gz')
        file.extractall()
        
    X, y = [], []
    for i in range(1, 6):
        d = pickle.load(open(os.path.join('cifar-10-batches-py', f'data_batch_{i}'), 'rb'), encoding='bytes')
        X.append(d[b'data'])
        y.append(d[b'labels'])
        
    d = pickle.load(open(os.path.join('cifar-10-batches-py', 'test_batch'), 'rb'), encoding='bytes')
        
    return (
        np.concatenate(X, axis=0).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1),
        np.concatenate(y, axis=0),
        pickle.load(open('cifar-10-batches-py/batches.meta', 'rb'))['label_names'],
        np.array(d[b'data']).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1),
        np.array(d[b'labels'])
    )

In [None]:
def plot_confusion_matrix(model: tf.keras.models.Model, X: np.ndarray, y: np.ndarray, labels: list[str], batch_size: int = 8, **kwargs):
    sklearn.metrics.ConfusionMatrixDisplay.from_predictions(
        y,
        model.predict(X, verbose=False, batch_size=batch_size)[output_index].argmax(axis=-1),
        display_labels=labels,
        xticks_rotation='vertical',
        **kwargs
    )

In [None]:
X, y, labels, X_test, y_test = get_data()

In [None]:
model = ...

In [None]:
model.summary()

In [None]:
model.compile(loss=..., metrics=[...], optimizer=...)

In [None]:
logdir = os.path.join(HISTORY_DIR, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    os.path.join(logdir, 'model'),
    save_best_only=True
)

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    os.path.join(logdir, 'logs'),    
)

In [None]:
%tensorboard --logdir $logdir

In [None]:
model.fit(..., validation_data=(...), batch_size=..., epochs=..., callbacks=[model_checkpoint_callback, tensorboard_callback])

In [None]:
plot_confusion_matrix(model, X, y, labels=labels)

In [None]:
plot_confusion_matrix(model, X_test, y_test, labels=labels)