Skip to content

Core Classes

Source code in src/pytorch_tabular/tabular_model.py
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
class TabularModel:
    def __init__(
        self,
        config: Optional[DictConfig] = None,
        data_config: Optional[Union[DataConfig, str]] = None,
        model_config: Optional[Union[ModelConfig, str]] = None,
        optimizer_config: Optional[Union[OptimizerConfig, str]] = None,
        trainer_config: Optional[Union[TrainerConfig, str]] = None,
        experiment_config: Optional[Union[ExperimentConfig, str]] = None,
        model_callable: Optional[Callable] = None,
        model_state_dict_path: Optional[Union[str, Path]] = None,
        verbose: bool = True,
        suppress_lightning_logger: bool = False,
    ) -> None:
        """The core model which orchestrates everything from initializing the datamodule, the model, trainer, etc.

        Args:
            config (Optional[Union[DictConfig, str]], optional): Single OmegaConf DictConfig object or
                the path to the yaml file holding all the config parameters. Defaults to None.

            data_config (Optional[Union[DataConfig, str]], optional):
                DataConfig object or path to the yaml file. Defaults to None.

            model_config (Optional[Union[ModelConfig, str]], optional):
                A subclass of ModelConfig or path to the yaml file.
                Determines which model to run from the type of config. Defaults to None.

            optimizer_config (Optional[Union[OptimizerConfig, str]], optional):
                OptimizerConfig object or path to the yaml file. Defaults to None.

            trainer_config (Optional[Union[TrainerConfig, str]], optional):
                TrainerConfig object or path to the yaml file. Defaults to None.

            experiment_config (Optional[Union[ExperimentConfig, str]], optional):
                ExperimentConfig object or path to the yaml file.
                If Provided configures the experiment tracking. Defaults to None.

            model_callable (Optional[Callable], optional):
                If provided, will override the model callable that will be loaded from the config.
                Typically used when providing Custom Models

            model_state_dict_path (Optional[Union[str, Path]], optional):
                If provided, will load the state dict after initializing the model from config.

            verbose (bool): turns off and on the logging. Defaults to True.

            suppress_lightning_logger (bool): If True, will suppress the default logging from PyTorch Lightning.
                Defaults to False.

        """
        super().__init__()
        if suppress_lightning_logger:
            suppress_lightning_logs()
        self.verbose = verbose
        self.exp_manager = ExperimentRunManager()
        if config is None:
            assert any(c is not None for c in (data_config, model_config, optimizer_config, trainer_config)), (
                "If `config` is None, `data_config`, `model_config`,"
                " `trainer_config`, and `optimizer_config` cannot be None"
            )
            data_config = self._read_parse_config(data_config, DataConfig)
            model_config = self._read_parse_config(model_config, ModelConfig)
            trainer_config = self._read_parse_config(trainer_config, TrainerConfig)
            optimizer_config = self._read_parse_config(optimizer_config, OptimizerConfig)
            if model_config.task != "ssl":
                assert data_config.target is not None, (
                    "`target` in data_config should not be None for" f" {model_config.task} task"
                )
            if experiment_config is None:
                if self.verbose:
                    logger.info("Experiment Tracking is turned off")
                self.track_experiment = False
                self.config = OmegaConf.merge(
                    OmegaConf.to_container(data_config),
                    OmegaConf.to_container(model_config),
                    OmegaConf.to_container(trainer_config),
                    OmegaConf.to_container(optimizer_config),
                )
            else:
                experiment_config = self._read_parse_config(experiment_config, ExperimentConfig)
                self.track_experiment = True
                self.config = OmegaConf.merge(
                    OmegaConf.to_container(data_config),
                    OmegaConf.to_container(model_config),
                    OmegaConf.to_container(trainer_config),
                    OmegaConf.to_container(experiment_config),
                    OmegaConf.to_container(optimizer_config),
                )
        else:
            self.config = config
            if hasattr(config, "log_target") and (config.log_target is not None):
                # experiment_config = OmegaConf.structured(experiment_config)
                self.track_experiment = True
            else:
                if self.verbose:
                    logger.info("Experiment Tracking is turned off")
                self.track_experiment = False

        self.run_name, self.uid = self._get_run_name_uid()
        if self.track_experiment:
            self._setup_experiment_tracking()
        else:
            self.logger = None

        self.exp_manager = ExperimentRunManager()
        if model_callable is None:
            self.model_callable = getattr_nested(self.config._module_src, self.config._model_name)
            self.custom_model = False
        else:
            self.model_callable = model_callable
            self.custom_model = True
        self.model_state_dict_path = model_state_dict_path
        self._is_config_updated_with_data = False
        self._run_validation()
        self._is_fitted = False

    @property
    def has_datamodule(self):
        if hasattr(self, "datamodule") and self.datamodule is not None:
            return True
        else:
            return False

    @property
    def has_model(self):
        if hasattr(self, "model") and self.model is not None:
            return True
        else:
            return False

    @property
    def is_fitted(self):
        return self._is_fitted

    @property
    def name(self):
        if self.has_model:
            return self.model.__class__.__name__
        else:
            return self.config._model_name

    @property
    def num_params(self):
        if self.has_model:
            return count_parameters(self.model)

    def _run_validation(self):
        """Validates the Config params and throws errors if something is wrong."""
        if self.config.task == "classification":
            if len(self.config.target) > 1:
                raise NotImplementedError("Multi-Target Classification is not implemented.")
        if self.config.task == "regression":
            if self.config.target_range is not None:
                if (
                    (len(self.config.target_range) != len(self.config.target))
                    or any(len(range_) != 2 for range_ in self.config.target_range)
                    or any(range_[0] > range_[1] for range_ in self.config.target_range)
                ):
                    raise ValueError(
                        "Targe Range, if defined, should be list tuples of length"
                        " two(min,max). The length of the list should be equal to hte"
                        " length of target columns"
                    )
        if self.config.task == "ssl":
            assert not self.config.handle_unknown_categories, (
                "SSL only supports handle_unknown_categories=False. Please set this" " in your DataConfig"
            )
            assert not self.config.handle_missing_values, (
                "SSL only supports handle_missing_values=False. Please set this in" " your DataConfig"
            )

    def _read_parse_config(self, config, cls):
        if isinstance(config, str):
            if os.path.exists(config):
                _config = OmegaConf.load(config)
                if cls == ModelConfig:
                    cls = getattr_nested(_config._module_src, _config._config_name)
                config = cls(
                    **{
                        k: v
                        for k, v in _config.items()
                        if (k in cls.__dataclass_fields__.keys()) and (cls.__dataclass_fields__[k].init)
                    }
                )
            else:
                raise ValueError(f"{config} is not a valid path")
        config = OmegaConf.structured(config)
        return config

    def _get_run_name_uid(self) -> Tuple[str, int]:
        """Gets the name of the experiment and increments version by 1.

        Returns:
            tuple[str, int]: Returns the name and version number

        """
        if hasattr(self.config, "run_name") and self.config.run_name is not None:
            name = self.config.run_name
        elif hasattr(self.config, "checkpoints_name") and self.config.checkpoints_name is not None:
            name = self.config.checkpoints_name
        else:
            name = self.config.task
        uid = self.exp_manager.update_versions(name)
        return name, uid

    def _setup_experiment_tracking(self):
        """Sets up the Experiment Tracking Framework according to the choices made in the Experimentconfig."""
        if self.config.log_target == "tensorboard":
            self.logger = pl.loggers.TensorBoardLogger(
                name=self.run_name, save_dir=self.config.project_name, version=self.uid
            )
        elif self.config.log_target == "wandb":
            self.logger = pl.loggers.WandbLogger(
                name=f"{self.run_name}_{self.uid}",
                project=self.config.project_name,
                offline=False,
            )
        else:
            raise NotImplementedError(
                f"{self.config.log_target} is not implemented. Try one of [wandb," " tensorboard]"
            )

    def _prepare_callbacks(self, callbacks=None) -> List:
        """Prepares the necesary callbacks to the Trainer based on the configuration.

        Returns:
            List: A list of callbacks

        """
        callbacks = [] if callbacks is None else callbacks
        if self.config.early_stopping is not None:
            early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(
                monitor=self.config.early_stopping,
                min_delta=self.config.early_stopping_min_delta,
                patience=self.config.early_stopping_patience,
                mode=self.config.early_stopping_mode,
                **self.config.early_stopping_kwargs,
            )
            callbacks.append(early_stop_callback)
        if self.config.checkpoints:
            ckpt_name = f"{self.run_name}-{self.uid}"
            ckpt_name = ckpt_name.replace(" ", "_") + "_{epoch}-{valid_loss:.2f}"
            model_checkpoint = pl.callbacks.ModelCheckpoint(
                monitor=self.config.checkpoints,
                dirpath=self.config.checkpoints_path,
                filename=ckpt_name,
                save_top_k=self.config.checkpoints_save_top_k,
                mode=self.config.checkpoints_mode,
                every_n_epochs=self.config.checkpoints_every_n_epochs,
                **self.config.checkpoints_kwargs,
            )
            callbacks.append(model_checkpoint)
            self.config.enable_checkpointing = True
        else:
            self.config.enable_checkpointing = False
        if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get("enable_progress_bar", True):
            callbacks.append(RichProgressBar())
        if self.verbose:
            logger.debug(f"Callbacks used: {callbacks}")
        return callbacks

    def _prepare_trainer(self, callbacks: List, max_epochs: int = None, min_epochs: int = None) -> pl.Trainer:
        """Prepares the Trainer object.

        Args:
            callbacks (List): A list of callbacks to be used
            max_epochs (int, optional): Maximum number of epochs to train for. Defaults to None.
            min_epochs (int, optional): Minimum number of epochs to train for. Defaults to None.

        Returns:
            pl.Trainer: A PyTorch Lightning Trainer object

        """
        if self.verbose:
            logger.info("Preparing the Trainer")
        if max_epochs is not None:
            self.config.max_epochs = max_epochs
        if min_epochs is not None:
            self.config.min_epochs = min_epochs
        # Getting Trainer Arguments from the init signature
        trainer_sig = inspect.signature(pl.Trainer.__init__)
        trainer_args = [p for p in trainer_sig.parameters.keys() if p != "self"]
        trainer_args_config = {k: v for k, v in self.config.items() if k in trainer_args}
        # For some weird reason, checkpoint_callback is not appearing in the Trainer vars
        trainer_args_config["enable_checkpointing"] = self.config.enable_checkpointing
        # turn off progress bar if progress_bar=='none'
        trainer_args_config["enable_progress_bar"] = self.config.progress_bar != "none"
        # Adding trainer_kwargs from config to trainer_args
        trainer_args_config.update(self.config.trainer_kwargs)
        if trainer_args_config["devices"] == -1:
            # Setting devices to auto if -1 so that lightning will use all available GPUs/CPUs
            trainer_args_config["devices"] = "auto"
        return pl.Trainer(
            logger=self.logger,
            callbacks=callbacks,
            **trainer_args_config,
        )

    def _check_and_set_target_transform(self, target_transform):
        if target_transform is not None:
            if isinstance(target_transform, Iterable):
                assert len(target_transform) == 2, (
                    "If `target_transform` is a tuple, it should have and only have"
                    " forward and backward transformations"
                )
            elif isinstance(target_transform, TransformerMixin):
                pass
            else:
                raise ValueError(
                    "`target_transform` should wither be an sklearn Transformer or a" " tuple of callables."
                )
        if self.config.task == "classification" and target_transform is not None:
            logger.warning("For classification task, target transform is not used. Ignoring the" " parameter")
            target_transform = None
        return target_transform

    def _prepare_for_training(self, model, datamodule, callbacks=None, max_epochs=None, min_epochs=None):
        self.callbacks = self._prepare_callbacks(callbacks)
        self.trainer = self._prepare_trainer(self.callbacks, max_epochs, min_epochs)
        self.model = model
        self.datamodule = datamodule

    @classmethod
    def _load_weights(cls, model, path: Union[str, Path]) -> None:
        """Loads the model weights in the specified directory.

        Args:
            path (str): The path to the file to load the model from

        Returns:
            None

        """
        ckpt = pl_load(path, map_location=lambda storage, loc: storage)
        model.load_state_dict(ckpt.get("state_dict") or ckpt)

    @classmethod
    def load_model(cls, dir: str, map_location=None, strict=True):
        """Loads a saved model from the directory.

        Args:
            dir (str): The directory where the model wa saved, along with the checkpoints
            map_location (Union[Dict[str, str], str, device, int, Callable, None]) : If your checkpoint
                saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map
                to the new setup. The behaviour is the same as in torch.load()
            strict (bool) : Whether to strictly enforce that the keys in checkpoint_path match the keys
                returned by this module's state dict. Default: True.

        Returns:
            TabularModel (TabularModel): The saved TabularModel

        """
        config = OmegaConf.load(os.path.join(dir, "config.yml"))
        datamodule = joblib.load(os.path.join(dir, "datamodule.sav"))
        if (
            hasattr(config, "log_target")
            and (config.log_target is not None)
            and os.path.exists(os.path.join(dir, "exp_logger.sav"))
        ):
            logger = joblib.load(os.path.join(dir, "exp_logger.sav"))
        else:
            logger = None
        if os.path.exists(os.path.join(dir, "callbacks.sav")):
            callbacks = joblib.load(os.path.join(dir, "callbacks.sav"))
            # Excluding Gradient Accumulation Scheduler Callback as we are creating
            # a new one in trainer
            callbacks = [c for c in callbacks if not isinstance(c, GradientAccumulationScheduler)]
        else:
            callbacks = []
        if os.path.exists(os.path.join(dir, "custom_model_callable.sav")):
            model_callable = joblib.load(os.path.join(dir, "custom_model_callable.sav"))
            custom_model = True
        else:
            model_callable = getattr_nested(config._module_src, config._model_name)
            # model_callable = getattr(
            #     getattr(models, config._module_src), config._model_name
            # )
            custom_model = False
        inferred_config = datamodule.update_config(config)
        inferred_config = OmegaConf.structured(inferred_config)
        model_args = {
            "config": config,
            "inferred_config": inferred_config,
        }
        custom_params = joblib.load(os.path.join(dir, "custom_params.sav"))
        if custom_params.get("custom_loss") is not None:
            model_args["loss"] = "MSELoss"  # For compatibility. Not Used
        if custom_params.get("custom_metrics") is not None:
            model_args["metrics"] = ["mean_squared_error"]  # For compatibility. Not Used
            model_args["metrics_params"] = [{}]  # For compatibility. Not Used
            model_args["metrics_prob_inputs"] = [False]  # For compatibility. Not Used
        if custom_params.get("custom_optimizer") is not None:
            model_args["optimizer"] = "Adam"  # For compatibility. Not Used
        if custom_params.get("custom_optimizer_params") is not None:
            model_args["optimizer_params"] = {}  # For compatibility. Not Used

        # Initializing with default metrics, losses, and optimizers. Will revert once initialized
        try:
            model = model_callable.load_from_checkpoint(
                checkpoint_path=os.path.join(dir, "model.ckpt"),
                map_location=map_location,
                strict=strict,
                **model_args,
            )
        except RuntimeError as e:
            if (
                "Unexpected key(s) in state_dict" in str(e)
                and "loss.weight" in str(e)
                and "custom_loss.weight" in str(e)
            ):
                # Custom loss will be loaded after the model is initialized
                # continuing with strict=False
                model = model_callable.load_from_checkpoint(
                    checkpoint_path=os.path.join(dir, "model.ckpt"),
                    map_location=map_location,
                    strict=False,
                    **model_args,
                )
            else:
                raise e
        if custom_params.get("custom_optimizer") is not None:
            model.custom_optimizer = custom_params["custom_optimizer"]
        if custom_params.get("custom_optimizer_params") is not None:
            model.custom_optimizer_params = custom_params["custom_optimizer_params"]
        if custom_params.get("custom_loss") is not None:
            model.loss = custom_params["custom_loss"]
        if custom_params.get("custom_metrics") is not None:
            model.custom_metrics = custom_params.get("custom_metrics")
            model.hparams.metrics = [m.__name__ for m in custom_params.get("custom_metrics")]
            model.hparams.metrics_params = [{}]
            model.hparams.metrics_prob_input = custom_params.get("custom_metrics_prob_inputs")
        model._setup_loss()
        model._setup_metrics()
        tabular_model = cls(config=config, model_callable=model_callable)
        tabular_model.model = model
        tabular_model.custom_model = custom_model
        tabular_model.datamodule = datamodule
        tabular_model.callbacks = callbacks
        tabular_model.trainer = tabular_model._prepare_trainer(callbacks=callbacks)
        # tabular_model.trainer.model = model
        tabular_model.logger = logger
        return tabular_model

    def prepare_dataloader(
        self,
        train: DataFrame,
        validation: Optional[DataFrame] = None,
        train_sampler: Optional[torch.utils.data.Sampler] = None,
        target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
        seed: Optional[int] = 42,
        cache_data: str = "memory",
    ) -> TabularDatamodule:
        """Prepares the dataloaders for training and validation.

        Args:
            train (DataFrame): Training Dataframe

            validation (Optional[DataFrame], optional):
                If provided, will use this dataframe as the validation while training.
                Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
                Defaults to None.

            train_sampler (Optional[torch.utils.data.Sampler], optional):
                Custom PyTorch batch samplers which will be passed to the DataLoaders.
                Useful for dealing with imbalanced data and other custom batching strategies

            target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional):
                If provided, applies the transform to the target before modelling and inverse the transform during
                prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or
                a tuple of callables (transform_func, inverse_transform_func)

            seed (Optional[int], optional): Random seed for reproducibility. Defaults to 42.

            cache_data (str): Decides how to cache the data in the dataloader. If set to
                "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
        Returns:
            TabularDatamodule: The prepared datamodule

        """
        if self.verbose:
            logger.info("Preparing the DataLoaders")
        target_transform = self._check_and_set_target_transform(target_transform)

        datamodule = TabularDatamodule(
            train=train,
            validation=validation,
            config=self.config,
            target_transform=target_transform,
            train_sampler=train_sampler,
            seed=seed,
            cache_data=cache_data,
            verbose=self.verbose,
        )
        datamodule.prepare_data()
        datamodule.setup("fit")
        return datamodule

    def prepare_model(
        self,
        datamodule: TabularDatamodule,
        loss: Optional[torch.nn.Module] = None,
        metrics: Optional[List[Callable]] = None,
        metrics_prob_inputs: Optional[List[bool]] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        optimizer_params: Dict = None,
    ) -> BaseModel:
        """Prepares the model for training.

        Args:
            datamodule (TabularDatamodule): The datamodule

            loss (Optional[torch.nn.Module], optional): Custom Loss functions which are not in standard pytorch library

            metrics (Optional[List[Callable]], optional): Custom metric functions(Callable) which has the
                signature metric_fn(y_hat, y) and works on torch tensor inputs

            metrics_prob_inputs (Optional[List[bool]], optional): This is a mandatory parameter for
                classification metrics. If the metric function requires probabilities as inputs, set this to True.
                The length of the list should be equal to the number of metrics. Defaults to None.

            optimizer (Optional[torch.optim.Optimizer], optional):
                Custom optimizers which are a drop in replacements for standard PyTorch optimizers.
                This should be the Class and not the initialized object

            optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

        Returns:
            BaseModel: The prepared model

        """
        if self.verbose:
            logger.info(f"Preparing the Model: {self.config._model_name}")
        # Fetching the config as some data specific configs have been added in the datamodule
        self.inferred_config = self._read_parse_config(datamodule.update_config(self.config), InferredConfig)
        model = self.model_callable(
            self.config,
            custom_loss=loss,  # Unused in SSL tasks
            custom_metrics=metrics,  # Unused in SSL tasks
            custom_metrics_prob_inputs=metrics_prob_inputs,  # Unused in SSL tasks
            custom_optimizer=optimizer,
            custom_optimizer_params=optimizer_params or {},
            inferred_config=self.inferred_config,
        )
        # Data Aware Initialization(for the models that need it)
        model.data_aware_initialization(datamodule)
        if self.model_state_dict_path is not None:
            self._load_weights(model, self.model_state_dict_path)
        if self.track_experiment and self.config.log_target == "wandb":
            self.logger.watch(model, log=self.config.exp_watch, log_freq=self.config.exp_log_freq)
        return model

    def train(
        self,
        model: pl.LightningModule,
        datamodule: TabularDatamodule,
        callbacks: Optional[List[pl.Callback]] = None,
        max_epochs: int = None,
        min_epochs: int = None,
        handle_oom: bool = True,
    ) -> pl.Trainer:
        """Trains the model.

        Args:
            model (pl.LightningModule): The PyTorch Lightning model to be trained.

            datamodule (TabularDatamodule): The datamodule

            callbacks (Optional[List[pl.Callback]], optional):
                List of callbacks to be used during training. Defaults to None.

            max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

            min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

            handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.

        Returns:
            pl.Trainer: The PyTorch Lightning Trainer instance

        """
        self._prepare_for_training(model, datamodule, callbacks, max_epochs, min_epochs)
        train_loader, val_loader = (
            self.datamodule.train_dataloader(),
            self.datamodule.val_dataloader(),
        )
        self.model.train()
        if self.config.auto_lr_find and (not self.config.fast_dev_run):
            if self.verbose:
                logger.info("Auto LR Find Started")
            with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
                result = Tuner(self.trainer).lr_find(
                    self.model,
                    train_dataloaders=train_loader,
                    val_dataloaders=val_loader,
                )
            if oom_handler.oom_triggered:
                raise OOMException(
                    "OOM detected during LR Find. Try reducing your batch_size or the"
                    " model parameters." + "/n" + "Original Error: " + oom_handler.oom_msg
                )
            if self.verbose:
                logger.info(
                    f"Suggested LR: {result.suggestion()}. For plot and detailed"
                    " analysis, use `find_learning_rate` method."
                )
            self.model.reset_weights()
            # Parameters in models needs to be initialized again after LR find
            self.model.data_aware_initialization(self.datamodule)
        self.model.train()
        if self.verbose:
            logger.info("Training Started")
        with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
            self.trainer.fit(self.model, train_loader, val_loader)
        if oom_handler.oom_triggered:
            raise OOMException(
                "OOM detected during Training. Try reducing your batch_size or the"
                " model parameters."
                "/n" + "Original Error: " + oom_handler.oom_msg
            )
        self._is_fitted = True
        if self.verbose:
            logger.info("Training the model completed")
        if self.config.load_best:
            self.load_best_model()
        return self.trainer

    def fit(
        self,
        train: Optional[DataFrame],
        validation: Optional[DataFrame] = None,
        loss: Optional[torch.nn.Module] = None,
        metrics: Optional[List[Callable]] = None,
        metrics_prob_inputs: Optional[List[bool]] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        optimizer_params: Dict = None,
        train_sampler: Optional[torch.utils.data.Sampler] = None,
        target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
        max_epochs: Optional[int] = None,
        min_epochs: Optional[int] = None,
        seed: Optional[int] = 42,
        callbacks: Optional[List[pl.Callback]] = None,
        datamodule: Optional[TabularDatamodule] = None,
        cache_data: str = "memory",
        handle_oom: bool = True,
    ) -> pl.Trainer:
        """The fit method which takes in the data and triggers the training.

        Args:
            train (DataFrame): Training Dataframe

            validation (Optional[DataFrame], optional):
                If provided, will use this dataframe as the validation while training.
                Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
                Defaults to None.

            loss (Optional[torch.nn.Module], optional): Custom Loss functions which are not in standard pytorch library

            metrics (Optional[List[Callable]], optional): Custom metric functions(Callable) which has the
                signature metric_fn(y_hat, y) and works on torch tensor inputs. y_hat is expected to be of shape
                (batch_size, num_classes) for classification and (batch_size, 1) for regression and y is expected to be
                of shape (batch_size, 1)

            metrics_prob_inputs (Optional[List[bool]], optional): This is a mandatory parameter for
                classification metrics. If the metric function requires probabilities as inputs, set this to True.
                The length of the list should be equal to the number of metrics. Defaults to None.

            optimizer (Optional[torch.optim.Optimizer], optional):
                Custom optimizers which are a drop in replacements for
                standard PyTorch optimizers. This should be the Class and not the initialized object

            optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

            train_sampler (Optional[torch.utils.data.Sampler], optional):
                Custom PyTorch batch samplers which will be passed
                to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies

            target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional):
                If provided, applies the transform to the target before modelling and inverse the transform during
                prediction. The parameter can either be a sklearn Transformer
                which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func)

            max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

            min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

            seed: (int): Random seed for reproducibility. Defaults to 42.

            callbacks (Optional[List[pl.Callback]], optional):
                List of callbacks to be used during training. Defaults to None.

            datamodule (Optional[TabularDatamodule], optional): The datamodule.
                If provided, will ignore the rest of the parameters like train, test etc and use the datamodule.
                Defaults to None.

            cache_data (str): Decides how to cache the data in the dataloader. If set to
                "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

            handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.

        Returns:
            pl.Trainer: The PyTorch Lightning Trainer instance

        """
        assert self.config.task != "ssl", (
            "`fit` is not valid for SSL task. Please use `pretrain` for" " semi-supervised learning"
        )
        if metrics is not None:
            assert len(metrics) == len(
                metrics_prob_inputs or []
            ), "The length of `metrics` and `metrics_prob_inputs` should be equal"
        seed = seed or self.config.seed
        if seed:
            seed_everything(seed)
        if datamodule is None:
            datamodule = self.prepare_dataloader(
                train,
                validation,
                train_sampler,
                target_transform,
                seed,
                cache_data,
            )
        else:
            if train is not None:
                warnings.warn(
                    "train data and datamodule is provided."
                    " Ignoring the train data and using the datamodule."
                    " Set either one of them to None to avoid this warning."
                )
        model = self.prepare_model(
            datamodule,
            loss,
            metrics,
            metrics_prob_inputs,
            optimizer,
            optimizer_params or {},
        )

        return self.train(model, datamodule, callbacks, max_epochs, min_epochs, handle_oom)

    def pretrain(
        self,
        train: Optional[DataFrame],
        validation: Optional[DataFrame] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        optimizer_params: Dict = None,
        # train_sampler: Optional[torch.utils.data.Sampler] = None,
        max_epochs: Optional[int] = None,
        min_epochs: Optional[int] = None,
        seed: Optional[int] = 42,
        callbacks: Optional[List[pl.Callback]] = None,
        datamodule: Optional[TabularDatamodule] = None,
        cache_data: str = "memory",
    ) -> pl.Trainer:
        """The pretrained method which takes in the data and triggers the training.

        Args:
            train (DataFrame): Training Dataframe

            validation (Optional[DataFrame], optional): If provided, will use this dataframe as the validation while
                training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
                Defaults to None.

            optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizers which are a drop in replacements
                for standard PyTorch optimizers. This should be the Class and not the initialized object

            optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

            max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

            min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

            seed: (int): Random seed for reproducibility. Defaults to 42.

            callbacks (Optional[List[pl.Callback]], optional): List of callbacks to be used during training.
                Defaults to None.

            datamodule (Optional[TabularDatamodule], optional): The datamodule. If provided, will ignore the rest of the
                parameters like train, test etc. and use the datamodule. Defaults to None.

            cache_data (str): Decides how to cache the data in the dataloader. If set to
                "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
        Returns:
            pl.Trainer: The PyTorch Lightning Trainer instance

        """
        assert self.config.task == "ssl", (
            f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" " instead."
        )
        seed = seed or self.config.seed
        if seed:
            seed_everything(seed)
        if datamodule is None:
            datamodule = self.prepare_dataloader(
                train,
                validation,
                train_sampler=None,
                target_transform=None,
                seed=seed,
                cache_data=cache_data,
            )
        else:
            if train is not None:
                warnings.warn(
                    "train data and datamodule is provided."
                    " Ignoring the train data and using the datamodule."
                    " Set either one of them to None to avoid this warning."
                )
        model = self.prepare_model(
            datamodule,
            optimizer,
            optimizer_params or {},
        )

        return self.train(model, datamodule, callbacks, max_epochs, min_epochs)

    def create_finetune_model(
        self,
        task: str,
        head: str,
        head_config: Dict,
        train: DataFrame,
        validation: Optional[DataFrame] = None,
        train_sampler: Optional[torch.utils.data.Sampler] = None,
        target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
        target: Optional[str] = None,
        optimizer_config: Optional[OptimizerConfig] = None,
        trainer_config: Optional[TrainerConfig] = None,
        experiment_config: Optional[ExperimentConfig] = None,
        loss: Optional[torch.nn.Module] = None,
        metrics: Optional[List[Union[Callable, str]]] = None,
        metrics_prob_input: Optional[List[bool]] = None,
        metrics_params: Optional[Dict] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        optimizer_params: Dict = None,
        learning_rate: Optional[float] = None,
        target_range: Optional[Tuple[float, float]] = None,
        seed: Optional[int] = 42,
    ):
        """Creates a new TabularModel model using the pretrained weights and the new task and head.

        Args:
            task (str): The task to be performed. One of "regression", "classification"

            head (str): The head to be used for the model. Should be one of the heads defined
                in `pytorch_tabular.models.common.heads`. Defaults to  LinearHead. Choices are:
                [`None`,`LinearHead`,`MixtureDensityHead`].

            head_config (Dict): The config as a dict which defines the head. If left empty,
                will be initialized as default linear head.

            train (DataFrame): The training data with labels

            validation (Optional[DataFrame], optional): The validation data with labels. Defaults to None.

            train_sampler (Optional[torch.utils.data.Sampler], optional): If provided, will be used as a batch sampler
                for training. Defaults to None.

            target_transform (Optional[Union[TransformerMixin, Tuple]], optional): If provided, will be used
                to transform the target before training and inverse transform the predictions.

            target (Optional[str], optional): The target column name if not provided in the initial pretraining stage.
                Defaults to None.

            optimizer_config (Optional[OptimizerConfig], optional):
                If provided, will redefine the optimizer for fine-tuning stage. Defaults to None.

            trainer_config (Optional[TrainerConfig], optional):
                If provided, will redefine the trainer for fine-tuning stage. Defaults to None.

            experiment_config (Optional[ExperimentConfig], optional):
                If provided, will redefine the experiment for fine-tuning stage. Defaults to None.

            loss (Optional[torch.nn.Module], optional):
                If provided, will be used as the loss function for the fine-tuning.
                By default, it is MSELoss for regression and CrossEntropyLoss for classification.

            metrics (Optional[List[Callable]], optional): List of metrics (either callables or str) to be used for the
                fine-tuning stage. If str, it should be one of the functional metrics implemented in
                ``torchmetrics.functional``. Defaults to None.

            metrics_prob_input (Optional[List[bool]], optional): Is a mandatory parameter for classification metrics
                This defines whether the input to the metric function is the probability or the class.
                Length should be same as the number of metrics. Defaults to None.

            metrics_params (Optional[Dict], optional): The parameters for the metrics in the same order as metrics.
                For eg. f1_score for multi-class needs a parameter `average` to fully define the metric.
                Defaults to None.

            optimizer (Optional[torch.optim.Optimizer], optional):
                Custom optimizers which are a drop in replacements for standard PyTorch optimizers. If provided,
                the OptimizerConfig is ignored in favor of this. Defaults to None.

            optimizer_params (Dict, optional): The parameters for the optimizer. Defaults to {}.

            learning_rate (Optional[float], optional): The learning rate to be used. Defaults to 1e-3.

            target_range (Optional[Tuple[float, float]], optional): The target range for the regression task.
                Is ignored for classification. Defaults to None.

            seed (Optional[int], optional): Random seed for reproducibility. Defaults to 42.
        Returns:
            TabularModel (TabularModel): The new TabularModel model for fine-tuning

        """
        config = self.config
        optimizer_params = optimizer_params or {}
        if target is None:
            assert (
                hasattr(config, "target") and config.target is not None
            ), "`target` cannot be None if it was not set in the initial `DataConfig`"
        else:
            assert isinstance(target, list), "`target` should be a list of strings"
            config.target = target
        config.task = task
        # Add code to update configs with newly provided ones
        if optimizer_config is not None:
            for key, value in optimizer_config.__dict__.items():
                config[key] = value
            if len(optimizer_params) > 0:
                config.optimizer_params = optimizer_params
            else:
                config.optimizer_params = {}
        if trainer_config is not None:
            for key, value in trainer_config.__dict__.items():
                config[key] = value
        if experiment_config is not None:
            for key, value in experiment_config.__dict__.items():
                config[key] = value
        else:
            if self.track_experiment:
                # Renaming the experiment run so that a different log is created for finetuning
                if self.verbose:
                    logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}")
                config["run_name"] = config["run_name"] + "_finetuned"

        datamodule = self.datamodule.copy(
            train=train,
            validation=validation,
            target_transform=target_transform,
            train_sampler=train_sampler,
            seed=seed,
            config_override={"target": target} if target is not None else {},
        )
        model_callable = _GenericModel
        inferred_config = OmegaConf.structured(datamodule._inferred_config)
        # Adding dummy attributes for compatibility. Not used because custom metrics are provided
        if not hasattr(config, "metrics"):
            config.metrics = "dummy"
        if not hasattr(config, "metrics_params"):
            config.metrics_params = {}
        if not hasattr(config, "metrics_prob_input"):
            config.metrics_prob_input = metrics_prob_input or [False]
        if metrics is not None:
            assert len(metrics) == len(metrics_params), "Number of metrics and metrics_params should be same"
            assert len(metrics) == len(metrics_prob_input), "Number of metrics and metrics_prob_input should be same"
            metrics = [getattr(torchmetrics.functional, m) if isinstance(m, str) else m for m in metrics]
        if task == "regression":
            loss = loss or torch.nn.MSELoss()
            if metrics is None:
                metrics = [torchmetrics.functional.mean_squared_error]
                metrics_params = [{}]
        elif task == "classification":
            loss = loss or torch.nn.CrossEntropyLoss()
            if metrics is None:
                metrics = [torchmetrics.functional.accuracy]
                metrics_params = [
                    {
                        "task": "multiclass",
                        "num_classes": inferred_config.output_dim,
                        "top_k": 1,
                    }
                ]
                metrics_prob_input = [False]
            else:
                for i, mp in enumerate(metrics_params):
                    # For classification task, output_dim == number of classses
                    metrics_params[i]["task"] = mp.get("task", "multiclass")
                    metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim)
                    metrics_params[i]["top_k"] = mp.get("top_k", 1)
        else:
            raise ValueError(f"Task {task} not supported")
        # Forming partial callables using metrics and metric params
        metrics = [partial(m, **mp) for m, mp in zip(metrics, metrics_params)]
        self.model.mode = "finetune"
        if learning_rate is not None:
            config.learning_rate = learning_rate
        config.target_range = target_range
        model_args = {
            "backbone": self.model,
            "head": head,
            "head_config": head_config,
            "config": config,
            "inferred_config": inferred_config,
            "custom_loss": loss,
            "custom_metrics": metrics,
            "custom_metrics_prob_inputs": metrics_prob_input,
            "custom_optimizer": optimizer,
            "custom_optimizer_params": optimizer_params,
        }
        # Initializing with default metrics, losses, and optimizers. Will revert once initialized
        model = model_callable(
            **model_args,
        )
        tabular_model = TabularModel(config=config, verbose=self.verbose)
        tabular_model.model = model
        tabular_model.datamodule = datamodule
        # Setting a flag to identify this as a fine-tune model
        tabular_model._is_finetune_model = True
        return tabular_model

    def finetune(
        self,
        max_epochs: Optional[int] = None,
        min_epochs: Optional[int] = None,
        callbacks: Optional[List[pl.Callback]] = None,
        freeze_backbone: bool = False,
    ) -> pl.Trainer:
        """Finetunes the model on the provided data.

        Args:
            max_epochs (Optional[int], optional): The maximum number of epochs to train for. Defaults to None.

            min_epochs (Optional[int], optional): The minimum number of epochs to train for. Defaults to None.

            callbacks (Optional[List[pl.Callback]], optional): If provided, will be added to the callbacks for Trainer.
                Defaults to None.

            freeze_backbone (bool, optional): If True, will freeze the backbone by tirning off gradients.
                Defaults to False, which means the pretrained weights are also further tuned during fine-tuning.

        Returns:
            pl.Trainer: The trainer object

        """
        assert self._is_finetune_model, (
            "finetune() can only be called on a finetune model created using" " `TabularModel.create_finetune_model()`"
        )
        seed_everything(self.config.seed)
        if freeze_backbone:
            for param in self.model.backbone.parameters():
                param.requires_grad = False
        return self.train(
            self.model,
            self.datamodule,
            callbacks=callbacks,
            max_epochs=max_epochs,
            min_epochs=min_epochs,
        )

    def find_learning_rate(
        self,
        model: pl.LightningModule,
        datamodule: TabularDatamodule,
        min_lr: float = 1e-8,
        max_lr: float = 1,
        num_training: int = 100,
        mode: str = "exponential",
        early_stop_threshold: Optional[float] = 4.0,
        plot: bool = True,
        callbacks: Optional[List] = None,
    ) -> Tuple[float, DataFrame]:
        """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
        picking a good starting learning rate.

        Args:
            model (pl.LightningModule): The PyTorch Lightning model to be trained.

            datamodule (TabularDatamodule): The datamodule

            min_lr (Optional[float], optional): minimum learning rate to investigate

            max_lr (Optional[float], optional): maximum learning rate to investigate

            num_training (Optional[int], optional): number of learning rates to test

            mode (Optional[str], optional): search strategy, either 'linear' or 'exponential'. If set to
                'linear' the learning rate will be searched by linearly increasing
                after each batch. If set to 'exponential', will increase learning
                rate exponentially.

            early_stop_threshold (Optional[float], optional): threshold for stopping the search. If the
                loss at any point is larger than early_stop_threshold*best_loss
                then the search is stopped. To disable, set to None.

            plot (bool, optional): If true, will plot using matplotlib

            callbacks (Optional[List], optional): If provided, will be added to the callbacks for Trainer.

        Returns:
            The suggested learning rate and the learning rate finder results

        """
        self._prepare_for_training(model, datamodule, callbacks, max_epochs=None, min_epochs=None)
        train_loader, _ = datamodule.train_dataloader(), datamodule.val_dataloader()
        lr_finder = Tuner(self.trainer).lr_find(
            model=self.model,
            train_dataloaders=train_loader,
            val_dataloaders=None,
            min_lr=min_lr,
            max_lr=max_lr,
            num_training=num_training,
            mode=mode,
            early_stop_threshold=early_stop_threshold,
        )
        if plot:
            fig = lr_finder.plot(suggest=True)
            fig.show()
        new_lr = lr_finder.suggestion()
        # cancelling the model and trainer that was loaded
        self.model = None
        self.trainer = None
        self.datamodule = None
        self.callbacks = None
        return new_lr, DataFrame(lr_finder.results)

    def evaluate(
        self,
        test: Optional[DataFrame] = None,
        test_loader: Optional[torch.utils.data.DataLoader] = None,
        ckpt_path: Optional[Union[str, Path]] = None,
        verbose: bool = True,
    ) -> Union[dict, list]:
        """Evaluates the dataframe using the loss and metrics already set in config.

        Args:
            test (Optional[DataFrame]): The dataframe to be evaluated. If not provided, will try to use the
                test provided during fit. If that was also not provided will return an empty dictionary

            test_loader (Optional[torch.utils.data.DataLoader], optional): The dataloader to be used for evaluation.
                If provided, will use the dataloader instead of the test dataframe or the test data provided during fit.
                Defaults to None.

            ckpt_path (Optional[Union[str, Path]], optional): The path to the checkpoint to be loaded. If not provided,
                will try to use the best checkpoint during training.

            verbose (bool, optional): If true, will print the results. Defaults to True.
        Returns:
            The final test result dictionary.

        """
        assert not (test_loader is None and test is None), (
            "Either `test_loader` or `test` should be provided."
            " If `test_loader` is not provided, `test` should be provided."
        )
        if test_loader is None:
            test_loader = self.datamodule.prepare_inference_dataloader(test)
        result = self.trainer.test(
            model=self.model,
            dataloaders=test_loader,
            ckpt_path=ckpt_path,
            verbose=verbose,
        )
        return result

    def _generate_predictions(
        self,
        model,
        inference_dataloader,
        quantiles,
        n_samples,
        ret_logits,
        progress_bar,
        is_probabilistic,
    ):
        point_predictions = []
        quantile_predictions = []
        logits_predictions = defaultdict(list)
        for batch in progress_bar(inference_dataloader):
            for k, v in batch.items():
                if isinstance(v, list) and (len(v) == 0):
                    continue  # Skipping empty list
                batch[k] = v.to(model.device)
            if is_probabilistic:
                samples, ret_value = model.sample(batch, n_samples, ret_model_output=True)
                y_hat = torch.mean(samples, dim=-1)
                quantile_preds = []
                for q in quantiles:
                    quantile_preds.append(torch.quantile(samples, q=q, dim=-1).unsqueeze(1))
            else:
                y_hat, ret_value = model.predict(batch, ret_model_output=True)
            if ret_logits:
                for k, v in ret_value.items():
                    logits_predictions[k].append(v.detach().cpu())
            point_predictions.append(y_hat.detach().cpu())
            if is_probabilistic:
                quantile_predictions.append(torch.cat(quantile_preds, dim=-1).detach().cpu())
        point_predictions = torch.cat(point_predictions, dim=0)
        if point_predictions.ndim == 1:
            point_predictions = point_predictions.unsqueeze(-1)
        if is_probabilistic:
            quantile_predictions = torch.cat(quantile_predictions, dim=0).unsqueeze(-1)
            if quantile_predictions.ndim == 2:
                quantile_predictions = quantile_predictions.unsqueeze(-1)
        return point_predictions, quantile_predictions, logits_predictions

    def _format_predicitons(
        self,
        test,
        point_predictions,
        quantile_predictions,
        logits_predictions,
        quantiles,
        ret_logits,
        include_input_features,
        is_probabilistic,
    ):
        pred_df = test.copy() if include_input_features else DataFrame(index=test.index)
        if self.config.task == "regression":
            point_predictions = point_predictions.numpy()
            # Probabilistic Models are only implemented for Regression
            if is_probabilistic:
                quantile_predictions = quantile_predictions.numpy()
            for i, target_col in enumerate(self.config.target):
                if self.datamodule.do_target_transform:
                    if self.config.target[i] in pred_df.columns:
                        pred_df[self.config.target[i]] = self.datamodule.target_transforms[i].inverse_transform(
                            pred_df[self.config.target[i]].values.reshape(-1, 1)
                        )
                    pred_df[f"{target_col}_prediction"] = self.datamodule.target_transforms[i].inverse_transform(
                        point_predictions[:, i].reshape(-1, 1)
                    )
                    if is_probabilistic:
                        for j, q in enumerate(quantiles):
                            col_ = f"{target_col}_q{int(q*100)}"
                            pred_df[col_] = self.datamodule.target_transforms[i].inverse_transform(
                                quantile_predictions[:, j, i].reshape(-1, 1)
                            )
                else:
                    pred_df[f"{target_col}_prediction"] = point_predictions[:, i]
                    if is_probabilistic:
                        for j, q in enumerate(quantiles):
                            pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1)

        elif self.config.task == "classification":
            point_predictions = nn.Softmax(dim=-1)(point_predictions).numpy()
            for i, class_ in enumerate(self.datamodule.label_encoder.classes_):
                pred_df[f"{class_}_probability"] = point_predictions[:, i]
            pred_df["prediction"] = self.datamodule.label_encoder.inverse_transform(
                np.argmax(point_predictions, axis=1)
            )
            warnings.warn(
                "Classification prediction column will be renamed to"
                " `{target_col}_prediction` in the next release to maintain"
                " consistency with regression.",
                DeprecationWarning,
            )
        if ret_logits:
            for k, v in logits_predictions.items():
                v = torch.cat(v, dim=0).numpy()
                if v.ndim == 1:
                    v = v.reshape(-1, 1)
                for i in range(v.shape[-1]):
                    if v.shape[-1] > 1:
                        pred_df[f"{k}_{i}"] = v[:, i]
                    else:
                        pred_df[f"{k}"] = v[:, i]
        return pred_df

    def _predict(
        self,
        test: DataFrame,
        quantiles: Optional[List] = [0.25, 0.5, 0.75],
        n_samples: Optional[int] = 100,
        ret_logits=False,
        include_input_features: bool = False,
        device: Optional[torch.device] = None,
        progress_bar: Optional[str] = None,
    ) -> DataFrame:
        """Uses the trained model to predict on new data and return as a dataframe.

        Args:
            test (DataFrame): The new dataframe with the features defined during training
            quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
                the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
                For other models it is ignored. Defaults to [0.25, 0.5, 0.75]
            n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
                Ignored for non-probabilistic models. Defaults to 100
            ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
                with the dataframe. Defaults to False
            include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
                Defaults to True
            progress_bar: choose progress bar for tracking the progress. "rich" or "tqdm" will set the respective
                progress bars. If None, no progress bar will be shown.

        Returns:
            DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
                If classification, it returns probabilities and final prediction

        """
        assert all(q <= 1 and q >= 0 for q in quantiles), "Quantiles should be a decimal between 0 and 1"
        model = self.model  # default
        if device is not None:
            if isinstance(device, str):
                device = torch.device(device)
            if self.model.device != device:
                model = self.model.to(device)
        model.eval()
        inference_dataloader = self.datamodule.prepare_inference_dataloader(test)
        is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic

        if progress_bar == "rich":
            from rich.progress import track

            progress_bar = partial(track, description="Generating Predictions...")
        elif progress_bar == "tqdm":
            from tqdm.auto import tqdm

            progress_bar = partial(tqdm, description="Generating Predictions...")
        else:
            progress_bar = lambda it: it  # E731
        point_predictions, quantile_predictions, logits_predictions = self._generate_predictions(
            model,
            inference_dataloader,
            quantiles,
            n_samples,
            ret_logits,
            progress_bar,
            is_probabilistic,
        )
        pred_df = self._format_predicitons(
            test,
            point_predictions,
            quantile_predictions,
            logits_predictions,
            quantiles,
            ret_logits,
            include_input_features,
            is_probabilistic,
        )
        return pred_df

    def predict(
        self,
        test: DataFrame,
        quantiles: Optional[List] = [0.25, 0.5, 0.75],
        n_samples: Optional[int] = 100,
        ret_logits=False,
        include_input_features: bool = False,
        device: Optional[torch.device] = None,
        progress_bar: Optional[str] = None,
        test_time_augmentation: Optional[bool] = False,
        num_tta: Optional[float] = 5,
        alpha_tta: Optional[float] = 0.1,
        aggregate_tta: Optional[str] = "mean",
        tta_seed: Optional[int] = 42,
    ) -> DataFrame:
        """Uses the trained model to predict on new data and return as a dataframe.

        Args:
            test (DataFrame): The new dataframe with the features defined during training

            quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
                the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
                For other models it is ignored. Defaults to [0.25, 0.5, 0.75]

            n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
                Ignored for non-probabilistic models. Defaults to 100

            ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
                with the dataframe. Defaults to False

            include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
                Defaults to True

            progress_bar: choose progress bar for tracking the progress. "rich" or "tqdm" will set the respective
                progress bars. If None, no progress bar will be shown.

            test_time_augmentation (bool): If True, will use test time augmentation to generate predictions.
                The approach is very similar to what is described [here](https://kozodoi.me/blog/20210908/tta-tabular)
                But, we add noise to the embedded inputs to handle categorical features as well.\
                \\(x_{aug} = x_{orig} + \alpha * \\epsilon\\) where \\(\\epsilon \\sim \\mathcal{N}(0, 1)\\)
                Defaults to False
            num_tta (float): The number of augumentations to run TTA for. Defaults to 0.0

            alpha_tta (float): The standard deviation of the gaussian noise to be added to the input features

            aggregate_tta (Union[str, Callable], optional): The function to be used to aggregate the
                predictions from each augumentation. If str, should be one of "mean", "median", "min", or "max"
                for regression. For classification, the previous options are applied to the confidence
                scores (soft voting) and then converted to final prediction. An additional option
                "hard_voting" is available for classification.
                If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
                and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".'

            tta_seed (int): The random seed to be used for the noise added in TTA. Defaults to 42.

        Returns:
            DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
                If classification, it returns probabilities and final prediction

        """
        warnings.warn(
            "`include_input_features` will be deprecated in the next release."
            " Please add index columns to the test dataframe if you want to"
            " retain some features like the key or id",
            DeprecationWarning,
        )
        if test_time_augmentation:
            assert num_tta > 0, "num_tta should be greater than 0"
            assert alpha_tta > 0, "alpha_tta should be greater than 0"
            assert include_input_features is False, "include_input_features cannot be True for TTA."
            if not callable(aggregate_tta):
                assert aggregate_tta in [
                    "mean",
                    "median",
                    "min",
                    "max",
                    "hard_voting",
                ], "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
            if self.config.task == "regression":
                assert aggregate_tta != "hard_voting", "hard_voting is only available for classification"

            torch.manual_seed(tta_seed)

            def add_noise(module, input, output):
                return output + alpha_tta * torch.randn_like(output, memory_format=torch.contiguous_format)

            # Register the hook to the embedding_layer
            handle = self.model.embedding_layer.register_forward_hook(add_noise)
            pred_prob_l = []
            for _ in range(num_tta):
                pred_df = self._predict(
                    test,
                    quantiles,
                    n_samples,
                    ret_logits,
                    include_input_features=False,
                    device=device,
                    progress_bar=progress_bar or "None",
                )
                pred_idx = pred_df.index
                if self.config.task == "classification":
                    pred_prob_l.append(pred_df.values[:, : -len(self.config.target)])
                elif self.config.task == "regression":
                    pred_prob_l.append(pred_df.values)
            pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate_tta, None)
            # Remove the hook
            handle.remove()
        else:
            pred_df = self._predict(
                test,
                quantiles,
                n_samples,
                ret_logits,
                include_input_features,
                device,
                progress_bar,
            )
        return pred_df

    def load_best_model(self) -> None:
        """Loads the best model after training is done."""
        if self.trainer.checkpoint_callback is not None:
            if self.verbose:
                logger.info("Loading the best model")
            ckpt_path = self.trainer.checkpoint_callback.best_model_path
            if ckpt_path != "":
                if self.verbose:
                    logger.debug(f"Model Checkpoint: {ckpt_path}")
                ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
                self.model.load_state_dict(ckpt["state_dict"])
            else:
                logger.warning("No best model available to load. Did you run it more than 1" " epoch?...")
        else:
            logger.warning(
                "No best model available to load. Checkpoint Callback needs to be" " enabled for this to work"
            )

    def save_datamodule(self, dir: str, inference_only: bool = False) -> None:
        """Saves the datamodule in the specified directory.

        Args:
            dir (str): The path to the directory to save the datamodule
            inference_only (bool): If True, will only save the inference datamodule
                without data. This cannot be used for further training, but can be
                used for inference. Defaults to False.

        """
        if inference_only:
            dm = self.datamodule.inference_only_copy()
        else:
            dm = self.datamodule

        joblib.dump(dm, os.path.join(dir, "datamodule.sav"))

    def save_config(self, dir: str) -> None:
        """Saves the config in the specified directory."""
        with open(os.path.join(dir, "config.yml"), "w") as fp:
            OmegaConf.save(self.config, fp, resolve=True)

    def save_model(self, dir: str, inference_only: bool = False) -> None:
        """Saves the model and checkpoints in the specified directory.

        Args:
            dir (str): The path to the directory to save the model
            inference_only (bool): If True, will only save the inference
                only version of the datamodule

        """
        if os.path.exists(dir) and (os.listdir(dir)):
            logger.warning("Directory is not empty. Overwriting the contents.")
            for f in os.listdir(dir):
                os.remove(os.path.join(dir, f))
        os.makedirs(dir, exist_ok=True)
        self.save_config(dir)
        self.save_datamodule(dir, inference_only=inference_only)
        if hasattr(self.config, "log_target") and self.config.log_target is not None:
            joblib.dump(self.logger, os.path.join(dir, "exp_logger.sav"))
        if hasattr(self, "callbacks"):
            joblib.dump(self.callbacks, os.path.join(dir, "callbacks.sav"))
        self.trainer.save_checkpoint(os.path.join(dir, "model.ckpt"))
        custom_params = {}
        custom_params["custom_loss"] = getattr(self.model, "custom_loss", None)
        custom_params["custom_metrics"] = getattr(self.model, "custom_metrics", None)
        custom_params["custom_metrics_prob_inputs"] = getattr(self.model, "custom_metrics_prob_inputs", None)
        custom_params["custom_optimizer"] = getattr(self.model, "custom_optimizer", None)
        custom_params["custom_optimizer_params"] = getattr(self.model, "custom_optimizer_params", None)
        joblib.dump(custom_params, os.path.join(dir, "custom_params.sav"))
        if self.custom_model:
            joblib.dump(self.model_callable, os.path.join(dir, "custom_model_callable.sav"))

    def save_weights(self, path: Union[str, Path]) -> None:
        """Saves the model weights in the specified directory.

        Args:
            path (str): The path to the file to save the model

        """
        torch.save(self.model.state_dict(), path)

    def load_weights(self, path: Union[str, Path]) -> None:
        """Loads the model weights in the specified directory.

        Args:
            path (str): The path to the file to load the model from

        """
        self._load_weights(self.model, path)

    # TODO Need to test ONNX export
    def save_model_for_inference(
        self,
        path: Union[str, Path],
        kind: str = "pytorch",
        onnx_export_params: Dict = {"opset_version": 12},
    ) -> bool:
        """Saves the model for inference.

        Args:
            path (Union[str, Path]): path to save the model
            kind (str): "pytorch" or "onnx" (Experimental)
            onnx_export_params (Dict): parameters for onnx export to be
                passed to torch.onnx.export

        Returns:
            bool: True if the model was saved successfully

        """
        if kind == "pytorch":
            torch.save(self.model, str(path))
            return True
        elif kind == "onnx":
            # Export the model
            onnx_export_params["input_names"] = ["categorical", "continuous"]
            onnx_export_params["output_names"] = onnx_export_params.get("output_names", ["output"])
            onnx_export_params["dynamic_axes"] = {
                onnx_export_params["input_names"][0]: {0: "batch_size"},
                onnx_export_params["output_names"][0]: {0: "batch_size"},
            }
            cat = torch.zeros(
                self.config.batch_size,
                len(self.config.categorical_cols),
                dtype=torch.int,
            )
            cont = torch.randn(
                self.config.batch_size,
                len(self.config.continuous_cols),
                requires_grad=True,
            )
            x = {"continuous": cont, "categorical": cat}
            torch.onnx.export(self.model, x, str(path), **onnx_export_params)
            return True
        else:
            raise ValueError("`kind` must be either pytorch or onnx")

    def summary(self, model=None, max_depth: int = -1) -> None:
        """Prints a summary of the model.

        Args:
            max_depth (int): The maximum depth to traverse the modules and displayed in the summary.
                Defaults to -1, which means will display all the modules.

        """
        if model is not None:
            print(summarize(model, max_depth=max_depth))
        elif self.has_model:
            print(summarize(self.model, max_depth=max_depth))
        else:
            rich_print(f"[bold green]{self.__class__.__name__}[/bold green]")
            rich_print("-" * 100)
            rich_print("[bold yellow]Config[/bold yellow]")
            rich_print("-" * 100)
            pprint(self.config.__dict__["_content"])
            rich_print(
                ":triangular_flag:[bold red]Full Model Summary once model has "
                "been initialized or passed in as an argument[/bold red]"
            )

    def __str__(self) -> str:
        return self.summary()

    def feature_importance(self) -> DataFrame:
        """Returns the feature importance of the model as a pandas DataFrame."""
        return self.model.feature_importance()

    def _prepare_input_for_captum(self, test_dl: torch.utils.data.DataLoader) -> Dict:
        tensor_inp = []
        tensor_tgt = []
        for x in test_dl:
            tensor_inp.append(self.model.embed_input(x))
            tensor_tgt.append(x["target"].squeeze(1))
        tensor_inp = torch.cat(tensor_inp, dim=0)
        tensor_tgt = torch.cat(tensor_tgt, dim=0)
        return tensor_inp, tensor_tgt

    def _prepare_baselines_captum(
        self,
        baselines: Union[float, torch.tensor, str],
        test_dl: torch.utils.data.DataLoader,
        do_baselines: bool,
        is_full_baselines: bool,
    ):
        if do_baselines and baselines is not None and isinstance(baselines, str):
            if baselines.startswith("b|"):
                num_samples = int(baselines.split("|")[1])
                tensor_inp_tr = []
                # tensor_tgt_tr = []
                count = 0
                for x in self.datamodule.train_dataloader():
                    tensor_inp_tr.append(self.model.embed_input(x))
                    # tensor_tgt_tr.append(x["target"])
                    count += x["target"].shape[0]
                    if count >= num_samples:
                        break
                tensor_inp_tr = torch.cat(tensor_inp_tr, dim=0)
                # tensor_tgt_tr = torch.cat(tensor_tgt_tr, dim=0)
                baselines = tensor_inp_tr[:num_samples]
                if is_full_baselines:
                    pass
                else:
                    baselines = baselines.mean(dim=0, keepdim=True)
            else:
                raise ValueError(
                    "Invalid value for `baselines`. Please refer to the documentation" " for more details."
                )
        return baselines

    def _handle_categorical_embeddings_attributions(
        self,
        attributions: torch.tensor,
        is_embedding1d: bool,
        is_embedding2d: bool,
        is_embbeding_dims: bool,
    ):
        # post processing to get attributions for categorical features
        if is_embedding1d and is_embbeding_dims:
            if self.model.hparams.categorical_dim > 0:
                cat_attributions = []
                index_counter = self.model.hparams.continuous_dim
                for _, embed_dim in self.model.hparams.embedding_dims:
                    cat_attributions.append(attributions[:, index_counter : index_counter + embed_dim].sum(dim=1))
                    index_counter += embed_dim
                cat_attributions = torch.stack(cat_attributions, dim=1)
                attributions = torch.cat(
                    [
                        attributions[:, : self.model.hparams.continuous_dim],
                        cat_attributions,
                    ],
                    dim=1,
                )
        elif is_embedding2d:
            attributions = attributions.mean(dim=-1)
        return attributions

    def explain(
        self,
        data: DataFrame,
        method: str = "GradientShap",
        method_args: Optional[Dict] = {},
        baselines: Union[float, torch.tensor, str] = None,
        **kwargs,
    ) -> DataFrame:
        """Returns the feature attributions/explanations of the model as a pandas DataFrame. The shape of the returned
        dataframe is (num_samples, num_features)

        Args:
            data (DataFrame): The dataframe to be explained
            method (str): The method to be used for explaining the model.
                It should be one of the Defaults to "GradientShap".
                For more details, refer to https://captum.ai/api/attribution.html
            method_args (Optional[Dict], optional): The arguments to be passed to the initialization
                of the Captum method.
            baselines (Union[float, torch.tensor, str]): The baselines to be used for the explanation.
                If a scalar is provided, will use that value as the baseline for all the features.
                If a tensor is provided, will use that tensor as the baseline for all the features.
                If a string like `b|<num_samples>` is provided, will use that many samples from the train
                Using the whole train data as the baseline is not recommended as it can be
                computationally expensive. By default, PyTorch Tabular uses 10000 samples from the
                train data as the baseline. You can configure this by passing a special string
                "b|<num_samples>" where <num_samples> is the number of samples to be used as the
                baseline. For eg. "b|1000" will use 1000 samples from the train.
                If None, will use default settings like zero in captum(which is method dependent).
                For `GradientShap`, it is the train data.
                Defaults to None.

            **kwargs: Additional keyword arguments to be passed to the Captum method `attribute` function.

        Returns:
            DataFrame: The dataframe with the feature importance

        """
        assert CAPTUM_INSTALLED, "Captum not installed. Please install using `pip install captum` or "
        "install PyTorch Tabular using `pip install pytorch-tabular[extra]`"
        ALLOWED_METHODS = [
            "GradientShap",
            "IntegratedGradients",
            "DeepLift",
            "DeepLiftShap",
            "InputXGradient",
            "FeaturePermutation",
            "FeatureAblation",
            "KernelShap",
        ]
        assert method in ALLOWED_METHODS, f"method should be one of {ALLOWED_METHODS}"
        if isinstance(data, pd.Series):
            data = data.to_frame().T
        if method in ["DeepLiftShap", "KernelShap"]:
            warnings.warn(
                f"{method} is computationally expensive and will take some time. For"
                " faster results, try usingsome other methods like GradientShap,"
                " IntegratedGradients etc."
            )
        if method in ["FeaturePermutation", "FeatureAblation"]:
            assert data.shape[0] > 1, f"{method} only works when the number of samples is greater than 1"
            if len(data) <= 100:
                warnings.warn(
                    f"{method} gives better results when the number of samples is"
                    " large. For better results, try using more samples or some other"
                    " methods like GradientShap which works well on single examples."
                )
        is_full_baselines = method in ["GradientShap", "DeepLiftShap"]
        is_not_supported = self.model._get_name() in [
            "TabNetModel",
            "MDNModel",
            "TabTransformerModel",
        ]
        do_baselines = method not in [
            "Saliency",
            "InputXGradient",
            "FeaturePermutation",
            "LRP",
        ]
        if is_full_baselines and (baselines is None or isinstance(baselines, (float, int))):
            raise ValueError(
                f"baselines cannot be a scalar or None for {method}. Please "
                "provide a tensor or a string like `b|<num_samples>`"
            )
        if is_not_supported:
            raise NotImplementedError(f"Attributions are not implemented for {self.model._get_name()}")

        is_embedding1d = isinstance(self.model.embedding_layer, (Embedding1dLayer, PreEncoded1dLayer))
        is_embedding2d = isinstance(self.model.embedding_layer, Embedding2dLayer)
        # Models like NODE may have no embedding dims (doing leaveOneOut encoding) even if categorical_dim > 0
        is_embbeding_dims = (
            hasattr(self.model.hparams, "embedding_dims") and self.model.hparams.embedding_dims is not None
        )
        if (not is_embedding1d) and (not is_embedding2d):
            raise NotImplementedError(
                "Attributions are not implemented for models with this type of" " embedding layer"
            )
        test_dl = self.datamodule.prepare_inference_dataloader(data)
        self.model.eval()
        # prepare import for Captum
        tensor_inp, tensor_tgt = self._prepare_input_for_captum(test_dl)
        baselines = self._prepare_baselines_captum(baselines, test_dl, do_baselines, is_full_baselines)
        # prepare model for Captum
        try:
            interp_model = _CaptumModel(self.model)
            captum_interp_cls = getattr(captum.attr, method)(interp_model, **method_args)
            if do_baselines:
                attributions = captum_interp_cls.attribute(
                    tensor_inp,
                    baselines=baselines,
                    target=(tensor_tgt if self.config.task == "classification" else None),
                    **kwargs,
                )
            else:
                attributions = captum_interp_cls.attribute(
                    tensor_inp,
                    target=(tensor_tgt if self.config.task == "classification" else None),
                    **kwargs,
                )
            attributions = self._handle_categorical_embeddings_attributions(
                attributions, is_embedding1d, is_embedding2d, is_embbeding_dims
            )
        finally:
            self.model.train()
        assert attributions.shape[1] == self.model.hparams.continuous_dim + self.model.hparams.categorical_dim, (
            "Something went wrong. The number of features in the attributions"
            f" ({attributions.shape[1]}) does not match the number of features in"
            " the model"
            f" ({self.model.hparams.continuous_dim+self.model.hparams.categorical_dim})"
        )
        return pd.DataFrame(
            attributions.detach().cpu().numpy(),
            columns=self.config.continuous_cols + self.config.categorical_cols,
        )

    def _check_cv(self, cv):
        cv = 5 if cv is None else cv
        if isinstance(cv, int):
            if self.config.task == "classification":
                return StratifiedKFold(cv)
            else:
                return KFold(cv)
        elif isinstance(cv, Iterable) and not isinstance(cv, str):
            # An iterable yielding (train, test) splits as arrays of indices.
            return cv
        elif isinstance(cv, BaseCrossValidator):
            return cv
        else:
            raise ValueError("cv must be int, iterable or scikit-learn splitter")

    def _split_kwargs(self, kwargs):
        prep_dl_kwargs = {}
        prep_model_kwargs = {}
        train_kwargs = {}
        # using the defined args in self.prepare_dataloder, self.prepare_model, and self.train
        # to split the kwargs
        for k, v in kwargs.items():
            if k in self.prepare_dataloader.__code__.co_varnames:
                prep_dl_kwargs[k] = v
            elif k in self.prepare_model.__code__.co_varnames:
                prep_model_kwargs[k] = v
            elif k in self.train.__code__.co_varnames:
                train_kwargs[k] = v
            else:
                raise ValueError(f"Invalid keyword argument: {k}")
        return prep_dl_kwargs, prep_model_kwargs, train_kwargs

    def cross_validate(
        self,
        cv: Optional[Union[int, Iterable, BaseCrossValidator]],
        train: DataFrame,
        metric: Optional[Union[str, Callable]] = None,
        return_oof: bool = False,
        groups: Optional[Union[str, np.ndarray]] = None,
        verbose: bool = True,
        reset_datamodule: bool = True,
        handle_oom: bool = True,
        **kwargs,
    ):
        """Cross validate the model.

        Args:
            cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
                Possible inputs for cv are:

                - None, to use the default 5-fold cross validation (KFold for
                Regression and StratifiedKFold for Classification),
                - integer, to specify the number of folds in a (Stratified)KFold,
                - An iterable yielding (train, test) splits as arrays of indices.
                - A scikit-learn CV splitter.

            train (DataFrame): The training data with labels

            metric (Optional[Union[str, Callable]], optional): The metrics to be used for evaluation.
                If None, will use the first metric in the config. If str is provided, will use that
                metric from the defined ones. If callable is provided, will use that function as the
                metric. We expect callable to be of the form `metric(y_true, y_pred)`. For classification
                problems, The `y_pred` is a dataframe with the probabilities for each class
                (<class>_probability) and a final prediction(prediction). And for Regression, it is a
                dataframe with a final prediction (<target>_prediction).
                Defaults to None.

            return_oof (bool, optional): If True, will return the out-of-fold predictions
                along with the cross validation results. Defaults to False.

            groups (Optional[Union[str, np.ndarray]], optional): Group labels for
                the samples used while splitting. If provided, will be used as the
                `groups` argument for the `split` method of the cross validator.
                If input is str, will use the column in the input dataframe with that
                name as the group labels. If input is array-like, will use that as the
                group. The only constraint is that the group labels should have the
                same size as the number of rows in the input dataframe. Defaults to None.

            verbose (bool, optional): If True, will log the results. Defaults to True.

            reset_datamodule (bool, optional): If True, will reset the datamodule for each iteration.
                It will be slower because we will be fitting the transformations for each fold.
                If False, we take an approximation that once the transformations are fit on the first
                fold, they will be valid for all the other folds. Defaults to True.

            handle_oom (bool, optional): If True, will handle out of memory errors elegantly
            **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.

        Returns:
            DataFrame: The dataframe with the cross validation results

        """
        cv = self._check_cv(cv)
        prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
        is_callable_metric = False
        if metric is None:
            metric = "test_" + self.config.metrics[0]
        elif isinstance(metric, str):
            metric = metric if metric.startswith("test_") else "test_" + metric
        elif callable(metric):
            is_callable_metric = True

        if isinstance(cv, BaseCrossValidator):
            it = enumerate(cv.split(train, y=train[self.config.target], groups=groups))
        else:
            # when iterable is directly passed
            it = enumerate(cv)
        cv_metrics = []
        datamodule = None
        model = None
        oof_preds = []
        for fold, (train_idx, val_idx) in it:
            if verbose:
                logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
            # train_fold = train.iloc[train_idx]
            # val_fold = train.iloc[val_idx]
            if reset_datamodule:
                datamodule = None
            if datamodule is None:
                # Initialize datamodule and model in the first fold
                # uses train data from this fold to fit all transformers
                datamodule = self.prepare_dataloader(
                    train=train.iloc[train_idx], validation=train.iloc[val_idx], seed=42, **prep_dl_kwargs
                )
                model = self.prepare_model(datamodule, **prep_model_kwargs)
            else:
                # Preprocess the current fold data using the fitted transformers and save in datamodule
                datamodule.train, _ = datamodule.preprocess_data(train.iloc[train_idx], stage="inference")
                datamodule.validation, _ = datamodule.preprocess_data(train.iloc[val_idx], stage="inference")

            # Train the model
            handle_oom = train_kwargs.pop("handle_oom", handle_oom)
            self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
            if return_oof or is_callable_metric:
                preds = self.predict(train.iloc[val_idx], include_input_features=False)
                oof_preds.append(preds)
            if is_callable_metric:
                cv_metrics.append(metric(train.iloc[val_idx][self.config.target], preds))
            else:
                result = self.evaluate(train.iloc[val_idx], verbose=False)
                cv_metrics.append(result[0][metric])
            if verbose:
                logger.info(f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}")
            self.model.reset_weights()
        return cv_metrics, oof_preds

    def _combine_predictions(
        self,
        pred_prob_l: List[DataFrame],
        pred_idx: Union[pd.Index, List],
        aggregate: Union[str, Callable],
        weights: Optional[List[float]] = None,
    ):
        if aggregate == "mean":
            bagged_pred = np.average(pred_prob_l, axis=0, weights=weights)
        elif aggregate == "median":
            bagged_pred = np.median(pred_prob_l, axis=0)
        elif aggregate == "min":
            bagged_pred = np.min(pred_prob_l, axis=0)
        elif aggregate == "max":
            bagged_pred = np.max(pred_prob_l, axis=0)
        elif aggregate == "hard_voting" and self.config.task == "classification":
            pred_l = [np.argmax(p, axis=1) for p in pred_prob_l]
            final_pred = np.apply_along_axis(
                lambda x: np.argmax(np.bincount(x)),
                axis=0,
                arr=pred_l,
            )
        elif callable(aggregate):
            bagged_pred = aggregate(pred_prob_l)
        if self.config.task == "classification":
            classes = self.datamodule.label_encoder.classes_
            if aggregate == "hard_voting":
                pred_df = pd.DataFrame(
                    np.concatenate(pred_prob_l, axis=1),
                    columns=[
                        f"{c}_probability_fold_{i}"
                        for i in range(len(pred_prob_l))
                        for c in self.datamodule.label_encoder.classes_
                    ],
                    index=pred_idx,
                )
                pred_df["prediction"] = classes[final_pred]
            else:
                final_pred = classes[np.argmax(bagged_pred, axis=1)]
                pred_df = pd.DataFrame(
                    bagged_pred,
                    columns=[f"{c}_probability" for c in self.datamodule.label_encoder.classes_],
                    index=pred_idx,
                )
                pred_df["prediction"] = final_pred
        elif self.config.task == "regression":
            pred_df = pd.DataFrame(bagged_pred, columns=self.config.target, index=pred_idx)
        else:
            raise NotImplementedError(f"Task {self.config.task} not supported for bagging")
        return pred_df

    def bagging_predict(
        self,
        cv: Optional[Union[int, Iterable, BaseCrossValidator]],
        train: DataFrame,
        test: DataFrame,
        groups: Optional[Union[str, np.ndarray]] = None,
        verbose: bool = True,
        reset_datamodule: bool = True,
        return_raw_predictions: bool = False,
        aggregate: Union[str, Callable] = "mean",
        weights: Optional[List[float]] = None,
        handle_oom: bool = True,
        **kwargs,
    ):
        """Bagging predict on the test data.

        Args:
            cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
                Possible inputs for cv are:

                - None, to use the default 5-fold cross validation (KFold for
                Regression and StratifiedKFold for Classification),
                - integer, to specify the number of folds in a (Stratified)KFold,
                - An iterable yielding (train, test) splits as arrays of indices.
                - A scikit-learn CV splitter.

            train (DataFrame): The training data with labels

            test (DataFrame): The test data to be predicted

            groups (Optional[Union[str, np.ndarray]], optional): Group labels for
                the samples used while splitting. If provided, will be used as the
                `groups` argument for the `split` method of the cross validator.
                If input is str, will use the column in the input dataframe with that
                name as the group labels. If input is array-like, will use that as the
                group. The only constraint is that the group labels should have the
                same size as the number of rows in the input dataframe. Defaults to None.

            verbose (bool, optional): If True, will log the results. Defaults to True.

            reset_datamodule (bool, optional): If True, will reset the datamodule for each iteration.
                It will be slower because we will be fitting the transformations for each fold.
                If False, we take an approximation that once the transformations are fit on the first
                fold, they will be valid for all the other folds. Defaults to True.

            return_raw_predictions (bool, optional): If True, will return the raw predictions
                from each fold. Defaults to False.

            aggregate (Union[str, Callable], optional): The function to be used to aggregate the
                predictions from each fold. If str, should be one of "mean", "median", "min", or "max"
                for regression. For classification, the previous options are applied to the confidence
                scores (soft voting) and then converted to final prediction. An additional option
                "hard_voting" is available for classification.
                If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
                and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".

            weights (Optional[List[float]], optional): The weights to be used for aggregating the predictions
                from each fold. If None, will use equal weights. This is only used when `aggregate` is "mean".
                Defaults to None.

            handle_oom (bool, optional): If True, will handle out of memory errors elegantly

            **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.

        Returns:
            DataFrame: The dataframe with the bagged predictions.

        """
        if weights is not None:
            assert len(weights) == cv.n_splits, "Number of weights should be equal to the number of folds"
        assert self.config.task in [
            "classification",
            "regression",
        ], "Bagging is only available for classification and regression"
        if not callable(aggregate):
            assert aggregate in ["mean", "median", "min", "max", "hard_voting"], (
                "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
            )
        if self.config.task == "regression":
            assert aggregate != "hard_voting", "hard_voting is only available for classification"
        cv = self._check_cv(cv)
        prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
        pred_prob_l = []
        datamodule = None
        model = None
        for fold, (train_idx, val_idx) in enumerate(cv.split(train, y=train[self.config.target], groups=groups)):
            if verbose:
                logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
            train_fold = train.iloc[train_idx]
            val_fold = train.iloc[val_idx]
            if reset_datamodule:
                datamodule = None
            if datamodule is None:
                # Initialize datamodule and model in the first fold
                # uses train data from this fold to fit all transformers
                datamodule = self.prepare_dataloader(train=train_fold, validation=val_fold, seed=42, **prep_dl_kwargs)
                model = self.prepare_model(datamodule, **prep_model_kwargs)
            else:
                # Preprocess the current fold data using the fitted transformers and save in datamodule
                datamodule.train, _ = datamodule.preprocess_data(train_fold, stage="inference")
                datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference")

            # Train the model
            handle_oom = train_kwargs.pop("handle_oom", handle_oom)
            self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
            fold_preds = self.predict(test, include_input_features=False)
            pred_idx = fold_preds.index
            if self.config.task == "classification":
                pred_prob_l.append(fold_preds.values[:, : -len(self.config.target)])
            elif self.config.task == "regression":
                pred_prob_l.append(fold_preds.values)
            if verbose:
                logger.info(f"Fold {fold+1}/{cv.get_n_splits()} prediction done")
            self.model.reset_weights()
        pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate, weights)
        if return_raw_predictions:
            return pred_df, pred_prob_l
        else:
            return pred_df

