Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v6.0.0
hooks:
- id: debug-statements
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-merge-conflict
- repo: https://github.com/asottile/pyupgrade
rev: v2.4.1
rev: v3.21.2
hooks:
- id: pyupgrade
- repo: https://github.com/pycqa/isort
rev: 5.8.0
rev: 7.0.0
hooks:
- id: isort
exclude: ^tests/|torchflare/callbacks/__init__.py
- repo: https://github.com/python/black
rev: 21.4b2
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.11.0
hooks:
- id: black
exclude: ^tests/
args: [ --safe, --quiet ]
args: [ --safe, --quiet ]
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
author = "Atharva Phatak"

# The full version, including alpha/beta/rc tags
with open("../../version.txt", "r") as f:
with open("../../version.txt") as f:
release = str(f.readline().strip())


Expand Down
3 changes: 2 additions & 1 deletion examples/Advanced-Tutorials/KD/vanilla-kd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Vanilla Knowledge distillation using Torchflare.
This example only shows how to modify the training script for KD.
"""

from typing import Dict

import torch
Expand All @@ -11,7 +12,7 @@

class KDExperiment(Experiment):
def __init__(self, temperature, alpha, **kwargs):
super(KDExperiment, self).__init__(**kwargs)
super().__init__(**kwargs)
self.temperature = temperature
self.alpha = alpha

Expand Down
7 changes: 4 additions & 3 deletions examples/Advanced-Tutorials/autoencoders/mnist-vae.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generating MNIST digits using Variational Autoencoders."""

import torch
import torch.nn.functional as F
from torch import nn
Expand All @@ -15,13 +16,13 @@ def __init__(self, d):
super().__init__()
self.d = d
self.encoder = nn.Sequential(
nn.Linear(784, self.d ** 2), nn.ReLU(), nn.Linear(self.d ** 2, self.d * 2)
nn.Linear(784, self.d**2), nn.ReLU(), nn.Linear(self.d**2, self.d * 2)
)

self.decoder = nn.Sequential(
nn.Linear(self.d, self.d ** 2),
nn.Linear(self.d, self.d**2),
nn.ReLU(),
nn.Linear(self.d ** 2, 784),
nn.Linear(self.d**2, 784),
nn.Sigmoid(),
)

Expand Down
7 changes: 4 additions & 3 deletions examples/Advanced-Tutorials/gans/dcgan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generating MNIST Digits using DCGAN."""

import os

import torch
Expand All @@ -20,7 +21,7 @@ def __init__(self, latent_dim, batchnorm=True):
latent_dim (int): latent dimension ("noise vector")
batchnorm (bool): Whether or not to use batch normalization
"""
super(Generator, self).__init__()
super().__init__()
self.latent_dim = latent_dim
self.batchnorm = batchnorm
self._init_modules()
Expand Down Expand Up @@ -77,7 +78,7 @@ def __init__(self, output_dim):
Images must be single-channel and 28x28 pixels.
Output activation is Sigmoid.
"""
super(Discriminator, self).__init__()
super().__init__()
self.output_dim = output_dim
self._init_modules() # I know this is overly-organized. Fight me.

Expand Down Expand Up @@ -127,7 +128,7 @@ def forward(self, input_tensor):
class DCGANExperiment(Experiment):
def __init__(self, latent_dim, batch_size, **kwargs):

super(DCGANExperiment, self).__init__(**kwargs)
super().__init__(**kwargs)

self.noise_fn = lambda x: torch.randn((x, latent_dim), device=self.device)
self.target_ones = torch.ones((batch_size, 1), device=self.device)
Expand Down
6 changes: 3 additions & 3 deletions examples/Advanced-Tutorials/self-supervision/ssl_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
# Defining the models.
class MLPHead(nn.Module):
def __init__(self, in_channels: int, projection_size: int = 256, hidden_size: int = 4096):
super(MLPHead, self).__init__()
super().__init__()

self.net = nn.Sequential(
nn.Linear(in_channels, hidden_size),
Expand All @@ -81,7 +81,7 @@ def forward(self, x):
# Defining resnet encoders.
class ResnetEncoder(nn.Module):
def __init__(self, pretrained, mlp_params):
super(ResnetEncoder, self).__init__()
super().__init__()
resnet = torchvision.models.resnet18(pretrained=pretrained)
self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1])
self.projector = MLPHead(in_channels=resnet.fc.in_features, **mlp_params)
Expand All @@ -95,7 +95,7 @@ def forward(self, x):
# Defining custom training method required as required by Bootstrap your own latent.(SSL)
class BYOLExperiment(Experiment):
def __init__(self, momentum, augmentation_fn, image_size, **kwargs):
super(BYOLExperiment, self).__init__(**kwargs)
super().__init__(**kwargs)
self.momentum = momentum
self.augmentation_fn = augmentation_fn(image_size)

Expand Down
2 changes: 1 addition & 1 deletion examples/Basic-Tutorials/fit_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class Net(torch.nn.Module):
def __init__(self, out_features):
super(Net, self).__init__()
super().__init__()

self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1),
Expand Down
2 changes: 1 addition & 1 deletion examples/Basic-Tutorials/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class Model(torch.nn.Module):
def __init__(self, dropout, out_features):

super(Model, self).__init__()
super().__init__()
self.bert = transformers.BertModel.from_pretrained("prajjwal1/bert-tiny", return_dict=False)
self.bert_drop = nn.Dropout(dropout)
self.out = nn.Linear(128, out_features)
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Setup.py for torchflare."""

