Skip to content

Commit cbafb49

Browse files
committed
fix: enforce flake8 and various code lints
1 parent 974d9a9 commit cbafb49

File tree

21 files changed

+161
-103
lines changed

21 files changed

+161
-103
lines changed

.pre-commit-config.yaml

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,22 @@
22
# See https://pre-commit.com/hooks.html for more hooks
33
repos:
44
- repo: https://github.com/pre-commit/pre-commit-hooks
5-
rev: v3.2.0
5+
rev: v4.1.0
66
hooks:
7-
- id: trailing-whitespace
8-
- id: end-of-file-fixer
9-
- id: check-yaml
7+
- id: check-case-conflict
8+
- id: check-json
9+
- id: check-symlinks
10+
- id: check-yaml
11+
- id: destroyed-symlinks
12+
- id: end-of-file-fixer
13+
exclude: docs/CNAME
14+
- id: fix-byte-order-marker
15+
- id: fix-encoding-pragma
16+
args: [--remove]
17+
- id: mixed-line-ending
18+
args: [--fix=lf]
19+
- id: requirements-txt-fixer
20+
- id: trailing-whitespace
1021
- repo: https://github.com/psf/black
1122
rev: 22.10.0
1223
hooks:
@@ -17,3 +28,7 @@ repos:
1728
hooks:
1829
- id: isort
1930
name: isort (python)
31+
- repo: https://github.com/pycqa/flake8
32+
rev: 6.0.0
33+
hooks:
34+
- id: flake8

docs/make.bat

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,35 @@
1-
@ECHO OFF
2-
3-
pushd %~dp0
4-
5-
REM Command file for Sphinx documentation
6-
7-
if "%SPHINXBUILD%" == "" (
8-
set SPHINXBUILD=sphinx-build
9-
)
10-
set SOURCEDIR=source
11-
set BUILDDIR=build
12-
13-
if "%1" == "" goto help
14-
15-
%SPHINXBUILD% >NUL 2>NUL
16-
if errorlevel 9009 (
17-
echo.
18-
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19-
echo.installed, then set the SPHINXBUILD environment variable to point
20-
echo.to the full path of the 'sphinx-build' executable. Alternatively you
21-
echo.may add the Sphinx directory to PATH.
22-
echo.
23-
echo.If you don't have Sphinx installed, grab it from
24-
echo.https://www.sphinx-doc.org/
25-
exit /b 1
26-
)
27-
28-
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29-
goto end
30-
31-
:help
32-
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33-
34-
:end
35-
popd
1+
@ECHO OFF
2+
3+
pushd %~dp0
4+
5+
REM Command file for Sphinx documentation
6+
7+
if "%SPHINXBUILD%" == "" (
8+
set SPHINXBUILD=sphinx-build
9+
)
10+
set SOURCEDIR=source
11+
set BUILDDIR=build
12+
13+
if "%1" == "" goto help
14+
15+
%SPHINXBUILD% >NUL 2>NUL
16+
if errorlevel 9009 (
17+
echo.
18+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19+
echo.installed, then set the SPHINXBUILD environment variable to point
20+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
21+
echo.may add the Sphinx directory to PATH.
22+
echo.
23+
echo.If you don't have Sphinx installed, grab it from
24+
echo.https://www.sphinx-doc.org/
25+
exit /b 1
26+
)
27+
28+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29+
goto end
30+
31+
:help
32+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33+
34+
:end
35+
popd

docs/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
sphinx==4.0.0
2-
sphinx_rtd_theme
31
accelerate==0.12.0
42
datasets==2.4.0
53
deepspeed==0.7.3
64
einops==0.4.1
75
numpy==1.23.2
6+
sphinx==4.0.0
7+
sphinx_rtd_theme
8+
torchtyping
89
tqdm==4.64.0
910
transformers==4.21.2
1011
wandb==0.13.2
11-
torchtyping

docs/source/conf.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import os
1414
import sys
1515

16+
import sphinx_rtd_theme
17+
1618
sys.path.insert(0, os.path.abspath('../..'))
1719

1820

@@ -28,16 +30,7 @@
2830
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
2931
# ones.
3032

