Files
DRLwithTL-drone/network/Memory.py
Anwar, Malik Aqeel d4e1f1136c Initial Commit
2019-10-14 18:15:33 -04:00

36 lines
915 B
Python

# Code used from https://github.com/rlcode/per
from network.SumTree import SumTree
import random
#-------------------- MEMORY --------------------------
class Memory: # stored as ( s, a, r, s_ ) in SumTree
e = 0.01
a = 0.6
def __init__(self, capacity):
self.tree = SumTree(capacity)
def _getPriority(self, error):
return (error + self.e) ** self.a
def add(self, error, sample):
p = self._getPriority(error)
self.tree.add(p, sample)
def sample(self, n):
batch = []
segment = self.tree.total() / n
for i in range(n):
a = segment * i
b = segment * (i + 1)
s = random.uniform(a, b)
(idx, p, data) = self.tree.get(s)
batch.append( (idx, data) )
return batch
def update(self, idx, error):
p = self._getPriority(error)
self.tree.update(idx, p)