# flake8: noqa
import os

Expand All @@ -15,15 +16,15 @@


readme_file_path = os.path.join(current_file_path, "README.md")
with open(readme_file_path, "r", encoding="utf-8") as f:
with open(readme_file_path, encoding="utf-8") as f:
readme = f.read()


version_file_path = os.path.join(current_file_path, "version.txt")
with open(version_file_path, "r", encoding="utf-8") as f:
with open(version_file_path, encoding="utf-8") as f:
version = f.read().strip()

with open(os.path.join(current_file_path, "requirements.txt"), "r") as f:
with open(os.path.join(current_file_path, "requirements.txt")) as f:
requirements = f.read().splitlines()


Expand Down
2 changes: 1 addition & 1 deletion tests/experiment/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class Model(torch.nn.Module):
def __init__(self, num_features, num_classes):
super(Model, self).__init__()
super().__init__()
self.model = torch.nn.Linear(num_features, num_classes)

def forward(self, x):
Expand Down
1 change: 1 addition & 0 deletions torchflare/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementation of Callbacks and CallbackRunner."""

from typing import TYPE_CHECKING, List

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion torchflare/callbacks/callback_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class FunctionalCallback(Callbacks):
"""

def __init__(self, func, order):
super(FunctionalCallback, self).__init__(order=order)
super().__init__(order=order)
self.func = func
functools.update_wrapper(self, func)

Expand Down
2 changes: 1 addition & 1 deletion torchflare/callbacks/comet_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
tags: List[str],
):
"""Constructor for CometLogger class."""
super(CometLogger, self).__init__(order=CallbackOrder.LOGGING)
super().__init__(order=CallbackOrder.LOGGING)
self.api_token = api_token
self.project_name = project_name
self.workspace = workspace
Expand Down
2 changes: 1 addition & 1 deletion torchflare/callbacks/criterion_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class AvgLoss(Callbacks):
"""Class for averaging the loss."""

def __init__(self):
super(AvgLoss, self).__init__(order=CallbackOrder.LOSS)
super().__init__(order=CallbackOrder.LOSS)
self.accum_loss, self.count = {}, 0
self.reset()

Expand Down
3 changes: 2 additions & 1 deletion torchflare/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementation of Early stopping."""

import math
from abc import ABC
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
min_delta: float = 1e-7,
):
"""Constructor for EarlyStopping class."""
super(EarlyStopping, self).__init__(order=CallbackOrder.STOPPING)
super().__init__(order=CallbackOrder.STOPPING)

if monitor.startswith("train_") or monitor.startswith("val_"):
self.monitor = monitor
Expand Down
1 change: 1 addition & 0 deletions torchflare/callbacks/extra_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implements extra utilities required."""

import math
from functools import partial

Expand Down
3 changes: 2 additions & 1 deletion torchflare/callbacks/load_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implements Load checkpoint."""

from abc import ABC
from typing import TYPE_CHECKING

Expand All @@ -16,7 +17,7 @@ class LoadCheckpoint(Callbacks, ABC):

def __init__(self, path_to_model: str = None):
"""Constructor method for LoadCheckpoint Class."""
super(LoadCheckpoint, self).__init__(order=CallbackOrder.MODEL_INIT)
super().__init__(order=CallbackOrder.MODEL_INIT)
self.path = path_to_model

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion torchflare/callbacks/lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implements LrScheduler callbacks."""

from abc import ABC
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Union

