anki/pylib/anki/rsbackend.py
2020-04-06 15:39:47 +10:00

625 lines
19 KiB
Python

# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
# pylint: skip-file
"""
Python bindings for Anki's Rust libraries.
Please do not access methods on the backend directly - they may be changed
or removed at any time. Instead, please use the methods on the collection
instead. Eg, don't use col.backend.all_deck_config(), instead use
col.decks.all_config()
"""
import enum
import json
import os
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NewType,
NoReturn,
Optional,
Sequence,
Tuple,
Union,
)
import ankirspy # pytype: disable=import-error
import anki.backend_pb2 as pb
import anki.buildinfo
from anki import hooks
from anki.dbproxy import Row as DBRow
from anki.dbproxy import ValueForDB
from anki.fluent_pb2 import FluentString as TR
from anki.models import AllTemplateReqs
from anki.sound import AVTag, SoundOrVideoTag, TTSTag
from anki.types import assert_impossible_literal
from anki.utils import intTime
assert ankirspy.buildhash() == anki.buildinfo.buildhash
SchedTimingToday = pb.SchedTimingTodayOut
BuiltinSortKind = pb.BuiltinSortKind
BackendCard = pb.Card
TagUsnTuple = pb.TagUsnTuple
try:
import orjson
except:
# add compat layer for 32 bit builds that can't use orjson
print("reverting to stock json")
class orjson: # type: ignore
def dumps(obj: Any) -> bytes:
return json.dumps(obj).encode("utf8")
loads = json.loads
class Interrupted(Exception):
pass
class StringError(Exception):
def __str__(self) -> str:
return self.args[0] # pylint: disable=unsubscriptable-object
NetworkErrorKind = pb.NetworkError.NetworkErrorKind
SyncErrorKind = pb.SyncError.SyncErrorKind
class NetworkError(StringError):
def kind(self) -> NetworkErrorKind:
return self.args[1]
class SyncError(StringError):
def kind(self) -> SyncErrorKind:
return self.args[1]
class IOError(StringError):
pass
class DBError(StringError):
pass
class TemplateError(StringError):
pass
def proto_exception_to_native(err: pb.BackendError) -> Exception:
val = err.WhichOneof("value")
if val == "interrupted":
return Interrupted()
elif val == "network_error":
return NetworkError(err.localized, err.network_error.kind)
elif val == "sync_error":
return SyncError(err.localized, err.sync_error.kind)
elif val == "io_error":
return IOError(err.localized)
elif val == "db_error":
return DBError(err.localized)
elif val == "template_parse":
return TemplateError(err.localized)
elif val == "invalid_input":
return StringError(err.localized)
else:
assert_impossible_literal(val)
def proto_template_reqs_to_legacy(
reqs: List[pb.TemplateRequirement],
) -> AllTemplateReqs:
legacy_reqs = []
for (idx, req) in enumerate(reqs):
kind = req.WhichOneof("value")
# fixme: sorting is for the unit tests - should check if any
# code depends on the order
if kind == "any":
legacy_reqs.append((idx, "any", sorted(req.any.ords)))
elif kind == "all":
legacy_reqs.append((idx, "all", sorted(req.all.ords)))
else:
l: List[int] = []
legacy_reqs.append((idx, "none", l))
return legacy_reqs
def av_tag_to_native(tag: pb.AVTag) -> AVTag:
val = tag.WhichOneof("value")
if val == "sound_or_video":
return SoundOrVideoTag(filename=tag.sound_or_video)
else:
return TTSTag(
field_text=tag.tts.field_text,
lang=tag.tts.lang,
voices=list(tag.tts.voices),
other_args=list(tag.tts.other_args),
speed=tag.tts.speed,
)
@dataclass
class TemplateReplacement:
field_name: str
current_text: str
filters: List[str]
TemplateReplacementList = List[Union[str, TemplateReplacement]]
MediaSyncProgress = pb.MediaSyncProgress
MediaCheckOutput = pb.MediaCheckOut
FormatTimeSpanContext = pb.FormatTimeSpanIn.Context
@dataclass
class ExtractedLatex:
filename: str
latex_body: str
@dataclass
class ExtractedLatexOutput:
html: str
latex: List[ExtractedLatex]
class ProgressKind(enum.Enum):
MediaSync = 0
MediaCheck = 1
@dataclass
class Progress:
kind: ProgressKind
val: Union[MediaSyncProgress, str]
def proto_replacement_list_to_native(
nodes: List[pb.RenderedTemplateNode],
) -> TemplateReplacementList:
results: TemplateReplacementList = []
for node in nodes:
if node.WhichOneof("value") == "text":
results.append(node.text)
else:
results.append(
TemplateReplacement(
field_name=node.replacement.field_name,
current_text=node.replacement.current_text,
filters=list(node.replacement.filters),
)
)
return results
def proto_progress_to_native(progress: pb.Progress) -> Progress:
kind = progress.WhichOneof("value")
if kind == "media_sync":
return Progress(kind=ProgressKind.MediaSync, val=progress.media_sync)
elif kind == "media_check":
return Progress(kind=ProgressKind.MediaCheck, val=progress.media_check)
else:
assert_impossible_literal(kind)
def _on_progress(progress_bytes: bytes) -> bool:
progress = pb.Progress()
progress.ParseFromString(progress_bytes)
native_progress = proto_progress_to_native(progress)
return hooks.bg_thread_progress_callback(True, native_progress)
class RustBackend:
def __init__(
self,
ftl_folder: Optional[str] = None,
langs: Optional[List[str]] = None,
server: bool = False,
) -> None:
# pick up global defaults if not provided
if ftl_folder is None:
ftl_folder = os.path.join(anki.lang.locale_folder, "fluent")
if langs is None:
langs = [anki.lang.currentLang]
init_msg = pb.BackendInit(
locale_folder_path=ftl_folder, preferred_langs=langs, server=server,
)
self._backend = ankirspy.open_backend(init_msg.SerializeToString())
self._backend.set_progress_callback(_on_progress)
def _run_command(
self, input: pb.BackendInput, release_gil: bool = False
) -> pb.BackendOutput:
input_bytes = input.SerializeToString()
output_bytes = self._backend.command(input_bytes, release_gil)
output = pb.BackendOutput()
output.ParseFromString(output_bytes)
kind = output.WhichOneof("value")
if kind == "error":
raise proto_exception_to_native(output.error)
else:
return output
def open_collection(
self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str
):
self._run_command(
pb.BackendInput(
open_collection=pb.OpenCollectionIn(
collection_path=col_path,
media_folder_path=media_folder_path,
media_db_path=media_db_path,
log_path=log_path,
)
),
release_gil=True,
)
def close_collection(self, downgrade=True):
self._run_command(
pb.BackendInput(
close_collection=pb.CloseCollectionIn(downgrade_to_schema11=downgrade)
),
release_gil=True,
)
def template_requirements(
self, template_fronts: List[str], field_map: Dict[str, int]
) -> AllTemplateReqs:
input = pb.BackendInput(
template_requirements=pb.TemplateRequirementsIn(
template_front=template_fronts, field_names_to_ordinals=field_map
)
)
output = self._run_command(input).template_requirements
reqs: List[pb.TemplateRequirement] = output.requirements # type: ignore
return proto_template_reqs_to_legacy(reqs)
def sched_timing_today(
self,
created_secs: int,
created_mins_west: Optional[int],
now_mins_west: Optional[int],
rollover: Optional[int],
) -> SchedTimingToday:
if created_mins_west is not None:
crt_west = pb.OptionalInt32(val=created_mins_west)
else:
crt_west = None
if now_mins_west is not None:
now_west = pb.OptionalInt32(val=now_mins_west)
else:
now_west = None
if rollover is not None:
roll = pb.OptionalInt32(val=rollover)
else:
roll = None
return self._run_command(
pb.BackendInput(
sched_timing_today=pb.SchedTimingTodayIn(
created_secs=created_secs,
now_secs=intTime(),
created_mins_west=crt_west,
now_mins_west=now_west,
rollover_hour=roll,
)
)
).sched_timing_today
def render_card(
self, qfmt: str, afmt: str, fields: Dict[str, str], card_ord: int
) -> Tuple[TemplateReplacementList, TemplateReplacementList]:
out = self._run_command(
pb.BackendInput(
render_card=pb.RenderCardIn(
question_template=qfmt,
answer_template=afmt,
fields=fields,
card_ordinal=card_ord,
)
)
).render_card
qnodes = proto_replacement_list_to_native(out.question_nodes) # type: ignore
anodes = proto_replacement_list_to_native(out.answer_nodes) # type: ignore
return (qnodes, anodes)
def local_minutes_west(self, stamp: int) -> int:
return self._run_command(
pb.BackendInput(local_minutes_west=stamp)
).local_minutes_west
def strip_av_tags(self, text: str) -> str:
return self._run_command(pb.BackendInput(strip_av_tags=text)).strip_av_tags
def extract_av_tags(
self, text: str, question_side: bool
) -> Tuple[str, List[AVTag]]:
out = self._run_command(
pb.BackendInput(
extract_av_tags=pb.ExtractAVTagsIn(
text=text, question_side=question_side
)
)
).extract_av_tags
native_tags = list(map(av_tag_to_native, out.av_tags))
return out.text, native_tags
def extract_latex(
self, text: str, svg: bool, expand_clozes: bool
) -> ExtractedLatexOutput:
out = self._run_command(
pb.BackendInput(
extract_latex=pb.ExtractLatexIn(
text=text, svg=svg, expand_clozes=expand_clozes
)
)
).extract_latex
return ExtractedLatexOutput(
html=out.text,
latex=[
ExtractedLatex(filename=l.filename, latex_body=l.latex_body)
for l in out.latex
],
)
def add_file_to_media_folder(self, desired_name: str, data: bytes) -> str:
return self._run_command(
pb.BackendInput(
add_media_file=pb.AddMediaFileIn(desired_name=desired_name, data=data)
)
).add_media_file
def sync_media(self, hkey: str, endpoint: str) -> None:
self._run_command(
pb.BackendInput(sync_media=pb.SyncMediaIn(hkey=hkey, endpoint=endpoint,)),
release_gil=True,
)
def check_media(self) -> MediaCheckOutput:
return self._run_command(
pb.BackendInput(check_media=pb.Empty()), release_gil=True,
).check_media
def trash_media_files(self, fnames: List[str]) -> None:
self._run_command(
pb.BackendInput(trash_media_files=pb.TrashMediaFilesIn(fnames=fnames))
)
def translate(self, key: TR, **kwargs: Union[str, int, float]) -> str:
return self._run_command(
pb.BackendInput(translate_string=translate_string_in(key, **kwargs))
).translate_string
def format_time_span(
self,
seconds: float,
context: FormatTimeSpanContext = FormatTimeSpanContext.INTERVALS,
) -> str:
return self._run_command(
pb.BackendInput(
format_time_span=pb.FormatTimeSpanIn(seconds=seconds, context=context)
)
).format_time_span
def studied_today(self, cards: int, seconds: float) -> str:
return self._run_command(
pb.BackendInput(
studied_today=pb.StudiedTodayIn(cards=cards, seconds=seconds)
)
).studied_today
def learning_congrats_msg(self, next_due: float, remaining: int) -> str:
return self._run_command(
pb.BackendInput(
congrats_learn_msg=pb.CongratsLearnMsgIn(
next_due=next_due, remaining=remaining
)
)
).congrats_learn_msg
def empty_trash(self):
self._run_command(pb.BackendInput(empty_trash=pb.Empty()))
def restore_trash(self):
self._run_command(pb.BackendInput(restore_trash=pb.Empty()))
def db_query(
self, sql: str, args: Sequence[ValueForDB], first_row_only: bool
) -> List[DBRow]:
return self._db_command(
dict(kind="query", sql=sql, args=args, first_row_only=first_row_only)
)
def db_execute_many(self, sql: str, args: List[List[ValueForDB]]) -> List[DBRow]:
return self._db_command(dict(kind="executemany", sql=sql, args=args))
def db_begin(self) -> None:
return self._db_command(dict(kind="begin"))
def db_commit(self) -> None:
return self._db_command(dict(kind="commit"))
def db_rollback(self) -> None:
return self._db_command(dict(kind="rollback"))
def _db_command(self, input: Dict[str, Any]) -> Any:
return orjson.loads(self._backend.db_command(orjson.dumps(input)))
def search_cards(
self, search: str, order: Union[bool, str, int], reverse: bool = False
) -> Sequence[int]:
if isinstance(order, str):
mode = pb.SortOrder(custom=order)
elif order is True:
mode = pb.SortOrder(from_config=pb.Empty())
elif order is False:
mode = pb.SortOrder(none=pb.Empty())
else:
# sadly we can't use the protobuf type in a Union, so we
# have to accept an int and convert it
kind = BuiltinSortKind.Value(BuiltinSortKind.Name(order))
mode = pb.SortOrder(
builtin=pb.BuiltinSearchOrder(kind=kind, reverse=reverse)
)
return self._run_command(
pb.BackendInput(search_cards=pb.SearchCardsIn(search=search, order=mode))
).search_cards.card_ids
def search_notes(self, search: str) -> Sequence[int]:
return self._run_command(
pb.BackendInput(search_notes=pb.SearchNotesIn(search=search))
).search_notes.note_ids
def get_card(self, cid: int) -> Optional[pb.Card]:
output = self._run_command(pb.BackendInput(get_card=cid)).get_card
if output.HasField("card"):
return output.card
else:
return None
def update_card(self, card: BackendCard) -> None:
self._run_command(pb.BackendInput(update_card=card))
# returns the new card id
def add_card(self, card: BackendCard) -> int:
return self._run_command(pb.BackendInput(add_card=card)).add_card
def get_deck_config(self, dcid: int) -> Dict[str, Any]:
jstr = self._run_command(pb.BackendInput(get_deck_config=dcid)).get_deck_config
return orjson.loads(jstr)
def add_or_update_deck_config(self, conf: Dict[str, Any], preserve_usn) -> None:
conf_json = orjson.dumps(conf)
id = self._run_command(
pb.BackendInput(
add_or_update_deck_config=pb.AddOrUpdateDeckConfigIn(
config=conf_json, preserve_usn_and_mtime=preserve_usn
)
)
).add_or_update_deck_config
conf["id"] = id
def all_deck_config(self) -> Sequence[Dict[str, Any]]:
jstr = self._run_command(
pb.BackendInput(all_deck_config=pb.Empty())
).all_deck_config
return orjson.loads(jstr)
def new_deck_config(self) -> Dict[str, Any]:
jstr = self._run_command(
pb.BackendInput(new_deck_config=pb.Empty())
).new_deck_config
return orjson.loads(jstr)
def remove_deck_config(self, dcid: int) -> None:
self._run_command(pb.BackendInput(remove_deck_config=dcid))
def abort_media_sync(self):
self._run_command(pb.BackendInput(abort_media_sync=pb.Empty()))
def all_tags(self) -> Iterable[TagUsnTuple]:
return self._run_command(pb.BackendInput(all_tags=pb.Empty())).all_tags.tags
def canonify_tags(self, tags: str) -> Tuple[str, bool]:
out = self._run_command(pb.BackendInput(canonify_tags=tags)).canonify_tags
return (out.tags, out.tag_list_changed)
def register_tags(self, tags: str, usn: Optional[int], clear_first: bool) -> bool:
if usn is None:
preserve_usn = False
usn_ = 0
else:
usn_ = usn
preserve_usn = True
return self._run_command(
pb.BackendInput(
register_tags=pb.RegisterTagsIn(
tags=tags,
usn=usn_,
preserve_usn=preserve_usn,
clear_first=clear_first,
)
)
).register_tags
def before_upload(self):
self._run_command(pb.BackendInput(before_upload=pb.Empty()))
def get_changed_tags(self, usn: int) -> List[str]:
return list(
self._run_command(
pb.BackendInput(get_changed_tags=usn)
).get_changed_tags.tags
)
def get_config_json(self, key: str) -> Any:
b = self._run_command(pb.BackendInput(get_config_json=key)).get_config_json
if b == b"":
raise KeyError
return orjson.loads(b)
def set_config_json(self, key: str, val: Any):
self._run_command(
pb.BackendInput(
set_config_json=pb.SetConfigJson(key=key, val=orjson.dumps(val))
)
)
def remove_config(self, key: str):
self._run_command(
pb.BackendInput(
set_config_json=pb.SetConfigJson(key=key, remove=pb.Empty())
)
)
def get_all_config(self) -> Dict[str, Any]:
jstr = self._run_command(
pb.BackendInput(get_all_config=pb.Empty())
).get_all_config
return orjson.loads(jstr)
def set_all_config(self, conf: Dict[str, Any]):
self._run_command(pb.BackendInput(set_all_config=orjson.dumps(conf)))
def translate_string_in(
key: TR, **kwargs: Union[str, int, float]
) -> pb.TranslateStringIn:
args = {}
for (k, v) in kwargs.items():
if isinstance(v, str):
args[k] = pb.TranslateArgValue(str=v)
else:
args[k] = pb.TranslateArgValue(number=v)
return pb.TranslateStringIn(key=key, args=args)
# temporarily force logging of media handling
if "RUST_LOG" not in os.environ:
os.environ["RUST_LOG"] = "warn,anki::media=debug"