environment test

This commit is contained in:
2023-10-26 01:43:56 +08:00
commit 49529a9400
14 changed files with 554 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
from recsim import document
from gym import spaces
import numpy as np
class FlashcardDocument(document.AbstractDocument):
def __init__(self, doc_id, difficulty):
self.base_difficulty = difficulty
# doc_id is an integer representing the unique ID of this document
super(FlashcardDocument, self).__init__(doc_id)
def create_observation(self):
return np.array(self.base_difficulty)
@staticmethod
def observation_space():
return spaces.Box(shape=(1,3), dtype=np.float32, low=0.0, high=1.0)
def __str__(self):
return "Flashcard {} with difficulty {}.".format(self._doc_id, self.base_difficulty)

View File

@@ -0,0 +1,14 @@
from .FlashcardDocument import FlashcardDocument
from recsim import document
class FlashcardDocumentSampler(document.AbstractDocumentSampler):
def __init__(self, doc_ctor=FlashcardDocument, **kwargs):
super(FlashcardDocumentSampler, self).__init__(doc_ctor, **kwargs)
self._doc_count = 0
def sample_document(self):
doc_features = {}
doc_features['doc_id'] = self._doc_count
doc_features['difficulty'] = self._rng.random_sample((1, 3))
self._doc_count += 1
return self._doc_ctor(**doc_features)

1
document/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .FlashcardDocumentSampler import FlashcardDocumentSampler