Skip to content

Commit b403a4b

Browse files
authored
Auto-upgrade syntax via pyupgrade (#158)
* Auto-upgrade syntax via `pyupgrade` * Add changes
1 parent bea0d98 commit b403a4b

23 files changed

+54
-48
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
repos:
2+
- repo: https://github.com/asottile/pyupgrade
3+
rev: v2.31.1
4+
hooks:
5+
- id: pyupgrade
6+
args:
7+
- --py37-plus
28
- repo: https://github.com/PyCQA/isort
39
rev: 5.10.1
410
hooks:

cleanrl/c51.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def thunk():
8989
# ALGO LOGIC: initialize agent here:
9090
class QNetwork(nn.Module):
9191
def __init__(self, env, n_atoms=101, v_min=-100, v_max=100):
92-
super(QNetwork, self).__init__()
92+
super().__init__()
9393
self.env = env
9494
self.n_atoms = n_atoms
9595
self.register_buffer("atoms", torch.linspace(v_min, v_max, steps=n_atoms))

cleanrl/c51_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def thunk():
105105
# ALGO LOGIC: initialize agent here:
106106
class QNetwork(nn.Module):
107107
def __init__(self, env, n_atoms=101, v_min=-100, v_max=100):
108-
super(QNetwork, self).__init__()
108+
super().__init__()
109109
self.env = env
110110
self.n_atoms = n_atoms
111111
self.register_buffer("atoms", torch.linspace(v_min, v_max, steps=n_atoms))

cleanrl/ddpg_continuous_action.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def thunk():
8686
# ALGO LOGIC: initialize agent here:
8787
class QNetwork(nn.Module):
8888
def __init__(self, env):
89-
super(QNetwork, self).__init__()
89+
super().__init__()
9090
self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)
9191
self.fc2 = nn.Linear(256, 256)
9292
self.fc3 = nn.Linear(256, 1)
@@ -101,7 +101,7 @@ def forward(self, x, a):
101101

102102
class Actor(nn.Module):
103103
def __init__(self, env):
104-
super(Actor, self).__init__()
104+
super().__init__()
105105
self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
106106
self.fc2 = nn.Linear(256, 256)
107107
self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))

cleanrl/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def thunk():
8484
# ALGO LOGIC: initialize agent here:
8585
class QNetwork(nn.Module):
8686
def __init__(self, env):
87-
super(QNetwork, self).__init__()
87+
super().__init__()
8888
self.network = nn.Sequential(
8989
nn.Linear(np.array(env.single_observation_space.shape).prod(), 64),
9090
nn.Tanh(),

cleanrl/dqn_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def thunk():
100100
# ALGO LOGIC: initialize agent here:
101101
class QNetwork(nn.Module):
102102
def __init__(self, env):
103-
super(QNetwork, self).__init__()
103+
super().__init__()
104104
self.network = nn.Sequential(
105105
nn.Conv2d(4, 32, 8, stride=4),
106106
nn.ReLU(),

cleanrl/ppg_procgen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def get_output_shape(self):
142142

143143
class Agent(nn.Module):
144144
def __init__(self, envs):
145-
super(Agent, self).__init__()
145+
super().__init__()
146146
h, w, c = envs.single_observation_space.shape
147147
shape = (c, h, w)
148148
conv_seqs = []

cleanrl/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
103103

104104
class Agent(nn.Module):
105105
def __init__(self, envs):
106-
super(Agent, self).__init__()
106+
super().__init__()
107107
self.critic = nn.Sequential(
108108
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
109109
nn.Tanh(),

cleanrl/ppo_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
117117

118118
class Agent(nn.Module):
119119
def __init__(self, envs):
120-
super(Agent, self).__init__()
120+
super().__init__()
121121
self.network = nn.Sequential(
122122
layer_init(nn.Conv2d(4, 32, 8, stride=4)),
123123
nn.ReLU(),

cleanrl/ppo_atari_envpool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def parse_args():
8181

8282
class RecordEpisodeStatistics(gym.Wrapper):
8383
def __init__(self, env, deque_size=100):
84-
super(RecordEpisodeStatistics, self).__init__(env)
84+
super().__init__(env)
8585
self.num_envs = getattr(env, "num_envs", 1)
8686
self.episode_returns = None
8787
self.episode_lengths = None
@@ -94,7 +94,7 @@ def __init__(self, env, deque_size=100):
9494
print("env has lives")
9595

9696
def reset(self, **kwargs):
97-
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
97+
observations = super().reset(**kwargs)
9898
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
9999
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
100100
self.lives = np.zeros(self.num_envs, dtype=np.int32)
@@ -103,7 +103,7 @@ def reset(self, **kwargs):
103103
return observations
104104

105105
def step(self, action):
106-
observations, rewards, dones, infos = super(RecordEpisodeStatistics, self).step(action)
106+
observations, rewards, dones, infos = super().step(action)
107107
self.episode_returns += infos["reward"]
108108
self.episode_lengths += 1
109109
self.returned_episode_returns[:] = self.episode_returns
@@ -133,7 +133,7 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
133133

134134
class Agent(nn.Module):
135135
def __init__(self, envs):
136-
super(Agent, self).__init__()
136+
super().__init__()
137137
self.network = nn.Sequential(
138138
layer_init(nn.Conv2d(4, 32, 8, stride=4)),
139139
nn.ReLU(),

0 commit comments

Comments
 (0)