__init__(config=None, data_config=None, model_config=None, optimizer_config=None, trainer_config=None, experiment_config=None, model_callable=None, model_state_dict_path=None, verbose=True, suppress_lightning_logger=False)

The core model which orchestrates everything from initializing the datamodule, the model, trainer, etc.

Parameters:

Name Type Description Default
config Optional[Union[DictConfig, str]]

Single OmegaConf DictConfig object or the path to the yaml file holding all the config parameters. Defaults to None.

None
data_config Optional[Union[DataConfig, str]]

DataConfig object or path to the yaml file. Defaults to None.

None
model_config Optional[Union[ModelConfig, str]]

A subclass of ModelConfig or path to the yaml file. Determines which model to run from the type of config. Defaults to None.

None
optimizer_config Optional[Union[OptimizerConfig, str]]

OptimizerConfig object or path to the yaml file. Defaults to None.

None
trainer_config Optional[Union[TrainerConfig, str]]

TrainerConfig object or path to the yaml file. Defaults to None.

None
experiment_config Optional[Union[ExperimentConfig, str]]

ExperimentConfig object or path to the yaml file. If Provided configures the experiment tracking. Defaults to None.

None
model_callable Optional[Callable]

If provided, will override the model callable that will be loaded from the config. Typically used when providing Custom Models

