mirror of
https://github.com/aqeelanwar/DRLwithTL.git
synced 2023-04-23 23:24:24 +03:00
36 lines
915 B
Python
36 lines
915 B
Python
# Code used from https://github.com/rlcode/per
|
|
|
|
from network.SumTree import SumTree
|
|
import random
|
|
#-------------------- MEMORY --------------------------
|
|
class Memory: # stored as ( s, a, r, s_ ) in SumTree
|
|
e = 0.01
|
|
a = 0.6
|
|
|
|
def __init__(self, capacity):
|
|
self.tree = SumTree(capacity)
|
|
|
|
def _getPriority(self, error):
|
|
return (error + self.e) ** self.a
|
|
|
|
def add(self, error, sample):
|
|
p = self._getPriority(error)
|
|
self.tree.add(p, sample)
|
|
|
|
def sample(self, n):
|
|
batch = []
|
|
segment = self.tree.total() / n
|
|
|
|
for i in range(n):
|
|
a = segment * i
|
|
b = segment * (i + 1)
|
|
|
|
s = random.uniform(a, b)
|
|
(idx, p, data) = self.tree.get(s)
|
|
batch.append( (idx, data) )
|
|
|
|
return batch
|
|
|
|
def update(self, idx, error):
|
|
p = self._getPriority(error)
|
|
self.tree.update(idx, p) |