Calculate parameters automatically

Logic derived from d8e2f6a0ff

Closes #2661
This commit is contained in:
Damien Elmes 2023-09-18 16:43:12 +10:00
parent e7bf248a62
commit 6074865763
3 changed files with 185 additions and 162 deletions

View File

@ -340,7 +340,11 @@ message ComputeFsrsWeightsResponse {
message ComputeOptimalRetentionRequest {
repeated float weights = 1;
OptimalRetentionParameters params = 2;
uint32 deck_size = 2;
uint32 days_to_simulate = 3;
uint32 max_seconds_of_study_per_day = 4;
uint32 max_interval = 5;
string search = 6;
}
message ComputeOptimalRetentionResponse {
@ -348,15 +352,11 @@ message ComputeOptimalRetentionResponse {
}
message OptimalRetentionParameters {
uint32 deck_size = 2;
uint32 days_to_simulate = 3;
uint32 max_seconds_of_study_per_day = 4;
uint32 max_interval = 5;
double recall_secs_hard = 6;
double recall_secs_good = 7;
double recall_secs_easy = 8;
uint32 forget_secs = 9;
uint32 learn_secs = 10;
double forget_secs = 9;
double learn_secs = 10;
double first_rating_probability_again = 11;
double first_rating_probability_hard = 12;
double first_rating_probability_good = 13;

View File

@ -5,8 +5,10 @@ use anki_proto::scheduler::ComputeOptimalRetentionRequest;
use anki_proto::scheduler::OptimalRetentionParameters;
use fsrs::SimulatorConfig;
use fsrs::FSRS;
use itertools::Itertools;
use crate::prelude::*;
use crate::revlog::RevlogReviewKind;
use crate::search::SortMode;
#[derive(Default, Clone, Copy, Debug)]
@ -22,74 +24,177 @@ impl Collection {
) -> Result<f32> {
let mut anki_progress = self.new_progress_handler::<ComputeRetentionProgress>();
let fsrs = FSRS::new(None)?;
let p = req.params.as_ref().or_invalid("missing params")?;
Ok(fsrs.optimal_retention(
&SimulatorConfig {
deck_size: p.deck_size as usize,
learn_span: p.days_to_simulate as usize,
max_cost_perday: p.max_seconds_of_study_per_day as f64,
max_ivl: p.max_interval as f64,
recall_costs: [p.recall_secs_hard, p.recall_secs_good, p.recall_secs_easy],
forget_cost: p.forget_secs as f64,
learn_cost: p.learn_secs as f64,
first_rating_prob: [
p.first_rating_probability_again,
p.first_rating_probability_hard,
p.first_rating_probability_good,
p.first_rating_probability_easy,
],
review_rating_prob: [
p.review_rating_probability_hard,
p.review_rating_probability_good,
p.review_rating_probability_easy,
],
},
&req.weights,
|ip| {
anki_progress
.update(false, |p| {
p.total = ip.total as u32;
p.current = ip.current as u32;
})
.is_ok()
},
)? as f32)
if req.days_to_simulate == 0 {
invalid_input!("no days to simulate")
}
let p = self.get_optimal_retention_parameters(&req.search)?;
Ok(fsrs
.optimal_retention(
&SimulatorConfig {
deck_size: req.deck_size as usize,
learn_span: req.days_to_simulate as usize,
max_cost_perday: req.max_seconds_of_study_per_day as f64,
max_ivl: req.max_interval as f64,
recall_costs: [p.recall_secs_hard, p.recall_secs_good, p.recall_secs_easy],
forget_cost: p.forget_secs,
learn_cost: p.learn_secs,
first_rating_prob: [
p.first_rating_probability_again,
p.first_rating_probability_hard,
p.first_rating_probability_good,
p.first_rating_probability_easy,
],
review_rating_prob: [
p.review_rating_probability_hard,
p.review_rating_probability_good,
p.review_rating_probability_easy,
],
},
&req.weights,
|ip| {
anki_progress
.update(false, |p| {
p.total = ip.total as u32;
p.current = ip.current as u32;
})
.is_ok()
},
)?
.max(0.8)
.min(0.97) as f32)
}
pub fn get_optimal_retention_parameters(
&mut self,
search: &str,
) -> Result<OptimalRetentionParameters> {
let guard = self.search_cards_into_table(search, SortMode::NoOrder)?;
let deck_size = guard.cards as u32;
// if you need access to cards too:
// let cards = self.storage.all_searched_cards()?;
let _revlogs = guard
let revlogs = self
.search_cards_into_table(search, SortMode::NoOrder)?
.col
.storage
.get_revlog_entries_for_searched_cards_in_order()?;
// todo: compute values from revlogs
let first_rating_count = revlogs
.iter()
.filter(|r| {
r.review_kind == RevlogReviewKind::Learning
&& r.last_interval == 0
&& r.button_chosen >= 1
})
.counts_by(|r| r.button_chosen);
let total_first = first_rating_count.values().sum::<usize>() as f64;
let first_rating_prob = if total_first > 0.0 {
let mut arr = [0.0; 4];
first_rating_count
.iter()
.for_each(|(button_chosen, count)| {
arr[*button_chosen as usize - 1] = *count as f64 / total_first
});
arr
} else {
return Err(AnkiError::FsrsInsufficientData);
};
let review_rating_count = revlogs
.iter()
.filter(|r| r.review_kind == RevlogReviewKind::Review && r.button_chosen != 1)
.counts_by(|r| r.button_chosen);
let total_reviews = review_rating_count.values().sum::<usize>() as f64;
let review_rating_prob = if total_reviews > 0.0 {
let mut arr = [0.0; 3];
review_rating_count
.iter()
.filter(|(&button_chosen, ..)| button_chosen >= 2)
.for_each(|(button_chosen, count)| {
arr[*button_chosen as usize - 2] = *count as f64 / total_reviews;
});
arr
} else {
return Err(AnkiError::FsrsInsufficientData);
};
let recall_costs = {
let default = [14.0, 14.0, 10.0, 6.0];
let mut arr = default;
revlogs
.iter()
.filter(|r| r.review_kind == RevlogReviewKind::Review && r.button_chosen > 0)
.sorted_by(|a, b| a.button_chosen.cmp(&b.button_chosen))
.group_by(|r| r.button_chosen)
.into_iter()
.for_each(|(button_chosen, group)| {
let group_vec = group.into_iter().map(|r| r.taken_millis).collect_vec();
let average_secs =
group_vec.iter().sum::<u32>() as f64 / group_vec.len() as f64 / 1000.0;
arr[button_chosen as usize - 1] = average_secs
});
if arr == default {
return Err(AnkiError::FsrsInsufficientData);
}
arr
};
let learn_cost = {
let revlogs_filter = revlogs
.iter()
.filter(|r| r.review_kind == RevlogReviewKind::Learning && r.last_interval == 0)
.map(|r| r.taken_millis);
let count = revlogs_filter.clone().count() as f64;
if count > 0.0 {
revlogs_filter.sum::<u32>() as f64 / count / 1000.0
} else {
return Err(AnkiError::FsrsInsufficientData);
}
};
let forget_cost = {
let review_kind_to_total_millis = revlogs
.iter()
.sorted_by(|a, b| a.cid.cmp(&b.cid).then(a.id.cmp(&b.id)))
.group_by(|r| r.review_kind)
/*
for example:
o x x o o x x x o o x x o x
|<->| |<--->| |<->| |<>|
x means forgotten, there are 4 consecutive sets of internal relearning in this card.
So each group is counted separately, and each group is summed up internally.(following code)
Finally averaging all groups, so sort by cid and id.
*/
.into_iter()
.map(|(review_kind, group)| {
let total_millis: u32 = group.into_iter().map(|r| r.taken_millis).sum();
(review_kind, total_millis)
})
.collect_vec();
let mut group_sec_by_review_kind: [Vec<_>; 5] = Default::default();
for (review_kind, sec) in review_kind_to_total_millis.into_iter() {
group_sec_by_review_kind[review_kind as usize].push(sec)
}
let mut arr = [0.0; 5];
for (review_kind, group) in group_sec_by_review_kind.iter().enumerate() {
if group.is_empty() && review_kind == RevlogReviewKind::Relearning as usize {
return Err(AnkiError::FsrsInsufficientData);
}
let average_secs = group.iter().sum::<u32>() as f64 / group.len() as f64 / 1000.0;
arr[review_kind] = average_secs
}
arr
};
let forget_cost = forget_cost[RevlogReviewKind::Relearning as usize] + recall_costs[0];
let params = OptimalRetentionParameters {
deck_size,
days_to_simulate: 365,
max_seconds_of_study_per_day: 1800,
// this should be filled in by the frontend based on their configured value
max_interval: 0,
recall_secs_hard: 14.0,
recall_secs_good: 10.0,
recall_secs_easy: 6.0,
forget_secs: 50,
learn_secs: 20,
first_rating_probability_again: 0.15,
first_rating_probability_hard: 0.2,
first_rating_probability_good: 0.6,
first_rating_probability_easy: 0.05,
review_rating_probability_hard: 0.3,
review_rating_probability_good: 0.6,
review_rating_probability_easy: 0.1,
recall_secs_hard: recall_costs[1],
recall_secs_good: recall_costs[2],
recall_secs_easy: recall_costs[3],
forget_secs: forget_cost,
learn_secs: learn_cost,
first_rating_probability_again: first_rating_prob[0],
first_rating_probability_hard: first_rating_prob[1],
first_rating_probability_good: first_rating_prob[2],
first_rating_probability_easy: first_rating_prob[3],
review_rating_probability_hard: review_rating_prob[0],
review_rating_probability_good: review_rating_prob[1],
review_rating_probability_easy: review_rating_prob[2],
};
Ok(params)
}

View File

@ -7,12 +7,11 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
ComputeRetentionProgress,
type ComputeWeightsProgress,
} from "@tslib/anki/collection_pb";
import { OptimalRetentionParameters } from "@tslib/anki/scheduler_pb";
import { ComputeOptimalRetentionRequest } from "@tslib/anki/scheduler_pb";
import {
computeFsrsWeights,
computeOptimalRetention,
evaluateWeights,
getOptimalRetentionParameters,
setWantsAbort,
} from "@tslib/backend";
import * as tr from "@tslib/ftl";
@ -39,7 +38,11 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
| ComputeRetentionProgress
| undefined;
let optimalParams = new OptimalRetentionParameters({});
const optimalRetentionRequest = new ComputeOptimalRetentionRequest({
deckSize: 10000,
daysToSimulate: 365,
maxSecondsOfStudyPerDay: 1800,
});
async function computeWeights(): Promise<void> {
if (computing) {
await setWantsAbort({});
@ -125,10 +128,10 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
try {
await runWithBackendProgress(
async () => {
const resp = await computeOptimalRetention({
params: optimalParams,
weights: $config.fsrsWeights,
});
optimalRetentionRequest.maxInterval = $config.maximumReviewInterval;
optimalRetentionRequest.weights = $config.fsrsWeights;
optimalRetentionRequest.search = `preset:"${state.getCurrentName()}"`;
const resp = await computeOptimalRetention(optimalRetentionRequest);
$config.desiredRetention = resp.optimalRetention;
if (computeRetentionProgress) {
computeRetentionProgress.current =
@ -146,23 +149,6 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
}
}
async function getRetentionParams(): Promise<void> {
if (computing) {
return;
}
computing = true;
try {
// await
const resp = await getOptimalRetentionParameters({
search: `preset:"${state.getCurrentName()}"`,
});
optimalParams = resp.params!;
optimalParams.maxInterval = $config.maximumReviewInterval;
} finally {
computing = false;
}
}
$: computeWeightsProgressString = renderWeightProgress(computeWeightsProgress);
$: computeRetentionProgressString = renderRetentionProgress(
computeRetentionProgress,
@ -248,90 +234,22 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
Deck size:
<br />
<input type="number" bind:value={optimalParams.deckSize} />
<input type="number" bind:value={optimalRetentionRequest.deckSize} />
<br />
Days to simulate
<br />
<input type="number" bind:value={optimalParams.daysToSimulate} />
<input type="number" bind:value={optimalRetentionRequest.daysToSimulate} />
<br />
Max seconds of study per day:
<br />
<input type="number" bind:value={optimalParams.maxSecondsOfStudyPerDay} />
<input
type="number"
bind:value={optimalRetentionRequest.maxSecondsOfStudyPerDay}
/>
<br />
Seconds to forget a card (again):
<br />
<input type="number" bind:value={optimalParams.forgetSecs} />
<br />
Seconds to recall a card (hard):
<br />
<input type="number" bind:value={optimalParams.recallSecsHard} />
<br />
Seconds to recall a card (good):
<br />
<input type="number" bind:value={optimalParams.recallSecsGood} />
<br />
Seconds to recall a card (easy):
<br />
<input type="number" bind:value={optimalParams.recallSecsEasy} />
<br />
Seconds to learn a card:
<br />
<input type="number" bind:value={optimalParams.learnSecs} />
<br />
First rating probability (again):
<br />
<input type="number" bind:value={optimalParams.firstRatingProbabilityAgain} />
<br />
First rating probability (hard):
<br />
<input type="number" bind:value={optimalParams.firstRatingProbabilityHard} />
<br />
First rating probability (good):
<br />
<input type="number" bind:value={optimalParams.firstRatingProbabilityGood} />
<br />
First rating probability (easy):
<br />
<input type="number" bind:value={optimalParams.firstRatingProbabilityEasy} />
<br />
Review rating probability (hard):
<br />
<input type="number" bind:value={optimalParams.reviewRatingProbabilityHard} />
<br />
Review rating probability (good):
<br />
<input type="number" bind:value={optimalParams.reviewRatingProbabilityGood} />
<br />
Review rating probability (easy):
<br />
<input type="number" bind:value={optimalParams.reviewRatingProbabilityEasy} />
<br />
<button
class="btn {computing ? 'btn-warning' : 'btn-primary'}"
on:click={() => getRetentionParams()}
>
{#if computing}
{tr.actionsCancel()}
{:else}
{tr.deckConfigGetParams()}
{/if}
</button>
<button
class="btn {computing ? 'btn-warning' : 'btn-primary'}"
on:click={() => computeRetention()}