None
model_state_dict_path Optional[Union[str, Path]]

If provided, will load the state dict after initializing the model from config.

None
verbose bool

turns off and on the logging. Defaults to True.

True
suppress_lightning_logger bool

If True, will suppress the default logging from PyTorch Lightning. Defaults to False.

False
Source code in src/pytorch_tabular/tabular_model.py
def __init__(
    self,
    config: Optional[DictConfig] = None,
    data_config: Optional[Union[DataConfig, str]] = None,
    model_config: Optional[Union[ModelConfig, str]] = None,
    optimizer_config: Optional[Union[OptimizerConfig, str]] = None,
    trainer_config: Optional[Union[TrainerConfig, str]] = None,
    experiment_config: Optional[Union[ExperimentConfig, str]] = None,
    model_callable: Optional[Callable] = None,
    model_state_dict_path: Optional[Union[str, Path]] = None,
    verbose: bool = True,
    suppress_lightning_logger: bool = False,
) -> None:
    """The core model which orchestrates everything from initializing the datamodule, the model, trainer, etc.

    Args:
        config (Optional[Union[DictConfig, str]], optional): Single OmegaConf DictConfig object or
            the path to the yaml file holding all the config parameters. Defaults to None.

        data_config (Optional[Union[DataConfig, str]], optional):
            DataConfig object or path to the yaml file. Defaults to None.

        model_config (Optional[Union[ModelConfig, str]], optional):
            A subclass of ModelConfig or path to the yaml file.
            Determines which model to run from the type of config. Defaults to None.

        optimizer_config (Optional[Union[OptimizerConfig, str]], optional):
            OptimizerConfig object or path to the yaml file. Defaults to None.

        trainer_config (Optional[Union[TrainerConfig, str]], optional):
            TrainerConfig object or path to the yaml file. Defaults to None.

        experiment_config (Optional[Union[ExperimentConfig, str]], optional):
            ExperimentConfig object or path to the yaml file.
            If Provided configures the experiment tracking. Defaults to None.

        model_callable (Optional[Callable], optional):
            If provided, will override the model callable that will be loaded from the config.
            Typically used when providing Custom Models

        model_state_dict_path (Optional[Union[str, Path]], optional):
            If provided, will load the state dict after initializing the model from config.

        verbose (bool): turns off and on the logging. Defaults to True.

        suppress_lightning_logger (bool): If True, will suppress the default logging from PyTorch Lightning.
            Defaults to False.

    """
    super().__init__()
    if suppress_lightning_logger:
        suppress_lightning_logs()
    self.verbose = verbose
    self.exp_manager = ExperimentRunManager()
    if config is None:
        assert any(c is not None for c in (data_config, model_config, optimizer_config, trainer_config)), (
            "If `config` is None, `data_config`, `model_config`,"
            " `trainer_config`, and `optimizer_config` cannot be None"
        )
        data_config = self._read_parse_config(data_config, DataConfig)
        model_config = self._read_parse_config(model_config, ModelConfig)
        trainer_config = self._read_parse_config(trainer_config, TrainerConfig)
        optimizer_config = self._read_parse_config(optimizer_config, OptimizerConfig)
        if model_config.task != "ssl":
            assert data_config.target is not None, (
                "`target` in data_config should not be None for" f" {model_config.task} task"
            )
        if experiment_config is None:
            if self.verbose:
                logger.info("Experiment Tracking is turned off")
            self.track_experiment = False
            self.config = OmegaConf.merge(
                OmegaConf.to_container(data_config),
                OmegaConf.to_container(model_config),
                OmegaConf.to_container(trainer_config),
                OmegaConf.to_container(optimizer_config),
            )
        else:
            experiment_config = self._read_parse_config(experiment_config, ExperimentConfig)
            self.track_experiment = True
            self.config = OmegaConf.merge(
                OmegaConf.to_container(data_config),
                OmegaConf.to_container(model_config),
                OmegaConf.to_container(trainer_config),
                OmegaConf.to_container(experiment_config),
                OmegaConf.to_container(optimizer_config),
            )
    else:
        self.config = config
        if hasattr(config, "log_target") and (config.log_target is not None):
            # experiment_config = OmegaConf.structured(experiment_config)
            self.track_experiment = True
        else:
            if self.verbose:
                logger.info("Experiment Tracking is turned off")
            self.track_experiment = False

    self.run_name, self.uid = self._get_run_name_uid()
    if self.track_experiment:
        self._setup_experiment_tracking()
    else:
        self.logger = None

    self.exp_manager = ExperimentRunManager()
    if model_callable is None:
        self.model_callable = getattr_nested(self.config._module_src, self.config._model_name)
        self.custom_model = False
    else:
        self.model_callable = model_callable
        self.custom_model = True
    self.model_state_dict_path = model_state_dict_path
    self._is_config_updated_with_data = False
    self._run_validation()
    self._is_fitted = False

bagging_predict(cv, train, test, groups=None, verbose=True, reset_datamodule=True, return_raw_predictions=False, aggregate='mean', weights=None, handle_oom=True, **kwargs)

Bagging predict on the test data.

Parameters:

Name Type Description Default
cv Optional[Union[int, Iterable, BaseCrossValidator]]

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 5-fold cross validation (KFold for Regression and StratifiedKFold for Classification),
  • integer, to specify the number of folds in a (Stratified)KFold,
  • An iterable yielding (train, test) splits as arrays of indices.
  • A scikit-learn CV splitter.
required
train DataFrame

The training data with labels

required
test DataFrame

The test data to be predicted

required
groups Optional[Union[str, ndarray]]

Group labels for the samples used while splitting. If provided, will be used as the groups argument for the split method of the cross validator. If input is str, will use the column in the input dataframe with that name as the group labels. If input is array-like, will use that as the group. The only constraint is that the group labels should have the same size as the number of rows in the input dataframe. Defaults to None.

None
verbose bool

If True, will log the results. Defaults to True.

True
reset_datamodule bool

If True, will reset the datamodule for each iteration. It will be slower because we will be fitting the transformations for each fold. If False, we take an approximation that once the transformations are fit on the first fold, they will be valid for all the other folds. Defaults to True.

True
return_raw_predictions bool

If True, will return the raw predictions from each fold. Defaults to False.

False
aggregate Union[str, Callable]

The function to be used to aggregate the predictions from each fold. If str, should be one of "mean", "median", "min", or "max" for regression. For classification, the previous options are applied to the confidence scores (soft voting) and then converted to final prediction. An additional option "hard_voting" is available for classification. If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets) and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".

'mean'
weights Optional[List[float]]

The weights to be used for aggregating the predictions from each fold. If None, will use equal weights. This is only used when aggregate is "mean". Defaults to None.

None
handle_oom bool

If True, will handle out of memory errors elegantly

True
**kwargs

Additional keyword arguments to be passed to the fit method of the model.

{}

Returns:

Name Type Description
DataFrame

The dataframe with the bagged predictions.

Source code in src/pytorch_tabular/tabular_model.py
def bagging_predict(
    self,
    cv: Optional[Union[int, Iterable, BaseCrossValidator]],
    train: DataFrame,
    test: DataFrame,
    groups: Optional[Union[str, np.ndarray]] = None,
    verbose: bool = True,
    reset_datamodule: bool = True,
    return_raw_predictions: bool = False,
    aggregate: Union[str, Callable] = "mean",
    weights: Optional[List[float]] = None,
    handle_oom: bool = True,
    **kwargs,
):
    """Bagging predict on the test data.

    Args:
        cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
            Possible inputs for cv are:

            - None, to use the default 5-fold cross validation (KFold for
            Regression and StratifiedKFold for Classification),
            - integer, to specify the number of folds in a (Stratified)KFold,
            - An iterable yielding (train, test) splits as arrays of indices.
            - A scikit-learn CV splitter.

        train (DataFrame): The training data with labels

        test (DataFrame): The test data to be predicted

        groups (Optional[Union[str, np.ndarray]], optional): Group labels for
            the samples used while splitting. If provided, will be used as the
            `groups` argument for the `split` method of the cross validator.
            If input is str, will use the column in the input dataframe with that
            name as the group labels. If input is array-like, will use that as the
            group. The only constraint is that the group labels should have the
            same size as the number of rows in the input dataframe. Defaults to None.

        verbose (bool, optional): If True, will log the results. Defaults to True.

        reset_datamodule (bool, optional): If True, will reset the datamodule for each iteration.
            It will be slower because we will be fitting the transformations for each fold.
            If False, we take an approximation that once the transformations are fit on the first
            fold, they will be valid for all the other folds. Defaults to True.

        return_raw_predictions (bool, optional): If True, will return the raw predictions
            from each fold. Defaults to False.

        aggregate (Union[str, Callable], optional): The function to be used to aggregate the
            predictions from each fold. If str, should be one of "mean", "median", "min", or "max"
            for regression. For classification, the previous options are applied to the confidence
            scores (soft voting) and then converted to final prediction. An additional option
            "hard_voting" is available for classification.
            If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
            and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".

        weights (Optional[List[float]], optional): The weights to be used for aggregating the predictions
            from each fold. If None, will use equal weights. This is only used when `aggregate` is "mean".
            Defaults to None.

        handle_oom (bool, optional): If True, will handle out of memory errors elegantly

        **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.

    Returns:
        DataFrame: The dataframe with the bagged predictions.

    """
    if weights is not None:
        assert len(weights) == cv.n_splits, "Number of weights should be equal to the number of folds"
    assert self.config.task in [
        "classification",
        "regression",
    ], "Bagging is only available for classification and regression"
    if not callable(aggregate):
        assert aggregate in ["mean", "median", "min", "max", "hard_voting"], (
            "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
        )
    if self.config.task == "regression":
        assert aggregate != "hard_voting", "hard_voting is only available for classification"
    cv = self._check_cv(cv)
    prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
    pred_prob_l = []
    datamodule = None
    model = None
    for fold, (train_idx, val_idx) in enumerate(cv.split(train, y=train[self.config.target], groups=groups)):
        if verbose:
            logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
        train_fold = train.iloc[train_idx]
        val_fold = train.iloc[val_idx]
        if reset_datamodule:
            datamodule = None
        if datamodule is None:
            # Initialize datamodule and model in the first fold
            # uses train data from this fold to fit all transformers
            datamodule = self.prepare_dataloader(train=train_fold, validation=val_fold, seed=42, **prep_dl_kwargs)
            model = self.prepare_model(datamodule, **prep_model_kwargs)
        else:
            # Preprocess the current fold data using the fitted transformers and save in datamodule
            datamodule.train, _ = datamodule.preprocess_data(train_fold, stage="inference")
            datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference")

        # Train the model
        handle_oom = train_kwargs.pop("handle_oom", handle_oom)
        self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
        fold_preds = self.predict(test, include_input_features=False)
        pred_idx = fold_preds.index
        if self.config.task == "classification":
            pred_prob_l.append(fold_preds.values[:, : -len(self.config.target)])
        elif self.config.task == "regression":
            pred_prob_l.append(fold_preds.values)
        if verbose:
            logger.info(f"Fold {fold+1}/{cv.get_n_splits()} prediction done")
        self.model.reset_weights()
    pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate, weights)
    if return_raw_predictions:
        return pred_df, pred_prob_l
    else:
        return pred_df

create_finetune_model(task, head, head_config, train, validation=None, train_sampler=None, target_transform=None, target=None, optimizer_config=None, trainer_config=None, experiment_config=None, loss=None, metrics=None, metrics_prob_input=None, metrics_params=None, optimizer=None, optimizer_params=None, learning_rate=None, target_range=None, seed=42)

Creates a new TabularModel model using the pretrained weights and the new task and head.

Parameters:

Name Type Description Default
task str

The task to be performed. One of "regression", "classification"

required
head str

The head to be used for the model. Should be one of the heads defined in pytorch_tabular.models.common.heads. Defaults to LinearHead. Choices are: [None,LinearHead,MixtureDensityHead].

required
head_config Dict

The config as a dict which defines the head. If left empty, will be initialized as default linear head.

required
train DataFrame

The training data with labels

required
validation Optional[DataFrame]

The validation data with labels. Defaults to None.

None
train_sampler Optional[Sampler]

If provided, will be used as a batch sampler for training. Defaults to None.

None
target_transform Optional[Union[TransformerMixin, Tuple]]

If provided, will be used to transform the target before training and inverse transform the predictions.

None
target Optional[str]

The target column name if not provided in the initial pretraining stage. Defaults to None.

None
optimizer_config Optional[OptimizerConfig]

If provided, will redefine the optimizer for fine-tuning stage. Defaults to None.

None
trainer_config Optional[TrainerConfig]

If provided, will redefine the trainer for fine-tuning stage. Defaults to None.

None
experiment_config Optional[ExperimentConfig]

If provided, will redefine the experiment for fine-tuning stage. Defaults to None.

None
loss Optional[Module]

If provided, will be used as the loss function for the fine-tuning. By default, it is MSELoss for regression and CrossEntropyLoss for classification.

None
metrics Optional[List[Callable]]

List of metrics (either callables or str) to be used for the fine-tuning stage. If str, it should be one of the functional metrics implemented in torchmetrics.functional. Defaults to None.

