Skip to content

Optimizer and Learning Rate Scheduler

The Optimizer is at the heart of the Gradient Descent process and is a key component that we need to train a good model. Pytorch Tabular uses Adam optimizer with a learning rate of 1e-3 by default. This is mainly because of a rule of thumb which provides a good starting point.

Sometimes, Learning Rate Schedulers let's you have finer control in the way the learning rates are used through the optimization process. By default, PyTorch Tabular applies no Learning Rate Scheduler.

Basic Usage

  • optimizer: str: Any of the standard optimizers from torch.optim. Defaults to Adam
  • optimizer_params: Dict: The parameters for the optimizer. If left blank, will use default parameters.
  • lr_scheduler: str: The name of the LearningRateScheduler to use, if any, from torch.optim.lr_scheduler. If None, will not use any scheduler. Defaults to None
  • lr_scheduler_params: Dict: The parameters for the LearningRateScheduler. If left blank, will use default parameters.
  • lr_scheduler_monitor_metric: str: Used with ReduceLROnPlateau, where the plateau is decided based on this metric. Defaults to val_loss

Usage Example

optimizer_config = OptimizerConfig(
    optimizer="RMSprop", lr_scheduler="StepLR", lr_scheduler_params={"step_size": 10}
)

Advanced Usage

While the Config object restricts you to the standard Optimizers and Learning Rate Schedulers in torch.optim, you can use any custom Optimizer or Learning Rate Scheduler, as long as they are drop-in replacements for standard ones. You can do this using the fit method of TabularModel, which allows you to override the optimizer and learning rate which is set through config.

Usage Example

from torch_optimizer import QHAdam

tabular_model.fit(
    train=train,
    validation=val,
    optimizer=QHAdam,
    optimizer_params={"nus": (0.7, 1.0), "betas": (0.95, 0.998)},
)

pytorch_tabular.config.OptimizerConfig dataclass

Optimizer and Learning Rate Scheduler configuration.

Parameters:

Name Type Description Default
optimizer str

Any of the standard optimizers from torch.optim or provide full python path, for example "torch_optimizer.RAdam".

'Adam'
optimizer_params Dict

The parameters for the optimizer. If left blank, will use default parameters.

lambda: {}()
lr_scheduler Optional[str]

The name of the LearningRateScheduler to use, if any, from torch.optim.lr_scheduler. If None, will not use any scheduler. Defaults to None

None
lr_scheduler_params Optional[Dict]

The parameters for the LearningRateScheduler. If left blank, will use default parameters.

lambda: {}()
lr_scheduler_monitor_metric Optional[str]

Used with ReduceLROnPlateau, where the plateau is decided based on this metric

'valid_loss'
Source code in src/pytorch_tabular/config/config.py
@dataclass
class OptimizerConfig:
    """Optimizer and Learning Rate Scheduler configuration.

    Args:
        optimizer (str): Any of the standard optimizers from
                [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) or provide full python path,
                for example "torch_optimizer.RAdam".

        optimizer_params (Dict): The parameters for the optimizer. If left blank, will use default
                parameters.

        lr_scheduler (Optional[str]): The name of the LearningRateScheduler to use, if any, from
                [torch.optim.lr_scheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-
                rate). If None, will not use any scheduler. Defaults to `None`

        lr_scheduler_params (Optional[Dict]): The parameters for the LearningRateScheduler. If left blank,
                will use default parameters.

        lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is
                decided based on this metric

    """

    optimizer: str = field(
        default="Adam",
        metadata={
            "help": "Any of the standard optimizers from"
            " [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) or provide full python path,"
            " for example 'torch_optimizer.RAdam'."
        },
    )
    optimizer_params: Dict = field(
        default_factory=lambda: {},
        metadata={"help": "The parameters for the optimizer. If left blank, will use default parameters."},
    )
    lr_scheduler: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the LearningRateScheduler to use, if any, from"
            " https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate."
            " If None, will not use any scheduler. Defaults to `None`",
        },
    )
    lr_scheduler_params: Optional[Dict] = field(
        default_factory=lambda: {},
        metadata={"help": "The parameters for the LearningRateScheduler. If left blank, will use default parameters."},
    )

    lr_scheduler_monitor_metric: Optional[str] = field(
        default="valid_loss",
        metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
    )

    @staticmethod
    def read_from_yaml(filename: str = "config/optimizer_config.yml"):
        config = _read_yaml(filename)
        if config["lr_scheduler_params"] is None:
            config["lr_scheduler_params"] = {}
        return OptimizerConfig(**config)