diff --git a/document/FlashcardDocumentSampler.py b/document/FlashcardDocumentSampler.py index 3cdf9d1..a8a1f96 100644 --- a/document/FlashcardDocumentSampler.py +++ b/document/FlashcardDocumentSampler.py @@ -2,13 +2,13 @@ from .FlashcardDocument import FlashcardDocument from recsim import document class FlashcardDocumentSampler(document.AbstractDocumentSampler): - def __init__(self, doc_ctor=FlashcardDocument, **kwargs): - super(FlashcardDocumentSampler, self).__init__(doc_ctor, **kwargs) + def __init__(self, doc_ctor=FlashcardDocument, seed=0, **kwargs): + super(FlashcardDocumentSampler, self).__init__(doc_ctor, seed, **kwargs) self._doc_count = 0 def sample_document(self): doc_features = {} doc_features['doc_id'] = self._doc_count - doc_features['difficulty'] = self._rng.uniform(0, 5, (1, 3)) + doc_features['difficulty'] = self._rng.uniform(0, 3, (1, 3)) self._doc_count += 1 return self._doc_ctor(**doc_features) \ No newline at end of file diff --git a/main.py b/main.py index 7cef75c..fe9ca21 100644 --- a/main.py +++ b/main.py @@ -17,8 +17,8 @@ tf.compat.v1.disable_eager_execution() create_agent_fn = create_agent_helper(full_slate_q_agent.FullSlateQAgent) ltsenv = environment.Environment( - FlashcardUserModel(num_candidates, time_budget, slate_size), - FlashcardDocumentSampler(), + FlashcardUserModel(num_candidates, time_budget, slate_size, seed=0, sample_seed=0), + FlashcardDocumentSampler(seed=0), num_candidates, slate_size, resample_documents=False) diff --git a/user/FlashcardUserModel.py b/user/FlashcardUserModel.py index 0ef3e06..959d306 100644 --- a/user/FlashcardUserModel.py +++ b/user/FlashcardUserModel.py @@ -7,13 +7,14 @@ from util import eval_result import numpy as np class FlashcardUserModel(user.AbstractUserModel): - def __init__(self, num_candidates, time_budget, slate_size, seed=0): + def __init__(self, num_candidates, time_budget, slate_size, seed=0, sample_seed=0): super(FlashcardUserModel, self).__init__( UserResponse, UserSampler( UserState, num_candidates, time_budget, - seed=seed + seed=sample_seed ), slate_size) self.choice_model = MultinomialLogitChoiceModel({}) + self._rng = np.random.RandomState(seed) def is_terminal(self): terminated = self._user_state._time > self._user_state._time_budget @@ -52,7 +53,8 @@ class FlashcardUserModel(user.AbstractUserModel): doc_id = doc._doc_id W = self._user_state._W[doc_id] if not W.any(): # uninitialzed - self._user_state._W[doc_id] = W = doc.base_difficulty * np.random.uniform(0.5, 2.0, (1, 3)) # a uniform error for each user + error = self._user_state._doc_error[doc_id] # a uniform error for each user + self._user_state._W[doc_id] = W = doc.base_difficulty * error print(W) # use exponential function to simulate whether the user recalls last_review = self._user_state._time - self._user_state._last_review[doc_id] @@ -60,6 +62,6 @@ class FlashcardUserModel(user.AbstractUserModel): pr = np.exp(-last_review / np.exp(np.dot(W, x))).squeeze() print(f"time: {self._user_state._time}, reviewing flashcard {doc_id}, recall rate = {pr}") - if np.random.rand() < pr: # remembered + if self._rng.random_sample() < pr: # remembered response._recall = True response._pr = pr \ No newline at end of file diff --git a/user/UserSampler.py b/user/UserSampler.py index fc1613b..fa282dc 100644 --- a/user/UserSampler.py +++ b/user/UserSampler.py @@ -7,9 +7,13 @@ class UserSampler(user.AbstractUserSampler): num_candidates=10, time_budget=60, **kwargs): - self._state_parameters = {'num_candidates': num_candidates, 'time_budget': time_budget} super(UserSampler, self).__init__(user_ctor, **kwargs) - + doc_error = self._rng.uniform(0.5, 1.5, (num_candidates, 3)) + self._state_parameters = { + 'num_candidates': num_candidates, + 'time_budget': time_budget, + 'doc_error': doc_error + } def sample_user(self): return self._user_ctor(**self._state_parameters) \ No newline at end of file diff --git a/user/UserState.py b/user/UserState.py index c253668..f554f65 100644 --- a/user/UserState.py +++ b/user/UserState.py @@ -3,13 +3,14 @@ import numpy as np from gym import spaces class UserState(user.AbstractUserState): - def __init__(self, num_candidates, time_budget): + def __init__(self, num_candidates, time_budget, doc_error): self._cards = num_candidates self._history = np.zeros((num_candidates, 3)) self._last_review = np.repeat(-1.0, num_candidates) self._time_budget = time_budget self._time = 0 self._W = np.zeros((num_candidates, 3)) + self._doc_error = doc_error super(UserState, self).__init__() def create_observation(self): return {'history': self._history, 'last_review': self._last_review, 'time': self._time, 'time_budget': self._time_budget}