mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-24 03:26:11 -05:00 
			
		
		
		
	Compare commits
	
		
			120 Commits
		
	
	
		
			dependabot
			...
			5a18f3a529
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 5a18f3a529 | ||
|   | acaa83ad30 | ||
|   | bf7e7cd3b9 | ||
|   | a78a2349bb | ||
|   | a569b0574d | ||
|   | 4b07179b01 | ||
|   | 8bb65af214 | ||
|   | 4318f7dac3 | ||
|   | 9837407879 | ||
|   | d21d0eaf08 | ||
|   | f0eb9d981c | ||
|   | 66f5f3cbee | ||
|   | e00dc63021 | ||
|   | 3825023337 | ||
|   | b9e34bd793 | ||
|   | fcbc438ffd | ||
|   | 4076a35559 | ||
|   | 3bb03062b1 | ||
|   | af1928f734 | ||
|   | 7cc089599c | ||
|   | 4c719948d9 | ||
|   | 867c7d9e62 | ||
|   | 6eb0b21a44 | ||
|   | 95ed997717 | ||
|   | 7bd9b385aa | ||
|   | 541108688a | ||
|   | 74c9fedd4c | ||
|   | 6b99c21710 | ||
|   | 64ff422fef | ||
|   | 540539643c | ||
|   | b52412d776 | ||
|   | da2ac19193 | ||
|   | 3583470856 | ||
|   | 5bfbe856a6 | ||
|   | 20bae4bd41 | ||
|   | b94912a392 | ||
|   | 50e6a4bd61 | ||
|   | 87e5d82c46 | ||
|   | 476844f32a | ||
|   | 01285c96d4 | ||
|   | 3e6ba34c5e | ||
|   | d9cbd3652a | ||
|   | 90bd878cf2 | ||
|   | 62e04ab2fe | ||
|   | dbdc67da7a | ||
|   | 11a4e0d5ba | ||
|   | c4b431f5a6 | ||
|   | d31f4669a2 | ||
|   | 483f1e9438 | ||
|   | d7a358d39d | ||
|   | b94a60d607 | ||
|   | e6d8cd6547 | ||
|   | e2fc7f596d | ||
|   | 20e7f01cec | ||
|   | 96daa5eb18 | ||
|   | 84e17535fc | ||
|   | 77db0c399c | ||
|   | e51c7a27bb | ||
|   | a3455c8373 | ||
|   | cce9dfd5b8 | ||
|   | 3a9257f10a | ||
|   | 3b921da6c3 | ||
|   | ad8519482c | ||
|   | fe205b31c2 | ||
|   | 13ab148c7e | ||
|   | 559caf72c2 | ||
|   | 2481a66544 | ||
|   | f6a3882199 | ||
|   | 8d48d398eb | ||
|   | b3b9a8fb5b | ||
|   | 4cdc629e3d | ||
|   | 5195a97e4c | ||
|   | 96fa522394 | ||
|   | dd1da9f072 | ||
|   | d99f2d6160 | ||
|   | ebd46f08e5 | ||
|   | 6f0c6f39b1 | ||
|   | 0690fd36c5 | ||
|   | 0052f21cea | ||
|   | c809a65571 | ||
|   | bb3336f7bc | ||
|   | a9ed46de11 | ||
|   | 1ccaf66869 | ||
|   | e864a51497 | ||
|   | 4a28be233e | ||
|   | 9183bfc0a4 | ||
|   | 5f26139a5f | ||
|   | ccfc7d98b1 | ||
|   | d1bd2af49c | ||
|   | e2eec6dc71 | ||
|   | 42e3684211 | ||
|   | df8f07555f | ||
|   | 3660336bcf | ||
|   | aeceaf60a2 | ||
|   | 959ebdbb85 | ||
|   | eb1c49090b | ||
|   | 9f8b8a9f20 | ||
|   | f5fc04cfe2 | ||
|   | 3186550fd7 | ||
|   | 74aaf18630 | ||
|   | e6a147079d | ||
|   | 105b823fd9 | ||
|   | be20c48588 | ||
|   | 377dcc39f5 | ||
|   | 767118fa8a | ||
|   | 339612f4ec | ||
|   | e7592c6269 | ||
|   | ffc0b936f3 | ||
|   | 1a6540e8ed | ||
|   | abbf9060d0 | ||
|   | 11a3dfe890 | ||
|   | faa5d3e5b9 | ||
|   | 8d1a8c2c42 | ||
|   | 01dc3cc17c | ||
|   | cfbd5af820 | ||
|   | e8090fd030 | ||
|   | 05896d5b70 | ||
|   | 65b8a74166 | ||
|   | 56b1c7adeb | ||
|   | 55cb9cedc7 | 
| @@ -11,6 +11,7 @@ for command in decrypt_documents \ | ||||
| 	mail_fetcher \ | ||||
| 	document_create_classifier \ | ||||
| 	document_index \ | ||||
| 	document_llmindex \ | ||||
| 	document_renamer \ | ||||
| 	document_retagger \ | ||||
| 	document_thumbnails \ | ||||
|   | ||||
							
								
								
									
										14
									
								
								docker/rootfs/usr/local/bin/document_llmindex
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										14
									
								
								docker/rootfs/usr/local/bin/document_llmindex
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| #!/command/with-contenv /usr/bin/bash | ||||
| # shellcheck shell=bash | ||||
|  | ||||
| set -e | ||||
|  | ||||
| cd "${PAPERLESS_SRC_DIR}" | ||||
|  | ||||
| if [[ $(id -u) == 0 ]]; then | ||||
| 	s6-setuidgid paperless python3 manage.py document_llmindex "$@" | ||||
| elif [[ $(id -un) == "paperless" ]]; then | ||||
| 	python3 manage.py document_llmindex "$@" | ||||
| else | ||||
| 	echo "Unknown user." | ||||
| fi | ||||
| @@ -1805,3 +1805,67 @@ password. All of these options come from their similarly-named [Django settings] | ||||
| #### [`PAPERLESS_EMAIL_USE_SSL=<bool>`](#PAPERLESS_EMAIL_USE_SSL) {#PAPERLESS_EMAIL_USE_SSL} | ||||
|  | ||||
| : Defaults to false. | ||||
|  | ||||
| ## AI {#ai} | ||||
|  | ||||
| #### [`PAPERLESS_AI_ENABLED=<bool>`](#PAPERLESS_AI_ENABLED) {#PAPERLESS_AI_ENABLED} | ||||
|  | ||||
| : Enables the AI features in Paperless. This includes the AI-based | ||||
| suggestions. This setting is required to be set to true in order to use the AI features. | ||||
|  | ||||
|     Defaults to false. | ||||
|  | ||||
| #### [`PAPERLESS_AI_LLM_EMBEDDING_BACKEND=<str>`](#PAPERLESS_AI_LLM_EMBEDDING_BACKEND) {#PAPERLESS_AI_LLM_EMBEDDING_BACKEND} | ||||
|  | ||||
| : The embedding backend to use for RAG. This can be either "openai" or "huggingface". | ||||
|  | ||||
|     Defaults to None. | ||||
|  | ||||
| #### [`PAPERLESS_AI_LLM_EMBEDDING_MODEL=<str>`](#PAPERLESS_AI_LLM_EMBEDDING_MODEL) {#PAPERLESS_AI_LLM_EMBEDDING_MODEL} | ||||
|  | ||||
| : The model to use for the embedding backend for RAG. This can be set to any of the embedding models supported by the current embedding backend. If not supplied, defaults to "text-embedding-3-small" for OpenAI and "sentence-transformers/all-MiniLM-L6-v2" for Huggingface. | ||||
|  | ||||
|     Defaults to None. | ||||
|  | ||||
| #### [`PAPERLESS_AI_BACKEND=<str>`](#PAPERLESS_AI_BACKEND) {#PAPERLESS_AI_BACKEND} | ||||
|  | ||||
| : The AI backend to use. This can be either "openai" or "ollama". If set to "ollama", the AI | ||||
| features will be run locally on your machine. If set to "openai", the AI features will be run | ||||
| using the OpenAI API. This setting is required to be set to use the AI features. | ||||
|  | ||||
|     Defaults to None. | ||||
|  | ||||
|     !!! note | ||||
|  | ||||
|         The OpenAI API is a paid service. You will need to set up an OpenAI account and | ||||
|         will be charged for usage incurred by Paperless-ngx features and your document data | ||||
|         will (of course) be sent to the OpenAI API. Paperless-ngx does not endorse the use of the | ||||
|         OpenAI API in any way. | ||||
|  | ||||
|         Refer to the OpenAI terms of service, and use at your own risk. | ||||
|  | ||||
| #### [`PAPERLESS_AI_LLM_MODEL=<str>`](#PAPERLESS_AI_LLM_MODEL) {#PAPERLESS_AI_LLM_MODEL} | ||||
|  | ||||
| : The model to use for the AI backend, i.e. "gpt-3.5-turbo", "gpt-4" or any of the models supported by the | ||||
| current backend. If not supplied, defaults to "gpt-3.5-turbo" for OpenAI and "llama3" for Ollama. | ||||
|  | ||||
|     Defaults to None. | ||||
|  | ||||
| #### [`PAPERLESS_AI_LLM_API_KEY=<str>`](#PAPERLESS_AI_LLM_API_KEY) {#PAPERLESS_AI_LLM_API_KEY} | ||||
|  | ||||
| : The API key to use for the AI backend. This is required for the OpenAI backend only. | ||||
|  | ||||
|     Defaults to None. | ||||
|  | ||||
| #### [`PAPERLESS_AI_LLM_ENDPOINT=<str>`](#PAPERLESS_AI_LLM_ENDPOINT) {#PAPERLESS_AI_LLM_ENDPOINT} | ||||
|  | ||||
| : The endpoint / url to use for the AI backend. This is required for the Ollama backend only. | ||||
|  | ||||
|     Defaults to None. | ||||
|  | ||||
| #### [`PAPERLESS_AI_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_AI_LLM_INDEX_TASK_CRON) {#PAPERLESS_AI_LLM_INDEX_TASK_CRON} | ||||
|  | ||||
| : Configures the schedule to update the AI embeddings of text content and metadata for all documents. Only performed if | ||||
| AI is enabled and the LLM embedding backend is set. | ||||
|  | ||||
|     Defaults to `10 2 * * *`, once per day. | ||||
|   | ||||
| @@ -25,11 +25,12 @@ physical documents into a searchable online archive so you can keep, well, _less | ||||
| ## Features | ||||
|  | ||||
| -   **Organize and index** your scanned documents with tags, correspondents, types, and more. | ||||
| -   _Your_ data is stored locally on _your_ server and is never transmitted or shared in any way. | ||||
| -   _Your_ data is stored locally on _your_ server and is never transmitted or shared in any way, unless you explicitly choose to do so. | ||||
| -   Performs **OCR** on your documents, adding searchable and selectable text, even to documents scanned with only images. | ||||
| -   Utilizes the open-source Tesseract engine to recognize more than 100 languages. | ||||
| -   Documents are saved as PDF/A format which is designed for long term storage, alongside the unaltered originals. | ||||
| -   Uses machine-learning to automatically add tags, correspondents and document types to your documents. | ||||
| -   **New**: Paperless-ngx can now leverage AI (Large Language Models or LLMs) for document suggestions. This is an optional feature that can be enabled (and is disabled by default). | ||||
| -   Supports PDF documents, images, plain text files, Office documents (Word, Excel, PowerPoint, and LibreOffice equivalents)[^1] and more. | ||||
| -   Paperless stores your documents plain on disk. Filenames and folders are managed by paperless and their format can be configured freely with different configurations assigned to different documents. | ||||
| -   **Beautiful, modern web application** that features: | ||||
|   | ||||
| @@ -274,6 +274,28 @@ Once setup, navigating to the email settings page in Paperless-ngx will allow yo | ||||
| You can also submit a document using the REST API, see [POSTing documents](api.md#file-uploads) | ||||
| for details. | ||||
|  | ||||
| ## Document Suggestions | ||||
|  | ||||
| Paperless-ngx can suggest tags, correspondents, document types and storage paths for documents based on the content of the document. This is done using a (non-LLM) machine learning model that is trained on the documents in your database. The suggestions are shown in the document detail page and can be accepted or rejected by the user. | ||||
|  | ||||
| ## AI Features | ||||
|  | ||||
| Paperless-ngx includes several features that use AI to enhance the document management experience. These features are optional and can be enabled or disabled in the settings. If you are using the AI features, you may want to also enable the "LLM index" feature, which supports Retrieval-Augmented Generation (RAG) designed to improve the quality of AI responses. The LLM index feature is not enabled by default and requires additional configuration. | ||||
|  | ||||
| !!! warning | ||||
|  | ||||
|     Remember that Paperless-ngx will send document content to the AI provider you have configured, so consider the privacy implications of using these features, especially if using a remote model (e.g. OpenAI), instead of the default local model. | ||||
|  | ||||
| The AI features work by creating an embedding of the text content and metadata of documents, which is then used for various tasks such as similarity search and question answering. This uses the FAISS vector store. | ||||
|  | ||||
| ### AI-Enhanced Suggestions | ||||
|  | ||||
| If enabled, Paperless-ngx can use an AI LLM model to suggest document titles, dates, tags, correspondents and document types for documents. This feature will always be "opt-in" and does not disable the existing classifier-based suggestion system. Currently, both remote (via the OpenAI API) and local (via Ollama) models are supported, see [configuration](configuration.md#ai) for details. | ||||
|  | ||||
| ### Document Chat | ||||
|  | ||||
| Paperless-ngx can use an AI LLM model to answer questions about a document or across multiple documents. Again, this feature works best when RAG is enabled. The chat feature is available in the upper app toolbar and will switch between chatting across multiple documents or a single document based on the current view. | ||||
|  | ||||
| ## Sharing documents from Paperless-ngx | ||||
|  | ||||
| Paperless-ngx supports sharing documents with other users by assigning them [permissions](#object-permissions) | ||||
|   | ||||
| @@ -42,6 +42,7 @@ dependencies = [ | ||||
|   "drf-spectacular~=0.28", | ||||
|   "drf-spectacular-sidecar~=2025.9.1", | ||||
|   "drf-writable-nested~=0.7.1", | ||||
|   "faiss-cpu>=1.10", | ||||
|   "filelock~=3.19.1", | ||||
|   "flower~=2.0.1", | ||||
|   "gotenberg-client~=0.11.0", | ||||
| @@ -50,8 +51,15 @@ dependencies = [ | ||||
|   "inotifyrecursive~=0.3", | ||||
|   "jinja2~=3.1.5", | ||||
|   "langdetect~=1.0.9", | ||||
|   "llama-index-core>=0.12.33.post1", | ||||
|   "llama-index-embeddings-huggingface>=0.5.3", | ||||
|   "llama-index-embeddings-openai>=0.3.1", | ||||
|   "llama-index-llms-ollama>=0.5.4", | ||||
|   "llama-index-llms-openai>=0.3.38", | ||||
|   "llama-index-vector-stores-faiss>=0.3", | ||||
|   "nltk~=3.9.1", | ||||
|   "ocrmypdf~=16.11.0", | ||||
|   "openai>=1.76", | ||||
|   "pathvalidate~=3.3.1", | ||||
|   "pdf2image~=1.17.0", | ||||
|   "psycopg-pool", | ||||
| @@ -64,6 +72,7 @@ dependencies = [ | ||||
|   "rapidfuzz~=3.14.0", | ||||
|   "redis[hiredis]~=5.2.1", | ||||
|   "scikit-learn~=1.7.0", | ||||
|   "sentence-transformers>=4.1", | ||||
|   "setproctitle~=1.3.4", | ||||
|   "tika-client~=0.10.0", | ||||
|   "tqdm~=4.67.1", | ||||
| @@ -233,6 +242,7 @@ testpaths = [ | ||||
|   "src/paperless_tesseract/tests/", | ||||
|   "src/paperless_tika/tests", | ||||
|   "src/paperless_text/tests/", | ||||
|   "src/paperless_ai/tests", | ||||
| ] | ||||
| addopts = [ | ||||
|   "--pythonwarnings=all", | ||||
|   | ||||
| @@ -35,8 +35,12 @@ | ||||
|                                                     @case (ConfigOptionType.String) { <pngx-input-text [formControlName]="option.key" [error]="errors[option.key]"></pngx-input-text> } | ||||
|                                                     @case (ConfigOptionType.JSON) { <pngx-input-text [formControlName]="option.key" [error]="errors[option.key]"></pngx-input-text> } | ||||
|                                                     @case (ConfigOptionType.File) { <pngx-input-file [formControlName]="option.key" (upload)="uploadFile($event, option.key)" [error]="errors[option.key]"></pngx-input-file> } | ||||
|                                                     @case (ConfigOptionType.Password) { <pngx-input-password [formControlName]="option.key" [error]="errors[option.key]"></pngx-input-password> } | ||||
|                                                 } | ||||
|                                             </div> | ||||
|                                             @if (option.note) { | ||||
|                                                 <div class="form-text fst-italic">{{option.note}}</div> | ||||
|                                             } | ||||
|                                         </div> | ||||
|                                     </div> | ||||
|                                 </div> | ||||
|   | ||||
| @@ -29,6 +29,7 @@ import { SettingsService } from 'src/app/services/settings.service' | ||||
| import { ToastService } from 'src/app/services/toast.service' | ||||
| import { FileComponent } from '../../common/input/file/file.component' | ||||
| import { NumberComponent } from '../../common/input/number/number.component' | ||||
| import { PasswordComponent } from '../../common/input/password/password.component' | ||||
| import { SelectComponent } from '../../common/input/select/select.component' | ||||
| import { SwitchComponent } from '../../common/input/switch/switch.component' | ||||
| import { TextComponent } from '../../common/input/text/text.component' | ||||
| @@ -46,6 +47,7 @@ import { LoadingComponentWithPermissions } from '../../loading-component/loading | ||||
|     TextComponent, | ||||
|     NumberComponent, | ||||
|     FileComponent, | ||||
|     PasswordComponent, | ||||
|     AsyncPipe, | ||||
|     NgbNavModule, | ||||
|     FormsModule, | ||||
|   | ||||
| @@ -92,6 +92,9 @@ const status: SystemStatus = { | ||||
|     sanity_check_status: SystemStatusItemStatus.ERROR, | ||||
|     sanity_check_last_run: new Date().toISOString(), | ||||
|     sanity_check_error: 'Error running sanity check.', | ||||
|     llmindex_status: SystemStatusItemStatus.DISABLED, | ||||
|     llmindex_last_modified: new Date().toISOString(), | ||||
|     llmindex_error: null, | ||||
|   }, | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -30,6 +30,9 @@ | ||||
|     </div> | ||||
|   </div> | ||||
|   <ul ngbNav class="order-sm-3"> | ||||
|     @if (aiEnabled) { | ||||
|       <pngx-chat></pngx-chat> | ||||
|     } | ||||
|     <pngx-toasts-dropdown></pngx-toasts-dropdown> | ||||
|     <li ngbDropdown class="nav-item dropdown"> | ||||
|       <button class="btn ps-1 border-0" id="userDropdown" ngbDropdownToggle> | ||||
|   | ||||
| @@ -44,6 +44,7 @@ import { SettingsService } from 'src/app/services/settings.service' | ||||
| import { TasksService } from 'src/app/services/tasks.service' | ||||
| import { ToastService } from 'src/app/services/toast.service' | ||||
| import { environment } from 'src/environments/environment' | ||||
| import { ChatComponent } from '../chat/chat/chat.component' | ||||
| import { ProfileEditDialogComponent } from '../common/profile-edit-dialog/profile-edit-dialog.component' | ||||
| import { DocumentDetailComponent } from '../document-detail/document-detail.component' | ||||
| import { ComponentWithPermissions } from '../with-permissions/with-permissions.component' | ||||
| @@ -59,6 +60,7 @@ import { ToastsDropdownComponent } from './toasts-dropdown/toasts-dropdown.compo | ||||
|     DocumentTitlePipe, | ||||
|     IfPermissionsDirective, | ||||
|     ToastsDropdownComponent, | ||||
|     ChatComponent, | ||||
|     RouterModule, | ||||
|     NgClass, | ||||
|     NgbDropdownModule, | ||||
| @@ -171,6 +173,10 @@ export class AppFrameComponent | ||||
|       }) | ||||
|   } | ||||
|  | ||||
|   get aiEnabled(): boolean { | ||||
|     return this.settingsService.get(SETTINGS_KEYS.AI_ENABLED) | ||||
|   } | ||||
|  | ||||
|   closeMenu() { | ||||
|     this.isMenuCollapsed = true | ||||
|   } | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
|  | ||||
| <li ngbDropdown class="nav-item" (openChange)="onOpenChange($event)"> | ||||
| <li ngbDropdown class="nav-item mx-1" (openChange)="onOpenChange($event)"> | ||||
|   @if (toasts.length) { | ||||
|     <span class="badge rounded-pill z-3 pe-none bg-secondary me-2 position-absolute top-0 left-0">{{ toasts.length }}</span> | ||||
|   } | ||||
|   | ||||
							
								
								
									
										35
									
								
								src-ui/src/app/components/chat/chat/chat.component.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								src-ui/src/app/components/chat/chat/chat.component.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
|  | ||||
| <li ngbDropdown class="nav-item me-n2" (openChange)="onOpenChange($event)"> | ||||
|   <button class="btn border-0" id="chatDropdown" ngbDropdownToggle> | ||||
|     <i-bs width="1.3em" height="1.3em" name="chatSquareDots"></i-bs> | ||||
|   </button> | ||||
|   <div ngbDropdownMenu class="dropdown-menu-end shadow p-3" aria-labelledby="chatDropdown"> | ||||
|     <div class="chat-container bg-light p-2"> | ||||
|       <div class="chat-messages font-monospace small"> | ||||
|         @for (message of messages; track message) { | ||||
|           <div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'"> | ||||
|             <span class="p-2 m-2" [class.bg-dark]="message.role === 'user'"> | ||||
|               {{ message.content }} | ||||
|               @if (message.isStreaming) { <span class="blinking-cursor">|</span> } | ||||
|             </span> | ||||
|           </div> | ||||
|         } | ||||
|         <div #scrollAnchor></div> | ||||
|       </div> | ||||
|  | ||||
|       <form class="chat-input"> | ||||
|         <div class="input-group"> | ||||
|           <input | ||||
|             #chatInput | ||||
|             class="form-control form-control-sm" name="chatInput" type="text" | ||||
|             [placeholder]="placeholder" | ||||
|             [disabled]="loading" | ||||
|             [(ngModel)]="input" | ||||
|             (keydown)="searchInputKeyDown($event)" | ||||
|             /> | ||||
|           <button class="btn btn-sm btn-secondary" type="button" (click)="sendMessage()" [disabled]="loading">Send</button> | ||||
|         </div> | ||||
|       </form> | ||||
|     </div> | ||||
|   </div> | ||||
| </li> | ||||
							
								
								
									
										37
									
								
								src-ui/src/app/components/chat/chat/chat.component.scss
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								src-ui/src/app/components/chat/chat/chat.component.scss
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| .dropdown-menu { | ||||
|   width: var(--pngx-toast-max-width); | ||||
| } | ||||
|  | ||||
| .chat-messages { | ||||
|   max-height: 350px; | ||||
|   overflow-y: auto; | ||||
| } | ||||
|  | ||||
| .dropdown-toggle::after { | ||||
|   display: none; | ||||
| } | ||||
|  | ||||
| .dropdown-item { | ||||
|   white-space: initial; | ||||
| } | ||||
|  | ||||
| @media screen and (max-width: 400px) { | ||||
|   :host ::ng-deep .dropdown-menu-end { | ||||
|     right: -3rem; | ||||
|   } | ||||
| } | ||||
|  | ||||
| .blinking-cursor { | ||||
|   font-weight: bold; | ||||
|   font-size: 1.2em; | ||||
|   animation: blink 1s step-end infinite; | ||||
| } | ||||
|  | ||||
| @keyframes blink { | ||||
|   from, to { | ||||
|     opacity: 0; | ||||
|   } | ||||
|   50% { | ||||
|     opacity: 1; | ||||
|   } | ||||
| } | ||||
							
								
								
									
										132
									
								
								src-ui/src/app/components/chat/chat/chat.component.spec.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								src-ui/src/app/components/chat/chat/chat.component.spec.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,132 @@ | ||||
| import { provideHttpClient, withInterceptorsFromDi } from '@angular/common/http' | ||||
| import { provideHttpClientTesting } from '@angular/common/http/testing' | ||||
| import { ElementRef } from '@angular/core' | ||||
| import { ComponentFixture, TestBed } from '@angular/core/testing' | ||||
| import { NavigationEnd, Router } from '@angular/router' | ||||
| import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons' | ||||
| import { Subject } from 'rxjs' | ||||
| import { ChatService } from 'src/app/services/chat.service' | ||||
| import { ChatComponent } from './chat.component' | ||||
|  | ||||
| describe('ChatComponent', () => { | ||||
|   let component: ChatComponent | ||||
|   let fixture: ComponentFixture<ChatComponent> | ||||
|   let chatService: ChatService | ||||
|   let router: Router | ||||
|   let routerEvents$: Subject<NavigationEnd> | ||||
|   let mockStream$: Subject<string> | ||||
|  | ||||
|   beforeEach(async () => { | ||||
|     TestBed.configureTestingModule({ | ||||
|       imports: [NgxBootstrapIconsModule.pick(allIcons), ChatComponent], | ||||
|       providers: [ | ||||
|         provideHttpClient(withInterceptorsFromDi()), | ||||
|         provideHttpClientTesting(), | ||||
|       ], | ||||
|     }).compileComponents() | ||||
|  | ||||
|     fixture = TestBed.createComponent(ChatComponent) | ||||
|     router = TestBed.inject(Router) | ||||
|     routerEvents$ = new Subject<any>() | ||||
|     jest | ||||
|       .spyOn(router, 'events', 'get') | ||||
|       .mockReturnValue(routerEvents$.asObservable()) | ||||
|     chatService = TestBed.inject(ChatService) | ||||
|     mockStream$ = new Subject<string>() | ||||
|     jest | ||||
|       .spyOn(chatService, 'streamChat') | ||||
|       .mockReturnValue(mockStream$.asObservable()) | ||||
|     component = fixture.componentInstance | ||||
|  | ||||
|     jest.useFakeTimers() | ||||
|  | ||||
|     fixture.detectChanges() | ||||
|  | ||||
|     component.scrollAnchor.nativeElement.scrollIntoView = jest.fn() | ||||
|   }) | ||||
|  | ||||
|   it('should update documentId on initialization', () => { | ||||
|     jest.spyOn(router, 'url', 'get').mockReturnValue('/documents/123') | ||||
|     component.ngOnInit() | ||||
|     expect(component.documentId).toBe(123) | ||||
|   }) | ||||
|  | ||||
|   it('should update documentId on navigation', () => { | ||||
|     component.ngOnInit() | ||||
|     routerEvents$.next(new NavigationEnd(1, '/documents/456', '/documents/456')) | ||||
|     expect(component.documentId).toBe(456) | ||||
|   }) | ||||
|  | ||||
|   it('should return correct placeholder based on documentId', () => { | ||||
|     component.documentId = 123 | ||||
|     expect(component.placeholder).toBe('Ask a question about this document...') | ||||
|     component.documentId = undefined | ||||
|     expect(component.placeholder).toBe('Ask a question about a document...') | ||||
|   }) | ||||
|  | ||||
|   it('should send a message and handle streaming response', () => { | ||||
|     component.input = 'Hello' | ||||
|     component.sendMessage() | ||||
|  | ||||
|     expect(component.messages.length).toBe(2) | ||||
|     expect(component.messages[0].content).toBe('Hello') | ||||
|     expect(component.loading).toBe(true) | ||||
|  | ||||
|     mockStream$.next('Hi') | ||||
|     expect(component.messages[1].content).toBe('H') | ||||
|     mockStream$.next('Hi there') | ||||
|     // advance time to process the typewriter effect | ||||
|     jest.advanceTimersByTime(1000) | ||||
|     expect(component.messages[1].content).toBe('Hi there') | ||||
|  | ||||
|     mockStream$.complete() | ||||
|     expect(component.loading).toBe(false) | ||||
|     expect(component.messages[1].isStreaming).toBe(false) | ||||
|   }) | ||||
|  | ||||
|   it('should handle errors during streaming', () => { | ||||
|     component.input = 'Hello' | ||||
|     component.sendMessage() | ||||
|  | ||||
|     mockStream$.error('Error') | ||||
|     expect(component.messages[1].content).toContain( | ||||
|       '⚠️ Error receiving response.' | ||||
|     ) | ||||
|     expect(component.loading).toBe(false) | ||||
|   }) | ||||
|  | ||||
|   it('should enqueue typewriter chunks correctly', () => { | ||||
|     const message = { content: '', role: 'assistant', isStreaming: true } | ||||
|     component.enqueueTypewriter(null, message as any) // coverage for null | ||||
|     component.enqueueTypewriter('Hello', message as any) | ||||
|     expect(component['typewriterBuffer'].length).toBe(4) | ||||
|   }) | ||||
|  | ||||
|   it('should scroll to bottom after sending a message', () => { | ||||
|     const scrollSpy = jest.spyOn( | ||||
|       ChatComponent.prototype as any, | ||||
|       'scrollToBottom' | ||||
|     ) | ||||
|     component.input = 'Test' | ||||
|     component.sendMessage() | ||||
|     expect(scrollSpy).toHaveBeenCalled() | ||||
|   }) | ||||
|  | ||||
|   it('should focus chat input when dropdown is opened', () => { | ||||
|     const focus = jest.fn() | ||||
|     component.chatInput = { | ||||
|       nativeElement: { focus: focus }, | ||||
|     } as unknown as ElementRef<HTMLInputElement> | ||||
|  | ||||
|     component.onOpenChange(true) | ||||
|     jest.advanceTimersByTime(15) | ||||
|     expect(focus).toHaveBeenCalled() | ||||
|   }) | ||||
|  | ||||
|   it('should send message on Enter key press', () => { | ||||
|     jest.spyOn(component, 'sendMessage') | ||||
|     const event = new KeyboardEvent('keydown', { key: 'Enter' }) | ||||
|     component.searchInputKeyDown(event) | ||||
|     expect(component.sendMessage).toHaveBeenCalled() | ||||
|   }) | ||||
| }) | ||||
							
								
								
									
										140
									
								
								src-ui/src/app/components/chat/chat/chat.component.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								src-ui/src/app/components/chat/chat/chat.component.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,140 @@ | ||||
| import { Component, ElementRef, inject, OnInit, ViewChild } from '@angular/core' | ||||
| import { FormsModule, ReactiveFormsModule } from '@angular/forms' | ||||
| import { NavigationEnd, Router } from '@angular/router' | ||||
| import { NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap' | ||||
| import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons' | ||||
| import { filter, map } from 'rxjs' | ||||
| import { ChatMessage, ChatService } from 'src/app/services/chat.service' | ||||
|  | ||||
| @Component({ | ||||
|   selector: 'pngx-chat', | ||||
|   imports: [ | ||||
|     FormsModule, | ||||
|     ReactiveFormsModule, | ||||
|     NgxBootstrapIconsModule, | ||||
|     NgbDropdownModule, | ||||
|   ], | ||||
|   templateUrl: './chat.component.html', | ||||
|   styleUrl: './chat.component.scss', | ||||
| }) | ||||
| export class ChatComponent implements OnInit { | ||||
|   public messages: ChatMessage[] = [] | ||||
|   public loading = false | ||||
|   public input: string = '' | ||||
|   public documentId!: number | ||||
|  | ||||
|   private chatService: ChatService = inject(ChatService) | ||||
|   private router: Router = inject(Router) | ||||
|  | ||||
|   @ViewChild('scrollAnchor') scrollAnchor!: ElementRef<HTMLDivElement> | ||||
|   @ViewChild('chatInput') chatInput!: ElementRef<HTMLInputElement> | ||||
|  | ||||
|   private typewriterBuffer: string[] = [] | ||||
|   private typewriterActive = false | ||||
|  | ||||
|   public get placeholder(): string { | ||||
|     return this.documentId | ||||
|       ? $localize`Ask a question about this document...` | ||||
|       : $localize`Ask a question about a document...` | ||||
|   } | ||||
|  | ||||
|   ngOnInit(): void { | ||||
|     this.updateDocumentId(this.router.url) | ||||
|     this.router.events | ||||
|       .pipe( | ||||
|         filter((event) => event instanceof NavigationEnd), | ||||
|         map((event) => (event as NavigationEnd).url) | ||||
|       ) | ||||
|       .subscribe((url) => { | ||||
|         this.updateDocumentId(url) | ||||
|       }) | ||||
|   } | ||||
|  | ||||
|   private updateDocumentId(url: string): void { | ||||
|     const docIdRe = url.match(/^\/documents\/(\d+)/) | ||||
|     this.documentId = docIdRe ? +docIdRe[1] : undefined | ||||
|   } | ||||
|  | ||||
|   sendMessage(): void { | ||||
|     if (!this.input.trim()) return | ||||
|  | ||||
|     const userMessage: ChatMessage = { role: 'user', content: this.input } | ||||
|     this.messages.push(userMessage) | ||||
|     this.scrollToBottom() | ||||
|  | ||||
|     const assistantMessage: ChatMessage = { | ||||
|       role: 'assistant', | ||||
|       content: '', | ||||
|       isStreaming: true, | ||||
|     } | ||||
|     this.messages.push(assistantMessage) | ||||
|     this.loading = true | ||||
|  | ||||
|     let lastPartialLength = 0 | ||||
|  | ||||
|     this.chatService.streamChat(this.documentId, this.input).subscribe({ | ||||
|       next: (chunk) => { | ||||
|         const delta = chunk.substring(lastPartialLength) | ||||
|         lastPartialLength = chunk.length | ||||
|         this.enqueueTypewriter(delta, assistantMessage) | ||||
|       }, | ||||
|       error: () => { | ||||
|         assistantMessage.content += '\n\n⚠️ Error receiving response.' | ||||
|         assistantMessage.isStreaming = false | ||||
|         this.loading = false | ||||
|       }, | ||||
|       complete: () => { | ||||
|         assistantMessage.isStreaming = false | ||||
|         this.loading = false | ||||
|         this.scrollToBottom() | ||||
|       }, | ||||
|     }) | ||||
|  | ||||
|     this.input = '' | ||||
|   } | ||||
|  | ||||
|   enqueueTypewriter(chunk: string, message: ChatMessage): void { | ||||
|     if (!chunk) return | ||||
|  | ||||
|     this.typewriterBuffer.push(...chunk.split('')) | ||||
|  | ||||
|     if (!this.typewriterActive) { | ||||
|       this.typewriterActive = true | ||||
|       this.playTypewriter(message) | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   playTypewriter(message: ChatMessage): void { | ||||
|     if (this.typewriterBuffer.length === 0) { | ||||
|       this.typewriterActive = false | ||||
|       return | ||||
|     } | ||||
|  | ||||
|     const nextChar = this.typewriterBuffer.shift() | ||||
|     message.content += nextChar | ||||
|     this.scrollToBottom() | ||||
|  | ||||
|     setTimeout(() => this.playTypewriter(message), 10) // 10ms per character | ||||
|   } | ||||
|  | ||||
|   private scrollToBottom(): void { | ||||
|     setTimeout(() => { | ||||
|       this.scrollAnchor?.nativeElement?.scrollIntoView({ behavior: 'smooth' }) | ||||
|     }, 50) | ||||
|   } | ||||
|  | ||||
|   public onOpenChange(open: boolean): void { | ||||
|     if (open) { | ||||
|       setTimeout(() => { | ||||
|         this.chatInput.nativeElement.focus() | ||||
|       }, 10) | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   public searchInputKeyDown(event: KeyboardEvent) { | ||||
|     if (event.key === 'Enter') { | ||||
|       event.preventDefault() | ||||
|       this.sendMessage() | ||||
|     } | ||||
|   } | ||||
| } | ||||
| @@ -1,7 +1,7 @@ | ||||
| <div ngbDropdown #fieldDropdown="ngbDropdown" (openChange)="onOpenClose($event)" [popperOptions]="popperOptions" placement="bottom-end"> | ||||
|     <button class="btn btn-sm btn-outline-primary" id="customFieldsDropdown" [disabled]="disabled" ngbDropdownToggle> | ||||
| <div ngbDropdown #fieldDropdown="ngbDropdown" (openChange)="onOpenClose($event)" [popperOptions]="popperOptions"> | ||||
|     <button type="button" class="btn btn-sm btn-outline-primary" id="customFieldsDropdown" [disabled]="disabled" ngbDropdownToggle> | ||||
|       <i-bs name="ui-radios"></i-bs> | ||||
|       <div class="d-none d-sm-inline"> <ng-container i18n>Custom Fields</ng-container></div> | ||||
|       <div class="d-none d-lg-inline"> <ng-container i18n>Custom Fields</ng-container></div> | ||||
|     </button> | ||||
|     <div ngbDropdownMenu aria-labelledby="customFieldsDropdown" class="shadow custom-fields-dropdown"> | ||||
|         <div class="list-group list-group-flush" (keydown)="listKeyDown($event)"> | ||||
|   | ||||
| @@ -1,17 +1,24 @@ | ||||
| <div class="mb-3"> | ||||
|   <label class="form-label" [for]="inputId">{{title}}</label> | ||||
|   <div class="input-group" [class.is-invalid]="error"> | ||||
|     <input #inputField [type]="showReveal && textVisible ? 'text' : 'password'" class="form-control" [class.is-invalid]="error" [id]="inputId" [(ngModel)]="value" (focus)="onFocus()" (focusout)="onFocusOut()" (change)="onChange(value)" [disabled]="disabled" [autocomplete]="autocomplete"> | ||||
|     @if (showReveal) { | ||||
|       <button type="button" class="btn btn-outline-secondary" (click)="toggleVisibility()" i18n-title title="Show password" [disabled]="disabled || disableRevealToggle"> | ||||
|         <i-bs name="eye"></i-bs> | ||||
|       </button> | ||||
| <div class="mb-3" [class.pb-3]="error"> | ||||
|   <div class="row"> | ||||
|     <div class="d-flex align-items-center position-relative hidden-button-container" [class.col-md-3]="horizontal"> | ||||
|       @if (title) { | ||||
|         <label class="form-label" [class.mb-md-0]="horizontal" [for]="inputId">{{title}}</label> | ||||
|       } | ||||
|     </div> | ||||
|   <div class="position-relative" [class.col-md-9]="horizontal"> | ||||
|     <div class="input-group" [class.is-invalid]="error"> | ||||
|       <input #inputField [type]="showReveal && textVisible ? 'text' : 'password'" class="form-control" [class.is-invalid]="error" [id]="inputId" [(ngModel)]="value" (focus)="onFocus()" (focusout)="onFocusOut()" (change)="onChange(value)" [disabled]="disabled" [autocomplete]="autocomplete"> | ||||
|       @if (showReveal) { | ||||
|         <button type="button" class="btn btn-outline-secondary" (click)="toggleVisibility()" i18n-title title="Show password" [disabled]="disabled || disableRevealToggle"> | ||||
|           <i-bs name="eye"></i-bs> | ||||
|         </button> | ||||
|       } | ||||
|     </div> | ||||
|     <div class="invalid-feedback"> | ||||
|       {{error}} | ||||
|     </div> | ||||
|     @if (hint) { | ||||
|       <small class="form-text text-muted" [innerHTML]="hint | safeHtml"></small> | ||||
|     } | ||||
|   </div> | ||||
|   <div class="invalid-feedback"> | ||||
|     {{error}} | ||||
|   </div> | ||||
|   @if (hint) { | ||||
|     <small class="form-text text-muted" [innerHTML]="hint | safeHtml"></small> | ||||
|   } | ||||
| </div> | ||||
|   | ||||
| @@ -15,6 +15,12 @@ | ||||
|         @if (hint) { | ||||
|           <small class="form-text text-muted" [innerHTML]="hint | safeHtml"></small> | ||||
|         } | ||||
|         @if (getSuggestion()?.length > 0) { | ||||
|           <small> | ||||
|             <span i18n>Suggestion:</span>  | ||||
|             <a (click)="applySuggestion(s)" [routerLink]="[]">{{getSuggestion()}}</a>  | ||||
|           </small> | ||||
|         } | ||||
|         <div class="invalid-feedback position-absolute top-100"> | ||||
|           {{error}} | ||||
|         </div> | ||||
|   | ||||
| @@ -26,10 +26,20 @@ describe('TextComponent', () => { | ||||
|  | ||||
|   it('should support use of input field', () => { | ||||
|     expect(component.value).toBeUndefined() | ||||
|     // TODO: why doesn't this work? | ||||
|     // input.value = 'foo' | ||||
|     // input.dispatchEvent(new Event('change')) | ||||
|     // fixture.detectChanges() | ||||
|     // expect(component.value).toEqual('foo') | ||||
|     input.value = 'foo' | ||||
|     input.dispatchEvent(new Event('input')) | ||||
|     fixture.detectChanges() | ||||
|     expect(component.value).toBe('foo') | ||||
|   }) | ||||
|  | ||||
|   it('should support suggestion', () => { | ||||
|     component.value = 'foo' | ||||
|     component.suggestion = 'foo' | ||||
|     expect(component.getSuggestion()).toBe('') | ||||
|     component.value = 'bar' | ||||
|     expect(component.getSuggestion()).toBe('foo') | ||||
|     component.applySuggestion() | ||||
|     fixture.detectChanges() | ||||
|     expect(component.value).toBe('foo') | ||||
|   }) | ||||
| }) | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import { | ||||
|   NG_VALUE_ACCESSOR, | ||||
|   ReactiveFormsModule, | ||||
| } from '@angular/forms' | ||||
| import { RouterLink } from '@angular/router' | ||||
| import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons' | ||||
| import { SafeHtmlPipe } from 'src/app/pipes/safehtml.pipe' | ||||
| import { AbstractInputComponent } from '../abstract-input' | ||||
| @@ -24,6 +25,7 @@ import { AbstractInputComponent } from '../abstract-input' | ||||
|     ReactiveFormsModule, | ||||
|     SafeHtmlPipe, | ||||
|     NgxBootstrapIconsModule, | ||||
|     RouterLink, | ||||
|   ], | ||||
| }) | ||||
| export class TextComponent extends AbstractInputComponent<string> { | ||||
| @@ -33,7 +35,19 @@ export class TextComponent extends AbstractInputComponent<string> { | ||||
|   @Input() | ||||
|   placeholder: string = '' | ||||
|  | ||||
|   @Input() | ||||
|   suggestion: string = '' | ||||
|  | ||||
|   constructor() { | ||||
|     super() | ||||
|   } | ||||
|  | ||||
|   getSuggestion() { | ||||
|     return this.value !== this.suggestion ? this.suggestion : '' | ||||
|   } | ||||
|  | ||||
|   applySuggestion() { | ||||
|     this.value = this.suggestion | ||||
|     this.onChange(this.value) | ||||
|   } | ||||
| } | ||||
|   | ||||
| @@ -0,0 +1,49 @@ | ||||
| <div class="btn-group"> | ||||
|   <button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="loading || (suggestions && !aiEnabled)"> | ||||
|     @if (loading) { | ||||
|       <div class="spinner-border spinner-border-sm" role="status"></div> | ||||
|     } @else { | ||||
|       <i-bs width="1.2em" height="1.2em" name="stars"></i-bs> | ||||
|     } | ||||
|     <span class="d-none d-lg-inline ps-1" i18n>Suggest</span> | ||||
|     @if (totalSuggestions > 0) { | ||||
|       <span class="badge bg-primary ms-2">{{ totalSuggestions }}</span> | ||||
|     } | ||||
|   </button> | ||||
|  | ||||
|   @if (aiEnabled) { | ||||
|     <div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions"> | ||||
|       <button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown"> | ||||
|         <span class="visually-hidden" i18n>Show suggestions</span> | ||||
|       </button> | ||||
|  | ||||
|       <div ngbDropdownMenu aria-labelledby="suggestionsDropdown" class="shadow suggestions-dropdown"> | ||||
|         <div class="list-group list-group-flush small pb-0"> | ||||
|           @if (!suggestions?.suggested_tags && !suggestions?.suggested_document_types && !suggestions?.suggested_correspondents) { | ||||
|             <div class="list-group-item text-muted fst-italic"> | ||||
|               <small class="text-muted small fst-italic" i18n>No novel suggestions</small> | ||||
|             </div> | ||||
|           } | ||||
|           @if (suggestions?.suggested_tags.length > 0) { | ||||
|             <small class="list-group-item text-uppercase text-muted small">Tags</small> | ||||
|             @for (tag of suggestions.suggested_tags; track tag) { | ||||
|               <button type="button" class="list-group-item list-group-item-action bg-light" (click)="addTag.emit(tag)" i18n>{{ tag }}</button> | ||||
|             } | ||||
|           } | ||||
|           @if (suggestions?.suggested_document_types.length > 0) { | ||||
|             <div class="list-group-item text-uppercase text-muted small">Document Types</div> | ||||
|             @for (type of suggestions.suggested_document_types; track type) { | ||||
|               <button type="button" class="list-group-item list-group-item-action bg-light" (click)="addDocumentType.emit(type)" i18n>{{ type }}</button> | ||||
|             } | ||||
|           } | ||||
|           @if (suggestions?.suggested_correspondents.length > 0) { | ||||
|             <div class="list-group-item text-uppercase text-muted small">Correspondents</div> | ||||
|             @for (correspondent of suggestions.suggested_correspondents; track correspondent) { | ||||
|               <button type="button" class="list-group-item list-group-item-action bg-light" (click)="addCorrespondent.emit(correspondent)" i18n>{{ correspondent }}</button> | ||||
|             } | ||||
|           } | ||||
|         </div> | ||||
|       </div> | ||||
|     </div> | ||||
|   } | ||||
| </div> | ||||
| @@ -0,0 +1,3 @@ | ||||
| .suggestions-dropdown { | ||||
|   min-width: 250px; | ||||
| } | ||||
| @@ -0,0 +1,51 @@ | ||||
| import { ComponentFixture, TestBed } from '@angular/core/testing' | ||||
| import { NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap' | ||||
| import { NgxBootstrapIconsModule, allIcons } from 'ngx-bootstrap-icons' | ||||
| import { SuggestionsDropdownComponent } from './suggestions-dropdown.component' | ||||
|  | ||||
| describe('SuggestionsDropdownComponent', () => { | ||||
|   let component: SuggestionsDropdownComponent | ||||
|   let fixture: ComponentFixture<SuggestionsDropdownComponent> | ||||
|  | ||||
|   beforeEach(() => { | ||||
|     TestBed.configureTestingModule({ | ||||
|       imports: [ | ||||
|         NgbDropdownModule, | ||||
|         NgxBootstrapIconsModule.pick(allIcons), | ||||
|         SuggestionsDropdownComponent, | ||||
|       ], | ||||
|       providers: [], | ||||
|     }) | ||||
|     fixture = TestBed.createComponent(SuggestionsDropdownComponent) | ||||
|     component = fixture.componentInstance | ||||
|     fixture.detectChanges() | ||||
|   }) | ||||
|  | ||||
|   it('should calculate totalSuggestions', () => { | ||||
|     component.suggestions = { | ||||
|       suggested_correspondents: ['John Doe'], | ||||
|       suggested_tags: ['Tag1', 'Tag2'], | ||||
|       suggested_document_types: ['Type1'], | ||||
|     } | ||||
|     expect(component.totalSuggestions).toBe(4) | ||||
|   }) | ||||
|  | ||||
|   it('should emit getSuggestions when clickSuggest is called and suggestions are null', () => { | ||||
|     jest.spyOn(component.getSuggestions, 'emit') | ||||
|     component.suggestions = null | ||||
|     component.clickSuggest() | ||||
|     expect(component.getSuggestions.emit).toHaveBeenCalled() | ||||
|   }) | ||||
|  | ||||
|   it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => { | ||||
|     component.aiEnabled = true | ||||
|     fixture.detectChanges() | ||||
|     component.suggestions = { | ||||
|       suggested_correspondents: [], | ||||
|       suggested_tags: [], | ||||
|       suggested_document_types: [], | ||||
|     } | ||||
|     component.clickSuggest() | ||||
|     expect(component.dropdown.open).toBeTruthy() | ||||
|   }) | ||||
| }) | ||||
| @@ -0,0 +1,64 @@ | ||||
| import { | ||||
|   Component, | ||||
|   EventEmitter, | ||||
|   Input, | ||||
|   Output, | ||||
|   ViewChild, | ||||
| } from '@angular/core' | ||||
| import { NgbDropdown, NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap' | ||||
| import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons' | ||||
| import { DocumentSuggestions } from 'src/app/data/document-suggestions' | ||||
| import { pngxPopperOptions } from 'src/app/utils/popper-options' | ||||
|  | ||||
| @Component({ | ||||
|   selector: 'pngx-suggestions-dropdown', | ||||
|   imports: [NgbDropdownModule, NgxBootstrapIconsModule], | ||||
|   templateUrl: './suggestions-dropdown.component.html', | ||||
|   styleUrl: './suggestions-dropdown.component.scss', | ||||
| }) | ||||
| export class SuggestionsDropdownComponent { | ||||
|   public popperOptions = pngxPopperOptions | ||||
|  | ||||
|   @ViewChild('dropdown') dropdown: NgbDropdown | ||||
|  | ||||
|   @Input() | ||||
|   suggestions: DocumentSuggestions = null | ||||
|  | ||||
|   @Input() | ||||
|   aiEnabled: boolean = false | ||||
|  | ||||
|   @Input() | ||||
|   loading: boolean = false | ||||
|  | ||||
|   @Input() | ||||
|   disabled: boolean = false | ||||
|  | ||||
|   @Output() | ||||
|   getSuggestions: EventEmitter<SuggestionsDropdownComponent> = | ||||
|     new EventEmitter() | ||||
|  | ||||
|   @Output() | ||||
|   addTag: EventEmitter<string> = new EventEmitter() | ||||
|  | ||||
|   @Output() | ||||
|   addDocumentType: EventEmitter<string> = new EventEmitter() | ||||
|  | ||||
|   @Output() | ||||
|   addCorrespondent: EventEmitter<string> = new EventEmitter() | ||||
|  | ||||
|   public clickSuggest(): void { | ||||
|     if (!this.suggestions) { | ||||
|       this.getSuggestions.emit(this) | ||||
|     } else { | ||||
|       this.dropdown?.toggle() | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   get totalSuggestions(): number { | ||||
|     return ( | ||||
|       this.suggestions?.suggested_correspondents?.length + | ||||
|         this.suggestions?.suggested_tags?.length + | ||||
|         this.suggestions?.suggested_document_types?.length || 0 | ||||
|     ) | ||||
|   } | ||||
| } | ||||
| @@ -266,6 +266,43 @@ | ||||
|                   } | ||||
|                 </span> | ||||
|               </dd> | ||||
|               @if (aiEnabled) { | ||||
|                 <dt i18n>AI Index</dt> | ||||
|                 <dd class="d-flex align-items-center"> | ||||
|                   <button class="btn btn-sm d-flex align-items-center btn-dark text-uppercase small" [ngbPopover]="llmIndexStatus" triggers="click mouseenter:mouseleave"> | ||||
|                     {{status.tasks.llmindex_status}} | ||||
|                     @if (status.tasks.llmindex_status === 'OK') { | ||||
|                       @if (isStale(status.tasks.llmindex_last_modified)) { | ||||
|                         <i-bs name="exclamation-triangle-fill" class="text-warning ms-2 lh-1"></i-bs> | ||||
|                       } @else { | ||||
|                         <i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs> | ||||
|                       } | ||||
|                     } @else { | ||||
|                       <i-bs name="exclamation-triangle-fill" class="ms-2 lh-1" | ||||
|                       [class.text-danger]="status.tasks.llmindex_status === SystemStatusItemStatus.ERROR" | ||||
|                       [class.text-warning]="status.tasks.llmindex_status === SystemStatusItemStatus.WARNING" | ||||
|                       [class.text-muted]="status.tasks.llmindex_status === SystemStatusItemStatus.DISABLED"></i-bs> | ||||
|                     } | ||||
|                   </button> | ||||
|                   @if (currentUserIsSuperUser) { | ||||
|                     @if (isRunning(PaperlessTaskName.LLMIndexUpdate)) { | ||||
|                       <div class="spinner-border spinner-border-sm ms-2" role="status"></div> | ||||
|                     } @else { | ||||
|                       <button class="btn btn-sm d-flex align-items-center btn-dark small ms-2" (click)="runTask(PaperlessTaskName.LLMIndexUpdate)"> | ||||
|                         <i-bs name="play-fill"></i-bs>  | ||||
|                         <ng-container i18n>Run Task</ng-container> | ||||
|                       </button> | ||||
|                     } | ||||
|                   } | ||||
|                 </dd> | ||||
|                 <ng-template #llmIndexStatus> | ||||
|                   @if (status.tasks.llmindex_status === 'OK') { | ||||
|                     <h6><ng-container i18n>Last Run</ng-container>:</h6> <span class="font-monospace small">{{status.tasks.llmindex_last_modified | customDate:'medium'}}</span> | ||||
|                   } @else { | ||||
|                     <h6><ng-container i18n>Error</ng-container>:</h6> <span class="font-monospace small">{{status.tasks.llmindex_error}}</span> | ||||
|                   } | ||||
|                 </ng-template> | ||||
|               } | ||||
|             </dl> | ||||
|           </div> | ||||
|         </div> | ||||
|   | ||||
| @@ -68,6 +68,9 @@ const status: SystemStatus = { | ||||
|     sanity_check_status: SystemStatusItemStatus.OK, | ||||
|     sanity_check_last_run: new Date().toISOString(), | ||||
|     sanity_check_error: null, | ||||
|     llmindex_status: SystemStatusItemStatus.OK, | ||||
|     llmindex_last_modified: new Date().toISOString(), | ||||
|     llmindex_error: null, | ||||
|   }, | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -13,9 +13,11 @@ import { | ||||
|   SystemStatus, | ||||
|   SystemStatusItemStatus, | ||||
| } from 'src/app/data/system-status' | ||||
| import { SETTINGS_KEYS } from 'src/app/data/ui-settings' | ||||
| import { CustomDatePipe } from 'src/app/pipes/custom-date.pipe' | ||||
| import { FileSizePipe } from 'src/app/pipes/file-size.pipe' | ||||
| import { PermissionsService } from 'src/app/services/permissions.service' | ||||
| import { SettingsService } from 'src/app/services/settings.service' | ||||
| import { SystemStatusService } from 'src/app/services/system-status.service' | ||||
| import { TasksService } from 'src/app/services/tasks.service' | ||||
| import { ToastService } from 'src/app/services/toast.service' | ||||
| @@ -44,6 +46,7 @@ export class SystemStatusDialogComponent implements OnInit, OnDestroy { | ||||
|   private toastService = inject(ToastService) | ||||
|   private permissionsService = inject(PermissionsService) | ||||
|   private websocketStatusService = inject(WebsocketStatusService) | ||||
|   private settingsService = inject(SettingsService) | ||||
|  | ||||
|   public SystemStatusItemStatus = SystemStatusItemStatus | ||||
|   public PaperlessTaskName = PaperlessTaskName | ||||
| @@ -60,6 +63,10 @@ export class SystemStatusDialogComponent implements OnInit, OnDestroy { | ||||
|     return this.permissionsService.isSuperUser() | ||||
|   } | ||||
|  | ||||
|   get aiEnabled(): boolean { | ||||
|     return this.settingsService.get(SETTINGS_KEYS.AI_ENABLED) | ||||
|   } | ||||
|  | ||||
|   public ngOnInit() { | ||||
|     this.versionMismatch = | ||||
|       environment.production && | ||||
|   | ||||
| @@ -68,16 +68,6 @@ | ||||
|     </div> | ||||
|   </div> | ||||
|  | ||||
|   <pngx-custom-fields-dropdown | ||||
|     *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.CustomField }" | ||||
|     [documentId]="documentId" | ||||
|     [disabled]="!userCanEdit" | ||||
|     [existingFields]="document?.custom_fields" | ||||
|     (created)="refreshCustomFields()" | ||||
|     (added)="addField($event)"> | ||||
|   </pngx-custom-fields-dropdown> | ||||
|  | ||||
|  | ||||
|   <div class="ms-auto" ngbDropdown> | ||||
|     <button class="btn btn-sm btn-outline-primary" id="sendDropdown" ngbDropdownToggle> | ||||
|       <i-bs name="send"></i-bs> | ||||
| @@ -98,7 +88,7 @@ | ||||
| </pngx-page-header> | ||||
|  | ||||
| <div class="row"> | ||||
|   <div class="col-md-6 col-xl-4 mb-4"> | ||||
|   <div class="col-md-6 col-xl-5 mb-4"> | ||||
|  | ||||
|     <form [formGroup]='documentForm' (ngSubmit)="save()"> | ||||
|  | ||||
| @@ -115,6 +105,32 @@ | ||||
|           </button> | ||||
|         </div> | ||||
|  | ||||
|         <ng-container *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.Document }"> | ||||
|           <div class="btn-group pb-3 ms-auto"> | ||||
|             <pngx-suggestions-dropdown *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.Document }" | ||||
|               [disabled]="!userCanEdit || suggestionsLoading" | ||||
|               [loading]="suggestionsLoading" | ||||
|               [suggestions]="suggestions" | ||||
|               [aiEnabled]="aiEnabled" | ||||
|               (getSuggestions)="getSuggestions()" | ||||
|               (addTag)="createTag($event)" | ||||
|               (addDocumentType)="createDocumentType($event)" | ||||
|               (addCorrespondent)="createCorrespondent($event)"> | ||||
|             </pngx-suggestions-dropdown> | ||||
|           </div> | ||||
|  | ||||
|           <div class="btn-group pb-3 ms-2"> | ||||
|             <pngx-custom-fields-dropdown | ||||
|               *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.CustomField }" | ||||
|               [documentId]="documentId" | ||||
|               [disabled]="!userCanEdit" | ||||
|               [existingFields]="document?.custom_fields" | ||||
|               (created)="refreshCustomFields()" | ||||
|               (added)="addField($event)"> | ||||
|             </pngx-custom-fields-dropdown> | ||||
|           </div> | ||||
|         </ng-container> | ||||
|  | ||||
|         <ng-container *ngTemplateOutlet="saveButtons"></ng-container> | ||||
|       </div> | ||||
|  | ||||
| @@ -123,7 +139,7 @@ | ||||
|           <a ngbNavLink i18n>Details</a> | ||||
|           <ng-template ngbNavContent> | ||||
|             <div> | ||||
|               <pngx-input-text #inputTitle i18n-title title="Title" formControlName="title" [horizontal]="true" (keyup)="titleKeyUp($event)" [error]="error?.title"></pngx-input-text> | ||||
|               <pngx-input-text #inputTitle i18n-title title="Title" formControlName="title" [horizontal]="true" [suggestion]="suggestions?.title" (keyup)="titleKeyUp($event)" [error]="error?.title"></pngx-input-text> | ||||
|               <pngx-input-number i18n-title title="Archive serial number" [error]="error?.archive_serial_number" [horizontal]="true" formControlName='archive_serial_number'></pngx-input-number> | ||||
|               <pngx-input-date i18n-title title="Date created" formControlName="created" [suggestions]="suggestions?.dates" [showFilter]="true" [horizontal]="true" (filterDocuments)="filterDocuments($event)" | ||||
|               [error]="error?.created"></pngx-input-date> | ||||
| @@ -133,7 +149,7 @@ | ||||
|               (createNew)="createDocumentType($event)" [hideAddButton]="createDisabled(DataType.DocumentType)" [suggestions]="suggestions?.document_types" *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.DocumentType }"></pngx-input-select> | ||||
|               <pngx-input-select [items]="storagePaths" i18n-title title="Storage path" formControlName="storage_path" [allowNull]="true" [showFilter]="true" [horizontal]="true" (filterDocuments)="filterDocuments($event, DataType.StoragePath)" | ||||
|               (createNew)="createStoragePath($event)" [hideAddButton]="createDisabled(DataType.StoragePath)" [suggestions]="suggestions?.storage_paths" i18n-placeholder placeholder="Default" *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.StoragePath }"></pngx-input-select> | ||||
|               <pngx-input-tags formControlName="tags" [suggestions]="suggestions?.tags" [showFilter]="true" [horizontal]="true" (filterDocuments)="filterDocuments($event, DataType.Tag)" [hideAddButton]="createDisabled(DataType.Tag)" *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.Tag }"></pngx-input-tags> | ||||
|               <pngx-input-tags #tagsInput formControlName="tags" [suggestions]="suggestions?.tags" [showFilter]="true" [horizontal]="true" (filterDocuments)="filterDocuments($event, DataType.Tag)" [hideAddButton]="createDisabled(DataType.Tag)" *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.Tag }"></pngx-input-tags> | ||||
|               @for (fieldInstance of document?.custom_fields; track fieldInstance.field; let i = $index) { | ||||
|                 <div [formGroup]="customFieldFormFields.controls[i]"> | ||||
|                   @switch (getCustomFieldFromInstance(fieldInstance)?.data_type) { | ||||
| @@ -355,14 +371,14 @@ | ||||
|     </form> | ||||
|   </div> | ||||
|  | ||||
|   <div class="col-md-6 col-xl-8 mb-3 d-none d-md-block position-relative" #pdfPreview> | ||||
|   <div class="col-md-6 col-xl-7 mb-3 d-none d-md-block position-relative" #pdfPreview> | ||||
|     <ng-container *ngTemplateOutlet="previewContent"></ng-container> | ||||
|   </div> | ||||
|  | ||||
| </div> | ||||
|  | ||||
| <ng-template #saveButtons> | ||||
|   <div class="btn-group pb-3 ms-auto"> | ||||
|   <div class="btn-group pb-3 ms-4"> | ||||
|     <ng-container *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.Document }"> | ||||
|       <button type="submit" class="order-3 btn btn-sm btn-primary" i18n [disabled]="!userCanEdit || networkActive || (isDirty$ | async) !== true">Save</button> | ||||
|       @if (hasNext()) { | ||||
|   | ||||
| @@ -139,6 +139,7 @@ describe('DocumentDetailComponent', () => { | ||||
|   let deviceDetectorService: DeviceDetectorService | ||||
|   let httpTestingController: HttpTestingController | ||||
|   let componentRouterService: ComponentRouterService | ||||
|   let tagService: TagService | ||||
|  | ||||
|   let currentUserCan = true | ||||
|   let currentUserHasObjectPermissions = true | ||||
| @@ -156,6 +157,16 @@ describe('DocumentDetailComponent', () => { | ||||
|         { | ||||
|           provide: TagService, | ||||
|           useValue: { | ||||
|             getCachedMany: (ids: number[]) => | ||||
|               of( | ||||
|                 ids.map((id) => ({ | ||||
|                   id, | ||||
|                   name: `Tag${id}`, | ||||
|                   is_inbox_tag: true, | ||||
|                   color: '#ff0000', | ||||
|                   text_color: '#000000', | ||||
|                 })) | ||||
|               ), | ||||
|             listAll: () => | ||||
|               of({ | ||||
|                 count: 3, | ||||
| @@ -278,6 +289,7 @@ describe('DocumentDetailComponent', () => { | ||||
|     fixture = TestBed.createComponent(DocumentDetailComponent) | ||||
|     httpTestingController = TestBed.inject(HttpTestingController) | ||||
|     componentRouterService = TestBed.inject(ComponentRouterService) | ||||
|     tagService = TestBed.inject(TagService) | ||||
|     component = fixture.componentInstance | ||||
|   }) | ||||
|  | ||||
| @@ -382,8 +394,75 @@ describe('DocumentDetailComponent', () => { | ||||
|     currentUserCan = true | ||||
|   }) | ||||
|  | ||||
|   it('should support creating document type', () => { | ||||
|   it('should support creating tag, remove from suggestions', () => { | ||||
|     initNormally() | ||||
|     component.suggestions = { | ||||
|       suggested_tags: ['Tag1', 'NewTag12'], | ||||
|     } | ||||
|     let openModal: NgbModalRef | ||||
|     modalService.activeInstances.subscribe((modal) => (openModal = modal[0])) | ||||
|     const modalSpy = jest.spyOn(modalService, 'open') | ||||
|     // temporarily add NewTag12 to listAll results | ||||
|     const listAllSpy = jest | ||||
|       .spyOn(tagService, 'listAll') | ||||
|       .mockImplementation(() => | ||||
|         of({ | ||||
|           count: 4, | ||||
|           all: [41, 42, 43, 12], | ||||
|           results: [ | ||||
|             { | ||||
|               id: 41, | ||||
|               name: 'Tag41', | ||||
|               is_inbox_tag: true, | ||||
|               color: '#ff0000', | ||||
|               text_color: '#000000', | ||||
|             }, | ||||
|             { | ||||
|               id: 42, | ||||
|               name: 'Tag42', | ||||
|               is_inbox_tag: true, | ||||
|               color: '#ff0000', | ||||
|               text_color: '#000000', | ||||
|             }, | ||||
|             { | ||||
|               id: 43, | ||||
|               name: 'Tag43', | ||||
|               is_inbox_tag: true, | ||||
|               color: '#ff0000', | ||||
|               text_color: '#000000', | ||||
|             }, | ||||
|             { | ||||
|               id: 12, | ||||
|               name: 'NewTag12', | ||||
|               is_inbox_tag: true, | ||||
|               color: '#ff0000', | ||||
|               text_color: '#000000', | ||||
|             }, | ||||
|           ], | ||||
|         }) | ||||
|       ) | ||||
|     try { | ||||
|       component.createTag('NewTag12') | ||||
|       expect(modalSpy).toHaveBeenCalled() | ||||
|       openModal.componentInstance.succeeded.next({ | ||||
|         id: 12, | ||||
|         name: 'NewTag12', | ||||
|         is_inbox_tag: true, | ||||
|         color: '#ff0000', | ||||
|         text_color: '#000000', | ||||
|       }) | ||||
|       expect(component.tagsInput.value.includes(12)).toBeTruthy() | ||||
|       expect(component.suggestions.suggested_tags).not.toContain('NewTag12') | ||||
|     } finally { | ||||
|       listAllSpy.mockRestore() | ||||
|     } | ||||
|   }) | ||||
|  | ||||
|   it('should support creating document type, remove from suggestions', () => { | ||||
|     initNormally() | ||||
|     component.suggestions = { | ||||
|       suggested_document_types: ['DocumentType1', 'NewDocType2'], | ||||
|     } | ||||
|     let openModal: NgbModalRef | ||||
|     modalService.activeInstances.subscribe((modal) => (openModal = modal[0])) | ||||
|     const modalSpy = jest.spyOn(modalService, 'open') | ||||
| @@ -391,10 +470,16 @@ describe('DocumentDetailComponent', () => { | ||||
|     expect(modalSpy).toHaveBeenCalled() | ||||
|     openModal.componentInstance.succeeded.next({ id: 12, name: 'NewDocType12' }) | ||||
|     expect(component.documentForm.get('document_type').value).toEqual(12) | ||||
|     expect(component.suggestions.suggested_document_types).not.toContain( | ||||
|       'NewDocType2' | ||||
|     ) | ||||
|   }) | ||||
|  | ||||
|   it('should support creating correspondent', () => { | ||||
|   it('should support creating correspondent, remove from suggestions', () => { | ||||
|     initNormally() | ||||
|     component.suggestions = { | ||||
|       suggested_correspondents: ['Correspondent1', 'NewCorrrespondent12'], | ||||
|     } | ||||
|     let openModal: NgbModalRef | ||||
|     modalService.activeInstances.subscribe((modal) => (openModal = modal[0])) | ||||
|     const modalSpy = jest.spyOn(modalService, 'open') | ||||
| @@ -405,6 +490,9 @@ describe('DocumentDetailComponent', () => { | ||||
|       name: 'NewCorrrespondent12', | ||||
|     }) | ||||
|     expect(component.documentForm.get('correspondent').value).toEqual(12) | ||||
|     expect(component.suggestions.suggested_correspondents).not.toContain( | ||||
|       'NewCorrrespondent12' | ||||
|     ) | ||||
|   }) | ||||
|  | ||||
|   it('should support creating storage path', () => { | ||||
| @@ -995,7 +1083,7 @@ describe('DocumentDetailComponent', () => { | ||||
|     expect(component.document.custom_fields).toHaveLength(initialLength - 1) | ||||
|     expect(component.customFieldFormFields).toHaveLength(initialLength - 1) | ||||
|     expect( | ||||
|       fixture.debugElement.query(By.css('form')).nativeElement.textContent | ||||
|       fixture.debugElement.query(By.css('form ul')).nativeElement.textContent | ||||
|     ).not.toContain('Field 1') | ||||
|     const patchSpy = jest.spyOn(documentService, 'patch') | ||||
|     component.save(true) | ||||
| @@ -1086,10 +1174,22 @@ describe('DocumentDetailComponent', () => { | ||||
|  | ||||
|   it('should get suggestions', () => { | ||||
|     const suggestionsSpy = jest.spyOn(documentService, 'getSuggestions') | ||||
|     suggestionsSpy.mockReturnValue(of({ tags: [42, 43] })) | ||||
|     suggestionsSpy.mockReturnValue( | ||||
|       of({ | ||||
|         tags: [42, 43], | ||||
|         suggested_tags: [], | ||||
|         suggested_document_types: [], | ||||
|         suggested_correspondents: [], | ||||
|       }) | ||||
|     ) | ||||
|     initNormally() | ||||
|     expect(suggestionsSpy).toHaveBeenCalled() | ||||
|     expect(component.suggestions).toEqual({ tags: [42, 43] }) | ||||
|     expect(component.suggestions).toEqual({ | ||||
|       tags: [42, 43], | ||||
|       suggested_tags: [], | ||||
|       suggested_document_types: [], | ||||
|       suggested_correspondents: [], | ||||
|     }) | ||||
|   }) | ||||
|  | ||||
|   it('should show error if needed for get suggestions', () => { | ||||
|   | ||||
| @@ -76,6 +76,7 @@ import { DocumentTypeService } from 'src/app/services/rest/document-type.service | ||||
| import { DocumentService } from 'src/app/services/rest/document.service' | ||||
| import { SavedViewService } from 'src/app/services/rest/saved-view.service' | ||||
| import { StoragePathService } from 'src/app/services/rest/storage-path.service' | ||||
| import { TagService } from 'src/app/services/rest/tag.service' | ||||
| import { UserService } from 'src/app/services/rest/user.service' | ||||
| import { SettingsService } from 'src/app/services/settings.service' | ||||
| import { ToastService } from 'src/app/services/toast.service' | ||||
| @@ -88,6 +89,7 @@ import { CorrespondentEditDialogComponent } from '../common/edit-dialog/correspo | ||||
| import { DocumentTypeEditDialogComponent } from '../common/edit-dialog/document-type-edit-dialog/document-type-edit-dialog.component' | ||||
| import { EditDialogMode } from '../common/edit-dialog/edit-dialog.component' | ||||
| import { StoragePathEditDialogComponent } from '../common/edit-dialog/storage-path-edit-dialog/storage-path-edit-dialog.component' | ||||
| import { TagEditDialogComponent } from '../common/edit-dialog/tag-edit-dialog/tag-edit-dialog.component' | ||||
| import { EmailDocumentDialogComponent } from '../common/email-document-dialog/email-document-dialog.component' | ||||
| import { CheckComponent } from '../common/input/check/check.component' | ||||
| import { DateComponent } from '../common/input/date/date.component' | ||||
| @@ -106,6 +108,7 @@ import { | ||||
|   PdfEditorEditMode, | ||||
| } from '../common/pdf-editor/pdf-editor.component' | ||||
| import { ShareLinksDialogComponent } from '../common/share-links-dialog/share-links-dialog.component' | ||||
| import { SuggestionsDropdownComponent } from '../common/suggestions-dropdown/suggestions-dropdown.component' | ||||
| import { DocumentHistoryComponent } from '../document-history/document-history.component' | ||||
| import { DocumentNotesComponent } from '../document-notes/document-notes.component' | ||||
| import { ComponentWithPermissions } from '../with-permissions/with-permissions.component' | ||||
| @@ -162,6 +165,7 @@ export enum ZoomSetting { | ||||
|     NumberComponent, | ||||
|     MonetaryComponent, | ||||
|     UrlComponent, | ||||
|     SuggestionsDropdownComponent, | ||||
|     CustomDatePipe, | ||||
|     FileSizePipe, | ||||
|     IfPermissionsDirective, | ||||
| @@ -183,6 +187,7 @@ export class DocumentDetailComponent | ||||
| { | ||||
|   private documentsService = inject(DocumentService) | ||||
|   private route = inject(ActivatedRoute) | ||||
|   private tagService = inject(TagService) | ||||
|   private correspondentService = inject(CorrespondentService) | ||||
|   private documentTypeService = inject(DocumentTypeService) | ||||
|   private router = inject(Router) | ||||
| @@ -205,6 +210,8 @@ export class DocumentDetailComponent | ||||
|   @ViewChild('inputTitle') | ||||
|   titleInput: TextComponent | ||||
|  | ||||
|   @ViewChild('tagsInput') tagsInput: TagsComponent | ||||
|  | ||||
|   expandOriginalMetadata = false | ||||
|   expandArchivedMetadata = false | ||||
|  | ||||
| @@ -216,6 +223,7 @@ export class DocumentDetailComponent | ||||
|   document: Document | ||||
|   metadata: DocumentMetadata | ||||
|   suggestions: DocumentSuggestions | ||||
|   suggestionsLoading: boolean = false | ||||
|   users: User[] | ||||
|  | ||||
|   title: string | ||||
| @@ -297,6 +305,10 @@ export class DocumentDetailComponent | ||||
|     return this.deviceDetectorService.isMobile() | ||||
|   } | ||||
|  | ||||
|   get aiEnabled(): boolean { | ||||
|     return this.settings.get(SETTINGS_KEYS.AI_ENABLED) | ||||
|   } | ||||
|  | ||||
|   get archiveContentRenderType(): ContentRenderType { | ||||
|     return this.document?.archived_file_name | ||||
|       ? this.getRenderType('application/pdf') | ||||
| @@ -678,25 +690,12 @@ export class DocumentDetailComponent | ||||
|         PermissionType.Document | ||||
|       ) | ||||
|     ) { | ||||
|       this.documentsService | ||||
|         .getSuggestions(doc.id) | ||||
|         .pipe( | ||||
|           first(), | ||||
|           takeUntil(this.unsubscribeNotifier), | ||||
|           takeUntil(this.docChangeNotifier) | ||||
|         ) | ||||
|         .subscribe({ | ||||
|           next: (result) => { | ||||
|             this.suggestions = result | ||||
|           }, | ||||
|           error: (error) => { | ||||
|             this.suggestions = null | ||||
|             this.toastService.showError( | ||||
|               $localize`Error retrieving suggestions.`, | ||||
|               error | ||||
|             ) | ||||
|           }, | ||||
|         }) | ||||
|       this.tagService.getCachedMany(doc.tags).subscribe((tags) => { | ||||
|         // only show suggestions if document has inbox tags | ||||
|         if (tags.some((tag) => tag.is_inbox_tag)) { | ||||
|           this.getSuggestions() | ||||
|         } | ||||
|       }) | ||||
|     } | ||||
|     this.title = this.documentTitlePipe.transform(doc.title) | ||||
|     this.prepareForm(doc) | ||||
| @@ -706,6 +705,60 @@ export class DocumentDetailComponent | ||||
|     return this.documentForm.get('custom_fields') as FormArray | ||||
|   } | ||||
|  | ||||
|   getSuggestions() { | ||||
|     this.suggestionsLoading = true | ||||
|     this.documentsService | ||||
|       .getSuggestions(this.documentId) | ||||
|       .pipe( | ||||
|         first(), | ||||
|         takeUntil(this.unsubscribeNotifier), | ||||
|         takeUntil(this.docChangeNotifier) | ||||
|       ) | ||||
|       .subscribe({ | ||||
|         next: (result) => { | ||||
|           this.suggestions = result | ||||
|           this.suggestionsLoading = false | ||||
|         }, | ||||
|         error: (error) => { | ||||
|           this.suggestions = null | ||||
|           this.suggestionsLoading = false | ||||
|           this.toastService.showError( | ||||
|             $localize`Error retrieving suggestions.`, | ||||
|             error | ||||
|           ) | ||||
|         }, | ||||
|       }) | ||||
|   } | ||||
|  | ||||
|   createTag(newName: string) { | ||||
|     var modal = this.modalService.open(TagEditDialogComponent, { | ||||
|       backdrop: 'static', | ||||
|     }) | ||||
|     modal.componentInstance.dialogMode = EditDialogMode.CREATE | ||||
|     if (newName) modal.componentInstance.object = { name: newName } | ||||
|     console.log('createTag called with', newName) | ||||
|  | ||||
|     modal.componentInstance.succeeded | ||||
|       .pipe( | ||||
|         switchMap((newTag) => { | ||||
|           return this.tagService | ||||
|             .listAll() | ||||
|             .pipe(map((tags) => ({ newTag, tags }))) | ||||
|         }) | ||||
|       ) | ||||
|       .pipe(takeUntil(this.unsubscribeNotifier)) | ||||
|       .subscribe(({ newTag, tags }) => { | ||||
|         this.tagsInput.tags = tags.results | ||||
|         this.tagsInput.addTag(newTag.id) | ||||
|         console.log(this.suggestions) | ||||
|  | ||||
|         if (this.suggestions) { | ||||
|           this.suggestions.suggested_tags = | ||||
|             this.suggestions.suggested_tags.filter((tag) => tag !== newName) | ||||
|         } | ||||
|       }) | ||||
|   } | ||||
|  | ||||
|   createDocumentType(newName: string) { | ||||
|     var modal = this.modalService.open(DocumentTypeEditDialogComponent, { | ||||
|       backdrop: 'static', | ||||
| @@ -725,6 +778,12 @@ export class DocumentDetailComponent | ||||
|         this.documentTypes = documentTypes.results | ||||
|         this.documentForm.get('document_type').setValue(newDocumentType.id) | ||||
|         this.documentForm.get('document_type').markAsDirty() | ||||
|         if (this.suggestions) { | ||||
|           this.suggestions.suggested_document_types = | ||||
|             this.suggestions.suggested_document_types.filter( | ||||
|               (dt) => dt !== newName | ||||
|             ) | ||||
|         } | ||||
|       }) | ||||
|   } | ||||
|  | ||||
| @@ -749,6 +808,12 @@ export class DocumentDetailComponent | ||||
|         this.correspondents = correspondents.results | ||||
|         this.documentForm.get('correspondent').setValue(newCorrespondent.id) | ||||
|         this.documentForm.get('correspondent').markAsDirty() | ||||
|         if (this.suggestions) { | ||||
|           this.suggestions.suggested_correspondents = | ||||
|             this.suggestions.suggested_correspondents.filter( | ||||
|               (c) => c !== newName | ||||
|             ) | ||||
|         } | ||||
|       }) | ||||
|   } | ||||
|  | ||||
|   | ||||
| @@ -1,11 +1,17 @@ | ||||
| export interface DocumentSuggestions { | ||||
|   title?: string | ||||
|  | ||||
|   tags?: number[] | ||||
|   suggested_tags?: string[] | ||||
|  | ||||
|   correspondents?: number[] | ||||
|   suggested_correspondents?: string[] | ||||
|  | ||||
|   document_types?: number[] | ||||
|   suggested_document_types?: string[] | ||||
|  | ||||
|   storage_paths?: number[] | ||||
|   suggested_storage_paths?: string[] | ||||
|  | ||||
|   dates?: string[] // ISO-formatted date string e.g. 2022-11-03 | ||||
| } | ||||
|   | ||||
| @@ -44,12 +44,24 @@ export enum ConfigOptionType { | ||||
|   Boolean = 'boolean', | ||||
|   JSON = 'json', | ||||
|   File = 'file', | ||||
|   Password = 'password', | ||||
| } | ||||
|  | ||||
| export const ConfigCategory = { | ||||
|   General: $localize`General Settings`, | ||||
|   OCR: $localize`OCR Settings`, | ||||
|   Barcode: $localize`Barcode Settings`, | ||||
|   AI: $localize`AI Settings`, | ||||
| } | ||||
|  | ||||
| export const LLMEmbeddingBackendConfig = { | ||||
|   OPENAI: 'openai', | ||||
|   HUGGINGFACE: 'huggingface', | ||||
| } | ||||
|  | ||||
| export const LLMBackendConfig = { | ||||
|   OPENAI: 'openai', | ||||
|   OLLAMA: 'ollama', | ||||
| } | ||||
|  | ||||
| export interface ConfigOption { | ||||
| @@ -59,6 +71,7 @@ export interface ConfigOption { | ||||
|   choices?: Array<{ id: string; name: string }> | ||||
|   config_key?: string | ||||
|   category: string | ||||
|   note?: string | ||||
| } | ||||
|  | ||||
| function mapToItems(enumObj: Object): Array<{ id: string; name: string }> { | ||||
| @@ -258,6 +271,58 @@ export const PaperlessConfigOptions: ConfigOption[] = [ | ||||
|     config_key: 'PAPERLESS_CONSUMER_TAG_BARCODE_MAPPING', | ||||
|     category: ConfigCategory.Barcode, | ||||
|   }, | ||||
|   { | ||||
|     key: 'ai_enabled', | ||||
|     title: $localize`AI Enabled`, | ||||
|     type: ConfigOptionType.Boolean, | ||||
|     config_key: 'PAPERLESS_AI_ENABLED', | ||||
|     category: ConfigCategory.AI, | ||||
|     note: $localize`Consider privacy implications when enabling AI features, especially if using a remote model.`, | ||||
|   }, | ||||
|   { | ||||
|     key: 'llm_embedding_backend', | ||||
|     title: $localize`LLM Embedding Backend`, | ||||
|     type: ConfigOptionType.Select, | ||||
|     choices: mapToItems(LLMEmbeddingBackendConfig), | ||||
|     config_key: 'PAPERLESS_AI_LLM_EMBEDDING_BACKEND', | ||||
|     category: ConfigCategory.AI, | ||||
|   }, | ||||
|   { | ||||
|     key: 'llm_embedding_model', | ||||
|     title: $localize`LLM Embedding Model`, | ||||
|     type: ConfigOptionType.String, | ||||
|     config_key: 'PAPERLESS_AI_LLM_EMBEDDING_MODEL', | ||||
|     category: ConfigCategory.AI, | ||||
|   }, | ||||
|   { | ||||
|     key: 'llm_backend', | ||||
|     title: $localize`LLM Backend`, | ||||
|     type: ConfigOptionType.Select, | ||||
|     choices: mapToItems(LLMBackendConfig), | ||||
|     config_key: 'PAPERLESS_AI_LLM_BACKEND', | ||||
|     category: ConfigCategory.AI, | ||||
|   }, | ||||
|   { | ||||
|     key: 'llm_model', | ||||
|     title: $localize`LLM Model`, | ||||
|     type: ConfigOptionType.String, | ||||
|     config_key: 'PAPERLESS_AI_LLM_MODEL', | ||||
|     category: ConfigCategory.AI, | ||||
|   }, | ||||
|   { | ||||
|     key: 'llm_api_key', | ||||
|     title: $localize`LLM API Key`, | ||||
|     type: ConfigOptionType.Password, | ||||
|     config_key: 'PAPERLESS_AI_LLM_API_KEY', | ||||
|     category: ConfigCategory.AI, | ||||
|   }, | ||||
|   { | ||||
|     key: 'llm_endpoint', | ||||
|     title: $localize`LLM Endpoint`, | ||||
|     type: ConfigOptionType.String, | ||||
|     config_key: 'PAPERLESS_AI_LLM_ENDPOINT', | ||||
|     category: ConfigCategory.AI, | ||||
|   }, | ||||
| ] | ||||
|  | ||||
| export interface PaperlessConfig extends ObjectWithId { | ||||
| @@ -287,4 +352,11 @@ export interface PaperlessConfig extends ObjectWithId { | ||||
|   barcode_max_pages: number | ||||
|   barcode_enable_tag: boolean | ||||
|   barcode_tag_mapping: object | ||||
|   ai_enabled: boolean | ||||
|   llm_embedding_backend: string | ||||
|   llm_embedding_model: string | ||||
|   llm_backend: string | ||||
|   llm_model: string | ||||
|   llm_api_key: string | ||||
|   llm_endpoint: string | ||||
| } | ||||
|   | ||||
| @@ -11,6 +11,7 @@ export enum PaperlessTaskName { | ||||
|   TrainClassifier = 'train_classifier', | ||||
|   SanityCheck = 'check_sanity', | ||||
|   IndexOptimize = 'index_optimize', | ||||
|   LLMIndexUpdate = 'llmindex_update', | ||||
| } | ||||
|  | ||||
| export enum PaperlessTaskStatus { | ||||
|   | ||||
| @@ -7,6 +7,7 @@ export enum SystemStatusItemStatus { | ||||
|   OK = 'OK', | ||||
|   ERROR = 'ERROR', | ||||
|   WARNING = 'WARNING', | ||||
|   DISABLED = 'DISABLED', | ||||
| } | ||||
|  | ||||
| export interface SystemStatus { | ||||
| @@ -43,6 +44,9 @@ export interface SystemStatus { | ||||
|     sanity_check_status: SystemStatusItemStatus | ||||
|     sanity_check_last_run: string // ISO date string | ||||
|     sanity_check_error: string | ||||
|     llmindex_status: SystemStatusItemStatus | ||||
|     llmindex_last_modified: string // ISO date string | ||||
|     llmindex_error: string | ||||
|   } | ||||
|   websocket_connected?: SystemStatusItemStatus // added client-side | ||||
| } | ||||
|   | ||||
| @@ -76,6 +76,7 @@ export const SETTINGS_KEYS = { | ||||
|   GMAIL_OAUTH_URL: 'gmail_oauth_url', | ||||
|   OUTLOOK_OAUTH_URL: 'outlook_oauth_url', | ||||
|   EMAIL_ENABLED: 'email_enabled', | ||||
|   AI_ENABLED: 'ai_enabled', | ||||
| } | ||||
|  | ||||
| export const SETTINGS: UiSetting[] = [ | ||||
| @@ -289,4 +290,9 @@ export const SETTINGS: UiSetting[] = [ | ||||
|     type: 'string', | ||||
|     default: 'page-width', // ZoomSetting from 'document-detail.component' | ||||
|   }, | ||||
|   { | ||||
|     key: SETTINGS_KEYS.AI_ENABLED, | ||||
|     type: 'boolean', | ||||
|     default: false, | ||||
|   }, | ||||
| ] | ||||
|   | ||||
| @@ -4,15 +4,15 @@ import { | ||||
|   HttpInterceptor, | ||||
|   HttpRequest, | ||||
| } from '@angular/common/http' | ||||
| import { Injectable, inject } from '@angular/core' | ||||
| import { inject, Injectable } from '@angular/core' | ||||
| import { Meta } from '@angular/platform-browser' | ||||
| import { CookieService } from 'ngx-cookie-service' | ||||
| import { Observable } from 'rxjs' | ||||
|  | ||||
| @Injectable() | ||||
| export class CsrfInterceptor implements HttpInterceptor { | ||||
|   private cookieService = inject(CookieService) | ||||
|   private meta = inject(Meta) | ||||
|   private cookieService: CookieService = inject(CookieService) | ||||
|   private meta: Meta = inject(Meta) | ||||
|  | ||||
|   intercept( | ||||
|     request: HttpRequest<unknown>, | ||||
|   | ||||
							
								
								
									
										58
									
								
								src-ui/src/app/services/chat.service.spec.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								src-ui/src/app/services/chat.service.spec.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| import { | ||||
|   HttpEventType, | ||||
|   provideHttpClient, | ||||
|   withInterceptorsFromDi, | ||||
| } from '@angular/common/http' | ||||
| import { | ||||
|   HttpTestingController, | ||||
|   provideHttpClientTesting, | ||||
| } from '@angular/common/http/testing' | ||||
| import { TestBed } from '@angular/core/testing' | ||||
| import { environment } from 'src/environments/environment' | ||||
| import { ChatService } from './chat.service' | ||||
|  | ||||
| describe('ChatService', () => { | ||||
|   let service: ChatService | ||||
|   let httpMock: HttpTestingController | ||||
|  | ||||
|   beforeEach(() => { | ||||
|     TestBed.configureTestingModule({ | ||||
|       imports: [], | ||||
|       providers: [ | ||||
|         ChatService, | ||||
|         provideHttpClient(withInterceptorsFromDi()), | ||||
|         provideHttpClientTesting(), | ||||
|       ], | ||||
|     }) | ||||
|     service = TestBed.inject(ChatService) | ||||
|     httpMock = TestBed.inject(HttpTestingController) | ||||
|   }) | ||||
|  | ||||
|   afterEach(() => { | ||||
|     httpMock.verify() | ||||
|   }) | ||||
|  | ||||
|   it('should stream chat messages', (done) => { | ||||
|     const documentId = 1 | ||||
|     const prompt = 'Hello, world!' | ||||
|     const mockResponse = 'Partial response text' | ||||
|     const apiUrl = `${environment.apiBaseUrl}documents/chat/` | ||||
|  | ||||
|     service.streamChat(documentId, prompt).subscribe((chunk) => { | ||||
|       expect(chunk).toBe(mockResponse) | ||||
|       done() | ||||
|     }) | ||||
|  | ||||
|     const req = httpMock.expectOne(apiUrl) | ||||
|     expect(req.request.method).toBe('POST') | ||||
|     expect(req.request.body).toEqual({ | ||||
|       document_id: documentId, | ||||
|       q: prompt, | ||||
|     }) | ||||
|  | ||||
|     req.event({ | ||||
|       type: HttpEventType.DownloadProgress, | ||||
|       partialText: mockResponse, | ||||
|     } as any) | ||||
|   }) | ||||
| }) | ||||
							
								
								
									
										46
									
								
								src-ui/src/app/services/chat.service.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								src-ui/src/app/services/chat.service.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | ||||
| import { | ||||
|   HttpClient, | ||||
|   HttpDownloadProgressEvent, | ||||
|   HttpEventType, | ||||
| } from '@angular/common/http' | ||||
| import { inject, Injectable } from '@angular/core' | ||||
| import { filter, map, Observable } from 'rxjs' | ||||
| import { environment } from 'src/environments/environment' | ||||
|  | ||||
| export interface ChatMessage { | ||||
|   role: 'user' | 'assistant' | ||||
|   content: string | ||||
|   isStreaming?: boolean | ||||
| } | ||||
|  | ||||
| @Injectable({ | ||||
|   providedIn: 'root', | ||||
| }) | ||||
| export class ChatService { | ||||
|   private http: HttpClient = inject(HttpClient) | ||||
|  | ||||
|   streamChat(documentId: number, prompt: string): Observable<string> { | ||||
|     return this.http | ||||
|       .post( | ||||
|         `${environment.apiBaseUrl}documents/chat/`, | ||||
|         { | ||||
|           document_id: documentId, | ||||
|           q: prompt, | ||||
|         }, | ||||
|         { | ||||
|           observe: 'events', | ||||
|           reportProgress: true, | ||||
|           responseType: 'text', | ||||
|           withCredentials: true, | ||||
|         } | ||||
|       ) | ||||
|       .pipe( | ||||
|         map((event) => { | ||||
|           if (event.type === HttpEventType.DownloadProgress) { | ||||
|             return (event as HttpDownloadProgressEvent).partialText! | ||||
|           } | ||||
|         }), | ||||
|         filter((chunk) => !!chunk) | ||||
|       ) | ||||
|   } | ||||
| } | ||||
| @@ -9,6 +9,7 @@ import { DatePipe, registerLocaleData } from '@angular/common' | ||||
| import { | ||||
|   HTTP_INTERCEPTORS, | ||||
|   provideHttpClient, | ||||
|   withFetch, | ||||
|   withInterceptorsFromDi, | ||||
| } from '@angular/common/http' | ||||
| import { FormsModule, ReactiveFormsModule } from '@angular/forms' | ||||
| @@ -48,6 +49,7 @@ import { | ||||
|   caretDown, | ||||
|   caretUp, | ||||
|   chatLeftText, | ||||
|   chatSquareDots, | ||||
|   check, | ||||
|   check2All, | ||||
|   checkAll, | ||||
| @@ -121,6 +123,7 @@ import { | ||||
|   sliders2Vertical, | ||||
|   sortAlphaDown, | ||||
|   sortAlphaUpAlt, | ||||
|   stars, | ||||
|   tag, | ||||
|   tagFill, | ||||
|   tags, | ||||
| @@ -260,6 +263,7 @@ const icons = { | ||||
|   caretDown, | ||||
|   caretUp, | ||||
|   chatLeftText, | ||||
|   chatSquareDots, | ||||
|   check, | ||||
|   check2All, | ||||
|   checkAll, | ||||
| @@ -333,6 +337,7 @@ const icons = { | ||||
|   sliders2Vertical, | ||||
|   sortAlphaDown, | ||||
|   sortAlphaUpAlt, | ||||
|   stars, | ||||
|   tagFill, | ||||
|   tag, | ||||
|   tags, | ||||
| @@ -397,6 +402,6 @@ bootstrapApplication(AppComponent, { | ||||
|     CorrespondentNamePipe, | ||||
|     DocumentTypeNamePipe, | ||||
|     StoragePathNamePipe, | ||||
|     provideHttpClient(withInterceptorsFromDi()), | ||||
|     provideHttpClient(withInterceptorsFromDi(), withFetch()), | ||||
|   ], | ||||
| }).catch((err) => console.error(err)) | ||||
|   | ||||
| @@ -11,6 +11,7 @@ class DocumentsConfig(AppConfig): | ||||
|         from documents.signals import document_consumption_finished | ||||
|         from documents.signals import document_updated | ||||
|         from documents.signals.handlers import add_inbox_tags | ||||
|         from documents.signals.handlers import add_or_update_document_in_llm_index | ||||
|         from documents.signals.handlers import add_to_index | ||||
|         from documents.signals.handlers import run_workflows_added | ||||
|         from documents.signals.handlers import run_workflows_updated | ||||
| @@ -26,6 +27,7 @@ class DocumentsConfig(AppConfig): | ||||
|         document_consumption_finished.connect(set_storage_path) | ||||
|         document_consumption_finished.connect(add_to_index) | ||||
|         document_consumption_finished.connect(run_workflows_added) | ||||
|         document_consumption_finished.connect(add_or_update_document_in_llm_index) | ||||
|         document_updated.connect(run_workflows_updated) | ||||
|  | ||||
|         import documents.schema  # noqa: F401 | ||||
|   | ||||
| @@ -196,6 +196,56 @@ def refresh_suggestions_cache( | ||||
|     cache.touch(doc_key, timeout) | ||||
|  | ||||
|  | ||||
| def get_llm_suggestion_cache( | ||||
|     document_id: int, | ||||
|     backend: str, | ||||
| ) -> SuggestionCacheData | None: | ||||
|     doc_key = get_suggestion_cache_key(document_id) | ||||
|     data: SuggestionCacheData = cache.get(doc_key) | ||||
|  | ||||
|     if data and data.classifier_hash == backend: | ||||
|         return data | ||||
|  | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def set_llm_suggestions_cache( | ||||
|     document_id: int, | ||||
|     suggestions: dict, | ||||
|     *, | ||||
|     backend: str, | ||||
|     timeout: int = CACHE_50_MINUTES, | ||||
| ) -> None: | ||||
|     """ | ||||
|     Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4'). | ||||
|     """ | ||||
|     from documents.caching import SuggestionCacheData | ||||
|  | ||||
|     doc_key = get_suggestion_cache_key(document_id) | ||||
|     cache.set( | ||||
|         doc_key, | ||||
|         SuggestionCacheData( | ||||
|             classifier_version=1000,  # Unique marker for LLM-based suggestion | ||||
|             classifier_hash=backend, | ||||
|             suggestions=suggestions, | ||||
|         ), | ||||
|         timeout, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def invalidate_llm_suggestions_cache( | ||||
|     document_id: int, | ||||
| ) -> None: | ||||
|     """ | ||||
|     Invalidate the LLM suggestions cache for a specific document and backend. | ||||
|     """ | ||||
|     doc_key = get_suggestion_cache_key(document_id) | ||||
|     data: SuggestionCacheData = cache.get(doc_key) | ||||
|  | ||||
|     if data: | ||||
|         cache.delete(doc_key) | ||||
|  | ||||
|  | ||||
| def get_metadata_cache_key(document_id: int) -> str: | ||||
|     """ | ||||
|     Returns the basic key for a document's metadata | ||||
|   | ||||
							
								
								
									
										22
									
								
								src/documents/management/commands/document_llmindex.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								src/documents/management/commands/document_llmindex.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| from django.core.management import BaseCommand | ||||
| from django.db import transaction | ||||
|  | ||||
| from documents.management.commands.mixins import ProgressBarMixin | ||||
| from documents.tasks import llmindex_index | ||||
|  | ||||
|  | ||||
| class Command(ProgressBarMixin, BaseCommand): | ||||
|     help = "Manages the LLM-based vector index for Paperless." | ||||
|  | ||||
|     def add_arguments(self, parser): | ||||
|         parser.add_argument("command", choices=["rebuild", "update"]) | ||||
|         self.add_argument_progress_bar_mixin(parser) | ||||
|  | ||||
|     def handle(self, *args, **options): | ||||
|         self.handle_progress_bar_mixin(**options) | ||||
|         with transaction.atomic(): | ||||
|             llmindex_index( | ||||
|                 progress_bar_disable=self.no_progress_bar, | ||||
|                 rebuild=options["command"] == "rebuild", | ||||
|                 scheduled=False, | ||||
|             ) | ||||
| @@ -0,0 +1,30 @@ | ||||
| # Generated by Django 5.1.8 on 2025-04-30 02:38 | ||||
|  | ||||
| from django.db import migrations | ||||
| from django.db import models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|     dependencies = [ | ||||
|         ("documents", "1071_tag_tn_ancestors_count_tag_tn_ancestors_pks_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AlterField( | ||||
|             model_name="paperlesstask", | ||||
|             name="task_name", | ||||
|             field=models.CharField( | ||||
|                 choices=[ | ||||
|                     ("consume_file", "Consume File"), | ||||
|                     ("train_classifier", "Train Classifier"), | ||||
|                     ("check_sanity", "Check Sanity"), | ||||
|                     ("index_optimize", "Index Optimize"), | ||||
|                     ("llmindex_update", "LLM Index Update"), | ||||
|                 ], | ||||
|                 help_text="Name of the task that was run", | ||||
|                 max_length=255, | ||||
|                 null=True, | ||||
|                 verbose_name="Task Name", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @@ -598,6 +598,7 @@ class PaperlessTask(ModelWithOwner): | ||||
|         TRAIN_CLASSIFIER = ("train_classifier", _("Train Classifier")) | ||||
|         CHECK_SANITY = ("check_sanity", _("Check Sanity")) | ||||
|         INDEX_OPTIMIZE = ("index_optimize", _("Index Optimize")) | ||||
|         LLMINDEX_UPDATE = ("llmindex_update", _("LLM Index Update")) | ||||
|  | ||||
|     task_id = models.CharField( | ||||
|         max_length=255, | ||||
|   | ||||
| @@ -31,6 +31,7 @@ from guardian.shortcuts import remove_perm | ||||
|  | ||||
| from documents import matching | ||||
| from documents.caching import clear_document_caches | ||||
| from documents.caching import invalidate_llm_suggestions_cache | ||||
| from documents.file_handling import create_source_path_directory | ||||
| from documents.file_handling import delete_empty_directories | ||||
| from documents.file_handling import generate_unique_filename | ||||
| @@ -52,6 +53,7 @@ from documents.models import WorkflowTrigger | ||||
| from documents.permissions import get_objects_for_user_owner_aware | ||||
| from documents.permissions import set_permissions_for_object | ||||
| from documents.templating.workflows import parse_w_workflow_placeholders | ||||
| from paperless.config import AIConfig | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from documents.classifier import DocumentClassifier | ||||
| @@ -531,6 +533,15 @@ def update_filename_and_move_files( | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @receiver(models.signals.post_save, sender=Document) | ||||
| def update_llm_suggestions_cache(sender, instance, **kwargs): | ||||
|     """ | ||||
|     Invalidate the LLM suggestions cache when a document is saved. | ||||
|     """ | ||||
|     # Invalidate the cache for the document | ||||
|     invalidate_llm_suggestions_cache(instance.pk) | ||||
|  | ||||
|  | ||||
| # should be disabled in /src/documents/management/commands/document_importer.py handle | ||||
| @receiver(models.signals.post_save, sender=CustomField) | ||||
| def check_paths_and_prune_custom_fields(sender, instance: CustomField, **kwargs): | ||||
| @@ -1500,3 +1511,26 @@ def close_connection_pool_on_worker_init(**kwargs): | ||||
|     for conn in connections.all(initialized_only=True): | ||||
|         if conn.alias == "default" and hasattr(conn, "pool") and conn.pool: | ||||
|             conn.close_pool() | ||||
|  | ||||
|  | ||||
| def add_or_update_document_in_llm_index(sender, document, **kwargs): | ||||
|     """ | ||||
|     Add or update a document in the LLM index when it is created or updated. | ||||
|     """ | ||||
|     ai_config = AIConfig() | ||||
|     if ai_config.llm_index_enabled(): | ||||
|         from documents.tasks import update_document_in_llm_index | ||||
|  | ||||
|         update_document_in_llm_index.delay(document) | ||||
|  | ||||
|  | ||||
| @receiver(models.signals.post_delete, sender=Document) | ||||
| def delete_document_from_llm_index(sender, instance: Document, **kwargs): | ||||
|     """ | ||||
|     Delete a document from the LLM index when it is deleted. | ||||
|     """ | ||||
|     ai_config = AIConfig() | ||||
|     if ai_config.llm_index_enabled(): | ||||
|         from documents.tasks import remove_document_from_llm_index | ||||
|  | ||||
|         remove_document_from_llm_index.delay(instance) | ||||
|   | ||||
| @@ -54,6 +54,10 @@ from documents.sanity_checker import SanityCheckFailedException | ||||
| from documents.signals import document_updated | ||||
| from documents.signals.handlers import cleanup_document_deletion | ||||
| from documents.signals.handlers import run_workflows | ||||
| from paperless.config import AIConfig | ||||
| from paperless_ai.indexing import llm_index_add_or_update_document | ||||
| from paperless_ai.indexing import llm_index_remove_document | ||||
| from paperless_ai.indexing import update_llm_index | ||||
|  | ||||
| if settings.AUDIT_LOG_ENABLED: | ||||
|     from auditlog.models import LogEntry | ||||
| @@ -242,6 +246,13 @@ def bulk_update_documents(document_ids): | ||||
|         for doc in documents: | ||||
|             index.update_document(writer, doc) | ||||
|  | ||||
|     ai_config = AIConfig() | ||||
|     if ai_config.llm_index_enabled(): | ||||
|         update_llm_index( | ||||
|             progress_bar_disable=True, | ||||
|             rebuild=False, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def update_document_content_maybe_archive_file(document_id): | ||||
| @@ -341,6 +352,10 @@ def update_document_content_maybe_archive_file(document_id): | ||||
|         with index.open_index_writer() as writer: | ||||
|             index.update_document(writer, document) | ||||
|  | ||||
|         ai_config = AIConfig() | ||||
|         if ai_config.llm_index_enabled: | ||||
|             llm_index_add_or_update_document(document) | ||||
|  | ||||
|         clear_document_caches(document.pk) | ||||
|  | ||||
|     except Exception: | ||||
| @@ -563,3 +578,55 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None: | ||||
|  | ||||
|     if affected: | ||||
|         bulk_update_documents.delay(document_ids=list(affected)) | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def llmindex_index( | ||||
|     *, | ||||
|     progress_bar_disable=True, | ||||
|     rebuild=False, | ||||
|     scheduled=True, | ||||
|     auto=False, | ||||
| ): | ||||
|     ai_config = AIConfig() | ||||
|     if ai_config.llm_index_enabled(): | ||||
|         task = PaperlessTask.objects.create( | ||||
|             type=PaperlessTask.TaskType.SCHEDULED_TASK | ||||
|             if scheduled | ||||
|             else PaperlessTask.TaskType.AUTO | ||||
|             if auto | ||||
|             else PaperlessTask.TaskType.MANUAL_TASK, | ||||
|             task_id=uuid.uuid4(), | ||||
|             task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, | ||||
|             status=states.STARTED, | ||||
|             date_created=timezone.now(), | ||||
|             date_started=timezone.now(), | ||||
|         ) | ||||
|         from paperless_ai.indexing import update_llm_index | ||||
|  | ||||
|         try: | ||||
|             result = update_llm_index( | ||||
|                 progress_bar_disable=progress_bar_disable, | ||||
|                 rebuild=rebuild, | ||||
|             ) | ||||
|             task.status = states.SUCCESS | ||||
|             task.result = result | ||||
|         except Exception as e: | ||||
|             logger.error("LLM index error: " + str(e)) | ||||
|             task.status = states.FAILURE | ||||
|             task.result = str(e) | ||||
|  | ||||
|         task.date_done = timezone.now() | ||||
|         task.save(update_fields=["status", "result", "date_done"]) | ||||
|     else: | ||||
|         logger.info("LLM index is disabled, skipping update.") | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def update_document_in_llm_index(document): | ||||
|     llm_index_add_or_update_document(document) | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def remove_document_from_llm_index(document): | ||||
|     llm_index_remove_document(document) | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| import json | ||||
| from pathlib import Path | ||||
| from unittest.mock import patch | ||||
|  | ||||
| from django.contrib.auth.models import User | ||||
| from django.core.files.uploadedfile import SimpleUploadedFile | ||||
| @@ -65,6 +66,13 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): | ||||
|                 "barcode_max_pages": None, | ||||
|                 "barcode_enable_tag": None, | ||||
|                 "barcode_tag_mapping": None, | ||||
|                 "ai_enabled": False, | ||||
|                 "llm_embedding_backend": None, | ||||
|                 "llm_embedding_model": None, | ||||
|                 "llm_backend": None, | ||||
|                 "llm_model": None, | ||||
|                 "llm_api_key": None, | ||||
|                 "llm_endpoint": None, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @@ -231,3 +239,76 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) | ||||
|         self.assertEqual(ApplicationConfiguration.objects.count(), 1) | ||||
|  | ||||
|     def test_update_llm_api_key(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Existing config with llm_api_key specified | ||||
|         WHEN: | ||||
|             - API to update llm_api_key is called with all *s | ||||
|             - API to update llm_api_key is called with empty string | ||||
|         THEN: | ||||
|             - llm_api_key is unchanged | ||||
|             - llm_api_key is set to None | ||||
|         """ | ||||
|         config = ApplicationConfiguration.objects.first() | ||||
|         config.llm_api_key = "1234567890" | ||||
|         config.save() | ||||
|  | ||||
|         # Test with all * | ||||
|         response = self.client.patch( | ||||
|             f"{self.ENDPOINT}1/", | ||||
|             json.dumps( | ||||
|                 { | ||||
|                     "llm_api_key": "*" * 32, | ||||
|                 }, | ||||
|             ), | ||||
|             content_type="application/json", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         config.refresh_from_db() | ||||
|         self.assertEqual(config.llm_api_key, "1234567890") | ||||
|         # Test with empty string | ||||
|         response = self.client.patch( | ||||
|             f"{self.ENDPOINT}1/", | ||||
|             json.dumps( | ||||
|                 { | ||||
|                     "llm_api_key": "", | ||||
|                 }, | ||||
|             ), | ||||
|             content_type="application/json", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         config.refresh_from_db() | ||||
|         self.assertEqual(config.llm_api_key, None) | ||||
|  | ||||
|     def test_enable_ai_index_triggers_update(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Existing config with AI disabled | ||||
|         WHEN: | ||||
|             - Config is updated to enable AI with llm_embedding_backend | ||||
|         THEN: | ||||
|             - LLM index is triggered to update | ||||
|         """ | ||||
|         config = ApplicationConfiguration.objects.first() | ||||
|         config.ai_enabled = False | ||||
|         config.llm_embedding_backend = None | ||||
|         config.save() | ||||
|  | ||||
|         with ( | ||||
|             patch("documents.tasks.llmindex_index.delay") as mock_update, | ||||
|             patch("paperless_ai.indexing.vector_store_file_exists") as mock_exists, | ||||
|         ): | ||||
|             mock_exists.return_value = False | ||||
|             self.client.patch( | ||||
|                 f"{self.ENDPOINT}1/", | ||||
|                 json.dumps( | ||||
|                     { | ||||
|                         "ai_enabled": True, | ||||
|                         "llm_embedding_backend": "openai", | ||||
|                     }, | ||||
|                 ), | ||||
|                 content_type="application/json", | ||||
|             ) | ||||
|             mock_update.assert_called_once() | ||||
|   | ||||
| @@ -310,3 +310,69 @@ class TestSystemStatus(APITestCase): | ||||
|             "ERROR", | ||||
|         ) | ||||
|         self.assertIsNotNone(response.data["tasks"]["sanity_check_error"]) | ||||
|  | ||||
|     def test_system_status_ai_disabled(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - The AI feature is disabled | ||||
|         WHEN: | ||||
|             - The user requests the system status | ||||
|         THEN: | ||||
|             - The response contains the correct AI status | ||||
|         """ | ||||
|         with override_settings(AI_ENABLED=False): | ||||
|             self.client.force_login(self.user) | ||||
|             response = self.client.get(self.ENDPOINT) | ||||
|             self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|             self.assertEqual(response.data["tasks"]["llmindex_status"], "DISABLED") | ||||
|             self.assertIsNone(response.data["tasks"]["llmindex_error"]) | ||||
|  | ||||
|     def test_system_status_ai_enabled(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - The AI index feature is enabled, but no tasks are found | ||||
|             - The AI index feature is enabled and a task is found | ||||
|         WHEN: | ||||
|             - The user requests the system status | ||||
|         THEN: | ||||
|             - The response contains the correct AI status | ||||
|         """ | ||||
|         with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"): | ||||
|             self.client.force_login(self.user) | ||||
|  | ||||
|             # No tasks found | ||||
|             response = self.client.get(self.ENDPOINT) | ||||
|             self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|             self.assertEqual(response.data["tasks"]["llmindex_status"], "WARNING") | ||||
|  | ||||
|             PaperlessTask.objects.create( | ||||
|                 type=PaperlessTask.TaskType.SCHEDULED_TASK, | ||||
|                 status=states.SUCCESS, | ||||
|                 task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, | ||||
|             ) | ||||
|             response = self.client.get(self.ENDPOINT) | ||||
|             self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|             self.assertEqual(response.data["tasks"]["llmindex_status"], "OK") | ||||
|             self.assertIsNone(response.data["tasks"]["llmindex_error"]) | ||||
|  | ||||
|     def test_system_status_ai_error(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - The AI index feature is enabled and a task is found with an error | ||||
|         WHEN: | ||||
|             - The user requests the system status | ||||
|         THEN: | ||||
|             - The response contains the correct AI status | ||||
|         """ | ||||
|         with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"): | ||||
|             PaperlessTask.objects.create( | ||||
|                 type=PaperlessTask.TaskType.SCHEDULED_TASK, | ||||
|                 status=states.FAILURE, | ||||
|                 task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, | ||||
|                 result="AI index update failed", | ||||
|             ) | ||||
|             self.client.force_login(self.user) | ||||
|             response = self.client.get(self.ENDPOINT) | ||||
|             self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|             self.assertEqual(response.data["tasks"]["llmindex_status"], "ERROR") | ||||
|             self.assertIsNotNone(response.data["tasks"]["llmindex_error"]) | ||||
|   | ||||
| @@ -49,6 +49,7 @@ class TestApiUiSettings(DirectoriesMixin, APITestCase): | ||||
|                     "backend_setting": "default", | ||||
|                 }, | ||||
|                 "email_enabled": False, | ||||
|                 "ai_enabled": False, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|   | ||||
| @@ -3,14 +3,17 @@ from datetime import timedelta | ||||
| from pathlib import Path | ||||
| from unittest import mock | ||||
|  | ||||
| from celery import states | ||||
| from django.conf import settings | ||||
| from django.test import TestCase | ||||
| from django.test import override_settings | ||||
| from django.utils import timezone | ||||
|  | ||||
| from documents import tasks | ||||
| from documents.models import Correspondent | ||||
| from documents.models import Document | ||||
| from documents.models import DocumentType | ||||
| from documents.models import PaperlessTask | ||||
| from documents.models import Tag | ||||
| from documents.sanity_checker import SanityCheckFailedException | ||||
| from documents.sanity_checker import SanityCheckMessages | ||||
| @@ -270,3 +273,103 @@ class TestUpdateContent(DirectoriesMixin, TestCase): | ||||
|  | ||||
|         tasks.update_document_content_maybe_archive_file(doc.pk) | ||||
|         self.assertNotEqual(Document.objects.get(pk=doc.pk).content, "test") | ||||
|  | ||||
|  | ||||
| class TestAIIndex(DirectoriesMixin, TestCase): | ||||
|     @override_settings( | ||||
|         AI_ENABLED=True, | ||||
|         LLM_EMBEDDING_BACKEND="huggingface", | ||||
|     ) | ||||
|     def test_ai_index_success(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Document exists, AI is enabled, llm index backend is set | ||||
|         WHEN: | ||||
|             - llmindex_index task is called | ||||
|         THEN: | ||||
|             - update_llm_index is called, and the task is marked as success | ||||
|         """ | ||||
|         Document.objects.create( | ||||
|             title="test", | ||||
|             content="my document", | ||||
|             checksum="wow", | ||||
|         ) | ||||
|         # lazy-loaded so mock the actual function | ||||
|         with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index: | ||||
|             update_llm_index.return_value = "LLM index updated successfully." | ||||
|             tasks.llmindex_index() | ||||
|             update_llm_index.assert_called_once() | ||||
|             task = PaperlessTask.objects.get( | ||||
|                 task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, | ||||
|             ) | ||||
|             self.assertEqual(task.status, states.SUCCESS) | ||||
|             self.assertEqual(task.result, "LLM index updated successfully.") | ||||
|  | ||||
|     @override_settings( | ||||
|         AI_ENABLED=True, | ||||
|         LLM_EMBEDDING_BACKEND="huggingface", | ||||
|     ) | ||||
|     def test_ai_index_failure(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Document exists, AI is enabled, llm index backend is set | ||||
|         WHEN: | ||||
|             - llmindex_index task is called | ||||
|         THEN: | ||||
|             - update_llm_index raises an exception, and the task is marked as failure | ||||
|         """ | ||||
|         Document.objects.create( | ||||
|             title="test", | ||||
|             content="my document", | ||||
|             checksum="wow", | ||||
|         ) | ||||
|         # lazy-loaded so mock the actual function | ||||
|         with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index: | ||||
|             update_llm_index.side_effect = Exception("LLM index update failed.") | ||||
|             tasks.llmindex_index() | ||||
|             update_llm_index.assert_called_once() | ||||
|             task = PaperlessTask.objects.get( | ||||
|                 task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, | ||||
|             ) | ||||
|             self.assertEqual(task.status, states.FAILURE) | ||||
|             self.assertIn("LLM index update failed.", task.result) | ||||
|  | ||||
|     def test_update_document_in_llm_index(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Nothing | ||||
|         WHEN: | ||||
|             - update_document_in_llm_index task is called | ||||
|         THEN: | ||||
|             - llm_index_add_or_update_document is called | ||||
|         """ | ||||
|         doc = Document.objects.create( | ||||
|             title="test", | ||||
|             content="my document", | ||||
|             checksum="wow", | ||||
|         ) | ||||
|         with mock.patch( | ||||
|             "documents.tasks.llm_index_add_or_update_document", | ||||
|         ) as llm_index_add_or_update_document: | ||||
|             tasks.update_document_in_llm_index(doc) | ||||
|             llm_index_add_or_update_document.assert_called_once_with(doc) | ||||
|  | ||||
|     def test_remove_document_from_llm_index(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Nothing | ||||
|         WHEN: | ||||
|             - remove_document_from_llm_index task is called | ||||
|         THEN: | ||||
|             - llm_index_remove_document is called | ||||
|         """ | ||||
|         doc = Document.objects.create( | ||||
|             title="test", | ||||
|             content="my document", | ||||
|             checksum="wow", | ||||
|         ) | ||||
|         with mock.patch( | ||||
|             "documents.tasks.llm_index_remove_document", | ||||
|         ) as llm_index_remove_document: | ||||
|             tasks.remove_document_from_llm_index(doc) | ||||
|             llm_index_remove_document.assert_called_once_with(doc) | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| import tempfile | ||||
| from datetime import timedelta | ||||
| from pathlib import Path | ||||
| from unittest.mock import MagicMock | ||||
| from unittest.mock import patch | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.models import Permission | ||||
| @@ -10,8 +12,15 @@ from django.test import override_settings | ||||
| from django.utils import timezone | ||||
| from rest_framework import status | ||||
|  | ||||
| from documents.caching import get_llm_suggestion_cache | ||||
| from documents.caching import set_llm_suggestions_cache | ||||
| from documents.models import Correspondent | ||||
| from documents.models import Document | ||||
| from documents.models import DocumentType | ||||
| from documents.models import ShareLink | ||||
| from documents.models import StoragePath | ||||
| from documents.models import Tag | ||||
| from documents.signals.handlers import update_llm_suggestions_cache | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
| from paperless.models import ApplicationConfiguration | ||||
|  | ||||
| @@ -154,3 +163,186 @@ class TestViews(DirectoriesMixin, TestCase): | ||||
|         response.render() | ||||
|         self.assertEqual(response.request["PATH_INFO"], "/accounts/login/") | ||||
|         self.assertContains(response, b"Share link has expired") | ||||
|  | ||||
|  | ||||
| class TestAISuggestions(DirectoriesMixin, TestCase): | ||||
|     def setUp(self): | ||||
|         self.user = User.objects.create_superuser(username="testuser") | ||||
|         self.document = Document.objects.create( | ||||
|             title="Test Document", | ||||
|             filename="test.pdf", | ||||
|             mime_type="application/pdf", | ||||
|         ) | ||||
|         self.tag1 = Tag.objects.create(name="tag1") | ||||
|         self.correspondent1 = Correspondent.objects.create(name="correspondent1") | ||||
|         self.document_type1 = DocumentType.objects.create(name="type1") | ||||
|         self.path1 = StoragePath.objects.create(name="path1") | ||||
|         super().setUp() | ||||
|  | ||||
|     @patch("documents.views.get_llm_suggestion_cache") | ||||
|     @patch("documents.views.refresh_suggestions_cache") | ||||
|     @override_settings( | ||||
|         AI_ENABLED=True, | ||||
|         LLM_BACKEND="mock_backend", | ||||
|     ) | ||||
|     def test_suggestions_with_cached_llm(self, mock_refresh_cache, mock_get_cache): | ||||
|         mock_get_cache.return_value = MagicMock(suggestions={"tags": ["tag1", "tag2"]}) | ||||
|  | ||||
|         self.client.force_login(user=self.user) | ||||
|         response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/") | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]}) | ||||
|         mock_refresh_cache.assert_called_once_with(self.document.pk) | ||||
|  | ||||
|     @patch("documents.views.get_ai_document_classification") | ||||
|     @override_settings( | ||||
|         AI_ENABLED=True, | ||||
|         LLM_BACKEND="mock_backend", | ||||
|     ) | ||||
|     def test_suggestions_with_ai_enabled( | ||||
|         self, | ||||
|         mock_get_ai_classification, | ||||
|     ): | ||||
|         mock_get_ai_classification.return_value = { | ||||
|             "title": "AI Title", | ||||
|             "tags": ["tag1", "tag2"], | ||||
|             "correspondents": ["correspondent1"], | ||||
|             "document_types": ["type1"], | ||||
|             "storage_paths": ["path1"], | ||||
|             "dates": ["2023-01-01"], | ||||
|         } | ||||
|  | ||||
|         self.client.force_login(user=self.user) | ||||
|         response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/") | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         self.assertEqual( | ||||
|             response.json(), | ||||
|             { | ||||
|                 "title": "AI Title", | ||||
|                 "tags": [self.tag1.pk], | ||||
|                 "suggested_tags": ["tag2"], | ||||
|                 "correspondents": [self.correspondent1.pk], | ||||
|                 "suggested_correspondents": [], | ||||
|                 "document_types": [self.document_type1.pk], | ||||
|                 "suggested_document_types": [], | ||||
|                 "storage_paths": [self.path1.pk], | ||||
|                 "suggested_storage_paths": [], | ||||
|                 "dates": ["2023-01-01"], | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_invalidate_suggestions_cache(self): | ||||
|         self.client.force_login(user=self.user) | ||||
|         suggestions = { | ||||
|             "title": "AI Title", | ||||
|             "tags": ["tag1", "tag2"], | ||||
|             "correspondents": ["correspondent1"], | ||||
|             "document_types": ["type1"], | ||||
|             "storage_paths": ["path1"], | ||||
|             "dates": ["2023-01-01"], | ||||
|         } | ||||
|         set_llm_suggestions_cache( | ||||
|             self.document.pk, | ||||
|             suggestions, | ||||
|             backend="mock_backend", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             get_llm_suggestion_cache( | ||||
|                 self.document.pk, | ||||
|                 backend="mock_backend", | ||||
|             ).suggestions, | ||||
|             suggestions, | ||||
|         ) | ||||
|         # post_save signal triggered | ||||
|         update_llm_suggestions_cache( | ||||
|             sender=None, | ||||
|             instance=self.document, | ||||
|         ) | ||||
|         self.assertIsNone( | ||||
|             get_llm_suggestion_cache( | ||||
|                 self.document.pk, | ||||
|                 backend="mock_backend", | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class TestAIChatStreamingView(DirectoriesMixin, TestCase): | ||||
|     ENDPOINT = "/api/documents/chat/" | ||||
|  | ||||
|     def setUp(self): | ||||
|         self.user = User.objects.create_user(username="testuser", password="pass") | ||||
|         self.client.force_login(user=self.user) | ||||
|         self.document = Document.objects.create( | ||||
|             title="Test Document", | ||||
|             filename="test.pdf", | ||||
|             mime_type="application/pdf", | ||||
|         ) | ||||
|         super().setUp() | ||||
|  | ||||
|     @override_settings(AI_ENABLED=False) | ||||
|     def test_post_ai_disabled(self): | ||||
|         response = self.client.post( | ||||
|             self.ENDPOINT, | ||||
|             data='{"q": "question"}', | ||||
|             content_type="application/json", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         self.assertIn(b"AI is required for this feature", response.content) | ||||
|  | ||||
|     @override_settings(AI_ENABLED=True) | ||||
|     def test_post_invalid_json(self): | ||||
|         response = self.client.post( | ||||
|             self.ENDPOINT, | ||||
|             data="invalid", | ||||
|             content_type="application/json", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         self.assertIn(b"Invalid request", response.content) | ||||
|  | ||||
|     @patch("documents.views.stream_chat_with_documents") | ||||
|     @patch("documents.views.get_objects_for_user_owner_aware") | ||||
|     @override_settings(AI_ENABLED=True) | ||||
|     def test_post_no_document_id(self, mock_get_objects, mock_stream_chat): | ||||
|         mock_get_objects.return_value = [self.document] | ||||
|         mock_stream_chat.return_value = iter([b"data"]) | ||||
|         response = self.client.post( | ||||
|             self.ENDPOINT, | ||||
|             data='{"q": "question"}', | ||||
|             content_type="application/json", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(response["Content-Type"], "text/event-stream") | ||||
|  | ||||
|     @patch("documents.views.stream_chat_with_documents") | ||||
|     @override_settings(AI_ENABLED=True) | ||||
|     def test_post_with_document_id(self, mock_stream_chat): | ||||
|         mock_stream_chat.return_value = iter([b"data"]) | ||||
|         response = self.client.post( | ||||
|             self.ENDPOINT, | ||||
|             data=f'{{"q": "question", "document_id": {self.document.pk}}}', | ||||
|             content_type="application/json", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(response["Content-Type"], "text/event-stream") | ||||
|  | ||||
|     @override_settings(AI_ENABLED=True) | ||||
|     def test_post_with_invalid_document_id(self): | ||||
|         response = self.client.post( | ||||
|             self.ENDPOINT, | ||||
|             data='{"q": "question", "document_id": 999999}', | ||||
|             content_type="application/json", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         self.assertIn(b"Document not found", response.content) | ||||
|  | ||||
|     @patch("documents.views.has_perms_owner_aware") | ||||
|     @override_settings(AI_ENABLED=True) | ||||
|     def test_post_with_document_id_no_permission(self, mock_has_perms): | ||||
|         mock_has_perms.return_value = False | ||||
|         response = self.client.post( | ||||
|             self.ENDPOINT, | ||||
|             data=f'{{"q": "question", "document_id": {self.document.pk}}}', | ||||
|             content_type="application/json", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 403) | ||||
|         self.assertIn(b"Insufficient permissions", response.content) | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import itertools | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| import platform | ||||
| @@ -17,6 +18,7 @@ import magic | ||||
| import pathvalidate | ||||
| from celery import states | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.decorators import login_required | ||||
| from django.contrib.auth.models import Group | ||||
| from django.contrib.auth.models import User | ||||
| from django.db import connections | ||||
| @@ -40,6 +42,7 @@ from django.http import HttpResponseBadRequest | ||||
| from django.http import HttpResponseForbidden | ||||
| from django.http import HttpResponseRedirect | ||||
| from django.http import HttpResponseServerError | ||||
| from django.http import StreamingHttpResponse | ||||
| from django.shortcuts import get_object_or_404 | ||||
| from django.utils import timezone | ||||
| from django.utils.decorators import method_decorator | ||||
| @@ -47,6 +50,7 @@ from django.utils.timezone import make_aware | ||||
| from django.utils.translation import get_language | ||||
| from django.views import View | ||||
| from django.views.decorators.cache import cache_control | ||||
| from django.views.decorators.csrf import ensure_csrf_cookie | ||||
| from django.views.decorators.http import condition | ||||
| from django.views.decorators.http import last_modified | ||||
| from django.views.generic import TemplateView | ||||
| @@ -82,10 +86,12 @@ from documents import index | ||||
| from documents.bulk_download import ArchiveOnlyStrategy | ||||
| from documents.bulk_download import OriginalAndArchiveStrategy | ||||
| from documents.bulk_download import OriginalsOnlyStrategy | ||||
| from documents.caching import get_llm_suggestion_cache | ||||
| from documents.caching import get_metadata_cache | ||||
| from documents.caching import get_suggestion_cache | ||||
| from documents.caching import refresh_metadata_cache | ||||
| from documents.caching import refresh_suggestions_cache | ||||
| from documents.caching import set_llm_suggestions_cache | ||||
| from documents.caching import set_metadata_cache | ||||
| from documents.caching import set_suggestions_cache | ||||
| from documents.classifier import load_classifier | ||||
| @@ -174,12 +180,21 @@ from documents.templating.filepath import validate_filepath_template_and_render | ||||
| from documents.utils import get_boolean | ||||
| from paperless import version | ||||
| from paperless.celery import app as celery_app | ||||
| from paperless.config import AIConfig | ||||
| from paperless.config import GeneralConfig | ||||
| from paperless.db import GnuPG | ||||
| from paperless.models import ApplicationConfiguration | ||||
| from paperless.serialisers import GroupSerializer | ||||
| from paperless.serialisers import UserSerializer | ||||
| from paperless.views import StandardPagination | ||||
| from paperless_ai.ai_classifier import get_ai_document_classification | ||||
| from paperless_ai.chat import stream_chat_with_documents | ||||
| from paperless_ai.indexing import update_llm_index | ||||
| from paperless_ai.matching import extract_unmatched_names | ||||
| from paperless_ai.matching import match_correspondents_by_name | ||||
| from paperless_ai.matching import match_document_types_by_name | ||||
| from paperless_ai.matching import match_storage_paths_by_name | ||||
| from paperless_ai.matching import match_tags_by_name | ||||
| from paperless_mail.models import MailAccount | ||||
| from paperless_mail.models import MailRule | ||||
| from paperless_mail.oauth import PaperlessMailOAuth2Manager | ||||
| @@ -774,37 +789,103 @@ class DocumentViewSet( | ||||
|         ): | ||||
|             return HttpResponseForbidden("Insufficient permissions") | ||||
|  | ||||
|         document_suggestions = get_suggestion_cache(doc.pk) | ||||
|         ai_config = AIConfig() | ||||
|  | ||||
|         if document_suggestions is not None: | ||||
|             refresh_suggestions_cache(doc.pk) | ||||
|             return Response(document_suggestions.suggestions) | ||||
|  | ||||
|         classifier = load_classifier() | ||||
|  | ||||
|         dates = [] | ||||
|         if settings.NUMBER_OF_SUGGESTED_DATES > 0: | ||||
|             gen = parse_date_generator(doc.filename, doc.content) | ||||
|             dates = sorted( | ||||
|                 {i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)}, | ||||
|         if ai_config.ai_enabled: | ||||
|             cached_llm_suggestions = get_llm_suggestion_cache( | ||||
|                 doc.pk, | ||||
|                 backend=ai_config.llm_backend, | ||||
|             ) | ||||
|  | ||||
|         resp_data = { | ||||
|             "correspondents": [ | ||||
|                 c.id for c in match_correspondents(doc, classifier, request.user) | ||||
|             ], | ||||
|             "tags": [t.id for t in match_tags(doc, classifier, request.user)], | ||||
|             "document_types": [ | ||||
|                 dt.id for dt in match_document_types(doc, classifier, request.user) | ||||
|             ], | ||||
|             "storage_paths": [ | ||||
|                 dt.id for dt in match_storage_paths(doc, classifier, request.user) | ||||
|             ], | ||||
|             "dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None], | ||||
|         } | ||||
|             if cached_llm_suggestions: | ||||
|                 refresh_suggestions_cache(doc.pk) | ||||
|                 return Response(cached_llm_suggestions.suggestions) | ||||
|  | ||||
|         # Cache the suggestions and the classifier hash for later | ||||
|         set_suggestions_cache(doc.pk, resp_data, classifier) | ||||
|             llm_suggestions = get_ai_document_classification(doc, request.user) | ||||
|  | ||||
|             matched_tags = match_tags_by_name( | ||||
|                 llm_suggestions.get("tags", []), | ||||
|                 request.user, | ||||
|             ) | ||||
|             matched_correspondents = match_correspondents_by_name( | ||||
|                 llm_suggestions.get("correspondents", []), | ||||
|                 request.user, | ||||
|             ) | ||||
|             matched_types = match_document_types_by_name( | ||||
|                 llm_suggestions.get("document_types", []), | ||||
|                 request.user, | ||||
|             ) | ||||
|             matched_paths = match_storage_paths_by_name( | ||||
|                 llm_suggestions.get("storage_paths", []), | ||||
|                 request.user, | ||||
|             ) | ||||
|  | ||||
|             resp_data = { | ||||
|                 "title": llm_suggestions.get("title"), | ||||
|                 "tags": [t.id for t in matched_tags], | ||||
|                 "suggested_tags": extract_unmatched_names( | ||||
|                     llm_suggestions.get("tags", []), | ||||
|                     matched_tags, | ||||
|                 ), | ||||
|                 "correspondents": [c.id for c in matched_correspondents], | ||||
|                 "suggested_correspondents": extract_unmatched_names( | ||||
|                     llm_suggestions.get("correspondents", []), | ||||
|                     matched_correspondents, | ||||
|                 ), | ||||
|                 "document_types": [d.id for d in matched_types], | ||||
|                 "suggested_document_types": extract_unmatched_names( | ||||
|                     llm_suggestions.get("document_types", []), | ||||
|                     matched_types, | ||||
|                 ), | ||||
|                 "storage_paths": [s.id for s in matched_paths], | ||||
|                 "suggested_storage_paths": extract_unmatched_names( | ||||
|                     llm_suggestions.get("storage_paths", []), | ||||
|                     matched_paths, | ||||
|                 ), | ||||
|                 "dates": llm_suggestions.get("dates", []), | ||||
|             } | ||||
|  | ||||
|             set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend) | ||||
|         else: | ||||
|             document_suggestions = get_suggestion_cache(doc.pk) | ||||
|  | ||||
|             if document_suggestions is not None: | ||||
|                 refresh_suggestions_cache(doc.pk) | ||||
|                 return Response(document_suggestions.suggestions) | ||||
|  | ||||
|             classifier = load_classifier() | ||||
|  | ||||
|             dates = [] | ||||
|             if settings.NUMBER_OF_SUGGESTED_DATES > 0: | ||||
|                 gen = parse_date_generator(doc.filename, doc.content) | ||||
|                 dates = sorted( | ||||
|                     { | ||||
|                         i | ||||
|                         for i in itertools.islice( | ||||
|                             gen, | ||||
|                             settings.NUMBER_OF_SUGGESTED_DATES, | ||||
|                         ) | ||||
|                     }, | ||||
|                 ) | ||||
|  | ||||
|             resp_data = { | ||||
|                 "correspondents": [ | ||||
|                     c.id for c in match_correspondents(doc, classifier, request.user) | ||||
|                 ], | ||||
|                 "tags": [t.id for t in match_tags(doc, classifier, request.user)], | ||||
|                 "document_types": [ | ||||
|                     dt.id for dt in match_document_types(doc, classifier, request.user) | ||||
|                 ], | ||||
|                 "storage_paths": [ | ||||
|                     dt.id for dt in match_storage_paths(doc, classifier, request.user) | ||||
|                 ], | ||||
|                 "dates": [ | ||||
|                     date.strftime("%Y-%m-%d") for date in dates if date is not None | ||||
|                 ], | ||||
|             } | ||||
|  | ||||
|             # Cache the suggestions and the classifier hash for later | ||||
|             set_suggestions_cache(doc.pk, resp_data, classifier) | ||||
|  | ||||
|         return Response(resp_data) | ||||
|  | ||||
| @@ -1104,6 +1185,52 @@ class DocumentViewSet( | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @method_decorator( | ||||
|     [ | ||||
|         ensure_csrf_cookie, | ||||
|         login_required, | ||||
|         cache_control(no_cache=True), | ||||
|     ], | ||||
|     name="dispatch", | ||||
| ) | ||||
| class ChatStreamingView(View): | ||||
|     def post(self, request): | ||||
|         request.compress_exempt = True | ||||
|         ai_config = AIConfig() | ||||
|         if not ai_config.ai_enabled: | ||||
|             return HttpResponseBadRequest("AI is required for this feature") | ||||
|  | ||||
|         try: | ||||
|             data = json.loads(request.body) | ||||
|             question = data["q"] | ||||
|             doc_id = data.get("document_id", None) | ||||
|         except (KeyError, json.JSONDecodeError): | ||||
|             return HttpResponseBadRequest("Invalid request") | ||||
|  | ||||
|         if doc_id: | ||||
|             try: | ||||
|                 document = Document.objects.get(id=doc_id) | ||||
|             except Document.DoesNotExist: | ||||
|                 return HttpResponseBadRequest("Document not found") | ||||
|  | ||||
|             if not has_perms_owner_aware(request.user, "view_document", document): | ||||
|                 return HttpResponseForbidden("Insufficient permissions") | ||||
|  | ||||
|             documents = [document] | ||||
|         else: | ||||
|             documents = get_objects_for_user_owner_aware( | ||||
|                 request.user, | ||||
|                 "view_document", | ||||
|                 Document, | ||||
|             ) | ||||
|  | ||||
|         response = StreamingHttpResponse( | ||||
|             stream_chat_with_documents(query_str=question, documents=documents), | ||||
|             content_type="text/event-stream", | ||||
|         ) | ||||
|         return response | ||||
|  | ||||
|  | ||||
| @extend_schema_view( | ||||
|     list=extend_schema( | ||||
|         description="Document views including search", | ||||
| @@ -2238,6 +2365,10 @@ class UiSettingsView(GenericAPIView): | ||||
|  | ||||
|         ui_settings["email_enabled"] = settings.EMAIL_ENABLED | ||||
|  | ||||
|         ai_config = AIConfig() | ||||
|  | ||||
|         ui_settings["ai_enabled"] = ai_config.ai_enabled | ||||
|  | ||||
|         user_resp = { | ||||
|             "id": user.id, | ||||
|             "username": user.username, | ||||
| @@ -2376,6 +2507,10 @@ class TasksViewSet(ReadOnlyModelViewSet): | ||||
|             sanity_check, | ||||
|             {"scheduled": False, "raise_on_error": False}, | ||||
|         ), | ||||
|         PaperlessTask.TaskName.LLMINDEX_UPDATE: ( | ||||
|             update_llm_index, | ||||
|             {"scheduled": False, "rebuild": False}, | ||||
|         ), | ||||
|     } | ||||
|  | ||||
|     def get_queryset(self): | ||||
| @@ -2891,6 +3026,31 @@ class SystemStatusView(PassUserMixin): | ||||
|             last_sanity_check.date_done if last_sanity_check else None | ||||
|         ) | ||||
|  | ||||
|         ai_config = AIConfig() | ||||
|         if not ai_config.llm_index_enabled(): | ||||
|             llmindex_status = "DISABLED" | ||||
|             llmindex_error = None | ||||
|             llmindex_last_modified = None | ||||
|         else: | ||||
|             last_llmindex_update = ( | ||||
|                 PaperlessTask.objects.filter( | ||||
|                     task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, | ||||
|                 ) | ||||
|                 .order_by("-date_done") | ||||
|                 .first() | ||||
|             ) | ||||
|             llmindex_status = "OK" | ||||
|             llmindex_error = None | ||||
|             if last_llmindex_update is None: | ||||
|                 llmindex_status = "WARNING" | ||||
|                 llmindex_error = "No LLM index update tasks found" | ||||
|             elif last_llmindex_update and last_llmindex_update.status == states.FAILURE: | ||||
|                 llmindex_status = "ERROR" | ||||
|                 llmindex_error = last_llmindex_update.result | ||||
|             llmindex_last_modified = ( | ||||
|                 last_llmindex_update.date_done if last_llmindex_update else None | ||||
|             ) | ||||
|  | ||||
|         return Response( | ||||
|             { | ||||
|                 "pngx_version": current_version, | ||||
| @@ -2928,6 +3088,9 @@ class SystemStatusView(PassUserMixin): | ||||
|                     "sanity_check_status": sanity_check_status, | ||||
|                     "sanity_check_last_run": sanity_check_last_run, | ||||
|                     "sanity_check_error": sanity_check_error, | ||||
|                     "llmindex_status": llmindex_status, | ||||
|                     "llmindex_last_modified": llmindex_last_modified, | ||||
|                     "llmindex_error": llmindex_error, | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
|   | ||||
| @@ -169,3 +169,36 @@ class GeneralConfig(BaseConfig): | ||||
|  | ||||
|         self.app_title = app_config.app_title or None | ||||
|         self.app_logo = app_config.app_logo.url if app_config.app_logo else None | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| class AIConfig(BaseConfig): | ||||
|     """ | ||||
|     AI related settings that require global scope | ||||
|     """ | ||||
|  | ||||
|     ai_enabled: bool = dataclasses.field(init=False) | ||||
|     llm_embedding_backend: str = dataclasses.field(init=False) | ||||
|     llm_embedding_model: str = dataclasses.field(init=False) | ||||
|     llm_backend: str = dataclasses.field(init=False) | ||||
|     llm_model: str = dataclasses.field(init=False) | ||||
|     llm_api_key: str = dataclasses.field(init=False) | ||||
|     llm_endpoint: str = dataclasses.field(init=False) | ||||
|  | ||||
|     def __post_init__(self) -> None: | ||||
|         app_config = self._get_config_instance() | ||||
|  | ||||
|         self.ai_enabled = app_config.ai_enabled or settings.AI_ENABLED | ||||
|         self.llm_embedding_backend = ( | ||||
|             app_config.llm_embedding_backend or settings.LLM_EMBEDDING_BACKEND | ||||
|         ) | ||||
|         self.llm_embedding_model = ( | ||||
|             app_config.llm_embedding_model or settings.LLM_EMBEDDING_MODEL | ||||
|         ) | ||||
|         self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND | ||||
|         self.llm_model = app_config.llm_model or settings.LLM_MODEL | ||||
|         self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY | ||||
|         self.llm_endpoint = app_config.llm_endpoint or settings.LLM_ENDPOINT | ||||
|  | ||||
|     def llm_index_enabled(self) -> bool: | ||||
|         return self.ai_enabled and self.llm_embedding_backend | ||||
|   | ||||
| @@ -0,0 +1,84 @@ | ||||
| # Generated by Django 5.1.8 on 2025-04-30 02:38 | ||||
|  | ||||
| from django.db import migrations | ||||
| from django.db import models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|     dependencies = [ | ||||
|         ("paperless", "0004_applicationconfiguration_barcode_asn_prefix_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="applicationconfiguration", | ||||
|             name="ai_enabled", | ||||
|             field=models.BooleanField( | ||||
|                 default=False, | ||||
|                 null=True, | ||||
|                 verbose_name="Enables AI features", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="applicationconfiguration", | ||||
|             name="llm_api_key", | ||||
|             field=models.CharField( | ||||
|                 blank=True, | ||||
|                 max_length=128, | ||||
|                 null=True, | ||||
|                 verbose_name="Sets the LLM API key", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="applicationconfiguration", | ||||
|             name="llm_backend", | ||||
|             field=models.CharField( | ||||
|                 blank=True, | ||||
|                 choices=[("openai", "OpenAI"), ("ollama", "Ollama")], | ||||
|                 max_length=32, | ||||
|                 null=True, | ||||
|                 verbose_name="Sets the LLM backend", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="applicationconfiguration", | ||||
|             name="llm_embedding_backend", | ||||
|             field=models.CharField( | ||||
|                 blank=True, | ||||
|                 choices=[("openai", "OpenAI"), ("huggingface", "Huggingface")], | ||||
|                 max_length=32, | ||||
|                 null=True, | ||||
|                 verbose_name="Sets the LLM embedding backend", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="applicationconfiguration", | ||||
|             name="llm_embedding_model", | ||||
|             field=models.CharField( | ||||
|                 blank=True, | ||||
|                 max_length=32, | ||||
|                 null=True, | ||||
|                 verbose_name="Sets the LLM embedding model", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="applicationconfiguration", | ||||
|             name="llm_model", | ||||
|             field=models.CharField( | ||||
|                 blank=True, | ||||
|                 max_length=32, | ||||
|                 null=True, | ||||
|                 verbose_name="Sets the LLM model", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="applicationconfiguration", | ||||
|             name="llm_endpoint", | ||||
|             field=models.CharField( | ||||
|                 blank=True, | ||||
|                 max_length=128, | ||||
|                 null=True, | ||||
|                 verbose_name="Sets the LLM endpoint, optional", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @@ -74,6 +74,20 @@ class ColorConvertChoices(models.TextChoices): | ||||
|     CMYK = ("CMYK", _("CMYK")) | ||||
|  | ||||
|  | ||||
| class LLMEmbeddingBackend(models.TextChoices): | ||||
|     OPENAI = ("openai", _("OpenAI")) | ||||
|     HUGGINGFACE = ("huggingface", _("Huggingface")) | ||||
|  | ||||
|  | ||||
| class LLMBackend(models.TextChoices): | ||||
|     """ | ||||
|     Matches to --llm-backend | ||||
|     """ | ||||
|  | ||||
|     OPENAI = ("openai", _("OpenAI")) | ||||
|     OLLAMA = ("ollama", _("Ollama")) | ||||
|  | ||||
|  | ||||
| class ApplicationConfiguration(AbstractSingletonModel): | ||||
|     """ | ||||
|     Settings which are common across more than 1 parser | ||||
| @@ -265,6 +279,60 @@ class ApplicationConfiguration(AbstractSingletonModel): | ||||
|         null=True, | ||||
|     ) | ||||
|  | ||||
|     """ | ||||
|     AI related settings | ||||
|     """ | ||||
|  | ||||
|     ai_enabled = models.BooleanField( | ||||
|         verbose_name=_("Enables AI features"), | ||||
|         null=True, | ||||
|         default=False, | ||||
|     ) | ||||
|  | ||||
|     llm_embedding_backend = models.CharField( | ||||
|         verbose_name=_("Sets the LLM embedding backend"), | ||||
|         null=True, | ||||
|         blank=True, | ||||
|         max_length=32, | ||||
|         choices=LLMEmbeddingBackend.choices, | ||||
|     ) | ||||
|  | ||||
|     llm_embedding_model = models.CharField( | ||||
|         verbose_name=_("Sets the LLM embedding model"), | ||||
|         null=True, | ||||
|         blank=True, | ||||
|         max_length=32, | ||||
|     ) | ||||
|  | ||||
|     llm_backend = models.CharField( | ||||
|         verbose_name=_("Sets the LLM backend"), | ||||
|         null=True, | ||||
|         blank=True, | ||||
|         max_length=32, | ||||
|         choices=LLMBackend.choices, | ||||
|     ) | ||||
|  | ||||
|     llm_model = models.CharField( | ||||
|         verbose_name=_("Sets the LLM model"), | ||||
|         null=True, | ||||
|         blank=True, | ||||
|         max_length=32, | ||||
|     ) | ||||
|  | ||||
|     llm_api_key = models.CharField( | ||||
|         verbose_name=_("Sets the LLM API key"), | ||||
|         null=True, | ||||
|         blank=True, | ||||
|         max_length=128, | ||||
|     ) | ||||
|  | ||||
|     llm_endpoint = models.CharField( | ||||
|         verbose_name=_("Sets the LLM endpoint, optional"), | ||||
|         null=True, | ||||
|         blank=True, | ||||
|         max_length=128, | ||||
|     ) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("paperless application settings") | ||||
|  | ||||
|   | ||||
| @@ -192,6 +192,10 @@ class ProfileSerializer(serializers.ModelSerializer): | ||||
| class ApplicationConfigurationSerializer(serializers.ModelSerializer): | ||||
|     user_args = serializers.JSONField(binary=True, allow_null=True) | ||||
|     barcode_tag_mapping = serializers.JSONField(binary=True, allow_null=True) | ||||
|     llm_api_key = ObfuscatedPasswordField( | ||||
|         required=False, | ||||
|         allow_null=True, | ||||
|     ) | ||||
|  | ||||
|     def run_validation(self, data): | ||||
|         # Empty strings treated as None to avoid unexpected behavior | ||||
| @@ -201,6 +205,11 @@ class ApplicationConfigurationSerializer(serializers.ModelSerializer): | ||||
|             data["barcode_tag_mapping"] = None | ||||
|         if "language" in data and data["language"] == "": | ||||
|             data["language"] = None | ||||
|         if "llm_api_key" in data and data["llm_api_key"] is not None: | ||||
|             if data["llm_api_key"] == "": | ||||
|                 data["llm_api_key"] = None | ||||
|             elif len(data["llm_api_key"].replace("*", "")) == 0: | ||||
|                 del data["llm_api_key"] | ||||
|         return super().run_validation(data) | ||||
|  | ||||
|     def update(self, instance, validated_data): | ||||
|   | ||||
| @@ -13,6 +13,7 @@ from typing import Final | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| from celery.schedules import crontab | ||||
| from compression_middleware.middleware import CompressionMiddleware | ||||
| from dateparser.languages.loader import LocaleDataLoader | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from dotenv import load_dotenv | ||||
| @@ -230,6 +231,17 @@ def _parse_beat_schedule() -> dict: | ||||
|                 "expires": 59.0 * 60.0, | ||||
|             }, | ||||
|         }, | ||||
|         { | ||||
|             "name": "Rebuild LLM index", | ||||
|             "env_key": "PAPERLESS_LLM_INDEX_TASK_CRON", | ||||
|             # Default daily at 02:10 | ||||
|             "env_default": "10 2 * * *", | ||||
|             "task": "documents.tasks.llmindex_index", | ||||
|             "options": { | ||||
|                 # 1 hour before default schedule sends again | ||||
|                 "expires": 23.0 * 60.0 * 60.0, | ||||
|             }, | ||||
|         }, | ||||
|     ] | ||||
|     for task in tasks: | ||||
|         # Either get the environment setting or use the default | ||||
| @@ -288,6 +300,7 @@ MODEL_FILE = __get_path( | ||||
|     "PAPERLESS_MODEL_FILE", | ||||
|     DATA_DIR / "classification_model.pickle", | ||||
| ) | ||||
| LLM_INDEX_DIR = DATA_DIR / "llm_index" | ||||
|  | ||||
| LOGGING_DIR = __get_path("PAPERLESS_LOGGING_DIR", DATA_DIR / "log") | ||||
|  | ||||
| @@ -380,6 +393,19 @@ MIDDLEWARE = [ | ||||
| if __get_boolean("PAPERLESS_ENABLE_COMPRESSION", "yes"):  # pragma: no cover | ||||
|     MIDDLEWARE.insert(0, "compression_middleware.middleware.CompressionMiddleware") | ||||
|  | ||||
| # Workaround to not compress streaming responses (e.g. chat). | ||||
| # See https://github.com/friedelwolff/django-compression-middleware/pull/7 | ||||
| original_process_response = CompressionMiddleware.process_response | ||||
|  | ||||
|  | ||||
| def patched_process_response(self, request, response): | ||||
|     if getattr(request, "compress_exempt", False): | ||||
|         return response | ||||
|     return original_process_response(self, request, response) | ||||
|  | ||||
|  | ||||
| CompressionMiddleware.process_response = patched_process_response | ||||
|  | ||||
| ROOT_URLCONF = "paperless.urls" | ||||
|  | ||||
|  | ||||
| @@ -590,6 +616,10 @@ X_FRAME_OPTIONS = "SAMEORIGIN" | ||||
| # The next 3 settings can also be set using just PAPERLESS_URL | ||||
| CSRF_TRUSTED_ORIGINS = __get_list("PAPERLESS_CSRF_TRUSTED_ORIGINS") | ||||
|  | ||||
| if DEBUG: | ||||
|     # Allow access from the angular development server during debugging | ||||
|     CSRF_TRUSTED_ORIGINS.append("http://localhost:4200") | ||||
|  | ||||
| # We allow CORS from localhost:8000 | ||||
| CORS_ALLOWED_ORIGINS = __get_list( | ||||
|     "PAPERLESS_CORS_ALLOWED_HOSTS", | ||||
| @@ -600,6 +630,8 @@ if DEBUG: | ||||
|     # Allow access from the angular development server during debugging | ||||
|     CORS_ALLOWED_ORIGINS.append("http://localhost:4200") | ||||
|  | ||||
| CORS_ALLOW_CREDENTIALS = True | ||||
|  | ||||
| CORS_EXPOSE_HEADERS = [ | ||||
|     "Content-Disposition", | ||||
| ] | ||||
| @@ -872,6 +904,7 @@ LOGGING = { | ||||
|     "loggers": { | ||||
|         "paperless": {"handlers": ["file_paperless"], "level": "DEBUG"}, | ||||
|         "paperless_mail": {"handlers": ["file_mail"], "level": "DEBUG"}, | ||||
|         "paperless_ai": {"handlers": ["file_paperless"], "level": "DEBUG"}, | ||||
|         "ocrmypdf": {"handlers": ["file_paperless"], "level": "INFO"}, | ||||
|         "celery": {"handlers": ["file_celery"], "level": "DEBUG"}, | ||||
|         "kombu": {"handlers": ["file_celery"], "level": "DEBUG"}, | ||||
| @@ -1389,3 +1422,16 @@ WEBHOOKS_ALLOW_INTERNAL_REQUESTS = __get_boolean( | ||||
|     "PAPERLESS_WEBHOOKS_ALLOW_INTERNAL_REQUESTS", | ||||
|     "true", | ||||
| ) | ||||
|  | ||||
| ################################################################################ | ||||
| # AI Settings                                                                  # | ||||
| ################################################################################ | ||||
| AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO") | ||||
| LLM_EMBEDDING_BACKEND = os.getenv( | ||||
|     "PAPERLESS_AI_LLM_EMBEDDING_BACKEND", | ||||
| )  # "huggingface" or "openai" | ||||
| LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_AI_LLM_EMBEDDING_MODEL") | ||||
| LLM_BACKEND = os.getenv("PAPERLESS_AI_LLM_BACKEND")  # "ollama" or "openai" | ||||
| LLM_MODEL = os.getenv("PAPERLESS_AI_LLM_MODEL") | ||||
| LLM_API_KEY = os.getenv("PAPERLESS_AI_LLM_API_KEY") | ||||
| LLM_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_ENDPOINT") | ||||
|   | ||||
| @@ -160,6 +160,7 @@ class TestCeleryScheduleParsing(TestCase): | ||||
|     SANITY_EXPIRE_TIME = ((7.0 * 24.0) - 1.0) * 60.0 * 60.0 | ||||
|     EMPTY_TRASH_EXPIRE_TIME = 23.0 * 60.0 * 60.0 | ||||
|     RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME = 59.0 * 60.0 | ||||
|     LLM_INDEX_EXPIRE_TIME = 23.0 * 60.0 * 60.0 | ||||
|  | ||||
|     def test_schedule_configuration_default(self): | ||||
|         """ | ||||
| @@ -204,6 +205,13 @@ class TestCeleryScheduleParsing(TestCase): | ||||
|                     "schedule": crontab(minute="5", hour="*/1"), | ||||
|                     "options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME}, | ||||
|                 }, | ||||
|                 "Rebuild LLM index": { | ||||
|                     "task": "documents.tasks.llmindex_index", | ||||
|                     "schedule": crontab(minute=10, hour=2), | ||||
|                     "options": { | ||||
|                         "expires": self.LLM_INDEX_EXPIRE_TIME, | ||||
|                     }, | ||||
|                 }, | ||||
|             }, | ||||
|             schedule, | ||||
|         ) | ||||
| @@ -256,6 +264,13 @@ class TestCeleryScheduleParsing(TestCase): | ||||
|                     "schedule": crontab(minute="5", hour="*/1"), | ||||
|                     "options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME}, | ||||
|                 }, | ||||
|                 "Rebuild LLM index": { | ||||
|                     "task": "documents.tasks.llmindex_index", | ||||
|                     "schedule": crontab(minute=10, hour=2), | ||||
|                     "options": { | ||||
|                         "expires": self.LLM_INDEX_EXPIRE_TIME, | ||||
|                     }, | ||||
|                 }, | ||||
|             }, | ||||
|             schedule, | ||||
|         ) | ||||
| @@ -300,6 +315,13 @@ class TestCeleryScheduleParsing(TestCase): | ||||
|                     "schedule": crontab(minute="5", hour="*/1"), | ||||
|                     "options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME}, | ||||
|                 }, | ||||
|                 "Rebuild LLM index": { | ||||
|                     "task": "documents.tasks.llmindex_index", | ||||
|                     "schedule": crontab(minute=10, hour=2), | ||||
|                     "options": { | ||||
|                         "expires": self.LLM_INDEX_EXPIRE_TIME, | ||||
|                     }, | ||||
|                 }, | ||||
|             }, | ||||
|             schedule, | ||||
|         ) | ||||
| @@ -322,6 +344,7 @@ class TestCeleryScheduleParsing(TestCase): | ||||
|                 "PAPERLESS_INDEX_TASK_CRON": "disable", | ||||
|                 "PAPERLESS_EMPTY_TRASH_TASK_CRON": "disable", | ||||
|                 "PAPERLESS_WORKFLOW_SCHEDULED_TASK_CRON": "disable", | ||||
|                 "PAPERLESS_LLM_INDEX_TASK_CRON": "disable", | ||||
|             }, | ||||
|         ): | ||||
|             schedule = _parse_beat_schedule() | ||||
|   | ||||
| @@ -18,6 +18,7 @@ from rest_framework.routers import DefaultRouter | ||||
| from documents.views import BulkDownloadView | ||||
| from documents.views import BulkEditObjectsView | ||||
| from documents.views import BulkEditView | ||||
| from documents.views import ChatStreamingView | ||||
| from documents.views import CorrespondentViewSet | ||||
| from documents.views import CustomFieldViewSet | ||||
| from documents.views import DocumentTypeViewSet | ||||
| @@ -137,6 +138,11 @@ urlpatterns = [ | ||||
|                                 SelectionDataView.as_view(), | ||||
|                                 name="selection_data", | ||||
|                             ), | ||||
|                             re_path( | ||||
|                                 "^chat/", | ||||
|                                 ChatStreamingView.as_view(), | ||||
|                                 name="chat_streaming_view", | ||||
|                             ), | ||||
|                         ], | ||||
|                     ), | ||||
|                 ), | ||||
|   | ||||
| @@ -35,6 +35,7 @@ from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from documents.index import DelayedQuery | ||||
| from documents.permissions import PaperlessObjectPermissions | ||||
| from documents.tasks import llmindex_index | ||||
| from paperless.filters import GroupFilterSet | ||||
| from paperless.filters import UserFilterSet | ||||
| from paperless.models import ApplicationConfiguration | ||||
| @@ -43,6 +44,7 @@ from paperless.serialisers import GroupSerializer | ||||
| from paperless.serialisers import PaperlessAuthTokenSerializer | ||||
| from paperless.serialisers import ProfileSerializer | ||||
| from paperless.serialisers import UserSerializer | ||||
| from paperless_ai.indexing import vector_store_file_exists | ||||
|  | ||||
|  | ||||
| class PaperlessObtainAuthTokenView(ObtainAuthToken): | ||||
| @@ -354,6 +356,30 @@ class ApplicationConfigurationViewSet(ModelViewSet): | ||||
|     def create(self, request, *args, **kwargs): | ||||
|         return Response(status=405)  # Not Allowed | ||||
|  | ||||
|     def perform_update(self, serializer): | ||||
|         old_instance = ApplicationConfiguration.objects.all().first() | ||||
|         old_ai_index_enabled = ( | ||||
|             old_instance.ai_enabled and old_instance.llm_embedding_backend | ||||
|         ) | ||||
|  | ||||
|         new_instance: ApplicationConfiguration = serializer.save() | ||||
|         new_ai_index_enabled = ( | ||||
|             new_instance.ai_enabled and new_instance.llm_embedding_backend | ||||
|         ) | ||||
|  | ||||
|         if ( | ||||
|             not old_ai_index_enabled | ||||
|             and new_ai_index_enabled | ||||
|             and not vector_store_file_exists() | ||||
|         ): | ||||
|             # AI index was just enabled and vector store file does not exist | ||||
|             llmindex_index.delay( | ||||
|                 progress_bar_disable=True, | ||||
|                 rebuild=True, | ||||
|                 scheduled=False, | ||||
|                 auto=True, | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @extend_schema_view( | ||||
|     post=extend_schema( | ||||
|   | ||||
							
								
								
									
										0
									
								
								src/paperless_ai/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/paperless_ai/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										106
									
								
								src/paperless_ai/ai_classifier.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								src/paperless_ai/ai_classifier.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | ||||
| import logging | ||||
|  | ||||
| from django.contrib.auth.models import User | ||||
|  | ||||
| from documents.models import Document | ||||
| from documents.permissions import get_objects_for_user_owner_aware | ||||
| from paperless.config import AIConfig | ||||
| from paperless_ai.client import AIClient | ||||
| from paperless_ai.indexing import query_similar_documents | ||||
| from paperless_ai.indexing import truncate_content | ||||
|  | ||||
| logger = logging.getLogger("paperless_ai.rag_classifier") | ||||
|  | ||||
|  | ||||
| def build_prompt_without_rag(document: Document) -> str: | ||||
|     filename = document.filename or "" | ||||
|     content = truncate_content(document.content[:4000] or "") | ||||
|  | ||||
|     return f""" | ||||
|     You are a document classification assistant. | ||||
|  | ||||
|     Analyze the following document and extract the following information: | ||||
|     - A short descriptive title | ||||
|     - Tags that reflect the content | ||||
|     - Names of people or organizations mentioned | ||||
|     - The type or category of the document | ||||
|     - Suggested folder paths for storing the document | ||||
|     - Up to 3 relevant dates in YYYY-MM-DD format | ||||
|  | ||||
|     Filename: | ||||
|     {filename} | ||||
|  | ||||
|     Content: | ||||
|     {content} | ||||
|     """.strip() | ||||
|  | ||||
|  | ||||
| def build_prompt_with_rag(document: Document, user: User | None = None) -> str: | ||||
|     base_prompt = build_prompt_without_rag(document) | ||||
|     context = truncate_content(get_context_for_document(document, user)) | ||||
|  | ||||
|     return f"""{base_prompt} | ||||
|  | ||||
|     Additional context from similar documents: | ||||
|     {context} | ||||
|     """.strip() | ||||
|  | ||||
|  | ||||
| def get_context_for_document( | ||||
|     doc: Document, | ||||
|     user: User | None = None, | ||||
|     max_docs: int = 5, | ||||
| ) -> str: | ||||
|     visible_documents = ( | ||||
|         get_objects_for_user_owner_aware( | ||||
|             user, | ||||
|             "view_document", | ||||
|             Document, | ||||
|         ) | ||||
|         if user | ||||
|         else None | ||||
|     ) | ||||
|     similar_docs = query_similar_documents( | ||||
|         document=doc, | ||||
|         document_ids=[document.pk for document in visible_documents] | ||||
|         if visible_documents | ||||
|         else None, | ||||
|     )[:max_docs] | ||||
|     context_blocks = [] | ||||
|     for similar in similar_docs: | ||||
|         text = similar.content[:1000] or "" | ||||
|         title = similar.title or similar.filename or "Untitled" | ||||
|         context_blocks.append(f"TITLE: {title}\n{text}") | ||||
|     return "\n\n".join(context_blocks) | ||||
|  | ||||
|  | ||||
| def parse_ai_response(raw: dict) -> dict: | ||||
|     return { | ||||
|         "title": raw.get("title", ""), | ||||
|         "tags": raw.get("tags", []), | ||||
|         "correspondents": raw.get("correspondents", []), | ||||
|         "document_types": raw.get("document_types", []), | ||||
|         "storage_paths": raw.get("storage_paths", []), | ||||
|         "dates": raw.get("dates", []), | ||||
|     } | ||||
|  | ||||
|  | ||||
| def get_ai_document_classification( | ||||
|     document: Document, | ||||
|     user: User | None = None, | ||||
| ) -> dict: | ||||
|     ai_config = AIConfig() | ||||
|  | ||||
|     prompt = ( | ||||
|         build_prompt_with_rag(document, user) | ||||
|         if ai_config.llm_embedding_backend | ||||
|         else build_prompt_without_rag(document) | ||||
|     ) | ||||
|  | ||||
|     try: | ||||
|         client = AIClient() | ||||
|         result = client.run_llm_query(prompt) | ||||
|         return parse_ai_response(result) | ||||
|     except Exception as e: | ||||
|         logger.exception("Failed AI classification") | ||||
|         raise e | ||||
							
								
								
									
										10
									
								
								src/paperless_ai/base_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								src/paperless_ai/base_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| from llama_index.core.bridge.pydantic import BaseModel | ||||
|  | ||||
|  | ||||
| class DocumentClassifierSchema(BaseModel): | ||||
|     title: str | ||||
|     tags: list[str] | ||||
|     correspondents: list[str] | ||||
|     document_types: list[str] | ||||
|     storage_paths: list[str] | ||||
|     dates: list[str] | ||||
							
								
								
									
										77
									
								
								src/paperless_ai/chat.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								src/paperless_ai/chat.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,77 @@ | ||||
| import logging | ||||
| import sys | ||||
|  | ||||
| from llama_index.core import VectorStoreIndex | ||||
| from llama_index.core.prompts import PromptTemplate | ||||
| from llama_index.core.query_engine import RetrieverQueryEngine | ||||
|  | ||||
| from documents.models import Document | ||||
| from paperless_ai.client import AIClient | ||||
| from paperless_ai.indexing import load_or_build_index | ||||
|  | ||||
| logger = logging.getLogger("paperless_ai.chat") | ||||
|  | ||||
| CHAT_PROMPT_TMPL = PromptTemplate( | ||||
|     template="""Context information is below. | ||||
|     --------------------- | ||||
|     {context_str} | ||||
|     --------------------- | ||||
|     Given the context information and not prior knowledge, answer the query. | ||||
|     Query: {query_str} | ||||
|     Answer:""", | ||||
| ) | ||||
|  | ||||
|  | ||||
| def stream_chat_with_documents(query_str: str, documents: list[Document]): | ||||
|     client = AIClient() | ||||
|     index = load_or_build_index() | ||||
|  | ||||
|     doc_ids = [str(doc.pk) for doc in documents] | ||||
|  | ||||
|     # Filter only the node(s) that match the document IDs | ||||
|     nodes = [ | ||||
|         node | ||||
|         for node in index.docstore.docs.values() | ||||
|         if node.metadata.get("document_id") in doc_ids | ||||
|     ] | ||||
|  | ||||
|     if len(nodes) == 0: | ||||
|         logger.warning("No nodes found for the given documents.") | ||||
|         yield "Sorry, I couldn't find any content to answer your question." | ||||
|         return | ||||
|  | ||||
|     local_index = VectorStoreIndex(nodes=nodes) | ||||
|     retriever = local_index.as_retriever( | ||||
|         similarity_top_k=3 if len(documents) == 1 else 5, | ||||
|     ) | ||||
|  | ||||
|     if len(documents) == 1: | ||||
|         # Just one doc — provide full content | ||||
|         doc = documents[0] | ||||
|         # TODO: include document metadata in the context | ||||
|         context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}" | ||||
|     else: | ||||
|         top_nodes = retriever.retrieve(query_str) | ||||
|         context = "\n\n".join( | ||||
|             f"TITLE: {node.metadata.get('title')}\n{node.text[:500]}" | ||||
|             for node in top_nodes | ||||
|         ) | ||||
|  | ||||
|     prompt = CHAT_PROMPT_TMPL.partial_format( | ||||
|         context_str=context, | ||||
|         query_str=query_str, | ||||
|     ).format(llm=client.llm) | ||||
|  | ||||
|     query_engine = RetrieverQueryEngine.from_args( | ||||
|         retriever=retriever, | ||||
|         llm=client.llm, | ||||
|         streaming=True, | ||||
|     ) | ||||
|  | ||||
|     logger.debug("Document chat prompt: %s", prompt) | ||||
|  | ||||
|     response_stream = query_engine.query(prompt) | ||||
|  | ||||
|     for chunk in response_stream.response_gen: | ||||
|         yield chunk | ||||
|         sys.stdout.flush() | ||||
							
								
								
									
										68
									
								
								src/paperless_ai/client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								src/paperless_ai/client.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,68 @@ | ||||
| import logging | ||||
|  | ||||
| from llama_index.core.llms import ChatMessage | ||||
| from llama_index.core.program.function_program import get_function_tool | ||||
| from llama_index.llms.ollama import Ollama | ||||
| from llama_index.llms.openai import OpenAI | ||||
|  | ||||
| from paperless.config import AIConfig | ||||
| from paperless_ai.base_model import DocumentClassifierSchema | ||||
|  | ||||
| logger = logging.getLogger("paperless_ai.client") | ||||
|  | ||||
|  | ||||
| class AIClient: | ||||
|     """ | ||||
|     A client for interacting with an LLM backend. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.settings = AIConfig() | ||||
|         self.llm = self.get_llm() | ||||
|  | ||||
|     def get_llm(self) -> Ollama | OpenAI: | ||||
|         if self.settings.llm_backend == "ollama": | ||||
|             return Ollama( | ||||
|                 model=self.settings.llm_model or "llama3", | ||||
|                 base_url=self.settings.llm_endpoint or "http://localhost:11434", | ||||
|                 request_timeout=120, | ||||
|             ) | ||||
|         elif self.settings.llm_backend == "openai": | ||||
|             return OpenAI( | ||||
|                 model=self.settings.llm_model or "gpt-3.5-turbo", | ||||
|                 api_key=self.settings.llm_api_key, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}") | ||||
|  | ||||
|     def run_llm_query(self, prompt: str) -> str: | ||||
|         logger.debug( | ||||
|             "Running LLM query against %s with model %s", | ||||
|             self.settings.llm_backend, | ||||
|             self.settings.llm_model, | ||||
|         ) | ||||
|  | ||||
|         user_msg = ChatMessage(role="user", content=prompt) | ||||
|         tool = get_function_tool(DocumentClassifierSchema) | ||||
|         result = self.llm.chat_with_tools( | ||||
|             tools=[tool], | ||||
|             user_msg=user_msg, | ||||
|             chat_history=[], | ||||
|         ) | ||||
|         tool_calls = self.llm.get_tool_calls_from_response( | ||||
|             result, | ||||
|             error_on_no_tool_calls=True, | ||||
|         ) | ||||
|         logger.debug("LLM query result: %s", tool_calls) | ||||
|         parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs) | ||||
|         return parsed.model_dump() | ||||
|  | ||||
|     def run_chat(self, messages: list[ChatMessage]) -> str: | ||||
|         logger.debug( | ||||
|             "Running chat query against %s with model %s", | ||||
|             self.settings.llm_backend, | ||||
|             self.settings.llm_model, | ||||
|         ) | ||||
|         result = self.llm.chat(messages) | ||||
|         logger.debug("Chat result: %s", result) | ||||
|         return result | ||||
							
								
								
									
										92
									
								
								src/paperless_ai/embedding.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								src/paperless_ai/embedding.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,92 @@ | ||||
| import json | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from pathlib import Path | ||||
|  | ||||
| from django.conf import settings | ||||
| from llama_index.core.base.embeddings.base import BaseEmbedding | ||||
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | ||||
| from llama_index.embeddings.openai import OpenAIEmbedding | ||||
|  | ||||
| from documents.models import Document | ||||
| from documents.models import Note | ||||
| from paperless.config import AIConfig | ||||
| from paperless.models import LLMEmbeddingBackend | ||||
|  | ||||
|  | ||||
| def get_embedding_model() -> BaseEmbedding: | ||||
|     config = AIConfig() | ||||
|  | ||||
|     match config.llm_embedding_backend: | ||||
|         case LLMEmbeddingBackend.OPENAI: | ||||
|             return OpenAIEmbedding( | ||||
|                 model=config.llm_embedding_model or "text-embedding-3-small", | ||||
|                 api_key=config.llm_api_key, | ||||
|             ) | ||||
|         case LLMEmbeddingBackend.HUGGINGFACE: | ||||
|             return HuggingFaceEmbedding( | ||||
|                 model_name=config.llm_embedding_model | ||||
|                 or "sentence-transformers/all-MiniLM-L6-v2", | ||||
|             ) | ||||
|         case _: | ||||
|             raise ValueError( | ||||
|                 f"Unsupported embedding backend: {config.llm_embedding_backend}", | ||||
|             ) | ||||
|  | ||||
|  | ||||
| def get_embedding_dim() -> int: | ||||
|     """ | ||||
|     Loads embedding dimension from meta.json if available, otherwise infers it | ||||
|     from a dummy embedding and stores it for future use. | ||||
|     """ | ||||
|     config = AIConfig() | ||||
|     model = config.llm_embedding_model or ( | ||||
|         "text-embedding-3-small" | ||||
|         if config.llm_embedding_backend == "openai" | ||||
|         else "sentence-transformers/all-MiniLM-L6-v2" | ||||
|     ) | ||||
|  | ||||
|     meta_path: Path = settings.LLM_INDEX_DIR / "meta.json" | ||||
|     if meta_path.exists(): | ||||
|         with meta_path.open() as f: | ||||
|             meta = json.load(f) | ||||
|         if meta.get("embedding_model") != model: | ||||
|             raise RuntimeError( | ||||
|                 f"Embedding model changed from {meta.get('embedding_model')} to {model}. " | ||||
|                 "You must rebuild the index.", | ||||
|             ) | ||||
|         return meta["dim"] | ||||
|  | ||||
|     embedding_model = get_embedding_model() | ||||
|     test_embed = embedding_model.get_text_embedding("test") | ||||
|     dim = len(test_embed) | ||||
|  | ||||
|     with meta_path.open("w") as f: | ||||
|         json.dump({"embedding_model": model, "dim": dim}, f) | ||||
|  | ||||
|     return dim | ||||
|  | ||||
|  | ||||
| def build_llm_index_text(doc: Document) -> str: | ||||
|     lines = [ | ||||
|         f"Title: {doc.title}", | ||||
|         f"Filename: {doc.filename}", | ||||
|         f"Created: {doc.created}", | ||||
|         f"Added: {doc.added}", | ||||
|         f"Modified: {doc.modified}", | ||||
|         f"Tags: {', '.join(tag.name for tag in doc.tags.all())}", | ||||
|         f"Document Type: {doc.document_type.name if doc.document_type else ''}", | ||||
|         f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}", | ||||
|         f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}", | ||||
|         f"Archive Serial Number: {doc.archive_serial_number or ''}", | ||||
|         f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}", | ||||
|     ] | ||||
|  | ||||
|     for instance in doc.custom_fields.all(): | ||||
|         lines.append(f"Custom Field - {instance.field.name}: {instance}") | ||||
|  | ||||
|     lines.append("\nContent:\n") | ||||
|     lines.append(doc.content or "") | ||||
|  | ||||
|     return "\n".join(lines) | ||||
							
								
								
									
										283
									
								
								src/paperless_ai/indexing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										283
									
								
								src/paperless_ai/indexing.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,283 @@ | ||||
| import logging | ||||
| import shutil | ||||
| from pathlib import Path | ||||
|  | ||||
| import faiss | ||||
| import llama_index.core.settings as llama_settings | ||||
| import tqdm | ||||
| from django.conf import settings | ||||
| from llama_index.core import Document as LlamaDocument | ||||
| from llama_index.core import StorageContext | ||||
| from llama_index.core import VectorStoreIndex | ||||
| from llama_index.core import load_index_from_storage | ||||
| from llama_index.core.indices.prompt_helper import PromptHelper | ||||
| from llama_index.core.node_parser import SimpleNodeParser | ||||
| from llama_index.core.prompts import PromptTemplate | ||||
| from llama_index.core.retrievers import VectorIndexRetriever | ||||
| from llama_index.core.schema import BaseNode | ||||
| from llama_index.core.storage.docstore import SimpleDocumentStore | ||||
| from llama_index.core.storage.index_store import SimpleIndexStore | ||||
| from llama_index.core.text_splitter import TokenTextSplitter | ||||
| from llama_index.vector_stores.faiss import FaissVectorStore | ||||
|  | ||||
| from documents.models import Document | ||||
| from paperless_ai.embedding import build_llm_index_text | ||||
| from paperless_ai.embedding import get_embedding_dim | ||||
| from paperless_ai.embedding import get_embedding_model | ||||
|  | ||||
| logger = logging.getLogger("paperless_ai.indexing") | ||||
|  | ||||
|  | ||||
| def get_or_create_storage_context(*, rebuild=False): | ||||
|     """ | ||||
|     Loads or creates the StorageContext (vector store, docstore, index store). | ||||
|     If rebuild=True, deletes and recreates everything. | ||||
|     """ | ||||
|     if rebuild: | ||||
|         shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True) | ||||
|         settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     if rebuild or not settings.LLM_INDEX_DIR.exists(): | ||||
|         embedding_dim = get_embedding_dim() | ||||
|         faiss_index = faiss.IndexFlatL2(embedding_dim) | ||||
|         vector_store = FaissVectorStore(faiss_index=faiss_index) | ||||
|         docstore = SimpleDocumentStore() | ||||
|         index_store = SimpleIndexStore() | ||||
|     else: | ||||
|         vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR) | ||||
|         docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR) | ||||
|         index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR) | ||||
|  | ||||
|     return StorageContext.from_defaults( | ||||
|         docstore=docstore, | ||||
|         index_store=index_store, | ||||
|         vector_store=vector_store, | ||||
|         persist_dir=settings.LLM_INDEX_DIR, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def build_document_node(document: Document) -> list[BaseNode]: | ||||
|     """ | ||||
|     Given a Document, returns parsed Nodes ready for indexing. | ||||
|     """ | ||||
|     text = build_llm_index_text(document) | ||||
|     metadata = { | ||||
|         "document_id": str(document.id), | ||||
|         "title": document.title, | ||||
|         "tags": [t.name for t in document.tags.all()], | ||||
|         "correspondent": document.correspondent.name | ||||
|         if document.correspondent | ||||
|         else None, | ||||
|         "document_type": document.document_type.name | ||||
|         if document.document_type | ||||
|         else None, | ||||
|         "created": document.created.isoformat() if document.created else None, | ||||
|         "added": document.added.isoformat() if document.added else None, | ||||
|         "modified": document.modified.isoformat(), | ||||
|     } | ||||
|     doc = LlamaDocument(text=text, metadata=metadata) | ||||
|     parser = SimpleNodeParser() | ||||
|     return parser.get_nodes_from_documents([doc]) | ||||
|  | ||||
|  | ||||
| def load_or_build_index(nodes=None): | ||||
|     """ | ||||
|     Load an existing VectorStoreIndex if present, | ||||
|     or build a new one using provided nodes if storage is empty. | ||||
|     """ | ||||
|     embed_model = get_embedding_model() | ||||
|     llama_settings.Settings.embed_model = embed_model | ||||
|     storage_context = get_or_create_storage_context() | ||||
|     try: | ||||
|         return load_index_from_storage(storage_context=storage_context) | ||||
|     except ValueError as e: | ||||
|         logger.warning("Failed to load index from storage: %s", e) | ||||
|         if not nodes: | ||||
|             logger.info("No nodes provided for index creation.") | ||||
|             raise | ||||
|         return VectorStoreIndex( | ||||
|             nodes=nodes, | ||||
|             storage_context=storage_context, | ||||
|             embed_model=embed_model, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex): | ||||
|     """ | ||||
|     Removes existing documents from docstore for a given document from the index. | ||||
|     This is necessary because FAISS IndexFlatL2 is append-only. | ||||
|     """ | ||||
|     all_node_ids = list(index.docstore.docs.keys()) | ||||
|     existing_nodes = [ | ||||
|         node.node_id | ||||
|         for node in index.docstore.get_nodes(all_node_ids) | ||||
|         if node.metadata.get("document_id") == str(document.id) | ||||
|     ] | ||||
|     for node_id in existing_nodes: | ||||
|         # Delete from docstore, FAISS IndexFlatL2 are append-only | ||||
|         index.docstore.delete_document(node_id) | ||||
|  | ||||
|  | ||||
| def vector_store_file_exists(): | ||||
|     """ | ||||
|     Check if the vector store file exists in the LLM index directory. | ||||
|     """ | ||||
|     return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists() | ||||
|  | ||||
|  | ||||
| def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str: | ||||
|     """ | ||||
|     Rebuild or update the LLM index. | ||||
|     """ | ||||
|     nodes = [] | ||||
|  | ||||
|     documents = Document.objects.all() | ||||
|     if not documents.exists(): | ||||
|         msg = "No documents found to index." | ||||
|         logger.warning(msg) | ||||
|         return msg | ||||
|  | ||||
|     if rebuild or not vector_store_file_exists(): | ||||
|         # remove meta.json to force re-detection of embedding dim | ||||
|         (settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True) | ||||
|         # Rebuild index from scratch | ||||
|         logger.info("Rebuilding LLM index.") | ||||
|         embed_model = get_embedding_model() | ||||
|         llama_settings.Settings.embed_model = embed_model | ||||
|         storage_context = get_or_create_storage_context(rebuild=True) | ||||
|         for document in tqdm.tqdm(documents, disable=progress_bar_disable): | ||||
|             document_nodes = build_document_node(document) | ||||
|             nodes.extend(document_nodes) | ||||
|  | ||||
|         index = VectorStoreIndex( | ||||
|             nodes=nodes, | ||||
|             storage_context=storage_context, | ||||
|             embed_model=embed_model, | ||||
|             show_progress=not progress_bar_disable, | ||||
|         ) | ||||
|         msg = "LLM index rebuilt successfully." | ||||
|     else: | ||||
|         # Update existing index | ||||
|         index = load_or_build_index() | ||||
|         all_node_ids = list(index.docstore.docs.keys()) | ||||
|         existing_nodes = { | ||||
|             node.metadata.get("document_id"): node | ||||
|             for node in index.docstore.get_nodes(all_node_ids) | ||||
|         } | ||||
|  | ||||
|         for document in tqdm.tqdm(documents, disable=progress_bar_disable): | ||||
|             doc_id = str(document.id) | ||||
|             document_modified = document.modified.isoformat() | ||||
|  | ||||
|             if doc_id in existing_nodes: | ||||
|                 node = existing_nodes[doc_id] | ||||
|                 node_modified = node.metadata.get("modified") | ||||
|  | ||||
|                 if node_modified == document_modified: | ||||
|                     continue | ||||
|  | ||||
|                 # Again, delete from docstore, FAISS IndexFlatL2 are append-only | ||||
|                 index.docstore.delete_document(node.node_id) | ||||
|                 nodes.extend(build_document_node(document)) | ||||
|             else: | ||||
|                 # New document, add it | ||||
|                 nodes.extend(build_document_node(document)) | ||||
|  | ||||
|         if nodes: | ||||
|             msg = "LLM index updated successfully." | ||||
|             logger.info( | ||||
|                 "Updating %d nodes in LLM index.", | ||||
|                 len(nodes), | ||||
|             ) | ||||
|             index.insert_nodes(nodes) | ||||
|         else: | ||||
|             msg = "No changes detected in LLM index." | ||||
|             logger.info(msg) | ||||
|  | ||||
|     index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) | ||||
|     return msg | ||||
|  | ||||
|  | ||||
| def llm_index_add_or_update_document(document: Document): | ||||
|     """ | ||||
|     Adds or updates a document in the LLM index. | ||||
|     If the document already exists, it will be replaced. | ||||
|     """ | ||||
|     new_nodes = build_document_node(document) | ||||
|  | ||||
|     index = load_or_build_index(nodes=new_nodes) | ||||
|  | ||||
|     remove_document_docstore_nodes(document, index) | ||||
|  | ||||
|     index.insert_nodes(new_nodes) | ||||
|  | ||||
|     index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) | ||||
|  | ||||
|  | ||||
| def llm_index_remove_document(document: Document): | ||||
|     """ | ||||
|     Removes a document from the LLM index. | ||||
|     """ | ||||
|     index = load_or_build_index() | ||||
|  | ||||
|     remove_document_docstore_nodes(document, index) | ||||
|  | ||||
|     index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) | ||||
|  | ||||
|  | ||||
| def truncate_content(content: str) -> str: | ||||
|     prompt_helper = PromptHelper( | ||||
|         context_window=8192, | ||||
|         num_output=512, | ||||
|         chunk_overlap_ratio=0.1, | ||||
|         chunk_size_limit=None, | ||||
|     ) | ||||
|     splitter = TokenTextSplitter(separator=" ", chunk_size=512, chunk_overlap=50) | ||||
|     content_chunks = splitter.split_text(content) | ||||
|     truncated_chunks = prompt_helper.truncate( | ||||
|         prompt=PromptTemplate(template="{content}"), | ||||
|         text_chunks=content_chunks, | ||||
|         padding=5, | ||||
|     ) | ||||
|     return " ".join(truncated_chunks) | ||||
|  | ||||
|  | ||||
| def query_similar_documents( | ||||
|     document: Document, | ||||
|     top_k: int = 5, | ||||
|     document_ids: list[int] | None = None, | ||||
| ) -> list[Document]: | ||||
|     """ | ||||
|     Runs a similarity query and returns top-k similar Document objects. | ||||
|     """ | ||||
|     index = load_or_build_index() | ||||
|  | ||||
|     # constrain only the node(s) that match the document IDs, if given | ||||
|     doc_node_ids = ( | ||||
|         [ | ||||
|             node.node_id | ||||
|             for node in index.docstore.docs.values() | ||||
|             if node.metadata.get("document_id") in document_ids | ||||
|         ] | ||||
|         if document_ids | ||||
|         else None | ||||
|     ) | ||||
|  | ||||
|     retriever = VectorIndexRetriever( | ||||
|         index=index, | ||||
|         similarity_top_k=top_k, | ||||
|         doc_ids=doc_node_ids, | ||||
|     ) | ||||
|  | ||||
|     query_text = truncate_content( | ||||
|         (document.title or "") + "\n" + (document.content or ""), | ||||
|     ) | ||||
|     results = retriever.retrieve(query_text) | ||||
|  | ||||
|     document_ids = [ | ||||
|         int(node.metadata["document_id"]) | ||||
|         for node in results | ||||
|         if "document_id" in node.metadata | ||||
|     ] | ||||
|  | ||||
|     return list(Document.objects.filter(pk__in=document_ids)) | ||||
							
								
								
									
										100
									
								
								src/paperless_ai/matching.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								src/paperless_ai/matching.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| import difflib | ||||
| import logging | ||||
| import re | ||||
|  | ||||
| from django.contrib.auth.models import User | ||||
|  | ||||
| from documents.models import Correspondent | ||||
| from documents.models import DocumentType | ||||
| from documents.models import StoragePath | ||||
| from documents.models import Tag | ||||
| from documents.permissions import get_objects_for_user_owner_aware | ||||
|  | ||||
| MATCH_THRESHOLD = 0.8 | ||||
|  | ||||
| logger = logging.getLogger("paperless_ai.matching") | ||||
|  | ||||
|  | ||||
| def match_tags_by_name(names: list[str], user: User) -> list[Tag]: | ||||
|     queryset = get_objects_for_user_owner_aware( | ||||
|         user, | ||||
|         ["view_tag"], | ||||
|         Tag, | ||||
|     ) | ||||
|     return _match_names_to_queryset(names, queryset, "name") | ||||
|  | ||||
|  | ||||
| def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]: | ||||
|     queryset = get_objects_for_user_owner_aware( | ||||
|         user, | ||||
|         ["view_correspondent"], | ||||
|         Correspondent, | ||||
|     ) | ||||
|     return _match_names_to_queryset(names, queryset, "name") | ||||
|  | ||||
|  | ||||
| def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]: | ||||
|     queryset = get_objects_for_user_owner_aware( | ||||
|         user, | ||||
|         ["view_documenttype"], | ||||
|         DocumentType, | ||||
|     ) | ||||
|     return _match_names_to_queryset(names, queryset, "name") | ||||
|  | ||||
|  | ||||
| def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]: | ||||
|     queryset = get_objects_for_user_owner_aware( | ||||
|         user, | ||||
|         ["view_storagepath"], | ||||
|         StoragePath, | ||||
|     ) | ||||
|     return _match_names_to_queryset(names, queryset, "name") | ||||
|  | ||||
|  | ||||
| def _normalize(s: str) -> str: | ||||
|     s = s.lower() | ||||
|     s = re.sub(r"[^\w\s]", "", s)  # remove punctuation | ||||
|     s = s.strip() | ||||
|     return s | ||||
|  | ||||
|  | ||||
| def _match_names_to_queryset(names: list[str], queryset, attr: str): | ||||
|     results = [] | ||||
|     objects = list(queryset) | ||||
|     object_names = [_normalize(getattr(obj, attr)) for obj in objects] | ||||
|  | ||||
|     for name in names: | ||||
|         if not name: | ||||
|             continue | ||||
|         target = _normalize(name) | ||||
|  | ||||
|         # First try exact match | ||||
|         if target in object_names: | ||||
|             index = object_names.index(target) | ||||
|             results.append(objects[index]) | ||||
|             # Remove the matched name from the list to avoid fuzzy matching later | ||||
|             object_names.remove(target) | ||||
|             continue | ||||
|  | ||||
|         # Fuzzy match fallback | ||||
|         matches = difflib.get_close_matches( | ||||
|             target, | ||||
|             object_names, | ||||
|             n=1, | ||||
|             cutoff=MATCH_THRESHOLD, | ||||
|         ) | ||||
|         if matches: | ||||
|             index = object_names.index(matches[0]) | ||||
|             results.append(objects[index]) | ||||
|         else: | ||||
|             pass | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def extract_unmatched_names( | ||||
|     names: list[str], | ||||
|     matched_objects: list, | ||||
|     attr="name", | ||||
| ) -> list[str]: | ||||
|     matched_names = {getattr(obj, attr).lower() for obj in matched_objects} | ||||
|     return [name for name in names if name.lower() not in matched_names] | ||||
							
								
								
									
										0
									
								
								src/paperless_ai/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/paperless_ai/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										186
									
								
								src/paperless_ai/tests/test_ai_classifier.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								src/paperless_ai/tests/test_ai_classifier.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,186 @@ | ||||
| import json | ||||
| from unittest.mock import MagicMock | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import pytest | ||||
| from django.test import override_settings | ||||
|  | ||||
| from documents.models import Document | ||||
| from paperless_ai.ai_classifier import build_prompt_with_rag | ||||
| from paperless_ai.ai_classifier import build_prompt_without_rag | ||||
| from paperless_ai.ai_classifier import get_ai_document_classification | ||||
| from paperless_ai.ai_classifier import get_context_for_document | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_document(): | ||||
|     doc = MagicMock(spec=Document) | ||||
|     doc.title = "Test Title" | ||||
|     doc.filename = "test_file.pdf" | ||||
|     doc.created = "2023-01-01" | ||||
|     doc.added = "2023-01-02" | ||||
|     doc.modified = "2023-01-03" | ||||
|  | ||||
|     tag1 = MagicMock() | ||||
|     tag1.name = "Tag1" | ||||
|     tag2 = MagicMock() | ||||
|     tag2.name = "Tag2" | ||||
|     doc.tags.all = MagicMock(return_value=[tag1, tag2]) | ||||
|  | ||||
|     doc.document_type = MagicMock() | ||||
|     doc.document_type.name = "Invoice" | ||||
|     doc.correspondent = MagicMock() | ||||
|     doc.correspondent.name = "Test Correspondent" | ||||
|     doc.archive_serial_number = "12345" | ||||
|     doc.content = "This is the document content." | ||||
|  | ||||
|     cf1 = MagicMock(__str__=lambda x: "Value1") | ||||
|     cf1.field = MagicMock() | ||||
|     cf1.field.name = "Field1" | ||||
|     cf1.value = "Value1" | ||||
|     cf2 = MagicMock(__str__=lambda x: "Value2") | ||||
|     cf2.field = MagicMock() | ||||
|     cf2.field.name = "Field2" | ||||
|     cf2.value = "Value2" | ||||
|     doc.custom_fields.all = MagicMock(return_value=[cf1, cf2]) | ||||
|  | ||||
|     return doc | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_similar_documents(): | ||||
|     doc1 = MagicMock() | ||||
|     doc1.content = "Content of document 1" | ||||
|     doc1.title = "Title 1" | ||||
|     doc1.filename = "file1.txt" | ||||
|  | ||||
|     doc2 = MagicMock() | ||||
|     doc2.content = "Content of document 2" | ||||
|     doc2.title = None | ||||
|     doc2.filename = "file2.txt" | ||||
|  | ||||
|     doc3 = MagicMock() | ||||
|     doc3.content = None | ||||
|     doc3.title = None | ||||
|     doc3.filename = None | ||||
|  | ||||
|     return [doc1, doc2, doc3] | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @patch("paperless_ai.client.AIClient.run_llm_query") | ||||
| @override_settings( | ||||
|     LLM_BACKEND="ollama", | ||||
|     LLM_MODEL="some_model", | ||||
| ) | ||||
| def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): | ||||
|     mock_run_llm_query.return_value = { | ||||
|         "title": "Test Title", | ||||
|         "tags": ["test", "document"], | ||||
|         "correspondents": ["John Doe"], | ||||
|         "document_types": ["report"], | ||||
|         "storage_paths": ["Reports"], | ||||
|         "dates": ["2023-01-01"], | ||||
|     } | ||||
|  | ||||
|     result = get_ai_document_classification(mock_document) | ||||
|  | ||||
|     assert result["title"] == "Test Title" | ||||
|     assert result["tags"] == ["test", "document"] | ||||
|     assert result["correspondents"] == ["John Doe"] | ||||
|     assert result["document_types"] == ["report"] | ||||
|     assert result["storage_paths"] == ["Reports"] | ||||
|     assert result["dates"] == ["2023-01-01"] | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @patch("paperless_ai.client.AIClient.run_llm_query") | ||||
| def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): | ||||
|     mock_run_llm_query.side_effect = Exception("LLM query failed") | ||||
|  | ||||
|     # assert raises an exception | ||||
|     with pytest.raises(Exception): | ||||
|         get_ai_document_classification(mock_document) | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @patch("paperless_ai.client.AIClient.run_llm_query") | ||||
| @patch("paperless_ai.ai_classifier.build_prompt_with_rag") | ||||
| @override_settings( | ||||
|     LLM_EMBEDDING_BACKEND="huggingface", | ||||
|     LLM_EMBEDDING_MODEL="some_model", | ||||
|     LLM_BACKEND="ollama", | ||||
|     LLM_MODEL="some_model", | ||||
| ) | ||||
| def test_use_rag_if_configured( | ||||
|     mock_build_prompt_with_rag, | ||||
|     mock_run_llm_query, | ||||
|     mock_document, | ||||
| ): | ||||
|     mock_build_prompt_with_rag.return_value = "Prompt with RAG" | ||||
|     mock_run_llm_query.return_value.text = json.dumps({}) | ||||
|     get_ai_document_classification(mock_document) | ||||
|     mock_build_prompt_with_rag.assert_called_once() | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @patch("paperless_ai.client.AIClient.run_llm_query") | ||||
| @patch("paperless_ai.ai_classifier.build_prompt_without_rag") | ||||
| @patch("paperless.config.AIConfig") | ||||
| @override_settings( | ||||
|     LLM_BACKEND="ollama", | ||||
|     LLM_MODEL="some_model", | ||||
| ) | ||||
| def test_use_without_rag_if_not_configured( | ||||
|     mock_ai_config, | ||||
|     mock_build_prompt_without_rag, | ||||
|     mock_run_llm_query, | ||||
|     mock_document, | ||||
| ): | ||||
|     mock_ai_config.llm_embedding_backend = None | ||||
|     mock_build_prompt_without_rag.return_value = "Prompt without RAG" | ||||
|     mock_run_llm_query.return_value.text = json.dumps({}) | ||||
|     get_ai_document_classification(mock_document) | ||||
|     mock_build_prompt_without_rag.assert_called_once() | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @override_settings( | ||||
|     LLM_EMBEDDING_BACKEND="huggingface", | ||||
|     LLM_BACKEND="ollama", | ||||
|     LLM_MODEL="some_model", | ||||
| ) | ||||
| def test_prompt_with_without_rag(mock_document): | ||||
|     with patch( | ||||
|         "paperless_ai.ai_classifier.get_context_for_document", | ||||
|         return_value="Context from similar documents", | ||||
|     ): | ||||
|         prompt = build_prompt_without_rag(mock_document) | ||||
|         assert "Additional context from similar documents:" not in prompt | ||||
|  | ||||
|         prompt = build_prompt_with_rag(mock_document) | ||||
|         assert "Additional context from similar documents:" in prompt | ||||
|  | ||||
|  | ||||
| @patch("paperless_ai.ai_classifier.query_similar_documents") | ||||
| def test_get_context_for_document( | ||||
|     mock_query_similar_documents, | ||||
|     mock_document, | ||||
|     mock_similar_documents, | ||||
| ): | ||||
|     mock_query_similar_documents.return_value = mock_similar_documents | ||||
|  | ||||
|     result = get_context_for_document(mock_document, max_docs=2) | ||||
|  | ||||
|     expected_result = ( | ||||
|         "TITLE: Title 1\nContent of document 1\n\n" | ||||
|         "TITLE: file2.txt\nContent of document 2" | ||||
|     ) | ||||
|     assert result == expected_result | ||||
|     mock_query_similar_documents.assert_called_once() | ||||
|  | ||||
|  | ||||
| def test_get_context_for_document_no_similar_docs(mock_document): | ||||
|     with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]): | ||||
|         result = get_context_for_document(mock_document) | ||||
|         assert result == "" | ||||
							
								
								
									
										334
									
								
								src/paperless_ai/tests/test_ai_indexing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										334
									
								
								src/paperless_ai/tests/test_ai_indexing.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,334 @@ | ||||
| import json | ||||
| from unittest.mock import MagicMock | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import pytest | ||||
| from django.test import override_settings | ||||
| from django.utils import timezone | ||||
| from llama_index.core.base.embeddings.base import BaseEmbedding | ||||
|  | ||||
| from documents.models import Document | ||||
| from paperless_ai import indexing | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def temp_llm_index_dir(tmp_path): | ||||
|     original_dir = indexing.settings.LLM_INDEX_DIR | ||||
|     indexing.settings.LLM_INDEX_DIR = tmp_path | ||||
|     yield tmp_path | ||||
|     indexing.settings.LLM_INDEX_DIR = original_dir | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def real_document(db): | ||||
|     return Document.objects.create( | ||||
|         title="Test Document", | ||||
|         content="This is some test content.", | ||||
|         added=timezone.now(), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_embed_model(): | ||||
|     fake = FakeEmbedding() | ||||
|     with ( | ||||
|         patch("paperless_ai.indexing.get_embedding_model") as mock_index, | ||||
|         patch( | ||||
|             "paperless_ai.embedding.get_embedding_model", | ||||
|         ) as mock_embedding, | ||||
|     ): | ||||
|         mock_index.return_value = fake | ||||
|         mock_embedding.return_value = fake | ||||
|         yield mock_index | ||||
|  | ||||
|  | ||||
| class FakeEmbedding(BaseEmbedding): | ||||
|     # TODO: maybe a better way to do this? | ||||
|     def _aget_query_embedding(self, query: str) -> list[float]: | ||||
|         return [0.1] * self.get_query_embedding_dim() | ||||
|  | ||||
|     def _get_query_embedding(self, query: str) -> list[float]: | ||||
|         return [0.1] * self.get_query_embedding_dim() | ||||
|  | ||||
|     def _get_text_embedding(self, text: str) -> list[float]: | ||||
|         return [0.1] * self.get_query_embedding_dim() | ||||
|  | ||||
|     def get_query_embedding_dim(self) -> int: | ||||
|         return 384  # Match your real FAISS config | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| def test_build_document_node(real_document): | ||||
|     nodes = indexing.build_document_node(real_document) | ||||
|     assert len(nodes) > 0 | ||||
|     assert nodes[0].metadata["document_id"] == str(real_document.id) | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| def test_update_llm_index( | ||||
|     temp_llm_index_dir, | ||||
|     real_document, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     with patch("documents.models.Document.objects.all") as mock_all: | ||||
|         mock_queryset = MagicMock() | ||||
|         mock_queryset.exists.return_value = True | ||||
|         mock_queryset.__iter__.return_value = iter([real_document]) | ||||
|         mock_all.return_value = mock_queryset | ||||
|         indexing.update_llm_index(rebuild=True) | ||||
|  | ||||
|         assert any(temp_llm_index_dir.glob("*.json")) | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| def test_update_llm_index_removes_meta( | ||||
|     temp_llm_index_dir, | ||||
|     real_document, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     # Pre-create a meta.json with incorrect data | ||||
|     (temp_llm_index_dir / "meta.json").write_text( | ||||
|         json.dumps({"embedding_model": "old", "dim": 1}), | ||||
|     ) | ||||
|  | ||||
|     with patch("documents.models.Document.objects.all") as mock_all: | ||||
|         mock_queryset = MagicMock() | ||||
|         mock_queryset.exists.return_value = True | ||||
|         mock_queryset.__iter__.return_value = iter([real_document]) | ||||
|         mock_all.return_value = mock_queryset | ||||
|         indexing.update_llm_index(rebuild=True) | ||||
|  | ||||
|     meta = json.loads((temp_llm_index_dir / "meta.json").read_text()) | ||||
|     from paperless.config import AIConfig | ||||
|  | ||||
|     config = AIConfig() | ||||
|     expected_model = config.llm_embedding_model or ( | ||||
|         "text-embedding-3-small" | ||||
|         if config.llm_embedding_backend == "openai" | ||||
|         else "sentence-transformers/all-MiniLM-L6-v2" | ||||
|     ) | ||||
|     assert meta == {"embedding_model": expected_model, "dim": 384} | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| def test_update_llm_index_partial_update( | ||||
|     temp_llm_index_dir, | ||||
|     real_document, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     doc2 = Document.objects.create( | ||||
|         title="Test Document 2", | ||||
|         content="This is some test content 2.", | ||||
|         added=timezone.now(), | ||||
|         checksum="1234567890abcdef", | ||||
|     ) | ||||
|     # Initial index | ||||
|     with patch("documents.models.Document.objects.all") as mock_all: | ||||
|         mock_queryset = MagicMock() | ||||
|         mock_queryset.exists.return_value = True | ||||
|         mock_queryset.__iter__.return_value = iter([real_document, doc2]) | ||||
|         mock_all.return_value = mock_queryset | ||||
|  | ||||
|         indexing.update_llm_index(rebuild=True) | ||||
|  | ||||
|     # modify document | ||||
|     updated_document = real_document | ||||
|     updated_document.modified = timezone.now()  # simulate modification | ||||
|  | ||||
|     # new doc | ||||
|     doc3 = Document.objects.create( | ||||
|         title="Test Document 3", | ||||
|         content="This is some test content 3.", | ||||
|         added=timezone.now(), | ||||
|         checksum="abcdef1234567890", | ||||
|     ) | ||||
|  | ||||
|     with patch("documents.models.Document.objects.all") as mock_all: | ||||
|         mock_queryset = MagicMock() | ||||
|         mock_queryset.exists.return_value = True | ||||
|         mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3]) | ||||
|         mock_all.return_value = mock_queryset | ||||
|  | ||||
|         # assert logs "Updating LLM index with %d new nodes and removing %d old nodes." | ||||
|         with patch("paperless_ai.indexing.logger") as mock_logger: | ||||
|             indexing.update_llm_index(rebuild=False) | ||||
|             mock_logger.info.assert_called_once_with( | ||||
|                 "Updating %d nodes in LLM index.", | ||||
|                 2, | ||||
|             ) | ||||
|         indexing.update_llm_index(rebuild=False) | ||||
|  | ||||
|     assert any(temp_llm_index_dir.glob("*.json")) | ||||
|  | ||||
|  | ||||
| def test_get_or_create_storage_context_raises_exception( | ||||
|     temp_llm_index_dir, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     with pytest.raises(Exception): | ||||
|         indexing.get_or_create_storage_context(rebuild=False) | ||||
|  | ||||
|  | ||||
| @override_settings( | ||||
|     LLM_EMBEDDING_BACKEND="huggingface", | ||||
| ) | ||||
| def test_load_or_build_index_builds_when_nodes_given( | ||||
|     temp_llm_index_dir, | ||||
|     real_document, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     with ( | ||||
|         patch( | ||||
|             "paperless_ai.indexing.load_index_from_storage", | ||||
|             side_effect=ValueError("Index not found"), | ||||
|         ), | ||||
|         patch( | ||||
|             "paperless_ai.indexing.VectorStoreIndex", | ||||
|             return_value=MagicMock(), | ||||
|         ) as mock_index_cls, | ||||
|         patch( | ||||
|             "paperless_ai.indexing.get_or_create_storage_context", | ||||
|             return_value=MagicMock(), | ||||
|         ) as mock_storage, | ||||
|     ): | ||||
|         mock_storage.return_value.persist_dir = temp_llm_index_dir | ||||
|         indexing.load_or_build_index( | ||||
|             nodes=[indexing.build_document_node(real_document)], | ||||
|         ) | ||||
|         mock_index_cls.assert_called_once() | ||||
|  | ||||
|  | ||||
| def test_load_or_build_index_raises_exception_when_no_nodes( | ||||
|     temp_llm_index_dir, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     with ( | ||||
|         patch( | ||||
|             "paperless_ai.indexing.load_index_from_storage", | ||||
|             side_effect=ValueError("Index not found"), | ||||
|         ), | ||||
|         patch( | ||||
|             "paperless_ai.indexing.get_or_create_storage_context", | ||||
|             return_value=MagicMock(), | ||||
|         ), | ||||
|     ): | ||||
|         with pytest.raises(Exception): | ||||
|             indexing.load_or_build_index() | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| def test_load_or_build_index_succeeds_when_nodes_given( | ||||
|     temp_llm_index_dir, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     with ( | ||||
|         patch( | ||||
|             "paperless_ai.indexing.load_index_from_storage", | ||||
|             side_effect=ValueError("Index not found"), | ||||
|         ), | ||||
|         patch( | ||||
|             "paperless_ai.indexing.VectorStoreIndex", | ||||
|             return_value=MagicMock(), | ||||
|         ) as mock_index_cls, | ||||
|         patch( | ||||
|             "paperless_ai.indexing.get_or_create_storage_context", | ||||
|             return_value=MagicMock(), | ||||
|         ) as mock_storage, | ||||
|     ): | ||||
|         mock_storage.return_value.persist_dir = temp_llm_index_dir | ||||
|         indexing.load_or_build_index( | ||||
|             nodes=[MagicMock()], | ||||
|         ) | ||||
|         mock_index_cls.assert_called_once() | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| def test_add_or_update_document_updates_existing_entry( | ||||
|     temp_llm_index_dir, | ||||
|     real_document, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     indexing.update_llm_index(rebuild=True) | ||||
|     indexing.llm_index_add_or_update_document(real_document) | ||||
|  | ||||
|     assert any(temp_llm_index_dir.glob("*.json")) | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| def test_remove_document_deletes_node_from_docstore( | ||||
|     temp_llm_index_dir, | ||||
|     real_document, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     indexing.update_llm_index(rebuild=True) | ||||
|     index = indexing.load_or_build_index() | ||||
|     assert len(index.docstore.docs) == 1 | ||||
|  | ||||
|     indexing.llm_index_remove_document(real_document) | ||||
|     index = indexing.load_or_build_index() | ||||
|     assert len(index.docstore.docs) == 0 | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| def test_update_llm_index_no_documents( | ||||
|     temp_llm_index_dir, | ||||
|     mock_embed_model, | ||||
| ): | ||||
|     with patch("documents.models.Document.objects.all") as mock_all: | ||||
|         mock_queryset = MagicMock() | ||||
|         mock_queryset.exists.return_value = False | ||||
|         mock_queryset.__iter__.return_value = iter([]) | ||||
|         mock_all.return_value = mock_queryset | ||||
|  | ||||
|         # check log message | ||||
|         with patch("paperless_ai.indexing.logger") as mock_logger: | ||||
|             indexing.update_llm_index(rebuild=True) | ||||
|             mock_logger.warning.assert_called_once_with( | ||||
|                 "No documents found to index.", | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @override_settings( | ||||
|     LLM_EMBEDDING_BACKEND="huggingface", | ||||
|     LLM_BACKEND="ollama", | ||||
| ) | ||||
| def test_query_similar_documents( | ||||
|     temp_llm_index_dir, | ||||
|     real_document, | ||||
| ): | ||||
|     with ( | ||||
|         patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage, | ||||
|         patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index, | ||||
|         patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls, | ||||
|         patch("paperless_ai.indexing.Document.objects.filter") as mock_filter, | ||||
|     ): | ||||
|         mock_storage.return_value = MagicMock() | ||||
|         mock_storage.return_value.persist_dir = temp_llm_index_dir | ||||
|  | ||||
|         mock_index = MagicMock() | ||||
|         mock_load_or_build_index.return_value = mock_index | ||||
|  | ||||
|         mock_retriever = MagicMock() | ||||
|         mock_retriever_cls.return_value = mock_retriever | ||||
|  | ||||
|         mock_node1 = MagicMock() | ||||
|         mock_node1.metadata = {"document_id": 1} | ||||
|  | ||||
|         mock_node2 = MagicMock() | ||||
|         mock_node2.metadata = {"document_id": 2} | ||||
|  | ||||
|         mock_retriever.retrieve.return_value = [mock_node1, mock_node2] | ||||
|  | ||||
|         mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)] | ||||
|         mock_filter.return_value = mock_filtered_docs | ||||
|  | ||||
|         result = indexing.query_similar_documents(real_document, top_k=3) | ||||
|  | ||||
|         mock_load_or_build_index.assert_called_once() | ||||
|         mock_retriever_cls.assert_called_once() | ||||
|         mock_retriever.retrieve.assert_called_once_with( | ||||
|             "Test Document\nThis is some test content.", | ||||
|         ) | ||||
|         mock_filter.assert_called_once_with(pk__in=[1, 2]) | ||||
|  | ||||
|         assert result == mock_filtered_docs | ||||
							
								
								
									
										142
									
								
								src/paperless_ai/tests/test_chat.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								src/paperless_ai/tests/test_chat.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,142 @@ | ||||
| from unittest.mock import MagicMock | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import pytest | ||||
| from llama_index.core import VectorStoreIndex | ||||
| from llama_index.core.schema import TextNode | ||||
|  | ||||
| from paperless_ai.chat import stream_chat_with_documents | ||||
|  | ||||
|  | ||||
| @pytest.fixture(autouse=True) | ||||
| def patch_embed_model(): | ||||
|     from llama_index.core import settings as llama_settings | ||||
|  | ||||
|     mock_embed_model = MagicMock() | ||||
|     mock_embed_model._get_text_embedding_batch.return_value = [ | ||||
|         [0.1] * 1536, | ||||
|     ]  # 1 vector per input | ||||
|     llama_settings.Settings._embed_model = mock_embed_model | ||||
|     yield | ||||
|     llama_settings.Settings._embed_model = None | ||||
|  | ||||
|  | ||||
| @pytest.fixture(autouse=True) | ||||
| def patch_embed_nodes(): | ||||
|     with patch( | ||||
|         "llama_index.core.indices.vector_store.base.embed_nodes", | ||||
|     ) as mock_embed_nodes: | ||||
|         mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: { | ||||
|             node.node_id: [0.1] * 1536 for node in nodes | ||||
|         } | ||||
|         yield | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_document(): | ||||
|     doc = MagicMock() | ||||
|     doc.pk = 1 | ||||
|     doc.title = "Test Document" | ||||
|     doc.filename = "test_file.pdf" | ||||
|     doc.content = "This is the document content." | ||||
|     return doc | ||||
|  | ||||
|  | ||||
| def test_stream_chat_with_one_document_full_content(mock_document): | ||||
|     with ( | ||||
|         patch("paperless_ai.chat.AIClient") as mock_client_cls, | ||||
|         patch("paperless_ai.chat.load_or_build_index") as mock_load_index, | ||||
|         patch( | ||||
|             "paperless_ai.chat.RetrieverQueryEngine.from_args", | ||||
|         ) as mock_query_engine_cls, | ||||
|     ): | ||||
|         mock_client = MagicMock() | ||||
|         mock_client_cls.return_value = mock_client | ||||
|         mock_client.llm = MagicMock() | ||||
|  | ||||
|         mock_node = TextNode( | ||||
|             text="This is node content.", | ||||
|             metadata={"document_id": str(mock_document.pk), "title": "Test Document"}, | ||||
|         ) | ||||
|         mock_index = MagicMock() | ||||
|         mock_index.docstore.docs.values.return_value = [mock_node] | ||||
|         mock_load_index.return_value = mock_index | ||||
|  | ||||
|         mock_response_stream = MagicMock() | ||||
|         mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) | ||||
|         mock_query_engine = MagicMock() | ||||
|         mock_query_engine_cls.return_value = mock_query_engine | ||||
|         mock_query_engine.query.return_value = mock_response_stream | ||||
|  | ||||
|         output = list(stream_chat_with_documents("What is this?", [mock_document])) | ||||
|  | ||||
|         assert output == ["chunk1", "chunk2"] | ||||
|  | ||||
|  | ||||
| def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes): | ||||
|     with ( | ||||
|         patch("paperless_ai.chat.AIClient") as mock_client_cls, | ||||
|         patch("paperless_ai.chat.load_or_build_index") as mock_load_index, | ||||
|         patch( | ||||
|             "paperless_ai.chat.RetrieverQueryEngine.from_args", | ||||
|         ) as mock_query_engine_cls, | ||||
|         patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever, | ||||
|     ): | ||||
|         # Mock AIClient and LLM | ||||
|         mock_client = MagicMock() | ||||
|         mock_client_cls.return_value = mock_client | ||||
|         mock_client.llm = MagicMock() | ||||
|  | ||||
|         # Create two real TextNodes | ||||
|         mock_node1 = TextNode( | ||||
|             text="Content for doc 1.", | ||||
|             metadata={"document_id": "1", "title": "Document 1"}, | ||||
|         ) | ||||
|         mock_node2 = TextNode( | ||||
|             text="Content for doc 2.", | ||||
|             metadata={"document_id": "2", "title": "Document 2"}, | ||||
|         ) | ||||
|         mock_index = MagicMock() | ||||
|         mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2] | ||||
|         mock_load_index.return_value = mock_index | ||||
|  | ||||
|         # Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2 | ||||
|         mock_retriever = MagicMock() | ||||
|         mock_retriever.retrieve.return_value = [mock_node1, mock_node2] | ||||
|         mock_as_retriever.return_value = mock_retriever | ||||
|  | ||||
|         # Mock response stream | ||||
|         mock_response_stream = MagicMock() | ||||
|         mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) | ||||
|  | ||||
|         # Mock RetrieverQueryEngine | ||||
|         mock_query_engine = MagicMock() | ||||
|         mock_query_engine_cls.return_value = mock_query_engine | ||||
|         mock_query_engine.query.return_value = mock_response_stream | ||||
|  | ||||
|         # Fake documents | ||||
|         doc1 = MagicMock(pk=1) | ||||
|         doc2 = MagicMock(pk=2) | ||||
|  | ||||
|         output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) | ||||
|  | ||||
|         assert output == ["chunk1", "chunk2"] | ||||
|  | ||||
|  | ||||
| def test_stream_chat_no_matching_nodes(): | ||||
|     with ( | ||||
|         patch("paperless_ai.chat.AIClient") as mock_client_cls, | ||||
|         patch("paperless_ai.chat.load_or_build_index") as mock_load_index, | ||||
|     ): | ||||
|         mock_client = MagicMock() | ||||
|         mock_client_cls.return_value = mock_client | ||||
|         mock_client.llm = MagicMock() | ||||
|  | ||||
|         mock_index = MagicMock() | ||||
|         # No matching nodes | ||||
|         mock_index.docstore.docs.values.return_value = [] | ||||
|         mock_load_index.return_value = mock_index | ||||
|  | ||||
|         output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) | ||||
|  | ||||
|         assert output == ["Sorry, I couldn't find any content to answer your question."] | ||||
							
								
								
									
										109
									
								
								src/paperless_ai/tests/test_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								src/paperless_ai/tests/test_client.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | ||||
| from unittest.mock import MagicMock | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import pytest | ||||
| from llama_index.core.llms import ChatMessage | ||||
| from llama_index.core.llms.llm import ToolSelection | ||||
|  | ||||
| from paperless_ai.client import AIClient | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_ai_config(): | ||||
|     with patch("paperless_ai.client.AIConfig") as MockAIConfig: | ||||
|         mock_config = MagicMock() | ||||
|         MockAIConfig.return_value = mock_config | ||||
|         yield mock_config | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_ollama_llm(): | ||||
|     with patch("paperless_ai.client.Ollama") as MockOllama: | ||||
|         yield MockOllama | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_openai_llm(): | ||||
|     with patch("paperless_ai.client.OpenAI") as MockOpenAI: | ||||
|         yield MockOpenAI | ||||
|  | ||||
|  | ||||
| def test_get_llm_ollama(mock_ai_config, mock_ollama_llm): | ||||
|     mock_ai_config.llm_backend = "ollama" | ||||
|     mock_ai_config.llm_model = "test_model" | ||||
|     mock_ai_config.llm_endpoint = "http://test-url" | ||||
|  | ||||
|     client = AIClient() | ||||
|  | ||||
|     mock_ollama_llm.assert_called_once_with( | ||||
|         model="test_model", | ||||
|         base_url="http://test-url", | ||||
|         request_timeout=120, | ||||
|     ) | ||||
|     assert client.llm == mock_ollama_llm.return_value | ||||
|  | ||||
|  | ||||
| def test_get_llm_openai(mock_ai_config, mock_openai_llm): | ||||
|     mock_ai_config.llm_backend = "openai" | ||||
|     mock_ai_config.llm_model = "test_model" | ||||
|     mock_ai_config.llm_api_key = "test_api_key" | ||||
|  | ||||
|     client = AIClient() | ||||
|  | ||||
|     mock_openai_llm.assert_called_once_with( | ||||
|         model="test_model", | ||||
|         api_key="test_api_key", | ||||
|     ) | ||||
|     assert client.llm == mock_openai_llm.return_value | ||||
|  | ||||
|  | ||||
| def test_get_llm_unsupported_backend(mock_ai_config): | ||||
|     mock_ai_config.llm_backend = "unsupported" | ||||
|  | ||||
|     with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"): | ||||
|         AIClient() | ||||
|  | ||||
|  | ||||
| def test_run_llm_query(mock_ai_config, mock_ollama_llm): | ||||
|     mock_ai_config.llm_backend = "ollama" | ||||
|     mock_ai_config.llm_model = "test_model" | ||||
|     mock_ai_config.llm_endpoint = "http://test-url" | ||||
|  | ||||
|     mock_llm_instance = mock_ollama_llm.return_value | ||||
|  | ||||
|     tool_selection = ToolSelection( | ||||
|         tool_id="call_test", | ||||
|         tool_name="DocumentClassifierSchema", | ||||
|         tool_kwargs={ | ||||
|             "title": "Test Title", | ||||
|             "tags": ["test", "document"], | ||||
|             "correspondents": ["John Doe"], | ||||
|             "document_types": ["report"], | ||||
|             "storage_paths": ["Reports"], | ||||
|             "dates": ["2023-01-01"], | ||||
|         }, | ||||
|     ) | ||||
|  | ||||
|     mock_llm_instance.chat_with_tools.return_value = MagicMock() | ||||
|     mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection] | ||||
|  | ||||
|     client = AIClient() | ||||
|     result = client.run_llm_query("test_prompt") | ||||
|  | ||||
|     assert result["title"] == "Test Title" | ||||
|  | ||||
|  | ||||
| def test_run_chat(mock_ai_config, mock_ollama_llm): | ||||
|     mock_ai_config.llm_backend = "ollama" | ||||
|     mock_ai_config.llm_model = "test_model" | ||||
|     mock_ai_config.llm_endpoint = "http://test-url" | ||||
|  | ||||
|     mock_llm_instance = mock_ollama_llm.return_value | ||||
|     mock_llm_instance.chat.return_value = "test_chat_result" | ||||
|  | ||||
|     client = AIClient() | ||||
|     messages = [ChatMessage(role="user", content="Hello")] | ||||
|     result = client.run_chat(messages) | ||||
|  | ||||
|     mock_llm_instance.chat.assert_called_once_with(messages) | ||||
|     assert result == "test_chat_result" | ||||
							
								
								
									
										169
									
								
								src/paperless_ai/tests/test_embedding.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								src/paperless_ai/tests/test_embedding.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,169 @@ | ||||
| import json | ||||
| from unittest.mock import MagicMock | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| import paperless_ai.embedding as embedding | ||||
| from documents.models import Document | ||||
| from paperless.models import LLMEmbeddingBackend | ||||
| from paperless_ai.embedding import build_llm_index_text | ||||
| from paperless_ai.embedding import get_embedding_dim | ||||
| from paperless_ai.embedding import get_embedding_model | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_ai_config(): | ||||
|     with patch("paperless_ai.embedding.AIConfig") as MockAIConfig: | ||||
|         yield MockAIConfig | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def temp_llm_index_dir(tmp_path): | ||||
|     original_dir = embedding.settings.LLM_INDEX_DIR | ||||
|     embedding.settings.LLM_INDEX_DIR = tmp_path | ||||
|     yield tmp_path | ||||
|     embedding.settings.LLM_INDEX_DIR = original_dir | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_document(): | ||||
|     doc = MagicMock(spec=Document) | ||||
|     doc.title = "Test Title" | ||||
|     doc.filename = "test_file.pdf" | ||||
|     doc.created = "2023-01-01" | ||||
|     doc.added = "2023-01-02" | ||||
|     doc.modified = "2023-01-03" | ||||
|  | ||||
|     tag1 = MagicMock() | ||||
|     tag1.name = "Tag1" | ||||
|     tag2 = MagicMock() | ||||
|     tag2.name = "Tag2" | ||||
|     doc.tags.all = MagicMock(return_value=[tag1, tag2]) | ||||
|  | ||||
|     doc.document_type = MagicMock() | ||||
|     doc.document_type.name = "Invoice" | ||||
|     doc.correspondent = MagicMock() | ||||
|     doc.correspondent.name = "Test Correspondent" | ||||
|     doc.archive_serial_number = "12345" | ||||
|     doc.content = "This is the document content." | ||||
|  | ||||
|     cf1 = MagicMock(__str__=lambda x: "Value1") | ||||
|     cf1.field = MagicMock() | ||||
|     cf1.field.name = "Field1" | ||||
|     cf1.value = "Value1" | ||||
|     cf2 = MagicMock(__str__=lambda x: "Value2") | ||||
|     cf2.field = MagicMock() | ||||
|     cf2.field.name = "Field2" | ||||
|     cf2.value = "Value2" | ||||
|     doc.custom_fields.all = MagicMock(return_value=[cf1, cf2]) | ||||
|  | ||||
|     return doc | ||||
|  | ||||
|  | ||||
| def test_get_embedding_model_openai(mock_ai_config): | ||||
|     mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI | ||||
|     mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small" | ||||
|     mock_ai_config.return_value.llm_api_key = "test_api_key" | ||||
|  | ||||
|     with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding: | ||||
|         model = get_embedding_model() | ||||
|         MockOpenAIEmbedding.assert_called_once_with( | ||||
|             model="text-embedding-3-small", | ||||
|             api_key="test_api_key", | ||||
|         ) | ||||
|         assert model == MockOpenAIEmbedding.return_value | ||||
|  | ||||
|  | ||||
| def test_get_embedding_model_huggingface(mock_ai_config): | ||||
|     mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE | ||||
|     mock_ai_config.return_value.llm_embedding_model = ( | ||||
|         "sentence-transformers/all-MiniLM-L6-v2" | ||||
|     ) | ||||
|  | ||||
|     with patch( | ||||
|         "paperless_ai.embedding.HuggingFaceEmbedding", | ||||
|     ) as MockHuggingFaceEmbedding: | ||||
|         model = get_embedding_model() | ||||
|         MockHuggingFaceEmbedding.assert_called_once_with( | ||||
|             model_name="sentence-transformers/all-MiniLM-L6-v2", | ||||
|         ) | ||||
|         assert model == MockHuggingFaceEmbedding.return_value | ||||
|  | ||||
|  | ||||
| def test_get_embedding_model_invalid_backend(mock_ai_config): | ||||
|     mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND" | ||||
|  | ||||
|     with pytest.raises( | ||||
|         ValueError, | ||||
|         match="Unsupported embedding backend: INVALID_BACKEND", | ||||
|     ): | ||||
|         get_embedding_model() | ||||
|  | ||||
|  | ||||
| def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config): | ||||
|     mock_ai_config.return_value.llm_embedding_backend = "openai" | ||||
|     mock_ai_config.return_value.llm_embedding_model = None | ||||
|  | ||||
|     class DummyEmbedding: | ||||
|         def get_text_embedding(self, text): | ||||
|             return [0.0] * 7 | ||||
|  | ||||
|     with patch( | ||||
|         "paperless_ai.embedding.get_embedding_model", | ||||
|         return_value=DummyEmbedding(), | ||||
|     ) as mock_get: | ||||
|         dim = get_embedding_dim() | ||||
|         mock_get.assert_called_once() | ||||
|  | ||||
|     assert dim == 7 | ||||
|     meta = json.loads((temp_llm_index_dir / "meta.json").read_text()) | ||||
|     assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7} | ||||
|  | ||||
|  | ||||
| def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config): | ||||
|     mock_ai_config.return_value.llm_embedding_backend = "openai" | ||||
|     mock_ai_config.return_value.llm_embedding_model = None | ||||
|  | ||||
|     (temp_llm_index_dir / "meta.json").write_text( | ||||
|         json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}), | ||||
|     ) | ||||
|  | ||||
|     with patch("paperless_ai.embedding.get_embedding_model") as mock_get: | ||||
|         assert get_embedding_dim() == 11 | ||||
|         mock_get.assert_not_called() | ||||
|  | ||||
|  | ||||
| def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config): | ||||
|     mock_ai_config.return_value.llm_embedding_backend = "openai" | ||||
|     mock_ai_config.return_value.llm_embedding_model = None | ||||
|  | ||||
|     (temp_llm_index_dir / "meta.json").write_text( | ||||
|         json.dumps({"embedding_model": "old", "dim": 11}), | ||||
|     ) | ||||
|  | ||||
|     with pytest.raises( | ||||
|         RuntimeError, | ||||
|         match="Embedding model changed from old to text-embedding-3-small", | ||||
|     ): | ||||
|         get_embedding_dim() | ||||
|  | ||||
|  | ||||
| def test_build_llm_index_text(mock_document): | ||||
|     with patch("documents.models.Note.objects.filter") as mock_notes_filter: | ||||
|         mock_notes_filter.return_value = [ | ||||
|             MagicMock(note="Note1"), | ||||
|             MagicMock(note="Note2"), | ||||
|         ] | ||||
|  | ||||
|         result = build_llm_index_text(mock_document) | ||||
|  | ||||
|         assert "Title: Test Title" in result | ||||
|         assert "Filename: test_file.pdf" in result | ||||
|         assert "Created: 2023-01-01" in result | ||||
|         assert "Tags: Tag1, Tag2" in result | ||||
|         assert "Document Type: Invoice" in result | ||||
|         assert "Correspondent: Test Correspondent" in result | ||||
|         assert "Notes: Note1,Note2" in result | ||||
|         assert "Content:\n\nThis is the document content." in result | ||||
|         assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result | ||||
							
								
								
									
										86
									
								
								src/paperless_ai/tests/test_matching.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								src/paperless_ai/tests/test_matching.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| from unittest.mock import patch | ||||
|  | ||||
| from django.test import TestCase | ||||
|  | ||||
| from documents.models import Correspondent | ||||
| from documents.models import DocumentType | ||||
| from documents.models import StoragePath | ||||
| from documents.models import Tag | ||||
| from paperless_ai.matching import extract_unmatched_names | ||||
| from paperless_ai.matching import match_correspondents_by_name | ||||
| from paperless_ai.matching import match_document_types_by_name | ||||
| from paperless_ai.matching import match_storage_paths_by_name | ||||
| from paperless_ai.matching import match_tags_by_name | ||||
|  | ||||
|  | ||||
| class TestAIMatching(TestCase): | ||||
|     def setUp(self): | ||||
|         # Create test data for Tag | ||||
|         self.tag1 = Tag.objects.create(name="Test Tag 1") | ||||
|         self.tag2 = Tag.objects.create(name="Test Tag 2") | ||||
|  | ||||
|         # Create test data for Correspondent | ||||
|         self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1") | ||||
|         self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2") | ||||
|  | ||||
|         # Create test data for DocumentType | ||||
|         self.document_type1 = DocumentType.objects.create(name="Test Document Type 1") | ||||
|         self.document_type2 = DocumentType.objects.create(name="Test Document Type 2") | ||||
|  | ||||
|         # Create test data for StoragePath | ||||
|         self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1") | ||||
|         self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2") | ||||
|  | ||||
|     @patch("paperless_ai.matching.get_objects_for_user_owner_aware") | ||||
|     def test_match_tags_by_name(self, mock_get_objects): | ||||
|         mock_get_objects.return_value = Tag.objects.all() | ||||
|         names = ["Test Tag 1", "Nonexistent Tag"] | ||||
|         result = match_tags_by_name(names, user=None) | ||||
|         self.assertEqual(len(result), 1) | ||||
|         self.assertEqual(result[0].name, "Test Tag 1") | ||||
|  | ||||
|     @patch("paperless_ai.matching.get_objects_for_user_owner_aware") | ||||
|     def test_match_correspondents_by_name(self, mock_get_objects): | ||||
|         mock_get_objects.return_value = Correspondent.objects.all() | ||||
|         names = ["Test Correspondent 1", "Nonexistent Correspondent"] | ||||
|         result = match_correspondents_by_name(names, user=None) | ||||
|         self.assertEqual(len(result), 1) | ||||
|         self.assertEqual(result[0].name, "Test Correspondent 1") | ||||
|  | ||||
|     @patch("paperless_ai.matching.get_objects_for_user_owner_aware") | ||||
|     def test_match_document_types_by_name(self, mock_get_objects): | ||||
|         mock_get_objects.return_value = DocumentType.objects.all() | ||||
|         names = ["Test Document Type 1", "Nonexistent Document Type"] | ||||
|         result = match_document_types_by_name(names, user=None) | ||||
|         self.assertEqual(len(result), 1) | ||||
|         self.assertEqual(result[0].name, "Test Document Type 1") | ||||
|  | ||||
|     @patch("paperless_ai.matching.get_objects_for_user_owner_aware") | ||||
|     def test_match_storage_paths_by_name(self, mock_get_objects): | ||||
|         mock_get_objects.return_value = StoragePath.objects.all() | ||||
|         names = ["Test Storage Path 1", "Nonexistent Storage Path"] | ||||
|         result = match_storage_paths_by_name(names, user=None) | ||||
|         self.assertEqual(len(result), 1) | ||||
|         self.assertEqual(result[0].name, "Test Storage Path 1") | ||||
|  | ||||
|     def test_extract_unmatched_names(self): | ||||
|         llm_names = ["Test Tag 1", "Nonexistent Tag"] | ||||
|         matched_objects = [self.tag1] | ||||
|         unmatched_names = extract_unmatched_names(llm_names, matched_objects) | ||||
|         self.assertEqual(unmatched_names, ["Nonexistent Tag"]) | ||||
|  | ||||
|     @patch("paperless_ai.matching.get_objects_for_user_owner_aware") | ||||
|     def test_match_tags_by_name_with_empty_names(self, mock_get_objects): | ||||
|         mock_get_objects.return_value = Tag.objects.all() | ||||
|         names = [None, "", "   "] | ||||
|         result = match_tags_by_name(names, user=None) | ||||
|         self.assertEqual(result, []) | ||||
|  | ||||
|     @patch("paperless_ai.matching.get_objects_for_user_owner_aware") | ||||
|     def test_match_tags_with_fuzzy_matching(self, mock_get_objects): | ||||
|         mock_get_objects.return_value = Tag.objects.all() | ||||
|         names = ["Test Taag 1", "Teest Tag 2"] | ||||
|         result = match_tags_by_name(names, user=None) | ||||
|         self.assertEqual(len(result), 2) | ||||
|         self.assertEqual(result[0].name, "Test Tag 1") | ||||
|         self.assertEqual(result[1].name, "Test Tag 2") | ||||
		Reference in New Issue
	
	Block a user