anki/rslib/proto/python.rs
Damien Elmes 81ac5a91ad Include backend comments in Python and Rust codegen
This is of limited usefulness at the moment, as it doesn't help consumers
of the public API.

Also removed detached comments from the included comments.
2023-06-16 10:59:00 +10:00

260 lines
8.0 KiB
Rust

// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::io::BufWriter;
use std::io::Write;
use std::path::Path;
use anki_io::create_dir_all;
use anki_io::create_file;
use anyhow::Result;
use inflections::Inflect;
use prost_reflect::DescriptorPool;
use prost_reflect::FieldDescriptor;
use prost_reflect::Kind;
use prost_reflect::MessageDescriptor;
use prost_reflect::MethodDescriptor;
use prost_reflect::ServiceDescriptor;
use crate::utils::Comments;
pub(crate) fn write_python_interface(pool: &DescriptorPool) -> Result<()> {
let output_path = Path::new("../../out/pylib/anki/_backend_generated.py");
create_dir_all(output_path.parent().unwrap())?;
let mut out = BufWriter::new(create_file(output_path)?);
write_header(&mut out)?;
for service in pool.services() {
if service.name() == "AnkidroidService" {
continue;
}
let comments = Comments::from_file(service.parent_file().file_descriptor_proto());
for method in service.methods() {
render_method(&service, &method, &comments, &mut out);
}
}
Ok(())
}
/// Generates text like the following:
///
/// def get_field_names_raw(self, message: bytes) -> bytes:
/// return self._run_command(7, 16, message)
///
/// def get_field_names(self, ntid: int) -> Sequence[str]:
/// message = anki.notetypes_pb2.NotetypeId(ntid=ntid)
/// raw_bytes = self._run_command(7, 16, message.SerializeToString())
/// output = anki.generic_pb2.StringList()
/// output.ParseFromString(raw_bytes)
/// return output.vals
fn render_method(
service: &ServiceDescriptor,
method: &MethodDescriptor,
comments: &Comments,
out: &mut impl Write,
) {
let method_name = method.name().to_snake_case();
let input = method.input();
let output = method.output();
let service_idx = service.index();
let method_idx = method.index();
let comments = format_comments(comments.get_for_path(method.path()));
// raw bytes
write!(
out,
r#" def {method_name}_raw(self, message: bytes) -> bytes:
{comments}return self._run_command({service_idx}, {method_idx}, message)
"#
)
.unwrap();
// (possibly destructured) message
let (input_params, input_assign) = maybe_destructured_input(&input);
let output_constructor = full_name_to_python(output.full_name());
let (output_msg_or_single_field, output_type) = maybe_destructured_output(&output);
write!(
out,
r#" def {method_name}({input_params}) -> {output_type}:
{comments}{input_assign}
raw_bytes = self._run_command({service_idx}, {method_idx}, message.SerializeToString())
output = {output_constructor}()
output.ParseFromString(raw_bytes)
return {output_msg_or_single_field}
"#
)
.unwrap();
}
fn format_comments(comments: Option<&str>) -> String {
comments
.as_ref()
.map(|c| {
format!(
r#""""{c}"""
"#
)
})
.unwrap_or_default()
}
/// If any of the following apply to the input type:
/// - it has a single field
/// - its name ends in Request
/// - it has any optional fields
/// ...then destructuring will be skipped, and the method will take the input
/// message directly. Returns (params_line, assignment_lines)
fn maybe_destructured_input(input: &MessageDescriptor) -> (String, String) {
if (input.name().ends_with("Request") || input.fields().len() < 2)
&& input.oneofs().next().is_none()
{
// destructure
let method_args = build_method_arguments(input);
let input_type = full_name_to_python(input.full_name());
let input_message_args = build_input_message_arguments(input);
let assignment = format!("message = {input_type}({input_message_args})",);
(method_args, assignment)
} else {
// no destructure
let params = format!("self, message: {}", full_name_to_python(input.full_name()));
let assignment = String::new();
(params, assignment)
}
}
/// e.g. "self, *, note_ids: Sequence[int], new_fields: Sequence[int]"
fn build_method_arguments(input: &MessageDescriptor) -> String {
let fields = input.fields();
let mut args = vec!["self".to_string()];
if fields.len() >= 2 {
args.push("*".to_string());
}
for field in fields {
let arg = format!("{}: {}", field.name(), python_type(&field));
args.push(arg);
}
args.join(", ")
}
/// e.g. "note_ids=note_ids, new_fields=new_fields"
fn build_input_message_arguments(input: &MessageDescriptor) -> String {
input
.fields()
.map(|field| {
let name = field.name();
format!("{name}={name}")
})
.collect::<Vec<_>>()
.join(", ")
}
// If output type has a single field and is not an enum, we return its single
// field value directly. Returns (expr, type), where expr is 'output' or
// 'output.<only_field>'.
fn maybe_destructured_output(output: &MessageDescriptor) -> (String, String) {
let first_field = output.fields().next();
if output.fields().len() == 1 && !matches!(first_field.as_ref().unwrap().kind(), Kind::Enum(_))
{
let field = first_field.unwrap();
(format!("output.{}", field.name()), python_type(&field))
} else {
("output".into(), full_name_to_python(output.full_name()))
}
}
/// e.g. uint32 -> int; repeated bool -> Sequence[bool]
fn python_type(field: &FieldDescriptor) -> String {
let kind = match field.kind() {
Kind::Int32
| Kind::Int64
| Kind::Uint32
| Kind::Uint64
| Kind::Sint32
| Kind::Sint64
| Kind::Fixed32
| Kind::Fixed64
| Kind::Sfixed32
| Kind::Sfixed64 => "int".into(),
Kind::Float | Kind::Double => "float".into(),
Kind::Bool => "bool".into(),
Kind::String => "str".into(),
Kind::Bytes => "bytes".into(),
Kind::Message(msg) => full_name_to_python(msg.full_name()),
Kind::Enum(en) => format!("{}.V", full_name_to_python(en.full_name())),
};
if field.is_list() {
format!("Sequence[{}]", kind)
} else if field.is_map() {
let map_kind = field.kind();
let map_kind = map_kind.as_message().unwrap();
let map_kv: Vec<_> = map_kind.fields().map(|f| python_type(&f)).collect();
format!("Mapping[{}, {}]", map_kv[0], map_kv[1])
} else {
kind
}
}
// e.g. anki.import_export.ImportResponse ->
// anki.import_export_pb2.ImportResponse
fn full_name_to_python(name: &str) -> String {
let mut name = name.splitn(3, '.');
format!(
"{}.{}_pb2.{}",
name.next().unwrap(),
name.next().unwrap(),
name.next().unwrap()
)
}
fn write_header(out: &mut impl Write) -> Result<()> {
out.write_all(
br#"# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; https://www.gnu.org/licenses/agpl.html
# pylint: skip-file
from __future__ import annotations
"""
THIS FILE IS AUTOMATICALLY GENERATED - DO NOT EDIT.
Please do not access methods on the backend directly - they may be changed
or removed at any time. Instead, please use the methods on the collection
instead. Eg, don't use col.backend.all_deck_config(), instead use
col.decks.all_config()
"""
from typing import *
import anki
import anki.backend_pb2
import anki.card_rendering_pb2
import anki.cards_pb2
import anki.collection_pb2
import anki.config_pb2
import anki.deckconfig_pb2
import anki.decks_pb2
import anki.i18n_pb2
import anki.image_occlusion_pb2
import anki.import_export_pb2
import anki.links_pb2
import anki.media_pb2
import anki.notes_pb2
import anki.notetypes_pb2
import anki.scheduler_pb2
import anki.search_pb2
import anki.stats_pb2
import anki.sync_pb2
import anki.tags_pb2
class RustBackendGenerated:
def _run_command(self, service: int, method: int, input: Any) -> bytes:
raise Exception("not implemented")
"#,
)?;
Ok(())
}