queries.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. use heed3::RoTxn;
  2. use get_routes::handler;
  3. use helix_db::{field_remapping, identifier_remapping, traversal_remapping, exclude_field, value_remapping};
  4. use helix_db::helix_engine::vector_core::vector::HVector;
  5. use helix_db::{
  6. helix_engine::graph_core::ops::{
  7. g::G,
  8. in_::{in_::InAdapter, in_e::InEdgesAdapter, to_n::ToNAdapter, to_v::ToVAdapter},
  9. out::{from_n::FromNAdapter, from_v::FromVAdapter, out::OutAdapter, out_e::OutEdgesAdapter},
  10. source::{
  11. add_e::{AddEAdapter, EdgeType},
  12. add_n::AddNAdapter,
  13. e_from_id::EFromIdAdapter,
  14. e_from_type::EFromTypeAdapter,
  15. n_from_id::NFromIdAdapter,
  16. n_from_type::NFromTypeAdapter,
  17. n_from_index::NFromIndexAdapter,
  18. },
  19. tr_val::{Traversable, TraversalVal},
  20. util::{
  21. dedup::DedupAdapter, filter_mut::FilterMut,
  22. filter_ref::FilterRefAdapter, range::RangeAdapter, update::UpdateAdapter,
  23. map::MapAdapter, paths::ShortestPathAdapter, props::PropsAdapter, drop::Drop,
  24. },
  25. vectors::{insert::InsertVAdapter, search::SearchVAdapter, brute_force_search::BruteForceSearchVAdapter},
  26. bm25::search_bm25::SearchBM25Adapter,
  27. },
  28. helix_engine::types::GraphError,
  29. helix_gateway::router::router::HandlerInput,
  30. node_matches, props,
  31. protocol::count::Count,
  32. protocol::remapping::{RemappingMap, ResponseRemapping},
  33. protocol::response::Response,
  34. protocol::{
  35. filterable::Filterable, remapping::Remapping, return_values::ReturnValue, value::Value, id::ID,
  36. },
  37. };
  38. use sonic_rs::{Deserialize, Serialize};
  39. use std::collections::{HashMap, HashSet};
  40. use std::sync::Arc;
  41. use std::time::Instant;
  42. use std::cell::RefCell;
  43. use chrono::{DateTime, Utc};
  44. pub struct Company {
  45. pub company_number: String,
  46. pub number_of_filings: i32,
  47. }
  48. pub struct DocumentEdge {
  49. pub from: Company,
  50. pub to: DocumentEmbedding,
  51. pub filing_id: String,
  52. pub category: String,
  53. pub subcategory: String,
  54. pub date: String,
  55. pub description: String,
  56. }
  57. pub struct DocumentEmbedding {
  58. pub text: String,
  59. pub chunk_id: String,
  60. pub page_number: u16,
  61. pub reference: String,
  62. pub source_link: String,
  63. pub source_date: String,
  64. }
  65. #[derive(Serialize, Deserialize)]
  66. pub struct GetAllCompanyEmbeddingsInput {
  67. pub company_number: String
  68. }
  69. #[handler]
  70. pub fn GetAllCompanyEmbeddings (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  71. let data: GetAllCompanyEmbeddingsInput = match sonic_rs::from_slice(&input.request.body) {
  72. Ok(data) => data,
  73. Err(err) => return Err(GraphError::from(err)),
  74. };
  75. let mut remapping_vals = RemappingMap::new();
  76. let db = Arc::clone(&input.graph.storage);
  77. let txn = db.graph_env.read_txn().unwrap();
  78. let c = G::new(Arc::clone(&db), &txn)
  79. .n_from_index("company_number", &data.company_number).collect_to::<Vec<_>>();
  80. let embeddings = G::new_from(Arc::clone(&db), &txn, c.clone())
  81. .out("DocumentEdge",&EdgeType::Vec).collect_to::<Vec<_>>();
  82. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  83. return_vals.insert("embeddings".to_string(), ReturnValue::from_traversal_value_array_with_mixin(embeddings.clone(), remapping_vals.borrow_mut()));
  84. txn.commit().unwrap();
  85. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  86. Ok(())
  87. }
  88. #[derive(Serialize, Deserialize)]
  89. pub struct AddVectorInput {
  90. pub vector: Vec<f64>,
  91. pub text: String,
  92. pub chunk_id: String,
  93. pub page_number: i32,
  94. pub reference: String
  95. }
  96. #[handler]
  97. pub fn AddVector (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  98. let data: AddVectorInput = match sonic_rs::from_slice(&input.request.body) {
  99. Ok(data) => data,
  100. Err(err) => return Err(GraphError::from(err)),
  101. };
  102. let mut remapping_vals = RemappingMap::new();
  103. let db = Arc::clone(&input.graph.storage);
  104. let mut txn = db.graph_env.write_txn().unwrap();
  105. let embedding = G::new_mut(Arc::clone(&db), &mut txn)
  106. .insert_v::<fn(&HVector, &RoTxn) -> bool>(&data.vector, "DocumentEmbedding", Some(props! { "text" => data.text, "page_number" => data.page_number, "chunk_id" => data.chunk_id, "reference" => data.reference })).collect_to::<Vec<_>>();
  107. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  108. return_vals.insert("embedding".to_string(), ReturnValue::from_traversal_value_array_with_mixin(embedding.clone(), remapping_vals.borrow_mut()));
  109. txn.commit().unwrap();
  110. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  111. Ok(())
  112. }
  113. #[derive(Serialize, Deserialize)]
  114. pub struct AddCompanyInput {
  115. pub company_number: String,
  116. pub number_of_filings: i32
  117. }
  118. #[handler]
  119. pub fn AddCompany (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  120. let data: AddCompanyInput = match sonic_rs::from_slice(&input.request.body) {
  121. Ok(data) => data,
  122. Err(err) => return Err(GraphError::from(err)),
  123. };
  124. let mut remapping_vals = RemappingMap::new();
  125. let db = Arc::clone(&input.graph.storage);
  126. let mut txn = db.graph_env.write_txn().unwrap();
  127. let company = G::new_mut(Arc::clone(&db), &mut txn)
  128. .add_n("Company", Some(props! { "number_of_filings" => data.number_of_filings.clone(), "company_number" => data.company_number.clone() }), Some(&["company_number"])).collect_to::<Vec<_>>();
  129. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  130. return_vals.insert("company".to_string(), ReturnValue::from_traversal_value_array_with_mixin(company.clone(), remapping_vals.borrow_mut()));
  131. txn.commit().unwrap();
  132. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  133. Ok(())
  134. }
  135. #[handler]
  136. pub fn DeleteAll (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  137. let mut remapping_vals = RemappingMap::new();
  138. let db = Arc::clone(&input.graph.storage);
  139. let mut txn = db.graph_env.write_txn().unwrap();
  140. Drop::<Vec<_>>::drop_traversal(
  141. G::new(Arc::clone(&db), &txn)
  142. .n_from_type("Company").collect::<Vec<_>>(),
  143. Arc::clone(&db),
  144. &mut txn,
  145. )?;;
  146. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  147. return_vals.insert("success".to_string(), ReturnValue::from(Value::from("success")));
  148. txn.commit().unwrap();
  149. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  150. Ok(())
  151. }
  152. #[derive(Serialize, Deserialize)]
  153. pub struct HasCompanyInput {
  154. pub company_number: String
  155. }
  156. #[handler]
  157. pub fn HasCompany (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  158. let data: HasCompanyInput = match sonic_rs::from_slice(&input.request.body) {
  159. Ok(data) => data,
  160. Err(err) => return Err(GraphError::from(err)),
  161. };
  162. let mut remapping_vals = RemappingMap::new();
  163. let db = Arc::clone(&input.graph.storage);
  164. let txn = db.graph_env.read_txn().unwrap();
  165. let company = G::new(Arc::clone(&db), &txn)
  166. .n_from_index("company_number", &data.company_number).collect_to::<Vec<_>>();
  167. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  168. return_vals.insert("company".to_string(), ReturnValue::from_traversal_value_array_with_mixin(company.clone(), remapping_vals.borrow_mut()));
  169. txn.commit().unwrap();
  170. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  171. Ok(())
  172. }
  173. #[derive(Serialize, Deserialize)]
  174. pub struct DeleteCompanyInput {
  175. pub company_number: String
  176. }
  177. #[handler]
  178. pub fn DeleteCompany (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  179. let data: DeleteCompanyInput = match sonic_rs::from_slice(&input.request.body) {
  180. Ok(data) => data,
  181. Err(err) => return Err(GraphError::from(err)),
  182. };
  183. let mut remapping_vals = RemappingMap::new();
  184. let db = Arc::clone(&input.graph.storage);
  185. let mut txn = db.graph_env.write_txn().unwrap();
  186. Drop::<Vec<_>>::drop_traversal(
  187. G::new(Arc::clone(&db), &txn)
  188. .n_from_index("company_number", &data.company_number)
  189. .out("DocumentEdge",&EdgeType::Vec).collect::<Vec<_>>(),
  190. Arc::clone(&db),
  191. &mut txn,
  192. )?;;
  193. Drop::<Vec<_>>::drop_traversal(
  194. G::new(Arc::clone(&db), &txn)
  195. .n_from_index("company_number", &data.company_number).collect::<Vec<_>>(),
  196. Arc::clone(&db),
  197. &mut txn,
  198. )?;;
  199. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  200. return_vals.insert("success".to_string(), ReturnValue::from(Value::from("success")));
  201. txn.commit().unwrap();
  202. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  203. Ok(())
  204. }
  205. #[derive(Serialize, Deserialize)]
  206. pub struct HasDocumentEmbeddingsInput {
  207. pub company_number: String
  208. }
  209. #[handler]
  210. pub fn HasDocumentEmbeddings (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  211. let data: HasDocumentEmbeddingsInput = match sonic_rs::from_slice(&input.request.body) {
  212. Ok(data) => data,
  213. Err(err) => return Err(GraphError::from(err)),
  214. };
  215. let mut remapping_vals = RemappingMap::new();
  216. let db = Arc::clone(&input.graph.storage);
  217. let txn = db.graph_env.read_txn().unwrap();
  218. let c = G::new(Arc::clone(&db), &txn)
  219. .n_from_index("company_number", &data.company_number).collect_to::<Vec<_>>();
  220. let embeddings = G::new_from(Arc::clone(&db), &txn, c.clone())
  221. .out("DocumentEdge",&EdgeType::Vec).collect_to::<Vec<_>>();
  222. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  223. return_vals.insert("embeddings".to_string(), ReturnValue::from_traversal_value_array_with_mixin(embeddings.clone(), remapping_vals.borrow_mut()));
  224. txn.commit().unwrap();
  225. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  226. Ok(())
  227. }
  228. #[handler]
  229. pub fn GetCompanies (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  230. let mut remapping_vals = RemappingMap::new();
  231. let db = Arc::clone(&input.graph.storage);
  232. let txn = db.graph_env.read_txn().unwrap();
  233. let companies = G::new(Arc::clone(&db), &txn)
  234. .n_from_type("Company").collect_to::<Vec<_>>();
  235. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  236. return_vals.insert("companies".to_string(), ReturnValue::from_traversal_value_array_with_mixin(companies.clone(), remapping_vals.borrow_mut()));
  237. txn.commit().unwrap();
  238. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  239. Ok(())
  240. }
  241. #[derive(Serialize, Deserialize)]
  242. pub struct embeddings_dataData {
  243. pub category: String,
  244. pub subcategory: String,
  245. pub reference: String,
  246. pub date1: String,
  247. pub source: String,
  248. pub chunk_id: String,
  249. pub description: String,
  250. pub filing_id: String,
  251. pub vector: Vec<f64>,
  252. pub page_number: i32,
  253. pub date2: String,
  254. pub text: String,
  255. }
  256. #[derive(Serialize, Deserialize)]
  257. pub struct AddEmbeddingsToCompanyInput {
  258. pub company_number: String,
  259. pub embeddings_data: Vec<embeddings_dataData>
  260. }
  261. #[handler]
  262. pub fn AddEmbeddingsToCompany (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  263. let data: AddEmbeddingsToCompanyInput = match sonic_rs::from_slice(&input.request.body) {
  264. Ok(data) => data,
  265. Err(err) => return Err(GraphError::from(err)),
  266. };
  267. let mut remapping_vals = RemappingMap::new();
  268. let db = Arc::clone(&input.graph.storage);
  269. let mut txn = db.graph_env.write_txn().unwrap();
  270. let c = G::new(Arc::clone(&db), &txn)
  271. .n_from_index("company_number", &data.company_number).collect_to::<Vec<_>>();
  272. for data in data.embeddings_data {
  273. let embedding = G::new_mut(Arc::clone(&db), &mut txn)
  274. .insert_v::<fn(&HVector, &RoTxn) -> bool>(&data.vector, "DocumentEmbedding", Some(props! { "source_date" => data.date1, "source_link" => data.source, "page_number" => data.page_number, "reference" => data.reference, "text" => data.text, "chunk_id" => data.chunk_id })).collect_to::<Vec<_>>();
  275. let edges = G::new_mut(Arc::clone(&db), &mut txn)
  276. .add_e("DocumentEdge", Some(props! { "filing_id" => data.filing_id.clone(), "date" => data.date2.clone(), "subcategory" => data.subcategory.clone(), "category" => data.category.clone(), "description" => data.description.clone() }), c.id(), embedding.id(), true, EdgeType::Node).collect_to::<Vec<_>>();
  277. }
  278. ;
  279. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  280. return_vals.insert("success".to_string(), ReturnValue::from(Value::from("success")));
  281. txn.commit().unwrap();
  282. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  283. Ok(())
  284. }
  285. #[derive(Serialize, Deserialize)]
  286. pub struct SearchVectorInput {
  287. pub query: Vec<f64>,
  288. pub k: i32
  289. }
  290. #[handler]
  291. pub fn SearchVector (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  292. let data: SearchVectorInput = match sonic_rs::from_slice(&input.request.body) {
  293. Ok(data) => data,
  294. Err(err) => return Err(GraphError::from(err)),
  295. };
  296. let mut remapping_vals = RemappingMap::new();
  297. let db = Arc::clone(&input.graph.storage);
  298. let txn = db.graph_env.read_txn().unwrap();
  299. let embedding_search = G::new(Arc::clone(&db), &txn)
  300. .search_v::<fn(&HVector, &RoTxn) -> bool>(&data.query, data.k as usize, None).collect_to::<Vec<_>>();
  301. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  302. return_vals.insert("embedding_search".to_string(), ReturnValue::from_traversal_value_array_with_mixin(embedding_search.clone(), remapping_vals.borrow_mut()));
  303. txn.commit().unwrap();
  304. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  305. Ok(())
  306. }
  307. #[derive(Serialize, Deserialize)]
  308. pub struct CompanyEmbeddingSearchInput {
  309. pub company_number: String,
  310. pub query: Vec<f64>,
  311. pub k: i32
  312. }
  313. #[handler]
  314. pub fn CompanyEmbeddingSearch (input: &HandlerInput, response: &mut Response) -> Result<(), GraphError> {
  315. let data: CompanyEmbeddingSearchInput = match sonic_rs::from_slice(&input.request.body) {
  316. Ok(data) => data,
  317. Err(err) => return Err(GraphError::from(err)),
  318. };
  319. let mut remapping_vals = RemappingMap::new();
  320. let db = Arc::clone(&input.graph.storage);
  321. let txn = db.graph_env.read_txn().unwrap();
  322. let c = G::new(Arc::clone(&db), &txn)
  323. .n_from_index("company_number", &data.company_number)
  324. .out_e("DocumentEdge")
  325. .to_v().collect_to::<Vec<_>>();
  326. let embedding_search = G::new_from(Arc::clone(&db), &txn, c.clone())
  327. .brute_force_search_v(&data.query, data.k as usize).collect_to::<Vec<_>>();
  328. let mut return_vals: HashMap<String, ReturnValue> = HashMap::new();
  329. return_vals.insert("embedding_search".to_string(), ReturnValue::from_traversal_value_array_with_mixin(embedding_search.clone(), remapping_vals.borrow_mut()));
  330. txn.commit().unwrap();
  331. response.body = sonic_rs::to_vec(&return_vals).unwrap();
  332. Ok(())
  333. }