257 lines
11 KiB
TypeScript
257 lines
11 KiB
TypeScript
import React, { useState, useEffect, useRef, useCallback } from 'react';
|
|
import { UploadedImage } from '../types';
|
|
import { segmentSubject } from '../services/geminiService';
|
|
import { ArrowLeftIcon } from './icons/ArrowLeftIcon';
|
|
import { ArrowRightIcon } from './icons/ArrowRightIcon';
|
|
import { BrushIcon } from './icons/BrushIcon';
|
|
import { EraserIcon } from './icons/EraserIcon';
|
|
import { useLogger } from '../contexts/LoggingContext';
|
|
|
|
interface ImageSegmenterProps {
|
|
images: UploadedImage[];
|
|
onComplete: (images: UploadedImage[]) => void;
|
|
onBack: () => void;
|
|
}
|
|
|
|
type EditorMode = 'brush' | 'eraser';
|
|
|
|
const ImageSegmenter: React.FC<ImageSegmenterProps> = ({ images, onComplete, onBack }) => {
|
|
const [internalImages, setInternalImages] = useState<UploadedImage[]>(images);
|
|
const [loadingStates, setLoadingStates] = useState<Record<number, boolean>>({});
|
|
const [errorStates, setErrorStates] = useState<Record<number, string | null>>({});
|
|
const [activeIndex, setActiveIndex] = useState<number>(0);
|
|
|
|
const [mode, setMode] = useState<EditorMode>('brush');
|
|
const [brushSize, setBrushSize] = useState<number>(20);
|
|
|
|
const canvasRef = useRef<HTMLCanvasElement>(null);
|
|
const imageRef = useRef<HTMLImageElement | null>(null);
|
|
const maskRef = useRef<HTMLImageElement | null>(null);
|
|
const isDrawing = useRef<boolean>(false);
|
|
const { log } = useLogger();
|
|
|
|
const generateMask = useCallback(async (index: number) => {
|
|
setLoadingStates(prev => ({ ...prev, [index]: true }));
|
|
setErrorStates(prev => ({ ...prev, [index]: null }));
|
|
log('info', `Requesting segmentation mask for image ${index + 1} ("${internalImages[index].subjectDescription}").`);
|
|
try {
|
|
const maskBase64 = await segmentSubject(internalImages[index].file, internalImages[index].subjectDescription);
|
|
setInternalImages(prev => {
|
|
const updated = [...prev];
|
|
updated[index].maskDataUrl = `data:image/png;base64,${maskBase64}`;
|
|
return updated;
|
|
});
|
|
log('success', `Successfully received segmentation mask for image ${index + 1}.`);
|
|
} catch (e) {
|
|
const errorMessage = e instanceof Error ? e.message : 'Mask generation failed';
|
|
|
|
let displayError = "Failed";
|
|
if (errorMessage.includes("The AI returned a message")) {
|
|
displayError = "AI Response Error";
|
|
} else if (errorMessage.includes("No segmentation mask")) {
|
|
displayError = "No Mask Found";
|
|
}
|
|
|
|
setErrorStates(prev => ({ ...prev, [index]: displayError }));
|
|
log('error', `Failed to generate mask for image ${index + 1}: ${errorMessage}`);
|
|
} finally {
|
|
setLoadingStates(prev => ({ ...prev, [index]: false }));
|
|
}
|
|
}, [internalImages, log]);
|
|
|
|
useEffect(() => {
|
|
internalImages.forEach((image, index) => {
|
|
if (!image.maskDataUrl && !loadingStates[index] && !errorStates[index]) {
|
|
generateMask(index);
|
|
}
|
|
});
|
|
}, [internalImages, generateMask, loadingStates, errorStates]);
|
|
|
|
const draw = useCallback(() => {
|
|
const canvas = canvasRef.current;
|
|
const originalImage = imageRef.current;
|
|
const maskImage = maskRef.current;
|
|
if (!canvas || !originalImage || !maskImage) return;
|
|
|
|
const ctx = canvas.getContext('2d');
|
|
if (!ctx) return;
|
|
|
|
if (originalImage.naturalWidth === 0 || maskImage.naturalWidth === 0 || !originalImage.complete) {
|
|
return;
|
|
}
|
|
|
|
const { naturalWidth: w, naturalHeight: h } = originalImage;
|
|
if(canvas.width !== w) canvas.width = w;
|
|
if(canvas.height !== h) canvas.height = h;
|
|
|
|
ctx.clearRect(0, 0, w, h);
|
|
|
|
// Use a temporary canvas for the overlay so we don't mess up the main canvas's state
|
|
const overlayCanvas = document.createElement('canvas');
|
|
overlayCanvas.width = w;
|
|
overlayCanvas.height = h;
|
|
const overlayCtx = overlayCanvas.getContext('2d');
|
|
if (!overlayCtx) return;
|
|
|
|
// Fill the overlay with a semi-transparent black
|
|
overlayCtx.fillStyle = 'rgba(0, 0, 0, 0.6)';
|
|
overlayCtx.fillRect(0, 0, w, h);
|
|
|
|
// Use 'destination-out' to punch a hole in the overlay where the mask is white
|
|
overlayCtx.globalCompositeOperation = 'destination-out';
|
|
overlayCtx.drawImage(maskImage, 0, 0, w, h);
|
|
|
|
// Draw the original image on the main canvas
|
|
ctx.drawImage(originalImage, 0, 0, w, h);
|
|
|
|
// Draw the overlay (with the hole punched out) on top
|
|
ctx.drawImage(overlayCanvas, 0, 0);
|
|
}, []);
|
|
|
|
useEffect(() => {
|
|
const activeImage = internalImages[activeIndex];
|
|
if (activeImage?.previewUrl && activeImage?.maskDataUrl) {
|
|
const originalImage = new Image();
|
|
const maskImage = new Image();
|
|
|
|
imageRef.current = originalImage;
|
|
maskRef.current = maskImage;
|
|
|
|
originalImage.src = activeImage.previewUrl;
|
|
maskImage.src = activeImage.maskDataUrl;
|
|
|
|
const loadImages = Promise.all([
|
|
new Promise((res, rej) => { originalImage.onload = res; originalImage.onerror = rej; }),
|
|
new Promise((res, rej) => { maskImage.onload = res; maskImage.onerror = rej; })
|
|
]);
|
|
|
|
loadImages.then(() => {
|
|
draw();
|
|
}).catch(err => {
|
|
console.error("Error loading images for canvas: ", err);
|
|
log('error', `Canvas Error: Failed to load images for editor view. ${err}`);
|
|
});
|
|
}
|
|
}, [activeIndex, internalImages, draw, log]);
|
|
|
|
const handleCanvasInteraction = (e: React.MouseEvent<HTMLCanvasElement> | React.TouchEvent<HTMLCanvasElement>) => {
|
|
if (!isDrawing.current && e.type !== 'mousedown' && e.type !== 'touchstart') return;
|
|
|
|
const canvas = canvasRef.current;
|
|
if (!canvas || !maskRef.current) return;
|
|
|
|
const tempCanvas = document.createElement('canvas');
|
|
tempCanvas.width = maskRef.current.naturalWidth;
|
|
tempCanvas.height = maskRef.current.naturalHeight;
|
|
const tempCtx = tempCanvas.getContext('2d');
|
|
if (!tempCtx) return;
|
|
|
|
tempCtx.drawImage(maskRef.current, 0, 0);
|
|
|
|
const rect = canvas.getBoundingClientRect();
|
|
const scaleX = tempCanvas.width / rect.width;
|
|
const scaleY = tempCanvas.height / rect.height;
|
|
|
|
const getCoords = (evt: any) => {
|
|
if (evt.touches) {
|
|
return { x: evt.touches[0].clientX - rect.left, y: evt.touches[0].clientY - rect.top };
|
|
}
|
|
return { x: evt.clientX - rect.left, y: evt.clientY - rect.top };
|
|
}
|
|
const {x, y} = getCoords(e.nativeEvent);
|
|
|
|
tempCtx.fillStyle = mode === 'brush' ? '#FFFFFF' : '#000000';
|
|
tempCtx.beginPath();
|
|
tempCtx.arc(x * scaleX, y * scaleY, (brushSize/2) * scaleX, 0, 2 * Math.PI);
|
|
tempCtx.fill();
|
|
|
|
const newMaskUrl = tempCanvas.toDataURL();
|
|
maskRef.current.src = newMaskUrl;
|
|
|
|
maskRef.current.onload = () => {
|
|
draw();
|
|
// Update state debounced or on mouse up for performance
|
|
if (e.type === 'mouseup' || e.type === 'touchend') {
|
|
setInternalImages(prev => {
|
|
const updated = [...prev];
|
|
updated[activeIndex].maskDataUrl = newMaskUrl;
|
|
return updated;
|
|
});
|
|
}
|
|
};
|
|
};
|
|
|
|
const startDrawing = () => { isDrawing.current = true; };
|
|
const stopDrawing = (e: any) => { isDrawing.current = false; handleCanvasInteraction(e); };
|
|
|
|
const canProceed = internalImages.every(img => img.maskDataUrl);
|
|
|
|
return (
|
|
<div className="w-full">
|
|
<div className="text-center mb-6">
|
|
<h2 className="text-2xl font-bold text-gray-100">Review & Refine Subjects</h2>
|
|
<p className="text-gray-400">The AI has extracted the subjects. Use the tools to refine the selection if needed.</p>
|
|
</div>
|
|
|
|
<div className="flex flex-col lg:flex-row gap-8">
|
|
{/* Thumbnails */}
|
|
<div className="lg:w-1/4 flex lg:flex-col gap-2 overflow-x-auto lg:overflow-y-auto lg:max-h-[500px] p-2 bg-gray-900/50 rounded-lg">
|
|
{internalImages.map((image, index) => (
|
|
<button key={index} onClick={() => setActiveIndex(index)} className={`rounded-lg border-2 transition-all p-1 flex-shrink-0 ${activeIndex === index ? 'border-purple-500' : 'border-transparent hover:border-gray-600'}`}>
|
|
<div className="relative w-24 h-24">
|
|
{loadingStates[index] && <div className="absolute inset-0 bg-black/70 flex items-center justify-center rounded-md"><div className="animate-spin rounded-full h-8 w-8 border-b-2 border-purple-400"></div></div>}
|
|
{errorStates[index] && <div className="absolute inset-0 bg-red-900/80 text-white text-xs text-center flex items-center justify-center p-1 rounded-md">{errorStates[index]}</div>}
|
|
{image.maskDataUrl && <img src={image.maskDataUrl} alt={`mask preview ${index}`} className="w-full h-full object-contain rounded-md bg-black" />}
|
|
{!image.maskDataUrl && !loadingStates[index] && !errorStates[index] && <div className="w-full h-full bg-gray-700 rounded-md flex items-center justify-center text-xs text-gray-400">Waiting...</div>}
|
|
</div>
|
|
</button>
|
|
))}
|
|
</div>
|
|
|
|
{/* Editor */}
|
|
<div className="lg:w-3/4 flex flex-col items-center">
|
|
<div className="w-full flex justify-center items-center mb-4 p-2 bg-gray-700/50 rounded-lg">
|
|
<div className="flex items-center gap-4">
|
|
<button onClick={() => setMode('brush')} className={`p-2 rounded-md transition-colors ${mode === 'brush' ? 'bg-purple-600' : 'bg-gray-600 hover:bg-gray-500'}`}><BrushIcon className="w-6 h-6"/></button>
|
|
<button onClick={() => setMode('eraser')} className={`p-2 rounded-md transition-colors ${mode === 'eraser' ? 'bg-purple-600' : 'bg-gray-600 hover:bg-gray-500'}`}><EraserIcon className="w-6 h-6"/></button>
|
|
<div className="flex items-center gap-2">
|
|
<label htmlFor="brushSize" className="text-sm">Size:</label>
|
|
<input type="range" id="brushSize" min="2" max="100" value={brushSize} onChange={e => setBrushSize(Number(e.target.value))} className="w-32 cursor-pointer"/>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
<canvas
|
|
ref={canvasRef}
|
|
className="rounded-lg max-w-full h-auto"
|
|
onMouseDown={startDrawing}
|
|
onMouseUp={stopDrawing}
|
|
onMouseMove={handleCanvasInteraction}
|
|
onMouseLeave={stopDrawing}
|
|
onTouchStart={startDrawing}
|
|
onTouchEnd={stopDrawing}
|
|
onTouchMove={handleCanvasInteraction}
|
|
/>
|
|
</div>
|
|
</div>
|
|
|
|
<div className="flex flex-col sm:flex-row justify-between items-center gap-4 pt-8 mt-6 border-t border-gray-700">
|
|
<button
|
|
type="button"
|
|
onClick={onBack}
|
|
className="bg-gray-600 hover:bg-gray-500 text-white font-bold py-3 px-6 rounded-lg transition-colors duration-300 flex items-center gap-2 w-full sm:w-auto justify-center"
|
|
>
|
|
<ArrowLeftIcon className="w-5 h-5"/> Back
|
|
</button>
|
|
<button
|
|
onClick={() => onComplete(internalImages)}
|
|
disabled={!canProceed}
|
|
className="bg-purple-600 hover:bg-purple-700 disabled:bg-gray-600 disabled:cursor-not-allowed text-white font-bold py-3 px-8 rounded-lg transition-all duration-300 text-lg flex items-center gap-2 mx-auto w-full sm:w-auto justify-center"
|
|
>
|
|
Next: Customize Prompt <ArrowRightIcon className="w-5 h-5" />
|
|
</button>
|
|
</div>
|
|
</div>
|
|
);
|
|
};
|
|
|
|
export default ImageSegmenter; |