Freak

Freak is a Python package which allows interacting with the program’s state remotely. You can define the state object as a pydantic model and use Freak to expose it over HTTP. It supports nested models, partial updates, data validation, and uses FastAPI to run the web server.

It can be useful to quickly setup control over long-running programs like bots or neural network trainings.

Here freak is used to create a PyTorch Lightning Callback for graceful model training interruption. This is particularly useful in distributed training context.

from typing import Optional

import lightning as L
from freak import Freak
from lightning.pytorch.utilities.distributed import rank_zero_only
from pydantic import BaseModel


class TrainingState(BaseModel):
    should_stop: bool = False


class TrainingStopCallback(L.Callback):
    """
    Callback which stops training when self.state.shoudl_stop is set to True.
    """

    def __init__(self, freak: Optional[Freak] = None):
        self.freak = freak if freak is not None else Freak(host="127.0.0.1")
        self.state = TrainingState()

    @rank_zero_only
    def on_train_start(
        self, trainer: "L.Trainer", pl_module: "L.LightningModule"
    ) -> None:
        self.freak.control(self.state)  # launch the Freak server in a background thread

    def on_train_epoch_end(
        self, trainer: "L.Trainer", pl_module: "L.LightningModule"
    ) -> None:
        self.state = trainer.strategy.broadcast(self.state, 0)

        if self.state.should_stop:  # call the Freak API to set this to True
            # this triggers lightning to stop training
            trainer.should_stop = True
            trainer.strategy.barrier()

    @rank_zero_only
    def on_train_end(
        self, trainer: "L.Trainer", pl_module: "L.LightningModule"
    ) -> None:
        self.freak.stop()