environment test
This commit is contained in:
commit
49529a9400
1
agent/__init__.py
Normal file
1
agent/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .util import create_create_agent
|
12
agent/util.py
Normal file
12
agent/util.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from recsim.agents import full_slate_q_agent
|
||||||
|
|
||||||
|
def create_create_agent(agent=full_slate_q_agent.FullSlateQAgent):
|
||||||
|
def create_agent(sess, environment, eval_mode, summary_writer=None):
|
||||||
|
kwargs = {
|
||||||
|
'observation_space': environment.observation_space,
|
||||||
|
'action_space': environment.action_space,
|
||||||
|
'summary_writer': summary_writer,
|
||||||
|
'eval_mode': eval_mode,
|
||||||
|
}
|
||||||
|
return agent(sess, **kwargs)
|
||||||
|
return create_agent
|
19
document/FlashcardDocument.py
Normal file
19
document/FlashcardDocument.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from recsim import document
|
||||||
|
from gym import spaces
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class FlashcardDocument(document.AbstractDocument):
|
||||||
|
def __init__(self, doc_id, difficulty):
|
||||||
|
self.base_difficulty = difficulty
|
||||||
|
# doc_id is an integer representing the unique ID of this document
|
||||||
|
super(FlashcardDocument, self).__init__(doc_id)
|
||||||
|
|
||||||
|
def create_observation(self):
|
||||||
|
return np.array(self.base_difficulty)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def observation_space():
|
||||||
|
return spaces.Box(shape=(1,3), dtype=np.float32, low=0.0, high=1.0)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Flashcard {} with difficulty {}.".format(self._doc_id, self.base_difficulty)
|
14
document/FlashcardDocumentSampler.py
Normal file
14
document/FlashcardDocumentSampler.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from .FlashcardDocument import FlashcardDocument
|
||||||
|
from recsim import document
|
||||||
|
|
||||||
|
class FlashcardDocumentSampler(document.AbstractDocumentSampler):
|
||||||
|
def __init__(self, doc_ctor=FlashcardDocument, **kwargs):
|
||||||
|
super(FlashcardDocumentSampler, self).__init__(doc_ctor, **kwargs)
|
||||||
|
self._doc_count = 0
|
||||||
|
|
||||||
|
def sample_document(self):
|
||||||
|
doc_features = {}
|
||||||
|
doc_features['doc_id'] = self._doc_count
|
||||||
|
doc_features['difficulty'] = self._rng.random_sample((1, 3))
|
||||||
|
self._doc_count += 1
|
||||||
|
return self._doc_ctor(**doc_features)
|
1
document/__init__.py
Normal file
1
document/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .FlashcardDocumentSampler import FlashcardDocumentSampler
|
39
main.py
Normal file
39
main.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
from recsim.simulator import environment
|
||||||
|
from user import FlashcardUserModel
|
||||||
|
from document import FlashcardDocumentSampler
|
||||||
|
from recsim.simulator import recsim_gym
|
||||||
|
from recsim.agents import full_slate_q_agent
|
||||||
|
from recsim.simulator import runner_lib
|
||||||
|
from agent import create_create_agent
|
||||||
|
from util import reward, update_metrics
|
||||||
|
|
||||||
|
slate_size = 1
|
||||||
|
num_candidates = 10
|
||||||
|
time_budget = 60
|
||||||
|
|
||||||
|
tf.compat.v1.disable_eager_execution()
|
||||||
|
|
||||||
|
create_agent_fn = create_create_agent(full_slate_q_agent.FullSlateQAgent)
|
||||||
|
|
||||||
|
ltsenv = environment.Environment(
|
||||||
|
FlashcardUserModel(num_candidates, time_budget, slate_size),
|
||||||
|
FlashcardDocumentSampler(),
|
||||||
|
num_candidates,
|
||||||
|
slate_size,
|
||||||
|
resample_documents=False)
|
||||||
|
|
||||||
|
lts_gym_env = recsim_gym.RecSimGymEnv(ltsenv, reward, update_metrics)
|
||||||
|
lts_gym_env.reset()
|
||||||
|
|
||||||
|
tmp_base_dir = './recsim/'
|
||||||
|
runner = runner_lib.TrainRunner(
|
||||||
|
base_dir=tmp_base_dir,
|
||||||
|
create_agent_fn=create_agent_fn,
|
||||||
|
env=lts_gym_env,
|
||||||
|
episode_log_file="",
|
||||||
|
max_training_steps=5,
|
||||||
|
num_iterations=1
|
||||||
|
)
|
||||||
|
|
||||||
|
runner.run_experiment()
|
309
recsim_environment.py
Normal file
309
recsim_environment.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""RecSim Environment
|
||||||
|
|
||||||
|
Automatically generated by Colaboratory.
|
||||||
|
|
||||||
|
Original file is located at
|
||||||
|
https://colab.research.google.com/drive/1KJbwKa0URSOU9B7GsDAkYOoFAoU5g14Y
|
||||||
|
"""
|
||||||
|
|
||||||
|
!pip install --upgrade --no-cache-dir recsim
|
||||||
|
|
||||||
|
#@title Generic imports
|
||||||
|
import numpy as np
|
||||||
|
from gym import spaces
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
|
#@title RecSim imports
|
||||||
|
from recsim import document
|
||||||
|
from recsim import user
|
||||||
|
from recsim.choice_model import MultinomialLogitChoiceModel
|
||||||
|
from recsim.simulator import environment
|
||||||
|
from recsim.simulator import recsim_gym
|
||||||
|
|
||||||
|
# diasble eager execution to avoid error
|
||||||
|
import tensorflow as tf
|
||||||
|
tf.compat.v1.disable_eager_execution()
|
||||||
|
|
||||||
|
"""# Flashcard Learning Environment Build
|
||||||
|
## Documents (Flashcards)
|
||||||
|
- difficulty (w)
|
||||||
|
- deadline
|
||||||
|
- other features?
|
||||||
|
|
||||||
|
### Document Model
|
||||||
|
### Sampler
|
||||||
|
|
||||||
|
## Users
|
||||||
|
### User State and Transition
|
||||||
|
**static**
|
||||||
|
- learning ability
|
||||||
|
|
||||||
|
**dynamic**
|
||||||
|
- recall history (#correct, #wrong)
|
||||||
|
|
||||||
|
### Sampler
|
||||||
|
|
||||||
|
### User Choice Model
|
||||||
|
- user has no choice but to review the card agent provides
|
||||||
|
|
||||||
|
### User Response
|
||||||
|
- user's self evaluation (remember or not) -> update history
|
||||||
|
|
||||||
|
## Reward (From User Response)
|
||||||
|
- gain = maximum additional retention rate if the card is chosen
|
||||||
|
- time factor = α * sqrt(lnδ/n_t)
|
||||||
|
"""
|
||||||
|
|
||||||
|
slate_size = 1
|
||||||
|
num_candidates = 10
|
||||||
|
|
||||||
|
class FlashcardDocument(document.AbstractDocument):
|
||||||
|
def __init__(self, doc_id, difficulty):
|
||||||
|
self.base_difficulty = difficulty
|
||||||
|
# doc_id is an integer representing the unique ID of this document
|
||||||
|
super(FlashcardDocument, self).__init__(doc_id)
|
||||||
|
|
||||||
|
def create_observation(self):
|
||||||
|
return np.array(self.base_difficulty)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def observation_space():
|
||||||
|
return spaces.Box(shape=(1,3), dtype=np.float32, low=0.0, high=1.0)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Flashcard {} with difficulty {}.".format(self._doc_id, self.base_difficulty)
|
||||||
|
|
||||||
|
class FlashcardDocumentSampler(document.AbstractDocumentSampler):
|
||||||
|
def __init__(self, doc_ctor=FlashcardDocument, **kwargs):
|
||||||
|
super(FlashcardDocumentSampler, self).__init__(doc_ctor, **kwargs)
|
||||||
|
self._doc_count = 0
|
||||||
|
|
||||||
|
def sample_document(self):
|
||||||
|
doc_features = {}
|
||||||
|
doc_features['doc_id'] = self._doc_count
|
||||||
|
doc_features['difficulty'] = self._rng.random_sample((1, 3))
|
||||||
|
self._doc_count += 1
|
||||||
|
return self._doc_ctor(**doc_features)
|
||||||
|
|
||||||
|
class UserState(user.AbstractUserState):
|
||||||
|
def __init__(self, num_candidates, time_budget):
|
||||||
|
self._cards = num_candidates
|
||||||
|
self._history = np.zeros((num_candidates, 3))
|
||||||
|
self._last_review = np.zeros((num_candidates,))
|
||||||
|
self._time_budget = time_budget
|
||||||
|
self._time = 0
|
||||||
|
self._W = np.zeros((num_candidates, 3))
|
||||||
|
super(UserState, self).__init__()
|
||||||
|
def create_observation(self):
|
||||||
|
return {'history': self._history, 'last_review': self._last_review, 'time': self._time, 'time_budget': self._time_budget}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def observation_space():
|
||||||
|
return spaces.Dict({
|
||||||
|
'history': spaces.Box(shape=(num_candidates, 3), low=0, high=np.inf, dtype=int),
|
||||||
|
'last_review': spaces.Box(shape=(num_candidates,), low=0, high=np.inf, dtype=int),
|
||||||
|
'time': spaces.Box(shape=(1,), low=0, high=np.inf, dtype=int),
|
||||||
|
'time_budget': spaces.Box(shape=(1,), low=0, high=np.inf, dtype=int),
|
||||||
|
})
|
||||||
|
|
||||||
|
def score_document(self, doc_obs):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
class UserSampler(user.AbstractUserSampler):
|
||||||
|
_state_parameters = {'num_candidates': num_candidates, 'time_budget': 60}
|
||||||
|
def __init__(self,
|
||||||
|
user_ctor=UserState,
|
||||||
|
**kwargs):
|
||||||
|
# self._state_parameters = {'num_candidates': num_candidates}
|
||||||
|
super(UserSampler, self).__init__(user_ctor, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_user(self):
|
||||||
|
return self._user_ctor(**self._state_parameters)
|
||||||
|
|
||||||
|
sampler = UserSampler()
|
||||||
|
# for i in range(10):
|
||||||
|
u = sampler.sample_user()
|
||||||
|
u.observation_space()
|
||||||
|
|
||||||
|
class UserResponse(user.AbstractResponse):
|
||||||
|
def __init__(self, recall=False, pr=0):
|
||||||
|
self._recall = recall
|
||||||
|
self._pr = pr
|
||||||
|
|
||||||
|
def create_observation(self):
|
||||||
|
return {'recall': int(self._recall), 'pr': self._pr}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_space(cls):
|
||||||
|
# return spaces.Discrete(2)
|
||||||
|
return spaces.Dict({'recall': spaces.Discrete(2), 'pr': spaces.Box(low=0.0, high=1.0)})
|
||||||
|
|
||||||
|
"""# Evaluation
|
||||||
|
Calling `eval_result()` to evaluate the agent performance. This function should be outside the RecSim structure to avoid changing the training status.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
def eval_result(train_time, last_review, history, W):
|
||||||
|
with open(f"{datetime.now()}.txt", "w") as f:
|
||||||
|
print(train_time, file=f)
|
||||||
|
print(last_review, file=f)
|
||||||
|
print(history, file=f)
|
||||||
|
print(W, file=f)
|
||||||
|
# np.einsum('ij,ij->i', a, b)
|
||||||
|
last_review = train_time - last_review
|
||||||
|
mem_param = np.exp(np.einsum('ij,ij->i', history, W))
|
||||||
|
pr = np.exp(-last_review / mem_param)
|
||||||
|
print(pr, file=f)
|
||||||
|
print(pr)
|
||||||
|
print("score:", np.sum(pr) / pr.shape[0], file=f)
|
||||||
|
print("score:", np.sum(pr) / pr.shape[0])
|
||||||
|
|
||||||
|
class FlashcardUserModel(user.AbstractUserModel):
|
||||||
|
def __init__(self, slate_size, seed=0):
|
||||||
|
super(FlashcardUserModel, self).__init__(
|
||||||
|
UserResponse, UserSampler(
|
||||||
|
UserState, seed=seed
|
||||||
|
), slate_size)
|
||||||
|
self.choice_model = MultinomialLogitChoiceModel({})
|
||||||
|
|
||||||
|
def is_terminal(self):
|
||||||
|
terminated = self._user_state._time > self._user_state._time_budget
|
||||||
|
if terminated: # run evaluation process
|
||||||
|
eval_result(self._user_state._time,
|
||||||
|
self._user_state._last_review.copy(),
|
||||||
|
self._user_state._history.copy(),
|
||||||
|
self._user_state._W.copy())
|
||||||
|
return terminated
|
||||||
|
|
||||||
|
def update_state(self, slate_documents, responses):
|
||||||
|
for doc, response in zip(slate_documents, responses):
|
||||||
|
doc_id = doc._doc_id
|
||||||
|
self._user_state._history[doc_id][0] += 1
|
||||||
|
if response._recall:
|
||||||
|
self._user_state._history[doc_id][1] += 1
|
||||||
|
else:
|
||||||
|
self._user_state._history[doc_id][2] += 1
|
||||||
|
self._user_state._last_review[doc_id] = self._user_state._time
|
||||||
|
self._user_state._time += 1
|
||||||
|
|
||||||
|
def simulate_response(self, slate_documents):
|
||||||
|
responses = [self._response_model_ctor() for _ in slate_documents]
|
||||||
|
# Get click from of choice model.
|
||||||
|
self.choice_model.score_documents(
|
||||||
|
self._user_state, [doc.create_observation() for doc in slate_documents])
|
||||||
|
scores = self.choice_model.scores
|
||||||
|
selected_index = self.choice_model.choose_item()
|
||||||
|
# Populate clicked item.
|
||||||
|
self._generate_response(slate_documents[selected_index],
|
||||||
|
responses[selected_index])
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def _generate_response(self, doc, response):
|
||||||
|
# W = np.array([1,1,1])
|
||||||
|
doc_id = doc._doc_id
|
||||||
|
W = self._user_state._W[doc_id]
|
||||||
|
if not W.any(): # uninitialzed
|
||||||
|
self._user_state._W[doc_id] = W = doc.base_difficulty + np.random.uniform(-1, 1, (1, 3)) # a uniform error for each user
|
||||||
|
print(W)
|
||||||
|
# use exponential function to simulate whether the user recalls
|
||||||
|
last_review = self._user_state._time - self._user_state._last_review[doc_id]
|
||||||
|
x = self._user_state._history[doc_id]
|
||||||
|
|
||||||
|
pr = np.exp(-last_review / np.exp(np.dot(W, x))).squeeze()
|
||||||
|
print(f"time: {self._user_state._time}, reviewing flashcard {doc_id}, recall rate = {pr}")
|
||||||
|
if np.random.rand() < pr: # remembered
|
||||||
|
response._recall = True
|
||||||
|
response._pr = pr
|
||||||
|
|
||||||
|
ltsenv = environment.Environment(
|
||||||
|
FlashcardUserModel(slate_size),
|
||||||
|
FlashcardDocumentSampler(),
|
||||||
|
num_candidates,
|
||||||
|
slate_size,
|
||||||
|
resample_documents=False)
|
||||||
|
|
||||||
|
def reward(responses):
|
||||||
|
reward = 0.0
|
||||||
|
for response in responses:
|
||||||
|
reward += int(response._recall)
|
||||||
|
return reward
|
||||||
|
|
||||||
|
def update_metrics(responses, metrics, info):
|
||||||
|
# print("responses: ", responses)
|
||||||
|
prs = []
|
||||||
|
for response in responses:
|
||||||
|
prs.append(response['pr'])
|
||||||
|
if type(metrics) != list:
|
||||||
|
metrics = [prs]
|
||||||
|
else:
|
||||||
|
metrics.append(prs)
|
||||||
|
# print(metrics)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
observation = ltsenv.reset()
|
||||||
|
# user - history (n, n+, n-)
|
||||||
|
print("Observation space of user:")
|
||||||
|
print(u.observation_space(), '\n')
|
||||||
|
print("User history:")
|
||||||
|
print(observation[0]['history'], '\n')
|
||||||
|
# user - last review time of each card
|
||||||
|
print("User last_review:")
|
||||||
|
print(observation[0]['last_review'], '\n')
|
||||||
|
# user - current time (you can get the delta by time - last_review)
|
||||||
|
print("User time:")
|
||||||
|
print(observation[0]['time'], '\n')
|
||||||
|
# user - time bidget (deadline)
|
||||||
|
print("User time budget:")
|
||||||
|
print(observation[0]['time_budget'])
|
||||||
|
|
||||||
|
# ltsenv.reset()
|
||||||
|
lts_gym_env = recsim_gym.RecSimGymEnv(ltsenv, reward, update_metrics)
|
||||||
|
lts_gym_env.reset()
|
||||||
|
|
||||||
|
try_observation = lts_gym_env.reset()
|
||||||
|
|
||||||
|
for i in range(len(try_observation['doc'])):
|
||||||
|
print(try_observation['user']['history'][i])
|
||||||
|
|
||||||
|
#print(try_observation['user']['history'].shape[0])
|
||||||
|
|
||||||
|
my_list = [10.0, 5.5, 8.1, 2.0, 1.57]
|
||||||
|
max_value = max(my_list)
|
||||||
|
print(my_list.index(max(my_list)))
|
||||||
|
|
||||||
|
def create_agent(sess, environment, eval_mode, summary_writer=None):
|
||||||
|
kwargs = {
|
||||||
|
'observation_space': environment.observation_space,
|
||||||
|
'action_space': environment.action_space,
|
||||||
|
'summary_writer': summary_writer,
|
||||||
|
'eval_mode': eval_mode,
|
||||||
|
}
|
||||||
|
return full_slate_q_agent.FullSlateQAgent(sess, **kwargs)
|
||||||
|
|
||||||
|
#@title Importing RecSim components
|
||||||
|
from recsim.environments import interest_evolution
|
||||||
|
from recsim.agents import full_slate_q_agent
|
||||||
|
from recsim.simulator import runner_lib
|
||||||
|
|
||||||
|
tmp_base_dir = '/tmp/recsim/'
|
||||||
|
runner = runner_lib.TrainRunner(
|
||||||
|
base_dir=tmp_base_dir,
|
||||||
|
create_agent_fn=create_agent,
|
||||||
|
env=lts_gym_env,
|
||||||
|
episode_log_file="",
|
||||||
|
max_training_steps=5,
|
||||||
|
num_iterations=1
|
||||||
|
)
|
||||||
|
|
||||||
|
runner.run_experiment()
|
||||||
|
|
||||||
|
# Commented out IPython magic to ensure Python compatibility.
|
||||||
|
# Load the TensorBoard notebook extension
|
||||||
|
# %load_ext tensorboard
|
||||||
|
#@title Tensorboard
|
||||||
|
# %tensorboard --logdir=/tmp/recsim/
|
||||||
|
|
65
user/FlashcardUserModel.py
Normal file
65
user/FlashcardUserModel.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from recsim import user
|
||||||
|
from recsim.choice_model import MultinomialLogitChoiceModel
|
||||||
|
from .UserState import UserState
|
||||||
|
from .UserSampler import UserSampler
|
||||||
|
from .UserResponse import UserResponse
|
||||||
|
from util import eval_result
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class FlashcardUserModel(user.AbstractUserModel):
|
||||||
|
def __init__(self, num_candidates, time_budget, slate_size, seed=0):
|
||||||
|
super(FlashcardUserModel, self).__init__(
|
||||||
|
UserResponse, UserSampler(
|
||||||
|
UserState, num_candidates, time_budget,
|
||||||
|
seed=seed
|
||||||
|
), slate_size)
|
||||||
|
self.choice_model = MultinomialLogitChoiceModel({})
|
||||||
|
|
||||||
|
def is_terminal(self):
|
||||||
|
terminated = self._user_state._time > self._user_state._time_budget
|
||||||
|
if terminated: # run evaluation process
|
||||||
|
eval_result(self._user_state._time,
|
||||||
|
self._user_state._last_review.copy(),
|
||||||
|
self._user_state._history.copy(),
|
||||||
|
self._user_state._W.copy())
|
||||||
|
return terminated
|
||||||
|
|
||||||
|
def update_state(self, slate_documents, responses):
|
||||||
|
for doc, response in zip(slate_documents, responses):
|
||||||
|
doc_id = doc._doc_id
|
||||||
|
self._user_state._history[doc_id][0] += 1
|
||||||
|
if response._recall:
|
||||||
|
self._user_state._history[doc_id][1] += 1
|
||||||
|
else:
|
||||||
|
self._user_state._history[doc_id][2] += 1
|
||||||
|
self._user_state._last_review[doc_id] = self._user_state._time
|
||||||
|
self._user_state._time += 1
|
||||||
|
|
||||||
|
def simulate_response(self, slate_documents):
|
||||||
|
responses = [self._response_model_ctor() for _ in slate_documents]
|
||||||
|
# Get click from of choice model.
|
||||||
|
self.choice_model.score_documents(
|
||||||
|
self._user_state, [doc.create_observation() for doc in slate_documents])
|
||||||
|
scores = self.choice_model.scores
|
||||||
|
selected_index = self.choice_model.choose_item()
|
||||||
|
# Populate clicked item.
|
||||||
|
self._generate_response(slate_documents[selected_index],
|
||||||
|
responses[selected_index])
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def _generate_response(self, doc, response):
|
||||||
|
# W = np.array([1,1,1])
|
||||||
|
doc_id = doc._doc_id
|
||||||
|
W = self._user_state._W[doc_id]
|
||||||
|
if not W.any(): # uninitialzed
|
||||||
|
self._user_state._W[doc_id] = W = doc.base_difficulty + np.random.uniform(-0.5, 0.5, (1, 3)) # a uniform error for each user
|
||||||
|
print(W)
|
||||||
|
# use exponential function to simulate whether the user recalls
|
||||||
|
last_review = self._user_state._time - self._user_state._last_review[doc_id]
|
||||||
|
x = self._user_state._history[doc_id]
|
||||||
|
|
||||||
|
pr = np.exp(-last_review / np.exp(np.dot(W, x))).squeeze()
|
||||||
|
print(f"time: {self._user_state._time}, reviewing flashcard {doc_id}, recall rate = {pr}")
|
||||||
|
if np.random.rand() < pr: # remembered
|
||||||
|
response._recall = True
|
||||||
|
response._pr = pr
|
15
user/UserResponse.py
Normal file
15
user/UserResponse.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from recsim import user
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
class UserResponse(user.AbstractResponse):
|
||||||
|
def __init__(self, recall=False, pr=0):
|
||||||
|
self._recall = recall
|
||||||
|
self._pr = pr
|
||||||
|
|
||||||
|
def create_observation(self):
|
||||||
|
return {'recall': int(self._recall), 'pr': self._pr}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_space(cls):
|
||||||
|
# return spaces.Discrete(2)
|
||||||
|
return spaces.Dict({'recall': spaces.Discrete(2), 'pr': spaces.Box(low=0.0, high=1.0)})
|
15
user/UserSampler.py
Normal file
15
user/UserSampler.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from .UserState import UserState
|
||||||
|
from recsim import user
|
||||||
|
|
||||||
|
class UserSampler(user.AbstractUserSampler):
|
||||||
|
def __init__(self,
|
||||||
|
user_ctor=UserState,
|
||||||
|
num_candidates=10,
|
||||||
|
time_budget=60,
|
||||||
|
**kwargs):
|
||||||
|
self._state_parameters = {'num_candidates': num_candidates, 'time_budget': time_budget}
|
||||||
|
super(UserSampler, self).__init__(user_ctor, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_user(self):
|
||||||
|
return self._user_ctor(**self._state_parameters)
|
26
user/UserState.py
Normal file
26
user/UserState.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from recsim import user
|
||||||
|
import numpy as np
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
class UserState(user.AbstractUserState):
|
||||||
|
def __init__(self, num_candidates, time_budget):
|
||||||
|
self._cards = num_candidates
|
||||||
|
self._history = np.zeros((num_candidates, 3))
|
||||||
|
self._last_review = np.zeros((num_candidates,))
|
||||||
|
self._time_budget = time_budget
|
||||||
|
self._time = 0
|
||||||
|
self._W = np.zeros((num_candidates, 3))
|
||||||
|
super(UserState, self).__init__()
|
||||||
|
def create_observation(self):
|
||||||
|
return {'history': self._history, 'last_review': self._last_review, 'time': self._time, 'time_budget': self._time_budget}
|
||||||
|
|
||||||
|
def observation_space(self): # can this work?
|
||||||
|
return spaces.Dict({
|
||||||
|
'history': spaces.Box(shape=(self._cards, 3), low=0, high=np.inf, dtype=int),
|
||||||
|
'last_review': spaces.Box(shape=(self._cards,), low=0, high=np.inf, dtype=int),
|
||||||
|
'time': spaces.Box(shape=(1,), low=0, high=np.inf, dtype=int),
|
||||||
|
'time_budget': spaces.Box(shape=(1,), low=0, high=np.inf, dtype=int),
|
||||||
|
})
|
||||||
|
|
||||||
|
def score_document(self, doc_obs):
|
||||||
|
return 1
|
2
user/__init__.py
Normal file
2
user/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .FlashcardUserModel import FlashcardUserModel
|
||||||
|
from .UserResponse import UserResponse
|
1
util/__init__.py
Normal file
1
util/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from util.util import *
|
35
util/util.py
Normal file
35
util/util.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def reward(responses):
|
||||||
|
reward = 0.0
|
||||||
|
for response in responses:
|
||||||
|
reward += int(response._recall)
|
||||||
|
return reward
|
||||||
|
|
||||||
|
def update_metrics(responses, metrics, info):
|
||||||
|
# print("responses: ", responses)
|
||||||
|
prs = []
|
||||||
|
for response in responses:
|
||||||
|
prs.append(response['pr'])
|
||||||
|
if type(metrics) != list:
|
||||||
|
metrics = [prs]
|
||||||
|
else:
|
||||||
|
metrics.append(prs)
|
||||||
|
# print(metrics)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def eval_result(train_time, last_review, history, W):
|
||||||
|
with open(f"{datetime.now()}.txt", "w") as f:
|
||||||
|
print(train_time, file=f)
|
||||||
|
print(last_review, file=f)
|
||||||
|
print(history, file=f)
|
||||||
|
print(W, file=f)
|
||||||
|
# np.einsum('ij,ij->i', a, b)
|
||||||
|
last_review = train_time - last_review
|
||||||
|
mem_param = np.exp(np.einsum('ij,ij->i', history, W))
|
||||||
|
pr = np.exp(-last_review / mem_param)
|
||||||
|
print(pr, file=f)
|
||||||
|
print(pr)
|
||||||
|
print("score:", np.sum(pr) / pr.shape[0], file=f)
|
||||||
|
print("score:", np.sum(pr) / pr.shape[0])
|
Loading…
x
Reference in New Issue
Block a user