Expand All @@ -21,7 +22,7 @@ def __init__(self, scheduler, step_on_batch: bool):
scheduler: A pytorch scheduler
step_on_batch: Whether the scheduler steps after batch or not.
"""
super(LRSchedulerCallback, self).__init__(order=CallbackOrder.SCHEDULER)
super().__init__(order=CallbackOrder.SCHEDULER)
self._scheduler = scheduler
self.step_on_batch = step_on_batch
self.scheduler = None
Expand Down
5 changes: 3 additions & 2 deletions torchflare/callbacks/message_notifiers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implements notifiers for slack and discord."""

import json
from abc import ABC
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -42,7 +43,7 @@ class SlackNotifierCallback(Callbacks, ABC):

def __init__(self, webhook_url: str):
"""Constructor method for SlackNotifierCallback."""
super(SlackNotifierCallback, self).__init__(order=CallbackOrder.EXTERNAL)
super().__init__(order=CallbackOrder.EXTERNAL)
self.webhook_url = webhook_url

def on_epoch_end(self, experiment: "Experiment"):
Expand Down Expand Up @@ -81,7 +82,7 @@ class DiscordNotifierCallback(Callbacks, ABC):

def __init__(self, exp_name: str, webhook_url: str):
"""Constructor method for DiscordNotifierCallback."""
super(DiscordNotifierCallback, self).__init__(order=CallbackOrder.EXTERNAL)
super().__init__(order=CallbackOrder.EXTERNAL)
self.exp_name = exp_name
self.webhook_url = webhook_url

Expand Down
3 changes: 2 additions & 1 deletion torchflare/callbacks/metric_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implements container for loss and metric computation."""

from typing import TYPE_CHECKING, Dict, List

from torchmetrics import MetricCollection
Expand All @@ -23,7 +24,7 @@ def __init__(self, metrics: List = None):
Args:
metrics: The list of metrics
"""
super(MetricCallback, self).__init__(CallbackOrder.METRICS)
super().__init__(CallbackOrder.METRICS)
metrics = MetricCollection(metrics)
self.metrics = {
"train": metrics.clone(prefix="train_"),
Expand Down
3 changes: 2 additions & 1 deletion torchflare/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implements Model Checkpoint Callback."""

import os
from abc import ABC
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -62,7 +63,7 @@ class ModelCheckpoint(Callbacks, ABC):

def __init__(self, mode: str, monitor: str, save_dir: str = "./", file_name: str = "model.bin"):
"""Constructor for ModelCheckpoint class."""
super(ModelCheckpoint, self).__init__(order=CallbackOrder.CHECKPOINT)
super().__init__(order=CallbackOrder.CHECKPOINT)
if monitor.startswith("train_") or monitor.startswith("val_"):
self.monitor = monitor
else:
Expand Down
2 changes: 1 addition & 1 deletion torchflare/callbacks/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class History(Callbacks, ABC):

def __init__(self):
"""Constructor class for History Class."""
super(History, self).__init__(order=CallbackOrder.LOGGING)
super().__init__(order=CallbackOrder.LOGGING)
self.history = None

def on_experiment_start(self, experiment: "Experiment"):
Expand Down
3 changes: 2 additions & 1 deletion torchflare/callbacks/neptune_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implements Neptune Logger."""

from abc import ABC
from typing import TYPE_CHECKING, List

Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(
tags: List[str] = None,
):
"""Constructor for NeptuneLogger Class."""
super(NeptuneLogger, self).__init__(order=CallbackOrder.LOGGING)
super().__init__(order=CallbackOrder.LOGGING)
self.project_dir = project_dir
self.api_token = api_token
self.params = params
Expand Down
3 changes: 2 additions & 1 deletion torchflare/callbacks/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementation of Progress Bar."""

import math
import sys
import time
Expand Down Expand Up @@ -26,7 +27,7 @@ def __init__(
unit_name: str = "step",
):
"""Constructor class for ProgressBar."""
super(ProgressBar, self).__init__(order=CallbackOrder.EXTERNAL)
super().__init__(order=CallbackOrder.EXTERNAL)
self.num_epochs = None
self.width = width
self.interval = interval
Expand Down
1 change: 1 addition & 0 deletions torchflare/callbacks/states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Definitions of experiment states and Callback order."""

from enum import IntEnum


Expand Down
3 changes: 2 additions & 1 deletion torchflare/callbacks/tensorboard_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implements Tensorboard Logger."""

from abc import ABC
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -26,7 +27,7 @@ class TensorboardLogger(Callbacks, ABC):

def __init__(self, log_dir: str):
"""Constructor for TensorboardLogger class."""
super(TensorboardLogger, self).__init__(order=CallbackOrder.LOGGING)
super().__init__(order=CallbackOrder.LOGGING)
self.log_dir = log_dir
self._experiment = None

Expand Down
Loading
Loading