RecSim_FlashcardLearning/recsim_environment.py
2023-10-26 01:43:56 +08:00

310 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""RecSim Environment
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1KJbwKa0URSOU9B7GsDAkYOoFAoU5g14Y
"""
!pip install --upgrade --no-cache-dir recsim
#@title Generic imports
import numpy as np
from gym import spaces
import matplotlib.pyplot as plt
from scipy import stats
#@title RecSim imports
from recsim import document
from recsim import user
from recsim.choice_model import MultinomialLogitChoiceModel
from recsim.simulator import environment
from recsim.simulator import recsim_gym
# diasble eager execution to avoid error
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
"""# Flashcard Learning Environment Build
## Documents (Flashcards)
- difficulty (w)
- deadline
- other features?
### Document Model
### Sampler
## Users
### User State and Transition
**static**
- learning ability
**dynamic**
- recall history (#correct, #wrong)
### Sampler
### User Choice Model
- user has no choice but to review the card agent provides
### User Response
- user's self evaluation (remember or not) -> update history
## Reward (From User Response)
- gain = maximum additional retention rate if the card is chosen
- time factor = α * sqrt(lnδ/n_t)
"""
slate_size = 1
num_candidates = 10
class FlashcardDocument(document.AbstractDocument):
def __init__(self, doc_id, difficulty):
self.base_difficulty = difficulty
# doc_id is an integer representing the unique ID of this document
super(FlashcardDocument, self).__init__(doc_id)
def create_observation(self):
return np.array(self.base_difficulty)
@staticmethod
def observation_space():
return spaces.Box(shape=(1,3), dtype=np.float32, low=0.0, high=1.0)
def __str__(self):
return "Flashcard {} with difficulty {}.".format(self._doc_id, self.base_difficulty)
class FlashcardDocumentSampler(document.AbstractDocumentSampler):
def __init__(self, doc_ctor=FlashcardDocument, **kwargs):
super(FlashcardDocumentSampler, self).__init__(doc_ctor, **kwargs)
self._doc_count = 0
def sample_document(self):
doc_features = {}
doc_features['doc_id'] = self._doc_count
doc_features['difficulty'] = self._rng.random_sample((1, 3))
self._doc_count += 1
return self._doc_ctor(**doc_features)
class UserState(user.AbstractUserState):
def __init__(self, num_candidates, time_budget):
self._cards = num_candidates
self._history = np.zeros((num_candidates, 3))
self._last_review = np.zeros((num_candidates,))
self._time_budget = time_budget
self._time = 0
self._W = np.zeros((num_candidates, 3))
super(UserState, self).__init__()
def create_observation(self):
return {'history': self._history, 'last_review': self._last_review, 'time': self._time, 'time_budget': self._time_budget}
@staticmethod
def observation_space():
return spaces.Dict({
'history': spaces.Box(shape=(num_candidates, 3), low=0, high=np.inf, dtype=int),
'last_review': spaces.Box(shape=(num_candidates,), low=0, high=np.inf, dtype=int),
'time': spaces.Box(shape=(1,), low=0, high=np.inf, dtype=int),
'time_budget': spaces.Box(shape=(1,), low=0, high=np.inf, dtype=int),
})
def score_document(self, doc_obs):
return 1
class UserSampler(user.AbstractUserSampler):
_state_parameters = {'num_candidates': num_candidates, 'time_budget': 60}
def __init__(self,
user_ctor=UserState,
**kwargs):
# self._state_parameters = {'num_candidates': num_candidates}
super(UserSampler, self).__init__(user_ctor, **kwargs)
def sample_user(self):
return self._user_ctor(**self._state_parameters)
sampler = UserSampler()
# for i in range(10):
u = sampler.sample_user()
u.observation_space()
class UserResponse(user.AbstractResponse):
def __init__(self, recall=False, pr=0):
self._recall = recall
self._pr = pr
def create_observation(self):
return {'recall': int(self._recall), 'pr': self._pr}
@classmethod
def response_space(cls):
# return spaces.Discrete(2)
return spaces.Dict({'recall': spaces.Discrete(2), 'pr': spaces.Box(low=0.0, high=1.0)})
"""# Evaluation
Calling `eval_result()` to evaluate the agent performance. This function should be outside the RecSim structure to avoid changing the training status.
"""
from datetime import datetime
def eval_result(train_time, last_review, history, W):
with open(f"{datetime.now()}.txt", "w") as f:
print(train_time, file=f)
print(last_review, file=f)
print(history, file=f)
print(W, file=f)
# np.einsum('ij,ij->i', a, b)
last_review = train_time - last_review
mem_param = np.exp(np.einsum('ij,ij->i', history, W))
pr = np.exp(-last_review / mem_param)
print(pr, file=f)
print(pr)
print("score:", np.sum(pr) / pr.shape[0], file=f)
print("score:", np.sum(pr) / pr.shape[0])
class FlashcardUserModel(user.AbstractUserModel):
def __init__(self, slate_size, seed=0):
super(FlashcardUserModel, self).__init__(
UserResponse, UserSampler(
UserState, seed=seed
), slate_size)
self.choice_model = MultinomialLogitChoiceModel({})
def is_terminal(self):
terminated = self._user_state._time > self._user_state._time_budget
if terminated: # run evaluation process
eval_result(self._user_state._time,
self._user_state._last_review.copy(),
self._user_state._history.copy(),
self._user_state._W.copy())
return terminated
def update_state(self, slate_documents, responses):
for doc, response in zip(slate_documents, responses):
doc_id = doc._doc_id
self._user_state._history[doc_id][0] += 1
if response._recall:
self._user_state._history[doc_id][1] += 1
else:
self._user_state._history[doc_id][2] += 1
self._user_state._last_review[doc_id] = self._user_state._time
self._user_state._time += 1
def simulate_response(self, slate_documents):
responses = [self._response_model_ctor() for _ in slate_documents]
# Get click from of choice model.
self.choice_model.score_documents(
self._user_state, [doc.create_observation() for doc in slate_documents])
scores = self.choice_model.scores
selected_index = self.choice_model.choose_item()
# Populate clicked item.
self._generate_response(slate_documents[selected_index],
responses[selected_index])
return responses
def _generate_response(self, doc, response):
# W = np.array([1,1,1])
doc_id = doc._doc_id
W = self._user_state._W[doc_id]
if not W.any(): # uninitialzed
self._user_state._W[doc_id] = W = doc.base_difficulty + np.random.uniform(-1, 1, (1, 3)) # a uniform error for each user
print(W)
# use exponential function to simulate whether the user recalls
last_review = self._user_state._time - self._user_state._last_review[doc_id]
x = self._user_state._history[doc_id]
pr = np.exp(-last_review / np.exp(np.dot(W, x))).squeeze()
print(f"time: {self._user_state._time}, reviewing flashcard {doc_id}, recall rate = {pr}")
if np.random.rand() < pr: # remembered
response._recall = True
response._pr = pr
ltsenv = environment.Environment(
FlashcardUserModel(slate_size),
FlashcardDocumentSampler(),
num_candidates,
slate_size,
resample_documents=False)
def reward(responses):
reward = 0.0
for response in responses:
reward += int(response._recall)
return reward
def update_metrics(responses, metrics, info):
# print("responses: ", responses)
prs = []
for response in responses:
prs.append(response['pr'])
if type(metrics) != list:
metrics = [prs]
else:
metrics.append(prs)
# print(metrics)
return metrics
observation = ltsenv.reset()
# user - history (n, n+, n-)
print("Observation space of user:")
print(u.observation_space(), '\n')
print("User history:")
print(observation[0]['history'], '\n')
# user - last review time of each card
print("User last_review:")
print(observation[0]['last_review'], '\n')
# user - current time (you can get the delta by time - last_review)
print("User time:")
print(observation[0]['time'], '\n')
# user - time bidget (deadline)
print("User time budget:")
print(observation[0]['time_budget'])
# ltsenv.reset()
lts_gym_env = recsim_gym.RecSimGymEnv(ltsenv, reward, update_metrics)
lts_gym_env.reset()
try_observation = lts_gym_env.reset()
for i in range(len(try_observation['doc'])):
print(try_observation['user']['history'][i])
#print(try_observation['user']['history'].shape[0])
my_list = [10.0, 5.5, 8.1, 2.0, 1.57]
max_value = max(my_list)
print(my_list.index(max(my_list)))
def create_agent(sess, environment, eval_mode, summary_writer=None):
kwargs = {
'observation_space': environment.observation_space,
'action_space': environment.action_space,
'summary_writer': summary_writer,
'eval_mode': eval_mode,
}
return full_slate_q_agent.FullSlateQAgent(sess, **kwargs)
#@title Importing RecSim components
from recsim.environments import interest_evolution
from recsim.agents import full_slate_q_agent
from recsim.simulator import runner_lib
tmp_base_dir = '/tmp/recsim/'
runner = runner_lib.TrainRunner(
base_dir=tmp_base_dir,
create_agent_fn=create_agent,
env=lts_gym_env,
episode_log_file="",
max_training_steps=5,
num_iterations=1
)
runner.run_experiment()
# Commented out IPython magic to ensure Python compatibility.
# Load the TensorBoard notebook extension
# %load_ext tensorboard
#@title Tensorboard
# %tensorboard --logdir=/tmp/recsim/