mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-18 17:34:39 -05:00
Compare commits
82 Commits
87e5d82c46
...
3fa4fad811
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3fa4fad811 | ||
![]() |
8b0e97d3ec | ||
![]() |
2ba5c0c391 | ||
![]() |
1367453d2c | ||
![]() |
f92982217f | ||
![]() |
701b254b86 | ||
![]() |
a092dd3ab0 | ||
![]() |
9fdfdf9b48 | ||
![]() |
81bacec484 | ||
![]() |
5b13e3594e | ||
![]() |
3f051fb27b | ||
![]() |
d1294c4183 | ||
![]() |
ca87f43262 | ||
![]() |
c20273f46f | ||
![]() |
48d0315cc4 | ||
![]() |
553bfeb9fc | ||
![]() |
1b9d775508 | ||
![]() |
c75ec8dfc3 | ||
![]() |
f8890bd14a | ||
![]() |
8fc77d92a9 | ||
![]() |
92837f86c0 | ||
![]() |
4b43e39fbb | ||
![]() |
51a89b0cde | ||
![]() |
6536a9c874 | ||
![]() |
1e04ce1e57 | ||
![]() |
199f328999 | ||
![]() |
495a6fe2fe | ||
![]() |
03f183712b | ||
![]() |
f8c6989eaf | ||
![]() |
5c0903b6da | ||
![]() |
d49982a5ba | ||
![]() |
db0dc337bd | ||
![]() |
fd1554fb96 | ||
![]() |
404dbae431 | ||
![]() |
0f1aee3a3c | ||
![]() |
3f8dbc630a | ||
![]() |
5180651400 | ||
![]() |
f0ac80a08a | ||
![]() |
d439b58aaf | ||
![]() |
37745e846d | ||
![]() |
9c00b48dc7 | ||
![]() |
bdaae882a6 | ||
![]() |
37e1290e00 | ||
![]() |
183d369350 | ||
![]() |
d431f1af15 | ||
![]() |
b4ea2b7521 | ||
![]() |
46df529c3a | ||
![]() |
3ed877b301 | ||
![]() |
1e79795fbf | ||
![]() |
cee5a3b62d | ||
![]() |
0807e32278 | ||
![]() |
6bdf396083 | ||
![]() |
5c88a7207d | ||
![]() |
3e8a9958a5 | ||
![]() |
b1b2d03644 | ||
![]() |
2c4b8c9afe | ||
![]() |
0a19a5500c | ||
![]() |
edeb9a7534 | ||
![]() |
06b0817cc2 | ||
![]() |
3051ea5fbb | ||
![]() |
c3175c2cd6 | ||
![]() |
77796ac3f4 | ||
![]() |
45451ac110 | ||
![]() |
2e34223ead | ||
![]() |
c416bca3df | ||
![]() |
b8dc6665dc | ||
![]() |
2fe901cd8d | ||
![]() |
199834ee8f | ||
![]() |
5849c6fff6 | ||
![]() |
de6e43738c | ||
![]() |
4e23a072d4 | ||
![]() |
b600e27f90 | ||
![]() |
2c83e0d07f | ||
![]() |
be327c1b72 | ||
![]() |
4013f15e51 | ||
![]() |
a53014ab89 | ||
![]() |
214420ba1a | ||
![]() |
453dc5062b | ||
![]() |
1b6a4a3be4 | ||
![]() |
3d33149f03 | ||
![]() |
455ff068cb | ||
![]() |
0b06ded65a |
@ -11,6 +11,7 @@ for command in decrypt_documents \
|
||||
mail_fetcher \
|
||||
document_create_classifier \
|
||||
document_index \
|
||||
document_llmindex \
|
||||
document_renamer \
|
||||
document_retagger \
|
||||
document_thumbnails \
|
||||
|
14
docker/rootfs/usr/local/bin/document_llmindex
Executable file
14
docker/rootfs/usr/local/bin/document_llmindex
Executable file
@ -0,0 +1,14 @@
|
||||
#!/command/with-contenv /usr/bin/bash
|
||||
# shellcheck shell=bash
|
||||
|
||||
set -e
|
||||
|
||||
cd "${PAPERLESS_SRC_DIR}"
|
||||
|
||||
if [[ $(id -u) == 0 ]]; then
|
||||
s6-setuidgid paperless python3 manage.py document_llmindex "$@"
|
||||
elif [[ $(id -un) == "paperless" ]]; then
|
||||
python3 manage.py document_llmindex "$@"
|
||||
else
|
||||
echo "Unknown user."
|
||||
fi
|
@ -1708,3 +1708,67 @@ password. All of these options come from their similarly-named [Django settings]
|
||||
#### [`PAPERLESS_EMAIL_USE_SSL=<bool>`](#PAPERLESS_EMAIL_USE_SSL) {#PAPERLESS_EMAIL_USE_SSL}
|
||||
|
||||
: Defaults to false.
|
||||
|
||||
## AI {#ai}
|
||||
|
||||
#### [`PAPERLESS_ENABLE_AI=<bool>`](#PAPERLESS_ENABLE_AI) {#PAPERLESS_ENABLE_AI}
|
||||
|
||||
: Enables the AI features in Paperless. This includes the AI-based
|
||||
suggestions. This setting is required to be set to true in order to use the AI features.
|
||||
|
||||
Defaults to false.
|
||||
|
||||
#### [`PAPERLESS_LLM_EMBEDDING_BACKEND=<str>`](#PAPERLESS_LLM_EMBEDDING_BACKEND) {#PAPERLESS_LLM_EMBEDDING_BACKEND}
|
||||
|
||||
: The embedding backend to use for RAG. This can be either "openai" or "huggingface".
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_LLM_EMBEDDING_MODEL=<str>`](#PAPERLESS_LLM_EMBEDDING_MODEL) {#PAPERLESS_LLM_EMBEDDING_MODEL}
|
||||
|
||||
: The model to use for the embedding backend for RAG. This can be set to any of the embedding models supported by the current embedding backend. If not supplied, defaults to "text-embedding-3-small" for OpenAI and "sentence-transformers/all-MiniLM-L6-v2" for Huggingface.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_AI_BACKEND=<str>`](#PAPERLESS_AI_BACKEND) {#PAPERLESS_AI_BACKEND}
|
||||
|
||||
: The AI backend to use. This can be either "openai" or "ollama". If set to "ollama", the AI
|
||||
features will be run locally on your machine. If set to "openai", the AI features will be run
|
||||
using the OpenAI API. This setting is required to be set to use the AI features.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
!!! note
|
||||
|
||||
The OpenAI API is a paid service. You will need to set up an OpenAI account and
|
||||
will be charged for usage incurred by Paperless-ngx features and your document data
|
||||
will (of course) be sent to the OpenAI API. Paperless-ngx does not endorse the use of the
|
||||
OpenAI API in any way.
|
||||
|
||||
Refer to the OpenAI terms of service, and use at your own risk.
|
||||
|
||||
#### [`PAPERLESS_LLM_MODEL=<str>`](#PAPERLESS_LLM_MODEL) {#PAPERLESS_LLM_MODEL}
|
||||
|
||||
: The model to use for the AI backend, i.e. "gpt-3.5-turbo", "gpt-4" or any of the models supported by the
|
||||
current backend. If not supplied, defaults to "gpt-3.5-turbo" for OpenAI and "llama3" for Ollama.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_LLM_API_KEY=<str>`](#PAPERLESS_LLM_API_KEY) {#PAPERLESS_LLM_API_KEY}
|
||||
|
||||
: The API key to use for the AI backend. This is required for the OpenAI backend only.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_LLM_URL=<str>`](#PAPERLESS_LLM_URL) {#PAPERLESS_LLM_URL}
|
||||
|
||||
: The URL to use for the AI backend. This is required for the Ollama backend only.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_LLM_INDEX_TASK_CRON) {#PAPERLESS_LLM_INDEX_TASK_CRON}
|
||||
|
||||
: Configures the schedule to update the AI embeddings for all documents. Only performed if
|
||||
AI is enabled and the LLM embedding backend is set.
|
||||
|
||||
Defaults to `10 2 * * *`, once per day.
|
||||
|
@ -25,11 +25,12 @@ physical documents into a searchable online archive so you can keep, well, _less
|
||||
## Features
|
||||
|
||||
- **Organize and index** your scanned documents with tags, correspondents, types, and more.
|
||||
- _Your_ data is stored locally on _your_ server and is never transmitted or shared in any way.
|
||||
- _Your_ data is stored locally on _your_ server and is never transmitted or shared in any way, unless you explicitly choose to do so.
|
||||
- Performs **OCR** on your documents, adding searchable and selectable text, even to documents scanned with only images.
|
||||
- Utilizes the open-source Tesseract engine to recognize more than 100 languages.
|
||||
- Documents are saved as PDF/A format which is designed for long term storage, alongside the unaltered originals.
|
||||
- Uses machine-learning to automatically add tags, correspondents and document types to your documents.
|
||||
- **New**: Paperless-ngx can now leverage AI (Large Language Models or LLMs) for document suggestions. This is an optional feature that can be enabled (and is disabled by default).
|
||||
- Supports PDF documents, images, plain text files, Office documents (Word, Excel, Powerpoint, and LibreOffice equivalents)[^1] and more.
|
||||
- Paperless stores your documents plain on disk. Filenames and folders are managed by paperless and their format can be configured freely with different configurations assigned to different documents.
|
||||
- **Beautiful, modern web application** that features:
|
||||
|
@ -261,6 +261,22 @@ Once setup, navigating to the email settings page in Paperless-ngx will allow yo
|
||||
You can also submit a document using the REST API, see [POSTing documents](api.md#file-uploads)
|
||||
for details.
|
||||
|
||||
## Document Suggestions
|
||||
|
||||
Paperless-ngx can suggest tags, correspondents, document types and storage paths for documents based on the content of the document. This is done using a machine learning model that is trained on the documents in your database. The suggestions are shown in the document detail page and can be accepted or rejected by the user.
|
||||
|
||||
## AI Features
|
||||
|
||||
Paperless-ngx includes several features that use AI to enhance the document management experience. These features are optional and can be enabled or disabled in the settings. If you are using the AI features, you may want to also enable the "LLM index" feature, which supports Retrieval-Augmented Generation (RAG) designed to improve the quality of AI responses. The LLM index feature is not enabled by default and requires additional configuration.
|
||||
|
||||
### Document Chat
|
||||
|
||||
Paperless-ngx can use an AI LLM model to answer questions about a document or across multiple documents. Again, this feature works best when RAG is enabled. The chat feature is available in the upper app toolbar and will switch between chatting across multiple documents or a single document based on the current view.
|
||||
|
||||
### AI-Enhanced Suggestions
|
||||
|
||||
If enabled, Paperless-ngx can use an AI LLM model to suggest document titles, dates, tags, correspondents and document types for documents. This feature will always be "opt-in" and does not disable the existing classifier-based suggestion system. Currently, both remote (via the OpenAI API) and local (via Ollama) models are supported, see [configuration](configuration.md#ai) for details.
|
||||
|
||||
## Sharing documents from Paperless-ngx
|
||||
|
||||
Paperless-ngx supports sharing documents with other users by assigning them [permissions](#object-permissions)
|
||||
|
@ -39,6 +39,7 @@ dependencies = [
|
||||
"drf-spectacular~=0.28",
|
||||
"drf-spectacular-sidecar~=2025.4.1",
|
||||
"drf-writable-nested~=0.7.1",
|
||||
"faiss-cpu>=1.10",
|
||||
"filelock~=3.18.0",
|
||||
"flower~=2.0.1",
|
||||
"gotenberg-client~=0.10.0",
|
||||
@ -47,8 +48,15 @@ dependencies = [
|
||||
"inotifyrecursive~=0.3",
|
||||
"jinja2~=3.1.5",
|
||||
"langdetect~=1.0.9",
|
||||
"llama-index-core>=0.12.33.post1",
|
||||
"llama-index-embeddings-huggingface>=0.5.3",
|
||||
"llama-index-embeddings-openai>=0.3.1",
|
||||
"llama-index-llms-ollama>=0.5.4",
|
||||
"llama-index-llms-openai>=0.3.38",
|
||||
"llama-index-vector-stores-faiss>=0.3",
|
||||
"nltk~=3.9.1",
|
||||
"ocrmypdf~=16.10.0",
|
||||
"openai>=1.76",
|
||||
"pathvalidate~=3.2.3",
|
||||
"pdf2image~=1.17.0",
|
||||
"python-dateutil~=2.9.0",
|
||||
@ -60,6 +68,7 @@ dependencies = [
|
||||
"rapidfuzz~=3.13.0",
|
||||
"redis[hiredis]~=5.2.1",
|
||||
"scikit-learn~=1.6.1",
|
||||
"sentence-transformers>=4.1",
|
||||
"setproctitle~=1.3.4",
|
||||
"tika-client~=0.9.0",
|
||||
"tqdm~=4.67.1",
|
||||
@ -240,6 +249,7 @@ testpaths = [
|
||||
"src/paperless_mail/tests/",
|
||||
"src/paperless_tesseract/tests/",
|
||||
"src/paperless_tika/tests",
|
||||
"src/paperless_ai/tests",
|
||||
]
|
||||
addopts = [
|
||||
"--pythonwarnings=all",
|
||||
|
@ -35,6 +35,7 @@
|
||||
@case (ConfigOptionType.String) { <pngx-input-text [formControlName]="option.key" [error]="errors[option.key]"></pngx-input-text> }
|
||||
@case (ConfigOptionType.JSON) { <pngx-input-text [formControlName]="option.key" [error]="errors[option.key]"></pngx-input-text> }
|
||||
@case (ConfigOptionType.File) { <pngx-input-file [formControlName]="option.key" (upload)="uploadFile($event, option.key)" [error]="errors[option.key]"></pngx-input-file> }
|
||||
@case (ConfigOptionType.Password) { <pngx-input-password [formControlName]="option.key" [error]="errors[option.key]"></pngx-input-password> }
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
|
@ -29,6 +29,7 @@ import { SettingsService } from 'src/app/services/settings.service'
|
||||
import { ToastService } from 'src/app/services/toast.service'
|
||||
import { FileComponent } from '../../common/input/file/file.component'
|
||||
import { NumberComponent } from '../../common/input/number/number.component'
|
||||
import { PasswordComponent } from '../../common/input/password/password.component'
|
||||
import { SelectComponent } from '../../common/input/select/select.component'
|
||||
import { SwitchComponent } from '../../common/input/switch/switch.component'
|
||||
import { TextComponent } from '../../common/input/text/text.component'
|
||||
@ -46,6 +47,7 @@ import { LoadingComponentWithPermissions } from '../../loading-component/loading
|
||||
TextComponent,
|
||||
NumberComponent,
|
||||
FileComponent,
|
||||
PasswordComponent,
|
||||
AsyncPipe,
|
||||
NgbNavModule,
|
||||
FormsModule,
|
||||
|
@ -314,6 +314,9 @@ describe('SettingsComponent', () => {
|
||||
sanity_check_status: SystemStatusItemStatus.ERROR,
|
||||
sanity_check_last_run: new Date().toISOString(),
|
||||
sanity_check_error: 'Error running sanity check.',
|
||||
llmindex_status: SystemStatusItemStatus.DISABLED,
|
||||
llmindex_last_modified: new Date().toISOString(),
|
||||
llmindex_error: null,
|
||||
},
|
||||
}
|
||||
jest.spyOn(systemStatusService, 'get').mockReturnValue(of(status))
|
||||
|
@ -30,6 +30,9 @@
|
||||
</div>
|
||||
</div>
|
||||
<ul ngbNav class="order-sm-3">
|
||||
@if (aiEnabled) {
|
||||
<pngx-chat></pngx-chat>
|
||||
}
|
||||
<pngx-toasts-dropdown></pngx-toasts-dropdown>
|
||||
<li ngbDropdown class="nav-item dropdown">
|
||||
<button class="btn ps-1 border-0" id="userDropdown" ngbDropdownToggle>
|
||||
|
@ -44,6 +44,7 @@ import { SettingsService } from 'src/app/services/settings.service'
|
||||
import { TasksService } from 'src/app/services/tasks.service'
|
||||
import { ToastService } from 'src/app/services/toast.service'
|
||||
import { environment } from 'src/environments/environment'
|
||||
import { ChatComponent } from '../chat/chat/chat.component'
|
||||
import { ProfileEditDialogComponent } from '../common/profile-edit-dialog/profile-edit-dialog.component'
|
||||
import { DocumentDetailComponent } from '../document-detail/document-detail.component'
|
||||
import { ComponentWithPermissions } from '../with-permissions/with-permissions.component'
|
||||
@ -59,6 +60,7 @@ import { ToastsDropdownComponent } from './toasts-dropdown/toasts-dropdown.compo
|
||||
DocumentTitlePipe,
|
||||
IfPermissionsDirective,
|
||||
ToastsDropdownComponent,
|
||||
ChatComponent,
|
||||
RouterModule,
|
||||
NgClass,
|
||||
NgbDropdownModule,
|
||||
@ -168,6 +170,10 @@ export class AppFrameComponent
|
||||
})
|
||||
}
|
||||
|
||||
get aiEnabled(): boolean {
|
||||
return this.settingsService.get(SETTINGS_KEYS.AI_ENABLED)
|
||||
}
|
||||
|
||||
closeMenu() {
|
||||
this.isMenuCollapsed = true
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
|
||||
<li ngbDropdown class="nav-item" (openChange)="onOpenChange($event)">
|
||||
<li ngbDropdown class="nav-item mx-1" (openChange)="onOpenChange($event)">
|
||||
@if (toasts.length) {
|
||||
<span class="badge rounded-pill z-3 pe-none bg-secondary me-2 position-absolute top-0 left-0">{{ toasts.length }}</span>
|
||||
}
|
||||
|
35
src-ui/src/app/components/chat/chat/chat.component.html
Normal file
35
src-ui/src/app/components/chat/chat/chat.component.html
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
<li ngbDropdown class="nav-item me-n2" (openChange)="onOpenChange($event)">
|
||||
<button class="btn border-0" id="chatDropdown" ngbDropdownToggle>
|
||||
<i-bs width="1.3em" height="1.3em" name="chatSquareDots"></i-bs>
|
||||
</button>
|
||||
<div ngbDropdownMenu class="dropdown-menu-end shadow p-3" aria-labelledby="chatDropdown">
|
||||
<div class="chat-container bg-light p-2">
|
||||
<div class="chat-messages font-monospace small">
|
||||
@for (message of messages; track message) {
|
||||
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
|
||||
<span class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
|
||||
{{ message.content }}
|
||||
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
|
||||
</span>
|
||||
</div>
|
||||
}
|
||||
<div #scrollAnchor></div>
|
||||
</div>
|
||||
|
||||
<form class="chat-input">
|
||||
<div class="input-group">
|
||||
<input
|
||||
#chatInput
|
||||
class="form-control form-control-sm" name="chatInput" type="text"
|
||||
[placeholder]="placeholder"
|
||||
[disabled]="loading"
|
||||
[(ngModel)]="input"
|
||||
(keydown)="searchInputKeyDown($event)"
|
||||
/>
|
||||
<button class="btn btn-sm btn-secondary" type="button" (click)="sendMessage()" [disabled]="loading">Send</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
37
src-ui/src/app/components/chat/chat/chat.component.scss
Normal file
37
src-ui/src/app/components/chat/chat/chat.component.scss
Normal file
@ -0,0 +1,37 @@
|
||||
.dropdown-menu {
|
||||
width: var(--pngx-toast-max-width);
|
||||
}
|
||||
|
||||
.chat-messages {
|
||||
max-height: 350px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.dropdown-toggle::after {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.dropdown-item {
|
||||
white-space: initial;
|
||||
}
|
||||
|
||||
@media screen and (max-width: 400px) {
|
||||
:host ::ng-deep .dropdown-menu-end {
|
||||
right: -3rem;
|
||||
}
|
||||
}
|
||||
|
||||
.blinking-cursor {
|
||||
font-weight: bold;
|
||||
font-size: 1.2em;
|
||||
animation: blink 1s step-end infinite;
|
||||
}
|
||||
|
||||
@keyframes blink {
|
||||
from, to {
|
||||
opacity: 0;
|
||||
}
|
||||
50% {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
132
src-ui/src/app/components/chat/chat/chat.component.spec.ts
Normal file
132
src-ui/src/app/components/chat/chat/chat.component.spec.ts
Normal file
@ -0,0 +1,132 @@
|
||||
import { provideHttpClient, withInterceptorsFromDi } from '@angular/common/http'
|
||||
import { provideHttpClientTesting } from '@angular/common/http/testing'
|
||||
import { ElementRef } from '@angular/core'
|
||||
import { ComponentFixture, TestBed } from '@angular/core/testing'
|
||||
import { NavigationEnd, Router } from '@angular/router'
|
||||
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { Subject } from 'rxjs'
|
||||
import { ChatService } from 'src/app/services/chat.service'
|
||||
import { ChatComponent } from './chat.component'
|
||||
|
||||
describe('ChatComponent', () => {
|
||||
let component: ChatComponent
|
||||
let fixture: ComponentFixture<ChatComponent>
|
||||
let chatService: ChatService
|
||||
let router: Router
|
||||
let routerEvents$: Subject<NavigationEnd>
|
||||
let mockStream$: Subject<string>
|
||||
|
||||
beforeEach(async () => {
|
||||
TestBed.configureTestingModule({
|
||||
imports: [NgxBootstrapIconsModule.pick(allIcons), ChatComponent],
|
||||
providers: [
|
||||
provideHttpClient(withInterceptorsFromDi()),
|
||||
provideHttpClientTesting(),
|
||||
],
|
||||
}).compileComponents()
|
||||
|
||||
fixture = TestBed.createComponent(ChatComponent)
|
||||
router = TestBed.inject(Router)
|
||||
routerEvents$ = new Subject<any>()
|
||||
jest
|
||||
.spyOn(router, 'events', 'get')
|
||||
.mockReturnValue(routerEvents$.asObservable())
|
||||
chatService = TestBed.inject(ChatService)
|
||||
mockStream$ = new Subject<string>()
|
||||
jest
|
||||
.spyOn(chatService, 'streamChat')
|
||||
.mockReturnValue(mockStream$.asObservable())
|
||||
component = fixture.componentInstance
|
||||
|
||||
jest.useFakeTimers()
|
||||
|
||||
fixture.detectChanges()
|
||||
|
||||
component.scrollAnchor.nativeElement.scrollIntoView = jest.fn()
|
||||
})
|
||||
|
||||
it('should update documentId on initialization', () => {
|
||||
jest.spyOn(router, 'url', 'get').mockReturnValue('/documents/123')
|
||||
component.ngOnInit()
|
||||
expect(component.documentId).toBe(123)
|
||||
})
|
||||
|
||||
it('should update documentId on navigation', () => {
|
||||
component.ngOnInit()
|
||||
routerEvents$.next(new NavigationEnd(1, '/documents/456', '/documents/456'))
|
||||
expect(component.documentId).toBe(456)
|
||||
})
|
||||
|
||||
it('should return correct placeholder based on documentId', () => {
|
||||
component.documentId = 123
|
||||
expect(component.placeholder).toBe('Ask a question about this document...')
|
||||
component.documentId = undefined
|
||||
expect(component.placeholder).toBe('Ask a question about a document...')
|
||||
})
|
||||
|
||||
it('should send a message and handle streaming response', () => {
|
||||
component.input = 'Hello'
|
||||
component.sendMessage()
|
||||
|
||||
expect(component.messages.length).toBe(2)
|
||||
expect(component.messages[0].content).toBe('Hello')
|
||||
expect(component.loading).toBe(true)
|
||||
|
||||
mockStream$.next('Hi')
|
||||
expect(component.messages[1].content).toBe('H')
|
||||
mockStream$.next('Hi there')
|
||||
// advance time to process the typewriter effect
|
||||
jest.advanceTimersByTime(1000)
|
||||
expect(component.messages[1].content).toBe('Hi there')
|
||||
|
||||
mockStream$.complete()
|
||||
expect(component.loading).toBe(false)
|
||||
expect(component.messages[1].isStreaming).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle errors during streaming', () => {
|
||||
component.input = 'Hello'
|
||||
component.sendMessage()
|
||||
|
||||
mockStream$.error('Error')
|
||||
expect(component.messages[1].content).toContain(
|
||||
'⚠️ Error receiving response.'
|
||||
)
|
||||
expect(component.loading).toBe(false)
|
||||
})
|
||||
|
||||
it('should enqueue typewriter chunks correctly', () => {
|
||||
const message = { content: '', role: 'assistant', isStreaming: true }
|
||||
component.enqueueTypewriter(null, message as any) // coverage for null
|
||||
component.enqueueTypewriter('Hello', message as any)
|
||||
expect(component['typewriterBuffer'].length).toBe(4)
|
||||
})
|
||||
|
||||
it('should scroll to bottom after sending a message', () => {
|
||||
const scrollSpy = jest.spyOn(
|
||||
ChatComponent.prototype as any,
|
||||
'scrollToBottom'
|
||||
)
|
||||
component.input = 'Test'
|
||||
component.sendMessage()
|
||||
expect(scrollSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should focus chat input when dropdown is opened', () => {
|
||||
const focus = jest.fn()
|
||||
component.chatInput = {
|
||||
nativeElement: { focus: focus },
|
||||
} as unknown as ElementRef<HTMLInputElement>
|
||||
|
||||
component.onOpenChange(true)
|
||||
jest.advanceTimersByTime(15)
|
||||
expect(focus).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should send message on Enter key press', () => {
|
||||
jest.spyOn(component, 'sendMessage')
|
||||
const event = new KeyboardEvent('keydown', { key: 'Enter' })
|
||||
component.searchInputKeyDown(event)
|
||||
expect(component.sendMessage).toHaveBeenCalled()
|
||||
})
|
||||
})
|
144
src-ui/src/app/components/chat/chat/chat.component.ts
Normal file
144
src-ui/src/app/components/chat/chat/chat.component.ts
Normal file
@ -0,0 +1,144 @@
|
||||
import { Component, ElementRef, OnInit, ViewChild } from '@angular/core'
|
||||
import { FormsModule, ReactiveFormsModule } from '@angular/forms'
|
||||
import { NavigationEnd, Router } from '@angular/router'
|
||||
import { NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap'
|
||||
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { filter, map } from 'rxjs'
|
||||
import { ChatMessage, ChatService } from 'src/app/services/chat.service'
|
||||
|
||||
@Component({
|
||||
selector: 'pngx-chat',
|
||||
imports: [
|
||||
FormsModule,
|
||||
ReactiveFormsModule,
|
||||
NgxBootstrapIconsModule,
|
||||
NgbDropdownModule,
|
||||
],
|
||||
templateUrl: './chat.component.html',
|
||||
styleUrl: './chat.component.scss',
|
||||
})
|
||||
export class ChatComponent implements OnInit {
|
||||
public messages: ChatMessage[] = []
|
||||
public loading = false
|
||||
public input: string = ''
|
||||
public documentId!: number
|
||||
|
||||
@ViewChild('scrollAnchor') scrollAnchor!: ElementRef<HTMLDivElement>
|
||||
@ViewChild('chatInput') chatInput!: ElementRef<HTMLInputElement>
|
||||
|
||||
private typewriterBuffer: string[] = []
|
||||
private typewriterActive = false
|
||||
|
||||
public get placeholder(): string {
|
||||
return this.documentId
|
||||
? $localize`Ask a question about this document...`
|
||||
: $localize`Ask a question about a document...`
|
||||
}
|
||||
|
||||
constructor(
|
||||
private chatService: ChatService,
|
||||
private router: Router
|
||||
) {}
|
||||
|
||||
ngOnInit(): void {
|
||||
this.updateDocumentId(this.router.url)
|
||||
this.router.events
|
||||
.pipe(
|
||||
filter((event) => event instanceof NavigationEnd),
|
||||
map((event) => (event as NavigationEnd).url)
|
||||
)
|
||||
.subscribe((url) => {
|
||||
console.log('URL changed:', url)
|
||||
|
||||
this.updateDocumentId(url)
|
||||
})
|
||||
}
|
||||
|
||||
private updateDocumentId(url: string): void {
|
||||
const docIdRe = url.match(/^\/documents\/(\d+)/)
|
||||
this.documentId = docIdRe ? +docIdRe[1] : undefined
|
||||
}
|
||||
|
||||
sendMessage(): void {
|
||||
if (!this.input.trim()) return
|
||||
|
||||
const userMessage: ChatMessage = { role: 'user', content: this.input }
|
||||
this.messages.push(userMessage)
|
||||
this.scrollToBottom()
|
||||
|
||||
const assistantMessage: ChatMessage = {
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
isStreaming: true,
|
||||
}
|
||||
this.messages.push(assistantMessage)
|
||||
this.loading = true
|
||||
|
||||
let lastPartialLength = 0
|
||||
|
||||
this.chatService.streamChat(this.documentId, this.input).subscribe({
|
||||
next: (chunk) => {
|
||||
const delta = chunk.substring(lastPartialLength)
|
||||
lastPartialLength = chunk.length
|
||||
this.enqueueTypewriter(delta, assistantMessage)
|
||||
},
|
||||
error: () => {
|
||||
assistantMessage.content += '\n\n⚠️ Error receiving response.'
|
||||
assistantMessage.isStreaming = false
|
||||
this.loading = false
|
||||
},
|
||||
complete: () => {
|
||||
assistantMessage.isStreaming = false
|
||||
this.loading = false
|
||||
this.scrollToBottom()
|
||||
},
|
||||
})
|
||||
|
||||
this.input = ''
|
||||
}
|
||||
|
||||
enqueueTypewriter(chunk: string, message: ChatMessage): void {
|
||||
if (!chunk) return
|
||||
|
||||
this.typewriterBuffer.push(...chunk.split(''))
|
||||
|
||||
if (!this.typewriterActive) {
|
||||
this.typewriterActive = true
|
||||
this.playTypewriter(message)
|
||||
}
|
||||
}
|
||||
|
||||
playTypewriter(message: ChatMessage): void {
|
||||
if (this.typewriterBuffer.length === 0) {
|
||||
this.typewriterActive = false
|
||||
return
|
||||
}
|
||||
|
||||
const nextChar = this.typewriterBuffer.shift()!
|
||||
message.content += nextChar
|
||||
this.scrollToBottom()
|
||||
|
||||
setTimeout(() => this.playTypewriter(message), 10) // 10ms per character
|
||||
}
|
||||
|
||||
private scrollToBottom(): void {
|
||||
setTimeout(() => {
|
||||
this.scrollAnchor?.nativeElement?.scrollIntoView({ behavior: 'smooth' })
|
||||
}, 50)
|
||||
}
|
||||
|
||||
public onOpenChange(open: boolean): void {
|
||||
if (open) {
|
||||
setTimeout(() => {
|
||||
this.chatInput.nativeElement.focus()
|
||||
}, 10)
|
||||
}
|
||||
}
|
||||
|
||||
public searchInputKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter') {
|
||||
event.preventDefault()
|
||||
this.sendMessage()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
<div ngbDropdown #fieldDropdown="ngbDropdown" (openChange)="onOpenClose($event)" [popperOptions]="popperOptions" placement="bottom-end">
|
||||
<button class="btn btn-sm btn-outline-primary" id="customFieldsDropdown" [disabled]="disabled" ngbDropdownToggle>
|
||||
<div ngbDropdown #fieldDropdown="ngbDropdown" (openChange)="onOpenClose($event)" [popperOptions]="popperOptions">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" id="customFieldsDropdown" [disabled]="disabled" ngbDropdownToggle>
|
||||
<i-bs name="ui-radios"></i-bs>
|
||||
<div class="d-none d-sm-inline"> <ng-container i18n>Custom Fields</ng-container></div>
|
||||
<div class="d-none d-lg-inline"> <ng-container i18n>Custom Fields</ng-container></div>
|
||||
</button>
|
||||
<div ngbDropdownMenu aria-labelledby="customFieldsDropdown" class="shadow custom-fields-dropdown">
|
||||
<div class="list-group list-group-flush" (keydown)="listKeyDown($event)">
|
||||
|
@ -1,17 +1,24 @@
|
||||
<div class="mb-3">
|
||||
<label class="form-label" [for]="inputId">{{title}}</label>
|
||||
<div class="input-group" [class.is-invalid]="error">
|
||||
<input #inputField [type]="showReveal && textVisible ? 'text' : 'password'" class="form-control" [class.is-invalid]="error" [id]="inputId" [(ngModel)]="value" (focus)="onFocus()" (focusout)="onFocusOut()" (change)="onChange(value)" [disabled]="disabled" [autocomplete]="autocomplete">
|
||||
@if (showReveal) {
|
||||
<button type="button" class="btn btn-outline-secondary" (click)="toggleVisibility()" i18n-title title="Show password" [disabled]="disabled || disableRevealToggle">
|
||||
<i-bs name="eye"></i-bs>
|
||||
</button>
|
||||
<div class="mb-3" [class.pb-3]="error">
|
||||
<div class="row">
|
||||
<div class="d-flex align-items-center position-relative hidden-button-container" [class.col-md-3]="horizontal">
|
||||
@if (title) {
|
||||
<label class="form-label" [class.mb-md-0]="horizontal" [for]="inputId">{{title}}</label>
|
||||
}
|
||||
</div>
|
||||
<div class="position-relative" [class.col-md-9]="horizontal">
|
||||
<div class="input-group" [class.is-invalid]="error">
|
||||
<input #inputField [type]="showReveal && textVisible ? 'text' : 'password'" class="form-control" [class.is-invalid]="error" [id]="inputId" [(ngModel)]="value" (focus)="onFocus()" (focusout)="onFocusOut()" (change)="onChange(value)" [disabled]="disabled" [autocomplete]="autocomplete">
|
||||
@if (showReveal) {
|
||||
<button type="button" class="btn btn-outline-secondary" (click)="toggleVisibility()" i18n-title title="Show password" [disabled]="disabled || disableRevealToggle">
|
||||
<i-bs name="eye"></i-bs>
|
||||
</button>
|
||||
}
|
||||
</div>
|
||||
<div class="invalid-feedback">
|
||||
{{error}}
|
||||
</div>
|
||||
@if (hint) {
|
||||
<small class="form-text text-muted" [innerHTML]="hint | safeHtml"></small>
|
||||
}
|
||||
</div>
|
||||
<div class="invalid-feedback">
|
||||
{{error}}
|
||||
</div>
|
||||
@if (hint) {
|
||||
<small class="form-text text-muted" [innerHTML]="hint | safeHtml"></small>
|
||||
}
|
||||
</div>
|
||||
|
@ -15,6 +15,12 @@
|
||||
@if (hint) {
|
||||
<small class="form-text text-muted" [innerHTML]="hint | safeHtml"></small>
|
||||
}
|
||||
@if (getSuggestion()?.length > 0) {
|
||||
<small>
|
||||
<span i18n>Suggestion:</span>
|
||||
<a (click)="applySuggestion(s)" [routerLink]="[]">{{getSuggestion()}}</a>
|
||||
</small>
|
||||
}
|
||||
<div class="invalid-feedback position-absolute top-100">
|
||||
{{error}}
|
||||
</div>
|
||||
|
@ -26,10 +26,20 @@ describe('TextComponent', () => {
|
||||
|
||||
it('should support use of input field', () => {
|
||||
expect(component.value).toBeUndefined()
|
||||
// TODO: why doesn't this work?
|
||||
// input.value = 'foo'
|
||||
// input.dispatchEvent(new Event('change'))
|
||||
// fixture.detectChanges()
|
||||
// expect(component.value).toEqual('foo')
|
||||
input.value = 'foo'
|
||||
input.dispatchEvent(new Event('input'))
|
||||
fixture.detectChanges()
|
||||
expect(component.value).toBe('foo')
|
||||
})
|
||||
|
||||
it('should support suggestion', () => {
|
||||
component.value = 'foo'
|
||||
component.suggestion = 'foo'
|
||||
expect(component.getSuggestion()).toBe('')
|
||||
component.value = 'bar'
|
||||
expect(component.getSuggestion()).toBe('foo')
|
||||
component.applySuggestion()
|
||||
fixture.detectChanges()
|
||||
expect(component.value).toBe('foo')
|
||||
})
|
||||
})
|
||||
|
@ -4,6 +4,7 @@ import {
|
||||
NG_VALUE_ACCESSOR,
|
||||
ReactiveFormsModule,
|
||||
} from '@angular/forms'
|
||||
import { RouterLink } from '@angular/router'
|
||||
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { SafeHtmlPipe } from 'src/app/pipes/safehtml.pipe'
|
||||
import { AbstractInputComponent } from '../abstract-input'
|
||||
@ -24,6 +25,7 @@ import { AbstractInputComponent } from '../abstract-input'
|
||||
ReactiveFormsModule,
|
||||
SafeHtmlPipe,
|
||||
NgxBootstrapIconsModule,
|
||||
RouterLink,
|
||||
],
|
||||
})
|
||||
export class TextComponent extends AbstractInputComponent<string> {
|
||||
@ -33,7 +35,19 @@ export class TextComponent extends AbstractInputComponent<string> {
|
||||
@Input()
|
||||
placeholder: string = ''
|
||||
|
||||
@Input()
|
||||
suggestion: string = ''
|
||||
|
||||
constructor() {
|
||||
super()
|
||||
}
|
||||
|
||||
getSuggestion() {
|
||||
return this.value !== this.suggestion ? this.suggestion : ''
|
||||
}
|
||||
|
||||
applySuggestion() {
|
||||
this.value = this.suggestion
|
||||
this.onChange(this.value)
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,49 @@
|
||||
<div class="btn-group">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="loading || (suggestions && !aiEnabled)">
|
||||
@if (loading) {
|
||||
<div class="spinner-border spinner-border-sm" role="status"></div>
|
||||
} @else {
|
||||
<i-bs width="1.2em" height="1.2em" name="stars"></i-bs>
|
||||
}
|
||||
<span class="d-none d-lg-inline ps-1" i18n>Suggest</span>
|
||||
@if (totalSuggestions > 0) {
|
||||
<span class="badge bg-primary ms-2">{{ totalSuggestions }}</span>
|
||||
}
|
||||
</button>
|
||||
|
||||
@if (aiEnabled) {
|
||||
<div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
|
||||
<span class="visually-hidden" i18n>Show suggestions</span>
|
||||
</button>
|
||||
|
||||
<div ngbDropdownMenu aria-labelledby="suggestionsDropdown" class="shadow suggestions-dropdown">
|
||||
<div class="list-group list-group-flush small pb-0">
|
||||
@if (!suggestions?.suggested_tags && !suggestions?.suggested_document_types && !suggestions?.suggested_correspondents) {
|
||||
<div class="list-group-item text-muted fst-italic">
|
||||
<small class="text-muted small fst-italic" i18n>No novel suggestions</small>
|
||||
</div>
|
||||
}
|
||||
@if (suggestions?.suggested_tags.length > 0) {
|
||||
<small class="list-group-item text-uppercase text-muted small">Tags</small>
|
||||
@for (tag of suggestions.suggested_tags; track tag) {
|
||||
<button type="button" class="list-group-item list-group-item-action bg-light" (click)="addTag.emit(tag)" i18n>{{ tag }}</button>
|
||||
}
|
||||
}
|
||||
@if (suggestions?.suggested_document_types.length > 0) {
|
||||
<div class="list-group-item text-uppercase text-muted small">Document Types</div>
|
||||
@for (type of suggestions.suggested_document_types; track type) {
|
||||
<button type="button" class="list-group-item list-group-item-action bg-light" (click)="addDocumentType.emit(type)" i18n>{{ type }}</button>
|
||||
}
|
||||
}
|
||||
@if (suggestions?.suggested_correspondents.length > 0) {
|
||||
<div class="list-group-item text-uppercase text-muted small">Correspondents</div>
|
||||
@for (correspondent of suggestions.suggested_correspondents; track correspondent) {
|
||||
<button type="button" class="list-group-item list-group-item-action bg-light" (click)="addCorrespondent.emit(correspondent)" i18n>{{ correspondent }}</button>
|
||||
}
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
</div>
|
@ -0,0 +1,3 @@
|
||||
.suggestions-dropdown {
|
||||
min-width: 250px;
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
import { ComponentFixture, TestBed } from '@angular/core/testing'
|
||||
import { NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap'
|
||||
import { NgxBootstrapIconsModule, allIcons } from 'ngx-bootstrap-icons'
|
||||
import { SuggestionsDropdownComponent } from './suggestions-dropdown.component'
|
||||
|
||||
describe('SuggestionsDropdownComponent', () => {
|
||||
let component: SuggestionsDropdownComponent
|
||||
let fixture: ComponentFixture<SuggestionsDropdownComponent>
|
||||
|
||||
beforeEach(() => {
|
||||
TestBed.configureTestingModule({
|
||||
imports: [
|
||||
NgbDropdownModule,
|
||||
NgxBootstrapIconsModule.pick(allIcons),
|
||||
SuggestionsDropdownComponent,
|
||||
],
|
||||
providers: [],
|
||||
})
|
||||
fixture = TestBed.createComponent(SuggestionsDropdownComponent)
|
||||
component = fixture.componentInstance
|
||||
fixture.detectChanges()
|
||||
})
|
||||
|
||||
it('should calculate totalSuggestions', () => {
|
||||
component.suggestions = {
|
||||
suggested_correspondents: ['John Doe'],
|
||||
suggested_tags: ['Tag1', 'Tag2'],
|
||||
suggested_document_types: ['Type1'],
|
||||
}
|
||||
expect(component.totalSuggestions).toBe(4)
|
||||
})
|
||||
|
||||
it('should emit getSuggestions when clickSuggest is called and suggestions are null', () => {
|
||||
jest.spyOn(component.getSuggestions, 'emit')
|
||||
component.suggestions = null
|
||||
component.clickSuggest()
|
||||
expect(component.getSuggestions.emit).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => {
|
||||
component.aiEnabled = true
|
||||
fixture.detectChanges()
|
||||
component.suggestions = {
|
||||
suggested_correspondents: [],
|
||||
suggested_tags: [],
|
||||
suggested_document_types: [],
|
||||
}
|
||||
component.clickSuggest()
|
||||
expect(component.dropdown.open).toBeTruthy()
|
||||
})
|
||||
})
|
@ -0,0 +1,64 @@
|
||||
import {
|
||||
Component,
|
||||
EventEmitter,
|
||||
Input,
|
||||
Output,
|
||||
ViewChild,
|
||||
} from '@angular/core'
|
||||
import { NgbDropdown, NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap'
|
||||
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { DocumentSuggestions } from 'src/app/data/document-suggestions'
|
||||
import { pngxPopperOptions } from 'src/app/utils/popper-options'
|
||||
|
||||
@Component({
|
||||
selector: 'pngx-suggestions-dropdown',
|
||||
imports: [NgbDropdownModule, NgxBootstrapIconsModule],
|
||||
templateUrl: './suggestions-dropdown.component.html',
|
||||
styleUrl: './suggestions-dropdown.component.scss',
|
||||
})
|
||||
export class SuggestionsDropdownComponent {
|
||||
public popperOptions = pngxPopperOptions
|
||||
|
||||
@ViewChild('dropdown') dropdown: NgbDropdown
|
||||
|
||||
@Input()
|
||||
suggestions: DocumentSuggestions = null
|
||||
|
||||
@Input()
|
||||
aiEnabled: boolean = false
|
||||
|
||||
@Input()
|
||||
loading: boolean = false
|
||||
|
||||
@Input()
|
||||
disabled: boolean = false
|
||||
|
||||
@Output()
|
||||
getSuggestions: EventEmitter<SuggestionsDropdownComponent> =
|
||||
new EventEmitter()
|
||||
|
||||
@Output()
|
||||
addTag: EventEmitter<string> = new EventEmitter()
|
||||
|
||||
@Output()
|
||||
addDocumentType: EventEmitter<string> = new EventEmitter()
|
||||
|
||||
@Output()
|
||||
addCorrespondent: EventEmitter<string> = new EventEmitter()
|
||||
|
||||
public clickSuggest(): void {
|
||||
if (!this.suggestions) {
|
||||
this.getSuggestions.emit(this)
|
||||
} else {
|
||||
this.dropdown?.toggle()
|
||||
}
|
||||
}
|
||||
|
||||
get totalSuggestions(): number {
|
||||
return (
|
||||
this.suggestions?.suggested_correspondents?.length +
|
||||
this.suggestions?.suggested_tags?.length +
|
||||
this.suggestions?.suggested_document_types?.length || 0
|
||||
)
|
||||
}
|
||||
}
|
@ -254,6 +254,43 @@
|
||||
<h6><ng-container i18n>Error</ng-container>:</h6> <span class="font-monospace small">{{status.tasks.sanity_check_error}}</span>
|
||||
}
|
||||
</ng-template>
|
||||
@if (aiEnabled) {
|
||||
<dt i18n>AI Index</dt>
|
||||
<dd class="d-flex align-items-center">
|
||||
<button class="btn btn-sm d-flex align-items-center btn-dark text-uppercase small" [ngbPopover]="llmIndexStatus" triggers="click mouseenter:mouseleave">
|
||||
{{status.tasks.llmindex_status}}
|
||||
@if (status.tasks.llmindex_status === 'OK') {
|
||||
@if (isStale(status.tasks.llmindex_last_modified)) {
|
||||
<i-bs name="exclamation-triangle-fill" class="text-warning ms-2 lh-1"></i-bs>
|
||||
} @else {
|
||||
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
|
||||
}
|
||||
} @else {
|
||||
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
|
||||
[class.text-danger]="status.tasks.llmindex_status === SystemStatusItemStatus.ERROR"
|
||||
[class.text-warning]="status.tasks.llmindex_status === SystemStatusItemStatus.WARNING"
|
||||
[class.text-muted]="status.tasks.llmindex_status === SystemStatusItemStatus.DISABLED"></i-bs>
|
||||
}
|
||||
</button>
|
||||
@if (currentUserIsSuperUser) {
|
||||
@if (isRunning(PaperlessTaskName.LLMIndexUpdate)) {
|
||||
<div class="spinner-border spinner-border-sm ms-2" role="status"></div>
|
||||
} @else {
|
||||
<button class="btn btn-sm d-flex align-items-center btn-dark small ms-2" (click)="runTask(PaperlessTaskName.LLMIndexUpdate)">
|
||||
<i-bs name="play-fill"></i-bs>
|
||||
<ng-container i18n>Run Task</ng-container>
|
||||
</button>
|
||||
}
|
||||
}
|
||||
</dd>
|
||||
<ng-template #llmIndexStatus>
|
||||
@if (status.tasks.llmindex_status === 'OK') {
|
||||
<h6><ng-container i18n>Last Run</ng-container>:</h6> <span class="font-monospace small">{{status.tasks.llmindex_last_modified | customDate:'medium'}}</span>
|
||||
} @else {
|
||||
<h6><ng-container i18n>Error</ng-container>:</h6> <span class="font-monospace small">{{status.tasks.llmindex_error}}</span>
|
||||
}
|
||||
</ng-template>
|
||||
}
|
||||
</dl>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -67,6 +67,9 @@ const status: SystemStatus = {
|
||||
sanity_check_status: SystemStatusItemStatus.OK,
|
||||
sanity_check_last_run: new Date().toISOString(),
|
||||
sanity_check_error: null,
|
||||
llmindex_status: SystemStatusItemStatus.OK,
|
||||
llmindex_last_modified: new Date().toISOString(),
|
||||
llmindex_error: null,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -12,9 +12,11 @@ import {
|
||||
SystemStatus,
|
||||
SystemStatusItemStatus,
|
||||
} from 'src/app/data/system-status'
|
||||
import { SETTINGS_KEYS } from 'src/app/data/ui-settings'
|
||||
import { CustomDatePipe } from 'src/app/pipes/custom-date.pipe'
|
||||
import { FileSizePipe } from 'src/app/pipes/file-size.pipe'
|
||||
import { PermissionsService } from 'src/app/services/permissions.service'
|
||||
import { SettingsService } from 'src/app/services/settings.service'
|
||||
import { SystemStatusService } from 'src/app/services/system-status.service'
|
||||
import { TasksService } from 'src/app/services/tasks.service'
|
||||
import { ToastService } from 'src/app/services/toast.service'
|
||||
@ -49,13 +51,18 @@ export class SystemStatusDialogComponent implements OnInit {
|
||||
return this.permissionsService.isSuperUser()
|
||||
}
|
||||
|
||||
get aiEnabled(): boolean {
|
||||
return this.settingsService.get(SETTINGS_KEYS.AI_ENABLED)
|
||||
}
|
||||
|
||||
constructor(
|
||||
public activeModal: NgbActiveModal,
|
||||
private clipboard: Clipboard,
|
||||
private systemStatusService: SystemStatusService,
|
||||
private tasksService: TasksService,
|
||||
private toastService: ToastService,
|
||||
private permissionsService: PermissionsService
|
||||
private permissionsService: PermissionsService,
|
||||
private settingsService: SettingsService
|
||||
) {}
|
||||
|
||||
public ngOnInit() {
|
||||
|
@ -72,16 +72,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<pngx-custom-fields-dropdown
|
||||
*pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.CustomField }"
|
||||
[documentId]="documentId"
|
||||
[disabled]="!userCanEdit"
|
||||
[existingFields]="document?.custom_fields"
|
||||
(created)="refreshCustomFields()"
|
||||
(added)="addField($event)">
|
||||
</pngx-custom-fields-dropdown>
|
||||
|
||||
|
||||
<div class="ms-auto" ngbDropdown>
|
||||
<button class="btn btn-sm btn-outline-primary" id="sendDropdown" ngbDropdownToggle>
|
||||
<i-bs name="send"></i-bs>
|
||||
@ -102,7 +92,7 @@
|
||||
</pngx-page-header>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-6 col-xl-4 mb-4">
|
||||
<div class="col-md-6 col-xl-5 mb-4">
|
||||
|
||||
<form [formGroup]='documentForm' (ngSubmit)="save()">
|
||||
|
||||
@ -119,6 +109,32 @@
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<ng-container *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.Document }">
|
||||
<div class="btn-group pb-3 ms-auto">
|
||||
<pngx-suggestions-dropdown *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.Document }"
|
||||
[disabled]="!userCanEdit || suggestionsLoading"
|
||||
[loading]="suggestionsLoading"
|
||||
[suggestions]="suggestions"
|
||||
[aiEnabled]="aiEnabled"
|
||||
(getSuggestions)="getSuggestions()"
|
||||
(addTag)="createTag($event)"
|
||||
(addDocumentType)="createDocumentType($event)"
|
||||
(addCorrespondent)="createCorrespondent($event)">
|
||||
</pngx-suggestions-dropdown>
|
||||
</div>
|
||||
|
||||
<div class="btn-group pb-3 ms-2">
|
||||
<pngx-custom-fields-dropdown
|
||||
*pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.CustomField }"
|
||||
[documentId]="documentId"
|
||||
[disabled]="!userCanEdit"
|
||||
[existingFields]="document?.custom_fields"
|
||||
(created)="refreshCustomFields()"
|
||||
(added)="addField($event)">
|
||||
</pngx-custom-fields-dropdown>
|
||||
</div>
|
||||
</ng-container>
|
||||
|
||||
<ng-container *ngTemplateOutlet="saveButtons"></ng-container>
|
||||
</div>
|
||||
|
||||
@ -127,7 +143,7 @@
|
||||
<a ngbNavLink i18n>Details</a>
|
||||
<ng-template ngbNavContent>
|
||||
<div>
|
||||
<pngx-input-text #inputTitle i18n-title title="Title" formControlName="title" [horizontal]="true" (keyup)="titleKeyUp($event)" [error]="error?.title"></pngx-input-text>
|
||||
<pngx-input-text #inputTitle i18n-title title="Title" formControlName="title" [horizontal]="true" [suggestion]="suggestions?.title" (keyup)="titleKeyUp($event)" [error]="error?.title"></pngx-input-text>
|
||||
<pngx-input-number i18n-title title="Archive serial number" [error]="error?.archive_serial_number" [horizontal]="true" formControlName='archive_serial_number'></pngx-input-number>
|
||||
<pngx-input-date i18n-title title="Date created" formControlName="created" [suggestions]="suggestions?.dates" [showFilter]="true" [horizontal]="true" (filterDocuments)="filterDocuments($event)"
|
||||
[error]="error?.created"></pngx-input-date>
|
||||
@ -137,7 +153,7 @@
|
||||
(createNew)="createDocumentType($event)" [hideAddButton]="createDisabled(DataType.DocumentType)" [suggestions]="suggestions?.document_types" *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.DocumentType }"></pngx-input-select>
|
||||
<pngx-input-select [items]="storagePaths" i18n-title title="Storage path" formControlName="storage_path" [allowNull]="true" [showFilter]="true" [horizontal]="true" (filterDocuments)="filterDocuments($event, DataType.StoragePath)"
|
||||
(createNew)="createStoragePath($event)" [hideAddButton]="createDisabled(DataType.StoragePath)" [suggestions]="suggestions?.storage_paths" i18n-placeholder placeholder="Default" *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.StoragePath }"></pngx-input-select>
|
||||
<pngx-input-tags formControlName="tags" [suggestions]="suggestions?.tags" [showFilter]="true" [horizontal]="true" (filterDocuments)="filterDocuments($event, DataType.Tag)" [hideAddButton]="createDisabled(DataType.Tag)" *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.Tag }"></pngx-input-tags>
|
||||
<pngx-input-tags #tagsInput formControlName="tags" [suggestions]="suggestions?.tags" [showFilter]="true" [horizontal]="true" (filterDocuments)="filterDocuments($event, DataType.Tag)" [hideAddButton]="createDisabled(DataType.Tag)" *pngxIfPermissions="{ action: PermissionAction.View, type: PermissionType.Tag }"></pngx-input-tags>
|
||||
@for (fieldInstance of document?.custom_fields; track fieldInstance.field; let i = $index) {
|
||||
<div [formGroup]="customFieldFormFields.controls[i]">
|
||||
@switch (getCustomFieldFromInstance(fieldInstance)?.data_type) {
|
||||
@ -351,14 +367,14 @@
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="col-md-6 col-xl-8 mb-3 d-none d-md-block position-relative" #pdfPreview>
|
||||
<div class="col-md-6 col-xl-7 mb-3 d-none d-md-block position-relative" #pdfPreview>
|
||||
<ng-container *ngTemplateOutlet="previewContent"></ng-container>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<ng-template #saveButtons>
|
||||
<div class="btn-group pb-3 ms-auto">
|
||||
<div class="btn-group pb-3 ms-4">
|
||||
<ng-container *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.Document }">
|
||||
<button type="submit" class="order-3 btn btn-sm btn-primary" i18n [disabled]="!userCanEdit || networkActive || (isDirty$ | async) !== true">Save</button>
|
||||
@if (hasNext()) {
|
||||
|
@ -156,6 +156,16 @@ describe('DocumentDetailComponent', () => {
|
||||
{
|
||||
provide: TagService,
|
||||
useValue: {
|
||||
getCachedMany: (ids: number[]) =>
|
||||
of(
|
||||
ids.map((id) => ({
|
||||
id,
|
||||
name: `Tag${id}`,
|
||||
is_inbox_tag: true,
|
||||
color: '#ff0000',
|
||||
text_color: '#000000',
|
||||
}))
|
||||
),
|
||||
listAll: () =>
|
||||
of({
|
||||
count: 3,
|
||||
@ -382,8 +392,32 @@ describe('DocumentDetailComponent', () => {
|
||||
currentUserCan = true
|
||||
})
|
||||
|
||||
it('should support creating document type', () => {
|
||||
it('should support creating tag, remove from suggestions', () => {
|
||||
initNormally()
|
||||
component.suggestions = {
|
||||
suggested_tags: ['Tag1', 'NewTag12'],
|
||||
}
|
||||
let openModal: NgbModalRef
|
||||
modalService.activeInstances.subscribe((modal) => (openModal = modal[0]))
|
||||
const modalSpy = jest.spyOn(modalService, 'open')
|
||||
component.createTag('NewTag12')
|
||||
expect(modalSpy).toHaveBeenCalled()
|
||||
openModal.componentInstance.succeeded.next({
|
||||
id: 12,
|
||||
name: 'NewTag12',
|
||||
is_inbox_tag: true,
|
||||
color: '#ff0000',
|
||||
text_color: '#000000',
|
||||
})
|
||||
expect(component.documentForm.get('tags').value).toContain(12)
|
||||
expect(component.suggestions.suggested_tags).not.toContain('NewTag12')
|
||||
})
|
||||
|
||||
it('should support creating document type, remove from suggestions', () => {
|
||||
initNormally()
|
||||
component.suggestions = {
|
||||
suggested_document_types: ['DocumentType1', 'NewDocType2'],
|
||||
}
|
||||
let openModal: NgbModalRef
|
||||
modalService.activeInstances.subscribe((modal) => (openModal = modal[0]))
|
||||
const modalSpy = jest.spyOn(modalService, 'open')
|
||||
@ -391,10 +425,16 @@ describe('DocumentDetailComponent', () => {
|
||||
expect(modalSpy).toHaveBeenCalled()
|
||||
openModal.componentInstance.succeeded.next({ id: 12, name: 'NewDocType12' })
|
||||
expect(component.documentForm.get('document_type').value).toEqual(12)
|
||||
expect(component.suggestions.suggested_document_types).not.toContain(
|
||||
'NewDocType2'
|
||||
)
|
||||
})
|
||||
|
||||
it('should support creating correspondent', () => {
|
||||
it('should support creating correspondent, remove from suggestions', () => {
|
||||
initNormally()
|
||||
component.suggestions = {
|
||||
suggested_correspondents: ['Correspondent1', 'NewCorrrespondent12'],
|
||||
}
|
||||
let openModal: NgbModalRef
|
||||
modalService.activeInstances.subscribe((modal) => (openModal = modal[0]))
|
||||
const modalSpy = jest.spyOn(modalService, 'open')
|
||||
@ -405,6 +445,9 @@ describe('DocumentDetailComponent', () => {
|
||||
name: 'NewCorrrespondent12',
|
||||
})
|
||||
expect(component.documentForm.get('correspondent').value).toEqual(12)
|
||||
expect(component.suggestions.suggested_correspondents).not.toContain(
|
||||
'NewCorrrespondent12'
|
||||
)
|
||||
})
|
||||
|
||||
it('should support creating storage path', () => {
|
||||
@ -983,7 +1026,7 @@ describe('DocumentDetailComponent', () => {
|
||||
expect(component.document.custom_fields).toHaveLength(initialLength - 1)
|
||||
expect(component.customFieldFormFields).toHaveLength(initialLength - 1)
|
||||
expect(
|
||||
fixture.debugElement.query(By.css('form')).nativeElement.textContent
|
||||
fixture.debugElement.query(By.css('form ul')).nativeElement.textContent
|
||||
).not.toContain('Field 1')
|
||||
const patchSpy = jest.spyOn(documentService, 'patch')
|
||||
component.save(true)
|
||||
@ -1058,10 +1101,22 @@ describe('DocumentDetailComponent', () => {
|
||||
|
||||
it('should get suggestions', () => {
|
||||
const suggestionsSpy = jest.spyOn(documentService, 'getSuggestions')
|
||||
suggestionsSpy.mockReturnValue(of({ tags: [42, 43] }))
|
||||
suggestionsSpy.mockReturnValue(
|
||||
of({
|
||||
tags: [42, 43],
|
||||
suggested_tags: [],
|
||||
suggested_document_types: [],
|
||||
suggested_correspondents: [],
|
||||
})
|
||||
)
|
||||
initNormally()
|
||||
expect(suggestionsSpy).toHaveBeenCalled()
|
||||
expect(component.suggestions).toEqual({ tags: [42, 43] })
|
||||
expect(component.suggestions).toEqual({
|
||||
tags: [42, 43],
|
||||
suggested_tags: [],
|
||||
suggested_document_types: [],
|
||||
suggested_correspondents: [],
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error if needed for get suggestions', () => {
|
||||
|
@ -74,6 +74,7 @@ import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service
|
||||
import { DocumentTypeService } from 'src/app/services/rest/document-type.service'
|
||||
import { DocumentService } from 'src/app/services/rest/document.service'
|
||||
import { StoragePathService } from 'src/app/services/rest/storage-path.service'
|
||||
import { TagService } from 'src/app/services/rest/tag.service'
|
||||
import { UserService } from 'src/app/services/rest/user.service'
|
||||
import { SettingsService } from 'src/app/services/settings.service'
|
||||
import { ToastService } from 'src/app/services/toast.service'
|
||||
@ -89,6 +90,7 @@ import { CorrespondentEditDialogComponent } from '../common/edit-dialog/correspo
|
||||
import { DocumentTypeEditDialogComponent } from '../common/edit-dialog/document-type-edit-dialog/document-type-edit-dialog.component'
|
||||
import { EditDialogMode } from '../common/edit-dialog/edit-dialog.component'
|
||||
import { StoragePathEditDialogComponent } from '../common/edit-dialog/storage-path-edit-dialog/storage-path-edit-dialog.component'
|
||||
import { TagEditDialogComponent } from '../common/edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
|
||||
import { EmailDocumentDialogComponent } from '../common/email-document-dialog/email-document-dialog.component'
|
||||
import { CheckComponent } from '../common/input/check/check.component'
|
||||
import { DateComponent } from '../common/input/date/date.component'
|
||||
@ -102,6 +104,7 @@ import { TextComponent } from '../common/input/text/text.component'
|
||||
import { UrlComponent } from '../common/input/url/url.component'
|
||||
import { PageHeaderComponent } from '../common/page-header/page-header.component'
|
||||
import { ShareLinksDialogComponent } from '../common/share-links-dialog/share-links-dialog.component'
|
||||
import { SuggestionsDropdownComponent } from '../common/suggestions-dropdown/suggestions-dropdown.component'
|
||||
import { DocumentHistoryComponent } from '../document-history/document-history.component'
|
||||
import { DocumentNotesComponent } from '../document-notes/document-notes.component'
|
||||
import { ComponentWithPermissions } from '../with-permissions/with-permissions.component'
|
||||
@ -158,6 +161,7 @@ export enum ZoomSetting {
|
||||
NumberComponent,
|
||||
MonetaryComponent,
|
||||
UrlComponent,
|
||||
SuggestionsDropdownComponent,
|
||||
CustomDatePipe,
|
||||
FileSizePipe,
|
||||
IfPermissionsDirective,
|
||||
@ -179,6 +183,8 @@ export class DocumentDetailComponent
|
||||
@ViewChild('inputTitle')
|
||||
titleInput: TextComponent
|
||||
|
||||
@ViewChild('tagsInput') tagsInput: TagsComponent
|
||||
|
||||
expandOriginalMetadata = false
|
||||
expandArchivedMetadata = false
|
||||
|
||||
@ -190,6 +196,7 @@ export class DocumentDetailComponent
|
||||
document: Document
|
||||
metadata: DocumentMetadata
|
||||
suggestions: DocumentSuggestions
|
||||
suggestionsLoading: boolean = false
|
||||
users: User[]
|
||||
|
||||
title: string
|
||||
@ -262,6 +269,7 @@ export class DocumentDetailComponent
|
||||
constructor(
|
||||
private documentsService: DocumentService,
|
||||
private route: ActivatedRoute,
|
||||
private tagService: TagService,
|
||||
private correspondentService: CorrespondentService,
|
||||
private documentTypeService: DocumentTypeService,
|
||||
private router: Router,
|
||||
@ -291,6 +299,10 @@ export class DocumentDetailComponent
|
||||
return this.settings.get(SETTINGS_KEYS.USE_NATIVE_PDF_VIEWER)
|
||||
}
|
||||
|
||||
get aiEnabled(): boolean {
|
||||
return this.settings.get(SETTINGS_KEYS.AI_ENABLED)
|
||||
}
|
||||
|
||||
get archiveContentRenderType(): ContentRenderType {
|
||||
return this.document?.archived_file_name
|
||||
? this.getRenderType('application/pdf')
|
||||
@ -645,25 +657,12 @@ export class DocumentDetailComponent
|
||||
PermissionType.Document
|
||||
)
|
||||
) {
|
||||
this.documentsService
|
||||
.getSuggestions(doc.id)
|
||||
.pipe(
|
||||
first(),
|
||||
takeUntil(this.unsubscribeNotifier),
|
||||
takeUntil(this.docChangeNotifier)
|
||||
)
|
||||
.subscribe({
|
||||
next: (result) => {
|
||||
this.suggestions = result
|
||||
},
|
||||
error: (error) => {
|
||||
this.suggestions = null
|
||||
this.toastService.showError(
|
||||
$localize`Error retrieving suggestions.`,
|
||||
error
|
||||
)
|
||||
},
|
||||
})
|
||||
this.tagService.getCachedMany(doc.tags).subscribe((tags) => {
|
||||
// only show suggestions if document has inbox tags
|
||||
if (tags.some((tag) => tag.is_inbox_tag)) {
|
||||
this.getSuggestions()
|
||||
}
|
||||
})
|
||||
}
|
||||
this.title = this.documentTitlePipe.transform(doc.title)
|
||||
const docFormValues = Object.assign({}, doc)
|
||||
@ -680,6 +679,56 @@ export class DocumentDetailComponent
|
||||
return this.documentForm.get('custom_fields') as FormArray
|
||||
}
|
||||
|
||||
getSuggestions() {
|
||||
this.suggestionsLoading = true
|
||||
this.documentsService
|
||||
.getSuggestions(this.documentId)
|
||||
.pipe(
|
||||
first(),
|
||||
takeUntil(this.unsubscribeNotifier),
|
||||
takeUntil(this.docChangeNotifier)
|
||||
)
|
||||
.subscribe({
|
||||
next: (result) => {
|
||||
this.suggestions = result
|
||||
this.suggestionsLoading = false
|
||||
},
|
||||
error: (error) => {
|
||||
this.suggestions = null
|
||||
this.suggestionsLoading = false
|
||||
this.toastService.showError(
|
||||
$localize`Error retrieving suggestions.`,
|
||||
error
|
||||
)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
createTag(newName: string) {
|
||||
var modal = this.modalService.open(TagEditDialogComponent, {
|
||||
backdrop: 'static',
|
||||
})
|
||||
modal.componentInstance.dialogMode = EditDialogMode.CREATE
|
||||
if (newName) modal.componentInstance.object = { name: newName }
|
||||
modal.componentInstance.succeeded
|
||||
.pipe(
|
||||
switchMap((newTag) => {
|
||||
return this.tagService
|
||||
.listAll()
|
||||
.pipe(map((tags) => ({ newTag, tags })))
|
||||
})
|
||||
)
|
||||
.pipe(takeUntil(this.unsubscribeNotifier))
|
||||
.subscribe(({ newTag, tags }) => {
|
||||
this.tagsInput.tags = tags.results
|
||||
this.tagsInput.addTag(newTag.id)
|
||||
if (this.suggestions) {
|
||||
this.suggestions.suggested_tags =
|
||||
this.suggestions.suggested_tags.filter((tag) => tag !== newName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
createDocumentType(newName: string) {
|
||||
var modal = this.modalService.open(DocumentTypeEditDialogComponent, {
|
||||
backdrop: 'static',
|
||||
@ -699,6 +748,12 @@ export class DocumentDetailComponent
|
||||
this.documentTypes = documentTypes.results
|
||||
this.documentForm.get('document_type').setValue(newDocumentType.id)
|
||||
this.documentForm.get('document_type').markAsDirty()
|
||||
if (this.suggestions) {
|
||||
this.suggestions.suggested_document_types =
|
||||
this.suggestions.suggested_document_types.filter(
|
||||
(dt) => dt !== newName
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -723,6 +778,12 @@ export class DocumentDetailComponent
|
||||
this.correspondents = correspondents.results
|
||||
this.documentForm.get('correspondent').setValue(newCorrespondent.id)
|
||||
this.documentForm.get('correspondent').markAsDirty()
|
||||
if (this.suggestions) {
|
||||
this.suggestions.suggested_correspondents =
|
||||
this.suggestions.suggested_correspondents.filter(
|
||||
(c) => c !== newName
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,17 @@
|
||||
export interface DocumentSuggestions {
|
||||
title?: string
|
||||
|
||||
tags?: number[]
|
||||
suggested_tags?: string[]
|
||||
|
||||
correspondents?: number[]
|
||||
suggested_correspondents?: string[]
|
||||
|
||||
document_types?: number[]
|
||||
suggested_document_types?: string[]
|
||||
|
||||
storage_paths?: number[]
|
||||
suggested_storage_paths?: string[]
|
||||
|
||||
dates?: string[] // ISO-formatted date string e.g. 2022-11-03
|
||||
}
|
||||
|
@ -44,12 +44,24 @@ export enum ConfigOptionType {
|
||||
Boolean = 'boolean',
|
||||
JSON = 'json',
|
||||
File = 'file',
|
||||
Password = 'password',
|
||||
}
|
||||
|
||||
export const ConfigCategory = {
|
||||
General: $localize`General Settings`,
|
||||
OCR: $localize`OCR Settings`,
|
||||
Barcode: $localize`Barcode Settings`,
|
||||
AI: $localize`AI Settings`,
|
||||
}
|
||||
|
||||
export const LLMEmbeddingBackendConfig = {
|
||||
OPENAI: 'openai',
|
||||
HUGGINGFACE: 'huggingface',
|
||||
}
|
||||
|
||||
export const LLMBackendConfig = {
|
||||
OPENAI: 'openai',
|
||||
OLLAMA: 'ollama',
|
||||
}
|
||||
|
||||
export interface ConfigOption {
|
||||
@ -258,6 +270,57 @@ export const PaperlessConfigOptions: ConfigOption[] = [
|
||||
config_key: 'PAPERLESS_CONSUMER_TAG_BARCODE_MAPPING',
|
||||
category: ConfigCategory.Barcode,
|
||||
},
|
||||
{
|
||||
key: 'ai_enabled',
|
||||
title: $localize`AI Enabled`,
|
||||
type: ConfigOptionType.Boolean,
|
||||
config_key: 'PAPERLESS_AI_ENABLED',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_embedding_backend',
|
||||
title: $localize`LLM Embedding Backend`,
|
||||
type: ConfigOptionType.Select,
|
||||
choices: mapToItems(LLMEmbeddingBackendConfig),
|
||||
config_key: 'PAPERLESS_LLM_EMBEDDING_BACKEND',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_embedding_model',
|
||||
title: $localize`LLM Embedding Model`,
|
||||
type: ConfigOptionType.String,
|
||||
config_key: 'PAPERLESS_LLM_EMBEDDING_MODEL',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_backend',
|
||||
title: $localize`LLM Backend`,
|
||||
type: ConfigOptionType.Select,
|
||||
choices: mapToItems(LLMBackendConfig),
|
||||
config_key: 'PAPERLESS_LLM_BACKEND',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_model',
|
||||
title: $localize`LLM Model`,
|
||||
type: ConfigOptionType.String,
|
||||
config_key: 'PAPERLESS_LLM_MODEL',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_api_key',
|
||||
title: $localize`LLM API Key`,
|
||||
type: ConfigOptionType.Password,
|
||||
config_key: 'PAPERLESS_LLM_API_KEY',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_url',
|
||||
title: $localize`LLM URL`,
|
||||
type: ConfigOptionType.String,
|
||||
config_key: 'PAPERLESS_LLM_URL',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
]
|
||||
|
||||
export interface PaperlessConfig extends ObjectWithId {
|
||||
@ -287,4 +350,11 @@ export interface PaperlessConfig extends ObjectWithId {
|
||||
barcode_max_pages: number
|
||||
barcode_enable_tag: boolean
|
||||
barcode_tag_mapping: object
|
||||
ai_enabled: boolean
|
||||
llm_embedding_backend: string
|
||||
llm_embedding_model: string
|
||||
llm_backend: string
|
||||
llm_model: string
|
||||
llm_api_key: string
|
||||
llm_url: string
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ export enum PaperlessTaskName {
|
||||
TrainClassifier = 'train_classifier',
|
||||
SanityCheck = 'check_sanity',
|
||||
IndexOptimize = 'index_optimize',
|
||||
LLMIndexUpdate = 'llmindex_update',
|
||||
}
|
||||
|
||||
export enum PaperlessTaskStatus {
|
||||
|
@ -7,6 +7,7 @@ export enum SystemStatusItemStatus {
|
||||
OK = 'OK',
|
||||
ERROR = 'ERROR',
|
||||
WARNING = 'WARNING',
|
||||
DISABLED = 'DISABLED',
|
||||
}
|
||||
|
||||
export interface SystemStatus {
|
||||
@ -43,5 +44,8 @@ export interface SystemStatus {
|
||||
sanity_check_status: SystemStatusItemStatus
|
||||
sanity_check_last_run: string // ISO date string
|
||||
sanity_check_error: string
|
||||
llmindex_status: SystemStatusItemStatus
|
||||
llmindex_last_modified: string // ISO date string
|
||||
llmindex_error: string
|
||||
}
|
||||
}
|
||||
|
@ -74,6 +74,7 @@ export const SETTINGS_KEYS = {
|
||||
GMAIL_OAUTH_URL: 'gmail_oauth_url',
|
||||
OUTLOOK_OAUTH_URL: 'outlook_oauth_url',
|
||||
EMAIL_ENABLED: 'email_enabled',
|
||||
AI_ENABLED: 'ai_enabled',
|
||||
}
|
||||
|
||||
export const SETTINGS: UiSetting[] = [
|
||||
@ -282,4 +283,9 @@ export const SETTINGS: UiSetting[] = [
|
||||
type: 'string',
|
||||
default: 'page-width', // ZoomSetting from 'document-detail.component'
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.AI_ENABLED,
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
},
|
||||
]
|
||||
|
58
src-ui/src/app/services/chat.service.spec.ts
Normal file
58
src-ui/src/app/services/chat.service.spec.ts
Normal file
@ -0,0 +1,58 @@
|
||||
import {
|
||||
HttpEventType,
|
||||
provideHttpClient,
|
||||
withInterceptorsFromDi,
|
||||
} from '@angular/common/http'
|
||||
import {
|
||||
HttpTestingController,
|
||||
provideHttpClientTesting,
|
||||
} from '@angular/common/http/testing'
|
||||
import { TestBed } from '@angular/core/testing'
|
||||
import { environment } from 'src/environments/environment'
|
||||
import { ChatService } from './chat.service'
|
||||
|
||||
describe('ChatService', () => {
|
||||
let service: ChatService
|
||||
let httpMock: HttpTestingController
|
||||
|
||||
beforeEach(() => {
|
||||
TestBed.configureTestingModule({
|
||||
imports: [],
|
||||
providers: [
|
||||
ChatService,
|
||||
provideHttpClient(withInterceptorsFromDi()),
|
||||
provideHttpClientTesting(),
|
||||
],
|
||||
})
|
||||
service = TestBed.inject(ChatService)
|
||||
httpMock = TestBed.inject(HttpTestingController)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
httpMock.verify()
|
||||
})
|
||||
|
||||
it('should stream chat messages', (done) => {
|
||||
const documentId = 1
|
||||
const prompt = 'Hello, world!'
|
||||
const mockResponse = 'Partial response text'
|
||||
const apiUrl = `${environment.apiBaseUrl}documents/chat/`
|
||||
|
||||
service.streamChat(documentId, prompt).subscribe((chunk) => {
|
||||
expect(chunk).toBe(mockResponse)
|
||||
done()
|
||||
})
|
||||
|
||||
const req = httpMock.expectOne(apiUrl)
|
||||
expect(req.request.method).toBe('POST')
|
||||
expect(req.request.body).toEqual({
|
||||
document_id: documentId,
|
||||
q: prompt,
|
||||
})
|
||||
|
||||
req.event({
|
||||
type: HttpEventType.DownloadProgress,
|
||||
partialText: mockResponse,
|
||||
} as any)
|
||||
})
|
||||
})
|
46
src-ui/src/app/services/chat.service.ts
Normal file
46
src-ui/src/app/services/chat.service.ts
Normal file
@ -0,0 +1,46 @@
|
||||
import {
|
||||
HttpClient,
|
||||
HttpDownloadProgressEvent,
|
||||
HttpEventType,
|
||||
} from '@angular/common/http'
|
||||
import { Injectable } from '@angular/core'
|
||||
import { filter, map, Observable } from 'rxjs'
|
||||
import { environment } from 'src/environments/environment'
|
||||
|
||||
export interface ChatMessage {
|
||||
role: 'user' | 'assistant'
|
||||
content: string
|
||||
isStreaming?: boolean
|
||||
}
|
||||
|
||||
@Injectable({
|
||||
providedIn: 'root',
|
||||
})
|
||||
export class ChatService {
|
||||
constructor(private http: HttpClient) {}
|
||||
|
||||
streamChat(documentId: number, prompt: string): Observable<string> {
|
||||
return this.http
|
||||
.post(
|
||||
`${environment.apiBaseUrl}documents/chat/`,
|
||||
{
|
||||
document_id: documentId,
|
||||
q: prompt,
|
||||
},
|
||||
{
|
||||
observe: 'events',
|
||||
reportProgress: true,
|
||||
responseType: 'text',
|
||||
withCredentials: true,
|
||||
}
|
||||
)
|
||||
.pipe(
|
||||
map((event) => {
|
||||
if (event.type === HttpEventType.DownloadProgress) {
|
||||
return (event as HttpDownloadProgressEvent).partialText!
|
||||
}
|
||||
}),
|
||||
filter((chunk) => !!chunk)
|
||||
)
|
||||
}
|
||||
}
|
@ -9,6 +9,7 @@ import { DatePipe, registerLocaleData } from '@angular/common'
|
||||
import {
|
||||
HTTP_INTERCEPTORS,
|
||||
provideHttpClient,
|
||||
withFetch,
|
||||
withInterceptorsFromDi,
|
||||
} from '@angular/common/http'
|
||||
import { FormsModule, ReactiveFormsModule } from '@angular/forms'
|
||||
@ -48,6 +49,7 @@ import {
|
||||
caretDown,
|
||||
caretUp,
|
||||
chatLeftText,
|
||||
chatSquareDots,
|
||||
check,
|
||||
check2All,
|
||||
checkAll,
|
||||
@ -118,6 +120,7 @@ import {
|
||||
sliders2Vertical,
|
||||
sortAlphaDown,
|
||||
sortAlphaUpAlt,
|
||||
stars,
|
||||
tag,
|
||||
tagFill,
|
||||
tags,
|
||||
@ -255,6 +258,7 @@ const icons = {
|
||||
caretDown,
|
||||
caretUp,
|
||||
chatLeftText,
|
||||
chatSquareDots,
|
||||
check,
|
||||
check2All,
|
||||
checkAll,
|
||||
@ -325,6 +329,7 @@ const icons = {
|
||||
sliders2Vertical,
|
||||
sortAlphaDown,
|
||||
sortAlphaUpAlt,
|
||||
stars,
|
||||
tagFill,
|
||||
tag,
|
||||
tags,
|
||||
@ -389,6 +394,6 @@ bootstrapApplication(AppComponent, {
|
||||
CorrespondentNamePipe,
|
||||
DocumentTypeNamePipe,
|
||||
StoragePathNamePipe,
|
||||
provideHttpClient(withInterceptorsFromDi()),
|
||||
provideHttpClient(withInterceptorsFromDi(), withFetch()),
|
||||
],
|
||||
}).catch((err) => console.error(err))
|
||||
|
@ -11,6 +11,7 @@ class DocumentsConfig(AppConfig):
|
||||
from documents.signals import document_consumption_finished
|
||||
from documents.signals import document_updated
|
||||
from documents.signals.handlers import add_inbox_tags
|
||||
from documents.signals.handlers import add_or_update_document_in_llm_index
|
||||
from documents.signals.handlers import add_to_index
|
||||
from documents.signals.handlers import run_workflows_added
|
||||
from documents.signals.handlers import run_workflows_updated
|
||||
@ -26,6 +27,7 @@ class DocumentsConfig(AppConfig):
|
||||
document_consumption_finished.connect(set_storage_path)
|
||||
document_consumption_finished.connect(add_to_index)
|
||||
document_consumption_finished.connect(run_workflows_added)
|
||||
document_consumption_finished.connect(add_or_update_document_in_llm_index)
|
||||
document_updated.connect(run_workflows_updated)
|
||||
|
||||
import documents.schema # noqa: F401
|
||||
|
@ -115,6 +115,56 @@ def refresh_suggestions_cache(
|
||||
cache.touch(doc_key, timeout)
|
||||
|
||||
|
||||
def get_llm_suggestion_cache(
|
||||
document_id: int,
|
||||
backend: str,
|
||||
) -> SuggestionCacheData | None:
|
||||
doc_key = get_suggestion_cache_key(document_id)
|
||||
data: SuggestionCacheData = cache.get(doc_key)
|
||||
|
||||
if data and data.classifier_hash == backend:
|
||||
return data
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def set_llm_suggestions_cache(
|
||||
document_id: int,
|
||||
suggestions: dict,
|
||||
*,
|
||||
backend: str,
|
||||
timeout: int = CACHE_50_MINUTES,
|
||||
) -> None:
|
||||
"""
|
||||
Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4').
|
||||
"""
|
||||
from documents.caching import SuggestionCacheData
|
||||
|
||||
doc_key = get_suggestion_cache_key(document_id)
|
||||
cache.set(
|
||||
doc_key,
|
||||
SuggestionCacheData(
|
||||
classifier_version=1000, # Unique marker for LLM-based suggestion
|
||||
classifier_hash=backend,
|
||||
suggestions=suggestions,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
|
||||
|
||||
def invalidate_llm_suggestions_cache(
|
||||
document_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Invalidate the LLM suggestions cache for a specific document and backend.
|
||||
"""
|
||||
doc_key = get_suggestion_cache_key(document_id)
|
||||
data: SuggestionCacheData = cache.get(doc_key)
|
||||
|
||||
if data:
|
||||
cache.delete(doc_key)
|
||||
|
||||
|
||||
def get_metadata_cache_key(document_id: int) -> str:
|
||||
"""
|
||||
Returns the basic key for a document's metadata
|
||||
|
22
src/documents/management/commands/document_llmindex.py
Normal file
22
src/documents/management/commands/document_llmindex.py
Normal file
@ -0,0 +1,22 @@
|
||||
from django.core.management import BaseCommand
|
||||
from django.db import transaction
|
||||
|
||||
from documents.management.commands.mixins import ProgressBarMixin
|
||||
from documents.tasks import llmindex_index
|
||||
|
||||
|
||||
class Command(ProgressBarMixin, BaseCommand):
|
||||
help = "Manages the LLM-based vector index for Paperless."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument("command", choices=["rebuild", "update"])
|
||||
self.add_argument_progress_bar_mixin(parser)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
with transaction.atomic():
|
||||
llmindex_index(
|
||||
progress_bar_disable=self.no_progress_bar,
|
||||
rebuild=options["command"] == "rebuild",
|
||||
scheduled=False,
|
||||
)
|
@ -0,0 +1,30 @@
|
||||
# Generated by Django 5.1.8 on 2025-04-30 02:38
|
||||
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("documents", "1068_alter_document_created"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="paperlesstask",
|
||||
name="task_name",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("consume_file", "Consume File"),
|
||||
("train_classifier", "Train Classifier"),
|
||||
("check_sanity", "Check Sanity"),
|
||||
("index_optimize", "Index Optimize"),
|
||||
("llmindex_update", "LLM Index Update"),
|
||||
],
|
||||
help_text="Name of the task that was run",
|
||||
max_length=255,
|
||||
null=True,
|
||||
verbose_name="Task Name",
|
||||
),
|
||||
),
|
||||
]
|
@ -543,6 +543,7 @@ class PaperlessTask(ModelWithOwner):
|
||||
TRAIN_CLASSIFIER = ("train_classifier", _("Train Classifier"))
|
||||
CHECK_SANITY = ("check_sanity", _("Check Sanity"))
|
||||
INDEX_OPTIMIZE = ("index_optimize", _("Index Optimize"))
|
||||
LLMINDEX_UPDATE = ("llmindex_update", _("LLM Index Update"))
|
||||
|
||||
task_id = models.CharField(
|
||||
max_length=255,
|
||||
|
@ -26,6 +26,7 @@ from guardian.shortcuts import remove_perm
|
||||
|
||||
from documents import matching
|
||||
from documents.caching import clear_document_caches
|
||||
from documents.caching import invalidate_llm_suggestions_cache
|
||||
from documents.file_handling import create_source_path_directory
|
||||
from documents.file_handling import delete_empty_directories
|
||||
from documents.file_handling import generate_unique_filename
|
||||
@ -47,6 +48,7 @@ from documents.models import WorkflowTrigger
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from documents.permissions import set_permissions_for_object
|
||||
from documents.templating.workflows import parse_w_workflow_placeholders
|
||||
from paperless.config import AIConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
@ -525,6 +527,15 @@ def update_filename_and_move_files(
|
||||
)
|
||||
|
||||
|
||||
@receiver(models.signals.post_save, sender=Document)
|
||||
def update_llm_suggestions_cache(sender, instance, **kwargs):
|
||||
"""
|
||||
Invalidate the LLM suggestions cache when a document is saved.
|
||||
"""
|
||||
# Invalidate the cache for the document
|
||||
invalidate_llm_suggestions_cache(instance.pk)
|
||||
|
||||
|
||||
# should be disabled in /src/documents/management/commands/document_importer.py handle
|
||||
@receiver(models.signals.post_save, sender=CustomField)
|
||||
def check_paths_and_prune_custom_fields(sender, instance: CustomField, **kwargs):
|
||||
@ -1439,3 +1450,26 @@ def task_failure_handler(
|
||||
task_instance.save()
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Updating PaperlessTask failed")
|
||||
|
||||
|
||||
def add_or_update_document_in_llm_index(sender, document, **kwargs):
|
||||
"""
|
||||
Add or update a document in the LLM index when it is created or updated.
|
||||
"""
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled():
|
||||
from documents.tasks import update_document_in_llm_index
|
||||
|
||||
update_document_in_llm_index.delay(document)
|
||||
|
||||
|
||||
@receiver(models.signals.post_delete, sender=Document)
|
||||
def delete_document_from_llm_index(sender, instance: Document, **kwargs):
|
||||
"""
|
||||
Delete a document from the LLM index when it is deleted.
|
||||
"""
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled():
|
||||
from documents.tasks import remove_document_from_llm_index
|
||||
|
||||
remove_document_from_llm_index.delay(instance)
|
||||
|
@ -54,6 +54,10 @@ from documents.sanity_checker import SanityCheckFailedException
|
||||
from documents.signals import document_updated
|
||||
from documents.signals.handlers import cleanup_document_deletion
|
||||
from documents.signals.handlers import run_workflows
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.indexing import llm_index_add_or_update_document
|
||||
from paperless_ai.indexing import llm_index_remove_document
|
||||
from paperless_ai.indexing import update_llm_index
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.models import LogEntry
|
||||
@ -242,6 +246,13 @@ def bulk_update_documents(document_ids):
|
||||
for doc in documents:
|
||||
index.update_document(writer, doc)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled():
|
||||
update_llm_index(
|
||||
progress_bar_disable=True,
|
||||
rebuild=False,
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_document_content_maybe_archive_file(document_id):
|
||||
@ -341,6 +352,10 @@ def update_document_content_maybe_archive_file(document_id):
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, document)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
llm_index_add_or_update_document(document)
|
||||
|
||||
clear_document_caches(document.pk)
|
||||
|
||||
except Exception:
|
||||
@ -517,3 +532,53 @@ def check_scheduled_workflows():
|
||||
workflow_to_run=workflow,
|
||||
document=document,
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
def llmindex_index(
|
||||
*,
|
||||
progress_bar_disable=True,
|
||||
rebuild=False,
|
||||
scheduled=True,
|
||||
auto=False,
|
||||
):
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled():
|
||||
task = PaperlessTask.objects.create(
|
||||
type=PaperlessTask.TaskType.SCHEDULED_TASK
|
||||
if scheduled
|
||||
else PaperlessTask.TaskType.AUTO
|
||||
if auto
|
||||
else PaperlessTask.TaskType.MANUAL_TASK,
|
||||
task_id=uuid.uuid4(),
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
status=states.STARTED,
|
||||
date_created=timezone.now(),
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
from paperless_ai.indexing import update_llm_index
|
||||
|
||||
try:
|
||||
result = update_llm_index(
|
||||
progress_bar_disable=progress_bar_disable,
|
||||
rebuild=rebuild,
|
||||
)
|
||||
task.status = states.SUCCESS
|
||||
task.result = result
|
||||
except Exception as e:
|
||||
logger.error("LLM index error: " + str(e))
|
||||
task.status = states.FAILURE
|
||||
task.result = str(e)
|
||||
|
||||
task.date_done = timezone.now()
|
||||
task.save(update_fields=["status", "result", "date_done"])
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_document_in_llm_index(document):
|
||||
llm_index_add_or_update_document(document)
|
||||
|
||||
|
||||
@shared_task
|
||||
def remove_document_from_llm_index(document):
|
||||
llm_index_remove_document(document)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework import status
|
||||
@ -64,6 +65,13 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
"barcode_max_pages": None,
|
||||
"barcode_enable_tag": None,
|
||||
"barcode_tag_mapping": None,
|
||||
"ai_enabled": False,
|
||||
"llm_embedding_backend": None,
|
||||
"llm_embedding_model": None,
|
||||
"llm_backend": None,
|
||||
"llm_model": None,
|
||||
"llm_api_key": None,
|
||||
"llm_url": None,
|
||||
},
|
||||
)
|
||||
|
||||
@ -189,3 +197,76 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
self.assertEqual(ApplicationConfiguration.objects.count(), 1)
|
||||
|
||||
def test_update_llm_api_key(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Existing config with llm_api_key specified
|
||||
WHEN:
|
||||
- API to update llm_api_key is called with all *s
|
||||
- API to update llm_api_key is called with empty string
|
||||
THEN:
|
||||
- llm_api_key is unchanged
|
||||
- llm_api_key is set to None
|
||||
"""
|
||||
config = ApplicationConfiguration.objects.first()
|
||||
config.llm_api_key = "1234567890"
|
||||
config.save()
|
||||
|
||||
# Test with all *
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
json.dumps(
|
||||
{
|
||||
"llm_api_key": "*" * 32,
|
||||
},
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
config.refresh_from_db()
|
||||
self.assertEqual(config.llm_api_key, "1234567890")
|
||||
# Test with empty string
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
json.dumps(
|
||||
{
|
||||
"llm_api_key": "",
|
||||
},
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
config.refresh_from_db()
|
||||
self.assertEqual(config.llm_api_key, None)
|
||||
|
||||
def test_enable_ai_index_triggers_update(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Existing config with AI disabled
|
||||
WHEN:
|
||||
- Config is updated to enable AI with llm_embedding_backend
|
||||
THEN:
|
||||
- LLM index is triggered to update
|
||||
"""
|
||||
config = ApplicationConfiguration.objects.first()
|
||||
config.ai_enabled = False
|
||||
config.llm_embedding_backend = None
|
||||
config.save()
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.delay") as mock_update,
|
||||
patch("paperless_ai.indexing.vector_store_file_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = False
|
||||
self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
json.dumps(
|
||||
{
|
||||
"ai_enabled": True,
|
||||
"llm_embedding_backend": "openai",
|
||||
},
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
mock_update.assert_called_once()
|
||||
|
@ -310,3 +310,69 @@ class TestSystemStatus(APITestCase):
|
||||
"ERROR",
|
||||
)
|
||||
self.assertIsNotNone(response.data["tasks"]["sanity_check_error"])
|
||||
|
||||
def test_system_status_ai_disabled(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- The AI feature is disabled
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains the correct AI status
|
||||
"""
|
||||
with override_settings(AI_ENABLED=False):
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["llmindex_status"], "DISABLED")
|
||||
self.assertIsNone(response.data["tasks"]["llmindex_error"])
|
||||
|
||||
def test_system_status_ai_enabled(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- The AI index feature is enabled, but no tasks are found
|
||||
- The AI index feature is enabled and a task is found
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains the correct AI status
|
||||
"""
|
||||
with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"):
|
||||
self.client.force_login(self.user)
|
||||
|
||||
# No tasks found
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["llmindex_status"], "WARNING")
|
||||
|
||||
PaperlessTask.objects.create(
|
||||
type=PaperlessTask.TaskType.SCHEDULED_TASK,
|
||||
status=states.SUCCESS,
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["llmindex_status"], "OK")
|
||||
self.assertIsNone(response.data["tasks"]["llmindex_error"])
|
||||
|
||||
def test_system_status_ai_error(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- The AI index feature is enabled and a task is found with an error
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains the correct AI status
|
||||
"""
|
||||
with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"):
|
||||
PaperlessTask.objects.create(
|
||||
type=PaperlessTask.TaskType.SCHEDULED_TASK,
|
||||
status=states.FAILURE,
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
result="AI index update failed",
|
||||
)
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["llmindex_status"], "ERROR")
|
||||
self.assertIsNotNone(response.data["tasks"]["llmindex_error"])
|
||||
|
@ -49,6 +49,7 @@ class TestApiUiSettings(DirectoriesMixin, APITestCase):
|
||||
"backend_setting": "default",
|
||||
},
|
||||
"email_enabled": False,
|
||||
"ai_enabled": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -3,14 +3,17 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from celery import states
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
|
||||
from documents import tasks
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import PaperlessTask
|
||||
from documents.models import Tag
|
||||
from documents.sanity_checker import SanityCheckFailedException
|
||||
from documents.sanity_checker import SanityCheckMessages
|
||||
@ -270,3 +273,103 @@ class TestUpdateContent(DirectoriesMixin, TestCase):
|
||||
|
||||
tasks.update_document_content_maybe_archive_file(doc.pk)
|
||||
self.assertNotEqual(Document.objects.get(pk=doc.pk).content, "test")
|
||||
|
||||
|
||||
class TestAIIndex(DirectoriesMixin, TestCase):
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_ai_index_success(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Document exists, AI is enabled, llm index backend is set
|
||||
WHEN:
|
||||
- llmindex_index task is called
|
||||
THEN:
|
||||
- update_llm_index is called, and the task is marked as success
|
||||
"""
|
||||
Document.objects.create(
|
||||
title="test",
|
||||
content="my document",
|
||||
checksum="wow",
|
||||
)
|
||||
# lazy-loaded so mock the actual function
|
||||
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
|
||||
update_llm_index.return_value = "LLM index updated successfully."
|
||||
tasks.llmindex_index()
|
||||
update_llm_index.assert_called_once()
|
||||
task = PaperlessTask.objects.get(
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
self.assertEqual(task.status, states.SUCCESS)
|
||||
self.assertEqual(task.result, "LLM index updated successfully.")
|
||||
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_ai_index_failure(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Document exists, AI is enabled, llm index backend is set
|
||||
WHEN:
|
||||
- llmindex_index task is called
|
||||
THEN:
|
||||
- update_llm_index raises an exception, and the task is marked as failure
|
||||
"""
|
||||
Document.objects.create(
|
||||
title="test",
|
||||
content="my document",
|
||||
checksum="wow",
|
||||
)
|
||||
# lazy-loaded so mock the actual function
|
||||
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
|
||||
update_llm_index.side_effect = Exception("LLM index update failed.")
|
||||
tasks.llmindex_index()
|
||||
update_llm_index.assert_called_once()
|
||||
task = PaperlessTask.objects.get(
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
self.assertEqual(task.status, states.FAILURE)
|
||||
self.assertIn("LLM index update failed.", task.result)
|
||||
|
||||
def test_update_document_in_llm_index(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Nothing
|
||||
WHEN:
|
||||
- update_document_in_llm_index task is called
|
||||
THEN:
|
||||
- llm_index_add_or_update_document is called
|
||||
"""
|
||||
doc = Document.objects.create(
|
||||
title="test",
|
||||
content="my document",
|
||||
checksum="wow",
|
||||
)
|
||||
with mock.patch(
|
||||
"documents.tasks.llm_index_add_or_update_document",
|
||||
) as llm_index_add_or_update_document:
|
||||
tasks.update_document_in_llm_index(doc)
|
||||
llm_index_add_or_update_document.assert_called_once_with(doc)
|
||||
|
||||
def test_remove_document_from_llm_index(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Nothing
|
||||
WHEN:
|
||||
- remove_document_from_llm_index task is called
|
||||
THEN:
|
||||
- llm_index_remove_document is called
|
||||
"""
|
||||
doc = Document.objects.create(
|
||||
title="test",
|
||||
content="my document",
|
||||
checksum="wow",
|
||||
)
|
||||
with mock.patch(
|
||||
"documents.tasks.llm_index_remove_document",
|
||||
) as llm_index_remove_document:
|
||||
tasks.remove_document_from_llm_index(doc)
|
||||
llm_index_remove_document.assert_called_once_with(doc)
|
||||
|
@ -1,6 +1,8 @@
|
||||
import tempfile
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Permission
|
||||
@ -10,8 +12,15 @@ from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from rest_framework import status
|
||||
|
||||
from documents.caching import get_llm_suggestion_cache
|
||||
from documents.caching import set_llm_suggestions_cache
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import ShareLink
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.signals.handlers import update_llm_suggestions_cache
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from paperless.models import ApplicationConfiguration
|
||||
|
||||
@ -154,3 +163,104 @@ class TestViews(DirectoriesMixin, TestCase):
|
||||
response.render()
|
||||
self.assertEqual(response.request["PATH_INFO"], "/accounts/login/")
|
||||
self.assertContains(response, b"Share link has expired")
|
||||
|
||||
|
||||
class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_superuser(username="testuser")
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
filename="test.pdf",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
self.tag1 = Tag.objects.create(name="tag1")
|
||||
self.correspondent1 = Correspondent.objects.create(name="correspondent1")
|
||||
self.document_type1 = DocumentType.objects.create(name="type1")
|
||||
self.path1 = StoragePath.objects.create(name="path1")
|
||||
super().setUp()
|
||||
|
||||
@patch("documents.views.get_llm_suggestion_cache")
|
||||
@patch("documents.views.refresh_suggestions_cache")
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_BACKEND="mock_backend",
|
||||
)
|
||||
def test_suggestions_with_cached_llm(self, mock_refresh_cache, mock_get_cache):
|
||||
mock_get_cache.return_value = MagicMock(suggestions={"tags": ["tag1", "tag2"]})
|
||||
|
||||
self.client.force_login(user=self.user)
|
||||
response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]})
|
||||
mock_refresh_cache.assert_called_once_with(self.document.pk)
|
||||
|
||||
@patch("documents.views.get_ai_document_classification")
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_BACKEND="mock_backend",
|
||||
)
|
||||
def test_suggestions_with_ai_enabled(
|
||||
self,
|
||||
mock_get_ai_classification,
|
||||
):
|
||||
mock_get_ai_classification.return_value = {
|
||||
"title": "AI Title",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"correspondents": ["correspondent1"],
|
||||
"document_types": ["type1"],
|
||||
"storage_paths": ["path1"],
|
||||
"dates": ["2023-01-01"],
|
||||
}
|
||||
|
||||
self.client.force_login(user=self.user)
|
||||
response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(
|
||||
response.json(),
|
||||
{
|
||||
"title": "AI Title",
|
||||
"tags": [self.tag1.pk],
|
||||
"suggested_tags": ["tag2"],
|
||||
"correspondents": [self.correspondent1.pk],
|
||||
"suggested_correspondents": [],
|
||||
"document_types": [self.document_type1.pk],
|
||||
"suggested_document_types": [],
|
||||
"storage_paths": [self.path1.pk],
|
||||
"suggested_storage_paths": [],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_invalidate_suggestions_cache(self):
|
||||
self.client.force_login(user=self.user)
|
||||
suggestions = {
|
||||
"title": "AI Title",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"correspondents": ["correspondent1"],
|
||||
"document_types": ["type1"],
|
||||
"storage_paths": ["path1"],
|
||||
"dates": ["2023-01-01"],
|
||||
}
|
||||
set_llm_suggestions_cache(
|
||||
self.document.pk,
|
||||
suggestions,
|
||||
backend="mock_backend",
|
||||
)
|
||||
self.assertEqual(
|
||||
get_llm_suggestion_cache(
|
||||
self.document.pk,
|
||||
backend="mock_backend",
|
||||
).suggestions,
|
||||
suggestions,
|
||||
)
|
||||
# post_save signal triggered
|
||||
update_llm_suggestions_cache(
|
||||
sender=None,
|
||||
instance=self.document,
|
||||
)
|
||||
self.assertIsNone(
|
||||
get_llm_suggestion_cache(
|
||||
self.document.pk,
|
||||
backend="mock_backend",
|
||||
),
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
@ -16,6 +17,7 @@ import httpx
|
||||
import pathvalidate
|
||||
from celery import states
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.db import connections
|
||||
@ -38,6 +40,7 @@ from django.http import HttpResponseBadRequest
|
||||
from django.http import HttpResponseForbidden
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.http import HttpResponseServerError
|
||||
from django.http import StreamingHttpResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils import timezone
|
||||
from django.utils.decorators import method_decorator
|
||||
@ -45,6 +48,7 @@ from django.utils.timezone import make_aware
|
||||
from django.utils.translation import get_language
|
||||
from django.views import View
|
||||
from django.views.decorators.cache import cache_control
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
from django.views.decorators.http import condition
|
||||
from django.views.decorators.http import last_modified
|
||||
from django.views.generic import TemplateView
|
||||
@ -80,10 +84,12 @@ from documents import index
|
||||
from documents.bulk_download import ArchiveOnlyStrategy
|
||||
from documents.bulk_download import OriginalAndArchiveStrategy
|
||||
from documents.bulk_download import OriginalsOnlyStrategy
|
||||
from documents.caching import get_llm_suggestion_cache
|
||||
from documents.caching import get_metadata_cache
|
||||
from documents.caching import get_suggestion_cache
|
||||
from documents.caching import refresh_metadata_cache
|
||||
from documents.caching import refresh_suggestions_cache
|
||||
from documents.caching import set_llm_suggestions_cache
|
||||
from documents.caching import set_metadata_cache
|
||||
from documents.caching import set_suggestions_cache
|
||||
from documents.classifier import load_classifier
|
||||
@ -171,11 +177,20 @@ from documents.templating.filepath import validate_filepath_template_and_render
|
||||
from documents.utils import get_boolean
|
||||
from paperless import version
|
||||
from paperless.celery import app as celery_app
|
||||
from paperless.config import AIConfig
|
||||
from paperless.config import GeneralConfig
|
||||
from paperless.db import GnuPG
|
||||
from paperless.serialisers import GroupSerializer
|
||||
from paperless.serialisers import UserSerializer
|
||||
from paperless.views import StandardPagination
|
||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
from paperless_ai.indexing import update_llm_index
|
||||
from paperless_ai.matching import extract_unmatched_names
|
||||
from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
from paperless_ai.matching import match_storage_paths_by_name
|
||||
from paperless_ai.matching import match_tags_by_name
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||
@ -763,37 +778,103 @@ class DocumentViewSet(
|
||||
):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
|
||||
document_suggestions = get_suggestion_cache(doc.pk)
|
||||
ai_config = AIConfig()
|
||||
|
||||
if document_suggestions is not None:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(document_suggestions.suggestions)
|
||||
|
||||
classifier = load_classifier()
|
||||
|
||||
dates = []
|
||||
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
|
||||
gen = parse_date_generator(doc.filename, doc.content)
|
||||
dates = sorted(
|
||||
{i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)},
|
||||
if ai_config.ai_enabled:
|
||||
cached_llm_suggestions = get_llm_suggestion_cache(
|
||||
doc.pk,
|
||||
backend=ai_config.llm_backend,
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
"correspondents": [
|
||||
c.id for c in match_correspondents(doc, classifier, request.user)
|
||||
],
|
||||
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
|
||||
"document_types": [
|
||||
dt.id for dt in match_document_types(doc, classifier, request.user)
|
||||
],
|
||||
"storage_paths": [
|
||||
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
||||
],
|
||||
"dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None],
|
||||
}
|
||||
if cached_llm_suggestions:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(cached_llm_suggestions.suggestions)
|
||||
|
||||
# Cache the suggestions and the classifier hash for later
|
||||
set_suggestions_cache(doc.pk, resp_data, classifier)
|
||||
llm_suggestions = get_ai_document_classification(doc, request.user)
|
||||
|
||||
matched_tags = match_tags_by_name(
|
||||
llm_suggestions.get("tags", []),
|
||||
request.user,
|
||||
)
|
||||
matched_correspondents = match_correspondents_by_name(
|
||||
llm_suggestions.get("correspondents", []),
|
||||
request.user,
|
||||
)
|
||||
matched_types = match_document_types_by_name(
|
||||
llm_suggestions.get("document_types", []),
|
||||
request.user,
|
||||
)
|
||||
matched_paths = match_storage_paths_by_name(
|
||||
llm_suggestions.get("storage_paths", []),
|
||||
request.user,
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
"title": llm_suggestions.get("title"),
|
||||
"tags": [t.id for t in matched_tags],
|
||||
"suggested_tags": extract_unmatched_names(
|
||||
llm_suggestions.get("tags", []),
|
||||
matched_tags,
|
||||
),
|
||||
"correspondents": [c.id for c in matched_correspondents],
|
||||
"suggested_correspondents": extract_unmatched_names(
|
||||
llm_suggestions.get("correspondents", []),
|
||||
matched_correspondents,
|
||||
),
|
||||
"document_types": [d.id for d in matched_types],
|
||||
"suggested_document_types": extract_unmatched_names(
|
||||
llm_suggestions.get("document_types", []),
|
||||
matched_types,
|
||||
),
|
||||
"storage_paths": [s.id for s in matched_paths],
|
||||
"suggested_storage_paths": extract_unmatched_names(
|
||||
llm_suggestions.get("storage_paths", []),
|
||||
matched_paths,
|
||||
),
|
||||
"dates": llm_suggestions.get("dates", []),
|
||||
}
|
||||
|
||||
set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend)
|
||||
else:
|
||||
document_suggestions = get_suggestion_cache(doc.pk)
|
||||
|
||||
if document_suggestions is not None:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(document_suggestions.suggestions)
|
||||
|
||||
classifier = load_classifier()
|
||||
|
||||
dates = []
|
||||
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
|
||||
gen = parse_date_generator(doc.filename, doc.content)
|
||||
dates = sorted(
|
||||
{
|
||||
i
|
||||
for i in itertools.islice(
|
||||
gen,
|
||||
settings.NUMBER_OF_SUGGESTED_DATES,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
"correspondents": [
|
||||
c.id for c in match_correspondents(doc, classifier, request.user)
|
||||
],
|
||||
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
|
||||
"document_types": [
|
||||
dt.id for dt in match_document_types(doc, classifier, request.user)
|
||||
],
|
||||
"storage_paths": [
|
||||
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
||||
],
|
||||
"dates": [
|
||||
date.strftime("%Y-%m-%d") for date in dates if date is not None
|
||||
],
|
||||
}
|
||||
|
||||
# Cache the suggestions and the classifier hash for later
|
||||
set_suggestions_cache(doc.pk, resp_data, classifier)
|
||||
|
||||
return Response(resp_data)
|
||||
|
||||
@ -1093,6 +1174,52 @@ class DocumentViewSet(
|
||||
)
|
||||
|
||||
|
||||
@method_decorator(
|
||||
[
|
||||
ensure_csrf_cookie,
|
||||
login_required,
|
||||
cache_control(no_cache=True),
|
||||
],
|
||||
name="dispatch",
|
||||
)
|
||||
class ChatStreamingView(View):
|
||||
def post(self, request):
|
||||
request.compress_exempt = True
|
||||
ai_config = AIConfig()
|
||||
if not ai_config.ai_enabled:
|
||||
return HttpResponseBadRequest("AI is required for this feature")
|
||||
|
||||
try:
|
||||
data = json.loads(request.body)
|
||||
question = data["q"]
|
||||
doc_id = data.get("document_id", None)
|
||||
except (KeyError, json.JSONDecodeError):
|
||||
return HttpResponseBadRequest("Invalid request")
|
||||
|
||||
if doc_id:
|
||||
try:
|
||||
document = Document.objects.get(id=doc_id)
|
||||
except Document.DoesNotExist:
|
||||
return HttpResponseBadRequest("Document not found")
|
||||
|
||||
if not has_perms_owner_aware(request.user, "view_document", document):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
|
||||
documents = [document]
|
||||
else:
|
||||
documents = get_objects_for_user_owner_aware(
|
||||
request.user,
|
||||
"view_document",
|
||||
Document,
|
||||
)
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
stream_chat_with_documents(query_str=question, documents=documents),
|
||||
content_type="text/event-stream",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
parameters=[
|
||||
@ -2208,6 +2335,10 @@ class UiSettingsView(GenericAPIView):
|
||||
|
||||
ui_settings["email_enabled"] = settings.EMAIL_ENABLED
|
||||
|
||||
ai_config = AIConfig()
|
||||
|
||||
ui_settings["ai_enabled"] = ai_config.ai_enabled
|
||||
|
||||
user_resp = {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
@ -2335,6 +2466,10 @@ class TasksViewSet(ReadOnlyModelViewSet):
|
||||
sanity_check,
|
||||
{"scheduled": False, "raise_on_error": False},
|
||||
),
|
||||
PaperlessTask.TaskName.LLMINDEX_UPDATE: (
|
||||
update_llm_index,
|
||||
{"scheduled": False, "rebuild": False},
|
||||
),
|
||||
}
|
||||
|
||||
def get_queryset(self):
|
||||
@ -2840,6 +2975,31 @@ class SystemStatusView(PassUserMixin):
|
||||
last_sanity_check.date_done if last_sanity_check else None
|
||||
)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if not ai_config.llm_index_enabled():
|
||||
llmindex_status = "DISABLED"
|
||||
llmindex_error = None
|
||||
llmindex_last_modified = None
|
||||
else:
|
||||
last_llmindex_update = (
|
||||
PaperlessTask.objects.filter(
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
.order_by("-date_done")
|
||||
.first()
|
||||
)
|
||||
llmindex_status = "OK"
|
||||
llmindex_error = None
|
||||
if last_llmindex_update is None:
|
||||
llmindex_status = "WARNING"
|
||||
llmindex_error = "No LLM index update tasks found"
|
||||
elif last_llmindex_update and last_llmindex_update.status == states.FAILURE:
|
||||
llmindex_status = "ERROR"
|
||||
llmindex_error = last_llmindex_update.result
|
||||
llmindex_last_modified = (
|
||||
last_llmindex_update.date_done if last_llmindex_update else None
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"pngx_version": current_version,
|
||||
@ -2877,6 +3037,9 @@ class SystemStatusView(PassUserMixin):
|
||||
"sanity_check_status": sanity_check_status,
|
||||
"sanity_check_last_run": sanity_check_last_run,
|
||||
"sanity_check_error": sanity_check_error,
|
||||
"llmindex_status": llmindex_status,
|
||||
"llmindex_last_modified": llmindex_last_modified,
|
||||
"llmindex_error": llmindex_error,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -169,3 +169,36 @@ class GeneralConfig(BaseConfig):
|
||||
|
||||
self.app_title = app_config.app_title or None
|
||||
self.app_logo = app_config.app_logo.url if app_config.app_logo else None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AIConfig(BaseConfig):
|
||||
"""
|
||||
AI related settings that require global scope
|
||||
"""
|
||||
|
||||
ai_enabled: bool = dataclasses.field(init=False)
|
||||
llm_embedding_backend: str = dataclasses.field(init=False)
|
||||
llm_embedding_model: str = dataclasses.field(init=False)
|
||||
llm_backend: str = dataclasses.field(init=False)
|
||||
llm_model: str = dataclasses.field(init=False)
|
||||
llm_api_key: str = dataclasses.field(init=False)
|
||||
llm_url: str = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
app_config = self._get_config_instance()
|
||||
|
||||
self.ai_enabled = app_config.ai_enabled or settings.AI_ENABLED
|
||||
self.llm_embedding_backend = (
|
||||
app_config.llm_embedding_backend or settings.LLM_EMBEDDING_BACKEND
|
||||
)
|
||||
self.llm_embedding_model = (
|
||||
app_config.llm_embedding_model or settings.LLM_EMBEDDING_MODEL
|
||||
)
|
||||
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
|
||||
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
||||
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
||||
self.llm_url = app_config.llm_url or settings.LLM_URL
|
||||
|
||||
def llm_index_enabled(self) -> bool:
|
||||
return self.ai_enabled and self.llm_embedding_backend
|
||||
|
@ -0,0 +1,84 @@
|
||||
# Generated by Django 5.1.8 on 2025-04-30 02:38
|
||||
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("paperless", "0004_applicationconfiguration_barcode_asn_prefix_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="ai_enabled",
|
||||
field=models.BooleanField(
|
||||
default=False,
|
||||
null=True,
|
||||
verbose_name="Enables AI features",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="llm_api_key",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
max_length=128,
|
||||
null=True,
|
||||
verbose_name="Sets the LLM API key",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="llm_backend",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
choices=[("openai", "OpenAI"), ("ollama", "Ollama")],
|
||||
max_length=32,
|
||||
null=True,
|
||||
verbose_name="Sets the LLM backend",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="llm_embedding_backend",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
choices=[("openai", "OpenAI"), ("huggingface", "Huggingface")],
|
||||
max_length=32,
|
||||
null=True,
|
||||
verbose_name="Sets the LLM embedding backend",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="llm_embedding_model",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
max_length=32,
|
||||
null=True,
|
||||
verbose_name="Sets the LLM embedding model",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="llm_model",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
max_length=32,
|
||||
null=True,
|
||||
verbose_name="Sets the LLM model",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="llm_url",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
max_length=128,
|
||||
null=True,
|
||||
verbose_name="Sets the LLM URL, optional",
|
||||
),
|
||||
),
|
||||
]
|
@ -74,6 +74,20 @@ class ColorConvertChoices(models.TextChoices):
|
||||
CMYK = ("CMYK", _("CMYK"))
|
||||
|
||||
|
||||
class LLMEmbeddingBackend(models.TextChoices):
|
||||
OPENAI = ("openai", _("OpenAI"))
|
||||
HUGGINGFACE = ("huggingface", _("Huggingface"))
|
||||
|
||||
|
||||
class LLMBackend(models.TextChoices):
|
||||
"""
|
||||
Matches to --llm-backend
|
||||
"""
|
||||
|
||||
OPENAI = ("openai", _("OpenAI"))
|
||||
OLLAMA = ("ollama", _("Ollama"))
|
||||
|
||||
|
||||
class ApplicationConfiguration(AbstractSingletonModel):
|
||||
"""
|
||||
Settings which are common across more than 1 parser
|
||||
@ -265,6 +279,60 @@ class ApplicationConfiguration(AbstractSingletonModel):
|
||||
null=True,
|
||||
)
|
||||
|
||||
"""
|
||||
AI related settings
|
||||
"""
|
||||
|
||||
ai_enabled = models.BooleanField(
|
||||
verbose_name=_("Enables AI features"),
|
||||
null=True,
|
||||
default=False,
|
||||
)
|
||||
|
||||
llm_embedding_backend = models.CharField(
|
||||
verbose_name=_("Sets the LLM embedding backend"),
|
||||
null=True,
|
||||
blank=True,
|
||||
max_length=32,
|
||||
choices=LLMEmbeddingBackend.choices,
|
||||
)
|
||||
|
||||
llm_embedding_model = models.CharField(
|
||||
verbose_name=_("Sets the LLM embedding model"),
|
||||
null=True,
|
||||
blank=True,
|
||||
max_length=32,
|
||||
)
|
||||
|
||||
llm_backend = models.CharField(
|
||||
verbose_name=_("Sets the LLM backend"),
|
||||
null=True,
|
||||
blank=True,
|
||||
max_length=32,
|
||||
choices=LLMBackend.choices,
|
||||
)
|
||||
|
||||
llm_model = models.CharField(
|
||||
verbose_name=_("Sets the LLM model"),
|
||||
null=True,
|
||||
blank=True,
|
||||
max_length=32,
|
||||
)
|
||||
|
||||
llm_api_key = models.CharField(
|
||||
verbose_name=_("Sets the LLM API key"),
|
||||
null=True,
|
||||
blank=True,
|
||||
max_length=128,
|
||||
)
|
||||
|
||||
llm_url = models.CharField(
|
||||
verbose_name=_("Sets the LLM URL, optional"),
|
||||
null=True,
|
||||
blank=True,
|
||||
max_length=128,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("paperless application settings")
|
||||
|
||||
|
@ -190,6 +190,10 @@ class ProfileSerializer(serializers.ModelSerializer):
|
||||
class ApplicationConfigurationSerializer(serializers.ModelSerializer):
|
||||
user_args = serializers.JSONField(binary=True, allow_null=True)
|
||||
barcode_tag_mapping = serializers.JSONField(binary=True, allow_null=True)
|
||||
llm_api_key = ObfuscatedPasswordField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
)
|
||||
|
||||
def run_validation(self, data):
|
||||
# Empty strings treated as None to avoid unexpected behavior
|
||||
@ -199,6 +203,11 @@ class ApplicationConfigurationSerializer(serializers.ModelSerializer):
|
||||
data["barcode_tag_mapping"] = None
|
||||
if "language" in data and data["language"] == "":
|
||||
data["language"] = None
|
||||
if "llm_api_key" in data and data["llm_api_key"] is not None:
|
||||
if data["llm_api_key"] == "":
|
||||
data["llm_api_key"] = None
|
||||
elif len(data["llm_api_key"].replace("*", "")) == 0:
|
||||
del data["llm_api_key"]
|
||||
return super().run_validation(data)
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
|
@ -11,6 +11,7 @@ from typing import Final
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from celery.schedules import crontab
|
||||
from compression_middleware.middleware import CompressionMiddleware
|
||||
from concurrent_log_handler.queue import setup_logging_queues
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dotenv import load_dotenv
|
||||
@ -226,6 +227,17 @@ def _parse_beat_schedule() -> dict:
|
||||
"expires": 59.0 * 60.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "Rebuild LLM index",
|
||||
"env_key": "PAPERLESS_LLM_INDEX_TASK_CRON",
|
||||
# Default daily at 02:10
|
||||
"env_default": "10 2 * * *",
|
||||
"task": "documents.tasks.llmindex_index",
|
||||
"options": {
|
||||
# 1 hour before default schedule sends again
|
||||
"expires": 23.0 * 60.0 * 60.0,
|
||||
},
|
||||
},
|
||||
]
|
||||
for task in tasks:
|
||||
# Either get the environment setting or use the default
|
||||
@ -284,6 +296,7 @@ MODEL_FILE = __get_path(
|
||||
"PAPERLESS_MODEL_FILE",
|
||||
DATA_DIR / "classification_model.pickle",
|
||||
)
|
||||
LLM_INDEX_DIR = DATA_DIR / "llm_index"
|
||||
|
||||
LOGGING_DIR = __get_path("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
|
||||
|
||||
@ -375,6 +388,19 @@ MIDDLEWARE = [
|
||||
if __get_boolean("PAPERLESS_ENABLE_COMPRESSION", "yes"): # pragma: no cover
|
||||
MIDDLEWARE.insert(0, "compression_middleware.middleware.CompressionMiddleware")
|
||||
|
||||
# Workaround to not compress streaming responses (e.g. chat).
|
||||
# See https://github.com/friedelwolff/django-compression-middleware/pull/7
|
||||
original_process_response = CompressionMiddleware.process_response
|
||||
|
||||
|
||||
def patched_process_response(self, request, response):
|
||||
if getattr(request, "compress_exempt", False):
|
||||
return response
|
||||
return original_process_response(self, request, response)
|
||||
|
||||
|
||||
CompressionMiddleware.process_response = patched_process_response
|
||||
|
||||
ROOT_URLCONF = "paperless.urls"
|
||||
|
||||
|
||||
@ -584,6 +610,10 @@ X_FRAME_OPTIONS = "SAMEORIGIN"
|
||||
# The next 3 settings can also be set using just PAPERLESS_URL
|
||||
CSRF_TRUSTED_ORIGINS = __get_list("PAPERLESS_CSRF_TRUSTED_ORIGINS")
|
||||
|
||||
if DEBUG:
|
||||
# Allow access from the angular development server during debugging
|
||||
CSRF_TRUSTED_ORIGINS.append("http://localhost:4200")
|
||||
|
||||
# We allow CORS from localhost:8000
|
||||
CORS_ALLOWED_ORIGINS = __get_list(
|
||||
"PAPERLESS_CORS_ALLOWED_HOSTS",
|
||||
@ -594,6 +624,8 @@ if DEBUG:
|
||||
# Allow access from the angular development server during debugging
|
||||
CORS_ALLOWED_ORIGINS.append("http://localhost:4200")
|
||||
|
||||
CORS_ALLOW_CREDENTIALS = True
|
||||
|
||||
CORS_EXPOSE_HEADERS = [
|
||||
"Content-Disposition",
|
||||
]
|
||||
@ -855,6 +887,7 @@ LOGGING = {
|
||||
"loggers": {
|
||||
"paperless": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
||||
"paperless_mail": {"handlers": ["file_mail"], "level": "DEBUG"},
|
||||
"paperless_ai": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
||||
"ocrmypdf": {"handlers": ["file_paperless"], "level": "INFO"},
|
||||
"celery": {"handlers": ["file_celery"], "level": "DEBUG"},
|
||||
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
|
||||
@ -1278,3 +1311,16 @@ OUTLOOK_OAUTH_ENABLED = bool(
|
||||
and OUTLOOK_OAUTH_CLIENT_ID
|
||||
and OUTLOOK_OAUTH_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# AI Settings #
|
||||
################################################################################
|
||||
AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
|
||||
LLM_EMBEDDING_BACKEND = os.getenv(
|
||||
"PAPERLESS_LLM_EMBEDDING_BACKEND",
|
||||
) # "huggingface" or "openai"
|
||||
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
|
||||
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND") # "ollama" or "openai"
|
||||
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
|
||||
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
|
||||
LLM_URL = os.getenv("PAPERLESS_LLM_URL")
|
||||
|
@ -158,6 +158,7 @@ class TestCeleryScheduleParsing(TestCase):
|
||||
SANITY_EXPIRE_TIME = ((7.0 * 24.0) - 1.0) * 60.0 * 60.0
|
||||
EMPTY_TRASH_EXPIRE_TIME = 23.0 * 60.0 * 60.0
|
||||
RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME = 59.0 * 60.0
|
||||
LLM_INDEX_EXPIRE_TIME = 23.0 * 60.0 * 60.0
|
||||
|
||||
def test_schedule_configuration_default(self):
|
||||
"""
|
||||
@ -202,6 +203,13 @@ class TestCeleryScheduleParsing(TestCase):
|
||||
"schedule": crontab(minute="5", hour="*/1"),
|
||||
"options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME},
|
||||
},
|
||||
"Rebuild LLM index": {
|
||||
"task": "documents.tasks.llmindex_index",
|
||||
"schedule": crontab(minute=10, hour=2),
|
||||
"options": {
|
||||
"expires": self.LLM_INDEX_EXPIRE_TIME,
|
||||
},
|
||||
},
|
||||
},
|
||||
schedule,
|
||||
)
|
||||
@ -254,6 +262,13 @@ class TestCeleryScheduleParsing(TestCase):
|
||||
"schedule": crontab(minute="5", hour="*/1"),
|
||||
"options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME},
|
||||
},
|
||||
"Rebuild LLM index": {
|
||||
"task": "documents.tasks.llmindex_index",
|
||||
"schedule": crontab(minute=10, hour=2),
|
||||
"options": {
|
||||
"expires": self.LLM_INDEX_EXPIRE_TIME,
|
||||
},
|
||||
},
|
||||
},
|
||||
schedule,
|
||||
)
|
||||
@ -298,6 +313,13 @@ class TestCeleryScheduleParsing(TestCase):
|
||||
"schedule": crontab(minute="5", hour="*/1"),
|
||||
"options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME},
|
||||
},
|
||||
"Rebuild LLM index": {
|
||||
"task": "documents.tasks.llmindex_index",
|
||||
"schedule": crontab(minute=10, hour=2),
|
||||
"options": {
|
||||
"expires": self.LLM_INDEX_EXPIRE_TIME,
|
||||
},
|
||||
},
|
||||
},
|
||||
schedule,
|
||||
)
|
||||
@ -320,6 +342,7 @@ class TestCeleryScheduleParsing(TestCase):
|
||||
"PAPERLESS_INDEX_TASK_CRON": "disable",
|
||||
"PAPERLESS_EMPTY_TRASH_TASK_CRON": "disable",
|
||||
"PAPERLESS_WORKFLOW_SCHEDULED_TASK_CRON": "disable",
|
||||
"PAPERLESS_LLM_INDEX_TASK_CRON": "disable",
|
||||
},
|
||||
):
|
||||
schedule = _parse_beat_schedule()
|
||||
|
@ -21,6 +21,7 @@ from rest_framework.routers import DefaultRouter
|
||||
from documents.views import BulkDownloadView
|
||||
from documents.views import BulkEditObjectsView
|
||||
from documents.views import BulkEditView
|
||||
from documents.views import ChatStreamingView
|
||||
from documents.views import CorrespondentViewSet
|
||||
from documents.views import CustomFieldViewSet
|
||||
from documents.views import DocumentTypeViewSet
|
||||
@ -139,6 +140,11 @@ urlpatterns = [
|
||||
SelectionDataView.as_view(),
|
||||
name="selection_data",
|
||||
),
|
||||
re_path(
|
||||
"^chat/",
|
||||
ChatStreamingView.as_view(),
|
||||
name="chat_streaming_view",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
|
@ -35,6 +35,7 @@ from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from documents.index import DelayedQuery
|
||||
from documents.permissions import PaperlessObjectPermissions
|
||||
from documents.tasks import llmindex_index
|
||||
from paperless.filters import GroupFilterSet
|
||||
from paperless.filters import UserFilterSet
|
||||
from paperless.models import ApplicationConfiguration
|
||||
@ -43,6 +44,7 @@ from paperless.serialisers import GroupSerializer
|
||||
from paperless.serialisers import PaperlessAuthTokenSerializer
|
||||
from paperless.serialisers import ProfileSerializer
|
||||
from paperless.serialisers import UserSerializer
|
||||
from paperless_ai.indexing import vector_store_file_exists
|
||||
|
||||
|
||||
class PaperlessObtainAuthTokenView(ObtainAuthToken):
|
||||
@ -353,6 +355,30 @@ class ApplicationConfigurationViewSet(ModelViewSet):
|
||||
def create(self, request, *args, **kwargs):
|
||||
return Response(status=405) # Not Allowed
|
||||
|
||||
def perform_update(self, serializer):
|
||||
old_instance = ApplicationConfiguration.objects.all().first()
|
||||
old_ai_index_enabled = (
|
||||
old_instance.ai_enabled and old_instance.llm_embedding_backend
|
||||
)
|
||||
|
||||
new_instance: ApplicationConfiguration = serializer.save()
|
||||
new_ai_index_enabled = (
|
||||
new_instance.ai_enabled and new_instance.llm_embedding_backend
|
||||
)
|
||||
|
||||
if (
|
||||
not old_ai_index_enabled
|
||||
and new_ai_index_enabled
|
||||
and not vector_store_file_exists()
|
||||
):
|
||||
# AI index was just enabled and vector store file does not exist
|
||||
llmindex_index.delay(
|
||||
progress_bar_disable=True,
|
||||
rebuild=True,
|
||||
scheduled=False,
|
||||
auto=True,
|
||||
)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
post=extend_schema(
|
||||
|
0
src/paperless_ai/__init__.py
Normal file
0
src/paperless_ai/__init__.py
Normal file
153
src/paperless_ai/ai_classifier.py
Normal file
153
src/paperless_ai/ai_classifier.py
Normal file
@ -0,0 +1,153 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from llama_index.core.base.llms.types import CompletionResponse
|
||||
|
||||
from documents.models import Document
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.indexing import query_similar_documents
|
||||
from paperless_ai.indexing import truncate_content
|
||||
|
||||
logger = logging.getLogger("paperless_ai.rag_classifier")
|
||||
|
||||
|
||||
def build_prompt_without_rag(document: Document) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(document.content[:4000] or "")
|
||||
|
||||
prompt = f"""
|
||||
You are an assistant that extracts structured information from documents.
|
||||
Only respond with the JSON object as described below.
|
||||
Never ask for further information, additional content or ask questions. Never include any other text.
|
||||
Suggested tags and document types must be strictly based on the content of the document.
|
||||
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
|
||||
Each field must be a list of plain strings.
|
||||
|
||||
The JSON object must contain the following fields:
|
||||
- title: A short, descriptive title
|
||||
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
|
||||
- correspondents: A list of names or organizations mentioned in the document
|
||||
- document_types: The type/category of the document (e.g. "invoice", "medical record")
|
||||
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
||||
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
||||
|
||||
The format of the JSON object is as follows:
|
||||
{{
|
||||
"title": "xxxxx",
|
||||
"tags": ["xxxx", "xxxx"],
|
||||
"correspondents": ["xxxx", "xxxx"],
|
||||
"document_types": ["xxxx", "xxxx"],
|
||||
"storage_paths": ["xxxx", "xxxx"],
|
||||
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
||||
}}
|
||||
---------
|
||||
|
||||
FILENAME:
|
||||
{filename}
|
||||
|
||||
CONTENT:
|
||||
{content}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
||||
context = truncate_content(get_context_for_document(document, user))
|
||||
prompt = build_prompt_without_rag(document)
|
||||
|
||||
prompt += f"""
|
||||
|
||||
CONTEXT FROM SIMILAR DOCUMENTS:
|
||||
{context}
|
||||
|
||||
---------
|
||||
|
||||
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_context_for_document(
|
||||
doc: Document,
|
||||
user: User | None = None,
|
||||
max_docs: int = 5,
|
||||
) -> str:
|
||||
visible_documents = (
|
||||
get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"view_document",
|
||||
Document,
|
||||
)
|
||||
if user
|
||||
else None
|
||||
)
|
||||
similar_docs = query_similar_documents(
|
||||
document=doc,
|
||||
document_ids=[document.pk for document in visible_documents]
|
||||
if visible_documents
|
||||
else None,
|
||||
)[:max_docs]
|
||||
context_blocks = []
|
||||
for similar in similar_docs:
|
||||
text = similar.content[:1000] or ""
|
||||
title = similar.title or similar.filename or "Untitled"
|
||||
context_blocks.append(f"TITLE: {title}\n{text}")
|
||||
return "\n\n".join(context_blocks)
|
||||
|
||||
|
||||
def parse_ai_response(response: CompletionResponse) -> dict:
|
||||
try:
|
||||
raw = json.loads(response.text)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
|
||||
try:
|
||||
# search for a valid json string like { ... } in the response
|
||||
start = response.text.index("{")
|
||||
end = response.text.rindex("}") + 1
|
||||
json_str = response.text[start:end]
|
||||
raw = json.loads(json_str)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
logger.exception("Failed to parse AI response")
|
||||
return {}
|
||||
|
||||
|
||||
def get_ai_document_classification(
|
||||
document: Document,
|
||||
user: User | None = None,
|
||||
) -> dict:
|
||||
ai_config = AIConfig()
|
||||
|
||||
prompt = (
|
||||
build_prompt_with_rag(document, user)
|
||||
if ai_config.llm_embedding_backend
|
||||
else build_prompt_without_rag(document)
|
||||
)
|
||||
|
||||
try:
|
||||
client = AIClient()
|
||||
result = client.run_llm_query(prompt)
|
||||
return parse_ai_response(result)
|
||||
except Exception as e:
|
||||
logger.exception("Failed AI classification")
|
||||
raise e
|
77
src/paperless_ai/chat.py
Normal file
77
src/paperless_ai/chat.py
Normal file
@ -0,0 +1,77 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.indexing import load_or_build_index
|
||||
|
||||
logger = logging.getLogger("paperless_ai.chat")
|
||||
|
||||
CHAT_PROMPT_TMPL = PromptTemplate(
|
||||
template="""Context information is below.
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Given the context information and not prior knowledge, answer the query.
|
||||
Query: {query_str}
|
||||
Answer:""",
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
client = AIClient()
|
||||
index = load_or_build_index()
|
||||
|
||||
doc_ids = [str(doc.pk) for doc in documents]
|
||||
|
||||
# Filter only the node(s) that match the document IDs
|
||||
nodes = [
|
||||
node
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in doc_ids
|
||||
]
|
||||
|
||||
if len(nodes) == 0:
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
yield "Sorry, I couldn't find any content to answer your question."
|
||||
return
|
||||
|
||||
local_index = VectorStoreIndex(nodes=nodes)
|
||||
retriever = local_index.as_retriever(
|
||||
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||
)
|
||||
|
||||
if len(documents) == 1:
|
||||
# Just one doc — provide full content
|
||||
doc = documents[0]
|
||||
# TODO: include document metadata in the context
|
||||
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
|
||||
else:
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
context = "\n\n".join(
|
||||
f"TITLE: {node.metadata.get('title')}\n{node.text[:500]}"
|
||||
for node in top_nodes
|
||||
)
|
||||
|
||||
prompt = CHAT_PROMPT_TMPL.partial_format(
|
||||
context_str=context,
|
||||
query_str=query_str,
|
||||
).format(llm=client.llm)
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
logger.debug("Document chat prompt: %s", prompt)
|
||||
|
||||
response_stream = query_engine.query(prompt)
|
||||
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
sys.stdout.flush()
|
54
src/paperless_ai/client.py
Normal file
54
src/paperless_ai/client.py
Normal file
@ -0,0 +1,54 @@
|
||||
import logging
|
||||
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.config import AIConfig
|
||||
|
||||
logger = logging.getLogger("paperless_ai.client")
|
||||
|
||||
|
||||
class AIClient:
|
||||
"""
|
||||
A client for interacting with an LLM backend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = AIConfig()
|
||||
self.llm = self.get_llm()
|
||||
|
||||
def get_llm(self):
|
||||
if self.settings.llm_backend == "ollama":
|
||||
return Ollama(
|
||||
model=self.settings.llm_model or "llama3",
|
||||
base_url=self.settings.llm_url or "http://localhost:11434",
|
||||
request_timeout=120,
|
||||
)
|
||||
elif self.settings.llm_backend == "openai":
|
||||
return OpenAI(
|
||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||
api_key=self.settings.llm_api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
|
||||
|
||||
def run_llm_query(self, prompt: str) -> str:
|
||||
logger.debug(
|
||||
"Running LLM query against %s with model %s",
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
result = self.llm.complete(prompt)
|
||||
logger.debug("LLM query result: %s", result)
|
||||
return result
|
||||
|
||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||
logger.debug(
|
||||
"Running chat query against %s with model %s",
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
result = self.llm.chat(messages)
|
||||
logger.debug("Chat result: %s", result)
|
||||
return result
|
69
src/paperless_ai/embedding.py
Normal file
69
src/paperless_ai/embedding.py
Normal file
@ -0,0 +1,69 @@
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import Note
|
||||
from paperless.config import AIConfig
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
|
||||
EMBEDDING_DIMENSIONS = {
|
||||
"text-embedding-3-small": 1536,
|
||||
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model() -> BaseEmbedding:
|
||||
config = AIConfig()
|
||||
|
||||
match config.llm_embedding_backend:
|
||||
case LLMEmbeddingBackend.OPENAI:
|
||||
return OpenAIEmbedding(
|
||||
model=config.llm_embedding_model or "text-embedding-3-small",
|
||||
api_key=config.llm_api_key,
|
||||
)
|
||||
case LLMEmbeddingBackend.HUGGINGFACE:
|
||||
return HuggingFaceEmbedding(
|
||||
model_name=config.llm_embedding_model
|
||||
or "sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding backend: {config.llm_embedding_backend}",
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_dim() -> int:
|
||||
config = AIConfig()
|
||||
model = config.llm_embedding_model or (
|
||||
"text-embedding-3-small"
|
||||
if config.llm_embedding_backend == "openai"
|
||||
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
if model not in EMBEDDING_DIMENSIONS:
|
||||
raise ValueError(f"Unknown embedding model: {model}")
|
||||
return EMBEDDING_DIMENSIONS[model]
|
||||
|
||||
|
||||
def build_llm_index_text(doc: Document) -> str:
|
||||
lines = [
|
||||
f"Title: {doc.title}",
|
||||
f"Filename: {doc.filename}",
|
||||
f"Created: {doc.created}",
|
||||
f"Added: {doc.added}",
|
||||
f"Modified: {doc.modified}",
|
||||
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
|
||||
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
|
||||
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
|
||||
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
|
||||
f"Archive Serial Number: {doc.archive_serial_number or ''}",
|
||||
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
|
||||
]
|
||||
|
||||
for instance in doc.custom_fields.all():
|
||||
lines.append(f"Custom Field - {instance.field.name}: {instance}")
|
||||
|
||||
lines.append("\nContent:\n")
|
||||
lines.append(doc.content or "")
|
||||
|
||||
return "\n".join(lines)
|
281
src/paperless_ai/indexing.py
Normal file
281
src/paperless_ai/indexing.py
Normal file
@ -0,0 +1,281 @@
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import faiss
|
||||
import llama_index.core.settings as llama_settings
|
||||
import tqdm
|
||||
from django.conf import settings
|
||||
from llama_index.core import Document as LlamaDocument
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core import load_index_from_storage
|
||||
from llama_index.core.indices.prompt_helper import PromptHelper
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
logger = logging.getLogger("paperless_ai.indexing")
|
||||
|
||||
|
||||
def get_or_create_storage_context(*, rebuild=False):
|
||||
"""
|
||||
Loads or creates the StorageContext (vector store, docstore, index store).
|
||||
If rebuild=True, deletes and recreates everything.
|
||||
"""
|
||||
if rebuild:
|
||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||
embedding_dim = get_embedding_dim()
|
||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
docstore = SimpleDocumentStore()
|
||||
index_store = SimpleIndexStore()
|
||||
else:
|
||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
|
||||
return StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
index_store=index_store,
|
||||
vector_store=vector_store,
|
||||
persist_dir=settings.LLM_INDEX_DIR,
|
||||
)
|
||||
|
||||
|
||||
def build_document_node(document: Document) -> list[BaseNode]:
|
||||
"""
|
||||
Given a Document, returns parsed Nodes ready for indexing.
|
||||
"""
|
||||
text = build_llm_index_text(document)
|
||||
metadata = {
|
||||
"document_id": str(document.id),
|
||||
"title": document.title,
|
||||
"tags": [t.name for t in document.tags.all()],
|
||||
"correspondent": document.correspondent.name
|
||||
if document.correspondent
|
||||
else None,
|
||||
"document_type": document.document_type.name
|
||||
if document.document_type
|
||||
else None,
|
||||
"created": document.created.isoformat() if document.created else None,
|
||||
"added": document.added.isoformat() if document.added else None,
|
||||
"modified": document.modified.isoformat(),
|
||||
}
|
||||
doc = LlamaDocument(text=text, metadata=metadata)
|
||||
parser = SimpleNodeParser()
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
def load_or_build_index(nodes=None):
|
||||
"""
|
||||
Load an existing VectorStoreIndex if present,
|
||||
or build a new one using provided nodes if storage is empty.
|
||||
"""
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context()
|
||||
try:
|
||||
return load_index_from_storage(storage_context=storage_context)
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to load index from storage: %s", e)
|
||||
if not nodes:
|
||||
logger.info("No nodes provided for index creation.")
|
||||
raise
|
||||
return VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
|
||||
def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex):
|
||||
"""
|
||||
Removes existing documents from docstore for a given document from the index.
|
||||
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||
"""
|
||||
all_node_ids = list(index.docstore.docs.keys())
|
||||
existing_nodes = [
|
||||
node.node_id
|
||||
for node in index.docstore.get_nodes(all_node_ids)
|
||||
if node.metadata.get("document_id") == str(document.id)
|
||||
]
|
||||
for node_id in existing_nodes:
|
||||
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
index.docstore.delete_document(node_id)
|
||||
|
||||
|
||||
def vector_store_file_exists():
|
||||
"""
|
||||
Check if the vector store file exists in the LLM index directory.
|
||||
"""
|
||||
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
|
||||
|
||||
|
||||
def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
|
||||
"""
|
||||
Rebuild or update the LLM index.
|
||||
"""
|
||||
nodes = []
|
||||
|
||||
documents = Document.objects.all()
|
||||
if not documents.exists():
|
||||
msg = "No documents found to index."
|
||||
logger.warning(msg)
|
||||
return msg
|
||||
|
||||
if rebuild or not vector_store_file_exists():
|
||||
# Rebuild index from scratch
|
||||
logger.info("Rebuilding LLM index.")
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context(rebuild=True)
|
||||
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
|
||||
document_nodes = build_document_node(document)
|
||||
nodes.extend(document_nodes)
|
||||
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
show_progress=not progress_bar_disable,
|
||||
)
|
||||
msg = "LLM index rebuilt successfully."
|
||||
else:
|
||||
# Update existing index
|
||||
index = load_or_build_index()
|
||||
all_node_ids = list(index.docstore.docs.keys())
|
||||
existing_nodes = {
|
||||
node.metadata.get("document_id"): node
|
||||
for node in index.docstore.get_nodes(all_node_ids)
|
||||
}
|
||||
|
||||
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
|
||||
doc_id = str(document.id)
|
||||
document_modified = document.modified.isoformat()
|
||||
|
||||
if doc_id in existing_nodes:
|
||||
node = existing_nodes[doc_id]
|
||||
node_modified = node.metadata.get("modified")
|
||||
|
||||
if node_modified == document_modified:
|
||||
continue
|
||||
|
||||
# Again, delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
index.docstore.delete_document(node.node_id)
|
||||
nodes.extend(build_document_node(document))
|
||||
else:
|
||||
# New document, add it
|
||||
nodes.extend(build_document_node(document))
|
||||
|
||||
if nodes:
|
||||
msg = "LLM index updated successfully."
|
||||
logger.info(
|
||||
"Updating %d nodes in LLM index.",
|
||||
len(nodes),
|
||||
)
|
||||
index.insert_nodes(nodes)
|
||||
else:
|
||||
msg = "No changes detected in LLM index."
|
||||
logger.info(msg)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
return msg
|
||||
|
||||
|
||||
def llm_index_add_or_update_document(document: Document):
|
||||
"""
|
||||
Adds or updates a document in the LLM index.
|
||||
If the document already exists, it will be replaced.
|
||||
"""
|
||||
new_nodes = build_document_node(document)
|
||||
|
||||
index = load_or_build_index(nodes=new_nodes)
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.insert_nodes(new_nodes)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def llm_index_remove_document(document: Document):
|
||||
"""
|
||||
Removes a document from the LLM index.
|
||||
"""
|
||||
index = load_or_build_index()
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def truncate_content(content: str) -> str:
|
||||
prompt_helper = PromptHelper(
|
||||
context_window=8192,
|
||||
num_output=512,
|
||||
chunk_overlap_ratio=0.1,
|
||||
chunk_size_limit=None,
|
||||
)
|
||||
splitter = TokenTextSplitter(separator=" ", chunk_size=512, chunk_overlap=50)
|
||||
content_chunks = splitter.split_text(content)
|
||||
truncated_chunks = prompt_helper.truncate(
|
||||
prompt=PromptTemplate(template="{content}"),
|
||||
text_chunks=content_chunks,
|
||||
padding=5,
|
||||
)
|
||||
return " ".join(truncated_chunks)
|
||||
|
||||
|
||||
def query_similar_documents(
|
||||
document: Document,
|
||||
top_k: int = 5,
|
||||
document_ids: list[int] | None = None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Runs a similarity query and returns top-k similar Document objects.
|
||||
"""
|
||||
index = load_or_build_index()
|
||||
|
||||
# constrain only the node(s) that match the document IDs, if given
|
||||
doc_node_ids = (
|
||||
[
|
||||
node.node_id
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in document_ids
|
||||
]
|
||||
if document_ids
|
||||
else None
|
||||
)
|
||||
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=top_k,
|
||||
doc_ids=doc_node_ids,
|
||||
)
|
||||
|
||||
query_text = truncate_content(
|
||||
(document.title or "") + "\n" + (document.content or ""),
|
||||
)
|
||||
results = retriever.retrieve(query_text)
|
||||
|
||||
document_ids = [
|
||||
int(node.metadata["document_id"])
|
||||
for node in results
|
||||
if "document_id" in node.metadata
|
||||
]
|
||||
|
||||
return list(Document.objects.filter(pk__in=document_ids))
|
100
src/paperless_ai/matching.py
Normal file
100
src/paperless_ai/matching.py
Normal file
@ -0,0 +1,100 @@
|
||||
import difflib
|
||||
import logging
|
||||
import re
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
|
||||
MATCH_THRESHOLD = 0.8
|
||||
|
||||
logger = logging.getLogger("paperless_ai.matching")
|
||||
|
||||
|
||||
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_tag"],
|
||||
Tag,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_correspondent"],
|
||||
Correspondent,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_documenttype"],
|
||||
DocumentType,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_storagepath"],
|
||||
StoragePath,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def _normalize(s: str) -> str:
|
||||
s = s.lower()
|
||||
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
|
||||
s = s.strip()
|
||||
return s
|
||||
|
||||
|
||||
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
results = []
|
||||
objects = list(queryset)
|
||||
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
|
||||
|
||||
for name in names:
|
||||
if not name:
|
||||
continue
|
||||
target = _normalize(name)
|
||||
|
||||
# First try exact match
|
||||
if target in object_names:
|
||||
index = object_names.index(target)
|
||||
results.append(objects[index])
|
||||
# Remove the matched name from the list to avoid fuzzy matching later
|
||||
object_names.remove(target)
|
||||
continue
|
||||
|
||||
# Fuzzy match fallback
|
||||
matches = difflib.get_close_matches(
|
||||
target,
|
||||
object_names,
|
||||
n=1,
|
||||
cutoff=MATCH_THRESHOLD,
|
||||
)
|
||||
if matches:
|
||||
index = object_names.index(matches[0])
|
||||
results.append(objects[index])
|
||||
else:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
def extract_unmatched_names(
|
||||
names: list[str],
|
||||
matched_objects: list,
|
||||
attr="name",
|
||||
) -> list[str]:
|
||||
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
|
||||
return [name for name in names if name.lower() not in matched_names]
|
0
src/paperless_ai/tests/__init__.py
Normal file
0
src/paperless_ai/tests/__init__.py
Normal file
248
src/paperless_ai/tests/test_ai_classifier.py
Normal file
248
src/paperless_ai/tests/test_ai_classifier.py
Normal file
@ -0,0 +1,248 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.ai_classifier import get_context_for_document
|
||||
from paperless_ai.ai_classifier import parse_ai_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.title = "Test Title"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.created = "2023-01-01"
|
||||
doc.added = "2023-01-02"
|
||||
doc.modified = "2023-01-03"
|
||||
|
||||
tag1 = MagicMock()
|
||||
tag1.name = "Tag1"
|
||||
tag2 = MagicMock()
|
||||
tag2.name = "Tag2"
|
||||
doc.tags.all = MagicMock(return_value=[tag1, tag2])
|
||||
|
||||
doc.document_type = MagicMock()
|
||||
doc.document_type.name = "Invoice"
|
||||
doc.correspondent = MagicMock()
|
||||
doc.correspondent.name = "Test Correspondent"
|
||||
doc.archive_serial_number = "12345"
|
||||
doc.content = "This is the document content."
|
||||
|
||||
cf1 = MagicMock(__str__=lambda x: "Value1")
|
||||
cf1.field = MagicMock()
|
||||
cf1.field.name = "Field1"
|
||||
cf1.value = "Value1"
|
||||
cf2 = MagicMock(__str__=lambda x: "Value2")
|
||||
cf2.field = MagicMock()
|
||||
cf2.field.name = "Field2"
|
||||
cf2.value = "Value2"
|
||||
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_similar_documents():
|
||||
doc1 = MagicMock()
|
||||
doc1.content = "Content of document 1"
|
||||
doc1.title = "Title 1"
|
||||
doc1.filename = "file1.txt"
|
||||
|
||||
doc2 = MagicMock()
|
||||
doc2.content = "Content of document 2"
|
||||
doc2.title = None
|
||||
doc2.filename = "file2.txt"
|
||||
|
||||
doc3 = MagicMock()
|
||||
doc3.content = None
|
||||
doc3.title = None
|
||||
doc3.filename = None
|
||||
|
||||
return [doc1, doc2, doc3]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.return_value.text = json.dumps(
|
||||
{
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
|
||||
assert result["title"] == "Test Title"
|
||||
assert result["tags"] == ["test", "document"]
|
||||
assert result["correspondents"] == ["John Doe"]
|
||||
assert result["document_types"] == ["report"]
|
||||
assert result["storage_paths"] == ["Reports"]
|
||||
assert result["dates"] == ["2023-01-01"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_fallback_parse_success(
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_run_llm_query.return_value.text = """
|
||||
There is some text before the JSON.
|
||||
```json
|
||||
{
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"]
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
|
||||
assert result["title"] == "Test Title"
|
||||
assert result["tags"] == ["test", "document"]
|
||||
assert result["correspondents"] == ["John Doe"]
|
||||
assert result["document_types"] == ["report"]
|
||||
assert result["storage_paths"] == ["Reports"]
|
||||
assert result["dates"] == ["2023-01-01"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_parse_failure(
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_run_llm_query.return_value.text = "Invalid JSON response"
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
||||
|
||||
# assert raises an exception
|
||||
with pytest.raises(Exception):
|
||||
get_ai_document_classification(mock_document)
|
||||
|
||||
|
||||
def test_parse_llm_classification_response_invalid_json():
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Invalid JSON response"
|
||||
|
||||
result = parse_ai_response(mock_response)
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_EMBEDDING_MODEL="some_model",
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_use_rag_if_configured(
|
||||
mock_build_prompt_with_rag,
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
|
||||
mock_run_llm_query.return_value.text = json.dumps({})
|
||||
get_ai_document_classification(mock_document)
|
||||
mock_build_prompt_with_rag.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@patch("paperless_ai.ai_classifier.build_prompt_without_rag")
|
||||
@patch("paperless.config.AIConfig")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_use_without_rag_if_not_configured(
|
||||
mock_ai_config,
|
||||
mock_build_prompt_without_rag,
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_ai_config.llm_embedding_backend = None
|
||||
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
|
||||
mock_run_llm_query.return_value.text = json.dumps({})
|
||||
get_ai_document_classification(mock_document)
|
||||
mock_build_prompt_without_rag.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_prompt_with_without_rag(mock_document):
|
||||
with patch(
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
||||
|
||||
|
||||
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||
def test_get_context_for_document(
|
||||
mock_query_similar_documents,
|
||||
mock_document,
|
||||
mock_similar_documents,
|
||||
):
|
||||
mock_query_similar_documents.return_value = mock_similar_documents
|
||||
|
||||
result = get_context_for_document(mock_document, max_docs=2)
|
||||
|
||||
expected_result = (
|
||||
"TITLE: Title 1\nContent of document 1\n\n"
|
||||
"TITLE: file2.txt\nContent of document 2"
|
||||
)
|
||||
assert result == expected_result
|
||||
mock_query_similar_documents.assert_called_once()
|
||||
|
||||
|
||||
def test_get_context_for_document_no_similar_docs(mock_document):
|
||||
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||
result = get_context_for_document(mock_document)
|
||||
assert result == ""
|
295
src/paperless_ai/tests/test_ai_indexing.py
Normal file
295
src/paperless_ai/tests/test_ai_indexing.py
Normal file
@ -0,0 +1,295 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai import indexing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path):
|
||||
original_dir = indexing.settings.LLM_INDEX_DIR
|
||||
indexing.settings.LLM_INDEX_DIR = tmp_path
|
||||
yield tmp_path
|
||||
indexing.settings.LLM_INDEX_DIR = original_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_document(db):
|
||||
return Document.objects.create(
|
||||
title="Test Document",
|
||||
content="This is some test content.",
|
||||
added=timezone.now(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
with patch("paperless_ai.indexing.get_embedding_model") as mock:
|
||||
mock.return_value = FakeEmbedding()
|
||||
yield mock
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
# TODO: maybe a better way to do this?
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384 # Match your real FAISS config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node(real_document):
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
assert nodes[0].metadata["document_id"] == str(real_document.id)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_partial_update(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
doc2 = Document.objects.create(
|
||||
title="Test Document 2",
|
||||
content="This is some test content 2.",
|
||||
added=timezone.now(),
|
||||
checksum="1234567890abcdef",
|
||||
)
|
||||
# Initial index
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document, doc2])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
# modify document
|
||||
updated_document = real_document
|
||||
updated_document.modified = timezone.now() # simulate modification
|
||||
|
||||
# new doc
|
||||
doc3 = Document.objects.create(
|
||||
title="Test Document 3",
|
||||
content="This is some test content 3.",
|
||||
added=timezone.now(),
|
||||
checksum="abcdef1234567890",
|
||||
)
|
||||
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
|
||||
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
mock_logger.info.assert_called_once_with(
|
||||
"Updating %d nodes in LLM index.",
|
||||
2,
|
||||
)
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
def test_get_or_create_storage_context_raises_exception(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
indexing.get_or_create_storage_context(rebuild=False)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_load_or_build_index_builds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[indexing.build_document_node(real_document)],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_load_or_build_index_raises_exception_when_no_nodes(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
indexing.load_or_build_index()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_load_or_build_index_succeeds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[MagicMock()],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_add_or_update_document_updates_existing_entry(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_remove_document_deletes_node_from_docstore(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 1
|
||||
|
||||
indexing.llm_index_remove_document(real_document)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 0
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_no_documents(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = False
|
||||
mock_queryset.__iter__.return_value = iter([])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
# check log message
|
||||
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"No documents found to index.",
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_BACKEND="ollama",
|
||||
)
|
||||
def test_query_similar_documents(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
):
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||
):
|
||||
mock_storage.return_value = MagicMock()
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_load_or_build_index.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node1.metadata = {"document_id": 1}
|
||||
|
||||
mock_node2 = MagicMock()
|
||||
mock_node2.metadata = {"document_id": 2}
|
||||
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
|
||||
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
|
||||
mock_filter.return_value = mock_filtered_docs
|
||||
|
||||
result = indexing.query_similar_documents(real_document, top_k=3)
|
||||
|
||||
mock_load_or_build_index.assert_called_once()
|
||||
mock_retriever_cls.assert_called_once()
|
||||
mock_retriever.retrieve.assert_called_once_with(
|
||||
"Test Document\nThis is some test content.",
|
||||
)
|
||||
mock_filter.assert_called_once_with(pk__in=[1, 2])
|
||||
|
||||
assert result == mock_filtered_docs
|
142
src/paperless_ai/tests/test_chat.py
Normal file
142
src/paperless_ai/tests/test_chat.py
Normal file
@ -0,0 +1,142 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embed_model():
|
||||
from llama_index.core import settings as llama_settings
|
||||
|
||||
mock_embed_model = MagicMock()
|
||||
mock_embed_model._get_text_embedding_batch.return_value = [
|
||||
[0.1] * 1536,
|
||||
] # 1 vector per input
|
||||
llama_settings.Settings._embed_model = mock_embed_model
|
||||
yield
|
||||
llama_settings.Settings._embed_model = None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embed_nodes():
|
||||
with patch(
|
||||
"llama_index.core.indices.vector_store.base.embed_nodes",
|
||||
) as mock_embed_nodes:
|
||||
mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
|
||||
node.node_id: [0.1] * 1536 for node in nodes
|
||||
}
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock()
|
||||
doc.pk = 1
|
||||
doc.title = "Test Document"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.content = "This is the document content."
|
||||
return doc
|
||||
|
||||
|
||||
def test_stream_chat_with_one_document_full_content(mock_document):
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_node = TextNode(
|
||||
text="This is node content.",
|
||||
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
|
||||
assert output == ["chunk1", "chunk2"]
|
||||
|
||||
|
||||
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
||||
):
|
||||
# Mock AIClient and LLM
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
# Create two real TextNodes
|
||||
mock_node1 = TextNode(
|
||||
text="Content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1"},
|
||||
)
|
||||
mock_node2 = TextNode(
|
||||
text="Content for doc 2.",
|
||||
metadata={"document_id": "2", "title": "Document 2"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
mock_as_retriever.return_value = mock_retriever
|
||||
|
||||
# Mock response stream
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
|
||||
# Mock RetrieverQueryEngine
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
# Fake documents
|
||||
doc1 = MagicMock(pk=1)
|
||||
doc2 = MagicMock(pk=2)
|
||||
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
assert output == ["chunk1", "chunk2"]
|
||||
|
||||
|
||||
def test_stream_chat_no_matching_nodes():
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_index = MagicMock()
|
||||
# No matching nodes
|
||||
mock_index.docstore.docs.values.return_value = []
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
|
||||
assert output == ["Sorry, I couldn't find any content to answer your question."]
|
94
src/paperless_ai/tests/test_client.py
Normal file
94
src/paperless_ai/tests/test_client.py
Normal file
@ -0,0 +1,94 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import ChatMessage
|
||||
|
||||
from paperless_ai.client import AIClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
|
||||
mock_config = MagicMock()
|
||||
MockAIConfig.return_value = mock_config
|
||||
yield mock_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_llm():
|
||||
with patch("paperless_ai.client.Ollama") as MockOllama:
|
||||
yield MockOllama
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_llm():
|
||||
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
|
||||
yield MockOpenAI
|
||||
|
||||
|
||||
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
|
||||
client = AIClient()
|
||||
|
||||
mock_ollama_llm.assert_called_once_with(
|
||||
model="test_model",
|
||||
base_url="http://test-url",
|
||||
request_timeout=120,
|
||||
)
|
||||
assert client.llm == mock_ollama_llm.return_value
|
||||
|
||||
|
||||
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
||||
mock_ai_config.llm_backend = "openai"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_api_key = "test_api_key"
|
||||
|
||||
client = AIClient()
|
||||
|
||||
mock_openai_llm.assert_called_once_with(
|
||||
model="test_model",
|
||||
api_key="test_api_key",
|
||||
)
|
||||
assert client.llm == mock_openai_llm.return_value
|
||||
|
||||
|
||||
def test_get_llm_unsupported_backend(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "unsupported"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
|
||||
AIClient()
|
||||
|
||||
|
||||
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.complete.return_value = "test_result"
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query("test_prompt")
|
||||
|
||||
mock_llm_instance.complete.assert_called_once_with("test_prompt")
|
||||
assert result == "test_result"
|
||||
|
||||
|
||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.chat.return_value = "test_chat_result"
|
||||
|
||||
client = AIClient()
|
||||
messages = [ChatMessage(role="user", content="Hello")]
|
||||
result = client.run_chat(messages)
|
||||
|
||||
mock_llm_instance.chat.assert_called_once_with(messages)
|
||||
assert result == "test_chat_result"
|
133
src/paperless_ai/tests/test_embedding.py
Normal file
133
src/paperless_ai/tests/test_embedding.py
Normal file
@ -0,0 +1,133 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
||||
yield MockAIConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.title = "Test Title"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.created = "2023-01-01"
|
||||
doc.added = "2023-01-02"
|
||||
doc.modified = "2023-01-03"
|
||||
|
||||
tag1 = MagicMock()
|
||||
tag1.name = "Tag1"
|
||||
tag2 = MagicMock()
|
||||
tag2.name = "Tag2"
|
||||
doc.tags.all = MagicMock(return_value=[tag1, tag2])
|
||||
|
||||
doc.document_type = MagicMock()
|
||||
doc.document_type.name = "Invoice"
|
||||
doc.correspondent = MagicMock()
|
||||
doc.correspondent.name = "Test Correspondent"
|
||||
doc.archive_serial_number = "12345"
|
||||
doc.content = "This is the document content."
|
||||
|
||||
cf1 = MagicMock(__str__=lambda x: "Value1")
|
||||
cf1.field = MagicMock()
|
||||
cf1.field.name = "Field1"
|
||||
cf1.value = "Value1"
|
||||
cf2 = MagicMock(__str__=lambda x: "Value2")
|
||||
cf2.field = MagicMock()
|
||||
cf2.field.name = "Field2"
|
||||
cf2.value = "Value2"
|
||||
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
def test_get_embedding_model_openai(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
|
||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
|
||||
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
)
|
||||
assert model == MockOpenAIEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
|
||||
mock_ai_config.return_value.llm_embedding_model = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"paperless_ai.embedding.HuggingFaceEmbedding",
|
||||
) as MockHuggingFaceEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockHuggingFaceEmbedding.assert_called_once_with(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
assert model == MockHuggingFaceEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Unsupported embedding backend: INVALID_BACKEND",
|
||||
):
|
||||
get_embedding_model()
|
||||
|
||||
|
||||
def test_get_embedding_dim_openai(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
assert get_embedding_dim() == 1536
|
||||
|
||||
|
||||
def test_get_embedding_dim_huggingface(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "huggingface"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
assert get_embedding_dim() == 384
|
||||
|
||||
|
||||
def test_get_embedding_dim_unknown_model(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = "unknown-model"
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown embedding model: unknown-model"):
|
||||
get_embedding_dim()
|
||||
|
||||
|
||||
def test_build_llm_index_text(mock_document):
|
||||
with patch("documents.models.Note.objects.filter") as mock_notes_filter:
|
||||
mock_notes_filter.return_value = [
|
||||
MagicMock(note="Note1"),
|
||||
MagicMock(note="Note2"),
|
||||
]
|
||||
|
||||
result = build_llm_index_text(mock_document)
|
||||
|
||||
assert "Title: Test Title" in result
|
||||
assert "Filename: test_file.pdf" in result
|
||||
assert "Created: 2023-01-01" in result
|
||||
assert "Tags: Tag1, Tag2" in result
|
||||
assert "Document Type: Invoice" in result
|
||||
assert "Correspondent: Test Correspondent" in result
|
||||
assert "Notes: Note1,Note2" in result
|
||||
assert "Content:\n\nThis is the document content." in result
|
||||
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
|
86
src/paperless_ai/tests/test_matching.py
Normal file
86
src/paperless_ai/tests/test_matching.py
Normal file
@ -0,0 +1,86 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from paperless_ai.matching import extract_unmatched_names
|
||||
from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
from paperless_ai.matching import match_storage_paths_by_name
|
||||
from paperless_ai.matching import match_tags_by_name
|
||||
|
||||
|
||||
class TestAIMatching(TestCase):
|
||||
def setUp(self):
|
||||
# Create test data for Tag
|
||||
self.tag1 = Tag.objects.create(name="Test Tag 1")
|
||||
self.tag2 = Tag.objects.create(name="Test Tag 2")
|
||||
|
||||
# Create test data for Correspondent
|
||||
self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1")
|
||||
self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2")
|
||||
|
||||
# Create test data for DocumentType
|
||||
self.document_type1 = DocumentType.objects.create(name="Test Document Type 1")
|
||||
self.document_type2 = DocumentType.objects.create(name="Test Document Type 2")
|
||||
|
||||
# Create test data for StoragePath
|
||||
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
|
||||
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_by_name(self, mock_get_objects):
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = ["Test Tag 1", "Nonexistent Tag"]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Tag 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_correspondents_by_name(self, mock_get_objects):
|
||||
mock_get_objects.return_value = Correspondent.objects.all()
|
||||
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
|
||||
result = match_correspondents_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Correspondent 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_document_types_by_name(self, mock_get_objects):
|
||||
mock_get_objects.return_value = DocumentType.objects.all()
|
||||
names = ["Test Document Type 1", "Nonexistent Document Type"]
|
||||
result = match_document_types_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Document Type 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_storage_paths_by_name(self, mock_get_objects):
|
||||
mock_get_objects.return_value = StoragePath.objects.all()
|
||||
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
|
||||
result = match_storage_paths_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Storage Path 1")
|
||||
|
||||
def test_extract_unmatched_names(self):
|
||||
llm_names = ["Test Tag 1", "Nonexistent Tag"]
|
||||
matched_objects = [self.tag1]
|
||||
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
|
||||
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_by_name_with_empty_names(self, mock_get_objects):
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = [None, "", " "]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_with_fuzzy_matching(self, mock_get_objects):
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = ["Test Taag 1", "Teest Tag 2"]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].name, "Test Tag 1")
|
||||
self.assertEqual(result[1].name, "Test Tag 2")
|
Loading…
x
Reference in New Issue
Block a user