anki/rslib/proto/python.rs

260 lines
8.0 KiB
Rust
Raw Normal View History

// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::io::BufWriter;
use std::io::Write;
use std::path::Path;
2023-06-12 09:10:22 +02:00
use anki_io::create_dir_all;
use anki_io::create_file;
use anyhow::Result;
use inflections::Inflect;
use prost_reflect::DescriptorPool;
use prost_reflect::FieldDescriptor;
use prost_reflect::Kind;
use prost_reflect::MessageDescriptor;
use prost_reflect::MethodDescriptor;
use prost_reflect::ServiceDescriptor;
use crate::utils::Comments;
pub(crate) fn write_python_interface(pool: &DescriptorPool) -> Result<()> {
let output_path = Path::new("../../out/pylib/anki/_backend_generated.py");
2023-06-12 09:10:22 +02:00
create_dir_all(output_path.parent().unwrap())?;
let mut out = BufWriter::new(create_file(output_path)?);
write_header(&mut out)?;
for service in pool.services() {
if service.name() == "AnkidroidService" {
continue;
}
let comments = Comments::from_file(service.parent_file().file_descriptor_proto());
for method in service.methods() {
render_method(&service, &method, &comments, &mut out);
}
}
Ok(())
}
/// Generates text like the following:
///
/// def get_field_names_raw(self, message: bytes) -> bytes:
/// return self._run_command(7, 16, message)
///
/// def get_field_names(self, ntid: int) -> Sequence[str]:
/// message = anki.notetypes_pb2.NotetypeId(ntid=ntid)
/// raw_bytes = self._run_command(7, 16, message.SerializeToString())
/// output = anki.generic_pb2.StringList()
/// output.ParseFromString(raw_bytes)
/// return output.vals
fn render_method(
service: &ServiceDescriptor,
method: &MethodDescriptor,
comments: &Comments,
out: &mut impl Write,
) {
let method_name = method.name().to_snake_case();
let input = method.input();
let output = method.output();
let service_idx = service.index();
let method_idx = method.index();
let comments = format_comments(comments.get_for_path(method.path()));
// raw bytes
write!(
out,
r#" def {method_name}_raw(self, message: bytes) -> bytes:
{comments}return self._run_command({service_idx}, {method_idx}, message)
"#
)
.unwrap();
// (possibly destructured) message
let (input_params, input_assign) = maybe_destructured_input(&input);
let output_constructor = full_name_to_python(output.full_name());
let (output_msg_or_single_field, output_type) = maybe_destructured_output(&output);
write!(
out,
r#" def {method_name}({input_params}) -> {output_type}:
{comments}{input_assign}
raw_bytes = self._run_command({service_idx}, {method_idx}, message.SerializeToString())
output = {output_constructor}()
output.ParseFromString(raw_bytes)
return {output_msg_or_single_field}
"#
)
.unwrap();
}
fn format_comments(comments: Option<&str>) -> String {
comments
.as_ref()
.map(|c| {
format!(
r#""""{c}"""
"#
)
})
.unwrap_or_default()
}
/// If any of the following apply to the input type:
/// - it has a single field
/// - its name ends in Request
/// - it has any optional fields
/// ...then destructuring will be skipped, and the method will take the input
/// message directly. Returns (params_line, assignment_lines)
fn maybe_destructured_input(input: &MessageDescriptor) -> (String, String) {
if (input.name().ends_with("Request") || input.fields().len() < 2)
&& input.oneofs().next().is_none()
{
// destructure
let method_args = build_method_arguments(input);
let input_type = full_name_to_python(input.full_name());
let input_message_args = build_input_message_arguments(input);
let assignment = format!("message = {input_type}({input_message_args})",);
(method_args, assignment)
} else {
// no destructure
let params = format!("self, message: {}", full_name_to_python(input.full_name()));
let assignment = String::new();
(params, assignment)
}
}
/// e.g. "self, *, note_ids: Sequence[int], new_fields: Sequence[int]"
fn build_method_arguments(input: &MessageDescriptor) -> String {
let fields = input.fields();
let mut args = vec!["self".to_string()];
if fields.len() >= 2 {
args.push("*".to_string());
}
for field in fields {
let arg = format!("{}: {}", field.name(), python_type(&field));
args.push(arg);
}
args.join(", ")
}
/// e.g. "note_ids=note_ids, new_fields=new_fields"
fn build_input_message_arguments(input: &MessageDescriptor) -> String {
input
.fields()
.map(|field| {
let name = field.name();
format!("{name}={name}")
})
.collect::<Vec<_>>()
.join(", ")
}
// If output type has a single field and is not an enum, we return its single
// field value directly. Returns (expr, type), where expr is 'output' or
// 'output.<only_field>'.
fn maybe_destructured_output(output: &MessageDescriptor) -> (String, String) {
let first_field = output.fields().next();
if output.fields().len() == 1 && !matches!(first_field.as_ref().unwrap().kind(), Kind::Enum(_))
{
let field = first_field.unwrap();
(format!("output.{}", field.name()), python_type(&field))
} else {
("output".into(), full_name_to_python(output.full_name()))
}
}
/// e.g. uint32 -> int; repeated bool -> Sequence[bool]
fn python_type(field: &FieldDescriptor) -> String {
let kind = match field.kind() {
Kind::Int32
| Kind::Int64
| Kind::Uint32
| Kind::Uint64
| Kind::Sint32
| Kind::Sint64
| Kind::Fixed32
| Kind::Fixed64
| Kind::Sfixed32
| Kind::Sfixed64 => "int".into(),
Kind::Float | Kind::Double => "float".into(),
Kind::Bool => "bool".into(),
Kind::String => "str".into(),
Kind::Bytes => "bytes".into(),
Kind::Message(msg) => full_name_to_python(msg.full_name()),
Kind::Enum(en) => format!("{}.V", full_name_to_python(en.full_name())),
};
if field.is_list() {
format!("Sequence[{}]", kind)
} else if field.is_map() {
let map_kind = field.kind();
let map_kind = map_kind.as_message().unwrap();
let map_kv: Vec<_> = map_kind.fields().map(|f| python_type(&f)).collect();
format!("Mapping[{}, {}]", map_kv[0], map_kv[1])
} else {
kind
}
}
// e.g. anki.import_export.ImportResponse ->
// anki.import_export_pb2.ImportResponse
fn full_name_to_python(name: &str) -> String {
let mut name = name.splitn(3, '.');
format!(
"{}.{}_pb2.{}",
name.next().unwrap(),
name.next().unwrap(),
name.next().unwrap()
)
}
fn write_header(out: &mut impl Write) -> Result<()> {
out.write_all(
br#"# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; https://www.gnu.org/licenses/agpl.html
# pylint: skip-file
from __future__ import annotations
"""
THIS FILE IS AUTOMATICALLY GENERATED - DO NOT EDIT.
Please do not access methods on the backend directly - they may be changed
or removed at any time. Instead, please use the methods on the collection
instead. Eg, don't use col.backend.all_deck_config(), instead use
col.decks.all_config()
"""
from typing import *
import anki
import anki.backend_pb2
import anki.card_rendering_pb2
import anki.cards_pb2
import anki.collection_pb2
import anki.config_pb2
import anki.deckconfig_pb2
import anki.decks_pb2
import anki.i18n_pb2
import anki.image_occlusion_pb2
import anki.import_export_pb2
import anki.links_pb2
import anki.media_pb2
import anki.notes_pb2
import anki.notetypes_pb2
import anki.scheduler_pb2
import anki.search_pb2
import anki.stats_pb2
import anki.sync_pb2
import anki.tags_pb2
class RustBackendGenerated:
def _run_command(self, service: int, method: int, input: Any) -> bytes:
raise Exception("not implemented")
"#,
)?;
Ok(())
}