None
metrics_prob_input Optional[List[bool]]

Is a mandatory parameter for classification metrics This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.

None
metrics_params Optional[Dict]

The parameters for the metrics in the same order as metrics. For eg. f1_score for multi-class needs a parameter average to fully define the metric. Defaults to None.

None
optimizer Optional[Optimizer]

Custom optimizers which are a drop in replacements for standard PyTorch optimizers. If provided, the OptimizerConfig is ignored in favor of this. Defaults to None.

None
optimizer_params Dict

The parameters for the optimizer. Defaults to {}.

None
learning_rate Optional[float]

The learning rate to be used. Defaults to 1e-3.

None
target_range Optional[Tuple[float, float]]

The target range for the regression task. Is ignored for classification. Defaults to None.

None
seed Optional[int]

Random seed for reproducibility. Defaults to 42.

42

Returns: TabularModel (TabularModel): The new TabularModel model for fine-tuning

Source code in src/pytorch_tabular/tabular_model.py
def create_finetune_model(
    self,
    task: str,
    head: str,
    head_config: Dict,
    train: DataFrame,
    validation: Optional[DataFrame] = None,
    train_sampler: Optional[torch.utils.data.Sampler] = None,
    target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
    target: Optional[str] = None,
    optimizer_config: Optional[OptimizerConfig] = None,
    trainer_config: Optional[TrainerConfig] = None,
    experiment_config: Optional[ExperimentConfig] = None,
    loss: Optional[torch.nn.Module] = None,
    metrics: Optional[List[Union[Callable, str]]] = None,
    metrics_prob_input: Optional[List[bool]] = None,
    metrics_params: Optional[Dict] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    optimizer_params: Dict = None,
    learning_rate: Optional[float] = None,
    target_range: Optional[Tuple[float, float]] = None,
    seed: Optional[int] = 42,
):
    """Creates a new TabularModel model using the pretrained weights and the new task and head.

    Args:
        task (str): The task to be performed. One of "regression", "classification"

        head (str): The head to be used for the model. Should be one of the heads defined
            in `pytorch_tabular.models.common.heads`. Defaults to  LinearHead. Choices are:
            [`None`,`LinearHead`,`MixtureDensityHead`].

        head_config (Dict): The config as a dict which defines the head. If left empty,
            will be initialized as default linear head.

        train (DataFrame): The training data with labels

        validation (Optional[DataFrame], optional): The validation data with labels. Defaults to None.

        train_sampler (Optional[torch.utils.data.Sampler], optional): If provided, will be used as a batch sampler
            for training. Defaults to None.

        target_transform (Optional[Union[TransformerMixin, Tuple]], optional): If provided, will be used
            to transform the target before training and inverse transform the predictions.

        target (Optional[str], optional): The target column name if not provided in the initial pretraining stage.
            Defaults to None.

        optimizer_config (Optional[OptimizerConfig], optional):
            If provided, will redefine the optimizer for fine-tuning stage. Defaults to None.

        trainer_config (Optional[TrainerConfig], optional):
            If provided, will redefine the trainer for fine-tuning stage. Defaults to None.

        experiment_config (Optional[ExperimentConfig], optional):
            If provided, will redefine the experiment for fine-tuning stage. Defaults to None.

        loss (Optional[torch.nn.Module], optional):
            If provided, will be used as the loss function for the fine-tuning.
            By default, it is MSELoss for regression and CrossEntropyLoss for classification.

        metrics (Optional[List[Callable]], optional): List of metrics (either callables or str) to be used for the
            fine-tuning stage. If str, it should be one of the functional metrics implemented in
            ``torchmetrics.functional``. Defaults to None.

        metrics_prob_input (Optional[List[bool]], optional): Is a mandatory parameter for classification metrics
            This defines whether the input to the metric function is the probability or the class.
            Length should be same as the number of metrics. Defaults to None.

        metrics_params (Optional[Dict], optional): The parameters for the metrics in the same order as metrics.
            For eg. f1_score for multi-class needs a parameter `average` to fully define the metric.
            Defaults to None.

        optimizer (Optional[torch.optim.Optimizer], optional):
            Custom optimizers which are a drop in replacements for standard PyTorch optimizers. If provided,
            the OptimizerConfig is ignored in favor of this. Defaults to None.

        optimizer_params (Dict, optional): The parameters for the optimizer. Defaults to {}.

        learning_rate (Optional[float], optional): The learning rate to be used. Defaults to 1e-3.

        target_range (Optional[Tuple[float, float]], optional): The target range for the regression task.
            Is ignored for classification. Defaults to None.

        seed (Optional[int], optional): Random seed for reproducibility. Defaults to 42.
    Returns:
        TabularModel (TabularModel): The new TabularModel model for fine-tuning

    """
    config = self.config
    optimizer_params = optimizer_params or {}
    if target is None:
        assert (
            hasattr(config, "target") and config.target is not None
        ), "`target` cannot be None if it was not set in the initial `DataConfig`"
    else:
        assert isinstance(target, list), "`target` should be a list of strings"
        config.target = target
    config.task = task
    # Add code to update configs with newly provided ones
    if optimizer_config is not None:
        for key, value in optimizer_config.__dict__.items():
            config[key] = value
        if len(optimizer_params) > 0:
            config.optimizer_params = optimizer_params
        else:
            config.optimizer_params = {}
    if trainer_config is not None:
        for key, value in trainer_config.__dict__.items():
            config[key] = value
    if experiment_config is not None:
        for key, value in experiment_config.__dict__.items():
            config[key] = value
    else:
        if self.track_experiment:
            # Renaming the experiment run so that a different log is created for finetuning
            if self.verbose:
                logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}")
            config["run_name"] = config["run_name"] + "_finetuned"

    datamodule = self.datamodule.copy(
        train=train,
        validation=validation,
        target_transform=target_transform,
        train_sampler=train_sampler,
        seed=seed,
        config_override={"target": target} if target is not None else {},
    )
    model_callable = _GenericModel
    inferred_config = OmegaConf.structured(datamodule._inferred_config)
    # Adding dummy attributes for compatibility. Not used because custom metrics are provided
    if not hasattr(config, "metrics"):
        config.metrics = "dummy"
    if not hasattr(config, "metrics_params"):
        config.metrics_params = {}
    if not hasattr(config, "metrics_prob_input"):
        config.metrics_prob_input = metrics_prob_input or [False]
    if metrics is not None:
        assert len(metrics) == len(metrics_params), "Number of metrics and metrics_params should be same"
        assert len(metrics) == len(metrics_prob_input), "Number of metrics and metrics_prob_input should be same"
        metrics = [getattr(torchmetrics.functional, m) if isinstance(m, str) else m for m in metrics]
    if task == "regression":
        loss = loss or torch.nn.MSELoss()
        if metrics is None:
            metrics = [torchmetrics.functional.mean_squared_error]
            metrics_params = [{}]
    elif task == "classification":
        loss = loss or torch.nn.CrossEntropyLoss()
        if metrics is None:
            metrics = [torchmetrics.functional.accuracy]
            metrics_params = [
                {
                    "task": "multiclass",
                    "num_classes": inferred_config.output_dim,
                    "top_k": 1,
                }
            ]
            metrics_prob_input = [False]
        else:
            for i, mp in enumerate(metrics_params):
                # For classification task, output_dim == number of classses
                metrics_params[i]["task"] = mp.get("task", "multiclass")
                metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim)
                metrics_params[i]["top_k"] = mp.get("top_k", 1)
    else:
        raise ValueError(f"Task {task} not supported")
    # Forming partial callables using metrics and metric params
    metrics = [partial(m, **mp) for m, mp in zip(metrics, metrics_params)]
    self.model.mode = "finetune"
    if learning_rate is not None:
        config.learning_rate = learning_rate
    config.target_range = target_range
    model_args = {
        "backbone": self.model,
        "head": head,
        "head_config": head_config,
        "config": config,
        "inferred_config": inferred_config,
        "custom_loss": loss,
        "custom_metrics": metrics,
        "custom_metrics_prob_inputs": metrics_prob_input,
        "custom_optimizer": optimizer,
        "custom_optimizer_params": optimizer_params,
    }
    # Initializing with default metrics, losses, and optimizers. Will revert once initialized
    model = model_callable(
        **model_args,
    )
    tabular_model = TabularModel(config=config, verbose=self.verbose)
    tabular_model.model = model
    tabular_model.datamodule = datamodule
    # Setting a flag to identify this as a fine-tune model
    tabular_model._is_finetune_model = True
    return tabular_model

cross_validate(cv, train, metric=None, return_oof=False, groups=None, verbose=True, reset_datamodule=True, handle_oom=True, **kwargs)

Cross validate the model.

Parameters:

Name Type Description Default
cv Optional[Union[int, Iterable, BaseCrossValidator]]

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 5-fold cross validation (KFold for Regression and StratifiedKFold for Classification),
  • integer, to specify the number of folds in a (Stratified)KFold,
  • An iterable yielding (train, test) splits as arrays of indices.
  • A scikit-learn CV splitter.
required
train DataFrame

The training data with labels

required
metric Optional[Union[str, Callable]]

The metrics to be used for evaluation. If None, will use the first metric in the config. If str is provided, will use that metric from the defined ones. If callable is provided, will use that function as the metric. We expect callable to be of the form metric(y_true, y_pred). For classification problems, The y_pred is a dataframe with the probabilities for each class (_probability) and a final prediction(prediction). And for Regression, it is a dataframe with a final prediction (_prediction). Defaults to None.

None
return_oof bool

If True, will return the out-of-fold predictions along with the cross validation results. Defaults to False.

False
groups Optional[Union[str, ndarray]]

Group labels for the samples used while splitting. If provided, will be used as the groups argument for the split method of the cross validator. If input is str, will use the column in the input dataframe with that name as the group labels. If input is array-like, will use that as the group. The only constraint is that the group labels should have the same size as the number of rows in the input dataframe. Defaults to None.

None
verbose bool

If True, will log the results. Defaults to True.

True
reset_datamodule bool

If True, will reset the datamodule for each iteration. It will be slower because we will be fitting the transformations for each fold. If False, we take an approximation that once the transformations are fit on the first fold, they will be valid for all the other folds. Defaults to True.

True
handle_oom bool

If True, will handle out of memory errors elegantly

True
**kwargs

Additional keyword arguments to be passed to the fit method of the model.

{}

Returns:

Name Type Description
DataFrame

The dataframe with the cross validation results

Source code in src/pytorch_tabular/tabular_model.py
def cross_validate(
    self,
    cv: Optional[Union[int, Iterable, BaseCrossValidator]],
    train: DataFrame,
    metric: Optional[Union[str, Callable]] = None,
    return_oof: bool = False,
    groups: Optional[Union[str, np.ndarray]] = None,
    verbose: bool = True,
    reset_datamodule: bool = True,
    handle_oom: bool = True,
    **kwargs,
):
    """Cross validate the model.

    Args:
        cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
            Possible inputs for cv are:

            - None, to use the default 5-fold cross validation (KFold for
            Regression and StratifiedKFold for Classification),
            - integer, to specify the number of folds in a (Stratified)KFold,
            - An iterable yielding (train, test) splits as arrays of indices.
            - A scikit-learn CV splitter.

        train (DataFrame): The training data with labels

        metric (Optional[Union[str, Callable]], optional): The metrics to be used for evaluation.
            If None, will use the first metric in the config. If str is provided, will use that
            metric from the defined ones. If callable is provided, will use that function as the
            metric. We expect callable to be of the form `metric(y_true, y_pred)`. For classification
            problems, The `y_pred` is a dataframe with the probabilities for each class
            (<class>_probability) and a final prediction(prediction). And for Regression, it is a
            dataframe with a final prediction (<target>_prediction).
            Defaults to None.

        return_oof (bool, optional): If True, will return the out-of-fold predictions
            along with the cross validation results. Defaults to False.

        groups (Optional[Union[str, np.ndarray]], optional): Group labels for
            the samples used while splitting. If provided, will be used as the
            `groups` argument for the `split` method of the cross validator.
            If input is str, will use the column in the input dataframe with that
            name as the group labels. If input is array-like, will use that as the
            group. The only constraint is that the group labels should have the
            same size as the number of rows in the input dataframe. Defaults to None.

        verbose (bool, optional): If True, will log the results. Defaults to True.

        reset_datamodule (bool, optional): If True, will reset the datamodule for each iteration.
            It will be slower because we will be fitting the transformations for each fold.
            If False, we take an approximation that once the transformations are fit on the first
            fold, they will be valid for all the other folds. Defaults to True.

        handle_oom (bool, optional): If True, will handle out of memory errors elegantly
        **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.

    Returns:
        DataFrame: The dataframe with the cross validation results

    """
    cv = self._check_cv(cv)
    prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
    is_callable_metric = False
    if metric is None:
        metric = "test_" + self.config.metrics[0]
    elif isinstance(metric, str):
        metric = metric if metric.startswith("test_") else "test_" + metric
    elif callable(metric):
        is_callable_metric = True

    if isinstance(cv, BaseCrossValidator):
        it = enumerate(cv.split(train, y=train[self.config.target], groups=groups))
    else:
        # when iterable is directly passed
        it = enumerate(cv)
    cv_metrics = []
    datamodule = None
    model = None
    oof_preds = []
    for fold, (train_idx, val_idx) in it:
        if verbose:
            logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
        # train_fold = train.iloc[train_idx]
        # val_fold = train.iloc[val_idx]
        if reset_datamodule:
            datamodule = None
        if datamodule is None:
            # Initialize datamodule and model in the first fold
            # uses train data from this fold to fit all transformers
            datamodule = self.prepare_dataloader(
                train=train.iloc[train_idx], validation=train.iloc[val_idx], seed=42, **prep_dl_kwargs
            )
            model = self.prepare_model(datamodule, **prep_model_kwargs)
        else:
            # Preprocess the current fold data using the fitted transformers and save in datamodule
            datamodule.train, _ = datamodule.preprocess_data(train.iloc[train_idx], stage="inference")
            datamodule.validation, _ = datamodule.preprocess_data(train.iloc[val_idx], stage="inference")

        # Train the model
        handle_oom = train_kwargs.pop("handle_oom", handle_oom)
        self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
        if return_oof or is_callable_metric:
            preds = self.predict(train.iloc[val_idx], include_input_features=False)
            oof_preds.append(preds)
        if is_callable_metric:
            cv_metrics.append(metric(train.iloc[val_idx][self.config.target], preds))
        else:
            result = self.evaluate(train.iloc[val_idx], verbose=False)
            cv_metrics.append(result[0][metric])
        if verbose:
            logger.info(f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}")
        self.model.reset_weights()
    return cv_metrics, oof_preds

evaluate(test=None, test_loader=None, ckpt_path=None, verbose=True)

Evaluates the dataframe using the loss and metrics already set in config.

Parameters:

Name Type Description Default
test Optional[DataFrame]

The dataframe to be evaluated. If not provided, will try to use the test provided during fit. If that was also not provided will return an empty dictionary

None
test_loader Optional[DataLoader]

The dataloader to be used for evaluation. If provided, will use the dataloader instead of the test dataframe or the test data provided during fit. Defaults to None.

None
ckpt_path Optional[Union[str, Path]]

The path to the checkpoint to be loaded. If not provided, will try to use the best checkpoint during training.

None
verbose bool

If true, will print the results. Defaults to True.

True

Returns: The final test result dictionary.

Source code in src/pytorch_tabular/tabular_model.py
def evaluate(
    self,
    test: Optional[DataFrame] = None,
    test_loader: Optional[torch.utils.data.DataLoader] = None,
    ckpt_path: Optional[Union[str, Path]] = None,
    verbose: bool = True,
) -> Union[dict, list]:
    """Evaluates the dataframe using the loss and metrics already set in config.

    Args:
        test (Optional[DataFrame]): The dataframe to be evaluated. If not provided, will try to use the
            test provided during fit. If that was also not provided will return an empty dictionary

        test_loader (Optional[torch.utils.data.DataLoader], optional): The dataloader to be used for evaluation.
            If provided, will use the dataloader instead of the test dataframe or the test data provided during fit.
            Defaults to None.

        ckpt_path (Optional[Union[str, Path]], optional): The path to the checkpoint to be loaded. If not provided,
            will try to use the best checkpoint during training.

        verbose (bool, optional): If true, will print the results. Defaults to True.
    Returns:
        The final test result dictionary.

    """
    assert not (test_loader is None and test is None), (
        "Either `test_loader` or `test` should be provided."
        " If `test_loader` is not provided, `test` should be provided."
    )
    if test_loader is None:
        test_loader = self.datamodule.prepare_inference_dataloader(test)
    result = self.trainer.test(
        model=self.model,
        dataloaders=test_loader,
        ckpt_path=ckpt_path,
        verbose=verbose,
    )
    return result

explain(data, method='GradientShap', method_args={}, baselines=None, **kwargs)

Returns the feature attributions/explanations of the model as a pandas DataFrame. The shape of the returned dataframe is (num_samples, num_features)

Parameters:

Name Type Description Default
data DataFrame

The dataframe to be explained

required
method str

The method to be used for explaining the model. It should be one of the Defaults to "GradientShap". For more details, refer to https://captum.ai/api/attribution.html

'GradientShap'
method_args Optional[Dict]

The arguments to be passed to the initialization of the Captum method.

{}
baselines Union[float, tensor, str]

The baselines to be used for the explanation. If a scalar is provided, will use that value as the baseline for all the features. If a tensor is provided, will use that tensor as the baseline for all the features. If a string like b|<num_samples> is provided, will use that many samples from the train Using the whole train data as the baseline is not recommended as it can be computationally expensive. By default, PyTorch Tabular uses 10000 samples from the train data as the baseline. You can configure this by passing a special string "b|" where is the number of samples to be used as the baseline. For eg. "b|1000" will use 1000 samples from the train. If None, will use default settings like zero in captum(which is method dependent). For GradientShap, it is the train data. Defaults to None.

None
**kwargs

Additional keyword arguments to be passed to the Captum method attribute function.

{}

Returns:

Name Type Description
DataFrame DataFrame

The dataframe with the feature importance

Source code in src/pytorch_tabular/tabular_model.py
def explain(
    self,
    data: DataFrame,
    method: str = "GradientShap",
    method_args: Optional[Dict] = {},
    baselines: Union[float, torch.tensor, str] = None,
    **kwargs,
) -> DataFrame:
    """Returns the feature attributions/explanations of the model as a pandas DataFrame. The shape of the returned
    dataframe is (num_samples, num_features)

    Args:
        data (DataFrame): The dataframe to be explained
        method (str): The method to be used for explaining the model.
            It should be one of the Defaults to "GradientShap".
            For more details, refer to https://captum.ai/api/attribution.html
        method_args (Optional[Dict], optional): The arguments to be passed to the initialization
            of the Captum method.
        baselines (Union[float, torch.tensor, str]): The baselines to be used for the explanation.
            If a scalar is provided, will use that value as the baseline for all the features.
            If a tensor is provided, will use that tensor as the baseline for all the features.
            If a string like `b|<num_samples>` is provided, will use that many samples from the train
            Using the whole train data as the baseline is not recommended as it can be
            computationally expensive. By default, PyTorch Tabular uses 10000 samples from the
            train data as the baseline. You can configure this by passing a special string
            "b|<num_samples>" where <num_samples> is the number of samples to be used as the
            baseline. For eg. "b|1000" will use 1000 samples from the train.
            If None, will use default settings like zero in captum(which is method dependent).
            For `GradientShap`, it is the train data.
            Defaults to None.

        **kwargs: Additional keyword arguments to be passed to the Captum method `attribute` function.

    Returns:
        DataFrame: The dataframe with the feature importance

    """
    assert CAPTUM_INSTALLED, "Captum not installed. Please install using `pip install captum` or "
    "install PyTorch Tabular using `pip install pytorch-tabular[extra]`"
    ALLOWED_METHODS = [
        "GradientShap",
        "IntegratedGradients",
        "DeepLift",
        "DeepLiftShap",
        "InputXGradient",
        "FeaturePermutation",
        "FeatureAblation",
        "KernelShap",
    ]
    assert method in ALLOWED_METHODS, f"method should be one of {ALLOWED_METHODS}"
    if isinstance(data, pd.Series):
        data = data.to_frame().T
    if method in ["DeepLiftShap", "KernelShap"]:
        warnings.warn(
            f"{method} is computationally expensive and will take some time. For"
            " faster results, try usingsome other methods like GradientShap,"
            " IntegratedGradients etc."
        )
    if method in ["FeaturePermutation", "FeatureAblation"]:
        assert data.shape[0] > 1, f"{method} only works when the number of samples is greater than 1"
        if len(data) <= 100:
            warnings.warn(
                f"{method} gives better results when the number of samples is"
                " large. For better results, try using more samples or some other"
                " methods like GradientShap which works well on single examples."
            )
    is_full_baselines = method in ["GradientShap", "DeepLiftShap"]
    is_not_supported = self.model._get_name() in [
        "TabNetModel",
        "MDNModel",
        "TabTransformerModel",
    ]
    do_baselines = method not in [
        "Saliency",
        "InputXGradient",
        "FeaturePermutation",
        "LRP",
    ]
    if is_full_baselines and (baselines is None or isinstance(baselines, (float, int))):
        raise ValueError(
            f"baselines cannot be a scalar or None for {method}. Please "
            "provide a tensor or a string like `b|<num_samples>`"
        )
    if is_not_supported:
        raise NotImplementedError(f"Attributions are not implemented for {self.model._get_name()}")

    is_embedding1d = isinstance(self.model.embedding_layer, (Embedding1dLayer, PreEncoded1dLayer))
    is_embedding2d = isinstance(self.model.embedding_layer, Embedding2dLayer)
    # Models like NODE may have no embedding dims (doing leaveOneOut encoding) even if categorical_dim > 0
    is_embbeding_dims = (
        hasattr(self.model.hparams, "embedding_dims") and self.model.hparams.embedding_dims is not None
    )
    if (not is_embedding1d) and (not is_embedding2d):
        raise NotImplementedError(
            "Attributions are not implemented for models with this type of" " embedding layer"
        )
    test_dl = self.datamodule.prepare_inference_dataloader(data)
    self.model.eval()
    # prepare import for Captum
    tensor_inp, tensor_tgt = self._prepare_input_for_captum(test_dl)
    baselines = self._prepare_baselines_captum(baselines, test_dl, do_baselines, is_full_baselines)
    # prepare model for Captum
    try:
        interp_model = _CaptumModel(self.model)
        captum_interp_cls = getattr(captum.attr, method)(interp_model, **method_args)
        if do_baselines:
            attributions = captum_interp_cls.attribute(
                tensor_inp,
                baselines=baselines,
                target=(tensor_tgt if self.config.task == "classification" else None),
                **kwargs,
            )
        else:
            attributions = captum_interp_cls.attribute(
                tensor_inp,
                target=(tensor_tgt if self.config.task == "classification" else None),
                **kwargs,
            )
        attributions = self._handle_categorical_embeddings_attributions(
            attributions, is_embedding1d, is_embedding2d, is_embbeding_dims
        )
    finally:
        self.model.train()
    assert attributions.shape[1] == self.model.hparams.continuous_dim + self.model.hparams.categorical_dim, (
        "Something went wrong. The number of features in the attributions"
        f" ({attributions.shape[1]}) does not match the number of features in"
        " the model"
        f" ({self.model.hparams.continuous_dim+self.model.hparams.categorical_dim})"
    )
    return pd.DataFrame(
        attributions.detach().cpu().numpy(),
        columns=self.config.continuous_cols + self.config.categorical_cols,
    )

feature_importance()

Returns the feature importance of the model as a pandas DataFrame.

Source code in src/pytorch_tabular/tabular_model.py
def feature_importance(self) -> DataFrame:
    """Returns the feature importance of the model as a pandas DataFrame."""
    return self.model.feature_importance()

find_learning_rate(model, datamodule, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, plot=True, callbacks=None)

Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

Parameters:

Name Type Description Default
model LightningModule

The PyTorch Lightning model to be trained.

required
datamodule TabularDatamodule

The datamodule

required
min_lr Optional[float]

minimum learning rate to investigate

1e-08
max_lr Optional[float]

maximum learning rate to investigate

1
num_training Optional[int]

number of learning rates to test

100
mode Optional[str]

search strategy, either 'linear' or 'exponential'. If set to 'linear' the learning rate will be searched by linearly increasing after each batch. If set to 'exponential', will increase learning rate exponentially.

'exponential'
early_stop_threshold Optional[float]

threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.

4.0
plot bool

If true, will plot using matplotlib

True
callbacks Optional[List]

If provided, will be added to the callbacks for Trainer.

None

Returns:

Type Description
Tuple[float, DataFrame]

The suggested learning rate and the learning rate finder results

Source code in src/pytorch_tabular/tabular_model.py
def find_learning_rate(
    self,
    model: pl.LightningModule,
    datamodule: TabularDatamodule,
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = "exponential",
    early_stop_threshold: Optional[float] = 4.0,
    plot: bool = True,
    callbacks: Optional[List] = None,
) -> Tuple[float, DataFrame]:
    """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
    picking a good starting learning rate.

    Args:
        model (pl.LightningModule): The PyTorch Lightning model to be trained.

        datamodule (TabularDatamodule): The datamodule

        min_lr (Optional[float], optional): minimum learning rate to investigate

        max_lr (Optional[float], optional): maximum learning rate to investigate

        num_training (Optional[int], optional): number of learning rates to test

        mode (Optional[str], optional): search strategy, either 'linear' or 'exponential'. If set to
            'linear' the learning rate will be searched by linearly increasing
            after each batch. If set to 'exponential', will increase learning
            rate exponentially.

        early_stop_threshold (Optional[float], optional): threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.

        plot (bool, optional): If true, will plot using matplotlib

        callbacks (Optional[List], optional): If provided, will be added to the callbacks for Trainer.

    Returns:
        The suggested learning rate and the learning rate finder results

    """
    self._prepare_for_training(model, datamodule, callbacks, max_epochs=None, min_epochs=None)
    train_loader, _ = datamodule.train_dataloader(), datamodule.val_dataloader()
    lr_finder = Tuner(self.trainer).lr_find(
        model=self.model,
        train_dataloaders=train_loader,
        val_dataloaders=None,
        min_lr=min_lr,
        max_lr=max_lr,
        num_training=num_training,
        mode=mode,
        early_stop_threshold=early_stop_threshold,
    )
    if plot:
        fig = lr_finder.plot(suggest=True)
        fig.show()
    new_lr = lr_finder.suggestion()
    # cancelling the model and trainer that was loaded
    self.model = None
    self.trainer = None
    self.datamodule = None
    self.callbacks = None
    return new_lr, DataFrame(lr_finder.results)

finetune(max_epochs=None, min_epochs=None, callbacks=None, freeze_backbone=False)

Finetunes the model on the provided data.

Parameters:

Name Type Description Default
max_epochs Optional[int]

The maximum number of epochs to train for. Defaults to None.

None
min_epochs Optional[int]

The minimum number of epochs to train for. Defaults to None.

None
callbacks Optional[List[Callback]]

If provided, will be added to the callbacks for Trainer. Defaults to None.

None
freeze_backbone bool

If True, will freeze the backbone by tirning off gradients. Defaults to False, which means the pretrained weights are also further tuned during fine-tuning.

False

Returns:

Type Description
Trainer

pl.Trainer: The trainer object

Source code in src/pytorch_tabular/tabular_model.py
def finetune(
    self,
    max_epochs: Optional[int] = None,
    min_epochs: Optional[int] = None,
    callbacks: Optional[List[pl.Callback]] = None,
    freeze_backbone: bool = False,
) -> pl.Trainer:
    """Finetunes the model on the provided data.

    Args:
        max_epochs (Optional[int], optional): The maximum number of epochs to train for. Defaults to None.

        min_epochs (Optional[int], optional): The minimum number of epochs to train for. Defaults to None.

        callbacks (Optional[List[pl.Callback]], optional): If provided, will be added to the callbacks for Trainer.
            Defaults to None.

        freeze_backbone (bool, optional): If True, will freeze the backbone by tirning off gradients.
            Defaults to False, which means the pretrained weights are also further tuned during fine-tuning.

    Returns:
        pl.Trainer: The trainer object

    """
    assert self._is_finetune_model, (
        "finetune() can only be called on a finetune model created using" " `TabularModel.create_finetune_model()`"
    )
    seed_everything(self.config.seed)
    if freeze_backbone:
        for param in self.model.backbone.parameters():
            param.requires_grad = False
    return self.train(
        self.model,
        self.datamodule,
        callbacks=callbacks,
        max_epochs=max_epochs,
        min_epochs=min_epochs,
    )

fit(train, validation=None, loss=None, metrics=None, metrics_prob_inputs=None, optimizer=None, optimizer_params=None, train_sampler=None, target_transform=None, max_epochs=None, min_epochs=None, seed=42, callbacks=None, datamodule=None, cache_data='memory', handle_oom=True)

The fit method which takes in the data and triggers the training.

Parameters:

Name Type Description Default
train DataFrame

Training Dataframe

required
validation Optional[DataFrame]

If provided, will use this dataframe as the validation while training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None.

None
loss Optional[Module]

Custom Loss functions which are not in standard pytorch library

None
metrics Optional[List[Callable]]

Custom metric functions(Callable) which has the signature metric_fn(y_hat, y) and works on torch tensor inputs. y_hat is expected to be of shape (batch_size, num_classes) for classification and (batch_size, 1) for regression and y is expected to be of shape (batch_size, 1)

None
metrics_prob_inputs Optional[List[bool]]

This is a mandatory parameter for classification metrics. If the metric function requires probabilities as inputs, set this to True. The length of the list should be equal to the number of metrics. Defaults to None.

None
optimizer Optional[Optimizer]

Custom optimizers which are a drop in replacements for standard PyTorch optimizers. This should be the Class and not the initialized object

None
optimizer_params Optional[Dict]

The parameters to initialize the custom optimizer.

None
train_sampler Optional[Sampler]

