From 5426164ebf7bdc257a2f470215888c552568a765 Mon Sep 17 00:00:00 2001 From: RumovZ Date: Fri, 4 Mar 2022 09:43:27 +0100 Subject: [PATCH] Add regex tag search (#1707) --- rslib/src/search/builder.rs | 5 ++++- rslib/src/search/parser.rs | 36 ++++++++++++++++++++++++++++++++--- rslib/src/search/sqlwriter.rs | 21 ++++++++++++++------ rslib/src/search/writer.rs | 3 ++- rslib/src/storage/sqlite.rs | 21 ++++++++++++++++++++ 5 files changed, 75 insertions(+), 11 deletions(-) diff --git a/rslib/src/search/builder.rs b/rslib/src/search/builder.rs index 2c33e9871..de4560b3e 100644 --- a/rslib/src/search/builder.rs +++ b/rslib/src/search/builder.rs @@ -134,7 +134,10 @@ impl SearchNode { /// Construct [SearchNode] from an unescaped tag name. pub fn from_tag_name(name: &str) -> Self { - Self::Tag(escape_anki_wildcards_for_search_node(name)) + Self::Tag { + tag: escape_anki_wildcards_for_search_node(name), + is_re: false, + } } /// Construct [SearchNode] from an unescaped notetype name. diff --git a/rslib/src/search/parser.rs b/rslib/src/search/parser.rs index bc6997ef2..195b23ec8 100644 --- a/rslib/src/search/parser.rs +++ b/rslib/src/search/parser.rs @@ -64,7 +64,10 @@ pub enum SearchNode { days: u32, ease: RatingKind, }, - Tag(String), + Tag { + tag: String, + is_re: bool, + }, Duplicates { notetype_id: NotetypeId, text: String, @@ -311,7 +314,7 @@ fn search_node_for_text_with_argument<'a>( Ok(match key.to_ascii_lowercase().as_str() { "deck" => SearchNode::Deck(unescape(val)?), "note" => SearchNode::Notetype(unescape(val)?), - "tag" => SearchNode::Tag(unescape(val)?), + "tag" => parse_tag(val)?, "card" => parse_template(val)?, "flag" => parse_flag(val)?, "resched" => parse_resched(val)?, @@ -334,6 +337,20 @@ fn search_node_for_text_with_argument<'a>( }) } +fn parse_tag(s: &str) -> ParseResult { + Ok(if let Some(re) = s.strip_prefix("re:") { + SearchNode::Tag { + tag: unescape_quotes(re), + is_re: true, + } + } else { + SearchNode::Tag { + tag: unescape(s)?, + is_re: false, + } + }) +} + fn parse_template(s: &str) -> ParseResult { Ok(SearchNode::CardTemplate(match s.parse::() { Ok(n) => TemplateKind::Ordinal(n.max(1) - 1), @@ -820,7 +837,20 @@ mod test { ); assert_eq!(parse("note:basic")?, vec![Search(Notetype("basic".into()))]); - assert_eq!(parse("tag:hard")?, vec![Search(Tag("hard".into()))]); + assert_eq!( + parse("tag:hard")?, + vec![Search(Tag { + tag: "hard".into(), + is_re: false + })] + ); + assert_eq!( + parse(r"tag:re:\\")?, + vec![Search(Tag { + tag: r"\\".into(), + is_re: true + })] + ); assert_eq!( parse("nid:1237123712,2,3")?, vec![Search(NoteIds("1237123712,2,3".into()))] diff --git a/rslib/src/search/sqlwriter.rs b/rslib/src/search/sqlwriter.rs index c919538c2..c84584922 100644 --- a/rslib/src/search/sqlwriter.rs +++ b/rslib/src/search/sqlwriter.rs @@ -156,7 +156,7 @@ impl SqlWriter<'_> { SearchNode::Notetype(notetype) => self.write_notetype(&norm(notetype)), SearchNode::Rated { days, ease } => self.write_rated(">", -i64::from(*days), ease)?, - SearchNode::Tag(tag) => self.write_tag(&norm(tag)), + SearchNode::Tag { tag, is_re } => self.write_tag(&norm(tag), *is_re), SearchNode::State(state) => self.write_state(state)?, SearchNode::Flag(flag) => { write!(self.sql, "(c.flags & 7) == {}", flag).unwrap(); @@ -199,17 +199,19 @@ impl SqlWriter<'_> { .unwrap(); } - fn write_tag(&mut self, text: &str) { - if text.contains(' ') { - write!(self.sql, "false").unwrap(); + fn write_tag(&mut self, tag: &str, is_re: bool) { + if is_re { + self.args.push(format!("(?i){tag}")); + write!(self.sql, "regexp_tags(?{}, n.tags)", self.args.len()).unwrap(); } else { - match text { + match tag { "none" => { write!(self.sql, "n.tags = ''").unwrap(); } "*" => { write!(self.sql, "true").unwrap(); } + s if s.contains(' ') => write!(self.sql, "false").unwrap(), text => { write!(self.sql, "n.tags regexp ?").unwrap(); let re = &to_custom_re(text, r"\S"); @@ -660,7 +662,7 @@ impl SearchNode { SearchNode::UnqualifiedText(_) => RequiredTable::Notes, SearchNode::SingleField { .. } => RequiredTable::Notes, - SearchNode::Tag(_) => RequiredTable::Notes, + SearchNode::Tag { .. } => RequiredTable::Notes, SearchNode::Duplicates { .. } => RequiredTable::Notes, SearchNode::Regex(_) => RequiredTable::Notes, SearchNode::NoCombining(_) => RequiredTable::Notes, @@ -848,6 +850,13 @@ mod test { ); assert_eq!(s(ctx, "tag:none"), ("(n.tags = '')".into(), vec![])); assert_eq!(s(ctx, "tag:*"), ("(true)".into(), vec![])); + assert_eq!( + s(ctx, "tag:re:.ne|tw."), + ( + "(regexp_tags(?1, n.tags))".into(), + vec!["(?i).ne|tw.".into()] + ) + ); // state assert_eq!( diff --git a/rslib/src/search/writer.rs b/rslib/src/search/writer.rs index 51f60f4f5..ce754026b 100644 --- a/rslib/src/search/writer.rs +++ b/rslib/src/search/writer.rs @@ -70,7 +70,7 @@ fn write_search_node(node: &SearchNode) -> String { NotetypeId(NotetypeIdType(i)) => format!("mid:{}", i), Notetype(s) => maybe_quote(&format!("note:{}", s)), Rated { days, ease } => write_rated(days, ease), - Tag(s) => maybe_quote(&format!("tag:{}", s)), + Tag { tag, is_re } => write_single_field("tag", tag, *is_re), Duplicates { notetype_id, text } => write_dupe(notetype_id, text), State(k) => write_state(k), Flag(u) => format!("flag:{}", u), @@ -102,6 +102,7 @@ fn needs_quotation(txt: &str) -> bool { RE.is_match(txt) } +/// Also used by tag search, which has the same syntax. fn write_single_field(field: &str, text: &str, is_re: bool) -> String { let re = if is_re { "re:" } else { "" }; let text = if !is_re && text.starts_with("re:") { diff --git a/rslib/src/storage/sqlite.rs b/rslib/src/storage/sqlite.rs index 64974cc22..41e90ed9c 100644 --- a/rslib/src/storage/sqlite.rs +++ b/rslib/src/storage/sqlite.rs @@ -52,6 +52,7 @@ fn open_or_create_collection_db(path: &Path) -> Result { add_field_index_function(&db)?; add_regexp_function(&db)?; add_regexp_fields_function(&db)?; + add_regexp_tags_function(&db)?; add_without_combining_function(&db)?; add_fnvhash_function(&db)?; @@ -157,6 +158,26 @@ fn add_regexp_fields_function(db: &Connection) -> rusqlite::Result<()> { ) } +/// Adds sql function `regexp_tags(regex, tags) -> is_match`. +fn add_regexp_tags_function(db: &Connection) -> rusqlite::Result<()> { + db.create_scalar_function( + "regexp_tags", + 2, + FunctionFlags::SQLITE_DETERMINISTIC, + move |ctx| { + assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); + + let re: Arc = ctx + .get_or_create_aux(0, |vr| -> std::result::Result<_, BoxError> { + Ok(Regex::new(vr.as_str()?)?) + })?; + let mut tags = ctx.get_raw(1).as_str()?.split(' '); + + Ok(tags.any(|tag| re.is_match(tag))) + }, + ) +} + /// Fetch schema version from database. /// Return (must_create, version) fn schema_version(db: &Connection) -> Result<(bool, u8)> {