diff --git a/rslib/src/search/sqlwriter.rs b/rslib/src/search/sqlwriter.rs index 920afdba0..7f3d1821a 100644 --- a/rslib/src/search/sqlwriter.rs +++ b/rslib/src/search/sqlwriter.rs @@ -3,6 +3,8 @@ use std::{borrow::Cow, fmt::Write}; +use itertools::Itertools; + use super::{ parser::{Node, PropertyKind, RatingKind, SearchNode, StateKind, TemplateKind}, ReturnItemType, @@ -16,7 +18,7 @@ use crate::{ prelude::*, storage::ids_to_string, text::{ - is_glob, matches_glob, normalize_to_nfc, strip_html_preserving_media_filenames, + glob_matcher, is_glob, normalize_to_nfc, strip_html_preserving_media_filenames, to_custom_re, to_re, to_sql, to_text, without_combining, }, timestamp::TimestampSecs, @@ -117,7 +119,7 @@ impl SqlWriter<'_> { // note fields related SearchNode::UnqualifiedText(text) => self.write_unqualified(&self.norm_note(text)), SearchNode::SingleField { field, text, is_re } => { - self.write_single_field(&norm(field), &self.norm_note(text), *is_re)? + self.write_field(&norm(field), &self.norm_note(text), *is_re)? } SearchNode::Duplicates { notetype_id, text } => { self.write_dupe(*notetype_id, &self.norm_note(text))? @@ -419,55 +421,103 @@ impl SqlWriter<'_> { } } - fn write_single_field(&mut self, field_name: &str, val: &str, is_re: bool) -> Result<()> { + fn write_field(&mut self, field_name: &str, val: &str, is_re: bool) -> Result<()> { + if matches!(field_name, "*" | "_*" | "*_") { + if is_re { + self.write_all_fields_regexp(val); + } else { + self.write_all_fields(val); + } + Ok(()) + } else if is_re { + self.write_single_field_regexp(field_name, val) + } else { + self.write_single_field(field_name, val) + } + } + + fn write_all_fields_regexp(&mut self, val: &str) { + self.args.push(format!("(?i){}", val)); + write!(self.sql, "regexp_fields(?{}, n.flds)", self.args.len()).unwrap(); + } + + fn write_all_fields(&mut self, val: &str) { + self.args.push(format!("(?i)^{}$", to_re(val))); + write!(self.sql, "regexp_fields(?{}, n.flds)", self.args.len()).unwrap(); + } + + fn write_single_field_regexp(&mut self, field_name: &str, val: &str) -> Result<()> { + let field_indicies_by_notetype = self.fields_indices_by_notetype(field_name)?; + if field_indicies_by_notetype.is_empty() { + write!(self.sql, "false").unwrap(); + return Ok(()); + } + + self.args.push(format!("(?i){}", val)); + let arg_idx = self.args.len(); + + let all_notetype_clauses = field_indicies_by_notetype + .iter() + .map(|(mid, field_indices)| { + let field_index_list = field_indices.iter().join(", "); + format!("(n.mid = {mid} and regexp_fields(?{arg_idx}, n.flds, {field_index_list}))") + }) + .join(" or "); + + write!(self.sql, "({all_notetype_clauses})").unwrap(); + + Ok(()) + } + + fn write_single_field(&mut self, field_name: &str, val: &str) -> Result<()> { + let field_indicies_by_notetype = self.fields_indices_by_notetype(field_name)?; + if field_indicies_by_notetype.is_empty() { + write!(self.sql, "false").unwrap(); + return Ok(()); + } + + self.args.push(to_sql(val).into()); + let arg_idx = self.args.len(); + + let notetype_clause = |(mid, fields): &(NotetypeId, Vec)| -> String { + let field_index_clause = + |ord| format!("field_at_index(n.flds, {ord}) like ?{arg_idx} escape '\\'",); + let all_field_clauses = fields.iter().map(field_index_clause).join(" or "); + format!("(n.mid = {mid} and ({all_field_clauses}))",) + }; + let all_notetype_clauses = field_indicies_by_notetype + .iter() + .map(notetype_clause) + .join(" or "); + write!(self.sql, "({all_notetype_clauses})").unwrap(); + + Ok(()) + } + + fn fields_indices_by_notetype( + &mut self, + field_name: &str, + ) -> Result)>> { let notetypes = self.col.get_all_notetypes()?; + let matches_glob = glob_matcher(field_name); let mut field_map = vec![]; for nt in notetypes.values() { + let mut matched_fields = vec![]; for field in &nt.fields { - if matches_glob(&field.name, field_name) { - field_map.push((nt.id, field.ord)); + if matches_glob(&field.name) { + matched_fields.push(field.ord.unwrap_or_default()); } } + if !matched_fields.is_empty() { + field_map.push((nt.id, matched_fields)); + } } // for now, sort the map for the benefit of unit tests field_map.sort(); - if field_map.is_empty() { - write!(self.sql, "false").unwrap(); - return Ok(()); - } - - let cmp; - let cmp_trailer; - if is_re { - cmp = "regexp"; - cmp_trailer = ""; - self.args.push(format!("(?i){}", val)); - } else { - cmp = "like"; - cmp_trailer = "escape '\\'"; - self.args.push(to_sql(val).into()) - } - - let arg_idx = self.args.len(); - let searches: Vec<_> = field_map - .iter() - .map(|(ntid, ord)| { - format!( - "(n.mid = {mid} and field_at_index(n.flds, {ord}) {cmp} ?{n} {cmp_trailer})", - mid = ntid, - ord = ord.unwrap_or_default(), - cmp = cmp, - cmp_trailer = cmp_trailer, - n = arg_idx - ) - }) - .collect(); - write!(self.sql, "({})", searches.join(" or ")).unwrap(); - - Ok(()) + Ok(field_map) } fn write_dupe(&mut self, ntid: NotetypeId, text: &str) -> Result<()> { @@ -649,20 +699,50 @@ mod test { // user should be able to escape wildcards assert_eq!(s(ctx, r#"te\*s\_t"#).1, vec!["%te*s\\_t%".to_string()]); - // qualified search + // field search assert_eq!( s(ctx, "front:te*st"), ( concat!( - "(((n.mid = 1581236385344 and field_at_index(n.flds, 0) like ?1 escape '\\') or ", - "(n.mid = 1581236385345 and field_at_index(n.flds, 0) like ?1 escape '\\') or ", - "(n.mid = 1581236385346 and field_at_index(n.flds, 0) like ?1 escape '\\') or ", - "(n.mid = 1581236385347 and field_at_index(n.flds, 0) like ?1 escape '\\')))" + "(((n.mid = 1581236385344 and (field_at_index(n.flds, 0) like ?1 escape '\\')) or ", + "(n.mid = 1581236385345 and (field_at_index(n.flds, 0) like ?1 escape '\\')) or ", + "(n.mid = 1581236385346 and (field_at_index(n.flds, 0) like ?1 escape '\\')) or ", + "(n.mid = 1581236385347 and (field_at_index(n.flds, 0) like ?1 escape '\\'))))" ) .into(), vec!["te%st".into()] ) ); + // field search with regex + assert_eq!( + s(ctx, "front:re:te.*st"), + ( + concat!( + "(((n.mid = 1581236385344 and regexp_fields(?1, n.flds, 0)) or ", + "(n.mid = 1581236385345 and regexp_fields(?1, n.flds, 0)) or ", + "(n.mid = 1581236385346 and regexp_fields(?1, n.flds, 0)) or ", + "(n.mid = 1581236385347 and regexp_fields(?1, n.flds, 0))))" + ) + .into(), + vec!["(?i)te.*st".into()] + ) + ); + // all field search + assert_eq!( + s(ctx, "*:te*st"), + ( + "(regexp_fields(?1, n.flds))".into(), + vec!["(?i)^te.*st$".into()] + ) + ); + // all field search with regex + assert_eq!( + s(ctx, "*:re:te.*st"), + ( + "(regexp_fields(?1, n.flds))".into(), + vec!["(?i)te.*st".into()] + ) + ); // added let timing = ctx.timing_today().unwrap(); diff --git a/rslib/src/storage/sqlite.rs b/rslib/src/storage/sqlite.rs index 31f01525c..64974cc22 100644 --- a/rslib/src/storage/sqlite.rs +++ b/rslib/src/storage/sqlite.rs @@ -1,7 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use std::{borrow::Cow, cmp::Ordering, hash::Hasher, path::Path, sync::Arc}; +use std::{borrow::Cow, cmp::Ordering, collections::HashSet, hash::Hasher, path::Path, sync::Arc}; use fnv::FnvHasher; use regex::Regex; @@ -51,6 +51,7 @@ fn open_or_create_collection_db(path: &Path) -> Result { add_field_index_function(&db)?; add_regexp_function(&db)?; + add_regexp_fields_function(&db)?; add_without_combining_function(&db)?; add_fnvhash_function(&db)?; @@ -130,6 +131,32 @@ fn add_regexp_function(db: &Connection) -> rusqlite::Result<()> { ) } +/// Adds sql function `regexp_fields(regex, note_flds, indices...) -> is_match`. +/// If no indices are provided, all fields are matched against. +fn add_regexp_fields_function(db: &Connection) -> rusqlite::Result<()> { + db.create_scalar_function( + "regexp_fields", + -1, + FunctionFlags::SQLITE_DETERMINISTIC, + move |ctx| { + assert!(ctx.len() > 1, "not enough arguments"); + + let re: Arc = ctx + .get_or_create_aux(0, |vr| -> std::result::Result<_, BoxError> { + Ok(Regex::new(vr.as_str()?)?) + })?; + let fields = ctx.get_raw(1).as_str()?.split('\x1f'); + let indices: HashSet = (2..ctx.len()) + .map(|i| ctx.get(i)) + .collect::>()?; + + Ok(fields.enumerate().any(|(idx, field)| { + (indices.is_empty() || indices.contains(&idx)) && re.is_match(field) + })) + }, + ) +} + /// Fetch schema version from database. /// Return (must_create, version) fn schema_version(db: &Connection) -> Result<(bool, u8)> { diff --git a/rslib/src/text.rs b/rslib/src/text.rs index ff2d6ed54..c285f87f4 100644 --- a/rslib/src/text.rs +++ b/rslib/src/text.rs @@ -355,13 +355,23 @@ pub(crate) fn escape_anki_wildcards_for_search_node(txt: &str) -> String { } } -/// Compare text with a possible glob, folding case. -pub(crate) fn matches_glob(text: &str, search: &str) -> bool { +/// Return a function to match input against `search`, +/// which may contain wildcards. +pub(crate) fn glob_matcher(search: &str) -> impl Fn(&str) -> bool + '_ { + let mut regex = None; + let mut cow = None; if is_glob(search) { - let search = format!("^(?i){}$", to_re(search)); - Regex::new(&search).unwrap().is_match(text) + regex = Some(Regex::new(&format!("^(?i){}$", to_re(search))).unwrap()); } else { - uni_eq(text, &to_text(search)) + cow = Some(to_text(search)); + } + + move |text| { + if let Some(r) = ®ex { + r.is_match(text) + } else { + uni_eq(text, cow.as_ref().unwrap()) + } } } @@ -451,6 +461,6 @@ mod test { assert_eq!(&to_text(r"\*\_*_"), "*_*_"); assert!(is_glob(r"\\\\_")); assert!(!is_glob(r"\\\_")); - assert!(matches_glob("foo*bar123", r"foo\*bar*")); + assert!(glob_matcher(r"foo\*bar*")("foo*bar123")); } }