Custom PyTorch batch samplers which will be passed to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies

None
target_transform Optional[Union[TransformerMixin, Tuple(Callable)]]

If provided, applies the transform to the target before modelling and inverse the transform during prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func)

None
max_epochs Optional[int]

Overwrite maximum number of epochs to be run. Defaults to None.

None
min_epochs Optional[int]

Overwrite minimum number of epochs to be run. Defaults to None.

None
seed Optional[int]

(int): Random seed for reproducibility. Defaults to 42.

42
callbacks Optional[List[Callback]]

List of callbacks to be used during training. Defaults to None.

None
datamodule Optional[TabularDatamodule]

The datamodule. If provided, will ignore the rest of the parameters like train, test etc and use the datamodule. Defaults to None.

None
cache_data str

Decides how to cache the data in the dataloader. If set to "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

'memory'
handle_oom bool

If True, will try to handle OOM errors elegantly. Defaults to True.

True

Returns:

Type Description
Trainer

pl.Trainer: The PyTorch Lightning Trainer instance

Source code in src/pytorch_tabular/tabular_model.py
def fit(
    self,
    train: Optional[DataFrame],
    validation: Optional[DataFrame] = None,
    loss: Optional[torch.nn.Module] = None,
    metrics: Optional[List[Callable]] = None,
    metrics_prob_inputs: Optional[List[bool]] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    optimizer_params: Dict = None,
    train_sampler: Optional[torch.utils.data.Sampler] = None,
    target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
    max_epochs: Optional[int] = None,
    min_epochs: Optional[int] = None,
    seed: Optional[int] = 42,
    callbacks: Optional[List[pl.Callback]] = None,
    datamodule: Optional[TabularDatamodule] = None,
    cache_data: str = "memory",
    handle_oom: bool = True,
) -> pl.Trainer:
    """The fit method which takes in the data and triggers the training.

    Args:
        train (DataFrame): Training Dataframe

        validation (Optional[DataFrame], optional):
            If provided, will use this dataframe as the validation while training.
            Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
            Defaults to None.

        loss (Optional[torch.nn.Module], optional): Custom Loss functions which are not in standard pytorch library

        metrics (Optional[List[Callable]], optional): Custom metric functions(Callable) which has the
            signature metric_fn(y_hat, y) and works on torch tensor inputs. y_hat is expected to be of shape
            (batch_size, num_classes) for classification and (batch_size, 1) for regression and y is expected to be
            of shape (batch_size, 1)

        metrics_prob_inputs (Optional[List[bool]], optional): This is a mandatory parameter for
            classification metrics. If the metric function requires probabilities as inputs, set this to True.
            The length of the list should be equal to the number of metrics. Defaults to None.

        optimizer (Optional[torch.optim.Optimizer], optional):
            Custom optimizers which are a drop in replacements for
            standard PyTorch optimizers. This should be the Class and not the initialized object

        optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

        train_sampler (Optional[torch.utils.data.Sampler], optional):
            Custom PyTorch batch samplers which will be passed
            to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies

        target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional):
            If provided, applies the transform to the target before modelling and inverse the transform during
            prediction. The parameter can either be a sklearn Transformer
            which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func)

        max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

        min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

        seed: (int): Random seed for reproducibility. Defaults to 42.

        callbacks (Optional[List[pl.Callback]], optional):
            List of callbacks to be used during training. Defaults to None.

        datamodule (Optional[TabularDatamodule], optional): The datamodule.
            If provided, will ignore the rest of the parameters like train, test etc and use the datamodule.
            Defaults to None.

        cache_data (str): Decides how to cache the data in the dataloader. If set to
            "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

        handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.

    Returns:
        pl.Trainer: The PyTorch Lightning Trainer instance

    """
    assert self.config.task != "ssl", (
        "`fit` is not valid for SSL task. Please use `pretrain` for" " semi-supervised learning"
    )
    if metrics is not None:
        assert len(metrics) == len(
            metrics_prob_inputs or []
        ), "The length of `metrics` and `metrics_prob_inputs` should be equal"
    seed = seed or self.config.seed
    if seed:
        seed_everything(seed)
    if datamodule is None:
        datamodule = self.prepare_dataloader(
            train,
            validation,
            train_sampler,
            target_transform,
            seed,
            cache_data,
        )
    else:
        if train is not None:
            warnings.warn(
                "train data and datamodule is provided."
                " Ignoring the train data and using the datamodule."
                " Set either one of them to None to avoid this warning."
            )
    model = self.prepare_model(
        datamodule,
        loss,
        metrics,
        metrics_prob_inputs,
        optimizer,
        optimizer_params or {},
    )

    return self.train(model, datamodule, callbacks, max_epochs, min_epochs, handle_oom)

load_best_model()

Loads the best model after training is done.

Source code in src/pytorch_tabular/tabular_model.py
def load_best_model(self) -> None:
    """Loads the best model after training is done."""
    if self.trainer.checkpoint_callback is not None:
        if self.verbose:
            logger.info("Loading the best model")
        ckpt_path = self.trainer.checkpoint_callback.best_model_path
        if ckpt_path != "":
            if self.verbose:
                logger.debug(f"Model Checkpoint: {ckpt_path}")
            ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
            self.model.load_state_dict(ckpt["state_dict"])
        else:
            logger.warning("No best model available to load. Did you run it more than 1" " epoch?...")
    else:
        logger.warning(
            "No best model available to load. Checkpoint Callback needs to be" " enabled for this to work"
        )

load_model(dir, map_location=None, strict=True) classmethod

Loads a saved model from the directory.

Parameters:

Name Type Description Default
dir str

The directory where the model wa saved, along with the checkpoints

required
map_location Union[Dict[str, str], str, device, int, Callable, None])

If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load()

None
strict bool)

Whether to strictly enforce that the keys in checkpoint_path match the keys returned by this module's state dict. Default: True.

True

Returns:

Name Type Description
TabularModel TabularModel

The saved TabularModel

Source code in src/pytorch_tabular/tabular_model.py
@classmethod
def load_model(cls, dir: str, map_location=None, strict=True):
    """Loads a saved model from the directory.

    Args:
        dir (str): The directory where the model wa saved, along with the checkpoints
        map_location (Union[Dict[str, str], str, device, int, Callable, None]) : If your checkpoint
            saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map
            to the new setup. The behaviour is the same as in torch.load()
        strict (bool) : Whether to strictly enforce that the keys in checkpoint_path match the keys
            returned by this module's state dict. Default: True.

    Returns:
        TabularModel (TabularModel): The saved TabularModel

    """
    config = OmegaConf.load(os.path.join(dir, "config.yml"))
    datamodule = joblib.load(os.path.join(dir, "datamodule.sav"))
    if (
        hasattr(config, "log_target")
        and (config.log_target is not None)
        and os.path.exists(os.path.join(dir, "exp_logger.sav"))
    ):
        logger = joblib.load(os.path.join(dir, "exp_logger.sav"))
    else:
        logger = None
    if os.path.exists(os.path.join(dir, "callbacks.sav")):
        callbacks = joblib.load(os.path.join(dir, "callbacks.sav"))
        # Excluding Gradient Accumulation Scheduler Callback as we are creating
        # a new one in trainer
        callbacks = [c for c in callbacks if not isinstance(c, GradientAccumulationScheduler)]
    else:
        callbacks = []
    if os.path.exists(os.path.join(dir, "custom_model_callable.sav")):
        model_callable = joblib.load(os.path.join(dir, "custom_model_callable.sav"))
        custom_model = True
    else:
        model_callable = getattr_nested(config._module_src, config._model_name)
        # model_callable = getattr(
        #     getattr(models, config._module_src), config._model_name
        # )
        custom_model = False
    inferred_config = datamodule.update_config(config)
    inferred_config = OmegaConf.structured(inferred_config)
    model_args = {
        "config": config,
        "inferred_config": inferred_config,
    }
    custom_params = joblib.load(os.path.join(dir, "custom_params.sav"))
    if custom_params.get("custom_loss") is not None:
        model_args["loss"] = "MSELoss"  # For compatibility. Not Used
    if custom_params.get("custom_metrics") is not None:
        model_args["metrics"] = ["mean_squared_error"]  # For compatibility. Not Used
        model_args["metrics_params"] = [{}]  # For compatibility. Not Used
        model_args["metrics_prob_inputs"] = [False]  # For compatibility. Not Used
    if custom_params.get("custom_optimizer") is not None:
        model_args["optimizer"] = "Adam"  # For compatibility. Not Used
    if custom_params.get("custom_optimizer_params") is not None:
        model_args["optimizer_params"] = {}  # For compatibility. Not Used

    # Initializing with default metrics, losses, and optimizers. Will revert once initialized
    try:
        model = model_callable.load_from_checkpoint(
            checkpoint_path=os.path.join(dir, "model.ckpt"),
            map_location=map_location,
            strict=strict,
            **model_args,
        )
    except RuntimeError as e:
        if (
            "Unexpected key(s) in state_dict" in str(e)
            and "loss.weight" in str(e)
            and "custom_loss.weight" in str(e)
        ):
            # Custom loss will be loaded after the model is initialized
            # continuing with strict=False
            model = model_callable.load_from_checkpoint(
                checkpoint_path=os.path.join(dir, "model.ckpt"),
                map_location=map_location,
                strict=False,
                **model_args,
            )
        else:
            raise e
    if custom_params.get("custom_optimizer") is not None:
        model.custom_optimizer = custom_params["custom_optimizer"]
    if custom_params.get("custom_optimizer_params") is not None:
        model.custom_optimizer_params = custom_params["custom_optimizer_params"]
    if custom_params.get("custom_loss") is not None:
        model.loss = custom_params["custom_loss"]
    if custom_params.get("custom_metrics") is not None:
        model.custom_metrics = custom_params.get("custom_metrics")
        model.hparams.metrics = [m.__name__ for m in custom_params.get("custom_metrics")]
        model.hparams.metrics_params = [{}]
        model.hparams.metrics_prob_input = custom_params.get("custom_metrics_prob_inputs")
    model._setup_loss()
    model._setup_metrics()
    tabular_model = cls(config=config, model_callable=model_callable)
    tabular_model.model = model
    tabular_model.custom_model = custom_model
    tabular_model.datamodule = datamodule
    tabular_model.callbacks = callbacks
    tabular_model.trainer = tabular_model._prepare_trainer(callbacks=callbacks)
    # tabular_model.trainer.model = model
    tabular_model.logger = logger
    return tabular_model

load_weights(path)

Loads the model weights in the specified directory.

Parameters:

Name Type Description Default
path str

The path to the file to load the model from

required
Source code in src/pytorch_tabular/tabular_model.py
def load_weights(self, path: Union[str, Path]) -> None:
    """Loads the model weights in the specified directory.

    Args:
        path (str): The path to the file to load the model from

    """
    self._load_weights(self.model, path)

predict(test, quantiles=[0.25, 0.5, 0.75], n_samples=100, ret_logits=False, include_input_features=False, device=None, progress_bar=None, test_time_augmentation=False, num_tta=5, alpha_tta=0.1, aggregate_tta='mean', tta_seed=42)

Uses the trained model to predict on new data and return as a dataframe.

Parameters:

Name Type Description Default
test DataFrame

The new dataframe with the features defined during training

required
quantiles Optional[List]

For probabilistic models like Mixture Density Networks, this specifies the different quantiles to be extracted apart from the central_tendency and added to the dataframe. For other models it is ignored. Defaults to [0.25, 0.5, 0.75]

[0.25, 0.5, 0.75]
n_samples Optional[int]

Number of samples to draw from the posterior to estimate the quantiles. Ignored for non-probabilistic models. Defaults to 100

100
ret_logits bool

Flag to return raw model outputs/logits except the backbone features along with the dataframe. Defaults to False

False
include_input_features bool

DEPRECATED: Flag to include the input features in the returned dataframe. Defaults to True

False
progress_bar Optional[str]

choose progress bar for tracking the progress. "rich" or "tqdm" will set the respective progress bars. If None, no progress bar will be shown.

None
test_time_augmentation bool

If True, will use test time augmentation to generate predictions. The approach is very similar to what is described here But, we add noise to the embedded inputs to handle categorical features as well. (x_{aug} = x_{orig} + lpha * \epsilon) where (\epsilon \sim \mathcal{N}(0, 1)) Defaults to False

False
num_tta float

The number of augumentations to run TTA for. Defaults to 0.0

5
alpha_tta float

The standard deviation of the gaussian noise to be added to the input features

0.1
aggregate_tta Union[str, Callable]

The function to be used to aggregate the predictions from each augumentation. If str, should be one of "mean", "median", "min", or "max" for regression. For classification, the previous options are applied to the confidence scores (soft voting) and then converted to final prediction. An additional option "hard_voting" is available for classification. If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets) and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".'

'mean'
tta_seed int

The random seed to be used for the noise added in TTA. Defaults to 42.

42

Returns:

Name Type Description
DataFrame DataFrame

Returns a dataframe with predictions and features (if include_input_features=True). If classification, it returns probabilities and final prediction

Source code in src/pytorch_tabular/tabular_model.py
def predict(
    self,
    test: DataFrame,
    quantiles: Optional[List] = [0.25, 0.5, 0.75],
    n_samples: Optional[int] = 100,
    ret_logits=False,
    include_input_features: bool = False,
    device: Optional[torch.device] = None,
    progress_bar: Optional[str] = None,
    test_time_augmentation: Optional[bool] = False,
    num_tta: Optional[float] = 5,
    alpha_tta: Optional[float] = 0.1,
    aggregate_tta: Optional[str] = "mean",
    tta_seed: Optional[int] = 42,
) -> DataFrame:
    """Uses the trained model to predict on new data and return as a dataframe.

    Args:
        test (DataFrame): The new dataframe with the features defined during training

        quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
            the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
            For other models it is ignored. Defaults to [0.25, 0.5, 0.75]

        n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
            Ignored for non-probabilistic models. Defaults to 100

        ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
            with the dataframe. Defaults to False

        include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
            Defaults to True

        progress_bar: choose progress bar for tracking the progress. "rich" or "tqdm" will set the respective
            progress bars. If None, no progress bar will be shown.

        test_time_augmentation (bool): If True, will use test time augmentation to generate predictions.
            The approach is very similar to what is described [here](https://kozodoi.me/blog/20210908/tta-tabular)
            But, we add noise to the embedded inputs to handle categorical features as well.\
            \\(x_{aug} = x_{orig} + \alpha * \\epsilon\\) where \\(\\epsilon \\sim \\mathcal{N}(0, 1)\\)
            Defaults to False
        num_tta (float): The number of augumentations to run TTA for. Defaults to 0.0

        alpha_tta (float): The standard deviation of the gaussian noise to be added to the input features

        aggregate_tta (Union[str, Callable], optional): The function to be used to aggregate the
            predictions from each augumentation. If str, should be one of "mean", "median", "min", or "max"
            for regression. For classification, the previous options are applied to the confidence
            scores (soft voting) and then converted to final prediction. An additional option
            "hard_voting" is available for classification.
            If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
            and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".'

        tta_seed (int): The random seed to be used for the noise added in TTA. Defaults to 42.

    Returns:
        DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
            If classification, it returns probabilities and final prediction

    """
    warnings.warn(
        "`include_input_features` will be deprecated in the next release."
        " Please add index columns to the test dataframe if you want to"
        " retain some features like the key or id",
        DeprecationWarning,
    )
    if test_time_augmentation:
        assert num_tta > 0, "num_tta should be greater than 0"
        assert alpha_tta > 0, "alpha_tta should be greater than 0"
        assert include_input_features is False, "include_input_features cannot be True for TTA."
        if not callable(aggregate_tta):
            assert aggregate_tta in [
                "mean",
                "median",
                "min",
                "max",
                "hard_voting",
            ], "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
        if self.config.task == "regression":
            assert aggregate_tta != "hard_voting", "hard_voting is only available for classification"

        torch.manual_seed(tta_seed)

        def add_noise(module, input, output):
            return output + alpha_tta * torch.randn_like(output, memory_format=torch.contiguous_format)

        # Register the hook to the embedding_layer
        handle = self.model.embedding_layer.register_forward_hook(add_noise)
        pred_prob_l = []
        for _ in range(num_tta):
            pred_df = self._predict(
                test,
                quantiles,
                n_samples,
                ret_logits,
                include_input_features=False,
                device=device,
                progress_bar=progress_bar or "None",
            )
            pred_idx = pred_df.index
            if self.config.task == "classification":
                pred_prob_l.append(pred_df.values[:, : -len(self.config.target)])
            elif self.config.task == "regression":
                pred_prob_l.append(pred_df.values)
        pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate_tta, None)
        # Remove the hook
        handle.remove()
    else:
        pred_df = self._predict(
            test,
            quantiles,
            n_samples,
            ret_logits,
            include_input_features,
            device,
            progress_bar,
        )
    return pred_df

prepare_dataloader(train, validation=None, train_sampler=None, target_transform=None, seed=42, cache_data='memory')

Prepares the dataloaders for training and validation.

Parameters:

Name Type Description Default
train DataFrame

Training Dataframe

required
validation Optional[DataFrame]

If provided, will use this dataframe as the validation while training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None.

None
train_sampler Optional[Sampler]

Custom PyTorch batch samplers which will be passed to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies

None
target_transform Optional[Union[TransformerMixin, Tuple(Callable)]]

If provided, applies the transform to the target before modelling and inverse the transform during prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func)

None
seed Optional[int]

Random seed for reproducibility. Defaults to 42.

42
cache_data str

Decides how to cache the data in the dataloader. If set to "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

'memory'

Returns: TabularDatamodule: The prepared datamodule

Source code in src/pytorch_tabular/tabular_model.py
def prepare_dataloader(
    self,
    train: DataFrame,
    validation: Optional[DataFrame] = None,
    train_sampler: Optional[torch.utils.data.Sampler] = None,
    target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
    seed: Optional[int] = 42,
    cache_data: str = "memory",
) -> TabularDatamodule:
    """Prepares the dataloaders for training and validation.

    Args:
        train (DataFrame): Training Dataframe

        validation (Optional[DataFrame], optional):
            If provided, will use this dataframe as the validation while training.
            Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
            Defaults to None.

        train_sampler (Optional[torch.utils.data.Sampler], optional):
            Custom PyTorch batch samplers which will be passed to the DataLoaders.
            Useful for dealing with imbalanced data and other custom batching strategies

        target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional):
            If provided, applies the transform to the target before modelling and inverse the transform during
            prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or
            a tuple of callables (transform_func, inverse_transform_func)

        seed (Optional[int], optional): Random seed for reproducibility. Defaults to 42.

        cache_data (str): Decides how to cache the data in the dataloader. If set to
            "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
    Returns:
        TabularDatamodule: The prepared datamodule

    """
    if self.verbose:
        logger.info("Preparing the DataLoaders")
    target_transform = self._check_and_set_target_transform(target_transform)

    datamodule = TabularDatamodule(
        train=train,
        validation=validation,
        config=self.config,
        target_transform=target_transform,
        train_sampler=train_sampler,
        seed=seed,
        cache_data=cache_data,
        verbose=self.verbose,
    )
    datamodule.prepare_data()
    datamodule.setup("fit")
    return datamodule

prepare_model(datamodule, loss=None, metrics=None, metrics_prob_inputs=None, optimizer=None, optimizer_params=None)

Prepares the model for training.

Parameters:

Name Type Description Default
datamodule TabularDatamodule

The datamodule

required
loss Optional[Module]

Custom Loss functions which are not in standard pytorch library

None
metrics Optional[List[Callable]]

Custom metric functions(Callable) which has the signature metric_fn(y_hat, y) and works on torch tensor inputs

None
metrics_prob_inputs Optional[List[bool]]

This is a mandatory parameter for classification metrics. If the metric function requires probabilities as inputs, set this to True. The length of the list should be equal to the number of metrics. Defaults to None.

None
optimizer Optional[Optimizer]

Custom optimizers which are a drop in replacements for standard PyTorch optimizers. This should be the Class and not the initialized object

None
optimizer_params Optional[Dict]

The parameters to initialize the custom optimizer.

None

Returns:

Name Type Description
BaseModel BaseModel

The prepared model

Source code in src/pytorch_tabular/tabular_model.py
def prepare_model(
    self,
    datamodule: TabularDatamodule,
    loss: Optional[torch.nn.Module] = None,
    metrics: Optional[List[Callable]] = None,
    metrics_prob_inputs: Optional[List[bool]] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    optimizer_params: Dict = None,
) -> BaseModel:
    """Prepares the model for training.

    Args:
        datamodule (TabularDatamodule): The datamodule

        loss (Optional[torch.nn.Module], optional): Custom Loss functions which are not in standard pytorch library

        metrics (Optional[List[Callable]], optional): Custom metric functions(Callable) which has the
            signature metric_fn(y_hat, y) and works on torch tensor inputs

        metrics_prob_inputs (Optional[List[bool]], optional): This is a mandatory parameter for
            classification metrics. If the metric function requires probabilities as inputs, set this to True.
            The length of the list should be equal to the number of metrics. Defaults to None.

        optimizer (Optional[torch.optim.Optimizer], optional):
            Custom optimizers which are a drop in replacements for standard PyTorch optimizers.
            This should be the Class and not the initialized object

        optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

    Returns:
        BaseModel: The prepared model

    """
    if self.verbose:
        logger.info(f"Preparing the Model: {self.config._model_name}")
    # Fetching the config as some data specific configs have been added in the datamodule
    self.inferred_config = self._read_parse_config(datamodule.update_config(self.config), InferredConfig)
    model = self.model_callable(
        self.config,
        custom_loss=loss,  # Unused in SSL tasks
        custom_metrics=metrics,  # Unused in SSL tasks
        custom_metrics_prob_inputs=metrics_prob_inputs,  # Unused in SSL tasks
        custom_optimizer=optimizer,
        custom_optimizer_params=optimizer_params or {},
        inferred_config=self.inferred_config,
    )
    # Data Aware Initialization(for the models that need it)
    model.data_aware_initialization(datamodule)
    if self.model_state_dict_path is not None:
        self._load_weights(model, self.model_state_dict_path)
    if self.track_experiment and self.config.log_target == "wandb":
        self.logger.watch(model, log=self.config.exp_watch, log_freq=self.config.exp_log_freq)
    return model

pretrain(train, validation=None, optimizer=None, optimizer_params=None, max_epochs=None, min_epochs=None, seed=42, callbacks=None, datamodule=None, cache_data='memory')

The pretrained method which takes in the data and triggers the training.

Parameters:

Name Type Description Default
train DataFrame

Training Dataframe

required
validation Optional[DataFrame]

If provided, will use this dataframe as the validation while training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None.

None
optimizer Optional[Optimizer]

Custom optimizers which are a drop in replacements for standard PyTorch optimizers. This should be the Class and not the initialized object

None
optimizer_params Optional[Dict]

The parameters to initialize the custom optimizer.

None
max_epochs Optional[int]

Overwrite maximum number of epochs to be run. Defaults to None.

None
min_epochs Optional[int]

Overwrite minimum number of epochs to be run. Defaults to None.

None
seed Optional[int]

(int): Random seed for reproducibility. Defaults to 42.

42
callbacks Optional[List[Callback]]

List of callbacks to be used during training. Defaults to None.

None
datamodule Optional[TabularDatamodule]

The datamodule. If provided, will ignore the rest of the parameters like train, test etc. and use the datamodule. Defaults to None.

None
cache_data str

Decides how to cache the data in the dataloader. If set to "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

'memory'

Returns: pl.Trainer: The PyTorch Lightning Trainer instance

Source code in src/pytorch_tabular/tabular_model.py
def pretrain(
    self,
    train: Optional[DataFrame],
    validation: Optional[DataFrame] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    optimizer_params: Dict = None,
    # train_sampler: Optional[torch.utils.data.Sampler] = None,
    max_epochs: Optional[int] = None,
    min_epochs: Optional[int] = None,
    seed: Optional[int] = 42,
    callbacks: Optional[List[pl.Callback]] = None,
    datamodule: Optional[TabularDatamodule] = None,
    cache_data: str = "memory",
) -> pl.Trainer:
    """The pretrained method which takes in the data and triggers the training.

    Args:
        train (DataFrame): Training Dataframe

        validation (Optional[DataFrame], optional): If provided, will use this dataframe as the validation while
            training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
            Defaults to None.

        optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizers which are a drop in replacements
            for standard PyTorch optimizers. This should be the Class and not the initialized object

        optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

        max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

        min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

        seed: (int): Random seed for reproducibility. Defaults to 42.

        callbacks (Optional[List[pl.Callback]], optional): List of callbacks to be used during training.
            Defaults to None.

        datamodule (Optional[TabularDatamodule], optional): The datamodule. If provided, will ignore the rest of the
            parameters like train, test etc. and use the datamodule. Defaults to None.

        cache_data (str): Decides how to cache the data in the dataloader. If set to
            "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
    Returns:
        pl.Trainer: The PyTorch Lightning Trainer instance

    """
    assert self.config.task == "ssl", (
        f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" " instead."
    )
    seed = seed or self.config.seed
    if seed:
        seed_everything(seed)
    if datamodule is None:
        datamodule = self.prepare_dataloader(
            train,
            validation,
            train_sampler=None,
            target_transform=None,
            seed=seed,
            cache_data=cache_data,
        )
    else:
        if train is not None:
            warnings.warn(
                "train data and datamodule is provided."
                " Ignoring the train data and using the datamodule."
                " Set either one of them to None to avoid this warning."
            )
    model = self.prepare_model(
        datamodule,
        optimizer,
        optimizer_params or {},
    )

    return self.train(model, datamodule, callbacks, max_epochs, min_epochs)

save_config(dir)

Saves the config in the specified directory.

Source code in src/pytorch_tabular/tabular_model.py
def save_config(self, dir: str) -> None:
    """Saves the config in the specified directory."""
    with open(os.path.join(dir, "config.yml"), "w") as fp:
        OmegaConf.save(self.config, fp, resolve=True)

save_datamodule(dir, inference_only=False)

Saves the datamodule in the specified directory.

Parameters:

Name Type Description Default
dir str

The path to the directory to save the datamodule

required
inference_only bool

If True, will only save the inference datamodule without data. This cannot be used for further training, but can be used for inference. Defaults to False.

False
Source code in src/pytorch_tabular/tabular_model.py
def save_datamodule(self, dir: str, inference_only: bool = False) -> None:
    """Saves the datamodule in the specified directory.

    Args:
        dir (str): The path to the directory to save the datamodule
        inference_only (bool): If True, will only save the inference datamodule
            without data. This cannot be used for further training, but can be
            used for inference. Defaults to False.

    """
    if inference_only:
        dm = self.datamodule.inference_only_copy()
    else:
        dm = self.datamodule

    joblib.dump(dm, os.path.join(dir, "datamodule.sav"))

save_model(dir, inference_only=False)

Saves the model and checkpoints in the specified directory.

Parameters:

Name Type Description Default
dir str

The path to the directory to save the model

required
inference_only bool

If True, will only save the inference only version of the datamodule

False
Source code in src/pytorch_tabular/tabular_model.py
def save_model(self, dir: str, inference_only: bool = False) -> None:
    """Saves the model and checkpoints in the specified directory.

    Args:
        dir (str): The path to the directory to save the model
        inference_only (bool): If True, will only save the inference
            only version of the datamodule

    """
    if os.path.exists(dir) and (os.listdir(dir)):
        logger.warning("Directory is not empty. Overwriting the contents.")
        for f in os.listdir(dir):
            os.remove(os.path.join(dir, f))
    os.makedirs(dir, exist_ok=True)
    self.save_config(dir)
    self.save_datamodule(dir, inference_only=inference_only)
    if hasattr(self.config, "log_target") and self.config.log_target is not None:
        joblib.dump(self.logger, os.path.join(dir, "exp_logger.sav"))
    if hasattr(self, "callbacks"):
        joblib.dump(self.callbacks, os.path.join(dir, "callbacks.sav"))
    self.trainer.save_checkpoint(os.path.join(dir, "model.ckpt"))
    custom_params = {}
    custom_params["custom_loss"] = getattr(self.model, "custom_loss", None)
    custom_params["custom_metrics"] = getattr(self.model, "custom_metrics", None)
    custom_params["custom_metrics_prob_inputs"] = getattr(self.model, "custom_metrics_prob_inputs", None)
    custom_params["custom_optimizer"] = getattr(self.model, "custom_optimizer", None)
    custom_params["custom_optimizer_params"] = getattr(self.model, "custom_optimizer_params", None)
    joblib.dump(custom_params, os.path.join(dir, "custom_params.sav"))
    if self.custom_model:
        joblib.dump(self.model_callable, os.path.join(dir, "custom_model_callable.sav"))

save_model_for_inference(path, kind='pytorch', onnx_export_params={'opset_version': 12})

Saves the model for inference.

Parameters:

Name Type Description Default
path Union[str, Path]

path to save the model

required
kind str

"pytorch" or "onnx" (Experimental)

'pytorch'
onnx_export_params Dict

parameters for onnx export to be passed to torch.onnx.export

{'opset_version': 12}

Returns:

Name Type Description
bool bool

True if the model was saved successfully

Source code in src/pytorch_tabular/tabular_model.py
def save_model_for_inference(
    self,
    path: Union[str, Path],
    kind: str = "pytorch",
    onnx_export_params: Dict = {"opset_version": 12},
) -> bool:
    """Saves the model for inference.

    Args:
        path (Union[str, Path]): path to save the model
        kind (str): "pytorch" or "onnx" (Experimental)
        onnx_export_params (Dict): parameters for onnx export to be
            passed to torch.onnx.export

    Returns:
        bool: True if the model was saved successfully

    """
    if kind == "pytorch":
        torch.save(self.model, str(path))
        return True
    elif kind == "onnx":
        # Export the model
        onnx_export_params["input_names"] = ["categorical", "continuous"]
        onnx_export_params["output_names"] = onnx_export_params.get("output_names", ["output"])
        onnx_export_params["dynamic_axes"] = {
            onnx_export_params["input_names"][0]: {0: "batch_size"},
            onnx_export_params["output_names"][0]: {0: "batch_size"},
        }
        cat = torch.zeros(
            self.config.batch_size,
            len(self.config.categorical_cols),
            dtype=torch.int,
        )
        cont = torch.randn(
            self.config.batch_size,
            len(self.config.continuous_cols),
            requires_grad=True,
        )
        x = {"continuous": cont, "categorical": cat}
        torch.onnx.export(self.model, x, str(path), **onnx_export_params)
        return True
    else:
        raise ValueError("`kind` must be either pytorch or onnx")

