Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a583855e47 |
@@ -1,2 +1 @@
|
||||
from .util import create_agent_helper
|
||||
from .greedy import GreedyAgent
|
||||
@@ -1,58 +0,0 @@
|
||||
from recsim.agent import AbstractEpisodicRecommenderAgent
|
||||
import numpy as np
|
||||
|
||||
class GreedyAgent(AbstractEpisodicRecommenderAgent):
|
||||
def __init__(self, sess, observation_space, action_space, eval_mode, summary_writer):
|
||||
super(GreedyAgent, self).__init__(action_space, summary_writer)
|
||||
self._num_candidates = int(action_space.nvec[0])
|
||||
self._W = np.array([[3, 1.5, 0.5]] * self._num_candidates)
|
||||
assert self._slate_size == 1
|
||||
def begin_episode(self, observation=None):
|
||||
user = observation['user']
|
||||
docs = observation['doc']
|
||||
if 'W' in user: # use observable W
|
||||
self._W = user['W']
|
||||
else:
|
||||
w = []
|
||||
for doc_id in docs:
|
||||
w.append(docs[doc_id])
|
||||
self._W = np.array(w).reshape((-1, 3))
|
||||
print("agent W:", self._W)
|
||||
self._episode_num += 1
|
||||
return self.step(0, observation)
|
||||
def step(self, reward, observation):
|
||||
docs = observation['doc']
|
||||
user = observation['user']
|
||||
|
||||
base_pr = self.calc_prs(user['time'], user['last_review'], user['history'], self._W)
|
||||
# np.exp(-last_review / np.exp(np.dot(W, x))).squeeze()
|
||||
max_pr = -self._num_candidates
|
||||
max_id = 0
|
||||
for did in docs:
|
||||
doc_id = int(did)
|
||||
last_review = user['last_review'].copy()
|
||||
history = user['history'].copy()
|
||||
last_review[doc_id] = user['time']
|
||||
time = user['time'] + 1
|
||||
|
||||
history[doc_id][0] += 1
|
||||
history[doc_id][1] += 1
|
||||
pr1 = self.calc_prs(time, last_review, history, self._W)
|
||||
history[doc_id][1] -= 1
|
||||
history[doc_id][2] += 1
|
||||
pr2 = self.calc_prs(time, last_review, history, self._W)
|
||||
pr = (pr1 + pr2) / 2 - base_pr
|
||||
sum_pr = np.sum(pr)
|
||||
if sum_pr > max_pr:
|
||||
max_pr = sum_pr
|
||||
max_id = doc_id
|
||||
# print("pr1", pr1)
|
||||
# print("pr2", pr2)
|
||||
# print("pr0", base_pr)
|
||||
print(f"choose doc{max_id} with marginal gain {max_pr}")
|
||||
return [max_id]
|
||||
def calc_prs(self, train_time, last_review, history, W):
|
||||
last_review = train_time - last_review
|
||||
mem_param = np.exp(np.einsum('ij,ij->i', history, W))
|
||||
pr = np.exp(-last_review / mem_param)
|
||||
return pr
|
||||
@@ -2,13 +2,13 @@ from .FlashcardDocument import FlashcardDocument
|
||||
from recsim import document
|
||||
|
||||
class FlashcardDocumentSampler(document.AbstractDocumentSampler):
|
||||
def __init__(self, doc_ctor=FlashcardDocument, **kwargs):
|
||||
super(FlashcardDocumentSampler, self).__init__(doc_ctor, **kwargs)
|
||||
def __init__(self, doc_ctor=FlashcardDocument, seed=0, **kwargs):
|
||||
super(FlashcardDocumentSampler, self).__init__(doc_ctor, seed, **kwargs)
|
||||
self._doc_count = 0
|
||||
|
||||
def sample_document(self):
|
||||
doc_features = {}
|
||||
doc_features['doc_id'] = self._doc_count
|
||||
doc_features['difficulty'] = self._rng.uniform(0, 5, (1, 3))
|
||||
doc_features['difficulty'] = self._rng.uniform(0, 3, (1, 3))
|
||||
self._doc_count += 1
|
||||
return self._doc_ctor(**doc_features)
|
||||
4
main.py
4
main.py
@@ -17,8 +17,8 @@ tf.compat.v1.disable_eager_execution()
|
||||
create_agent_fn = create_agent_helper(full_slate_q_agent.FullSlateQAgent)
|
||||
|
||||
ltsenv = environment.Environment(
|
||||
FlashcardUserModel(num_candidates, time_budget, slate_size),
|
||||
FlashcardDocumentSampler(),
|
||||
FlashcardUserModel(num_candidates, time_budget, slate_size, seed=0, sample_seed=0),
|
||||
FlashcardDocumentSampler(seed=0),
|
||||
num_candidates,
|
||||
slate_size,
|
||||
resample_documents=False)
|
||||
|
||||
@@ -7,13 +7,14 @@ from util import eval_result
|
||||
import numpy as np
|
||||
|
||||
class FlashcardUserModel(user.AbstractUserModel):
|
||||
def __init__(self, num_candidates, time_budget, slate_size, seed=0):
|
||||
def __init__(self, num_candidates, time_budget, slate_size, seed=0, sample_seed=0):
|
||||
super(FlashcardUserModel, self).__init__(
|
||||
UserResponse, UserSampler(
|
||||
UserState, num_candidates, time_budget,
|
||||
seed=seed
|
||||
seed=sample_seed
|
||||
), slate_size)
|
||||
self.choice_model = MultinomialLogitChoiceModel({})
|
||||
self._rng = np.random.RandomState(seed)
|
||||
|
||||
def is_terminal(self):
|
||||
terminated = self._user_state._time > self._user_state._time_budget
|
||||
@@ -52,7 +53,8 @@ class FlashcardUserModel(user.AbstractUserModel):
|
||||
doc_id = doc._doc_id
|
||||
W = self._user_state._W[doc_id]
|
||||
if not W.any(): # uninitialzed
|
||||
self._user_state._W[doc_id] = W = doc.base_difficulty * np.random.uniform(0.5, 2.0, (1, 3)) # a uniform error for each user
|
||||
error = self._user_state._doc_error[doc_id] # a uniform error for each user
|
||||
self._user_state._W[doc_id] = W = doc.base_difficulty * error
|
||||
print(W)
|
||||
# use exponential function to simulate whether the user recalls
|
||||
last_review = self._user_state._time - self._user_state._last_review[doc_id]
|
||||
@@ -60,6 +62,6 @@ class FlashcardUserModel(user.AbstractUserModel):
|
||||
|
||||
pr = np.exp(-last_review / np.exp(np.dot(W, x))).squeeze()
|
||||
print(f"time: {self._user_state._time}, reviewing flashcard {doc_id}, recall rate = {pr}")
|
||||
if np.random.rand() < pr: # remembered
|
||||
if self._rng.random_sample() < pr: # remembered
|
||||
response._recall = True
|
||||
response._pr = pr
|
||||
@@ -7,9 +7,13 @@ class UserSampler(user.AbstractUserSampler):
|
||||
num_candidates=10,
|
||||
time_budget=60,
|
||||
**kwargs):
|
||||
self._state_parameters = {'num_candidates': num_candidates, 'time_budget': time_budget}
|
||||
super(UserSampler, self).__init__(user_ctor, **kwargs)
|
||||
|
||||
doc_error = self._rng.uniform(0.5, 1.5, (num_candidates, 3))
|
||||
self._state_parameters = {
|
||||
'num_candidates': num_candidates,
|
||||
'time_budget': time_budget,
|
||||
'doc_error': doc_error
|
||||
}
|
||||
|
||||
def sample_user(self):
|
||||
return self._user_ctor(**self._state_parameters)
|
||||
@@ -3,13 +3,14 @@ import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
class UserState(user.AbstractUserState):
|
||||
def __init__(self, num_candidates, time_budget):
|
||||
def __init__(self, num_candidates, time_budget, doc_error):
|
||||
self._cards = num_candidates
|
||||
self._history = np.zeros((num_candidates, 3))
|
||||
self._last_review = np.repeat(-1.0, num_candidates)
|
||||
self._time_budget = time_budget
|
||||
self._time = 0
|
||||
self._W = np.zeros((num_candidates, 3))
|
||||
self._doc_error = doc_error
|
||||
super(UserState, self).__init__()
|
||||
def create_observation(self):
|
||||
return {'history': self._history, 'last_review': self._last_review, 'time': self._time, 'time_budget': self._time_budget}
|
||||
|
||||
Reference in New Issue
Block a user