Catalog/imagecatalog/Services/AiExtractionService.cs
MaddoScientisto f57dc1edba
Some checks failed
Build Windows Avalonia / build (push) Failing after 1m38s
Build Windows Avalonia / release (push) Has been skipped
feat: Enhance AI extraction summaries and worker allocation for GPU support
2026-05-09 19:31:21 +02:00

294 lines
11 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using AIFotoONLUS.Core;
using ImageCatalog_2.Models;
using Microsoft.Extensions.Logging;
namespace ImageCatalog_2.Services;
public class AiExtractionService : IAiExtractionService
{
private readonly ILogger<AiExtractionService> _logger;
public AiExtractionService(ILogger<AiExtractionService> logger)
{
_logger = logger;
}
public async Task<AiExtractionRunSummary> RunAsync(
AiExtractionRequest request,
CancellationToken token,
Func<AiResultItem, Task> onResult,
Func<AiExtractionProgressUpdate, Task> onProgress)
{
var searchOption = request.Recursive ? SearchOption.AllDirectories : SearchOption.TopDirectoryOnly;
var imageFiles = Directory.EnumerateFiles(request.SearchRoot, "*.*", searchOption)
.Where(f => f.EndsWith(".jpg", StringComparison.OrdinalIgnoreCase)
|| f.EndsWith(".jpeg", StringComparison.OrdinalIgnoreCase)
|| f.EndsWith(".png", StringComparison.OrdinalIgnoreCase)
|| f.EndsWith(".bmp", StringComparison.OrdinalIgnoreCase)
|| f.EndsWith(".gif", StringComparison.OrdinalIgnoreCase))
.Where(f => request.IncludeThumbnails || !Path.GetFileName(f).StartsWith("tn_", StringComparison.OrdinalIgnoreCase))
.ToList();
var extractedResults = new List<AiResultItem>();
var modelConfiguration = BuildModelConfiguration(request.ModelsFolderPath, request.UseGpu);
var workloadLevel = NormalizeWorkloadLevel(request.WorkloadLevel);
var workerCount = ResolveWorkerCount(request.UseGpu, workloadLevel);
var total = imageFiles.Count;
if (total == 0)
{
var emptySummary = new AiExtractionRunSummary(0, 0, 0, 0, workloadLevel, workerCount, request.UseGpu);
await onProgress(new AiExtractionProgressUpdate(0, 0, 100, 0, workloadLevel, workerCount, request.UseGpu)).ConfigureAwait(false);
return emptySummary;
}
var processed = 0;
var failed = 0;
Exception? firstFailure = null;
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
var resultChannel = Channel.CreateUnbounded<AiResultItem>(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = false
});
var fileChannel = Channel.CreateBounded<string>(new BoundedChannelOptions(Math.Max(workerCount * 2, 1))
{
SingleReader = false,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait
});
var failureLock = new object();
var logLock = new object();
var lastLoggedElapsed = TimeSpan.Zero;
var reporterTask = Task.Run(async () =>
{
await foreach (var result in resultChannel.Reader.ReadAllAsync(token).ConfigureAwait(false))
{
extractedResults.Add(result);
await onResult(result).ConfigureAwait(false);
var currentProcessed = Interlocked.Increment(ref processed);
var averageImagesPerSecond = CalculateAverageImagesPerSecond(currentProcessed, stopwatch.Elapsed);
var percent = currentProcessed * 100.0 / total;
await onProgress(new AiExtractionProgressUpdate(total, currentProcessed, percent, averageImagesPerSecond, workloadLevel, workerCount, request.UseGpu)).ConfigureAwait(false);
var shouldLog = false;
lock (logLock)
{
if (currentProcessed == total || stopwatch.Elapsed - lastLoggedElapsed >= TimeSpan.FromSeconds(2))
{
lastLoggedElapsed = stopwatch.Elapsed;
shouldLog = true;
}
}
if (shouldLog)
{
_logger.LogInformation(
"Number AI progress: {Processed}/{Total} ({Percent:F1}%), {ImagesPerSecond:F2} img/s avg, workload {WorkloadLevel} ({WorkerCount} {ExecutionUnit})",
currentProcessed,
total,
percent,
averageImagesPerSecond,
workloadLevel,
workerCount,
request.UseGpu ? "batch" : "workers");
}
}
}, token);
try
{
if (request.UseGpu)
{
using var engine = new NumberRecognitionEngine(modelConfiguration, _logger);
var resultProgress = new SynchronousProgress<ImageResult>(result =>
{
resultChannel.Writer.TryWrite(new AiResultItem { Path = result.FilePath, Text = result.Text });
});
await engine.ProcessFilesAsync(
imageFiles,
skipTextNegative: false,
maxDegreeOfParallelism: workerCount,
progress: null,
resultProgress: resultProgress,
cancellationToken: token).ConfigureAwait(false);
}
else
{
var workerTasks = Enumerable.Range(0, workerCount)
.Select(_ => Task.Run(async () =>
{
using var engine = new NumberRecognitionEngine(modelConfiguration, _logger);
await foreach (var file in fileChannel.Reader.ReadAllAsync(token).ConfigureAwait(false))
{
var extracted = string.Empty;
try
{
extracted = engine.ProcessImage(file).Text;
}
catch (Exception ex)
{
lock (failureLock)
{
failed++;
firstFailure ??= ex;
}
_logger.LogWarning(ex, "Error processing AI OCR for {File}", file);
}
await resultChannel.Writer.WriteAsync(new AiResultItem { Path = file, Text = extracted }, token).ConfigureAwait(false);
}
}, token))
.ToArray();
foreach (var file in imageFiles)
{
await fileChannel.Writer.WriteAsync(file, token).ConfigureAwait(false);
}
fileChannel.Writer.TryComplete();
await Task.WhenAll(workerTasks).ConfigureAwait(false);
}
}
finally
{
fileChannel.Writer.TryComplete();
resultChannel.Writer.TryComplete();
await reporterTask.ConfigureAwait(false);
}
if (imageFiles.Count > 0 && failed == imageFiles.Count)
{
throw new InvalidOperationException($"AI OCR failed for all {imageFiles.Count} image(s). See previous log entries for details.", firstFailure);
}
var summary = new AiExtractionRunSummary(
total,
processed,
failed,
CalculateAverageImagesPerSecond(processed, stopwatch.Elapsed),
workloadLevel,
workerCount,
request.UseGpu);
_logger.LogInformation(
"Number AI completed: {Processed}/{Total} processed, {Failed} failures, {ImagesPerSecond:F2} img/s avg, workload {WorkloadLevel} ({WorkerCount} {ExecutionUnit})",
summary.ProcessedFiles,
summary.TotalFiles,
summary.FailedFiles,
summary.AverageImagesPerSecond,
summary.WorkloadLevel,
summary.WorkerCount,
request.UseGpu ? "batch" : "workers");
if (!string.IsNullOrWhiteSpace(request.CsvOutputPath))
{
try
{
var dir = Path.GetDirectoryName(request.CsvOutputPath) ?? string.Empty;
if (!string.IsNullOrWhiteSpace(dir) && !Directory.Exists(dir))
{
Directory.CreateDirectory(dir);
}
using var sw = new StreamWriter(request.CsvOutputPath, false, Encoding.UTF8);
sw.WriteLine("Path,Text");
foreach (var r in extractedResults)
{
var csvFileName = Path.GetFileName(r.Path ?? string.Empty);
var safeText = (r.Text ?? string.Empty).Replace("\"", "\"\"");
sw.WriteLine($"\"{csvFileName}\",\"{safeText}\"");
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to write CSV to {CsvOutputPath}", request.CsvOutputPath);
}
}
return summary;
}
private static double CalculateAverageImagesPerSecond(int processed, TimeSpan elapsed)
{
return elapsed.TotalSeconds > 0 ? processed / elapsed.TotalSeconds : 0;
}
private static int NormalizeWorkloadLevel(int workloadLevel)
{
return Math.Clamp(workloadLevel, 1, 5);
}
private static int ResolveWorkerCount(bool useGpu, int workloadLevel)
{
var normalized = NormalizeWorkloadLevel(workloadLevel);
var maxWorkers = Math.Max(1, Environment.ProcessorCount);
var requestedWorkers = useGpu
? normalized switch
{
1 => 4,
2 => 8,
3 => 16,
4 => 24,
_ => 32
}
: normalized switch
{
1 => 1,
2 => 2,
3 => 3,
4 => 4,
_ => 5
};
return useGpu ? requestedWorkers : Math.Min(requestedWorkers, maxWorkers);
}
private sealed class SynchronousProgress<T> : IProgress<T>
{
private readonly Action<T> _handler;
public SynchronousProgress(Action<T> handler)
{
_handler = handler;
}
public void Report(T value) => _handler(value);
}
private static ModelConfiguration BuildModelConfiguration(string modelsFolderPath, bool useGpu)
{
if (string.IsNullOrWhiteSpace(modelsFolderPath))
{
throw new InvalidOperationException("AI models folder is not configured.");
}
var modelsRoot = Path.GetFullPath(modelsFolderPath.Trim().Trim('"'));
if (!Directory.Exists(modelsRoot))
{
throw new DirectoryNotFoundException($"AI models folder not found: {modelsRoot}");
}
return new ModelConfiguration
{
DetectionCfg = Path.Combine(modelsRoot, "detection.cfg"),
DetectionWeights = Path.Combine(modelsRoot, "detection.weights"),
RecognitionCfg = Path.Combine(modelsRoot, "recognition.cfg"),
RecognitionWeights = Path.Combine(modelsRoot, "recognition.weights"),
UseGpu = useGpu
};
}
}