save_weights(path)

Saves the model weights in the specified directory.

Parameters:

Name Type Description Default
path str

The path to the file to save the model

required
Source code in src/pytorch_tabular/tabular_model.py
def save_weights(self, path: Union[str, Path]) -> None:
    """Saves the model weights in the specified directory.

    Args:
        path (str): The path to the file to save the model

    """
    torch.save(self.model.state_dict(), path)

summary(model=None, max_depth=-1)

Prints a summary of the model.

Parameters:

Name Type Description Default
max_depth int

The maximum depth to traverse the modules and displayed in the summary. Defaults to -1, which means will display all the modules.

-1
Source code in src/pytorch_tabular/tabular_model.py
def summary(self, model=None, max_depth: int = -1) -> None:
    """Prints a summary of the model.

    Args:
        max_depth (int): The maximum depth to traverse the modules and displayed in the summary.
            Defaults to -1, which means will display all the modules.

    """
    if model is not None:
        print(summarize(model, max_depth=max_depth))
    elif self.has_model:
        print(summarize(self.model, max_depth=max_depth))
    else:
        rich_print(f"[bold green]{self.__class__.__name__}[/bold green]")
        rich_print("-" * 100)
        rich_print("[bold yellow]Config[/bold yellow]")
        rich_print("-" * 100)
        pprint(self.config.__dict__["_content"])
        rich_print(
            ":triangular_flag:[bold red]Full Model Summary once model has "
            "been initialized or passed in as an argument[/bold red]"
        )

train(model, datamodule, callbacks=None, max_epochs=None, min_epochs=None, handle_oom=True)

Trains the model.

Parameters:

Name Type Description Default
model LightningModule

The PyTorch Lightning model to be trained.

required
datamodule TabularDatamodule

The datamodule

required
callbacks Optional[List[Callback]]

List of callbacks to be used during training. Defaults to None.

None
max_epochs Optional[int]

Overwrite maximum number of epochs to be run. Defaults to None.

None
min_epochs Optional[int]

Overwrite minimum number of epochs to be run. Defaults to None.

None
handle_oom bool

If True, will try to handle OOM errors elegantly. Defaults to True.

True

Returns:

Type Description
Trainer

pl.Trainer: The PyTorch Lightning Trainer instance

Source code in src/pytorch_tabular/tabular_model.py
def train(
    self,
    model: pl.LightningModule,
    datamodule: TabularDatamodule,
    callbacks: Optional[List[pl.Callback]] = None,
    max_epochs: int = None,
    min_epochs: int = None,
    handle_oom: bool = True,
) -> pl.Trainer:
    """Trains the model.

    Args:
        model (pl.LightningModule): The PyTorch Lightning model to be trained.

        datamodule (TabularDatamodule): The datamodule

        callbacks (Optional[List[pl.Callback]], optional):
            List of callbacks to be used during training. Defaults to None.

        max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

        min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

        handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.

    Returns:
        pl.Trainer: The PyTorch Lightning Trainer instance

    """
    self._prepare_for_training(model, datamodule, callbacks, max_epochs, min_epochs)
    train_loader, val_loader = (
        self.datamodule.train_dataloader(),
        self.datamodule.val_dataloader(),
    )
    self.model.train()
    if self.config.auto_lr_find and (not self.config.fast_dev_run):
        if self.verbose:
            logger.info("Auto LR Find Started")
        with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
            result = Tuner(self.trainer).lr_find(
                self.model,
                train_dataloaders=train_loader,
                val_dataloaders=val_loader,
            )
        if oom_handler.oom_triggered:
            raise OOMException(
                "OOM detected during LR Find. Try reducing your batch_size or the"
                " model parameters." + "/n" + "Original Error: " + oom_handler.oom_msg
            )
        if self.verbose:
            logger.info(
                f"Suggested LR: {result.suggestion()}. For plot and detailed"
                " analysis, use `find_learning_rate` method."
            )
        self.model.reset_weights()
        # Parameters in models needs to be initialized again after LR find
        self.model.data_aware_initialization(self.datamodule)
    self.model.train()
    if self.verbose:
        logger.info("Training Started")
    with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
        self.trainer.fit(self.model, train_loader, val_loader)
    if oom_handler.oom_triggered:
        raise OOMException(
            "OOM detected during Training. Try reducing your batch_size or the"
            " model parameters."
            "/n" + "Original Error: " + oom_handler.oom_msg
        )
    self._is_fitted = True
    if self.verbose:
        logger.info("Training the model completed")
    if self.config.load_best:
        self.load_best_model()
    return self.trainer

Bases: LightningDataModule

Source code in src/pytorch_tabular/tabular_datamodule.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
class TabularDatamodule(pl.LightningDataModule):
    CONTINUOUS_TRANSFORMS = {
        "quantile_uniform": {
            "callable": QuantileTransformer,
            "params": {"output_distribution": "uniform", "random_state": None},
        },
        "quantile_normal": {
            "callable": QuantileTransformer,
            "params": {"output_distribution": "normal", "random_state": None},
        },
        "box-cox": {
            "callable": PowerTransformer,
            "params": {"method": "box-cox", "standardize": False},
        },
        "yeo-johnson": {
            "callable": PowerTransformer,
            "params": {"method": "yeo-johnson", "standardize": False},
        },
    }

    class CACHE_MODES(Enum):
        MEMORY = "memory"
        DISK = "disk"
        INFERENCE = "inference"

    def __init__(
        self,
        train: DataFrame,
        config: DictConfig,
        validation: DataFrame = None,
        target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
        train_sampler: Optional[torch.utils.data.Sampler] = None,
        seed: Optional[int] = 42,
        cache_data: str = "memory",
        copy_data: bool = True,
        verbose: bool = True,
    ):
        """The Pytorch Lightning Datamodule for Tabular Data.

        Args:
            train (DataFrame): The Training Dataframe

            config (DictConfig): Merged configuration object from ModelConfig, DataConfig,
                TrainerConfig, OptimizerConfig & ExperimentConfig

            validation (DataFrame, optional): Validation Dataframe.
                If left empty, we use the validation split from DataConfig to split a random sample as validation.
                Defaults to None.

            target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional):
                If provided, applies the transform to the target before modelling and inverse the transform during
                prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or
                a tuple of callables (transform_func, inverse_transform_func)
                Defaults to None.

            train_sampler (Optional[torch.utils.data.Sampler], optional):
                If provided, the sampler will be used to sample the train data. Defaults to None.

            seed (Optional[int], optional): Seed to use for reproducible dataloaders. Defaults to 42.

            cache_data (str): Decides how to cache the data in the dataloader. If set to
                "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

            copy_data (bool): If True, will copy the dataframes before preprocessing. Defaults to True.

            verbose (bool): Sets the verbosity of the databodule logging

        """
        super().__init__()
        self.train = train.copy() if copy_data else train
        if validation is not None:
            self.validation = validation.copy() if copy_data else validation
        else:
            self.validation = None
        self._set_target_transform(target_transform)
        self.target = config.target or []
        self.batch_size = config.batch_size
        self.train_sampler = train_sampler
        self.config = config
        self.seed = seed
        self.verbose = verbose
        self._fitted = False
        self._setup_cache(cache_data)
        self._inferred_config = self._update_config(config)

    @property
    def categorical_encoder(self):
        """Returns the categorical encoder."""
        return getattr(self, "_categorical_encoder", None)

    @categorical_encoder.setter
    def categorical_encoder(self, value):
        self._categorical_encoder = value

    @property
    def continuous_transform(self):
        """Returns the continuous transform."""
        return getattr(self, "_continuous_transform", None)

    @continuous_transform.setter
    def continuous_transform(self, value):
        self._continuous_transform = value

    @property
    def scaler(self):
        """Returns the scaler."""
        return getattr(self, "_scaler", None)

    @scaler.setter
    def scaler(self, value):
        self._scaler = value

    @property
    def label_encoder(self):
        """Returns the label encoder."""
        return getattr(self, "_label_encoder", None)

    @label_encoder.setter
    def label_encoder(self, value):
        self._label_encoder = value

    @property
    def target_transforms(self):
        """Returns the target transforms."""
        if self.do_target_transform:
            return self._target_transforms
        else:
            return None

    @target_transforms.setter
    def target_transforms(self, value):
        self._target_transforms = value

    def _setup_cache(self, cache_data: Union[str, bool]) -> None:
        cache_data = cache_data.lower()
        if cache_data == self.CACHE_MODES.MEMORY.value:
            self.cache_mode = self.CACHE_MODES.MEMORY
        elif isinstance(cache_data, str):
            self.cache_mode = self.CACHE_MODES.DISK
            self.cache_dir = Path(cache_data)
            self.cache_dir.mkdir(parents=True, exist_ok=True)
        else:
            logger.warning(f"{cache_data} is not a valid path. Caching in memory")
            self.cache_mode = self.CACHE_MODES.MEMORY

    def _set_target_transform(self, target_transform: Union[TransformerMixin, Tuple]) -> None:
        if target_transform is not None:
            if isinstance(target_transform, Iterable):
                target_transform = FunctionTransformer(func=target_transform[0], inverse_func=target_transform[1])
            self.do_target_transform = True
        else:
            self.do_target_transform = False
        self.target_transform_template = target_transform

    def _update_config(self, config) -> InferredConfig:
        """Calculates and updates a few key information to the config object.

        Args:
            config (DictConfig): The config object

        Returns:
            InferredConfig: The updated config object

        """
        categorical_dim = len(config.categorical_cols)
        continuous_dim = len(config.continuous_cols)
        # self._output_dim_clf = len(np.unique(self.train_dataset.y)) if config.target else None
        # self._output_dim_reg = len(config.target) if config.target else None
        if config.task == "regression":
            # self._output_dim_reg = len(config.target) if config.target else None if self.train is not None:
            output_dim = len(config.target) if config.target else None
        elif config.task == "classification":
            # self._output_dim_clf = len(np.unique(self.train_dataset.y)) if config.target else None
            if self.train is not None:
                output_dim = len(np.unique(self.train[config.target[0]])) if config.target else None
            else:
                output_dim = len(np.unique(self.train_dataset.y)) if config.target else None
        elif config.task == "ssl":
            output_dim = None
        else:
            raise ValueError(f"{config.task} is an unsupported task.")
        if self.train is not None:
            categorical_cardinality = [
                int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
            ]
        else:
            categorical_cardinality = [
                int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
            ]
        if getattr(config, "embedding_dims", None) is not None:
            embedding_dims = config.embedding_dims
        else:
            embedding_dims = [(x, min(50, (x + 1) // 2)) for x in categorical_cardinality]
        return InferredConfig(
            categorical_dim=categorical_dim,
            continuous_dim=continuous_dim,
            output_dim=output_dim,
            categorical_cardinality=categorical_cardinality,
            embedding_dims=embedding_dims,
        )

    def update_config(self, config) -> InferredConfig:
        """Calculates and updates a few key information to the config object. Logic happens in _update_config. This is
        just a wrapper to make it accessible from outside and not break current apis.

        Args:
            config (DictConfig): The config object

        Returns:
            InferredConfig: The updated config object

        """
        if self.cache_mode is self.CACHE_MODES.INFERENCE:
            warnings.warn("Cannot update config in inference mode. Returning the cached config")
            return self._inferred_config
        else:
            return self._update_config(config)

    def _encode_date_columns(self, data: DataFrame) -> DataFrame:
        added_features = []
        for field_name, freq, format in self.config.date_columns:
            data = self.make_date(data, field_name, format)
            data, _new_feats = self.add_datepart(data, field_name, frequency=freq, prefix=None, drop=True)
            added_features += _new_feats
        return data, added_features

    def _encode_categorical_columns(self, data: DataFrame, stage: str) -> DataFrame:
        if stage != "fit":
            # Inference
            return self.categorical_encoder.transform(data)
        # Fit
        logger.debug("Encoding Categorical Columns using OrdinalEncoder")
        self.categorical_encoder = OrdinalEncoder(
            cols=self.config.categorical_cols,
            handle_unseen=("impute" if self.config.handle_unknown_categories else "error"),
            handle_missing="impute" if self.config.handle_missing_values else "error",
        )
        data = self.categorical_encoder.fit_transform(data)
        return data

    def _transform_continuous_columns(self, data: DataFrame, stage: str) -> DataFrame:
        if stage == "fit":
            transform = self.CONTINUOUS_TRANSFORMS[self.config.continuous_feature_transform]
            if "random_state" in transform["params"] and self.seed is not None:
                transform["params"]["random_state"] = self.seed
            self.continuous_transform = transform["callable"](**transform["params"])
            # can be accessed through property continuous_transform
            data.loc[:, self.config.continuous_cols] = self.continuous_transform.fit_transform(
                data.loc[:, self.config.continuous_cols]
            )
        else:
            data.loc[:, self.config.continuous_cols] = self.continuous_transform.transform(
                data.loc[:, self.config.continuous_cols]
            )
        return data

    def _normalize_continuous_columns(self, data: DataFrame, stage: str) -> DataFrame:
        if stage == "fit":
            self.scaler = StandardScaler()
            data.loc[:, self.config.continuous_cols] = self.scaler.fit_transform(
                data.loc[:, self.config.continuous_cols]
            )
        else:
            data.loc[:, self.config.continuous_cols] = self.scaler.transform(data.loc[:, self.config.continuous_cols])
        return data

    def _label_encode_target(self, data: DataFrame, stage: str) -> DataFrame:
        if self.config.task != "classification":
            return data
        if stage == "fit" or self.label_encoder is None:
            self.label_encoder = LabelEncoder()
            data[self.config.target[0]] = self.label_encoder.fit_transform(data[self.config.target[0]])
        else:
            if self.config.target[0] in data.columns:
                data[self.config.target[0]] = self.label_encoder.transform(data[self.config.target[0]])
        return data

    def _target_transform(self, data: DataFrame, stage: str) -> DataFrame:
        if self.config.task != "regression":
            return data
        # target transform only for regression
        if not all(col in data.columns for col in self.config.target):
            return data
        if self.do_target_transform:
            if stage == "fit" or self.target_transforms is None:
                target_transforms = []
                for col in self.config.target:
                    _target_transform = copy.deepcopy(self.target_transform_template)
                    data[col] = _target_transform.fit_transform(data[col].values.reshape(-1, 1))
                    target_transforms.append(_target_transform)
                self.target_transforms = target_transforms
            else:
                for col, _target_transform in zip(self.config.target, self.target_transforms):
                    data[col] = _target_transform.transform(data[col].values.reshape(-1, 1))
        return data

    def preprocess_data(self, data: DataFrame, stage: str = "inference") -> Tuple[DataFrame, list]:
        """The preprocessing, like Categorical Encoding, Normalization, etc. which any dataframe should undergo before
        feeding into the dataloder.

        Args:
            data (DataFrame): A dataframe with the features and target
            stage (str, optional): Internal parameter. Used to distinguisj between fit and inference.
                Defaults to "inference".

        Returns:
            Returns the processed dataframe and the added features(list) as a tuple

        """
        added_features = None
        if self.config.encode_date_columns:
            data, added_features = self._encode_date_columns(data)
        # The only features that are added are the date features extracted
        # from the date which are categorical in nature
        if (added_features is not None) and (stage == "fit"):
            logger.debug(f"Added {added_features} features after encoding the date_columns")
            self.config.categorical_cols += added_features
            # Update the categorical dimension in config
            self.config.categorical_dim = (
                len(self.config.categorical_cols) if self.config.categorical_cols is not None else 0
            )
            self._inferred_config = self._update_config(self.config)
        # Encoding Categorical Columns
        if len(self.config.categorical_cols) > 0:
            data = self._encode_categorical_columns(data, stage)

        # Transforming Continuous Columns
        if (self.config.continuous_feature_transform is not None) and (len(self.config.continuous_cols) > 0):
            data = self._transform_continuous_columns(data, stage)
        # Normalizing Continuous Columns
        if (self.config.normalize_continuous_features) and (len(self.config.continuous_cols) > 0):
            data = self._normalize_continuous_columns(data, stage)
        # Converting target labels to a 0 indexed label
        data = self._label_encode_target(data, stage)
        # Target Transforms
        data = self._target_transform(data, stage)
        return data, added_features

    def _cache_dataset(self):
        train_dataset = TabularDataset(
            task=self.config.task,
            data=self.train,
            categorical_cols=self.config.categorical_cols,
            continuous_cols=self.config.continuous_cols,
            target=self.target,
        )
        self.train = None
        validation_dataset = TabularDataset(
            task=self.config.task,
            data=self.validation,
            categorical_cols=self.config.categorical_cols,
            continuous_cols=self.config.continuous_cols,
            target=self.target,
        )
        self.validation = None

        if self.cache_mode is self.CACHE_MODES.DISK:
            torch.save(train_dataset, self.cache_dir / "train_dataset")
            torch.save(validation_dataset, self.cache_dir / "validation_dataset")
        elif self.cache_mode is self.CACHE_MODES.MEMORY:
            self.train_dataset = train_dataset
            self.validation_dataset = validation_dataset
        elif self.cache_mode is self.CACHE_MODES.INFERENCE:
            self.train_dataset = None
            self.validation_dataset = None
        else:
            raise ValueError(f"{self.cache_mode} is not a valid cache mode")

    def split_train_val(self, train):
        logger.debug(
            "No validation data provided." f" Using {self.config.validation_split*100}% of train data as validation"
        )
        val_idx = train.sample(
            int(self.config.validation_split * len(train)),
            random_state=self.seed,
        ).index
        validation = train[train.index.isin(val_idx)]
        train = train[~train.index.isin(val_idx)]
        return train, validation

    def setup(self, stage: Optional[str] = None) -> None:
        """Data Operations you want to perform on all GPUs, like train-test split, transformations, etc. This is called
        before accessing the dataloaders.

        Args:
            stage (Optional[str], optional):
                Internal parameter to distinguish between fit and inference. Defaults to None.

        """
        if not (stage is None or stage == "fit" or stage == "ssl_finetune"):
            return
        if self.verbose:
            logger.info(f"Setting up the datamodule for {self.config.task} task")
        is_ssl = stage == "ssl_finetune"
        if self.validation is None:
            self.train, self.validation = self.split_train_val(self.train)
        else:
            self.validation = self.validation.copy()
        # Preprocessing Train, Validation
        self.train, _ = self.preprocess_data(self.train, stage="fit" if not is_ssl else "inference")
        self.validation, _ = self.preprocess_data(self.validation, stage="inference")
        self._fitted = True
        self._cache_dataset()

    def inference_only_copy(self):
        """Creates a copy of the datamodule with the train and validation datasets removed. This is useful for
        inference only scenarios where we don't want to save the train and validation datasets.

        Returns:
            TabularDatamodule: A copy of the datamodule with the train and validation datasets removed.

        """
        if not self._fitted:
            raise RuntimeError("Can create an inference only copy only after model is fitted")
        dm_inference = copy.copy(self)
        dm_inference.train_dataset = None
        dm_inference.validation_dataset = None
        dm_inference.cache_mode = self.CACHE_MODES.INFERENCE
        return dm_inference

    # adapted from gluonts
    @classmethod
    def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
        """Returns a list of time features that will be appropriate for the given frequency string.

        Args:
            freq_str (str): Frequency string of the form `[multiple][granularity]` such as "12H", "5min", "1D" etc.

        Returns:
            List of added features

        """

        features_by_offsets = {
            offsets.YearBegin: [],
            offsets.YearEnd: [],
            offsets.MonthBegin: [
                "Month",
                "Quarter",
                "Is_quarter_end",
                "Is_quarter_start",
                "Is_year_end",
                "Is_year_start",
            ],
            offsets.MonthEnd: [
                "Month",
                "Quarter",
                "Is_quarter_end",
                "Is_quarter_start",
                "Is_year_end",
                "Is_year_start",
            ],
            offsets.Week: [
                "Month",
                "Quarter",
                "Is_quarter_end",
                "Is_quarter_start",
                "Is_year_end",
                "Is_year_start",
                "Is_month_start",
                "Week",
            ],
            offsets.Day: [
                "Month",
                "Quarter",
                "Is_quarter_end",
                "Is_quarter_start",
                "Is_year_end",
                "Is_year_start",
                "Is_month_start",
                "WeekDay",
                "Dayofweek",
                "Dayofyear",
            ],
            offsets.BusinessDay: [
                "Month",
                "Quarter",
                "Is_quarter_end",
                "Is_quarter_start",
                "Is_year_end",
                "Is_year_start",
                "Is_month_start",
                "WeekDay",
                "Dayofweek",
                "Dayofyear",
            ],
            offsets.Hour: [
                "Month",
                "Quarter",
                "Is_quarter_end",
                "Is_quarter_start",
                "Is_year_end",
                "Is_year_start",
                "Is_month_start",
                "WeekDay",
                "Dayofweek",
                "Dayofyear",
                "Hour",
            ],
            offsets.Minute: [
                "Month",
                "Quarter",
                "Is_quarter_end",
                "Is_quarter_start",
                "Is_year_end",
                "Is_year_start",
                "Is_month_start",
                "WeekDay",
                "Dayofweek",
                "Dayofyear",
                "Hour",
                "Minute",
            ],
        }

        offset = to_offset(freq_str)

        for offset_type, feature in features_by_offsets.items():
            if isinstance(offset, offset_type):
                return feature

        supported_freq_msg = f"""
        Unsupported frequency {freq_str}

        The following frequencies are supported:

            Y, YS   - yearly
                alias: A
            M, MS   - monthly
            W   - weekly
            D   - daily
            B   - business days
            H   - hourly
            T   - minutely
                alias: min
        """
        raise RuntimeError(supported_freq_msg)

    # adapted from fastai
    @classmethod
    def make_date(cls, df: DataFrame, date_field: str, date_format: str = "ISO8601") -> DataFrame:
        """Make sure `df[date_field]` is of the right date type.

        Args:
            df (DataFrame): Dataframe

            date_field (str): Date field name

        Returns:
            Dataframe with date field converted to datetime

        """
        field_dtype = df[date_field].dtype
        if isinstance(field_dtype, DatetimeTZDtype):
            field_dtype = np.datetime64
        if not np.issubdtype(field_dtype, np.datetime64):
            df[date_field] = to_datetime(df[date_field], format=date_format)
        return df

    # adapted from fastai
    @classmethod
    def add_datepart(
        cls,
        df: DataFrame,
        field_name: str,
        frequency: str,
        prefix: str = None,
        drop: bool = True,
    ) -> Tuple[DataFrame, List[str]]:
        """Helper function that adds columns relevant to a date in the column `field_name` of `df`.

        Args:
            df (DataFrame): Dataframe

            field_name (str): Date field name

            frequency (str): Frequency string of the form `[multiple][granularity]` such as "12H", "5min", "1D" etc.

            prefix (str, optional): Prefix to add to the new columns. Defaults to None.

            drop (bool, optional): Drop the original column. Defaults to True.

        Returns:
            Dataframe with added columns and list of added columns

        """
        field = df[field_name]
        prefix = (re.sub("[Dd]ate$", "", field_name) if prefix is None else prefix) + "_"
        attr = cls.time_features_from_frequency_str(frequency)
        added_features = []
        for n in attr:
            if n == "Week":
                continue
            df[prefix + n] = getattr(field.dt, n.lower())
            added_features.append(prefix + n)
        # Pandas removed `dt.week` in v1.1.10
        if "Week" in attr:
            week = field.dt.isocalendar().week if hasattr(field.dt, "isocalendar") else field.dt.week
            df.insert(3, prefix + "Week", week)
            added_features.append(prefix + "Week")
        # TODO Not adding Elapsed by default. Need to route it through config
        # mask = ~field.isna()
        # df[prefix + "Elapsed"] = np.where(
        #     mask, field.values.astype(np.int64) // 10 ** 9, None
        # )
        # added_features.append(prefix + "Elapsed")
        if drop:
            df.drop(field_name, axis=1, inplace=True)

        # Removing features woth zero variations
        # for col in added_features:
        #     if len(df[col].unique()) == 1:
        #         df.drop(columns=col, inplace=True)
        #         added_features.remove(col)
        return df, added_features

    def _load_dataset_from_cache(self, tag: str = "train"):
        if self.cache_mode is self.CACHE_MODES.MEMORY:
            try:
                dataset = getattr(self, f"_{tag}_dataset")
            except AttributeError:
                raise AttributeError(
                    f"{tag}_dataset not found in memory. Please provide the data for" f" {tag} dataloader"
                )
        elif self.cache_mode is self.CACHE_MODES.DISK:
            try:
                dataset = torch.load(self.cache_dir / f"{tag}_dataset")
            except FileNotFoundError:
                raise FileNotFoundError(
                    f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"
                )
        elif self.cache_mode is self.CACHE_MODES.INFERENCE:
            raise RuntimeError("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead")
        else:
            raise ValueError(f"{self.cache_mode} is not a valid cache mode")
        return dataset

    @property
    def train_dataset(self) -> TabularDataset:
        """Returns the train dataset.

        Returns:
            TabularDataset: The train dataset

        """
        return self._load_dataset_from_cache("train")

    @train_dataset.setter
    def train_dataset(self, value):
        self._train_dataset = value

    @property
    def validation_dataset(self) -> TabularDataset:
        """Returns the validation dataset.

        Returns:
            TabularDataset: The validation dataset

        """
        return self._load_dataset_from_cache("validation")

    @validation_dataset.setter
    def validation_dataset(self, value):
        self._validation_dataset = value

    def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
        """Function that loads the train set.

        Args:
            batch_size (Optional[int], optional): Batch size. Defaults to `self.batch_size`.

        Returns:
            DataLoader: Train dataloader

        """
        return DataLoader(
            self.train_dataset,
            batch_size or self.batch_size,
            shuffle=True if self.train_sampler is None else False,
            num_workers=self.config.num_workers,
            sampler=self.train_sampler,
            pin_memory=self.config.pin_memory,
        )

    def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
        """Function that loads the validation set.

        Args:
            batch_size (Optional[int], optional): Batch size. Defaults to `self.batch_size`.

        Returns:
            DataLoader: Validation dataloader

        """
        return DataLoader(
            self.validation_dataset,
            batch_size or self.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
        )

    def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
        """Prepare data for inference."""
        # TODO Is the target encoding necessary?
        if len(set(self.target) - set(df.columns)) > 0:
            if self.config.task == "classification":
                df.loc[:, self.target] = np.array([self.label_encoder.classes_[0]] * len(df)).reshape(-1, 1)
            else:
                df.loc[:, self.target] = np.zeros((len(df), len(self.target)))
        df, _ = self.preprocess_data(df, stage="inference")
        return df

    def prepare_inference_dataloader(
        self, df: DataFrame, batch_size: Optional[int] = None, copy_df: bool = True
    ) -> DataLoader:
        """Function that prepares and loads the new data.

        Args:
            df (DataFrame): Dataframe with the features and target
            batch_size (Optional[int], optional): Batch size. Defaults to `self.batch_size`.
            copy_df (bool, optional): Whether to copy the dataframe before processing or not. Defaults to False.
        Returns:
            DataLoader: The dataloader for the passed in dataframe

        """
        if copy_df:
            df = df.copy()
        df = self._prepare_inference_data(df)
        dataset = TabularDataset(
            task=self.config.task,
            data=df,
            categorical_cols=self.config.categorical_cols,
            continuous_cols=self.config.continuous_cols,
            target=(self.target if all(col in df.columns for col in self.target) else None),
        )
        return DataLoader(
            dataset,
            batch_size or self.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
        )

    def save_dataloader(self, path: Union[str, Path]) -> None:
        """Saves the dataloader to a path.

        Args:
            path (Union[str, Path]): Path to save the dataloader

        """
        if isinstance(path, str):
            path = Path(path)
        joblib.dump(self, path)

    @classmethod
    def load_datamodule(cls, path: Union[str, Path]):
        """Loads a datamodule from a path.

        Args:
            path (Union[str, Path]): Path to the datamodule

        Returns:
            TabularDatamodule (TabularDatamodule): The datamodule loaded from the path

        """
        if isinstance(path, str):
            path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"{path} does not exist.")
        datamodule = joblib.load(path)
        return datamodule

    def copy(
        self,
        train: DataFrame,
        validation: DataFrame = None,
        target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
        train_sampler: Optional[torch.utils.data.Sampler] = None,
        seed: Optional[int] = None,
        cache_data: str = None,
        copy_data: bool = None,
        verbose: bool = None,
        call_setup: bool = True,
        config_override: Optional[Dict] = {},
    ):
        if config_override is not None:
            for k, v in config_override.items():
                setattr(self.config, k, v)
        dm = TabularDatamodule(
            train=train,
            config=self.config,
            validation=validation,
            target_transform=target_transform or self.target_transforms,
            train_sampler=train_sampler or self.train_sampler,
            seed=seed or self.seed,
            cache_data=cache_data or self.cache_mode.value,
            copy_data=copy_data or True,
            verbose=verbose or self.verbose,
        )
        dm.categorical_encoder = self.categorical_encoder
        dm.continuous_transform = self.continuous_transform
        dm.scaler = self.scaler
        dm.label_encoder = self.label_encoder
        dm.target_transforms = self.target_transforms
        dm.setup(stage="ssl_finetune") if call_setup else None
        return dm

categorical_encoder property writable

Returns the categorical encoder.

continuous_transform property writable

Returns the continuous transform.

label_encoder property writable

Returns the label encoder.

scaler property writable

Returns the scaler.

target_transforms property writable

Returns the target transforms.

train_dataset: TabularDataset property writable

Returns the train dataset.

Returns:

Name Type Description
TabularDataset TabularDataset

The train dataset

validation_dataset: TabularDataset property writable

Returns the validation dataset.

Returns:

Name Type Description
TabularDataset TabularDataset

The validation dataset

__init__(train, config, validation=None, target_transform=None, train_sampler=None, seed=42, cache_data='memory', copy_data=True, verbose=True)

The Pytorch Lightning Datamodule for Tabular Data.

Parameters:

Name Type Description Default
train DataFrame

The Training Dataframe

required
config DictConfig

Merged configuration object from ModelConfig, DataConfig, TrainerConfig, OptimizerConfig & ExperimentConfig

required
validation DataFrame

Validation Dataframe. If left empty, we use the validation split from DataConfig to split a random sample as validation. Defaults to None.

None
target_transform Optional[Union[TransformerMixin, Tuple(Callable)]]

If provided, applies the transform to the target before modelling and inverse the transform during prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func) Defaults to None.

None
train_sampler Optional[Sampler]

If provided, the sampler will be used to sample the train data. Defaults to None.

None
seed Optional[int]

Seed to use for reproducible dataloaders. Defaults to 42.

42
cache_data str

Decides how to cache the data in the dataloader. If set to "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

'memory'
copy_data bool

If True, will copy the dataframes before preprocessing. Defaults to True.

True
verbose bool

Sets the verbosity of the databodule logging

True
Source code in src/pytorch_tabular/tabular_datamodule.py
def __init__(
    self,
    train: DataFrame,
    config: DictConfig,
    validation: DataFrame = None,
    target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
    train_sampler: Optional[torch.utils.data.Sampler] = None,
    seed: Optional[int] = 42,
    cache_data: str = "memory",
    copy_data: bool = True,
    verbose: bool = True,
):
    """The Pytorch Lightning Datamodule for Tabular Data.

    Args:
        train (DataFrame): The Training Dataframe

        config (DictConfig): Merged configuration object from ModelConfig, DataConfig,
            TrainerConfig, OptimizerConfig & ExperimentConfig

        validation (DataFrame, optional): Validation Dataframe.
            If left empty, we use the validation split from DataConfig to split a random sample as validation.
            Defaults to None.

        target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional):
            If provided, applies the transform to the target before modelling and inverse the transform during
            prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or
            a tuple of callables (transform_func, inverse_transform_func)
            Defaults to None.

        train_sampler (Optional[torch.utils.data.Sampler], optional):
            If provided, the sampler will be used to sample the train data. Defaults to None.

        seed (Optional[int], optional): Seed to use for reproducible dataloaders. Defaults to 42.

        cache_data (str): Decides how to cache the data in the dataloader. If set to
            "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

        copy_data (bool): If True, will copy the dataframes before preprocessing. Defaults to True.

        verbose (bool): Sets the verbosity of the databodule logging

    """
    super().__init__()
    self.train = train.copy() if copy_data else train
    if validation is not None:
        self.validation = validation.copy() if copy_data else validation
    else:
        self.validation = None
    self._set_target_transform(target_transform)
    self.target = config.target or []
    self.batch_size = config.batch_size
    self.train_sampler = train_sampler
    self.config = config
    self.seed = seed
    self.verbose = verbose
    self._fitted = False
    self._setup_cache(cache_data)
    self._inferred_config = self._update_config(config)

add_datepart(df, field_name, frequency, prefix=None, drop=True) classmethod

Helper function that adds columns relevant to a date in the column field_name of df.

Parameters:

Name Type Description Default
df DataFrame

Dataframe

required
field_name str

Date field name

required
frequency str

Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.

required
prefix str

Prefix to add to the new columns. Defaults to None.

None
drop bool

Drop the original column. Defaults to True.

True

Returns:

Type Description
Tuple[DataFrame, List[str]]

Dataframe with added columns and list of added columns

Source code in src/pytorch_tabular/tabular_datamodule.py
@classmethod
def add_datepart(
    cls,
    df: DataFrame,
    field_name: str,
    frequency: str,
    prefix: str = None,
    drop: bool = True,
) -> Tuple[DataFrame, List[str]]:
    """Helper function that adds columns relevant to a date in the column `field_name` of `df`.

    Args:
        df (DataFrame): Dataframe

        field_name (str): Date field name

        frequency (str): Frequency string of the form `[multiple][granularity]` such as "12H", "5min", "1D" etc.

        prefix (str, optional): Prefix to add to the new columns. Defaults to None.

        drop (bool, optional): Drop the original column. Defaults to True.

    Returns:
        Dataframe with added columns and list of added columns

    """
    field = df[field_name]
    prefix = (re.sub("[Dd]ate$", "", field_name) if prefix is None else prefix) + "_"
    attr = cls.time_features_from_frequency_str(frequency)
    added_features = []
    for n in attr:
        if n == "Week":
            continue
        df[prefix + n] = getattr(field.dt, n.lower())
        added_features.append(prefix + n)
    # Pandas removed `dt.week` in v1.1.10
    if "Week" in attr:
        week = field.dt.isocalendar().week if hasattr(field.dt, "isocalendar") else field.dt.week
        df.insert(3, prefix + "Week", week)
        added_features.append(prefix + "Week")
    # TODO Not adding Elapsed by default. Need to route it through config
    # mask = ~field.isna()
    # df[prefix + "Elapsed"] = np.where(
    #     mask, field.values.astype(np.int64) // 10 ** 9, None
    # )
    # added_features.append(prefix + "Elapsed")
    if drop:
        df.drop(field_name, axis=1, inplace=True)

    # Removing features woth zero variations
    # for col in added_features:
    #     if len(df[col].unique()) == 1:
    #         df.drop(columns=col, inplace=True)
    #         added_features.remove(col)
    return df, added_features

inference_only_copy()

Creates a copy of the datamodule with the train and validation datasets removed. This is useful for inference only scenarios where we don't want to save the train and validation datasets.

Returns:

Name Type Description
TabularDatamodule

A copy of the datamodule with the train and validation datasets removed.

Source code in src/pytorch_tabular/tabular_datamodule.py
def inference_only_copy(self):
    """Creates a copy of the datamodule with the train and validation datasets removed. This is useful for
    inference only scenarios where we don't want to save the train and validation datasets.

    Returns:
        TabularDatamodule: A copy of the datamodule with the train and validation datasets removed.

    """
    if not self._fitted:
        raise RuntimeError("Can create an inference only copy only after model is fitted")
    dm_inference = copy.copy(self)
    dm_inference.train_dataset = None
    dm_inference.validation_dataset = None
    dm_inference.cache_mode = self.CACHE_MODES.INFERENCE
    return dm_inference

load_datamodule(path) classmethod

Loads a datamodule from a path.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to the datamodule

required

Returns:

Name Type Description
TabularDatamodule TabularDatamodule

The datamodule loaded from the path

Source code in src/pytorch_tabular/tabular_datamodule.py
@classmethod
def load_datamodule(cls, path: Union[str, Path]):
    """Loads a datamodule from a path.

    Args:
        path (Union[str, Path]): Path to the datamodule

    Returns:
        TabularDatamodule (TabularDatamodule): The datamodule loaded from the path

    """
    if isinstance(path, str):
        path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"{path} does not exist.")
    datamodule = joblib.load(path)
    return datamodule

make_date(df, date_field, date_format='ISO8601') classmethod

Make sure df[date_field] is of the right date type.

Parameters:

Name Type Description Default
df DataFrame

Dataframe

required
date_field str

Date field name

required

Returns:

Type Description
DataFrame

Dataframe with date field converted to datetime

Source code in src/pytorch_tabular/tabular_datamodule.py
@classmethod
def make_date(cls, df: DataFrame, date_field: str, date_format: str = "ISO8601") -> DataFrame:
    """Make sure `df[date_field]` is of the right date type.

    Args:
        df (DataFrame): Dataframe

        date_field (str): Date field name

    Returns:
        Dataframe with date field converted to datetime

    """
    field_dtype = df[date_field].dtype
    if isinstance(field_dtype, DatetimeTZDtype):
        field_dtype = np.datetime64
    if not np.issubdtype(field_dtype, np.datetime64):
        df[date_field] = to_datetime(df[date_field], format=date_format)
    return df

prepare_inference_dataloader(df, batch_size=None, copy_df=True)

Function that prepares and loads the new data.

Parameters:

Name Type Description Default
df DataFrame

Dataframe with the features and target

required
batch_size Optional[int]

Batch size. Defaults to self.batch_size.

None
copy_df bool

Whether to copy the dataframe before processing or not. Defaults to False.

True

Returns: DataLoader: The dataloader for the passed in dataframe

Source code in src/pytorch_tabular/tabular_datamodule.py
def prepare_inference_dataloader(
    self, df: DataFrame, batch_size: Optional[int] = None, copy_df: bool = True
) -> DataLoader:
    """Function that prepares and loads the new data.

    Args:
        df (DataFrame): Dataframe with the features and target
        batch_size (Optional[int], optional): Batch size. Defaults to `self.batch_size`.
        copy_df (bool, optional): Whether to copy the dataframe before processing or not. Defaults to False.
    Returns:
        DataLoader: The dataloader for the passed in dataframe

    """
    if copy_df:
        df = df.copy()
    df = self._prepare_inference_data(df)
    dataset = TabularDataset(
        task=self.config.task,
        data=df,
        categorical_cols=self.config.categorical_cols,
        continuous_cols=self.config.continuous_cols,
        target=(self.target if all(col in df.columns for col in self.target) else None),
    )
    return DataLoader(
        dataset,
        batch_size or self.batch_size,
        shuffle=False,
        num_workers=self.config.num_workers,
    )

preprocess_data(data, stage='inference')

The preprocessing, like Categorical Encoding, Normalization, etc. which any dataframe should undergo before feeding into the dataloder.

Parameters:

Name Type Description Default
data DataFrame

A dataframe with the features and target

required
stage str

Internal parameter. Used to distinguisj between fit and inference. Defaults to "inference".

'inference'

Returns:

Type Description
Tuple[DataFrame, list]

Returns the processed dataframe and the added features(list) as a tuple

Source code in src/pytorch_tabular/tabular_datamodule.py
def preprocess_data(self, data: DataFrame, stage: str = "inference") -> Tuple[DataFrame, list]:
    """The preprocessing, like Categorical Encoding, Normalization, etc. which any dataframe should undergo before
    feeding into the dataloder.

    Args:
        data (DataFrame): A dataframe with the features and target
        stage (str, optional): Internal parameter. Used to distinguisj between fit and inference.
            Defaults to "inference".

    Returns:
        Returns the processed dataframe and the added features(list) as a tuple

    """
    added_features = None
    if self.config.encode_date_columns:
        data, added_features = self._encode_date_columns(data)
    # The only features that are added are the date features extracted
    # from the date which are categorical in nature
    if (added_features is not None) and (stage == "fit"):
        logger.debug(f"Added {added_features} features after encoding the date_columns")
        self.config.categorical_cols += added_features
        # Update the categorical dimension in config
        self.config.categorical_dim = (
            len(self.config.categorical_cols) if self.config.categorical_cols is not None else 0
        )
        self._inferred_config = self._update_config(self.config)
    # Encoding Categorical Columns
    if len(self.config.categorical_cols) > 0:
        data = self._encode_categorical_columns(data, stage)

    # Transforming Continuous Columns
    if (self.config.continuous_feature_transform is not None) and (len(self.config.continuous_cols) > 0):
        data = self._transform_continuous_columns(data, stage)
    # Normalizing Continuous Columns
    if (self.config.normalize_continuous_features) and (len(self.config.continuous_cols) > 0):
        data = self._normalize_continuous_columns(data, stage)
    # Converting target labels to a 0 indexed label
    data = self._label_encode_target(data, stage)
    # Target Transforms
    data = self._target_transform(data, stage)
    return data, added_features

save_dataloader(path)

Saves the dataloader to a path.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to save the dataloader

required
Source code in src/pytorch_tabular/tabular_datamodule.py
def save_dataloader(self, path: Union[str, Path]) -> None:
    """Saves the dataloader to a path.

    Args:
        path (Union[str, Path]): Path to save the dataloader

    """
    if isinstance(path, str):
        path = Path(path)
    joblib.dump(self, path)

setup(stage=None)

Data Operations you want to perform on all GPUs, like train-test split, transformations, etc. This is called before accessing the dataloaders.

Parameters:

Name Type Description Default
stage Optional[str]

Internal parameter to distinguish between fit and inference. Defaults to None.

None
Source code in src/pytorch_tabular/tabular_datamodule.py
def setup(self, stage: Optional[str] = None) -> None:
    """Data Operations you want to perform on all GPUs, like train-test split, transformations, etc. This is called
    before accessing the dataloaders.

    Args:
        stage (Optional[str], optional):
            Internal parameter to distinguish between fit and inference. Defaults to None.

    """
    if not (stage is None or stage == "fit" or stage == "ssl_finetune"):
        return
    if self.verbose:
        logger.info(f"Setting up the datamodule for {self.config.task} task")
    is_ssl = stage == "ssl_finetune"
    if self.validation is None:
        self.train, self.validation = self.split_train_val(self.train)
    else:
        self.validation = self.validation.copy()
    # Preprocessing Train, Validation
    self.train, _ = self.preprocess_data(self.train, stage="fit" if not is_ssl else "inference")
    self.validation, _ = self.preprocess_data(self.validation, stage="inference")
    self._fitted = True
    self._cache_dataset()

time_features_from_frequency_str(freq_str) classmethod

Returns a list of time features that will be appropriate for the given frequency string.

Parameters:

Name Type Description Default
freq_str str

Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.

required

Returns:

Type Description
List[str]

List of added features

Source code in src/pytorch_tabular/tabular_datamodule.py
@classmethod
def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
    """Returns a list of time features that will be appropriate for the given frequency string.

    Args:
        freq_str (str): Frequency string of the form `[multiple][granularity]` such as "12H", "5min", "1D" etc.

    Returns:
        List of added features

    """

    features_by_offsets = {
        offsets.YearBegin: [],
        offsets.YearEnd: [],
        offsets.MonthBegin: [
            "Month",
            "Quarter",
            "Is_quarter_end",
            "Is_quarter_start",
            "Is_year_end",
            "Is_year_start",
        ],
        offsets.MonthEnd: [
            "Month",
            "Quarter",
            "Is_quarter_end",
            "Is_quarter_start",
            "Is_year_end",
            "Is_year_start",
        ],
        offsets.Week: [
            "Month",
            "Quarter",
            "Is_quarter_end",
            "Is_quarter_start",
            "Is_year_end",
            "Is_year_start",
            "Is_month_start",
            "Week",
        ],
        offsets.Day: [
            "Month",
            "Quarter",
            "Is_quarter_end",
            "Is_quarter_start",
            "Is_year_end",
            "Is_year_start",
            "Is_month_start",
            "WeekDay",
            "Dayofweek",
            "Dayofyear",
        ],
        offsets.BusinessDay: [
            "Month",
            "Quarter",
            "Is_quarter_end",
            "Is_quarter_start",
            "Is_year_end",
            "Is_year_start",
            "Is_month_start",
            "WeekDay",
            "Dayofweek",
            "Dayofyear",
        ],
        offsets.Hour: [
            "Month",
            "Quarter",
            "Is_quarter_end",
            "Is_quarter_start",
            "Is_year_end",
            "Is_year_start",
            "Is_month_start",
            "WeekDay",
            "Dayofweek",
            "Dayofyear",
            "Hour",
        ],
        offsets.Minute: [
            "Month",
            "Quarter",
            "Is_quarter_end",
            "Is_quarter_start",
            "Is_year_end",
            "Is_year_start",
            "Is_month_start",
            "WeekDay",
            "Dayofweek",
            "Dayofyear",
            "Hour",
            "Minute",
        ],
    }

    offset = to_offset(freq_str)

    for offset_type, feature in features_by_offsets.items():
        if isinstance(offset, offset_type):
            return feature

    supported_freq_msg = f"""
    Unsupported frequency {freq_str}

    The following frequencies are supported:

        Y, YS   - yearly
            alias: A
        M, MS   - monthly
        W   - weekly
        D   - daily
        B   - business days
        H   - hourly
        T   - minutely
            alias: min
    """
    raise RuntimeError(supported_freq_msg)

train_dataloader(batch_size=None)

Function that loads the train set.

Parameters:

Name Type Description Default
batch_size Optional[int]

Batch size. Defaults to self.batch_size.

None

Returns:

Name Type Description
DataLoader DataLoader

Train dataloader

Source code in src/pytorch_tabular/tabular_datamodule.py
def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
    """Function that loads the train set.

    Args:
        batch_size (Optional[int], optional): Batch size. Defaults to `self.batch_size`.

    Returns:
        DataLoader: Train dataloader

    """
    return DataLoader(
        self.train_dataset,
        batch_size or self.batch_size,
        shuffle=True if self.train_sampler is None else False,
        num_workers=self.config.num_workers,
        sampler=self.train_sampler,
        pin_memory=self.config.pin_memory,
    )

update_config(config)

Calculates and updates a few key information to the config object. Logic happens in _update_config. This is just a wrapper to make it accessible from outside and not break current apis.

Parameters:

Name Type Description Default
config DictConfig

The config object

required

Returns:

Name Type Description
InferredConfig InferredConfig

The updated config object

Source code in src/pytorch_tabular/tabular_datamodule.py
def update_config(self, config) -> InferredConfig:
    """Calculates and updates a few key information to the config object. Logic happens in _update_config. This is
    just a wrapper to make it accessible from outside and not break current apis.

    Args:
        config (DictConfig): The config object

    Returns:
        InferredConfig: The updated config object

    """
    if self.cache_mode is self.CACHE_MODES.INFERENCE:
        warnings.warn("Cannot update config in inference mode. Returning the cached config")
        return self._inferred_config
    else:
        return self._update_config(config)

val_dataloader(batch_size=None)

Function that loads the validation set.

Parameters:

Name Type Description Default
batch_size Optional[int]

Batch size. Defaults to self.batch_size.

None

Returns:

Name Type Description
DataLoader DataLoader

Validation dataloader

Source code in src/pytorch_tabular/tabular_datamodule.py
def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
    """Function that loads the validation set.

    Args:
        batch_size (Optional[int], optional): Batch size. Defaults to `self.batch_size`.

    Returns:
        DataLoader: Validation dataloader

    """
    return DataLoader(
        self.validation_dataset,
        batch_size or self.batch_size,
        shuffle=False,
        num_workers=self.config.num_workers,
        pin_memory=self.config.pin_memory,
    )

Tabular Model Tuner.

This class is used to tune the hyperparameters of a TabularModel, given the search space, strategy and metric to optimize.

Source code in src/pytorch_tabular/tabular_model_tuner.py
class TabularModelTuner:
    """Tabular Model Tuner.

    This class is used to tune the hyperparameters of a TabularModel, given the search space,  strategy and metric to
    optimize.

    """

    ALLOWABLE_STRATEGIES = ["grid_search", "random_search"]
    OUTPUT = namedtuple("OUTPUT", ["trials_df", "best_params", "best_score", "best_model"])

    def __init__(
        self,
        data_config: Optional[Union[DataConfig, str]] = None,
        model_config: Optional[Union[ModelConfig, str]] = None,
        optimizer_config: Optional[Union[OptimizerConfig, str]] = None,
        trainer_config: Optional[Union[TrainerConfig, str]] = None,
        model_callable: Optional[Callable] = None,
        model_state_dict_path: Optional[Union[str, Path]] = None,
        suppress_lightning_logger: bool = True,
        **kwargs,
    ):
        """Tabular Model Tuner helps you tune the hyperparameters of a TabularModel.

        Args:
            data_config (Optional[Union[DataConfig, str]], optional): The DataConfig for the TabularModel.
                If str is passed, will initialize the DataConfig using the yaml file in that path.
                Defaults to None.

            model_config (Optional[Union[ModelConfig, str]], optional): The ModelConfig for the TabularModel.
                If str is passed, will initialize the ModelConfig using the yaml file in that path.
                Defaults to None.

            optimizer_config (Optional[Union[OptimizerConfig, str]], optional): The OptimizerConfig for the
                TabularModel. If str is passed, will initialize the OptimizerConfig using the yaml file in
                that path. Defaults to None.

            trainer_config (Optional[Union[TrainerConfig, str]], optional): The TrainerConfig for the TabularModel.
                If str is passed, will initialize the TrainerConfig using the yaml file in that path.
                Defaults to None.

            model_callable (Optional[Callable], optional): A callable that returns a PyTorch Tabular Model.
                If provided, will ignore the model_config and use this callable to initialize the model.
                Defaults to None.

            model_state_dict_path (Optional[Union[str, Path]], optional): Path to the state dict of the model.

                If provided, will ignore the model_config and use this state dict to initialize the model.
                Defaults to None.

            suppress_lightning_logger (bool, optional): Whether to suppress the lightning logger. Defaults to True.

            **kwargs: Additional keyword arguments to be passed to the TabularModel init.

        """
        if trainer_config.profiler is not None:
            warnings.warn(
                "Profiler is not supported in tuner. Set profiler=None in TrainerConfig to disable this warning."
            )
            trainer_config.profiler = None
        if trainer_config.fast_dev_run:
            warnings.warn("fast_dev_run is turned on. Tuning results won't be accurate.")
        if trainer_config.progress_bar != "none":
            # If config and tuner have progress bar enabled, it will result in a bug within the library (rich.progress)
            trainer_config.progress_bar = "none"
            warnings.warn("Turning off progress bar. Set progress_bar='none' in TrainerConfig to disable this warning.")
        trainer_config.trainer_kwargs.update({"enable_model_summary": False})
        self.data_config = data_config
        self.model_config = model_config
        self.optimizer_config = optimizer_config
        self.trainer_config = trainer_config
        self.suppress_lightning_logger = suppress_lightning_logger
        self.tabular_model_init_kwargs = {
            "model_callable": model_callable,
            "model_state_dict_path": model_state_dict_path,
            **kwargs,
        }

    def _check_assign_config(self, config, param, value):
        if isinstance(config, DictConfig):
            if param in config:
                config[param] = value
            else:
                raise ValueError(f"{param} is not a valid parameter for {str(config)}")
        elif isinstance(config, (ModelConfig, OptimizerConfig)):
            if hasattr(config, param):
                setattr(config, param, value)
            else:
                raise ValueError(f"{param} is not a valid parameter for {str(config)}")

    def _update_configs(
        self,
        optimizer_config: OptimizerConfig,
        model_config: ModelConfig,
        params: Dict,
    ):
        """Update the configs with the new parameters."""
        # update configs with the new parameters
        for k, v in params.items():
            root, param = k.split("__")
            if root.startswith("trainer_config"):
                raise ValueError(
                    "The trainer_config is not supported be tuner. Please remove it from tuner parameters!"
                )
            elif root.startswith("optimizer_config"):
                self._check_assign_config(optimizer_config, param, v)
            elif root.startswith("model_config.head_config"):
                param = param.replace("model_config.head_config.", "")
                self._check_assign_config(model_config.head_config, param, v)
            elif root.startswith("model_config") and "head_config" not in root:
                self._check_assign_config(model_config, param, v)
            else:
                raise ValueError(
                    f"{k} is not in the proper format. Use __ to separate the "
                    "root and param. for eg. `optimizer_config__optimizer` should be "
                    "used to update the optimizer parameter in the optimizer_config"
                )
        return optimizer_config, model_config

    def tune(
        self,
        train: DataFrame,
        search_space: Dict,
        metric: Union[str, Callable],
        mode: str,
        strategy: str,
        validation: Optional[DataFrame] = None,
        n_trials: Optional[int] = None,
        cv: Optional[Union[int, Iterable, BaseCrossValidator]] = None,
        cv_agg_func: Optional[Callable] = np.mean,
        cv_kwargs: Optional[Dict] = {},
        return_best_model: bool = True,
        verbose: bool = False,
        progress_bar: bool = True,
        random_state: Optional[int] = 42,
        ignore_oom: bool = True,
        **kwargs,
    ):
        """Tune the hyperparameters of the TabularModel.

        Args:
            train (DataFrame): Training data

            validation (DataFrame, optional): Validation data. Defaults to None.

            search_space (Dict): A dictionary of the form {param_name: [values to try]}
                for grid search or {param_name: distribution} for random search

            metric (Union[str, Callable]): The metric to be used for evaluation.
                If str is provided, will use that metric from the defined ones.
                If callable is provided, will use that function as the metric.
                We expect callable to be of the form `metric(y_true, y_pred)`. For classification
                problems, The `y_pred` is a dataframe with the probabilities for each class
                (<class>_probability) and a final prediction(prediction). And for Regression, it is a
                dataframe with a final prediction (<target>_prediction).
                Defaults to None.

            mode (str): One of ['max', 'min']. Whether to maximize or minimize the metric.

            strategy (str): One of ['grid_search', 'random_search']. The strategy to use for tuning.

            n_trials (int, optional): Number of trials to run. Only used for random search.
                Defaults to None.

            cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
                Possible inputs for cv are:

                - None, to not use any cross validation. We will just use the validation data
                - integer, to specify the number of folds in a (Stratified)KFold,
                - An iterable yielding (train, test) splits as arrays of indices.
                - A scikit-learn CV splitter.
                Defaults to None.

            cv_agg_func (Optional[Callable], optional): Function to aggregate the cross validation scores.
                Defaults to np.mean.

            cv_kwargs (Optional[Dict], optional): Additional keyword arguments to be passed to the cross validation
                method. Defaults to {}.

            return_best_model (bool, optional): If True, will return the best model. Defaults to True.

            verbose (bool, optional): Whether to print the results of each trial. Defaults to False.

            progress_bar (bool, optional): Whether to show a progress bar. Defaults to True.

            random_state (Optional[int], optional): Random state to be used for random search. Defaults to 42.

            ignore_oom (bool, optional): Whether to ignore out of memory errors. Defaults to True.

            **kwargs: Additional keyword arguments to be passed to the TabularModel fit.

        Returns:
            OUTPUT: A named tuple with the following attributes:
                trials_df (DataFrame): A dataframe with the results of each trial
                best_params (Dict): The best parameters found
                best_score (float): The best score found
                best_model (TabularModel or None): If return_best_model is True, return best_model otherwise return None

        """
        assert strategy in self.ALLOWABLE_STRATEGIES, f"tuner must be one of {self.ALLOWABLE_STRATEGIES}"
        assert mode in ["max", "min"], "mode must be one of ['max', 'min']"
        assert metric is not None, "metric must be specified"
        assert isinstance(search_space, dict) and len(search_space) > 0, "search_space must be a non-empty dict"
        if self.suppress_lightning_logger:
            suppress_lightning_logs()
        if cv is not None and validation is not None:
            warnings.warn(
                "Both validation and cv are provided. Ignoring validation and using cv. Use "
                "`validation=None` to turn off this warning."
            )
            validation = None

        if strategy == "grid_search":
            assert all(
                isinstance(v, list) for v in search_space.values()
            ), "For grid search, all values in search_space must be a list of values to try"
            iterator = ParameterGrid(search_space)
            if n_trials is not None:
                warnings.warn(
                    "n_trials is ignored for grid search to do a complete sweep of"
                    " the grid. Set n_trials=None to turn off this warning."
                )
            n_trials = sum(1 for _ in iterator)
        elif strategy == "random_search":
            assert n_trials is not None, "n_trials must be specified for random search"
            iterator = ParameterSampler(search_space, n_iter=n_trials, random_state=random_state)
        else:
            raise NotImplementedError(f"{strategy} is not implemented yet.")

        if progress_bar:
            iterator = track(iterator, description=f"[green]{strategy.replace('_',' ').title()}...", total=n_trials)
        verbose_tabular_model = self.tabular_model_init_kwargs.pop("verbose", False)

        temp_tabular_model = TabularModel(
            data_config=self.data_config,
            model_config=self.model_config,
            optimizer_config=self.optimizer_config,
            trainer_config=self.trainer_config,
            verbose=verbose_tabular_model,
            **self.tabular_model_init_kwargs,
        )

        prep_dl_kwargs, prep_model_kwargs, train_kwargs = temp_tabular_model._split_kwargs(kwargs)
        if "seed" not in prep_dl_kwargs:
            prep_dl_kwargs["seed"] = random_state
        datamodule = temp_tabular_model.prepare_dataloader(train=train, validation=validation, **prep_dl_kwargs)
        validation = validation if validation is not None else datamodule.validation_dataset.data

        if isinstance(metric, str):
            # metric = metric_to_pt_metric(metric)
            is_callable_metric = False
            metric_str = metric
        elif callable(metric):
            is_callable_metric = True
            metric_str = metric.__name__
        del temp_tabular_model

        trials = []
        best_model = None
        best_score = 0.0
        for i, params in enumerate(iterator):
            # Copying the configs as a base
            # Make sure all default parameters that you want to be set for all
            # trials are in the original configs
            trainer_config_t = deepcopy(self.trainer_config)
            optimizer_config_t = deepcopy(self.optimizer_config)
            model_config_t = deepcopy(self.model_config)

            optimizer_config_t, model_config_t = self._update_configs(optimizer_config_t, model_config_t, params)
            # Initialize Tabular model using the new config
            tabular_model_t = TabularModel(
                data_config=self.data_config,
                model_config=model_config_t,
                optimizer_config=optimizer_config_t,
                trainer_config=trainer_config_t,
                verbose=verbose_tabular_model,
                **self.tabular_model_init_kwargs,
            )

            if cv is not None:
                cv_verbose = cv_kwargs.pop("verbose", False)
                cv_kwargs.pop("handle_oom", None)
                with OutOfMemoryHandler(handle_oom=True) as handler:
                    cv_scores, _ = tabular_model_t.cross_validate(
                        cv=cv,
                        train=train,
                        metric=metric,
                        verbose=cv_verbose,
                        handle_oom=False,
                        **cv_kwargs,
                    )
                if handler.oom_triggered:
                    if not ignore_oom:
                        raise OOMException(
                            "Out of memory error occurred during cross validation. "
                            "Set ignore_oom=True to ignore this error."
                        )
                    else:
                        params.update({metric_str: (np.inf if mode == "min" else -np.inf)})
                else:
                    params.update({metric_str: cv_agg_func(cv_scores)})
            else:
                model = tabular_model_t.prepare_model(
                    datamodule=datamodule,
                    **prep_model_kwargs,
                )
                train_kwargs.pop("handle_oom", None)
                with OutOfMemoryHandler(handle_oom=True) as handler:
                    tabular_model_t.train(model=model, datamodule=datamodule, handle_oom=False, **train_kwargs)
                if handler.oom_triggered:
                    if not ignore_oom:
                        raise OOMException(
                            "Out of memory error occurred during training. " "Set ignore_oom=True to ignore this error."
                        )
                    else:
                        params.update({metric_str: (np.inf if mode == "min" else -np.inf)})
                else:
                    if is_callable_metric:
                        preds = tabular_model_t.predict(validation, include_input_features=False)
                        params.update({metric_str: metric(validation[tabular_model_t.config.target], preds)})
                    else:
                        result = tabular_model_t.evaluate(validation, verbose=False)
                        params.update({k.replace("test_", ""): v for k, v in result[0].items()})

                    if return_best_model:
                        tabular_model_t.datamodule = None
                        if best_model is None:
                            best_model = deepcopy(tabular_model_t)
                            best_score = params[metric_str]
                        else:
                            if mode == "min":
                                if params[metric_str] < best_score:
                                    best_model = deepcopy(tabular_model_t)
                                    best_score = params[metric_str]
                            elif mode == "max":
                                if params[metric_str] > best_score:
                                    best_model = deepcopy(tabular_model_t)
                                    best_score = params[metric_str]

            params.update({"trial_id": i})
            trials.append(params)
            if verbose:
                logger.info(f"Trial {i+1}/{n_trials}: {params} | Score: {params[metric]}")
        trials_df = pd.DataFrame(trials)
        trials = trials_df.pop("trial_id")
        if mode == "max":
            best_idx = trials_df[metric_str].idxmax()
        elif mode == "min":
            best_idx = trials_df[metric_str].idxmin()
        else:
            raise NotImplementedError(f"{mode} is not implemented yet.")
        best_params = trials_df.iloc[best_idx].to_dict()
        best_score = best_params.pop(metric_str)
        trials_df.insert(0, "trial_id", trials)

        if verbose:
            logger.info("Model Tuner Finished")
            logger.info(f"Best Score ({metric_str}): {best_score}")

        if return_best_model and best_model is not None:
            best_model.datamodule = datamodule

            return self.OUTPUT(trials_df, best_params, best_score, best_model)
        else:
            return self.OUTPUT(trials_df, best_params, best_score, None)

