mirror of
https://github.com/gmihaila/ml_things.git
synced 2021-10-04 01:29:04 +03:00
Create testing.py
This commit is contained in:
83
reinforcement_learning/testing.py
Normal file
83
reinforcement_learning/testing.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Unit testing
|
||||||
|
|
||||||
|
To run unit testing:
|
||||||
|
|
||||||
|
python -m unittest test_dqn_helper_functions
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from dqn_helper_functions import ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplayBuffer(unittest.TestCase):
|
||||||
|
def test_type(self):
|
||||||
|
# Make sure type erors are raised when necessary
|
||||||
|
self.assertRaises(TypeError, ReplayBuffer, (1.0, 20, 3, True))
|
||||||
|
self.assertRaises(TypeError, ReplayBuffer, (1, 2.0, 3, True))
|
||||||
|
self.assertRaises(TypeError, ReplayBuffer, (1, 2, 3.0, True))
|
||||||
|
self.assertRaises(TypeError, ReplayBuffer, (1, 2, 3, 'value'))
|
||||||
|
|
||||||
|
def test_collect_experience(self):
|
||||||
|
# test a sample of collect_experience
|
||||||
|
max_size = 2
|
||||||
|
input_shape = 20
|
||||||
|
n_actions = 3
|
||||||
|
discrete = False
|
||||||
|
|
||||||
|
memory = ReplayBuffer(max_size, input_shape, n_actions, discrete)
|
||||||
|
|
||||||
|
state = np.random.rand(input_shape)
|
||||||
|
action = np.random.rand(n_actions)
|
||||||
|
reward = np.random.randint(10)
|
||||||
|
state_= np.random.rand(input_shape)
|
||||||
|
done = False
|
||||||
|
|
||||||
|
memory.collect_experience(state, action, reward, state_, done)
|
||||||
|
|
||||||
|
self.assertEqual(memory.mem_cntr, 1)
|
||||||
|
np.testing.assert_array_equal(memory.terminal_memory, [1.,0.])
|
||||||
|
np.testing.assert_array_equal(memory.reward_memory, [reward, 0])
|
||||||
|
np.testing.assert_array_equal(memory.action_memory, np.array([action, np.zeros(n_actions)],
|
||||||
|
dtype=np.float32))
|
||||||
|
np.testing.assert_array_equal(memory.new_observation_memory, np.array([state_, np.zeros(input_shape)]))
|
||||||
|
np.testing.assert_array_equal(memory.observation_memory, np.array([state, np.zeros(input_shape)]))
|
||||||
|
|
||||||
|
def test_sample_buffer(self):
|
||||||
|
# test a sample of collect_experience
|
||||||
|
max_size = 2
|
||||||
|
input_shape = 20
|
||||||
|
n_actions = 3
|
||||||
|
discrete = False
|
||||||
|
batch_size = 2
|
||||||
|
# create memory
|
||||||
|
memory = ReplayBuffer(max_size, input_shape, n_actions, discrete)
|
||||||
|
# create dummy environment
|
||||||
|
state = np.random.rand(input_shape)
|
||||||
|
action = np.random.rand(n_actions)
|
||||||
|
reward = np.random.randint(10)
|
||||||
|
state_ = np.random.rand(input_shape)
|
||||||
|
done = False
|
||||||
|
# perform 2 collection
|
||||||
|
memory.collect_experience(state, action, reward, state_, done)
|
||||||
|
memory.collect_experience(state, action, reward, state_, done)
|
||||||
|
# get sample
|
||||||
|
observations, actions, rewards, new_observations, terminal = memory.sample_buffer(batch_size)
|
||||||
|
# test observations
|
||||||
|
np.testing.assert_array_equal(observations, np.array([state, state]))
|
||||||
|
# test observations
|
||||||
|
np.testing.assert_array_equal(new_observations, np.array([state_, state_]))
|
||||||
|
# test actions
|
||||||
|
np.testing.assert_array_equal(actions, np.array([action, action], dtype=np.float32))
|
||||||
|
# test rewards
|
||||||
|
np.testing.assert_array_equal(rewards, np.array([reward, reward]))
|
||||||
|
# test finished episode
|
||||||
|
np.testing.assert_array_equal(terminal, np.array([1, 1], dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
class TestDqnAgent(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_type(self):
|
||||||
|
# Make sure type erors are raised when necessary
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user