1use std::collections::HashMap;
5use std::convert::Infallible;
6use std::hash::Hash;
7use std::sync::{Arc, RwLock};
8
9use async_trait::async_trait;
10use futures::stream::{iter, StreamExt};
11
12use crate::{event, message, version};
13
14pub trait Streamer<StreamId, Event>: Send + Sync
17where
18 StreamId: Send + Sync,
19 Event: message::Message + Send + Sync,
20{
21 type Error: Send + Sync;
23
24 fn stream(
27 &self,
28 id: &StreamId,
29 select: event::VersionSelect,
30 ) -> event::Stream<StreamId, Event, Self::Error>;
31}
32
33#[derive(Debug, thiserror::Error)]
35pub enum AppendError {
36 #[error("failed to append new domain events: {0}")]
39 Conflict(#[from] version::ConflictError),
40 #[error("failed to append new domain events, an error occurred: {0}")]
42 Internal(#[from] anyhow::Error),
43}
44
45#[async_trait]
46pub trait Appender<StreamId, Event>: Send + Sync
48where
49 StreamId: Send + Sync,
50 Event: message::Message + Send + Sync,
51{
52 async fn append(
57 &self,
58 id: StreamId,
59 version_check: version::Check,
60 events: Vec<event::Envelope<Event>>,
61 ) -> Result<version::Version, AppendError>;
62}
63
64pub trait Store<StreamId, Event>:
69 Streamer<StreamId, Event> + Appender<StreamId, Event> + Send + Sync
70where
71 StreamId: Send + Sync,
72 Event: message::Message + Send + Sync,
73{
74}
75
76impl<T, StreamId, Event> Store<StreamId, Event> for T
77where
78 T: Streamer<StreamId, Event> + Appender<StreamId, Event> + Send + Sync,
79 StreamId: Send + Sync,
80 Event: message::Message + Send + Sync,
81{
82}
83
84#[derive(Debug)]
85struct InMemoryBackend<Id, Evt>
86where
87 Evt: message::Message,
88{
89 event_streams: HashMap<Id, Vec<event::Persisted<Id, Evt>>>,
90}
91
92impl<Id, Evt> Default for InMemoryBackend<Id, Evt>
93where
94 Evt: message::Message,
95{
96 fn default() -> Self {
97 Self {
98 event_streams: HashMap::default(),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
106pub struct InMemory<Id, Evt>
107where
108 Evt: message::Message,
109{
110 backend: Arc<RwLock<InMemoryBackend<Id, Evt>>>,
111}
112
113impl<Id, Evt> Default for InMemory<Id, Evt>
114where
115 Evt: message::Message,
116{
117 fn default() -> Self {
118 Self {
119 backend: Arc::default(),
120 }
121 }
122}
123
124impl<Id, Evt> Streamer<Id, Evt> for InMemory<Id, Evt>
125where
126 Id: Clone + Eq + Hash + Send + Sync,
127 Evt: message::Message + Clone + Send + Sync,
128{
129 type Error = Infallible;
130
131 fn stream(&self, id: &Id, select: event::VersionSelect) -> event::Stream<Id, Evt, Self::Error> {
132 let backend = self
133 .backend
134 .read()
135 .expect("acquire read lock on event store backend");
136
137 let events = backend
138 .event_streams
139 .get(id)
140 .cloned()
141 .unwrap_or_default() .into_iter()
143 .filter(move |evt| match select {
144 event::VersionSelect::All => true,
145 event::VersionSelect::From(v) => evt.version >= v,
146 });
147
148 iter(events).map(Ok).boxed()
149 }
150}
151
152#[async_trait]
153impl<Id, Evt> Appender<Id, Evt> for InMemory<Id, Evt>
154where
155 Id: Clone + Eq + Hash + Send + Sync,
156 Evt: message::Message + Clone + Send + Sync,
157{
158 async fn append(
159 &self,
160 id: Id,
161 version_check: version::Check,
162 events: Vec<event::Envelope<Evt>>,
163 ) -> Result<version::Version, AppendError> {
164 let mut backend = self
165 .backend
166 .write()
167 .expect("acquire write lock on event store backend");
168
169 let last_event_stream_version = backend
170 .event_streams
171 .get(&id)
172 .and_then(|events| events.last())
173 .map(|event| event.version)
174 .unwrap_or_default();
175
176 if let version::Check::MustBe(expected) = version_check {
177 if last_event_stream_version != expected {
178 return Err(AppendError::Conflict(version::ConflictError {
179 expected,
180 actual: last_event_stream_version,
181 }));
182 }
183 }
184
185 let mut persisted_events: Vec<event::Persisted<Id, Evt>> = events
186 .into_iter()
187 .enumerate()
188 .map(|(i, event)| event::Persisted {
189 stream_id: id.clone(),
190 version: last_event_stream_version + (i as u64) + 1,
191 event,
192 })
193 .collect();
194
195 let new_last_event_stream_version = persisted_events
196 .last()
197 .map(|evt| evt.version)
198 .unwrap_or_default();
199
200 backend
201 .event_streams
202 .entry(id)
203 .and_modify(|events| events.append(&mut persisted_events))
204 .or_insert_with(|| persisted_events);
205
206 Ok(new_last_event_stream_version)
207 }
208}
209
210#[derive(Debug, Clone)]
216pub struct Tracking<T, StreamId, Event>
217where
218 T: Store<StreamId, Event> + Send + Sync,
219 StreamId: Send + Sync,
220 Event: message::Message + Send + Sync,
221{
222 store: T,
223
224 #[allow(clippy::type_complexity)] events: Arc<RwLock<Vec<event::Persisted<StreamId, Event>>>>,
226}
227
228impl<T, StreamId, Event> Tracking<T, StreamId, Event>
229where
230 T: Store<StreamId, Event> + Send + Sync,
231 StreamId: Clone + Send + Sync,
232 Event: message::Message + Clone + Send + Sync,
233{
234 pub fn recorded_events(&self) -> Vec<event::Persisted<StreamId, Event>> {
241 self.events
242 .read()
243 .expect("acquire lock on recorded events list")
244 .clone()
245 }
246
247 pub fn reset_recorded_events(&self) {
254 self.events
255 .write()
256 .expect("acquire lock on recorded events list")
257 .clear();
258 }
259}
260
261impl<T, StreamId, Event> Streamer<StreamId, Event> for Tracking<T, StreamId, Event>
262where
263 T: Store<StreamId, Event> + Send + Sync,
264 StreamId: Clone + Send + Sync,
265 Event: message::Message + Clone + Send + Sync,
266{
267 type Error = <T as Streamer<StreamId, Event>>::Error;
268
269 fn stream(
270 &self,
271 id: &StreamId,
272 select: event::VersionSelect,
273 ) -> event::Stream<StreamId, Event, Self::Error> {
274 self.store.stream(id, select)
275 }
276}
277
278#[async_trait]
279impl<T, StreamId, Event> Appender<StreamId, Event> for Tracking<T, StreamId, Event>
280where
281 T: Store<StreamId, Event> + Send + Sync,
282 StreamId: Clone + Send + Sync,
283 Event: message::Message + Clone + Send + Sync,
284{
285 async fn append(
286 &self,
287 id: StreamId,
288 version_check: version::Check,
289 events: Vec<event::Envelope<Event>>,
290 ) -> Result<version::Version, AppendError> {
291 let new_version = self
292 .store
293 .append(id.clone(), version_check, events.clone())
294 .await?;
295
296 let events_size = events.len();
297 let previous_version = new_version - (events_size as version::Version);
298
299 let mut persisted_events = events
300 .into_iter()
301 .enumerate()
302 .map(|(i, event)| event::Persisted {
303 stream_id: id.clone(),
304 version: previous_version + (i as version::Version) + 1,
305 event,
306 })
307 .collect();
308
309 self.events
310 .write()
311 .expect("acquire lock on recorded events list")
312 .append(&mut persisted_events);
313
314 Ok(new_version)
315 }
316}
317
318pub trait EventStoreExt<StreamId, Event>: Store<StreamId, Event> + Send + Sync + Sized
321where
322 StreamId: Clone + Send + Sync,
323 Event: message::Message + Clone + Send + Sync,
324{
325 fn with_recorded_events_tracking(self) -> Tracking<Self, StreamId, Event> {
328 Tracking {
329 store: self,
330 events: Arc::default(),
331 }
332 }
333}
334
335impl<T, StreamId, Event> EventStoreExt<StreamId, Event> for T
336where
337 T: Store<StreamId, Event> + Send + Sync,
338 StreamId: Clone + Send + Sync,
339 Event: message::Message + Clone + Send + Sync,
340{
341}
342
343#[allow(clippy::semicolon_if_nothing_returned)] #[cfg(test)]
345mod test {
346 use std::sync::LazyLock;
347
348 use futures::TryStreamExt;
349
350 use super::*;
351 use crate::event;
352 use crate::event::store::{Appender, Streamer};
353 use crate::message::tests::StringMessage;
354 use crate::version::Version;
355
356 const STREAM_ID: &str = "stream:test";
357
358 static EVENTS: LazyLock<Vec<event::Envelope<StringMessage>>> = LazyLock::new(|| {
359 vec![
360 event::Envelope::from(StringMessage("event-1")),
361 event::Envelope::from(StringMessage("event-2")),
362 event::Envelope::from(StringMessage("event-3")),
363 ]
364 });
365
366 #[tokio::test]
367 async fn it_works() {
368 let event_store = InMemory::<&'static str, StringMessage>::default();
369
370 let new_event_stream_version = event_store
371 .append(STREAM_ID, version::Check::MustBe(0), EVENTS.clone())
372 .await
373 .expect("append should not fail");
374
375 let expected_version = EVENTS.len() as Version;
376 assert_eq!(expected_version, new_event_stream_version);
377
378 let expected_events = EVENTS
379 .clone()
380 .into_iter()
381 .enumerate()
382 .map(|(i, event)| event::Persisted {
383 stream_id: STREAM_ID,
384 version: (i as Version) + 1,
385 event,
386 })
387 .collect::<Vec<_>>();
388
389 let event_stream: Vec<_> = event_store
390 .stream(&STREAM_ID, event::VersionSelect::All)
391 .try_collect()
392 .await
393 .expect("opening an event stream should not fail");
394
395 assert_eq!(expected_events, event_stream);
396 }
397
398 #[tokio::test]
399 async fn tracking_store_works() {
400 let event_store = InMemory::<&'static str, StringMessage>::default();
401 let tracking_event_store = event_store.with_recorded_events_tracking();
402
403 tracking_event_store
404 .append(STREAM_ID, version::Check::MustBe(0), EVENTS.clone())
405 .await
406 .expect("append should not fail");
407
408 let event_stream: Vec<_> = tracking_event_store
409 .stream(&STREAM_ID, event::VersionSelect::All)
410 .try_collect()
411 .await
412 .expect("opening an event stream should not fail");
413
414 assert_eq!(event_stream, tracking_event_store.recorded_events());
415 }
416
417 #[tokio::test]
418 async fn version_conflict_checks_work_as_expected() {
419 let event_store = InMemory::<&'static str, StringMessage>::default();
420
421 let append_error = event_store
422 .append(STREAM_ID, version::Check::MustBe(3), EVENTS.clone())
423 .await
424 .expect_err("the event stream version should be zero");
425
426 if let AppendError::Conflict(err) = append_error {
427 return assert_eq!(
428 version::ConflictError {
429 expected: 3,
430 actual: 0,
431 },
432 err
433 );
434 }
435
436 panic!("expected conflict error, received: {append_error}")
437 }
438}