__init__(data_config=None, model_config=None, optimizer_config=None, trainer_config=None, model_callable=None, model_state_dict_path=None, suppress_lightning_logger=True, **kwargs)

Tabular Model Tuner helps you tune the hyperparameters of a TabularModel.

Parameters:

Name Type Description Default
data_config Optional[Union[DataConfig, str]]

The DataConfig for the TabularModel. If str is passed, will initialize the DataConfig using the yaml file in that path. Defaults to None.

None
model_config Optional[Union[ModelConfig, str]]

The ModelConfig for the TabularModel. If str is passed, will initialize the ModelConfig using the yaml file in that path. Defaults to None.

None
optimizer_config Optional[Union[OptimizerConfig, str]]

The OptimizerConfig for the TabularModel. If str is passed, will initialize the OptimizerConfig using the yaml file in that path. Defaults to None.

None
trainer_config Optional[Union[TrainerConfig, str]]

The TrainerConfig for the TabularModel. If str is passed, will initialize the TrainerConfig using the yaml file in that path. Defaults to None.

None
model_callable Optional[Callable]

A callable that returns a PyTorch Tabular Model. If provided, will ignore the model_config and use this callable to initialize the model. Defaults to None.

None
model_state_dict_path Optional[Union[str, Path]]

Path to the state dict of the model.

If provided, will ignore the model_config and use this state dict to initialize the model. Defaults to None.

None
suppress_lightning_logger bool

Whether to suppress the lightning logger. Defaults to True.

True
**kwargs

Additional keyword arguments to be passed to the TabularModel init.

{}
Source code in src/pytorch_tabular/tabular_model_tuner.py
def __init__(
    self,
    data_config: Optional[Union[DataConfig, str]] = None,
    model_config: Optional[Union[ModelConfig, str]] = None,
    optimizer_config: Optional[Union[OptimizerConfig, str]] = None,
    trainer_config: Optional[Union[TrainerConfig, str]] = None,
    model_callable: Optional[Callable] = None,
    model_state_dict_path: Optional[Union[str, Path]] = None,
    suppress_lightning_logger: bool = True,
    **kwargs,
):
    """Tabular Model Tuner helps you tune the hyperparameters of a TabularModel.

    Args:
        data_config (Optional[Union[DataConfig, str]], optional): The DataConfig for the TabularModel.
            If str is passed, will initialize the DataConfig using the yaml file in that path.
            Defaults to None.

        model_config (Optional[Union[ModelConfig, str]], optional): The ModelConfig for the TabularModel.
            If str is passed, will initialize the ModelConfig using the yaml file in that path.
            Defaults to None.

        optimizer_config (Optional[Union[OptimizerConfig, str]], optional): The OptimizerConfig for the
            TabularModel. If str is passed, will initialize the OptimizerConfig using the yaml file in
            that path. Defaults to None.

        trainer_config (Optional[Union[TrainerConfig, str]], optional): The TrainerConfig for the TabularModel.
            If str is passed, will initialize the TrainerConfig using the yaml file in that path.
            Defaults to None.

        model_callable (Optional[Callable], optional): A callable that returns a PyTorch Tabular Model.
            If provided, will ignore the model_config and use this callable to initialize the model.
            Defaults to None.

        model_state_dict_path (Optional[Union[str, Path]], optional): Path to the state dict of the model.

            If provided, will ignore the model_config and use this state dict to initialize the model.
            Defaults to None.

        suppress_lightning_logger (bool, optional): Whether to suppress the lightning logger. Defaults to True.

        **kwargs: Additional keyword arguments to be passed to the TabularModel init.

    """
    if trainer_config.profiler is not None:
        warnings.warn(
            "Profiler is not supported in tuner. Set profiler=None in TrainerConfig to disable this warning."
        )
        trainer_config.profiler = None
    if trainer_config.fast_dev_run:
        warnings.warn("fast_dev_run is turned on. Tuning results won't be accurate.")
    if trainer_config.progress_bar != "none":
        # If config and tuner have progress bar enabled, it will result in a bug within the library (rich.progress)
        trainer_config.progress_bar = "none"
        warnings.warn("Turning off progress bar. Set progress_bar='none' in TrainerConfig to disable this warning.")
    trainer_config.trainer_kwargs.update({"enable_model_summary": False})
    self.data_config = data_config
    self.model_config = model_config
    self.optimizer_config = optimizer_config
    self.trainer_config = trainer_config
    self.suppress_lightning_logger = suppress_lightning_logger
    self.tabular_model_init_kwargs = {
        "model_callable": model_callable,
        "model_state_dict_path": model_state_dict_path,
        **kwargs,
    }

tune(train, search_space, metric, mode, strategy, validation=None, n_trials=None, cv=None, cv_agg_func=np.mean, cv_kwargs={}, return_best_model=True, verbose=False, progress_bar=True, random_state=42, ignore_oom=True, **kwargs)

Tune the hyperparameters of the TabularModel.

Parameters:

Name Type Description Default
train DataFrame

Training data

required
validation DataFrame

Validation data. Defaults to None.

None
search_space Dict

A dictionary of the form {param_name: [values to try]} for grid search or {param_name: distribution} for random search

required
metric Union[str, Callable]

The metric to be used for evaluation. If str is provided, will use that metric from the defined ones. If callable is provided, will use that function as the metric. We expect callable to be of the form metric(y_true, y_pred). For classification problems, The y_pred is a dataframe with the probabilities for each class (_probability) and a final prediction(prediction). And for Regression, it is a dataframe with a final prediction (_prediction). Defaults to None.

required
mode str

One of ['max', 'min']. Whether to maximize or minimize the metric.

required
strategy str

One of ['grid_search', 'random_search']. The strategy to use for tuning.

required
n_trials int

Number of trials to run. Only used for random search. Defaults to None.

None
cv Optional[Union[int, Iterable, BaseCrossValidator]]

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to not use any cross validation. We will just use the validation data
  • integer, to specify the number of folds in a (Stratified)KFold,
  • An iterable yielding (train, test) splits as arrays of indices.
  • A scikit-learn CV splitter. Defaults to None.
None
cv_agg_func Optional[Callable]

Function to aggregate the cross validation scores. Defaults to np.mean.

mean
cv_kwargs Optional[Dict]

Additional keyword arguments to be passed to the cross validation method. Defaults to {}.

{}
return_best_model bool

If True, will return the best model. Defaults to True.

True
verbose bool

Whether to print the results of each trial. Defaults to False.

False
progress_bar bool

Whether to show a progress bar. Defaults to True.

True
random_state Optional[int]

Random state to be used for random search. Defaults to 42.

42
ignore_oom bool

Whether to ignore out of memory errors. Defaults to True.

True
**kwargs

Additional keyword arguments to be passed to the TabularModel fit.

{}

Returns:

Name Type Description
OUTPUT

A named tuple with the following attributes: trials_df (DataFrame): A dataframe with the results of each trial best_params (Dict): The best parameters found best_score (float): The best score found best_model (TabularModel or None): If return_best_model is True, return best_model otherwise return None

Source code in src/pytorch_tabular/tabular_model_tuner.py
def tune(
    self,
    train: DataFrame,
    search_space: Dict,
    metric: Union[str, Callable],
    mode: str,
    strategy: str,
    validation: Optional[DataFrame] = None,
    n_trials: Optional[int] = None,
    cv: Optional[Union[int, Iterable, BaseCrossValidator]] = None,
    cv_agg_func: Optional[Callable] = np.mean,
    cv_kwargs: Optional[Dict] = {},
    return_best_model: bool = True,
    verbose: bool = False,
    progress_bar: bool = True,
    random_state: Optional[int] = 42,
    ignore_oom: bool = True,
    **kwargs,
):
    """Tune the hyperparameters of the TabularModel.

    Args:
        train (DataFrame): Training data

        validation (DataFrame, optional): Validation data. Defaults to None.

        search_space (Dict): A dictionary of the form {param_name: [values to try]}
            for grid search or {param_name: distribution} for random search

        metric (Union[str, Callable]): The metric to be used for evaluation.
            If str is provided, will use that metric from the defined ones.
            If callable is provided, will use that function as the metric.
            We expect callable to be of the form `metric(y_true, y_pred)`. For classification
            problems, The `y_pred` is a dataframe with the probabilities for each class
            (<class>_probability) and a final prediction(prediction). And for Regression, it is a
            dataframe with a final prediction (<target>_prediction).
            Defaults to None.

        mode (str): One of ['max', 'min']. Whether to maximize or minimize the metric.

        strategy (str): One of ['grid_search', 'random_search']. The strategy to use for tuning.

        n_trials (int, optional): Number of trials to run. Only used for random search.
            Defaults to None.

        cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
            Possible inputs for cv are:

            - None, to not use any cross validation. We will just use the validation data
            - integer, to specify the number of folds in a (Stratified)KFold,
            - An iterable yielding (train, test) splits as arrays of indices.
            - A scikit-learn CV splitter.
            Defaults to None.

        cv_agg_func (Optional[Callable], optional): Function to aggregate the cross validation scores.
            Defaults to np.mean.

        cv_kwargs (Optional[Dict], optional): Additional keyword arguments to be passed to the cross validation
            method. Defaults to {}.

        return_best_model (bool, optional): If True, will return the best model. Defaults to True.

        verbose (bool, optional): Whether to print the results of each trial. Defaults to False.

        progress_bar (bool, optional): Whether to show a progress bar. Defaults to True.

        random_state (Optional[int], optional): Random state to be used for random search. Defaults to 42.

        ignore_oom (bool, optional): Whether to ignore out of memory errors. Defaults to True.

        **kwargs: Additional keyword arguments to be passed to the TabularModel fit.

    Returns:
        OUTPUT: A named tuple with the following attributes:
            trials_df (DataFrame): A dataframe with the results of each trial
            best_params (Dict): The best parameters found
            best_score (float): The best score found
            best_model (TabularModel or None): If return_best_model is True, return best_model otherwise return None

    """
    assert strategy in self.ALLOWABLE_STRATEGIES, f"tuner must be one of {self.ALLOWABLE_STRATEGIES}"
    assert mode in ["max", "min"], "mode must be one of ['max', 'min']"
    assert metric is not None, "metric must be specified"
    assert isinstance(search_space, dict) and len(search_space) > 0, "search_space must be a non-empty dict"
    if self.suppress_lightning_logger:
        suppress_lightning_logs()
    if cv is not None and validation is not None:
        warnings.warn(
            "Both validation and cv are provided. Ignoring validation and using cv. Use "
            "`validation=None` to turn off this warning."
        )
        validation = None

    if strategy == "grid_search":
        assert all(
            isinstance(v, list) for v in search_space.values()
        ), "For grid search, all values in search_space must be a list of values to try"
        iterator = ParameterGrid(search_space)
        if n_trials is not None:
            warnings.warn(
                "n_trials is ignored for grid search to do a complete sweep of"
                " the grid. Set n_trials=None to turn off this warning."
            )
        n_trials = sum(1 for _ in iterator)
    elif strategy == "random_search":
        assert n_trials is not None, "n_trials must be specified for random search"
        iterator = ParameterSampler(search_space, n_iter=n_trials, random_state=random_state)
    else:
        raise NotImplementedError(f"{strategy} is not implemented yet.")

    if progress_bar:
        iterator = track(iterator, description=f"[green]{strategy.replace('_',' ').title()}...", total=n_trials)
    verbose_tabular_model = self.tabular_model_init_kwargs.pop("verbose", False)

    temp_tabular_model = TabularModel(
        data_config=self.data_config,
        model_config=self.model_config,
        optimizer_config=self.optimizer_config,
        trainer_config=self.trainer_config,
        verbose=verbose_tabular_model,
        **self.tabular_model_init_kwargs,
    )

    prep_dl_kwargs, prep_model_kwargs, train_kwargs = temp_tabular_model._split_kwargs(kwargs)
    if "seed" not in prep_dl_kwargs:
        prep_dl_kwargs["seed"] = random_state
    datamodule = temp_tabular_model.prepare_dataloader(train=train, validation=validation, **prep_dl_kwargs)
    validation = validation if validation is not None else datamodule.validation_dataset.data

    if isinstance(metric, str):
        # metric = metric_to_pt_metric(metric)
        is_callable_metric = False
        metric_str = metric
    elif callable(metric):
        is_callable_metric = True
        metric_str = metric.__name__
    del temp_tabular_model

    trials = []
    best_model = None
    best_score = 0.0
    for i, params in enumerate(iterator):
        # Copying the configs as a base
        # Make sure all default parameters that you want to be set for all
        # trials are in the original configs
        trainer_config_t = deepcopy(self.trainer_config)
        optimizer_config_t = deepcopy(self.optimizer_config)
        model_config_t = deepcopy(self.model_config)

        optimizer_config_t, model_config_t = self._update_configs(optimizer_config_t, model_config_t, params)
        # Initialize Tabular model using the new config
        tabular_model_t = TabularModel(
            data_config=self.data_config,
            model_config=model_config_t,
            optimizer_config=optimizer_config_t,
            trainer_config=trainer_config_t,
            verbose=verbose_tabular_model,
            **self.tabular_model_init_kwargs,
        )

        if cv is not None:
            cv_verbose = cv_kwargs.pop("verbose", False)
            cv_kwargs.pop("handle_oom", None)
            with OutOfMemoryHandler(handle_oom=True) as handler:
                cv_scores, _ = tabular_model_t.cross_validate(
                    cv=cv,
                    train=train,
                    metric=metric,
                    verbose=cv_verbose,
                    handle_oom=False,
                    **cv_kwargs,
                )
            if handler.oom_triggered:
                if not ignore_oom:
                    raise OOMException(
                        "Out of memory error occurred during cross validation. "
                        "Set ignore_oom=True to ignore this error."
                    )
                else:
                    params.update({metric_str: (np.inf if mode == "min" else -np.inf)})
            else:
                params.update({metric_str: cv_agg_func(cv_scores)})
        else:
            model = tabular_model_t.prepare_model(
                datamodule=datamodule,
                **prep_model_kwargs,
            )
            train_kwargs.pop("handle_oom", None)
            with OutOfMemoryHandler(handle_oom=True) as handler:
                tabular_model_t.train(model=model, datamodule=datamodule, handle_oom=False, **train_kwargs)
            if handler.oom_triggered:
                if not ignore_oom:
                    raise OOMException(
                        "Out of memory error occurred during training. " "Set ignore_oom=True to ignore this error."
                    )
                else:
                    params.update({metric_str: (np.inf if mode == "min" else -np.inf)})
            else:
                if is_callable_metric:
                    preds = tabular_model_t.predict(validation, include_input_features=False)
                    params.update({metric_str: metric(validation[tabular_model_t.config.target], preds)})
                else:
                    result = tabular_model_t.evaluate(validation, verbose=False)
                    params.update({k.replace("test_", ""): v for k, v in result[0].items()})

                if return_best_model:
                    tabular_model_t.datamodule = None
                    if best_model is None:
                        best_model = deepcopy(tabular_model_t)
                        best_score = params[metric_str]
                    else:
                        if mode == "min":
                            if params[metric_str] < best_score:
                                best_model = deepcopy(tabular_model_t)
                                best_score = params[metric_str]
                        elif mode == "max":
                            if params[metric_str] > best_score:
                                best_model = deepcopy(tabular_model_t)
                                best_score = params[metric_str]

        params.update({"trial_id": i})
        trials.append(params)
        if verbose:
            logger.info(f"Trial {i+1}/{n_trials}: {params} | Score: {params[metric]}")
    trials_df = pd.DataFrame(trials)
    trials = trials_df.pop("trial_id")
    if mode == "max":
        best_idx = trials_df[metric_str].idxmax()
    elif mode == "min":
        best_idx = trials_df[metric_str].idxmin()
    else:
        raise NotImplementedError(f"{mode} is not implemented yet.")
    best_params = trials_df.iloc[best_idx].to_dict()
    best_score = best_params.pop(metric_str)
    trials_df.insert(0, "trial_id", trials)

    if verbose:
        logger.info("Model Tuner Finished")
        logger.info(f"Best Score ({metric_str}): {best_score}")

    if return_best_model and best_model is not None:
        best_model.datamodule = datamodule

        return self.OUTPUT(trials_df, best_params, best_score, best_model)
    else:
        return self.OUTPUT(trials_df, best_params, best_score, None)

Compare multiple models on the same dataset.

Parameters:

Name Type Description Default
task str

The type of prediction task. Either 'classification' or 'regression'

required
train DataFrame

The training data

required
test DataFrame

The test data on which performance is evaluated

required
data_config Union[DataConfig, str]

DataConfig object or path to the yaml file.

required
optimizer_config Union[OptimizerConfig, str]

OptimizerConfig object or path to the yaml file.

required
trainer_config Union[TrainerConfig, str]

TrainerConfig object or path to the yaml file.

required
model_list Union[str, List[Union[ModelConfig, str]]]

The list of models to compare. This can be one of the presets defined in pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS or a list of ModelConfig objects. Defaults to "lite".

'lite'
metrics Optional[List[str]]

the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in torchmetrics. By default, it is accuracy if classification and mean_squared_error for regression

None
metrics_prob_input Optional[bool]

Is a mandatory parameter for classification metrics defined in the config. This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.

None
metrics_params Optional[List]

The parameters to be passed to the metrics function. task is forced to be multiclass because the multiclass version can handle binary as well and for simplicity we are only using multiclass.

None
validation Optional[DataFrame]
If provided, will use this dataframe as the validation while training.
Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
Defaults to None.
None
experiment_config Optional[Union[ExperimentConfig, str]]

ExperimentConfig object or path to the yaml file.

None
common_model_args Optional[dict]

The model argument which are common to all models. The list of params can be found in ModelConfig. If not provided, will use defaults. Defaults to {}.

{}
rank_metric Optional[Tuple[str, str]]

The metric to use for ranking the models. The first element of the tuple is the metric name and the second element is the direction. Defaults to ('loss', "lower_is_better").

('loss', 'lower_is_better')
return_best_model bool

If True, will return the best model. Defaults to True.

True
seed int

The seed for reproducibility. Defaults to 42.

42
ignore_oom bool

If True, will ignore the Out of Memory error and continue with the next model.

True
progress_bar bool

If True, will show a progress bar. Defaults to True.

True
verbose bool

If True, will print the progress. Defaults to True.

True
suppress_lightning_logger bool

If True, will suppress the lightning logger. Defaults to True.

True
Returns

results: Training results.

best_model: If return_best_model is True, return best_model otherwise return None.

required
Source code in src/pytorch_tabular/tabular_model_sweep.py
def model_sweep(
    task: str,
    train: pd.DataFrame,
    test: pd.DataFrame,
    data_config: Union[DataConfig, str],
    optimizer_config: Union[OptimizerConfig, str],
    trainer_config: Union[TrainerConfig, str],
    model_list: Union[str, List[Union[ModelConfig, str]]] = "lite",
    metrics: Optional[List[Union[str, Callable]]] = None,
    metrics_params: Optional[List[dict]] = None,
    metrics_prob_input: Optional[List[bool]] = None,
    validation: Optional[pd.DataFrame] = None,
    experiment_config: Optional[Union[ExperimentConfig, str]] = None,
    common_model_args: Optional[dict] = {},
    rank_metric: Optional[Tuple[str, str]] = ("loss", "lower_is_better"),
    return_best_model: bool = True,
    seed: int = 42,
    ignore_oom: bool = True,
    progress_bar: bool = True,
    verbose: bool = True,
    suppress_lightning_logger: bool = True,
):
    """Compare multiple models on the same dataset.

    Args:
        task (str): The type of prediction task. Either 'classification' or 'regression'

        train (pd.DataFrame): The training data

        test (pd.DataFrame): The test data on which performance is evaluated

        data_config (Union[DataConfig, str]): DataConfig object or path to the yaml file.

        optimizer_config (Union[OptimizerConfig, str]): OptimizerConfig object or path to the yaml file.

        trainer_config (Union[TrainerConfig, str]): TrainerConfig object or path to the yaml file.

        model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare.
                This can be one of the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS``
                or a list of ``ModelConfig`` objects. Defaults to "lite".

        metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics
                should be one of the functional metrics implemented in ``torchmetrics``. By default, it is
                accuracy if classification and mean_squared_error for regression

        metrics_prob_input (Optional[bool]): Is a mandatory parameter for classification metrics defined in
                the config. This defines whether the input to the metric function is the probability or the class.
                Length should be same as the number of metrics. Defaults to None.

        metrics_params (Optional[List]): The parameters to be passed to the metrics function. `task` is forced to
                be `multiclass` because the multiclass version can handle binary as well and for simplicity we are
                only using `multiclass`.

        validation (Optional[DataFrame], optional):
                If provided, will use this dataframe as the validation while training.
                Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
                Defaults to None.

        experiment_config (Optional[Union[ExperimentConfig, str]], optional): ExperimentConfig object or path to
                the yaml file.

        common_model_args (Optional[dict], optional): The model argument which are common to all models. The list
                of params can be found in ``ModelConfig``. If not provided, will use defaults. Defaults to {}.

        rank_metric (Optional[Tuple[str, str]], optional): The metric to use for ranking the models. The first element
                of the tuple is the metric name and the second element is the direction.
                Defaults to ('loss', "lower_is_better").

        return_best_model (bool, optional): If True, will return the best model. Defaults to True.

        seed (int, optional): The seed for reproducibility. Defaults to 42.

        ignore_oom (bool, optional): If True, will ignore the Out of Memory error and continue with the next model.

        progress_bar (bool, optional): If True, will show a progress bar. Defaults to True.

        verbose (bool, optional): If True, will print the progress. Defaults to True.

        suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True.

        Returns:
            results: Training results.

            best_model: If return_best_model is True, return best_model otherwise return None.

    """
    _validate_args(
        task=task,
        train=train,
        test=test,
        data_config=data_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        model_list=model_list,
        metrics=metrics,
        metrics_params=metrics_params,
        metrics_prob_input=metrics_prob_input,
        validation=validation,
        experiment_config=experiment_config,
        common_model_args=common_model_args,
        rank_metric=rank_metric,
    )
    if suppress_lightning_logger:
        suppress_lightning_logs()
    if progress_bar:
        if trainer_config.progress_bar != "none":
            # Turning off thie internal progress bar to avoid conflict with sweep progress bar
            warnings.warn(
                "Training Progress bar is not `none`. Set `progress_bar=none` in"
                " `trainer_config` to remove this warning"
            )
            trainer_config.progress_bar = "none"

    if model_list in ["full", "high_memory"]:
        warnings.warn(
            "The full model list is quite large and uses a lot of memory. "
            "Consider using `lite` or define configs yourselves for a faster run"
        )
    _model_args = ["metrics", "metrics_params", "metrics_prob_input"]
    # Replacing the common model args with the ones passed in the function
    for arg in _model_args:
        if locals()[arg] is not None:
            common_model_args[arg] = locals()[arg]
    if isinstance(model_list, str):
        model_list = copy.deepcopy(MODEL_SWEEP_PRESETS[model_list])
        model_list = [
            (
                getattr(models, model_config[0])(task=task, **model_config[1], **common_model_args)
                if isinstance(model_config, Tuple)
                else (
                    getattr(models, model_config)(task=task, **common_model_args)
                    if isinstance(model_config, str)
                    else model_config
                )
            )
            for model_config in model_list
        ]

    def _init_tabular_model(m):
        return TabularModel(
            data_config=data_config,
            model_config=m,
            optimizer_config=optimizer_config,
            trainer_config=trainer_config,
            experiment_config=experiment_config,
            verbose=False,
        )

    datamodule = _init_tabular_model(model_list[0]).prepare_dataloader(train=train, validation=validation, seed=seed)
    results = []
    best_model = None
    is_lower_better = rank_metric[1] == "lower_is_better"
    best_score = 1e9 if is_lower_better else -1e9
    it = track(model_list, description="Sweeping Models") if progress_bar else model_list
    ctx = Progress() if progress_bar else nullcontext()
    with ctx as progress:
        if progress_bar:
            task_p = progress.add_task("Sweeping Models", total=len(model_list))
        for model_config in model_list:
            if isinstance(model_config, str):
                model_config = getattr(models, model_config)(task=task, **common_model_args)
            else:
                for key, val in common_model_args.items():
                    if hasattr(model_config, key):
                        setattr(model_config, key, val)
                    else:
                        raise ValueError(
                            f"ModelConfig {model_config.name} does not have an" f" attribute {key} in common_model_args"
                        )
            params = model_config.__dict__
            start_time = time.time()
            tabular_model = _init_tabular_model(model_config)
            name = tabular_model.name
            if verbose:
                logger.info(f"Training {name}")
            model = tabular_model.prepare_model(datamodule)
            if progress_bar:
                progress.update(task_p, description=f"Training {name}", advance=1)
            with OutOfMemoryHandler(handle_oom=True) as handler:
                tabular_model.train(model, datamodule, handle_oom=False)
            res_dict = {
                "model": name,
                "# Params": int_to_human_readable(tabular_model.num_params),
            }
            if handler.oom_triggered:
                if not ignore_oom:
                    raise OOMException(
                        "Out of memory error occurred during cross validation. "
                        "Set ignore_oom=True to ignore this error."
                    )
                else:
                    res_dict.update(
                        {
                            f"test_{rank_metric[0]}": (np.inf if is_lower_better else -np.inf),
                            "epochs": "OOM",
                            "time_taken": "OOM",
                            "time_taken_per_epoch": "OOM",
                        }
                    )
                    res_dict["model"] = name + " (OOM)"
            else:
                if (
                    tabular_model.trainer.early_stopping_callback is not None
                    and tabular_model.trainer.early_stopping_callback.stopped_epoch != 0
                ):
                    res_dict["epochs"] = tabular_model.trainer.early_stopping_callback.stopped_epoch
                else:
                    res_dict["epochs"] = tabular_model.trainer.max_epochs
                res_dict.update(tabular_model.evaluate(test=test, verbose=False)[0])
                res_dict["time_taken"] = time.time() - start_time
                res_dict["time_taken_per_epoch"] = res_dict["time_taken"] / res_dict["epochs"]

                if return_best_model:
                    tabular_model.datamodule = None
                    if best_model is None:
                        best_model = copy.deepcopy(tabular_model)
                        best_score = res_dict[f"test_{rank_metric[0]}"]
                    else:
                        if is_lower_better:
                            if res_dict[f"test_{rank_metric[0]}"] < best_score:
                                best_model = copy.deepcopy(tabular_model)
                                best_score = res_dict[f"test_{rank_metric[0]}"]
                        else:
                            if res_dict[f"test_{rank_metric[0]}"] > best_score:
                                best_model = copy.deepcopy(tabular_model)
                                best_score = res_dict[f"test_{rank_metric[0]}"]

            if verbose:
                logger.info(f"Finished Training {name}")
                logger.info("Results:" f" {', '.join([f'{k}: {v}' for k, v in res_dict.items()])}")
            res_dict["params"] = params
            results.append(res_dict)

    if verbose:
        logger.info("Model Sweep Finished")
        logger.info(f"Best Model: {best_model.name}")
    results = pd.DataFrame(results).sort_values(by=f"test_{rank_metric[0]}", ascending=is_lower_better)
    if return_best_model and best_model is not None:
        best_model.datamodule = datamodule
        return results, best_model
    else:
        return results, None