environment test
This commit is contained in:
65
user/FlashcardUserModel.py
Normal file
65
user/FlashcardUserModel.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from recsim import user
|
||||
from recsim.choice_model import MultinomialLogitChoiceModel
|
||||
from .UserState import UserState
|
||||
from .UserSampler import UserSampler
|
||||
from .UserResponse import UserResponse
|
||||
from util import eval_result
|
||||
import numpy as np
|
||||
|
||||
class FlashcardUserModel(user.AbstractUserModel):
|
||||
def __init__(self, num_candidates, time_budget, slate_size, seed=0):
|
||||
super(FlashcardUserModel, self).__init__(
|
||||
UserResponse, UserSampler(
|
||||
UserState, num_candidates, time_budget,
|
||||
seed=seed
|
||||
), slate_size)
|
||||
self.choice_model = MultinomialLogitChoiceModel({})
|
||||
|
||||
def is_terminal(self):
|
||||
terminated = self._user_state._time > self._user_state._time_budget
|
||||
if terminated: # run evaluation process
|
||||
eval_result(self._user_state._time,
|
||||
self._user_state._last_review.copy(),
|
||||
self._user_state._history.copy(),
|
||||
self._user_state._W.copy())
|
||||
return terminated
|
||||
|
||||
def update_state(self, slate_documents, responses):
|
||||
for doc, response in zip(slate_documents, responses):
|
||||
doc_id = doc._doc_id
|
||||
self._user_state._history[doc_id][0] += 1
|
||||
if response._recall:
|
||||
self._user_state._history[doc_id][1] += 1
|
||||
else:
|
||||
self._user_state._history[doc_id][2] += 1
|
||||
self._user_state._last_review[doc_id] = self._user_state._time
|
||||
self._user_state._time += 1
|
||||
|
||||
def simulate_response(self, slate_documents):
|
||||
responses = [self._response_model_ctor() for _ in slate_documents]
|
||||
# Get click from of choice model.
|
||||
self.choice_model.score_documents(
|
||||
self._user_state, [doc.create_observation() for doc in slate_documents])
|
||||
scores = self.choice_model.scores
|
||||
selected_index = self.choice_model.choose_item()
|
||||
# Populate clicked item.
|
||||
self._generate_response(slate_documents[selected_index],
|
||||
responses[selected_index])
|
||||
return responses
|
||||
|
||||
def _generate_response(self, doc, response):
|
||||
# W = np.array([1,1,1])
|
||||
doc_id = doc._doc_id
|
||||
W = self._user_state._W[doc_id]
|
||||
if not W.any(): # uninitialzed
|
||||
self._user_state._W[doc_id] = W = doc.base_difficulty + np.random.uniform(-0.5, 0.5, (1, 3)) # a uniform error for each user
|
||||
print(W)
|
||||
# use exponential function to simulate whether the user recalls
|
||||
last_review = self._user_state._time - self._user_state._last_review[doc_id]
|
||||
x = self._user_state._history[doc_id]
|
||||
|
||||
pr = np.exp(-last_review / np.exp(np.dot(W, x))).squeeze()
|
||||
print(f"time: {self._user_state._time}, reviewing flashcard {doc_id}, recall rate = {pr}")
|
||||
if np.random.rand() < pr: # remembered
|
||||
response._recall = True
|
||||
response._pr = pr
|
||||
15
user/UserResponse.py
Normal file
15
user/UserResponse.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from recsim import user
|
||||
from gym import spaces
|
||||
|
||||
class UserResponse(user.AbstractResponse):
|
||||
def __init__(self, recall=False, pr=0):
|
||||
self._recall = recall
|
||||
self._pr = pr
|
||||
|
||||
def create_observation(self):
|
||||
return {'recall': int(self._recall), 'pr': self._pr}
|
||||
|
||||
@classmethod
|
||||
def response_space(cls):
|
||||
# return spaces.Discrete(2)
|
||||
return spaces.Dict({'recall': spaces.Discrete(2), 'pr': spaces.Box(low=0.0, high=1.0)})
|
||||
15
user/UserSampler.py
Normal file
15
user/UserSampler.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .UserState import UserState
|
||||
from recsim import user
|
||||
|
||||
class UserSampler(user.AbstractUserSampler):
|
||||
def __init__(self,
|
||||
user_ctor=UserState,
|
||||
num_candidates=10,
|
||||
time_budget=60,
|
||||
**kwargs):
|
||||
self._state_parameters = {'num_candidates': num_candidates, 'time_budget': time_budget}
|
||||
super(UserSampler, self).__init__(user_ctor, **kwargs)
|
||||
|
||||
|
||||
def sample_user(self):
|
||||
return self._user_ctor(**self._state_parameters)
|
||||
26
user/UserState.py
Normal file
26
user/UserState.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from recsim import user
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
class UserState(user.AbstractUserState):
|
||||
def __init__(self, num_candidates, time_budget):
|
||||
self._cards = num_candidates
|
||||
self._history = np.zeros((num_candidates, 3))
|
||||
self._last_review = np.zeros((num_candidates,))
|
||||
self._time_budget = time_budget
|
||||
self._time = 0
|
||||
self._W = np.zeros((num_candidates, 3))
|
||||
super(UserState, self).__init__()
|
||||
def create_observation(self):
|
||||
return {'history': self._history, 'last_review': self._last_review, 'time': self._time, 'time_budget': self._time_budget}
|
||||
|
||||
def observation_space(self): # can this work?
|
||||
return spaces.Dict({
|
||||
'history': spaces.Box(shape=(self._cards, 3), low=0, high=np.inf, dtype=int),
|
||||
'last_review': spaces.Box(shape=(self._cards,), low=0, high=np.inf, dtype=int),
|
||||
'time': spaces.Box(shape=(1,), low=0, high=np.inf, dtype=int),
|
||||
'time_budget': spaces.Box(shape=(1,), low=0, high=np.inf, dtype=int),
|
||||
})
|
||||
|
||||
def score_document(self, doc_obs):
|
||||
return 1
|
||||
2
user/__init__.py
Normal file
2
user/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .FlashcardUserModel import FlashcardUserModel
|
||||
from .UserResponse import UserResponse
|
||||
Reference in New Issue
Block a user