Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
C
Cart Pole problem
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Naman Saxena
Cart Pole problem
Commits
e701bb83
Commit
e701bb83
authored
9 months ago
by
Naman Saxena
Browse files
Options
Downloads
Patches
Plain Diff
Upload New File
parent
69a96a28
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
control_algorithm.py
+310
-0
310 additions, 0 deletions
control_algorithm.py
with
310 additions
and
0 deletions
control_algorithm.py
0 → 100644
+
310
−
0
View file @
e701bb83
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
collections
import
deque
import
random
from
abc
import
ABC
,
abstractmethod
import
numpy
as
np
import
config
from
exploration_strategies
import
EpsilonGreedyStrategy
,
SoftmaxStrategy
class
ControlAlgorithm
(
ABC
):
def
__init__
(
self
,
control_params
,
exploration_strategy
):
self
.
params
=
control_params
self
.
exploration_strategy
=
exploration_strategy
@abstractmethod
def
get_action
(
self
,
state
):
pass
@abstractmethod
def
update
(
self
,
state
,
action
,
reward
,
next_state
,
done
):
pass
@abstractmethod
def
_discretize_state
(
self
,
state
):
pass
class
DQNNetwork
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
output_dim
):
super
(
DQNNetwork
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
input_dim
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
128
)
self
.
fc3
=
nn
.
Linear
(
128
,
output_dim
)
def
forward
(
self
,
x
):
x
=
torch
.
relu
(
self
.
fc1
(
x
))
x
=
torch
.
relu
(
self
.
fc2
(
x
))
return
self
.
fc3
(
x
)
# 2. Define the Replay Buffer class
class
ReplayBuffer
:
def
__init__
(
self
,
capacity
):
self
.
buffer
=
deque
(
maxlen
=
capacity
)
def
add
(
self
,
experience
):
self
.
buffer
.
append
(
experience
)
def
sample
(
self
,
batch_size
):
return
random
.
sample
(
self
.
buffer
,
batch_size
)
def
size
(
self
):
return
len
(
self
.
buffer
)
# Continue with your other classes like QLearningControl or DQNControl
# 3. Define the DQNControl class (As described in the previous step)
class
DQNControl
(
ControlAlgorithm
):
def
__init__
(
self
,
control_params
,
exploration_strategy
,
state_dim
,
action_dim
):
super
().
__init__
(
control_params
,
exploration_strategy
)
self
.
q_network
=
DQNNetwork
(
state_dim
,
action_dim
)
self
.
target_network
=
DQNNetwork
(
state_dim
,
action_dim
)
self
.
target_network
.
load_state_dict
(
self
.
q_network
.
state_dict
())
self
.
optimizer
=
optim
.
Adam
(
self
.
q_network
.
parameters
(),
lr
=
control_params
[
'
learning_rate
'
])
self
.
criterion
=
nn
.
MSELoss
()
self
.
replay_buffer
=
ReplayBuffer
(
1000
)
self
.
batch_size
=
control_params
.
get
(
'
batch_size
'
,
64
)
self
.
gamma
=
control_params
[
'
discount_factor
'
]
self
.
epsilon
=
control_params
[
'
epsilon
'
]
self
.
min_epsilon
=
control_params
.
get
(
'
min_epsilon
'
,
0.01
)
self
.
decay_rate
=
control_params
.
get
(
'
decay_rate
'
,
0.995
)
self
.
update_target_steps
=
control_params
.
get
(
'
update_target_steps
'
,
1000
)
self
.
steps
=
0
self
.
model
=
DQNNetwork
(
state_dim
,
action_dim
)
def
get_action
(
self
,
state
,
explore
=
True
):
"""
Decide action based on exploration or exploitation.
"""
if
explore
and
np
.
random
.
rand
()
<
self
.
epsilon
:
# Exploration: random action
return
np
.
random
.
choice
(
self
.
q_network
.
fc3
.
out_features
)
else
:
# Exploitation: best action
state
=
torch
.
tensor
(
state
,
dtype
=
torch
.
float32
).
unsqueeze
(
0
)
with
torch
.
no_grad
():
q_values
=
self
.
q_network
(
state
)
return
q_values
.
argmax
().
item
()
def
update
(
self
,
state
,
action
,
reward
,
next_state
,
done
):
# Add experience to the replay buffer
self
.
replay_buffer
.
add
((
state
,
action
,
reward
,
next_state
,
done
))
# Only start learning once we have enough samples in the buffer
if
self
.
replay_buffer
.
size
()
<
self
.
batch_size
:
return
# Sample a batch of experiences from the replay buffer
experiences
=
self
.
replay_buffer
.
sample
(
self
.
batch_size
)
states
,
actions
,
rewards
,
next_states
,
dones
=
zip
(
*
experiences
)
# Convert to torch tensors
states
=
torch
.
tensor
(
states
,
dtype
=
torch
.
float32
)
actions
=
torch
.
tensor
(
actions
,
dtype
=
torch
.
int64
)
rewards
=
torch
.
tensor
(
rewards
,
dtype
=
torch
.
float32
)
next_states
=
torch
.
tensor
(
next_states
,
dtype
=
torch
.
float32
)
dones
=
torch
.
tensor
(
dones
,
dtype
=
torch
.
float32
)
# Calculate the current Q values
current_q_values
=
self
.
q_network
(
states
).
gather
(
1
,
actions
.
unsqueeze
(
1
)).
squeeze
(
1
)
# Calculate the next Q values using the target network
with
torch
.
no_grad
():
next_q_values
=
self
.
target_network
(
next_states
).
max
(
1
)[
0
]
# Calculate the target Q values
target_q_values
=
rewards
+
self
.
gamma
*
next_q_values
*
(
1
-
dones
)
# Compute loss
loss
=
self
.
criterion
(
current_q_values
,
target_q_values
)
# Perform gradient descent
self
.
optimizer
.
zero_grad
()
loss
.
backward
()
self
.
optimizer
.
step
()
# Update the target network periodically
if
self
.
steps
%
self
.
update_target_steps
==
0
:
self
.
target_network
.
load_state_dict
(
self
.
q_network
.
state_dict
())
self
.
steps
+=
1
def
decay_epsilon
(
self
):
"""
Decay the epsilon value over time.
"""
self
.
epsilon
=
max
(
self
.
min_epsilon
,
self
.
epsilon
*
self
.
decay_rate
)
def
_discretize_state
(
self
,
state
):
# DQN typically doesn't discretize states, so we return the state as is
return
state
class
QLearningControl
(
ControlAlgorithm
):
def
__init__
(
self
,
control_params
,
exploration_strategy
):
super
().
__init__
(
control_params
,
exploration_strategy
)
self
.
q_table
=
{}
self
.
epsilon
=
control_params
[
'
epsilon
'
]
self
.
min_epsilon
=
control_params
.
get
(
'
min_epsilon
'
,
0.01
)
self
.
decay_rate
=
control_params
.
get
(
'
decay_rate
'
,
0.995
)
def
get_action
(
self
,
state
,
epsilon
=
0.1
):
"""
Selects an action using epsilon-greedy strategy.
Args:
state: The current state.
epsilon: Probability of choosing a random action (exploration).
Returns:
action: The selected action.
"""
state
=
self
.
_discretize_state
(
state
)
if
state
not
in
self
.
q_table
:
self
.
q_table
[
state
]
=
[
0
,
0
]
# Initialize Q-values for unseen state
#action = self.exploration_strategy.select_action(self.q_table[state])
if
np
.
random
.
rand
()
<
self
.
epsilon
:
action
=
np
.
random
.
choice
(
len
(
self
.
q_table
[
state
]))
# Exploration: random action
else
:
action
=
np
.
argmax
(
self
.
q_table
[
state
])
# Exploitation: best action
return
action
def
update
(
self
,
state
,
action
,
reward
,
next_state
,
done
):
state
=
self
.
_discretize_state
(
state
)
next_state
=
self
.
_discretize_state
(
next_state
)
if
state
not
in
self
.
q_table
:
self
.
q_table
[
state
]
=
[
0
,
0
]
if
next_state
not
in
self
.
q_table
:
self
.
q_table
[
next_state
]
=
[
0
,
0
]
current_q
=
self
.
q_table
[
state
][
action
]
max_next_q
=
np
.
max
(
self
.
q_table
[
next_state
])
new_q
=
current_q
+
self
.
params
[
'
learning_rate
'
]
*
(
reward
+
self
.
params
[
'
discount_factor
'
]
*
max_next_q
-
current_q
)
self
.
q_table
[
state
][
action
]
=
new_q
def
update_learning_rate
(
self
,
new_lr
):
"""
Dynamically update the learning rate during training.
"""
self
.
learning_rate
=
new_lr
def
decay_epsilon
(
self
):
"""
Decay the epsilon value over time.
"""
self
.
epsilon
=
max
(
self
.
min_epsilon
,
self
.
epsilon
*
self
.
decay_rate
)
def
_discretize_state
(
self
,
state
):
return
tuple
(
np
.
round
(
x
,
1
)
for
x
in
state
)
'''
class SarsaControl(ControlAlgorithm):
def __init__(self, control_params, exploration_strategy):
super().__init__(control_params, exploration_strategy)
self.q_table = {}
self.state_bins = self._create_bins() # Create bins for discretization
def _create_bins(self):
# Example binning for CartPole state
# Adjust the number of bins and limits based on your environment
'
s state space
bins = {
'
x
'
: np.linspace(-4.8, 4.8, 10), # Cart position
'
x_dot
'
: np.linspace(-3.0, 3.0, 10), # Cart velocity
'
theta
'
: np.linspace(-0.418, 0.418, 10), # Pole angle (in radians)
'
theta_dot
'
: np.linspace(-2.0, 2.0, 10) # Pole angular velocity
}
return bins
def _discretize_state(self, state):
# Discretize the state into bins
x, x_dot, theta, theta_dot = state
state_discrete = (
np.digitize(x, self.state_bins[
'
x
'
]) - 1,
np.digitize(x_dot, self.state_bins[
'
x_dot
'
]) - 1,
np.digitize(theta, self.state_bins[
'
theta
'
]) - 1,
np.digitize(theta_dot, self.state_bins[
'
theta_dot
'
]) - 1
)
return state_discrete
def get_action(self, state):
state = self._discretize_state(state)
if state not in self.q_table:
self.q_table[state] = [0, 0] # Initialize Q-values for unseen state
action = self.exploration_strategy.select_action(self.q_table[state])
return action
def update(self, state, action, reward, next_state, next_action, done):
state = self._discretize_state(state)
next_state = self._discretize_state(next_state)
if state not in self.q_table:
self.q_table[state] = [0, 0]
if next_state not in self.q_table:
self.q_table[next_state] = [0, 0]
# SARSA update rule
current_q = self.q_table[state][action]
next_q = self.q_table[next_state][next_action]
new_q = current_q + self.params[
'
learning_rate
'
] * (reward + self.params[
'
discount_factor
'
] * next_q - current_q)
self.q_table[state][action] = new_q
'''
class
SarsaControl
(
ControlAlgorithm
):
def
__init__
(
self
,
control_params
,
exploration_strategy
):
super
().
__init__
(
control_params
,
exploration_strategy
)
self
.
q_table
=
{}
self
.
state_bins
=
self
.
_create_bins
()
# Create bins for discretization
self
.
epsilon
=
self
.
params
[
'
epsilon
'
]
# Initial epsilon
self
.
epsilon_decay
=
self
.
params
.
get
(
'
epsilon_decay
'
,
0.99
)
# Decay factor
self
.
min_epsilon
=
self
.
params
.
get
(
'
min_epsilon
'
,
0.01
)
# Minimum epsilon
def
_create_bins
(
self
):
# Example binning for CartPole state
bins
=
{
'
x
'
:
np
.
linspace
(
-
4.8
,
4.8
,
10
),
# Cart position
'
x_dot
'
:
np
.
linspace
(
-
3.0
,
3.0
,
10
),
# Cart velocity
'
theta
'
:
np
.
linspace
(
-
0.418
,
0.418
,
10
),
# Pole angle (in radians)
'
theta_dot
'
:
np
.
linspace
(
-
2.0
,
2.0
,
10
)
# Pole angular velocity
}
return
bins
def
_discretize_state
(
self
,
state
):
x
,
x_dot
,
theta
,
theta_dot
=
state
state_discrete
=
(
np
.
digitize
(
x
,
self
.
state_bins
[
'
x
'
])
-
1
,
np
.
digitize
(
x_dot
,
self
.
state_bins
[
'
x_dot
'
])
-
1
,
np
.
digitize
(
theta
,
self
.
state_bins
[
'
theta
'
])
-
1
,
np
.
digitize
(
theta_dot
,
self
.
state_bins
[
'
theta_dot
'
])
-
1
)
return
state_discrete
def
get_action
(
self
,
state
):
state
=
self
.
_discretize_state
(
state
)
if
state
not
in
self
.
q_table
:
self
.
q_table
[
state
]
=
[
0
,
0
]
# Initialize Q-values for unseen state
if
np
.
random
.
rand
()
<
self
.
epsilon
:
action
=
np
.
random
.
choice
(
len
(
self
.
q_table
[
state
]))
# Exploration
else
:
action
=
np
.
argmax
(
self
.
q_table
[
state
])
# Exploitation
return
action
def
update
(
self
,
state
,
action
,
reward
,
next_state
,
next_action
,
done
):
state
=
self
.
_discretize_state
(
state
)
next_state
=
self
.
_discretize_state
(
next_state
)
if
state
not
in
self
.
q_table
:
self
.
q_table
[
state
]
=
[
0
,
0
]
if
next_state
not
in
self
.
q_table
:
self
.
q_table
[
next_state
]
=
[
0
,
0
]
# SARSA update rule
current_q
=
self
.
q_table
[
state
][
action
]
next_q
=
self
.
q_table
[
next_state
][
next_action
]
new_q
=
current_q
+
self
.
params
[
'
learning_rate
'
]
*
(
reward
+
self
.
params
[
'
discount_factor
'
]
*
next_q
-
current_q
)
self
.
q_table
[
state
][
action
]
=
new_q
def
decay_epsilon
(
self
):
self
.
epsilon
=
max
(
self
.
min_epsilon
,
self
.
epsilon
*
self
.
epsilon_decay
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment