environment test
This commit is contained in:
19
document/FlashcardDocument.py
Normal file
19
document/FlashcardDocument.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from recsim import document
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
|
||||
class FlashcardDocument(document.AbstractDocument):
|
||||
def __init__(self, doc_id, difficulty):
|
||||
self.base_difficulty = difficulty
|
||||
# doc_id is an integer representing the unique ID of this document
|
||||
super(FlashcardDocument, self).__init__(doc_id)
|
||||
|
||||
def create_observation(self):
|
||||
return np.array(self.base_difficulty)
|
||||
|
||||
@staticmethod
|
||||
def observation_space():
|
||||
return spaces.Box(shape=(1,3), dtype=np.float32, low=0.0, high=1.0)
|
||||
|
||||
def __str__(self):
|
||||
return "Flashcard {} with difficulty {}.".format(self._doc_id, self.base_difficulty)
|
||||
14
document/FlashcardDocumentSampler.py
Normal file
14
document/FlashcardDocumentSampler.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .FlashcardDocument import FlashcardDocument
|
||||
from recsim import document
|
||||
|
||||
class FlashcardDocumentSampler(document.AbstractDocumentSampler):
|
||||
def __init__(self, doc_ctor=FlashcardDocument, **kwargs):
|
||||
super(FlashcardDocumentSampler, self).__init__(doc_ctor, **kwargs)
|
||||
self._doc_count = 0
|
||||
|
||||
def sample_document(self):
|
||||
doc_features = {}
|
||||
doc_features['doc_id'] = self._doc_count
|
||||
doc_features['difficulty'] = self._rng.random_sample((1, 3))
|
||||
self._doc_count += 1
|
||||
return self._doc_ctor(**doc_features)
|
||||
1
document/__init__.py
Normal file
1
document/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .FlashcardDocumentSampler import FlashcardDocumentSampler
|
||||
Reference in New Issue
Block a user