31-
import sphinx_rtd_theme
32-
33-
extensions = [
34-
'sphinx_rtd_theme',
35-
'sphinx.ext.todo',
36-
'sphinx.ext.viewcode',
37-
'sphinx.ext.autodoc',
38-
'sphinx.ext.autosummary',
39-
'sphinx.ext.autosectionlabel'
40-
]
33+
extensions = ['sphinx_rtd_theme', 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.autosectionlabel']
4134

4235
# Add any paths that contain templates here, relative to this directory.
4336
templates_path = ['_templates']

examples/experiments/grounded_program_synthesis/lang.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# flake8: noqa
12
import copy
23
import json
34
import random

examples/randomwalks/randomwalks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def randexclude(rng: np.random.RandomState, n: int, exclude: int) -> int:
1010
return x
1111

1212

13-
def generate_random_walks(
13+
def generate_random_walks( # noqa: max-complexity
1414
n_nodes=21, max_length=10, n_walks=1000, p_edge=0.1, seed=1002, gpt2_tokenizer=False
1515
):
1616
rng = np.random.RandomState(seed)
@@ -30,6 +30,8 @@ def generate_random_walks(
3030

3131
goal = 0
3232
sample_walks = []
33+
delimiter = "|" if gpt2_tokenizer else ""
34+
3335
for _ in range(n_walks):
3436
node = randexclude(rng, n_nodes, goal)
3537
walk = [node]
@@ -43,7 +45,6 @@ def generate_random_walks(
4345
# code each node by a letter
4446
# for bpe tokenizer join them over | for a guaranteed split
4547
walk = [node_to_char[ix] for ix in walk]
46-
delimiter = "|" if gpt2_tokenizer else ""
4748

4849
sample_walks.append(delimiter.join(walk))
4950

setup.cfg

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,28 @@ dev =
3737
exclude =
3838
docs*
3939
tests*
40+
41+
[flake8]
42+
max-complexity = 10
43+
max-line-length = 127
44+
# flake8 error codes: https://flake8.pycqa.org/en/latest/user/error-codes.html
45+
# pycodestyle codes: https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes
46+
# E203 # whitespace before ‘,’, ‘;’, or ‘:’
47+
# E741 # do not use variables named ‘l’, ‘O’, or ‘I’
48+
# F401 # module imported but unused
49+
# F821 # undefined name name
50+
# W503 # line break before binary operator
51+
# W605 # invalid escape sequence ‘x’
52+
ignore =
53+
E203
54+
E741
55+
F821
56+
W503
57+
W605
58+
per-file-ignores = __init__.py:F401,loading.py:F401
59+
exclude =
60+
.git
61+
__pycache__
62+
docs/source/conf.py
63+
build
64+
dist

tests/test_ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_lm_heads(self):
4848
def test_frozen_head(self):
4949
# Ensure that all parameters of the `hydra_model.frozen_head` are actually frozen
5050
for parameter in TestHydraHead.hydra_model.frozen_head.parameters():
51-
self.assertTrue(parameter.requires_grad == False)
51+
self.assertTrue(parameter.requires_grad is False)
5252

5353
def test_forward(self):
5454
with torch.no_grad():

trlx/data/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import random
21
from dataclasses import dataclass
3-
from typing import Any, Callable, Iterable, List
2+
from typing import Any, Iterable
43

54
from torchtyping import TensorType
65

6+
from . import configs
7+
78

89
@dataclass
910
class GeneralElement:

trlx/data/configs.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from dataclasses import dataclass
2-
from typing import Any, Dict, Optional, Set
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, Optional, Set, Tuple
33

44
import yaml
55

@@ -56,7 +56,7 @@ class OptimizerConfig:
5656
"""
5757

5858
name: str
59-
kwargs: Dict[str, Any] = None
59+
kwargs: Dict[str, Any] = field(default_factory=dict)
6060

6161
@classmethod
6262
def from_dict(cls, config: Dict[str, Any]):
@@ -76,7 +76,7 @@ class SchedulerConfig:
7676
"""
7777

7878
name: str
79-
kwargs: Dict[str, Any] = None
79+
kwargs: Dict[str, Any] = field(default_factory=dict)
8080

8181
@classmethod
8282
def from_dict(cls, config: Dict[str, Any]):
@@ -100,6 +100,9 @@ class TrainConfig:
100100
:param batch_size: Batch size for training
101101
:type batch_size: int
102102
103+
:param trackers: Tuple of trackers to use for logging. Default: ("wandb",)
104+
:type trackers: Tuple[str]
105+
103106
:param checkpoint_interval: Save model every checkpoint_interval steps
104107
:type checkpoint_interval: int
105108
@@ -123,7 +126,8 @@ class TrainConfig:
123126
:param checkpoint_dir: Directory to save checkpoints
124127
:type checkpoint_dir: str
125128
126-
:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOTrainer.
129+
:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation.
130+
Only used by AcceleratePPOTrainer.
127131
:type rollout_logging_dir: Optional[str]
128132
129133
:param seed: Random seed
@@ -148,6 +152,7 @@ class TrainConfig:
148152
checkpoint_dir: str = "ckpts"
149153
rollout_logging_dir: Optional[str] = None
150154

155+
trackers: Tuple[str] = ("wandb",)
151156
seed: int = 1000
152157

153158
@classmethod

0 commit comments

Comments
 (0)