diff --git a/docs/AI_PROVIDER_IMPLEMENTATION.md b/docs/AI_PROVIDER_IMPLEMENTATION.md deleted file mode 100644 index 6ab8549..0000000 --- a/docs/AI_PROVIDER_IMPLEMENTATION.md +++ /dev/null @@ -1,365 +0,0 @@ -# AI Provider Implementation Summary - -## Overview - -VisionForge now supports **two AI providers** for the chatbot functionality: -1. **Gemini** (Google's Generative AI) - Original provider -2. **Claude** (Anthropic's Claude AI) - New addition - -Users can switch between providers using a simple environment variable configuration. - ---- - -## Files Created - -### 1. Claude Service (`project/block_manager/services/claude_service.py`) -- **Purpose:** Implements Claude AI integration mirroring Gemini's functionality -- **Key Features:** - - Chat with conversation history - - Workflow modification suggestions - - File upload support (images, PDFs, text) - - Architecture generation from files - - Improvement suggestions - -**Model Used:** `claude-3-5-sonnet-20241022` (Latest Claude Sonnet) - -**Key Differences from Gemini:** -- Uses Anthropic SDK (`anthropic` package) -- File handling: Base64 encoding for images/PDFs instead of File API -- Message format: Direct user/assistant roles (no conversion needed) -- Response format: `response.content[0].text` instead of `response.text` - -### 2. AI Service Factory (`project/block_manager/services/ai_service_factory.py`) -- **Purpose:** Provider selection and instantiation -- **Methods:** - - `create_service()`: Returns appropriate service based on `AI_PROVIDER` env var - - `get_provider_name()`: Returns human-readable provider name - -**Error Handling:** -- Validates `AI_PROVIDER` value (must be 'gemini' or 'claude') -- Propagates API key errors from individual services - ---- - -## Files Modified - -### 1. Requirements (`project/requirements.txt`) -**Added:** -``` -anthropic>=0.39.0 -``` - -### 2. Environment Configuration - -#### `.env` -```env -# AI Provider Configuration -AI_PROVIDER=gemini # or 'claude' - -# Gemini AI Configuration -GEMINI_API_KEY=your-key-here - -# Claude AI Configuration -ANTHROPIC_API_KEY=your-key-here -``` - -#### `.env.example` -Same structure with placeholder values. - -### 3. Chat Views (`project/block_manager/views/chat_views.py`) - -**Changes:** -- Replaced direct `GeminiChatService` import with `AIServiceFactory` -- Updated `chat_message()` endpoint: - - Uses factory to create service - - Handles file uploads differently per provider: - - **Gemini:** Uploads to Gemini File API → passes `gemini_file` param - - **Claude:** Reads file content locally → passes `file_content` param - - Provider-agnostic error messages -- Updated `get_suggestions()` endpoint: - - Uses factory instead of direct service instantiation - - Generic error messages - -### 4. Documentation (`docs/CHATBOT_SETUP.md`) - -**Additions:** -- Provider selection guide -- Separate setup instructions for Gemini and Claude -- API key acquisition for both providers -- Provider comparison section -- Troubleshooting for provider switching -- Updated security and privacy sections - ---- - -## Configuration Guide - -### Using Gemini (Default) - -1. **Get API Key:** - - Visit: https://aistudio.google.com/app/apikey - - Create free API key - -2. **Configure `.env`:** - ```env - AI_PROVIDER=gemini - GEMINI_API_KEY=AIzaSy... - ``` - -3. **Restart server:** - ```bash - python manage.py runserver - ``` - -### Using Claude - -1. **Get API Key:** - - Visit: https://console.anthropic.com/ - - Create account and generate API key - -2. **Configure `.env`:** - ```env - AI_PROVIDER=claude - ANTHROPIC_API_KEY=sk-ant-... - ``` - -3. **Install package:** - ```bash - pip install anthropic - ``` - -4. **Restart server:** - ```bash - python manage.py runserver - ``` - -### Switching Providers - -Simply update `AI_PROVIDER` in `.env` and restart the server. No code changes needed! - ---- - -## Technical Implementation Details - -### Architecture Pattern: Factory + Strategy - -``` -User Request - ↓ -chat_views.py - ↓ -AIServiceFactory.create_service() - ↓ - ├─→ GeminiChatService (if AI_PROVIDER=gemini) - └─→ ClaudeChatService (if AI_PROVIDER=claude) - ↓ -AI Provider API - ↓ -Response → Frontend -``` - -### Common Interface - -Both services implement the same interface: -```python -class ChatServiceInterface: - def chat(message, history, modification_mode, workflow_state, **kwargs) - def generate_suggestions(workflow_state) - def _format_workflow_context(workflow_state) - def _build_system_prompt(modification_mode, workflow_state) - def _extract_modifications(response_text) -``` - -### File Upload Handling - -**Gemini Approach:** -1. Save uploaded file to temp location -2. Upload to Gemini File API using `genai.upload_file()` -3. Pass file object to model -4. Clean up temp file - -**Claude Approach:** -1. Read file content directly from Django's UploadedFile -2. Encode images/PDFs as base64 -3. Include in message content array -4. No temp file needed - -### Response Parsing - -Both services use identical regex pattern to extract JSON modifications: -```python -json_pattern = r'```json\s*(\{.*?\})\s*```' -``` - -This ensures consistent modification format regardless of provider. - ---- - -## API Compatibility - -### Request Format (Same for Both) -```json -{ - "message": "Add a Conv2D layer", - "history": [{"role": "user", "content": "..."}], - "modificationMode": true, - "workflowState": {"nodes": [...], "edges": [...]} -} -``` - -### Response Format (Same for Both) -```json -{ - "response": "AI response text...", - "modifications": [ - { - "action": "add_node", - "details": {...}, - "explanation": "..." - } - ] -} -``` - -**Frontend compatibility:** No changes needed! The response format is identical. - ---- - -## Error Handling - -### Configuration Errors - -**Invalid Provider:** -``` -ValueError: Invalid AI_PROVIDER: 'gpt4'. Must be 'gemini' or 'claude'. -``` - -**Missing API Key (Gemini):** -``` -ValueError: GEMINI_API_KEY environment variable is not set -``` - -**Missing API Key (Claude):** -``` -ValueError: ANTHROPIC_API_KEY environment variable is not set -``` - -### Runtime Errors - -Both services handle: -- API communication failures -- Rate limiting -- Invalid file uploads -- Malformed responses - -Errors are logged and returned as user-friendly messages. - ---- - -## Testing Checklist - -- [x] Gemini provider works with chat -- [x] Gemini provider works with file uploads -- [x] Gemini provider works with suggestions -- [ ] Claude provider works with chat (requires API key) -- [ ] Claude provider works with file uploads (requires API key) -- [ ] Claude provider works with suggestions (requires API key) -- [x] Provider switching works -- [x] Error handling for missing API keys -- [x] Error handling for invalid provider -- [x] Documentation updated - ---- - -## Provider Comparison - -| Feature | Gemini | Claude | -|---------|--------|--------| -| **Model** | gemini-2.0-flash | claude-3-5-sonnet-20241022 | -| **Speed** | Very Fast | Fast | -| **Free Tier** | ✅ Yes | ❌ No | -| **Image Support** | ✅ Yes | ✅ Yes | -| **PDF Support** | ✅ Yes | ✅ Yes | -| **Max Tokens** | 8192 | 4096 (configurable) | -| **Reasoning** | Good | Excellent | -| **Code Understanding** | Good | Excellent | -| **Rate Limit (Free)** | 15 RPM | N/A | - ---- - -## Future Enhancements - -### Potential Additions: -1. **OpenAI GPT-4** support -2. **Provider-specific features:** - - Gemini: Grounding with Google Search - - Claude: Extended context (200k tokens) -3. **Provider fallback:** If one fails, try another -4. **Cost tracking:** Monitor API usage per provider -5. **A/B testing:** Compare response quality -6. **Provider-specific prompts:** Optimize for each model's strengths - -### Extension Pattern: -```python -# Add new provider: -# 1. Create service class: NewProviderChatService -# 2. Update AIServiceFactory.create_service() -# 3. Add env vars: NEW_PROVIDER_API_KEY -# 4. Update documentation -``` - ---- - -## Security Considerations - -1. **API Keys:** - - Stored in `.env` (git-ignored) - - Never exposed to frontend - - Validated at service initialization - -2. **Data Privacy:** - - Workflow data sent to external APIs - - User should review provider privacy policies - - No sensitive data should be in workflows - -3. **Rate Limiting:** - - Implement request throttling in production - - Monitor costs (especially for Claude) - - Consider caching common responses - ---- - -## Troubleshooting - -### Issue: "AI service not properly configured" -**Cause:** Missing or invalid `AI_PROVIDER` or API key -**Solution:** -1. Check `.env` file has `AI_PROVIDER=gemini` or `AI_PROVIDER=claude` -2. Verify corresponding API key is set -3. Restart Django server - -### Issue: Provider not switching -**Cause:** Server not restarted after `.env` change -**Solution:** Always restart Django after changing environment variables - -### Issue: File uploads failing with Claude -**Cause:** Unsupported file type or size -**Solution:** -- Check file is image (PNG, JPG, WEBP, GIF) or PDF -- Ensure file is under 10MB -- Review error logs for details - ---- - -## Summary - -This implementation provides: -- ✅ **Flexibility:** Easy provider switching via config -- ✅ **Consistency:** Same API interface for both providers -- ✅ **Maintainability:** Factory pattern for easy extension -- ✅ **Reliability:** Comprehensive error handling -- ✅ **Documentation:** Complete setup and usage guides - -**No frontend changes required** - the implementation is completely transparent to the client. - -Users can now choose the AI provider that best fits their needs, budget, and preferences! diff --git a/docs/BACKEND_NODES_COMPLETE.md b/docs/BACKEND_NODES_COMPLETE.md deleted file mode 100644 index bc7c63a..0000000 --- a/docs/BACKEND_NODES_COMPLETE.md +++ /dev/null @@ -1,371 +0,0 @@ -# Backend PyTorch Nodes - Complete Implementation ✅ - -**Date**: November 9, 2025 -**Status**: ✅ **COMPLETE** -**Nodes Implemented**: 17/17 (100%) - -## Summary - -Successfully implemented all 17 PyTorch node definitions for the VisionForge backend. All nodes are auto-discovered by the registry and fully functional. - -## Implementation Details - -### Nodes Implemented (17 Total) - -#### Input/Output Nodes (2) -1. **Input** (`input.py`) - - Manual shape configuration with DataLoader override - - Default shape: `[1, 3, 224, 224]` - - Accepts connections from DataLoader only - -2. **DataLoader** (`dataloader.py`) - - Configurable batch size, shuffle, num_workers - - Output shape configuration - - Source node (no incoming connections) - -#### Basic Layers (8) -3. **Linear** (`linear.py`) - - Fully connected layer - - Config: `out_features`, `bias` - - Requires 2D input: `[batch, features]` - -4. **Conv2D** (`conv2d.py`) - - 2D convolution layer - - Config: `out_channels`, `kernel_size`, `stride`, `padding`, `dilation` - - Requires 4D input: `[batch, channels, height, width]` - -5. **Flatten** (`flatten.py`) - - Flattens multi-dimensional tensors to 2D - - Config: `start_dim`, `end_dim` - - Preserves batch dimension - -6. **Dropout** (`dropout.py`) - - Regularization layer - - Config: `p` (dropout rate), `inplace` - - Preserves input shape - -7. **BatchNorm2D** (`batchnorm2d.py`) - - Batch normalization for 2D inputs - - Config: `num_features`, `eps`, `momentum`, `affine`, `track_running_stats` - - Requires 4D input, preserves shape - -8. **MaxPool2D** (`maxpool2d.py`) - - Max pooling layer - - Config: `kernel_size`, `stride`, `padding`, `dilation` - - Reduces spatial dimensions - -9. **AvgPool2D** (`avgpool2d.py`) - - Average pooling layer - - Config: `kernel_size`, `stride`, `padding` - - Reduces spatial dimensions - -10. **AdaptiveAvgPool2D** (`adaptiveavgpool2d.py`) - - Adaptive pooling to fixed output size - - Config: `output_size` (e.g., "1" or "[7, 7]") - - Often used before classifiers - -#### Advanced Layers (5) -11. **Conv1D** (`conv1d.py`) - - 1D convolution for sequential data - - Config: Same as Conv2D - - Requires 3D input: `[batch, channels, length]` - -12. **Conv3D** (`conv3d.py`) - - 3D convolution for volumetric data - - Config: Same as Conv2D - - Requires 5D input: `[batch, channels, depth, height, width]` - -13. **LSTM** (`lstm.py`) - - Long Short-Term Memory layer - - Config: `hidden_size`, `num_layers`, `bias`, `batch_first`, `dropout`, `bidirectional` - - Requires 3D input: `[batch, sequence, features]` - -14. **GRU** (`gru.py`) - - Gated Recurrent Unit layer - - Config: Same as LSTM - - Requires 3D input: `[batch, sequence, features]` - -15. **Embedding** (`embedding.py`) - - Token embedding layer - - Config: `num_embeddings`, `embedding_dim`, `padding_idx`, `max_norm`, `scale_grad_by_freq` - - Input: `[batch, sequence]` of token indices - - Output: `[batch, sequence, embedding_dim]` - -#### Merge Layers (2) -16. **Concat** (`concat.py`) - - Concatenates multiple tensors along a dimension - - Config: `dim` (concatenation dimension) - - Allows multiple inputs (`allows_multiple_inputs = True`) - -17. **Add** (`add.py`) - - Element-wise addition of tensors - - No configuration needed - - Allows multiple inputs (`allows_multiple_inputs = True`) - - All inputs must have same shape - -## Architecture Patterns - -### Shape Computation -All nodes implement `compute_output_shape()`: -```python -def compute_output_shape( - self, - input_shape: Optional[TensorShape], - config: Dict[str, Any] -) -> Optional[TensorShape]: - # Calculate output dimensions based on input and config - ... -``` - -### Validation -All nodes implement `validate_incoming_connection()`: -```python -def validate_incoming_connection( - self, - source_node_type: str, - source_output_shape: Optional[TensorShape], - target_config: Dict[str, Any] -) -> Optional[str]: - # Return None if valid, error message if invalid - ... -``` - -**Common Validation Pattern**: -```python -# Allow flexible connections -if source_node_type in ("input", "dataloader"): - return None -if source_node_type in ("empty", "custom"): - return None - -# Validate specific dimension requirements -return self.validate_dimensions( - source_output_shape, - expected_dims, - format_description -) -``` - -### Multi-Input Support -Merge layers (Concat, Add) override: -```python -@property -def allows_multiple_inputs(self) -> bool: - return True -``` - -## Database Fix - -### Issue Fixed -**Error**: `IntegrityError: NOT NULL constraint failed: block_manager_connection.source_handle` - -**Root Cause**: Frontend sends `null` for `sourceHandle` and `targetHandle` when they're not explicitly set, but database expects empty strings. - -**Solution**: Updated `architecture_views.py` to ensure handles are never `None`: -```python -# Before -source_handle = edge.get('sourceHandle', '') -target_handle = edge.get('targetHandle', '') - -# After (more robust) -source_handle = edge.get('sourceHandle') or '' -target_handle = edge.get('targetHandle') or '' -``` - -This handles both missing keys AND `null` values from frontend. - -## Verification - -### Registry Test -Ran verification script (`verify_nodes.py`): -``` -PyTorch Node Registry Verification -Total nodes registered: 17 -Expected nodes: 17 - -✓ All expected nodes are registered! -✓ PASS -Registered: 17/17 -``` - -### Categories Distribution -- **Input**: 2 nodes (Input, DataLoader) -- **Basic**: 8 nodes (Linear, Conv2D, Flatten, Dropout, BatchNorm2D, MaxPool2D, AvgPool2D, AdaptiveAvgPool2D) -- **Advanced**: 5 nodes (Conv1D, Conv3D, LSTM, GRU, Embedding) -- **Merge**: 2 nodes (Concat, Add) - -### Shape Computation Test -Verified Linear layer shape computation: -``` -Input: [32, 128] -Config: {"out_features": 64} -Output: [32, 64] ✓ -``` - -## Files Modified/Created - -### New Node Files (15) -1. `block_manager/services/nodes/pytorch/input.py` -2. `block_manager/services/nodes/pytorch/dataloader.py` -3. `block_manager/services/nodes/pytorch/flatten.py` -4. `block_manager/services/nodes/pytorch/dropout.py` -5. `block_manager/services/nodes/pytorch/batchnorm2d.py` -6. `block_manager/services/nodes/pytorch/maxpool2d.py` -7. `block_manager/services/nodes/pytorch/avgpool2d.py` -8. `block_manager/services/nodes/pytorch/adaptiveavgpool2d.py` -9. `block_manager/services/nodes/pytorch/conv1d.py` -10. `block_manager/services/nodes/pytorch/conv3d.py` -11. `block_manager/services/nodes/pytorch/lstm.py` -12. `block_manager/services/nodes/pytorch/gru.py` -13. `block_manager/services/nodes/pytorch/embedding.py` -14. `block_manager/services/nodes/pytorch/concat.py` -15. `block_manager/services/nodes/pytorch/add.py` - -### Updated Files (2) -16. `block_manager/services/nodes/pytorch/__init__.py` - Added all imports -17. `block_manager/views/architecture_views.py` - Fixed source_handle/target_handle bug - -### Verification Files (1) -18. `verify_nodes.py` - Automated registry testing script - -## Alignment with Frontend - -All backend nodes match the frontend TypeScript definitions: - -| Frontend Type | Backend Class | Status | -|---------------|---------------|--------| -| `input` | `InputNode` | ✅ Match | -| `dataloader` | `DataLoaderNode` | ✅ Match | -| `linear` | `LinearNode` | ✅ Match | -| `conv2d` | `Conv2DNode` | ✅ Match | -| `conv1d` | `Conv1DNode` | ✅ Match | -| `conv3d` | `Conv3DNode` | ✅ Match | -| `flatten` | `FlattenNode` | ✅ Match | -| `dropout` | `DropoutNode` | ✅ Match | -| `batchnorm2d` | `BatchNorm2DNode` | ✅ Match | -| `maxpool2d` | `MaxPool2DNode` | ✅ Match | -| `avgpool2d` | `AvgPool2DNode` | ✅ Match | -| `adaptiveavgpool2d` | `AdaptiveAvgPool2DNode` | ✅ Match | -| `lstm` | `LSTMNode` | ✅ Match | -| `gru` | `GRUNode` | ✅ Match | -| `embedding` | `EmbeddingNode` | ✅ Match | -| `concat` | `ConcatNode` | ✅ Match | -| `add` | `AddNode` | ✅ Match | - -## API Integration Points - -### Node Metadata Endpoint -All nodes provide metadata via `metadata.to_dict()`: -```python -{ - "type": "linear", - "label": "Linear", - "category": "basic", - "color": "var(--color-primary)", - "icon": "Lightning", - "description": "Fully connected layer", - "framework": "pytorch" -} -``` - -### Config Schema Endpoint -All nodes provide config schema via `config_schema`: -```python -[ - { - "name": "out_features", - "label": "Output Features", - "type": "number", - "required": True, - "min": 1, - "description": "Number of output features" - } -] -``` - -### Validation Endpoint -Backend can validate connections using: -```python -node = registry.get_node_definition("linear", Framework.PYTORCH) -error = node.validate_incoming_connection( - source_node_type="conv2d", - source_output_shape=TensorShape(dims=[32, 64, 28, 28]), - target_config={} -) -``` - -## Next Steps - -### Immediate -- [x] All PyTorch nodes implemented -- [x] Database bug fixed -- [x] Registry verification passed - -### Short-Term (Optional) -- [ ] TensorFlow node implementations (17 nodes) -- [ ] Backend API endpoints integration -- [ ] Unit tests for each node -- [ ] PyTorch code generation from node graph - -### Long-Term -- [ ] Custom layer code execution/validation -- [ ] Model training integration -- [ ] Model export to ONNX/TorchScript -- [ ] Automated architecture search - -## Testing Recommendations - -### Unit Tests -Create tests for each node: -```python -def test_linear_node(): - node = LinearNode() - - # Test metadata - assert node.metadata.type == "linear" - assert node.metadata.category == "basic" - - # Test config schema - assert len(node.config_schema) == 2 - - # Test shape computation - input_shape = TensorShape(dims=[32, 128]) - config = {"out_features": 64} - output = node.compute_output_shape(input_shape, config) - assert output.dims == [32, 64] - - # Test validation - error = node.validate_incoming_connection("conv2d", input_shape, {}) - assert error is not None # Conv2D outputs 4D, Linear needs 2D -``` - -### Integration Tests -Test the full pipeline: -1. Create architecture in frontend -2. Save to backend via API -3. Backend validates connections -4. Backend generates PyTorch code -5. Code executes successfully - -## Known Limitations - -### Shape Inference for Merge Nodes -Concat and Add nodes currently preserve input shape in `compute_output_shape()`. Full multi-input shape computation requires graph-level analysis (future enhancement). - -### Custom Layer Support -Custom layers are defined in frontend but not yet executable in backend. Requires sandboxed Python execution (security considerations). - -### Dynamic Shapes -Nodes currently assume static shapes. Dynamic batch sizes or variable sequence lengths may require additional handling. - -## Conclusion - -✅ **All 17 PyTorch backend nodes are successfully implemented and tested.** - -The backend now matches the frontend node registry 1:1, enabling full architecture validation, code generation, and eventual training integration. - ---- - -**Implemented by**: GitHub Copilot -**Verification**: Automated registry test passed (17/17) -**Production Ready**: Yes (pending integration tests) diff --git a/docs/CHATBOT_IMPLEMENTATION_SUMMARY.md b/docs/CHATBOT_IMPLEMENTATION_SUMMARY.md deleted file mode 100644 index 0cdfba7..0000000 --- a/docs/CHATBOT_IMPLEMENTATION_SUMMARY.md +++ /dev/null @@ -1,439 +0,0 @@ -# VisionForge Chatbot Implementation Summary - -## Overview - -This document summarizes the complete implementation of the AI-powered chatbot functionality for VisionForge, featuring Google Gemini integration with full workflow context awareness and modification capabilities. - -## Implementation Date -January 2025 - -## Features Implemented - -### Core Functionality -1. **Gemini AI Integration**: Full integration with Google Generative AI (Gemini 1.5 Flash) -2. **Two-Mode Operation**: - - Q&A Mode: Question answering and guidance - - Modification Mode: Active workflow modification with AI suggestions -3. **Workflow Context**: Full visibility of nodes, edges, and configurations -4. **In-Memory Chat History**: Persistent conversations during session -5. **One-Click Modifications**: Apply AI suggestions with a single button click - -## Files Created - -### Backend - -#### 1. `project/block_manager/services/gemini_service.py` -**Purpose**: Core Gemini AI service for chat functionality - -**Key Classes:** -- `GeminiChatService`: Main service class - -**Key Methods:** -- `chat()`: Send messages with workflow context -- `generate_suggestions()`: Get architecture improvement suggestions -- `_format_workflow_context()`: Convert workflow state to readable format -- `_build_system_prompt()`: Build context-aware system prompts -- `_extract_modifications()`: Parse JSON modification suggestions from AI - -**Features:** -- Automatic workflow state formatting -- Mode-aware system prompts -- Modification extraction from responses -- Error handling and fallbacks - -#### 2. `.env.example` -**Purpose**: Environment configuration template - -**Contents:** -```env -SECRET_KEY=your-secret-key-here -DEBUG=True -GEMINI_API_KEY=your-gemini-api-key-here -``` - -### Frontend - -No new files created, but significant modifications to existing files. - -## Files Modified - -### Backend - -#### 1. `project/requirements.txt` -**Changes:** -- Added `google-generativeai>=0.8.3` - -**Purpose**: Include Gemini API client library - -#### 2. `project/block_manager/views/chat_views.py` -**Changes:** -- Complete rewrite of `chat_message()` endpoint -- Complete rewrite of `get_suggestions()` endpoint -- Added Gemini service integration -- Added error handling for API key configuration -- Added support for modification mode and workflow state - -**New Request Format:** -```json -{ - "message": "string", - "history": [{"role": "user|assistant", "content": "string"}], - "modificationMode": boolean, - "workflowState": {"nodes": [], "edges": []} -} -``` - -**New Response Format:** -```json -{ - "response": "string", - "modifications": [ - { - "action": "add_node|remove_node|modify_node|add_connection|remove_connection", - "details": {...}, - "explanation": "string" - } - ] -} -``` - -### Frontend - -#### 1. `project/frontend/src/components/ChatBot.tsx` -**Major Changes:** - -**New Imports:** -- `Switch` and `Label` components for toggle -- `useModelBuilderStore` for workflow access - -**New State:** -- `modificationMode`: Toggle state for modification mode -- `modifications` in Message interface - -**New Features:** -- Modification mode toggle UI -- Workflow state serialization -- Modification suggestion display -- One-click application buttons -- Visual indicators for modification mode - -**New Methods:** -- `applyModification()`: Apply AI suggestions to workflow - -**Modification Actions Supported:** -1. `add_node`: Add new layers -2. `remove_node`: Remove layers -3. `modify_node`: Update configurations -4. `add_connection`: Create edges -5. `remove_connection`: Remove edges - -#### 2. `project/frontend/src/lib/api.ts` -**Changes:** - -**Updated `sendChatMessage()` function:** -- Added `modificationMode` parameter -- Added `workflowState` parameter -- Updated return type to include `modifications` - -**New Signature:** -```typescript -sendChatMessage( - message: string, - history?: any[], - modificationMode?: boolean, - workflowState?: { nodes: any[], edges: any[] } -): Promise> -``` - -## Documentation Created - -### 1. `CHATBOT_SETUP.md` -**Comprehensive setup and usage guide including:** -- Overview and features -- Detailed setup instructions -- Usage examples for both modes -- API endpoint documentation -- Modification action specifications -- Best practices -- Troubleshooting guide -- Security considerations -- Future enhancements -- Example use cases - -### 2. `QUICKSTART.md` -**Quick 5-minute setup guide including:** -- Prerequisites -- Step-by-step setup (5 steps) -- Quick examples -- Common troubleshooting -- Next steps - -### 3. `README.md` (Updated) -**Changes:** -- Added Gemini API key to prerequisites -- Added environment setup instructions -- Added AI Chatbot feature section -- Added links to setup documentation -- Added chatbot to Additional Resources - -### 4. `CHATBOT_IMPLEMENTATION_SUMMARY.md` (This File) -**Purpose**: Document all implementation changes - -## Architecture Overview - -### Data Flow - -``` -User Input (Frontend ChatBot) - ↓ - Serializes workflow state (nodes + edges) - ↓ - POST /api/chat with {message, history, modificationMode, workflowState} - ↓ -Backend (chat_views.py) - ↓ -GeminiChatService - ↓ - Formats workflow context - Builds system prompt - Sends to Gemini API - ↓ -Gemini API Response - ↓ - Extract modifications (JSON parsing) - ↓ -Return {response, modifications} - ↓ -Frontend ChatBot - ↓ - Display message - Show modification buttons - ↓ -User clicks "Apply Change" - ↓ - applyModification() updates store - ↓ -Workflow UI updates in real-time -``` - -### Component Interaction - -``` -ChatBot Component - ├── useModelBuilderStore (Zustand) - │ ├── nodes[] - │ ├── edges[] - │ ├── addNode() - │ ├── updateNode() - │ ├── removeNode() - │ ├── addEdge() - │ └── removeEdge() - │ - ├── API Service (api.ts) - │ └── sendChatMessage() - │ - └── Backend (/api/chat) - └── GeminiChatService - └── Gemini API -``` - -## API Integration Details - -### Gemini Model Used -- **Model**: `gemini-1.5-flash` -- **Provider**: Google Generative AI -- **Pricing**: Free tier available (60 requests/minute) - -### Configuration -- API key stored in environment variable: `GEMINI_API_KEY` -- Configured in backend `.env` file -- Accessed via `os.getenv('GEMINI_API_KEY')` - -### Error Handling -- Missing API key: Returns user-friendly error message -- API failures: Graceful degradation with error messages -- Suggestion generation fallback: Returns basic suggestions if API fails - -## Security Considerations - -### API Key Security -1. API key stored in `.env` file (not in version control) -2. `.env.example` provided as template (without actual key) -3. Backend validates key presence before making requests - -### Data Privacy -1. Workflow state sent to Google Gemini API -2. No persistent storage of chat conversations -3. Session-only memory (cleared on refresh) - -### Input Validation -1. Backend validates required fields -2. Frontend prevents empty messages -3. Error handling for malformed requests - -## Testing Recommendations - -### Manual Testing Checklist - -**Q&A Mode:** -- [ ] Ask about workflow state -- [ ] Request explanations of concepts -- [ ] Get guidance on architecture patterns -- [ ] Verify markdown rendering in responses - -**Modification Mode:** -- [ ] Enable modification mode toggle -- [ ] Request node additions -- [ ] Request node removals -- [ ] Request node modifications -- [ ] Request connection additions -- [ ] Request connection removals -- [ ] Apply modifications and verify UI updates - -**Error Handling:** -- [ ] Test without API key (should show error) -- [ ] Test with invalid API key -- [ ] Test with empty messages -- [ ] Test with malformed workflow state - -**Integration:** -- [ ] Verify workflow context is sent correctly -- [ ] Verify chat history maintains context -- [ ] Verify modifications update Zustand store -- [ ] Verify Canvas UI reflects changes - -## Performance Considerations - -### Optimization Strategies -1. **History Management**: Limited to session only (no DB overhead) -2. **Workflow Serialization**: Only send necessary fields -3. **API Calls**: Debouncing not implemented (consider for production) -4. **Response Parsing**: Efficient regex-based JSON extraction - -### Known Limitations -1. No rate limiting implemented (relies on Gemini free tier limits) -2. No request queuing (sequential requests only) -3. Chat history not persisted (resets on page refresh) -4. No support for file uploads or images in chat - -## Future Enhancement Opportunities - -### Short-term -1. Add loading indicators during API calls -2. Implement request debouncing -3. Add chat export functionality -4. Add undo/redo for AI modifications -5. Add bulk modification application - -### Medium-term -1. Persist chat sessions to database -2. Add multi-user collaboration -3. Implement suggestion history -4. Add custom model selection (GPT-4, Claude) -5. Add voice input/output - -### Long-term -1. Architecture template generation from descriptions -2. Automated architecture optimization -3. Training script generation -4. Dataset recommendations -5. Model performance predictions - -## Dependencies Added - -### Python -- `google-generativeai>=0.8.3`: Official Gemini API client - -### Frontend -No new dependencies (used existing components) - -## Environment Variables - -### Backend (.env) -```env -GEMINI_API_KEY=your-api-key-here -``` - -### Frontend (.env) -No changes required (uses existing `VITE_API_URL`) - -## Breaking Changes - -**None.** All changes are backwards compatible: -- Existing chat endpoint still works without new parameters -- Frontend gracefully handles missing modification data -- No database migrations required - -## Migration Guide - -### For Existing Installations - -1. **Pull latest code** -2. **Install new dependency:** - ```bash - pip install google-generativeai - ``` -3. **Create `.env` file:** - ```bash - cp .env.example .env - ``` -4. **Add API key to `.env`:** - ```env - GEMINI_API_KEY=your-key-here - ``` -5. **Restart backend server** -6. **Frontend auto-updates** (no changes needed) - -### For New Installations - -Follow the Quick Start Guide in [QUICKSTART.md](./QUICKSTART.md) - -## Troubleshooting Guide - -### Common Issues - -**1. "API key is not configured"** -- Cause: Missing or incorrect `GEMINI_API_KEY` -- Solution: Check `.env` file exists and contains valid key - -**2. "Connection error"** -- Cause: Backend server not running -- Solution: Start backend with `python manage.py runserver` - -**3. Modifications not applying** -- Cause: Modification mode toggle is OFF -- Solution: Enable toggle in chatbot header - -**4. Chat not opening** -- Cause: Frontend build error -- Solution: Check browser console, rebuild frontend - -## Success Metrics - -### Functional Metrics -- ✅ Chat responds to messages -- ✅ Workflow context is sent correctly -- ✅ Modifications are suggested in correct format -- ✅ Modifications can be applied successfully -- ✅ UI updates reflect changes immediately - -### Performance Metrics -- Response time: ~2-5 seconds (Gemini API latency) -- Workflow serialization: <100ms -- Modification application: <50ms - -## Conclusion - -The VisionForge chatbot implementation provides a comprehensive, production-ready AI assistant that enhances the visual neural network building experience. With two distinct modes of operation, full workflow context awareness, and seamless integration with the existing architecture, it empowers users to build, understand, and optimize their neural networks more efficiently. - -The implementation follows best practices for: -- Security (API key management) -- Error handling (graceful degradation) -- User experience (real-time updates, clear feedback) -- Code organization (service layer separation) -- Documentation (comprehensive guides) - -All features are fully functional and ready for immediate use. diff --git a/docs/CHATBOT_SETUP.md b/docs/CHATBOT_SETUP.md deleted file mode 100644 index a498872..0000000 --- a/docs/CHATBOT_SETUP.md +++ /dev/null @@ -1,559 +0,0 @@ -# VisionForge Chatbot Setup Guide - -This guide will help you set up and use the AI-powered chatbot feature in VisionForge. - -## Overview - -The VisionForge chatbot is an intelligent assistant that helps you build neural network architectures. It has two modes: - -1. **Q&A Mode**: Answer questions about your workflow, explain concepts, and provide guidance -2. **Modification Mode**: Actively suggest and apply changes to your workflow - -## Features - -- Real-time conversation with context awareness -- Full workflow state visibility (nodes, edges, configurations) -- Interactive modification suggestions -- One-click application of AI-suggested changes -- Persistent chat history during the session -- Markdown formatting support in responses - -## Setup Instructions - -### 1. Choose Your AI Provider - -VisionForge supports two AI providers: -- **Gemini** (Google's Generative AI) - Default -- **Claude** (Anthropic's Claude AI) - -Choose one provider and obtain an API key: - -#### Option A: Gemini API Key - -1. Go to [Google AI Studio](https://aistudio.google.com/app/apikey) -2. Sign in with your Google account -3. Click "Create API Key" -4. Copy the generated API key - -#### Option B: Claude API Key - -1. Go to [Anthropic Console](https://console.anthropic.com/) -2. Sign in or create an account -3. Navigate to API Keys section -4. Click "Create Key" -5. Copy the generated API key - -### 2. Configure Backend Environment - -1. Navigate to the backend directory: - ```bash - cd project - ``` - -2. Create a `.env` file (or copy from `.env.example`): - ```bash - cp .env.example .env - ``` - -3. Edit the `.env` file and configure your AI provider: - - **For Gemini:** - ```env - # AI Provider Configuration - AI_PROVIDER=gemini - - # Gemini AI Configuration - GEMINI_API_KEY=your-actual-gemini-api-key-here - ``` - - **For Claude:** - ```env - # AI Provider Configuration - AI_PROVIDER=claude - - # Claude AI Configuration - ANTHROPIC_API_KEY=your-actual-anthropic-api-key-here - ``` - - Replace the placeholder values with your actual API keys. - -### 3. Install Dependencies - -Install the required Python packages: - -```bash -pip install -r requirements.txt -``` - -This will install both AI provider packages: -- `google-generativeai` (for Gemini) -- `anthropic` (for Claude) - -Or install packages individually: - -```bash -# For Gemini -pip install google-generativeai - -# For Claude -pip install anthropic -``` - -### 4. Start the Backend Server - -```bash -python manage.py runserver -``` - -The server should start on `http://localhost:8000` - -### 5. Start the Frontend - -In a separate terminal: - -```bash -cd project/frontend -npm run dev -``` - -The frontend should start on `http://localhost:5173` - -## Using the Chatbot - -### Opening the Chat - -Click the floating chat button in the bottom-right corner of the screen to open the chatbot panel. - -### Chat Modes - -#### Q&A Mode (Default) - -When the **Modification Mode** toggle is OFF: -- Ask questions about your workflow -- Get explanations of neural network concepts -- Receive guidance on best practices -- Learn about available node types and configurations - -**Example questions:** -- "What does my current architecture do?" -- "How do I add a convolutional layer?" -- "What is the BatchNorm2D layer used for?" -- "Can you explain the current connections in my workflow?" - -#### Modification Mode - -When the **Modification Mode** toggle is ON: -- AI can suggest specific changes to your workflow -- Receive actionable modification recommendations -- Apply changes with a single click - -**Example requests:** -- "Add a Conv2D layer with 64 filters" -- "Add BatchNorm2D after the convolutional layer" -- "Remove the dropout layer" -- "Suggest improvements to reduce overfitting" - -### Applying Modifications - -When the AI suggests modifications: - -1. The chat will display suggested changes with explanations -2. Each suggestion includes an **"Apply Change"** button -3. Click the button to automatically apply the modification to your workflow -4. You'll see a success notification confirming the change -5. The workflow canvas will update in real-time - -### Workflow Context - -The chatbot has full visibility into your current workflow: - -- All nodes (layers) and their configurations -- All connections (edges) between nodes -- Node positions and arrangement -- Current architecture state - -This context allows the AI to: -- Provide specific recommendations based on your architecture -- Suggest compatible layers -- Identify potential issues -- Reference specific nodes by name - -## API Endpoints - -The chatbot uses the following backend endpoints: - -### POST /api/chat - -Send a chat message with workflow context. - -**Request:** -```json -{ - "message": "Add a Conv2D layer", - "history": [ - { - "role": "user", - "content": "Previous message" - }, - { - "role": "assistant", - "content": "Previous response" - } - ], - "modificationMode": true, - "workflowState": { - "nodes": [...], - "edges": [...] - } -} -``` - -**Response:** -```json -{ - "response": "I'll help you add a Conv2D layer...", - "modifications": [ - { - "action": "add_node", - "details": { - "nodeType": "Conv2D", - "config": { - "in_channels": 3, - "out_channels": 64, - "kernel_size": 3 - }, - "position": {"x": 100, "y": 100} - }, - "explanation": "Adding a Conv2D layer for feature extraction" - } - ] -} -``` - -### POST /api/suggestions - -Get architecture improvement suggestions. - -**Request:** -```json -{ - "nodes": [...], - "edges": [...] -} -``` - -**Response:** -```json -{ - "suggestions": [ - "Consider adding BatchNorm2D after convolutional layers", - "Add dropout layers to prevent overfitting", - "Use ReLU activation for faster convergence" - ] -} -``` - -## Modification Actions - -The chatbot can suggest the following types of modifications: - -### 1. Add Node -Adds a new layer to the workflow. - -```json -{ - "action": "add_node", - "details": { - "nodeType": "Conv2D", - "config": { - "in_channels": 3, - "out_channels": 64, - "kernel_size": 3, - "stride": 1, - "padding": 1 - }, - "position": {"x": 100, "y": 200} - }, - "explanation": "Adding convolutional layer for feature extraction" -} -``` - -### 2. Remove Node -Removes a layer from the workflow. - -```json -{ - "action": "remove_node", - "details": { - "nodeId": "node-123" - }, - "explanation": "Removing unnecessary layer" -} -``` - -### 3. Modify Node -Updates a layer's configuration. - -```json -{ - "action": "modify_node", - "details": { - "nodeId": "node-123", - "config": { - "out_channels": 128, - "kernel_size": 5 - } - }, - "explanation": "Increasing filter size for better feature extraction" -} -``` - -### 4. Add Connection -Adds an edge between two nodes. - -```json -{ - "action": "add_connection", - "details": { - "source": "node-123", - "target": "node-456", - "sourceHandle": "output", - "targetHandle": "input" - }, - "explanation": "Connecting layers to create data flow" -} -``` - -### 5. Remove Connection -Removes an edge between nodes. - -```json -{ - "action": "remove_connection", - "details": { - "edgeId": "edge-123" - }, - "explanation": "Removing invalid connection" -} -``` - -## Best Practices - -### 1. Start with Q&A Mode -- Learn about the available features -- Understand your current workflow -- Get familiar with the chatbot's capabilities - -### 2. Use Clear, Specific Requests -Instead of: "Make it better" -Try: "Add a Conv2D layer with 64 filters after the input" - -### 3. Review Modifications Before Applying -- Read the explanation provided -- Understand what the change does -- Ensure it aligns with your goals - -### 4. Iterative Refinement -- Apply one change at a time -- Review the results -- Continue the conversation to refine further - -### 5. Provide Context -When asking questions, reference specific parts of your workflow: -- "What does the second Conv2D layer do?" -- "Should I add normalization after my pooling layer?" - -## Troubleshooting - -### Chatbot Not Responding - -**Error:** "API key is not configured" or "AI service not properly configured" - -**Solution:** -1. Ensure you've set the correct environment variables in the `.env` file: - - For Gemini: `AI_PROVIDER=gemini` and `GEMINI_API_KEY` - - For Claude: `AI_PROVIDER=claude` and `ANTHROPIC_API_KEY` -2. Restart the Django server -3. Verify the API key is correct -4. Check that `AI_PROVIDER` is set to either `gemini` or `claude` - -### Switching Between Providers - -To switch from one AI provider to another: - -1. Update `AI_PROVIDER` in your `.env` file: - ```env - AI_PROVIDER=claude # or gemini - ``` -2. Ensure the corresponding API key is set -3. Restart the Django server -4. The chatbot will now use the new provider - -### Connection Errors - -**Error:** "I'm having trouble connecting to the server" - -**Solution:** -1. Check that the backend server is running (`python manage.py runserver`) -2. Verify the frontend is configured with the correct API URL (`VITE_API_URL`) -3. Check for CORS errors in the browser console - -### Modifications Not Applying - -**Solution:** -1. Ensure you're in Modification Mode (toggle should be ON) -2. Check that the node IDs in suggestions match your workflow -3. Review browser console for errors - -### API Rate Limits - -If you exceed your provider's rate limits: - -**Gemini:** -- Free tier: 15 requests per minute -- Wait a few minutes before trying again -- Consider upgrading your API plan - -**Claude:** -- Rate limits vary by plan -- Check [Anthropic's pricing page](https://www.anthropic.com/pricing) for details -- Consider upgrading your plan if needed - -## Advanced Usage - -### Custom System Prompts - -The chatbot builds context-aware system prompts that include: -- Available node types -- Current workflow state -- Mode-specific instructions - -This ensures the AI understands the VisionForge environment. - -### Chat History Management - -- Chat history is maintained in-memory during the session -- All previous messages provide context for new responses -- History is sent with each request for continuity -- Refreshing the page clears the history - -### Workflow State Serialization - -The workflow state is automatically serialized and sent with each message: -- Nodes: ID, type, label, configuration, position -- Edges: Source, target, handles -- This allows the AI to provide context-aware suggestions - -## Security Considerations - -1. **API Key Security** - - Never commit `.env` files to version control - - Keep your API key private - - Rotate keys periodically - -2. **Data Privacy** - - Workflow data is sent to your chosen AI provider's API - - For Gemini: Review [Google's privacy policy](https://policies.google.com/privacy) for AI services - - For Claude: Review [Anthropic's privacy policy](https://www.anthropic.com/legal/privacy) - - Don't include sensitive information in prompts - -3. **Rate Limiting** - - Implement request throttling for production use - - Monitor API usage to avoid unexpected costs - -## AI Provider Comparison - -### Gemini (Google) -- **Model:** gemini-2.0-flash -- **Strengths:** - - Fast response times - - Good at technical explanations - - Strong multimodal support (images, PDFs) - - Free tier available -- **API Key:** Get from [Google AI Studio](https://aistudio.google.com/app/apikey) - -### Claude (Anthropic) -- **Model:** claude-3-5-sonnet-20241022 -- **Strengths:** - - Excellent reasoning capabilities - - Strong code understanding - - Detailed explanations - - High-quality suggestions -- **API Key:** Get from [Anthropic Console](https://console.anthropic.com/) - -### Choosing a Provider - -**Use Gemini if:** -- You want a free tier option -- You need fast response times -- You frequently upload images/PDFs for analysis - -**Use Claude if:** -- You need detailed, high-quality architectural suggestions -- You prefer thorough explanations -- You want advanced reasoning capabilities - -## Future Enhancements - -Planned features for the chatbot: - -- [x] Multi-provider support (Gemini and Claude) -- [ ] Chat session persistence (save/load conversations) -- [ ] Export chat history -- [ ] Multi-user collaboration support -- [ ] Additional AI providers (GPT-4, etc.) -- [ ] Voice input/output -- [ ] Batch modification application -- [ ] Undo/redo for AI changes -- [ ] Architecture templates from descriptions - -## Support - -For issues or questions: -1. Check this documentation -2. Review the console logs (browser and server) -3. Ensure all dependencies are installed -4. Verify API key configuration - -## Example Use Cases - -### 1. Building from Scratch - -**User:** "I want to build an image classification model for CIFAR-10" - -**AI Response:** "I'll help you create an architecture for CIFAR-10 classification..." - -*[AI suggests adding Input, Conv2D, ReLU, MaxPool2D, etc.]* - -### 2. Debugging Architecture - -**User:** "Why isn't my model training well?" - -**AI Response:** "Let me analyze your architecture... I notice you don't have any normalization layers. This can lead to training instability..." - -### 3. Optimization - -**User:** "How can I reduce overfitting?" - -**AI Response:** "I can suggest several improvements: -1. Add Dropout layers (0.5 dropout rate) -2. Add BatchNorm2D after conv layers -3. Use data augmentation (in your training code)" - -### 4. Learning - -**User:** "What does the Flatten layer do?" - -**AI Response:** "The Flatten layer reshapes multi-dimensional tensors into a 1D vector. For example, if you have a tensor of shape (batch, 64, 7, 7) from convolutional layers, Flatten will convert it to (batch, 3136), which can then be fed into Linear layers for classification..." - -## Conclusion - -The VisionForge chatbot is a powerful tool for building, understanding, and optimizing neural network architectures. By combining conversational AI with direct workflow manipulation, it provides an intuitive interface for both beginners and experts. - -Remember to: -- Toggle Modification Mode when you want the AI to make changes -- Review suggestions before applying them -- Use the chat history to build context -- Experiment and iterate on your designs - -Happy building! diff --git a/docs/DYNAMIC_LOADING_SUMMARY.md b/docs/DYNAMIC_LOADING_SUMMARY.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs/DYNAMIC_LOADING_TEST_REPORT.md b/docs/DYNAMIC_LOADING_TEST_REPORT.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs/DYNAMIC_NODE_LOADING.md b/docs/DYNAMIC_NODE_LOADING.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs/IMPLEMENTATION_CHECKLIST.md b/docs/IMPLEMENTATION_CHECKLIST.md deleted file mode 100644 index f1cca65..0000000 --- a/docs/IMPLEMENTATION_CHECKLIST.md +++ /dev/null @@ -1,296 +0,0 @@ -# Implementation Checklist - Phase 1-3 - -## ✅ Phase 1: Backend Domain Model Refactor - -### Core Infrastructure -- [x] Created `specs/models.py` with frozen dataclasses - - [x] Framework enum (PYTORCH, TENSORFLOW) - - [x] ConfigOptionSpec - - [x] ConfigFieldSpec - - [x] NodeTemplateSpec - - [x] NodeSpec - - [x] default_config() method - -- [x] Created `specs/registry.py` with LRU caching - - [x] _load_spec_map() with @lru_cache - - [x] list_node_specs(framework) - - [x] get_node_spec(node_type, framework) - - [x] iter_all_specs() - -- [x] Created `specs/serialization.py` - - [x] _option_to_dict() - - [x] _field_to_dict() - - [x] _template_to_dict() - - [x] spec_to_dict() - camelCase output - - [x] compute_spec_hash() - deterministic SHA256 - -- [x] Created `templates/renderer.py` - - [x] RenderedTemplate dataclass - - [x] render_node_template() with Jinja2 - - [x] StrictUndefined mode - - [x] Context merging (config + metadata + extra) - -- [x] Created `specs/__init__.py` for public API -- [x] Created `templates/__init__.py` for exports - -### Shape Computation Functions -- [x] Created `rules/shape.py` - - [x] TensorShape class - - [x] compute_conv2d_output() - NCHW & NHWC - - [x] compute_linear_output() - - [x] compute_flatten_output() - - [x] compute_maxpool_output() - - [x] compute_concat_output() - multi-input - - [x] compute_add_output() - multi-input - - [x] compute_batchnorm_output() - - [x] compute_dropout_output() - - [x] compute_activation_output() - -### Validation Functions -- [x] Created `rules/validation.py` - - [x] ValidationError exception - - [x] validate_connection() - dimension compatibility - - [x] validate_multi_input_connection() - concat/add - - [x] validate_config() - schema validation - - [x] validate_graph_acyclic() - DAG enforcement - - [x] validate_single_input_node() - -- [x] Created `rules/__init__.py` for exports - -### PyTorch Node Specifications -- [x] Created `specs/pytorch/__init__.py` - - [x] INPUT_SPEC - - [x] LINEAR_SPEC - - [x] CONV2D_SPEC - - [x] FLATTEN_SPEC - - [x] RELU_SPEC - - [x] DROPOUT_SPEC - - [x] BATCHNORM_SPEC - - [x] MAXPOOL_SPEC - - [x] SOFTMAX_SPEC - - [x] CONCAT_SPEC (multi-input) - - [x] ADD_SPEC (multi-input) - - [x] ATTENTION_SPEC - - [x] CUSTOM_SPEC - - [x] DATALOADER_SPEC - - [x] OUTPUT_SPEC - - [x] LOSS_SPEC - - [x] EMPTY_SPEC - - [x] NODE_SPECS tuple - -### TensorFlow Node Specifications -- [x] Created `specs/tensorflow/__init__.py` - - [x] INPUT_SPEC - - [x] LINEAR_SPEC (Dense) - - [x] CONV2D_SPEC - - [x] FLATTEN_SPEC - - [x] DROPOUT_SPEC - - [x] BATCHNORM_SPEC (BatchNormalization) - - [x] MAXPOOL_SPEC (MaxPooling2D) - - [x] CONCAT_SPEC (multi-input) - - [x] ADD_SPEC (multi-input) - - [x] DATALOADER_SPEC - - [x] OUTPUT_SPEC - - [x] LOSS_SPEC - - [x] EMPTY_SPEC - - [x] CUSTOM_SPEC - - [x] NODE_SPECS tuple - ---- - -## ✅ Phase 2: Backend API Redesign - -### API Endpoints -- [x] Updated `views/architecture_views.py` - - [x] get_node_definitions() - uses new registry - - [x] get_node_definition() - uses new registry - - [x] render_node_code() - NEW endpoint - -- [x] Updated `urls.py` - - [x] Added render_node_code import - - [x] Added /render-node-code route - -### API Response Format -- [x] GET /node-definitions returns camelCase JSON -- [x] Includes config_schema, template, metadata -- [x] Includes deterministic hash -- [x] Framework-specific filtering - -### Dependencies -- [x] Added jinja2>=3.1.0 to requirements.txt -- [x] Installed jinja2 in virtual environment - ---- - -## ✅ Phase 3: Frontend Integration - -### TypeScript Types -- [x] Created `lib/nodeSpec.types.ts` - - [x] Framework type - - [x] ConfigOption interface - - [x] ConfigField interface - - [x] NodeTemplate interface - - [x] NodeSpec interface - - [x] NodeDefinitionsResponse interface - - [x] RenderCodeRequest interface - - [x] RenderCodeResponse interface - -### API Client -- [x] Updated `lib/api.ts` - - [x] Added NodeSpec types import - - [x] Updated getNodeDefinitions() with proper types - - [x] Updated getNodeDefinition() with proper types - - [x] Added renderNodeCode() function - - [x] Updated default export - -### React Hooks -- [x] Created `lib/useNodeSpecs.ts` - - [x] useNodeSpecs() hook - - [x] Fetches all specs for framework - - [x] getSpec() helper - - [x] renderCode() helper - - [x] refetch() function - - [x] loading/error states - - [x] useNodeSpec() hook (single spec) - -### UI Components -- [x] Created `components/CodePreview.tsx` - - [x] Fetches rendered code on mount - - [x] Shows loading state - - [x] Shows error state - - [x] Displays code in styled pre/code block - - [x] Tailwind CSS styling - ---- - -## ✅ Testing & Verification - -### Test Suite -- [x] Created test_nodespec_system.py - - [x] Test 1: Spec Registry (loading, retrieval, iteration) - - [x] Test 2: Serialization (dict conversion, hashing, determinism) - - [x] Test 3: Template Rendering (PyTorch, TensorFlow, parameter interpolation) - - [x] Test 4: Shape Computation (NCHW, NHWC, all functions) - - [x] Test 5: Validation (config, connections, dimension compatibility) - - [x] Test 6: API Integration (all 3 endpoints) - -### Test Results -- [x] All 6 test categories passing -- [x] 31 specs loaded (17 PyTorch + 14 TensorFlow) -- [x] Template rendering verified -- [x] Shape computation verified -- [x] Validation rules verified -- [x] API endpoints verified - -### No Syntax Errors -- [x] models.py - clean -- [x] registry.py - clean -- [x] serialization.py - clean -- [x] renderer.py - clean -- [x] shape.py - clean -- [x] validation.py - clean -- [x] specs/pytorch/__init__.py - clean -- [x] specs/tensorflow/__init__.py - clean -- [x] architecture_views.py - clean (new code) -- [x] api.ts - clean -- [x] nodeSpec.types.ts - clean -- [x] useNodeSpecs.ts - clean -- [x] CodePreview.tsx - clean - ---- - -## ✅ Documentation - -### Comprehensive Guides -- [x] NODESPEC_IMPLEMENTATION_COMPLETE.md - - [x] Architecture overview - - [x] Component details - - [x] Node specifications table - - [x] API endpoints - - [x] Frontend integration - - [x] Testing coverage - - [x] Migration notes - - [x] Dependencies - - [x] Performance optimizations - -- [x] NODESPEC_QUICK_REFERENCE.md - - [x] Backend developer guide - - [x] Frontend developer guide - - [x] Common patterns - - [x] Template syntax - - [x] Validation patterns - - [x] API endpoints summary - - [x] Debugging tips - - [x] File locations - -- [x] PHASE_1-3_IMPLEMENTATION_SUMMARY.md - - [x] Objective & scope - - [x] Phase details - - [x] Test coverage - - [x] Architecture highlights - - [x] Performance metrics - - [x] Migration path - - [x] Success criteria - -### Inline Documentation -- [x] Docstrings for all public functions -- [x] Type hints for all parameters -- [x] Comments for non-obvious logic - ---- - -## 🎯 Success Criteria Verification - -- [x] Backend can emit source code for all node types ✅ -- [x] API serves node specifications as JSON ✅ -- [x] Frontend has TypeScript types for all responses ✅ -- [x] Template rendering works for PyTorch and TensorFlow ✅ -- [x] Shape inference handles NCHW and NHWC formats ✅ -- [x] Validation prevents invalid connections ✅ -- [x] All tests pass (100% coverage for new code) ✅ -- [x] No placeholders or incomplete implementations ✅ -- [x] Documentation is comprehensive ✅ - ---- - -## 📊 Statistics - -### Code Written -- **Backend Python:** ~2,500 lines -- **Frontend TypeScript:** ~500 lines -- **Tests:** ~300 lines -- **Documentation:** ~2,000 lines -- **Total:** ~5,300 lines - -### Files Created -- **Backend:** 14 files -- **Frontend:** 4 files -- **Tests:** 1 file -- **Documentation:** 3 files -- **Total:** 22 files - -### Node Coverage -- **PyTorch Nodes:** 17 (100% of existing types) -- **TensorFlow Nodes:** 14 (100% of existing types) -- **Total Nodes:** 31 - -### Test Coverage -- **Test Categories:** 6 -- **Assertions:** 40+ -- **Pass Rate:** 100% - ---- - -## ✅ COMPLETE - All Phases Implemented - -**Status:** Ready for Production -**Date Completed:** December 2024 -**Implemented By:** GitHub Copilot - -All requested phases (1-3) are complete with: -- ✅ No placeholders -- ✅ No incomplete functions -- ✅ Full test coverage -- ✅ Comprehensive documentation -- ✅ Type safety (Python & TypeScript) -- ✅ Production-ready code diff --git a/docs/IMPLEMENTATION_COMPLETE_PORT_SYSTEM.md b/docs/IMPLEMENTATION_COMPLETE_PORT_SYSTEM.md deleted file mode 100644 index f42f93d..0000000 --- a/docs/IMPLEMENTATION_COMPLETE_PORT_SYSTEM.md +++ /dev/null @@ -1,253 +0,0 @@ -# Implementation Complete: Port-Based Connection System - -## Date -December 2024 - -## Summary -Successfully implemented a comprehensive port-based connection system that fixes all 19 identified connection-related bugs and establishes a robust foundation for semantic, handle-aware validation throughout VisionForge. - -## Implementation Phases Completed - -### ✅ Phase 1: Port Definition System -- **Frontend**: Created `/project/frontend/src/lib/nodes/ports.ts` with PortSemantic enum and PortDefinition interface -- **Backend**: Created `/project/block_manager/services/nodes/ports.py` mirroring frontend structure -- **Base Class**: Added `getInputPorts()` and `getOutputPorts()` methods to NodeDefinition -- **Interface**: Updated INodeDefinition to include port methods -- **Nodes Updated**: - - Loss node: Returns 2-3 ports based on loss_type (cross_entropy, mse, triplet_margin) - - DataLoader node: Returns dynamic ports based on num_input_outlets and has_ground_truth - -### ✅ Phase 2: Connection Validation -- **Handle-Aware Validation**: Enhanced `validateConnection()` in store.ts with: - - Source/target handle existence checking - - Port occupancy validation (prevents duplicate connections) - - Semantic compatibility validation - - Real-time loss input count validation -- **Architecture Validation**: Updated `validateArchitecture()` with: - - Handle-aware loss node validation - - Specific error messages naming missing ports - - Check for all required ports being filled - -### ✅ Phase 3: Visual Improvements -- **Port Occupancy Indicators**: Updated BlockNode.tsx with: - - Green ring around connected handles - - Checkmark (✓) next to connected port labels - - Dimmed opacity for connected ports - - Color change to green for connected handles -- **Applied to**: - - Loss node input handles - - DataLoader output handles - - Ground truth handle - -### ✅ Phase 5: Backend Validation Alignment -- **Updated ArchitectureValidator**: Modified validation.py to: - - Recognize loss as valid multi-input block - - Import LOSS_SPEC for port definitions - - Validate connection count and handle occupancy - - Provide detailed error messages -- **New Method**: Added `_validate_loss_connections()` for loss-specific validation - -## Files Created - -1. `/project/frontend/src/lib/nodes/ports.ts` (133 lines) - - PortSemantic enum with 10 semantic types - - PortDefinition interface - - arePortsCompatible() validation function - - validatePortConnection() helper - -2. `/project/block_manager/services/nodes/ports.py` (51 lines) - - Python equivalent of port system - - PortSemantic enum - - PortDefinition dataclass - -3. `/docs/PORT_BASED_CONNECTION_SYSTEM.md` (715 lines) - - Comprehensive implementation documentation - - Code examples and patterns - - Testing recommendations - - Migration guide - -4. `/docs/PORT_SYSTEM_QUICK_REFERENCE.md` (377 lines) - - Quick reference for developers - - Compatibility matrix - - Common patterns - - Troubleshooting guide - -## Files Modified - -1. `/project/frontend/src/lib/nodes/contracts.ts` - - Added getInputPorts() and getOutputPorts() to INodeDefinition - - Added PortDefinition import - -2. `/project/frontend/src/lib/nodes/base.ts` - - Implemented default port methods in NodeDefinition base class - - Returns single default port for backwards compatibility - -3. `/project/frontend/src/lib/nodes/definitions/pytorch/loss.ts` - - Updated getInputPorts() to return PortDefinition[] - - Added port configs for cross_entropy, mse, triplet_margin - - Implemented getOutputPorts() returning loss output - -4. `/project/frontend/src/lib/nodes/definitions/pytorch/dataloader.ts` - - Implemented getOutputPorts() with dynamic port generation - - Returns Data semantic for input outlets - - Returns Labels semantic for ground truth - -5. `/project/frontend/src/lib/store.ts` - - Enhanced validateConnection() with 5-step validation: - 1. Source handle existence - 2. Target handle existence - 3. Port occupancy check - 4. Semantic compatibility - 5. Real-time loss input count validation - - Updated validateArchitecture() with handle-aware loss validation - - Added import for arePortsCompatible - -6. `/project/frontend/src/components/BlockNode.tsx` - - Added edges import from store - - Implemented isHandleConnected() helper function - - Updated Loss input handles with occupancy indicators - - Updated DataLoader output handles with occupancy indicators - - Added visual feedback (green ring, checkmark, dimmed labels) - -7. `/project/block_manager/services/validation.py` - - Updated _validate_connections() to allow loss blocks - - Added _validate_loss_connections() method - - Implemented handle-aware validation on backend - -## Bugs Fixed - -### Critical (4/4) -1. ✅ Named input port connections not validated -2. ✅ Loss type changes don't update connections properly -3. ✅ Connection validation missing handle information -4. ✅ Target handle occupancy not checked - -### High Priority (4/4) -5. ✅ DataLoader outputs have no semantic types -6. ✅ Real-time validation missing for loss input count -7. ✅ No visual feedback for port occupancy -8. ✅ Backend validation doesn't support multi-input loss - -### Total: 8/19 bugs explicitly addressed -- Remaining bugs (11) relate to config handling, edge cases, and polish (Phases 4, 6, 7) -- Foundation established for addressing remaining issues - -## Key Features - -### Port Semantic Types -- **Data**: General tensor data flow -- **Labels**: Ground truth/target values -- **Predictions**: Model output/predictions -- **Features**: Intermediate representations -- **Anchor/Positive/Negative**: Triplet loss specific -- **Loss**: Loss value output -- **Any**: Accepts any connection -- **Generic**: Default/unspecified - -### Validation Pipeline -1. **Real-time** (during drag): validateConnection() -2. **Architecture-level** (before export): validateArchitecture() -3. **Backend** (server-side): ArchitectureValidator - -### Visual Feedback -- Connected ports show green ring + checkmark -- Unconnected ports show original color -- Labels dimmed when port is occupied -- Prevents confusion about which ports are available - -### Developer Experience -- Type-safe port definitions -- Backwards compatible with existing nodes -- Comprehensive documentation -- Quick reference guide -- Clear error messages - -## Testing Status - -### Manual Testing Performed -- ✅ Loss node accepts correct number of inputs -- ✅ Semantic validation prevents incorrect connections -- ✅ Visual indicators show connected ports -- ✅ Handle occupancy prevents duplicate connections -- ✅ Config changes update ports dynamically -- ✅ Backend validation mirrors frontend - -### Automated Testing -- ⏳ Not yet implemented (recommended for Phase 6) - -## Performance Impact - -- **Minimal**: Port definitions computed on-demand -- **Validation**: Only runs during connection attempts -- **Rendering**: Efficient Set lookups for occupancy checks -- **Memory**: No additional state storage required - -## Breaking Changes - -**None** - System is fully backwards compatible: -- Existing nodes automatically get default ports -- Default ports use PortSemantic.Any (accepts all connections) -- No changes required to nodes without custom ports - -## Next Steps (Optional Future Work) - -### Phase 4: Config Handling (Not Started) -- Update config panel to show port requirements -- Add warnings when changing config affects connections -- Implement connection migration on config change - -### Phase 6: Comprehensive Testing (Not Started) -- Unit tests for port compatibility -- Integration tests for validation pipeline -- E2E tests for user workflows - -### Phase 7: Documentation & Polish (Partial) -- ✅ Comprehensive documentation created -- ✅ Quick reference guide created -- ⏳ Tutorial videos/screenshots -- ⏳ User-facing help tooltips - -## Migration Guide - -### For Existing Code -No changes required! System is backwards compatible. - -### For New Nodes -```typescript -// 1. Implement port methods -getInputPorts(config: BlockConfig): PortDefinition[] { - return [ - { id: 'input', label: 'Input', semantic: PortSemantic.Data } - ] -} - -// 2. Update BlockNode rendering (if custom handles needed) -// 3. Add backend NodeSpec input_ports_config -``` - -## Documentation - -- **Full Implementation**: [PORT_BASED_CONNECTION_SYSTEM.md](./PORT_BASED_CONNECTION_SYSTEM.md) -- **Quick Reference**: [PORT_SYSTEM_QUICK_REFERENCE.md](./PORT_SYSTEM_QUICK_REFERENCE.md) -- **Loss Node Example**: [LOSS_NODE_MULTIPLE_INPUTS.md](./LOSS_NODE_MULTIPLE_INPUTS.md) - -## Conclusion - -The port-based connection system successfully addresses the critical bugs and establishes a robust foundation for semantic validation. The implementation is: - -- ✅ **Complete**: All core phases implemented -- ✅ **Tested**: Manual testing confirms functionality -- ✅ **Documented**: Comprehensive docs and quick reference -- ✅ **Backwards Compatible**: No breaking changes -- ✅ **Extensible**: Easy to add new semantic types -- ✅ **Type-Safe**: Full TypeScript support -- ✅ **Maintainable**: Clear patterns and structure - -The system is ready for production use and provides a solid foundation for future enhancements. - -## Sign-Off - -Implementation completed: December 2024 -Status: Production Ready ✅ -Verified: No compilation errors -Documentation: Complete diff --git a/docs/IMPLEMENTATION_SUMMARY.md b/docs/IMPLEMENTATION_SUMMARY.md deleted file mode 100644 index b4caf59..0000000 --- a/docs/IMPLEMENTATION_SUMMARY.md +++ /dev/null @@ -1,150 +0,0 @@ -# Implementation Summary: Node Configurations and Connection Rules - -## Changes Implemented - -### 1. Enhanced Block Configurations - -Added simple, practical configuration options to existing blocks: - -#### Conv2D -- Added `dilation` parameter (default: 1) for advanced convolution patterns - -#### BatchNorm -- Added `eps` parameter (default: 0.00001) for numerical stability -- Added `affine` boolean (default: true) to control learnable parameters - -#### MaxPool2D -- Added `padding` parameter (default: 0) for border handling - -#### Custom Layer -- Added `code` field to store Python implementation -- Code is edited via modal dialog with syntax highlighting - -### 2. Connection Validation Rules - -Implemented comprehensive dimension-based connection rules in `blockDefinitions.ts`: - -**New Functions:** -- `validateBlockConnection(sourceType, targetType, sourceShape)` - Returns error message if invalid -- `allowsMultipleInputs(blockType)` - Identifies merge blocks - -**Rules by Block Type:** -- **Input**: Cannot receive connections (source only) -- **Conv2D, MaxPool2D**: Require 4D input `[batch, channels, height, width]` -- **Linear**: Requires 2D input `[batch, features]` -- **Multi-Head Attention**: Requires 3D input `[batch, sequence, embedding]` -- **BatchNorm**: Requires 2D or 4D input -- **Dropout, ReLU, Softmax**: Dimension-agnostic -- **Flatten**: Accepts any input, outputs 2D -- **Concat, Add**: Accept multiple inputs (with compatibility checks) -- **Custom**: Flexible (user-defined) - -### 3. Custom Layer Modal Dialog - -**New Component:** `CustomLayerModal.tsx` - -**Features:** -- CodeMirror editor with Python syntax highlighting (@uiw/react-codemirror) -- Fields: Layer Name (required), Python Code (required), Output Shape (optional), Description (optional) -- Modal dialog presentation (not sidebar) -- Real-time code editing with line numbers and syntax highlighting - -**Updated Component:** `ConfigPanel.tsx` -- Detects custom block type and shows modal trigger button -- Displays saved configuration (name, description) in sidebar -- Opens modal for code editing - -### 4. Documentation - -**New File:** `docs/NODES_AND_RULES.md` - -Comprehensive documentation covering: -- All 13 block types with detailed descriptions -- Configuration parameters for each block -- Input/output shape requirements -- Connection rules and validation errors -- Common architecture patterns -- Tips for building architectures -- Color coding guide - -### 5. Updated Store Validation - -**Modified:** `store.ts` -- Integrated new `validateBlockConnection()` function -- Enhanced error messaging for invalid connections -- Maintains special handling for merge blocks (Add, Concat) - -## Files Modified - -1. `src/lib/blockDefinitions.ts` - Enhanced configs + connection rules -2. `src/components/ConfigPanel.tsx` - Custom layer modal integration -3. `src/components/CustomLayerModal.tsx` - New component (created) -4. `src/lib/store.ts` - Updated validation logic -5. `docs/NODES_AND_RULES.md` - New documentation (created) -6. `.github/copilot-instructions.md` - Updated with new patterns -7. `package.json` - Added CodeMirror dependencies - -## Dependencies Added - -```json -{ - "@uiw/react-codemirror": "^4.x.x", - "@codemirror/lang-python": "^6.x.x" -} -``` - -## Testing - -- ✅ Build succeeds without errors -- ✅ No TypeScript compilation errors -- ✅ All imports resolve correctly -- ✅ Connection validation rules integrated - -## Key Design Decisions - -1. **Simplicity**: Added only essential, commonly-used configuration options -2. **Non-duplication**: Leveraged existing schema-driven UI generation -3. **Consistency**: Maintained existing patterns (Zustand state, Radix UI components) -4. **User Experience**: Custom layer uses modal for better code editing experience -5. **Documentation**: Comprehensive but focused on practical usage - -## Usage Examples - -### Connecting Blocks -``` -✅ Valid: Input (4D) → Conv2D → ReLU → MaxPool2D -❌ Invalid: Input (4D) → Linear (needs Flatten first) -✅ Valid: Input (4D) → Conv2D → Flatten → Linear -``` - -### Custom Layer Code -```python -# Simple transformation -return x * 2.0 - -# Multi-step processing -x = torch.relu(x) -x = x.view(x.size(0), -1) -return x -``` - -### Multi-Input Connections -``` -Branch1 → Concat ← Branch2 ✅ (if dimensions compatible) -Branch1 → Add ← Branch2 ✅ (if shapes identical) -``` - -## Future Enhancements (Not Implemented) - -- Layer groups/templates -- Visual feedback for invalid connection attempts with tooltip -- Auto-suggestion of intermediate layers (e.g., suggest Flatten) -- Configuration presets for common architectures -- Validation of custom layer Python code - -## Notes - -- Custom layer code is stored in block config but not validated at design time -- Connection validation happens on attempted connection, not retroactively -- Documentation file is separate from codebase for easy reference -- All changes maintain backward compatibility with existing projects diff --git a/docs/IMPLEMENTATION_SUMMARY_2.md b/docs/IMPLEMENTATION_SUMMARY_2.md deleted file mode 100644 index 261b655..0000000 --- a/docs/IMPLEMENTATION_SUMMARY_2.md +++ /dev/null @@ -1,354 +0,0 @@ -# VisionForge - Complete Implementation Summary - -## Overview -This document summarizes all the changes implemented for VisionForge, including new node types, backend integration, URL routing, and UI improvements. - ---- - -## ✅ Completed Features - -### 1. New Block Types - -#### Output Node -- **Type**: `output` -- **Category**: output -- **Icon**: ArrowUp (green) -- **Configuration**: - - Output type (classification, regression, segmentation, custom) - - Number of classes -- **Purpose**: Terminal node to define model predictions - -#### Loss Function Node -- **Type**: `loss` -- **Category**: output -- **Icon**: Target (red) -- **Configuration**: - - Loss type: cross_entropy, mse, mae, bce, nll, smooth_l1, kl_div, custom - - Reduction: mean, sum, none - - Optional class weights (JSON array) -- **Purpose**: Define loss function for training - -#### Empty/Placeholder Node -- **Type**: `empty` -- **Category**: utility -- **Icon**: Circle (gray) -- **Configuration**: - - Note field for comments -- **Purpose**: Placeholder for architecture planning - -### 2. Enhanced Input Node - -**New Configuration Options**: -- `has_ground_truth` (boolean): Enable dual output for labels -- `ground_truth_shape` (text): Shape for labels e.g. `[1, 10]` -- `randomize` (boolean): Use synthetic random data -- `csv_file` (text): Path to CSV file for data loading - -**Benefits**: -- Input node can now output both input data and ground truth labels -- Supports multiple data sources (random, CSV) - -### 3. Backend Integration - -#### Database Storage (SQLite3) -- **Models**: Project, ModelArchitecture, Block, Connection -- **File**: `block_manager/models.py` -- Full CRUD operations via Django REST Framework - -#### API Endpoints -``` -GET /api/projects/ # List all projects -POST /api/projects/ # Create new project -GET /api/projects/{id}/ # Get project details -PATCH /api/projects/{id}/ # Update project -DELETE /api/projects/{id}/ # Delete project -POST /api/projects/{id}/save-architecture # Save nodes/edges -GET /api/projects/{id}/load-architecture # Load nodes/edges -POST /api/validate # Validate architecture -``` - -#### Frontend API Service -- **File**: `frontend/src/lib/projectApi.ts` -- Typed API client for all backend endpoints -- Automatic project conversion between backend/frontend formats - -### 4. URL Routing - -#### Routes -``` -/ # Home (empty canvas) -/project/:projectId # Specific project -``` - -#### Features -- **Auto-load**: Projects load automatically from URL parameter -- **Navigation**: Navigate to `/project/:id` when creating/loading projects -- **Persistence**: Project ID in URL ensures shareable links -- **Loading State**: Shows spinner while loading project - -#### Implementation -- React Router v6 -- BrowserRouter in `main.tsx` -- Route-aware ProjectCanvas component in `App.tsx` -- useParams hook for project ID extraction - -### 5. Header Component - Full Backend Integration - -**Removed**: GitHub Spark `useKV` storage -**Added**: Backend API integration - -#### Key Changes -- `fetchProjects()`: Load projects from backend on mount -- `createProject()`: Create via API + navigate to `/project/:id` -- `saveArchitecture()`: Save nodes/edges to backend -- `loadProject()`: Navigate to project URL -- `handleImportJSON()`: Create project + save architecture + navigate - -**Benefits**: -- Persistent storage in SQLite database -- Projects survive page reload -- Shareable project URLs -- No dependency on Spark KV - -### 6. UI Improvements - -#### History Toolbar -- **Added**: Reset button with trash icon -- **Features**: - - Confirmation dialog before clearing - - Disabled when canvas is empty - - Visual separator from undo/redo - - Hover effect with destructive color - -#### Config Panel -- **Fixed**: Scrollability issue -- **Changes**: - - Replaced ScrollArea with native `overflow-y-auto` - - Fixed header/footer with `shrink-0` - - Proper flex layout - -#### Error Badges -- **Added**: Red exclamation badge on nodes with errors -- **Location**: Top-right corner of node cards -- **Condition**: Only shows for `type === 'error'` validation errors - -### 7. Connection Validation Rules - -**Updated Rules**: -- Output and Loss nodes can receive any connections (terminal nodes) -- Empty nodes are passthrough (always valid connections) -- All existing validation rules preserved - ---- - -## 📁 File Changes - -### New Files -``` -frontend/src/lib/projectApi.ts # API client for backend -frontend/src/lib/exportImport.ts # JSON export/import utilities -frontend/.env # Environment variables -frontend/.env.example # Environment template -IMPLEMENTATION_SUMMARY.md # This file -EXPORT_FORMAT.md # JSON format documentation -``` - -### Modified Files -``` -frontend/src/main.tsx # Added BrowserRouter -frontend/src/App.tsx # Added Routes and project loading -frontend/src/components/Header.tsx # Backend API integration -frontend/src/components/HistoryToolbar.tsx # Added reset button -frontend/src/components/BlockNode.tsx # Added error badges -frontend/src/components/ConfigPanel.tsx # Fixed scrolling -frontend/src/lib/types.ts # New block types -frontend/src/lib/blockDefinitions.ts # New block definitions -``` - -### Backend Files (Already Complete) -``` -block_manager/models.py # Database models -block_manager/serializers.py # API serializers -block_manager/views/project_views.py # Project CRUD -block_manager/views/architecture_views.py # Save/load architecture -block_manager/urls.py # API routes -``` - ---- - -## 🔧 Configuration - -### Environment Variables -Create `.env` in frontend directory: -```bash -VITE_API_URL=http://localhost:8000/api -``` - -### Database -Run Django migrations: -```bash -cd project -python manage.py makemigrations -python manage.py migrate -``` - -### Install Dependencies -```bash -cd project/frontend -npm install react-router-dom -``` - ---- - -## 🚀 Running the Application - -### Backend (Django) -```bash -cd project -python manage.py runserver -# Runs on http://localhost:8000 -``` - -### Frontend (Vite) -```bash -cd project/frontend -npm run dev -# Runs on http://localhost:5173 -``` - -### Access -- Frontend: `http://localhost:5173` -- Backend API: `http://localhost:8000/api` -- Admin: `http://localhost:8000/admin` - ---- - -## 📊 Workflow Examples - -### Creating a New Project -1. Click "Create New Project" in header dropdown -2. Enter name, description, framework -3. Click "Create Project" -4. **Result**: Navigates to `/project/:id` with empty canvas - -### Saving a Project -1. Build architecture on canvas -2. Click "Save" button -3. **Result**: Nodes/edges saved to database - -### Loading a Project -1. Click project dropdown -2. Select project from list -3. **Result**: Navigates to `/project/:id` and loads architecture - -### Sharing a Project -1. Copy URL: `http://localhost:5173/project/:id` -2. Share with team -3. **Result**: Others can access the same project - -### Importing JSON -1. Click "Import" button -2. Select JSON file -3. **Result**: Creates new project + navigates to `/project/:id` - ---- - -## 🎨 New Node Categories - -### Input (1 node) -- Input (with ground truth support) - -### Output (2 nodes) -- Output -- Loss - -### Basic (8 nodes) -- Linear, Conv2D, Dropout, BatchNorm -- ReLU, Softmax, Flatten, MaxPool2D - -### Advanced (2 nodes) -- Multi-Head Attention -- Custom Layer - -### Merge (2 nodes) -- Concatenate -- Add - -### Utility (1 node) -- Empty (placeholder) - -**Total**: 16 block types - ---- - -## 🔒 Security - -### JSON Export/Import -- **No code execution**: Only configuration data -- **No secrets**: Excludes API keys, credentials -- **Validated**: Schema validation on import -- **Type-safe**: Full TypeScript typing - -### Backend API -- **CORS**: Configure for production -- **Authentication**: Ready for Django auth -- **Validation**: Input validation on all endpoints - ---- - -## 🧪 Testing Checklist - -- [x] Build succeeds without errors -- [x] React Router installed and configured -- [x] Project API service created -- [x] Header uses backend API -- [x] URL routing works -- [x] New block types render -- [x] Enhanced input node config visible -- [x] Reset button in toolbar -- [x] Config panel scrolls properly -- [x] Error badges show on invalid nodes - -### Manual Testing Needed -- [ ] Create project via UI -- [ ] Save architecture to backend -- [ ] Load project from URL -- [ ] Import JSON and create project -- [ ] Verify database persistence -- [ ] Test project switching -- [ ] Validate URL sharing - ---- - -## 📝 Next Steps - -### Optional Enhancements -1. **Authentication**: Add user accounts -2. **Permissions**: Project ownership and sharing -3. **Search**: Search projects by name/description -4. **Tags**: Categorize projects with tags -5. **Versioning**: Save multiple versions of architecture -6. **Export History**: Track exports and downloads -7. **Collaboration**: Real-time multi-user editing -8. **Templates**: Pre-built architecture templates - -### Documentation -- API documentation with Swagger -- User guide for new block types -- Video tutorials -- Architecture examples - ---- - -## 🎯 Summary - -All requested features have been successfully implemented: - -✅ **New Nodes**: Output, Loss, Empty -✅ **Enhanced Input**: Ground truth, randomize, CSV support -✅ **Backend Storage**: SQLite3 database via Django -✅ **URL Routing**: `/project/:id` with auto-load -✅ **Backend Integration**: Full API client and Header migration -✅ **UI Improvements**: Reset button, scrolling, error badges - -The application is now fully integrated with persistent backend storage and shareable project URLs. Projects are stored in SQLite3 database and can be accessed via clean URLs. diff --git a/docs/LOSS_NODE_MULTIPLE_INPUTS.md b/docs/LOSS_NODE_MULTIPLE_INPUTS.md deleted file mode 100644 index 588e775..0000000 --- a/docs/LOSS_NODE_MULTIPLE_INPUTS.md +++ /dev/null @@ -1,95 +0,0 @@ -# Loss Node Multiple Inputs Implementation - -## Overview -The Loss block now supports multiple named input ports that vary based on the selected loss function type. This enables proper modeling of loss functions that require different numbers and types of inputs. - -## Changes Made - -### 1. Backend Changes - -#### Updated Models (`block_manager/services/nodes/specs/models.py`) -- Added `InputPortSpec` dataclass to define named input ports -- Added `input_ports` field to `NodeSpec` to support configurable input ports - -#### Updated Loss Spec (`block_manager/services/nodes/specs/pytorch/__init__.py`) -- Added `allows_multiple_inputs=True` to LOSS_SPEC -- Added more loss function types: Triplet Loss, Contrastive Loss, NLL, KL Divergence -- Defined default input ports: `y_pred` and `y_true` -- Added `metadata.input_ports_config` to map loss types to their required input ports - -### 2. Frontend Changes - -#### Updated LossNode Definition (`frontend/src/lib/nodes/definitions/pytorch/loss.ts`) -- Added `InputPort` interface for type safety -- Added `getInputPorts(config)` method that returns appropriate input ports based on loss type -- Supports loss functions: - - **Standard losses** (MSE, MAE, Cross Entropy, BCE, NLL, KL Div): 2 inputs - - `y_pred`: Model predictions - - `y_true`: Ground truth labels/values - - **Triplet Loss**: 3 inputs - - `anchor`: Anchor embedding - - `positive`: Positive example embedding - - `negative`: Negative example embedding - - **Contrastive Loss**: 3 inputs - - `input1`: First input embedding - - `input2`: Second input embedding - - `label`: Similarity label (1 or -1) - -#### Updated BlockNode Component (`frontend/src/components/BlockNode.tsx`) -- Excluded `loss` from single-input-handle nodes -- Added dedicated rendering logic for loss node with multiple named input ports -- Input ports are displayed on the left side with colored labels -- Single output port on the right (loss value) in red -- Ports are evenly spaced vertically similar to DataLoader outlets - -#### Updated Store Validation (`frontend/src/lib/store.ts`) -- Added `loss` to nodes that allow multiple inputs -- Added validation to check that loss nodes have the correct number of inputs based on loss type -- Shows helpful error message indicating required inputs and their names - -## Usage Examples - -### Mean Squared Error (MSE) -``` -Output Block → y_pred (Predictions) → Loss Node -DataLoader → y_true (Ground Truth) → Loss Node -``` - -### Triplet Loss -``` -Model Branch 1 → anchor (Anchor) → Loss Node -Model Branch 2 → positive (Positive) → Loss Node -Model Branch 3 → negative (Negative) → Loss Node -``` - -### Cross Entropy -``` -Output Block → y_pred (Predictions) → Loss Node -DataLoader → y_true (Ground Truth) → Loss Node -``` - -## Visual Representation - -The Loss node now displays: -- **Left side**: Multiple colored input ports with labels (e.g., "Predictions", "Ground Truth", "Anchor", "Positive", "Negative") -- **Right side**: Single red output port for the loss value (to connect to optimizer) - -Each input port has: -- A unique color for easy identification -- A descriptive label -- A handle ID in the format `loss-input-{port_id}` - -## Benefits - -1. **Type Safety**: Each input is clearly labeled, reducing connection errors -2. **Flexibility**: Support for various loss functions with different input requirements -3. **Visual Clarity**: Users can easily see what each input represents -4. **Validation**: System validates that the correct number of inputs are connected -5. **Extensibility**: Easy to add new loss functions with custom input requirements - -## Future Enhancements - -1. Add more specialized loss functions (e.g., Focal Loss, Dice Loss) -2. Support custom loss functions with user-defined input ports -3. Add tooltips showing expected tensor shapes for each input -4. Implement port-specific validation (e.g., ensure y_true comes from DataLoader) diff --git a/docs/MIGRATION_COMPLETE.md b/docs/MIGRATION_COMPLETE.md deleted file mode 100644 index e47b94f..0000000 --- a/docs/MIGRATION_COMPLETE.md +++ /dev/null @@ -1,187 +0,0 @@ -# Migration Complete: Modular Node Definition System - -## Summary -Successfully migrated VisionForge from a monolithic 698-line `blockDefinitions.ts` file to a fully modular, class-based node definition architecture. The migration is **100% backward compatible** with zero breaking changes. - -## What Was Accomplished - -### Frontend (TypeScript/React) -1. **Core Architecture** - - Created `contracts.ts` with interface definitions (`INodeDefinition`, `INodeValidator`, `IShapeComputer`) - - Built `base.ts` with abstract `NodeDefinition` class and 4 specialized base classes - - Implemented auto-discovery `registry.ts` with lazy initialization and caching - -2. **Node Implementations** (17 total) - - **Input/Output**: `input`, `dataloader`, `output`, `loss`, `empty` - - **Basic Layers**: `linear`, `conv2d`, `flatten`, `relu`, `dropout`, `batchnorm`, `maxpool`, `softmax` - - **Advanced**: `concat`, `add`, `attention`, `custom` - - Each node is ~40 lines with embedded validation and shape computation - -3. **Backward Compatibility** - - Created `legacy/blockDefinitionsAdapter.ts` using Proxy pattern - - Refactored original `blockDefinitions.ts` to re-export adapter (698→30 lines) - - All existing components work unchanged through adapter - - Deprecation warnings logged once per session - -4. **Component Updates** - - Added registry imports to `BlockPalette.tsx` and `store.ts` - - Zero changes needed in `ConfigPanel.tsx`, `Canvas.tsx`, `BlockNode.tsx` (work via adapter) - - Added `getNodeDefinitions()` and `getNodeDefinition()` to `api.ts` - -### Backend (Python/Django) -1. **Core Architecture** - - Created `services/nodes/base.py` with: - - Abstract `NodeDefinition` class - - 4 specialized base classes (SourceNodeDefinition, TerminalNodeDefinition, MergeNodeDefinition, PassthroughNodeDefinition) - - 2 mixins (ShapeComputerMixin, ValidatorMixin) - - Built `services/nodes/registry.py` with dynamic `importlib` discovery - -2. **Node Implementations** (2 demos) - - `pytorch/linear.py` - Linear layer with shape computation - - `pytorch/conv2d.py` - Conv2D with 2D convolution logic - - Structure ready for remaining 15 nodes - -3. **API Endpoints** - - `GET /api/node-definitions?framework=pytorch` - Returns all node definitions - - `GET /api/node-definitions/?framework=pytorch` - Returns specific node - - Routes added to `block_manager/urls.py` - -## File Changes - -### Created Files (22 total) -#### Frontend -- `src/lib/nodes/contracts.ts` - Interface definitions -- `src/lib/nodes/base.ts` - Base classes -- `src/lib/nodes/registry.ts` - Auto-discovery system -- `src/lib/nodes/definitions/pytorch/*.ts` (17 files) - Node implementations -- `src/lib/nodes/definitions/tensorflow/index.ts` - TensorFlow structure -- `src/lib/legacy/blockDefinitionsAdapter.ts` - Backward compatibility - -#### Backend -- `block_manager/services/nodes/base.py` - Base classes and mixins -- `block_manager/services/nodes/registry.py` - Dynamic discovery -- `block_manager/services/nodes/__init__.py` - Package exports -- `block_manager/services/nodes/pytorch/linear.py` - Linear node -- `block_manager/services/nodes/pytorch/conv2d.py` - Conv2D node -- `block_manager/services/nodes/pytorch/__init__.py` - PyTorch package -- `block_manager/services/nodes/tensorflow/__init__.py` - TensorFlow package - -### Modified Files (6 total) -#### Frontend -- `src/lib/blockDefinitions.ts` - Reduced from 698 to ~30 lines (re-exports adapter) -- `src/lib/store.ts` - Added registry imports -- `src/components/BlockPalette.tsx` - Added registry imports -- `src/lib/api.ts` - Added node definition endpoints -- `src/main.tsx` - Restored BrowserRouter wrapper - -#### Backend -- `block_manager/views/architecture_views.py` - Added 2 API endpoints -- `block_manager/urls.py` - Added routes for node definitions - -## Verification Results - -### Build Status ✅ -- **Frontend Build**: Success (6650 modules, 6.3MB bundle) -- **Dev Server**: Running on http://localhost:5000/ -- **TypeScript Errors**: 0 critical errors (only style warnings) -- **Runtime**: No errors detected - -### Backward Compatibility ✅ -- All components work unchanged -- Legacy `blockDefinitions` object still accessible -- Adapter shows deprecation warning once per session -- Zero breaking changes for existing code - -### Code Quality ✅ -- **Frontend**: Fully typed, follows SOLID principles -- **Backend**: Type hints, docstrings, mixins for reusability -- **Documentation**: Comprehensive (NODE_DEFINITION_ARCHITECTURE.md) -- **Migration Path**: Clear deprecation notices - -## Benefits Achieved - -### Maintainability -- **Before**: 698-line monolith with hard-coded if-else chains -- **After**: 17 files × ~40 lines = highly focused, single-responsibility modules -- **Impact**: Adding a new node now requires 1 file, no edits to existing code - -### Extensibility -- **Framework Support**: Easy to add TensorFlow/JAX implementations -- **Custom Validators**: Each node has embedded validation logic -- **Shape Computation**: Decentralized to node classes -- **Connection Rules**: Enforced at node level, not globally - -### Developer Experience -- **Auto-Discovery**: No manual registry updates needed -- **Type Safety**: Full TypeScript/Python type coverage -- **Clear Patterns**: Base classes guide new implementations -- **Documentation**: Each node is self-documenting - -### Performance -- **Lazy Loading**: Registry only loads on first access -- **Caching**: Node definitions cached after first retrieval -- **Bundle Size**: No increase (tree-shaking removes unused nodes) - -## Remaining Tasks - -### Backend (Low Priority) -1. Complete remaining 15 PyTorch node implementations (Linear and Conv2D done) -2. Add TensorFlow node implementations when framework diverges -3. Unit tests for node classes (explicitly excluded per user request) - -### Future Enhancements (Not Required) -1. Hot-reload for node definitions in development -2. Server-driven UI (frontend queries backend for available nodes) -3. Visual node editor for creating custom nodes -4. Performance profiling for large graphs - -## Migration Strategy - -The implementation followed a **strangler fig pattern**: -1. Built new system alongside old (no disruption) -2. Created adapter layer for backward compatibility -3. Gradually update components to use new system -4. Mark old system as deprecated (but keep functional) -5. Remove adapter in future major version (v2.0) - -## Dependencies Added - -### Frontend -- `react-router-dom` (was missing, required by App.tsx) - -### Backend -- No new dependencies (uses existing Django REST Framework) - -## Testing Verification - -### Manual Tests Passed -- ✅ Frontend builds successfully -- ✅ Dev server runs without errors -- ✅ All 17 nodes accessible via registry -- ✅ Legacy adapter returns correct format -- ✅ Type checking passes (0 critical errors) -- ✅ Backend API endpoints respond correctly - -### Automated Tests -- Not implemented (explicitly excluded per user request) -- Test structure ready in `tests.py` files - -## Conclusion - -The migration is **complete and production-ready**. All objectives from the original plan have been achieved: - -- ✅ Modular, extensible architecture -- ✅ High decoupling (each node is independent) -- ✅ No hard-coded if-else statements -- ✅ 100% backward compatible (zero breaking changes) -- ✅ Type-safe with full TypeScript/Python coverage -- ✅ Auto-discovery for minimal maintenance -- ✅ Comprehensive documentation - -The system is ready for: -- Adding new node types (1 file per node) -- Supporting multiple frameworks (TensorFlow/JAX) -- Scaling to 100+ node types with no performance impact -- Future migration to server-driven UI - -**Next Steps**: Run application end-to-end, monitor for any runtime issues, then plan removal of legacy adapter for v2.0. diff --git a/docs/MIGRATION_IMPLEMENTATION.md b/docs/MIGRATION_IMPLEMENTATION.md deleted file mode 100644 index 54f19dd..0000000 --- a/docs/MIGRATION_IMPLEMENTATION.md +++ /dev/null @@ -1,336 +0,0 @@ -# Node Definition System Migration - Implementation Summary - -## Overview - -Successfully migrated VisionForge from a monolithic, hard-coded node definition system to a **modular, class-based architecture** with automatic discovery, high decoupling, and framework extensibility. - -## ✅ Completed Implementation - -### Core Architecture - -#### Frontend (TypeScript) - -**Created Files:** -- `frontend/src/lib/nodes/contracts.ts` - Interface definitions and contracts -- `frontend/src/lib/nodes/base.ts` - Abstract base classes with utilities -- `frontend/src/lib/nodes/registry.ts` - Automatic discovery and loading system -- `frontend/src/lib/legacy/blockDefinitionsAdapter.ts` - Backward compatibility layer - -**PyTorch Node Definitions (17 nodes):** -- `definitions/pytorch/input.ts` - Input placeholder node -- `definitions/pytorch/dataloader.ts` - Data loading node -- `definitions/pytorch/output.ts` - Output node -- `definitions/pytorch/loss.ts` - Loss function node -- `definitions/pytorch/empty.ts` - Placeholder/utility node -- `definitions/pytorch/linear.ts` - Fully connected layer -- `definitions/pytorch/conv2d.ts` - 2D convolution -- `definitions/pytorch/flatten.ts` - Flatten transformation -- `definitions/pytorch/relu.ts` - ReLU activation -- `definitions/pytorch/dropout.ts` - Dropout regularization -- `definitions/pytorch/batchnorm.ts` - Batch normalization -- `definitions/pytorch/maxpool.ts` - Max pooling -- `definitions/pytorch/softmax.ts` - Softmax activation -- `definitions/pytorch/concat.ts` - Concatenation (multi-input) -- `definitions/pytorch/add.ts` - Element-wise addition (multi-input) -- `definitions/pytorch/attention.ts` - Multi-head attention -- `definitions/pytorch/custom.ts` - Custom user-defined layer - -**TensorFlow Support:** -- `definitions/tensorflow/index.ts` - Framework structure (currently mirrors PyTorch) - -**Legacy Compatibility:** -- `frontend/src/lib/blockDefinitions.ts` - Refactored to re-export from adapter -- All existing imports continue to work unchanged -- Deprecation warnings guide developers to new system - -#### Backend (Python) - -**Created Files:** -- `block_manager/services/nodes/base.py` - Base classes, mixins, and utilities -- `block_manager/services/nodes/registry.py` - Dynamic node discovery system -- `block_manager/services/nodes/__init__.py` - Package exports - -**PyTorch Node Implementations:** -- `pytorch/linear.py` - Linear layer with shape computation and validation -- `pytorch/conv2d.py` - Conv2D with dimension calculation -- `pytorch/__init__.py` - Package exports - -**TensorFlow Support:** -- `tensorflow/__init__.py` - Framework structure (mirrors PyTorch initially) - -#### Documentation - -**Created:** -- `docs/NODE_DEFINITION_ARCHITECTURE.md` - Comprehensive architecture guide - - System overview and principles - - Step-by-step guide for adding new nodes - - Base class hierarchy explanation - - Registry usage patterns - - Migration guide from legacy system - - Testing patterns - - Best practices - -**To Update:** -- `docs/NODES_AND_RULES.md` - Reference new architecture -- `docs/IMPLEMENTATION_SUMMARY.md` - Add migration details - -## 🎯 Key Achievements - -### 1. **High Decoupling** -- ✅ Each node type in separate file -- ✅ No central switch/if-else chains -- ✅ Validators embedded in node classes -- ✅ Shape computation localized to nodes - -### 2. **Extensibility** -- ✅ Add new node = create one file + export -- ✅ Framework-specific implementations supported -- ✅ Base classes handle common patterns -- ✅ Automatic discovery (no manual registration) - -### 3. **Non-Breaking Migration** -- ✅ Legacy adapter maintains full compatibility -- ✅ All existing code works unchanged -- ✅ Zero compilation errors -- ✅ Gradual migration path with warnings - -### 4. **Framework Agnostic** -- ✅ PyTorch implementation complete -- ✅ TensorFlow structure prepared -- ✅ Easy to diverge implementations when needed -- ✅ Shared logic via base classes/mixins - -## 📊 Architecture Improvements - -### Before (Monolithic) -``` -blockDefinitions.ts (698 lines) -├── All node configs in one object -├── All validators in functions at bottom -├── All shape logic mixed together -└── Hard to extend, easy to break -``` - -### After (Modular) -``` -nodes/ -├── contracts.ts (Interface contracts) -├── base.ts (Shared utilities) -├── registry.ts (Auto-discovery) -├── definitions/ -│ ├── pytorch/ (17 node files, ~40 lines each) -│ └── tensorflow/ (Framework structure) -└── legacy/ (Backward compatibility) -``` - -**Metrics:** -- **Before**: 1 file, 698 lines, hard-coded if-else chains -- **After**: 20+ files, modular structure, zero hard-coding -- **Lines per node**: ~40 (was ~40 embedded in monolith) -- **Coupling**: Minimal (was high) -- **Extensibility**: Excellent (was poor) - -## 🔧 Technical Details - -### Base Class Hierarchy - -**Frontend:** -``` -NodeDefinition (abstract) -├── SourceNodeDefinition (input, dataloader) -├── TerminalNodeDefinition (output, loss) -├── MergeNodeDefinition (concat, add) -└── PassthroughNodeDefinition (relu, dropout, etc.) -``` - -**Backend:** -``` -NodeDefinition (abstract, + mixins) -├── ShapeComputerMixin -├── ValidatorMixin -├── SourceNodeDefinition -├── TerminalNodeDefinition -├── MergeNodeDefinition -└── PassthroughNodeDefinition -``` - -### Registry Pattern - -**Frontend:** -```typescript -// Initialization on first access -const registryCache = { - pytorch: { linear: LinearNode(), conv2d: Conv2DNode(), ... }, - tensorflow: { ... } -} - -// Usage -const nodeDef = getNodeDefinition('linear', BackendFramework.PyTorch) -const allNodes = getAllNodeDefinitions(BackendFramework.PyTorch) -``` - -**Backend:** -```python -# Dynamic discovery via importlib -def _load_framework_nodes(framework, package_name): - # Automatically finds and instantiates all NodeDefinition subclasses - -# Usage -node_def = get_node_definition('linear', Framework.PYTORCH) -all_nodes = get_all_node_definitions(Framework.PYTORCH) -``` - -### Backward Compatibility - -**Legacy Adapter** (`blockDefinitionsAdapter.ts`): -- Proxies access to legacy `blockDefinitions` object -- Converts new node definitions to old format on-the-fly -- Shows deprecation warning once per session -- Maintains exact same API surface - -**Result:** -- Zero breaking changes -- All existing components work unchanged -- Gradual migration possible -- Clear deprecation path - -## 🧪 Validation & Testing - -### Current Status -- ✅ Zero TypeScript compilation errors -- ✅ Legacy imports work correctly -- ✅ Components (ConfigPanel, BlockPalette, etc.) compile successfully -- ✅ Backend Python type hints in place -- ⏳ Unit tests to be added (planned) - -### Validation Patterns Implemented - -**Dimension Requirements:** -```typescript -// Single dimension -validateDimensions(shape, { dims: 2, description: '[batch, features]' }) - -// Multiple options -validateDimensions(shape, { dims: [2, 4], description: '2D or 4D' }) - -// Any dimension -validateDimensions(shape, { dims: 'any', description: '' }) -``` - -**Config Validation:** -- Required field checking -- Numeric range validation -- Custom validation hooks -- Format validation (JSON arrays, identifiers, etc.) - -**Connection Validation:** -- Source type exceptions (input, dataloader, empty, custom) -- Dimension compatibility -- Multi-input handling (concat, add) -- Shape matching for element-wise ops - -## 📋 Remaining Work - -### High Priority -1. ⏳ **Complete all node types** - Add remaining nodes (currently have 17/17 for demo) -2. ⏳ **Unit tests** - Frontend and backend test suites -3. ⏳ **Component migration** - Update ConfigPanel, BlockPalette to use registry directly - -### Medium Priority -4. ⏳ **Store refactoring** - Update dimension inference to use node methods -5. ⏳ **Code generator update** - Use node definitions for code generation -6. ⏳ **Backend API endpoint** - Expose node definitions via REST API - -### Low Priority -7. ⏳ **Performance optimization** - Cache frequently accessed definitions -8. ⏳ **TensorFlow divergence** - Implement TF-specific nodes where needed -9. ⏳ **Legacy removal** - Mark adapter for deprecation, eventual removal - -## 🚀 Usage Examples - -### Adding a New Node (Frontend) - -```typescript -// 1. Create file: definitions/pytorch/my_layer.ts -export class MyLayerNode extends NodeDefinition { - readonly metadata = { type: 'my_layer', ... } - readonly configSchema = [...] - computeOutputShape(input, config) { ... } - validateIncomingConnection(...) { ... } -} - -// 2. Export in definitions/pytorch/index.ts -export { MyLayerNode } from './my_layer' - -// 3. Add to types.ts -export type BlockType = ... | 'my_layer' - -// Done! Node is auto-discovered and available. -``` - -### Adding a New Node (Backend) - -```python -# 1. Create file: pytorch/my_layer.py -class MyLayerNode(NodeDefinition): - @property - def metadata(self): ... - - @property - def config_schema(self): ... - - def compute_output_shape(self, input_shape, config): ... - -# 2. Export in pytorch/__init__.py -from .my_layer import MyLayerNode - -# Done! Registry auto-discovers on next load. -``` - -## 📈 Benefits Realized - -### For Developers -- **Faster node addition** - Single file, clear pattern -- **Easier debugging** - Logic isolated to one place -- **Better IDE support** - Strong typing, autocomplete -- **Clear documentation** - Each node self-documents - -### For Maintainers -- **Lower coupling** - Changes isolated to single node -- **Easier testing** - Unit test one node at a time -- **Clearer architecture** - No 700-line god files -- **Easier onboarding** - Clear patterns, good docs - -### For Users -- **More reliable** - Validation logic closer to implementation -- **Better errors** - Context-specific validation messages -- **Future-proof** - Easy to add new capabilities -- **Framework choice** - PyTorch/TensorFlow support - -## 🎓 Lessons Learned - -1. **Proxy Pattern** - Excellent for backward compatibility during migrations -2. **Auto-Discovery** - Reduces maintenance burden significantly -3. **Base Classes** - Critical for reducing duplication -4. **Gradual Migration** - Non-breaking changes enable continuous delivery -5. **Documentation First** - Clear docs accelerate adoption - -## 🔗 Related Files - -- `frontend/src/lib/nodes/` - Core frontend implementation -- `block_manager/services/nodes/` - Core backend implementation -- `docs/NODE_DEFINITION_ARCHITECTURE.md` - Architecture guide -- `docs/NODES_AND_RULES.md` - Node-specific rules -- `frontend/src/lib/blockDefinitions.ts` - Legacy compatibility layer - -## ✨ Next Steps - -1. **Test the system** - Add comprehensive unit tests -2. **Complete migration** - Update components to use registry -3. **Performance tune** - Add caching where beneficial -4. **Expand coverage** - Ensure all original nodes implemented -5. **Documentation** - Update all references to new system - -## Status: ✅ Core Implementation Complete - -The new modular node definition system is **fully functional and backward compatible**. All core architecture is in place, legacy code continues to work, and the path forward is clear and well-documented. diff --git a/docs/NODESPEC_IMPLEMENTATION_COMPLETE.md b/docs/NODESPEC_IMPLEMENTATION_COMPLETE.md deleted file mode 100644 index 6e4aa58..0000000 --- a/docs/NODESPEC_IMPLEMENTATION_COMPLETE.md +++ /dev/null @@ -1,484 +0,0 @@ -# NodeSpec System Implementation - Complete - -## Overview -Successfully implemented a declarative, template-based node specification system for VisionForge, enabling dynamic code generation and fetching node definitions from the backend. This replaces the previous class-based node definition system. - -**Implementation Date:** December 2024 -**Phases Completed:** 1-3 (Backend Domain Model, Backend API, Frontend Integration) -**Status:** ✅ Complete & Tested - ---- - -## Architecture - -### Core Components - -#### 1. **NodeSpec Data Model** (`block_manager/services/nodes/specs/models.py`) -Frozen dataclasses providing immutable, declarative node specifications: - -```python -@dataclass(frozen=True) -class NodeSpec: - type: str # e.g., "conv2d", "linear" - label: str # Human-readable name - category: str # "input", "basic", "advanced", etc. - color: str # CSS color for UI - icon: str # Phosphor icon name - description: str # Documentation - framework: Framework # PYTORCH or TENSORFLOW - config_schema: tuple[ConfigFieldSpec, ...] # Immutable config fields - template: NodeTemplateSpec # Jinja2 template for code generation - allows_multiple_inputs: bool = False - metadata: Dict[str, Any] = field(default_factory=dict) - shape_fn: Optional[str] = None # Reference to shape computation function - validation_fn: Optional[str] = None # Reference to validation function -``` - -**Key Features:** -- Frozen dataclasses for thread-safety and hashability -- Tuples for config_schema (immutable, cacheable) -- Framework enum distinguishes PyTorch (NCHW) vs TensorFlow (NHWC) -- Default values computed via `default_config()` method - -#### 2. **Spec Registry** (`block_manager/services/nodes/specs/registry.py`) -Lazy-loading registry with LRU caching: - -```python -@lru_cache(maxsize=1) -def _load_spec_map() -> SpecMap: - """Lazily load all specs on first access, cache thereafter""" - -def list_node_specs(framework: Framework) -> list[NodeSpec]: - """Get all specs for a framework""" - -def get_node_spec(node_type: str, framework: Framework) -> Optional[NodeSpec]: - """Get a specific spec by type and framework""" -``` - -**Key Features:** -- Single source of truth for all node definitions -- LRU cache prevents repeated imports -- Framework-specific filtering -- Supports iteration across all specs - -#### 3. **Template Renderer** (`block_manager/services/nodes/templates/renderer.py`) -Jinja2-based code generation engine: - -```python -def render_node_template( - spec: NodeSpec, - config: Dict[str, Any], - metadata: Optional[Dict[str, Any]] = None, - extra_context: Optional[Dict[str, Any]] = None, -) -> RenderedTemplate: - """Render a node's template with configuration context""" -``` - -**Key Features:** -- StrictUndefined mode catches template errors at render time -- Merges config + metadata + extra_context for template access -- Returns `RenderedTemplate` with code and full context -- Framework-agnostic (works for PyTorch and TensorFlow) - -#### 4. **Serialization** (`block_manager/services/nodes/specs/serialization.py`) -Converts NodeSpec to JSON-serializable dicts for API responses: - -```python -def spec_to_dict(spec: NodeSpec) -> Dict[str, Any]: - """Convert NodeSpec to camelCase JSON dict""" - -def compute_spec_hash(payload: Dict[str, Any]) -> str: - """Deterministic SHA256 hash for caching/versioning""" -``` - -**Key Features:** -- Converts snake_case Python → camelCase JSON -- Deterministic hashing for API cache invalidation -- Includes all metadata, config schema, and template - -#### 5. **Shape & Validation Rules** (`block_manager/services/nodes/rules/`) -Utilities for dimension inference and connection validation: - -**Shape Functions:** -- `compute_conv2d_output()` - Handles NCHW (PyTorch) and NHWC (TensorFlow) -- `compute_linear_output()` - Fully connected layers -- `compute_flatten_output()` - Multi-dim → 2D -- `compute_maxpool_output()` - Pooling layers -- `compute_concat_output()` - Multi-input concatenation -- `compute_add_output()` - Element-wise addition -- `compute_batchnorm_output()` - Preserves shape -- `compute_dropout_output()` - Preserves shape - -**Validation Functions:** -- `validate_connection()` - Checks dimension compatibility -- `validate_multi_input_connection()` - Validates concat/add nodes -- `validate_config()` - Ensures config matches schema -- `validate_graph_acyclic()` - Detects cycles (DAG enforcement) - ---- - -## Node Specifications - -### PyTorch Nodes (17 types) -**Location:** `block_manager/services/nodes/specs/pytorch/__init__.py` - -| Type | Label | Category | Description | -|------|-------|----------|-------------| -| `input` | Input | input | Network input (NCHW format) | -| `linear` | Linear | basic | Fully connected layer | -| `conv2d` | Conv2D | basic | 2D convolutional layer | -| `flatten` | Flatten | basic | Flatten to 2D | -| `relu` | ReLU | basic | ReLU activation | -| `dropout` | Dropout | basic | Dropout regularization | -| `batchnorm` | Batch Normalization | basic | Batch normalization | -| `maxpool` | MaxPool2D | basic | 2D max pooling | -| `softmax` | Softmax | basic | Softmax activation | -| `concat` | Concatenate | merge | Concatenate tensors | -| `add` | Add | merge | Element-wise addition | -| `attention` | Multi-Head Attention | advanced | Attention mechanism | -| `custom` | Custom Layer | advanced | User-defined layer | -| `dataloader` | DataLoader | input | Data loading | -| `output` | Output | output | Network output | -| `loss` | Loss Function | output | Training loss | -| `empty` | Empty | utility | Placeholder | - -**Example Spec:** -```python -CONV2D_SPEC = NodeSpec( - type="conv2d", - label="Conv2D", - category="basic", - color="var(--color-purple)", - icon="SquareHalf", - description="2D convolutional layer (PyTorch)", - framework=Framework.PYTORCH, - config_schema=( - ConfigFieldSpec(name="out_channels", label="Output Channels", field_type="number", required=True, ...), - ConfigFieldSpec(name="kernel_size", label="Kernel Size", field_type="number", default=3, ...), - ... - ), - template=NodeTemplateSpec( - name="pytorch_conv2d", - engine="jinja2", - content="""nn.Conv2d({{ config.in_channels }}, {{ config.out_channels }}, kernel_size={{ config.kernel_size }}, ...)""" - ), -) -``` - -### TensorFlow Nodes (14 types) -**Location:** `block_manager/services/nodes/specs/tensorflow/__init__.py` - -Mirrors PyTorch structure but with TensorFlow-specific parameters: -- `Dense` instead of `Linear` (uses `units` param) -- `Conv2D` uses `filters` instead of `out_channels` -- `strides` (int) instead of `stride` -- `padding='same'/'valid'` instead of integer padding -- NHWC format instead of NCHW - ---- - -## API Endpoints - -### Updated Endpoints (Phase 2) - -#### 1. **GET `/api/node-definitions?framework={pytorch|tensorflow}`** -Returns all node specifications for a framework. - -**Response:** -```json -{ - "success": true, - "framework": "pytorch", - "definitions": [ - { - "type": "conv2d", - "label": "Conv2D", - "category": "basic", - "color": "var(--color-purple)", - "icon": "SquareHalf", - "description": "2D convolutional layer (PyTorch)", - "framework": "pytorch", - "configSchema": [ - { - "name": "out_channels", - "label": "Output Channels", - "type": "number", - "required": true, - "min": 1, - "description": "Number of output channels" - }, - ... - ], - "template": { - "name": "pytorch_conv2d", - "engine": "jinja2", - "content": "nn.Conv2d(...)" - }, - "hash": "abc123..." - }, - ... - ], - "count": 17 -} -``` - -#### 2. **GET `/api/node-definitions/{node_type}?framework={pytorch|tensorflow}`** -Returns a single node specification. - -**Response:** -```json -{ - "success": true, - "definition": { /* NodeSpec dict */ } -} -``` - -#### 3. **POST `/api/render-node-code`** (NEW) -Renders code for a node given its config. - -**Request:** -```json -{ - "node_type": "conv2d", - "framework": "pytorch", - "config": { - "out_channels": 64, - "kernel_size": 3, - "stride": 1, - "padding": 1 - }, - "metadata": { - "node_id": "node_123" - } -} -``` - -**Response:** -```json -{ - "success": true, - "code": "nn.Conv2d(None, 64, kernel_size=3, stride=1, padding=1, dilation=1)", - "spec_hash": "abc123...", - "node_type": "conv2d", - "framework": "pytorch", - "context": { /* full context dict */ } -} -``` - ---- - -## Frontend Integration (Phase 3) - -### New Files - -#### 1. **`src/lib/nodeSpec.types.ts`** -TypeScript interfaces mirroring Python NodeSpec structure: - -```typescript -export interface NodeSpec { - type: string - label: string - category: 'input' | 'basic' | 'advanced' | 'merge' | 'output' | 'utility' - color: string - icon: string - description: string - framework: Framework - config_schema: ConfigField[] - template: NodeTemplate - allows_multiple_inputs?: boolean -} -``` - -#### 2. **`src/lib/api.ts`** (Updated) -Added typed API functions: - -```typescript -export async function renderNodeCode( - nodeType: string, - framework: 'pytorch' | 'tensorflow', - config: Record, - metadata?: Record -): Promise> -``` - -#### 3. **`src/lib/useNodeSpecs.ts`** -React hooks for fetching and managing specs: - -```typescript -const { specs, loading, error, refetch, getSpec, renderCode } = useNodeSpecs({ framework: 'pytorch' }) - -// Get specific spec -const convSpec = getSpec('conv2d') - -// Render code for a node -const code = await renderCode('conv2d', { out_channels: 64, ... }) -``` - -#### 4. **`src/components/CodePreview.tsx`** -Component for displaying rendered node code: - -```tsx - -``` - ---- - -## Testing - -### Test Coverage (`test_nodespec_system.py`) - -✅ **Test 1: Spec Registry** -- Loads PyTorch specs (17 nodes) -- Loads TensorFlow specs (14 nodes) -- Retrieves specific specs by type -- Iterates all specs across frameworks - -✅ **Test 2: Serialization** -- Converts NodeSpec → JSON dict -- Computes deterministic SHA256 hash -- Verifies hash stability - -✅ **Test 3: Template Rendering** -- Renders PyTorch Conv2D template -- Renders TensorFlow Conv2D template -- Renders Linear/Dense templates -- Verifies parameter interpolation - -✅ **Test 4: Shape Computation** -- PyTorch Conv2D shape inference (NCHW) -- TensorFlow Conv2D shape inference (NHWC) -- Linear layer shape computation - -✅ **Test 5: Validation** -- Config validation (required fields, min/max, types) -- Connection validation (dimension compatibility) -- Rejects invalid connections (e.g., 4D → Linear without Flatten) - -✅ **Test 6: API Integration** -- GET `/node-definitions` returns all specs -- GET `/node-definitions/{type}` returns single spec -- POST `/render-node-code` renders template - -**Run Tests:** -```bash -cd project -python ../test_nodespec_system.py -``` - -**Output:** -``` -✅ ALL TESTS PASSED - -Phase 1-3 Implementation Complete: - ✓ Backend Domain Model Refactor (Phase 1) - ✓ Backend API Redesign (Phase 2) - ✓ Frontend Integration (Phase 3) -``` - ---- - -## Migration Notes - -### Breaking Changes -1. **Old API:** Node definitions returned class-based `to_dict()` output - **New API:** Returns declarative `NodeSpec` dicts with camelCase keys - -2. **Old System:** Node definitions were Python classes with methods - **New System:** Node definitions are frozen dataclasses with templates - -3. **Config Schema:** Now uses tuples (immutable) instead of lists - -### Backward Compatibility -The old class-based node registry (`block_manager/services/nodes/registry.py`) and node definitions (`block_manager/services/nodes/pytorch/*.py`, `block_manager/services/nodes/tensorflow/*.py`) are **still present** but deprecated. They can be removed in a future cleanup phase. - -### Frontend Updates Required -To use the new system, frontend code should: -1. Import types from `lib/nodeSpec.types.ts` -2. Use `useNodeSpecs()` hook instead of direct API calls -3. Use `` component for displaying node code -4. Update store to consume camelCase `NodeSpec` format from API - ---- - -## Dependencies - -### Backend (Added) -- `jinja2>=3.1.0` - Template rendering engine - -### Frontend (No new dependencies) -All frontend integration uses existing React/TypeScript infrastructure. - ---- - -## Performance Optimizations - -1. **LRU Caching:** Spec registry uses `@lru_cache` to load specs once -2. **Frozen Dataclasses:** Immutable specs enable safe caching -3. **Deterministic Hashing:** SHA256 hashes enable API response caching -4. **Lazy Loading:** Specs loaded on first access, not at import time - ---- - -## Future Work (Not Implemented) - -### Phase 4: Code Generation Integration -- Wire NodeSpec system into export pipeline -- Generate full PyTorch/TensorFlow projects from graph -- **Status:** Skipped per user request - -### Phase 5: Documentation & Verification -- User-facing documentation -- Migration guide for old node definitions -- **Status:** Skipped per user request - -### Cleanup -- Remove old class-based node definitions -- Remove deprecated `registry.py` functions -- Update frontend `blockDefinitions.ts` to fetch from backend - ---- - -## Key Design Decisions - -1. **Why Frozen Dataclasses?** - - Thread-safe for concurrent requests - - Hashable for use in dicts/sets - - Immutable prevents accidental mutations - -2. **Why Tuples for Config Schema?** - - Immutable → cacheable - - Prevents modification after spec creation - - Hashable for deterministic serialization - -3. **Why Jinja2?** - - Proven template engine - - StrictUndefined catches errors - - Familiar syntax for developers - - Python inspect.getsource() doesn't work with runtime classes - -4. **Why Framework Enum?** - - Type-safe framework selection - - Enables framework-specific logic (NCHW vs NHWC) - - Prevents typos ("pytorch" vs "PyTorch") - -5. **Why Separate Shape/Validation Modules?** - - Separation of concerns - - Reusable across different node types - - Easier to test in isolation - ---- - -## Contact & Maintenance - -**Implemented by:** GitHub Copilot -**Test Coverage:** 100% (all 6 test suites passing) -**Documentation:** This file + inline code comments - -For questions or issues: -1. Check test file for usage examples -2. See inline documentation in source files -3. Review API endpoint responses for schema details diff --git a/docs/NODESPEC_QUICK_REFERENCE.md b/docs/NODESPEC_QUICK_REFERENCE.md deleted file mode 100644 index bcc41da..0000000 --- a/docs/NODESPEC_QUICK_REFERENCE.md +++ /dev/null @@ -1,411 +0,0 @@ -# NodeSpec System - Quick Reference - -## For Backend Developers - -### Adding a New Node Type - -1. **Create the spec in the appropriate framework module:** - -```python -# In block_manager/services/nodes/specs/pytorch/__init__.py (or tensorflow/) - -MY_NEW_NODE_SPEC = NodeSpec( - type="my_node", # Unique identifier - label="My Custom Node", # Display name - category="basic", # Category for palette - color="var(--color-primary)", # CSS color - icon="Star", # Phosphor icon name - description="Does something cool", # Tooltip text - framework=Framework.PYTORCH, # or Framework.TENSORFLOW - - config_schema=( # Tuple of config fields - ConfigFieldSpec( - name="param1", - label="Parameter 1", - field_type="number", # "text", "number", "boolean", "select" - required=True, - min=1, - description="What this param does" - ), - ConfigFieldSpec( - name="activation", - label="Activation", - field_type="select", - default="relu", - options=( - ConfigOptionSpec(value="relu", label="ReLU"), - ConfigOptionSpec(value="tanh", label="Tanh"), - ), - ), - ), - - template=NodeTemplateSpec( - name="pytorch_my_node", - engine="jinja2", - content="""nn.MyNode({{ config.param1 }}, activation='{{ config.activation }}')""" - ), -) -``` - -2. **Add to NODE_SPECS tuple at bottom of file:** - -```python -NODE_SPECS = ( - INPUT_SPEC, - LINEAR_SPEC, - MY_NEW_NODE_SPEC, # <-- Add here - # ... rest -) -``` - -3. **No restart needed** - registry uses lazy loading and will pick up changes on next request. - -### Creating Shape Functions - -```python -# In block_manager/services/nodes/rules/shape.py - -def compute_my_node_output( - input_shape: Optional[TensorShape], - config: Dict[str, int], - framework: Framework, -) -> Optional[TensorShape]: - """Compute output shape for my custom node.""" - if not input_shape: - return None - - dims = input_shape.get("dims", []) - # Your shape logic here - - if framework is Framework.PYTORCH: - # NCHW format logic - pass - else: # TensorFlow - # NHWC format logic - pass - - return TensorShape({ - "dims": [batch, ...], - "description": "Transformed shape" - }) -``` - -### Creating Validation Functions - -```python -# In block_manager/services/nodes/rules/validation.py - -def validate_my_node_connection( - source_spec: NodeSpec, - target_spec: NodeSpec, - source_output_shape: Optional[TensorShape], -) -> tuple[bool, Optional[str]]: - """Custom validation for my node.""" - if source_output_shape: - dims = source_output_shape.get("dims", []) - if len(dims) != 4: - return False, "My node requires 4D input" - return True, None -``` - ---- - -## For Frontend Developers - -### Fetching Node Specs - -```typescript -import { useNodeSpecs } from '@/lib/useNodeSpecs' - -function MyComponent() { - const { specs, loading, error, getSpec, renderCode } = useNodeSpecs({ - framework: 'pytorch', - autoFetch: true // Automatically fetch on mount - }) - - if (loading) return
Loading...
- if (error) return
Error: {error}
- - // Get specific spec - const convSpec = getSpec('conv2d') - - // Render code for a node - const handleRender = async () => { - const code = await renderCode('conv2d', { - out_channels: 64, - kernel_size: 3 - }) - console.log(code) // "nn.Conv2d(...)" - } - - return ( -
- {specs.map(spec => ( -
{spec.label}
- ))} -
- ) -} -``` - -### Displaying Code Preview - -```typescript -import { CodePreview } from '@/components/CodePreview' - -function NodeConfigPanel({ node }) { - return ( -
-

Configuration

- {/* Config form here */} - - -
- ) -} -``` - -### Using the API Directly - -```typescript -import { getNodeDefinitions, renderNodeCode } from '@/lib/api' - -// Get all specs -const response = await getNodeDefinitions('pytorch') -if (response.success) { - const specs = response.data.definitions -} - -// Render code -const codeResponse = await renderNodeCode( - 'conv2d', - 'pytorch', - { out_channels: 64, kernel_size: 3 } -) -if (codeResponse.success) { - console.log(codeResponse.data.code) -} -``` - -### TypeScript Types - -```typescript -import type { - NodeSpec, - ConfigField, - Framework -} from '@/lib/nodeSpec.types' - -const spec: NodeSpec = { - type: 'conv2d', - label: 'Conv2D', - category: 'basic', - framework: 'pytorch', - config_schema: [ - { - name: 'out_channels', - label: 'Output Channels', - field_type: 'number', - required: true - } - ], - // ... rest -} -``` - ---- - -## Common Patterns - -### Template Syntax - -```jinja2 -{# Access config values #} -{{ config.out_channels }} - -{# Conditionals #} -{% if config.use_bias %}bias=True{% else %}bias=False{% endif %} - -{# Boolean to lowercase #} -{{ config.use_bias|lower }} - -{# Default values #} -{{ config.stride|default(1) }} - -{# Loops #} -{% for item in config.layers %} - Layer {{ loop.index }}: {{ item }} -{% endfor %} -``` - -### Config Schema Field Types - -| field_type | Description | Example Use | -|------------|-------------|-------------| -| `text` | String input | Layer names, custom code | -| `number` | Integer/float input | Dimensions, learning rate | -| `boolean` | True/false checkbox | use_bias, training | -| `select` | Dropdown menu | Activation functions, padding modes | - -### Validation Patterns - -```python -# Check required field -if field_spec.required and value is None: - errors.append(f"'{field_spec.label}' is required") - -# Check min/max for numbers -if field_spec.min is not None and num_value < field_spec.min: - errors.append(f"'{field_spec.label}' must be at least {field_spec.min}") - -# Check select options -if field_spec.options: - valid_values = [opt.value for opt in field_spec.options] - if value not in valid_values: - errors.append(f"'{field_spec.label}' must be one of: {', '.join(valid_values)}") -``` - ---- - -## API Endpoints Summary - -| Method | Endpoint | Purpose | -|--------|----------|---------| -| GET | `/api/node-definitions?framework=pytorch` | Get all specs | -| GET | `/api/node-definitions/conv2d?framework=pytorch` | Get single spec | -| POST | `/api/render-node-code` | Render node code | - ---- - -## Debugging Tips - -### Backend - -```python -# Check what specs are loaded -from block_manager.services.nodes.specs.registry import list_node_specs, Framework -specs = list_node_specs(Framework.PYTORCH) -print(f"Loaded {len(specs)} specs") - -# Test template rendering -from block_manager.services.nodes.specs.registry import get_node_spec -from block_manager.services.nodes.templates.renderer import render_node_template - -spec = get_node_spec('conv2d', Framework.PYTORCH) -rendered = render_node_template(spec, {'out_channels': 64, 'kernel_size': 3}) -print(rendered.code) - -# Validate config -from block_manager.services.nodes.rules import validate_config -is_valid, errors = validate_config(spec, config) -if not is_valid: - print("Errors:", errors) -``` - -### Frontend - -```typescript -// Check API response -const response = await getNodeDefinitions('pytorch') -console.log('Success:', response.success) -console.log('Data:', response.data) -console.log('Error:', response.error) - -// Debug rendered code -const codeResponse = await renderNodeCode('conv2d', 'pytorch', config) -console.log('Code:', codeResponse.data.code) -console.log('Context:', codeResponse.data.context) -``` - ---- - -## Testing - -### Run Backend Tests - -```bash -cd project -python ../test_nodespec_system.py -``` - -### Test Individual Components - -```python -# Test spec loading -from block_manager.services.nodes.specs.registry import get_node_spec, Framework -spec = get_node_spec('conv2d', Framework.PYTORCH) -assert spec is not None - -# Test serialization -from block_manager.services.nodes.specs.serialization import spec_to_dict -spec_dict = spec_to_dict(spec) -assert 'type' in spec_dict - -# Test rendering -from block_manager.services.nodes.templates.renderer import render_node_template -rendered = render_node_template(spec, {'out_channels': 64}) -assert 'nn.Conv2d' in rendered.code -``` - ---- - -## Performance Considerations - -1. **Registry is cached** - Specs loaded once on first access -2. **Frozen dataclasses** - Safe to cache, won't be mutated -3. **Deterministic hashing** - Same spec always produces same hash -4. **Lazy loading** - Specs only loaded when needed - ---- - -## File Locations - -| Component | Path | -|-----------|------| -| PyTorch Specs | `block_manager/services/nodes/specs/pytorch/__init__.py` | -| TensorFlow Specs | `block_manager/services/nodes/specs/tensorflow/__init__.py` | -| Spec Models | `block_manager/services/nodes/specs/models.py` | -| Registry | `block_manager/services/nodes/specs/registry.py` | -| Serialization | `block_manager/services/nodes/specs/serialization.py` | -| Template Renderer | `block_manager/services/nodes/templates/renderer.py` | -| Shape Functions | `block_manager/services/nodes/rules/shape.py` | -| Validation | `block_manager/services/nodes/rules/validation.py` | -| API Views | `block_manager/views/architecture_views.py` | -| Frontend Types | `frontend/src/lib/nodeSpec.types.ts` | -| Frontend Hook | `frontend/src/lib/useNodeSpecs.ts` | -| Code Preview | `frontend/src/components/CodePreview.tsx` | - ---- - -## Common Issues - -### Template Rendering Errors - -**Problem:** `UndefinedError: 'config' is undefined` -**Solution:** Ensure you're passing config dict to `render_node_template()` - -**Problem:** Template outputs `None` for config values -**Solution:** Check that config keys match field names in config_schema - -### API Returns Empty Definitions - -**Problem:** `definitions: []` in API response -**Solution:** Check that NODE_SPECS tuple includes your spec at bottom of file - -### Frontend Type Errors - -**Problem:** `Property 'config_schema' does not exist` -**Solution:** Use `configSchema` (camelCase) in frontend, `config_schema` (snake_case) in backend - ---- - -## Next Steps - -1. **Add your custom nodes** following the patterns above -2. **Test thoroughly** using the test script -3. **Update frontend** to consume backend specs instead of local definitions -4. **Remove deprecated code** once migration is complete diff --git a/docs/NODES_AND_RULES.md b/docs/NODES_AND_RULES.md deleted file mode 100644 index 4f26bdb..0000000 --- a/docs/NODES_AND_RULES.md +++ /dev/null @@ -1,418 +0,0 @@ -# Neural Network Blocks and Connection Rules - -This document describes all available neural network blocks in VisionForge and the rules governing how they can be connected together. - -## Block Categories - -### Input Layer -Blocks that define the entry points for data into the neural network. - -### Basic Layers -Common neural network building blocks for standard architectures. - -### Advanced Layers -Specialized blocks for complex architectures (attention, transformers, etc.). - -### Merge/Split Layers -Blocks that combine or split multiple tensor streams. - ---- - -## Available Blocks - -### Input -**Category:** Input -**Description:** Define input tensor shape for any modality (text, image, audio, etc.) - -**Configuration:** -- **Custom Label** (optional): Custom label for this input node -- **Note** (optional): Notes or comments about this input - -**Input Requirements:** Can receive connections only from data source nodes (e.g., DataLoader) - -**Output Shape:** Passes through the shape from connected data source, or user-defined if no connection - -**Connection Rules:** -- Can receive connections from data source nodes (DataLoader) -- Cannot receive connections from other processing nodes -- Can connect to any processing block - ---- - -### DataLoader -**Category:** Input -**Description:** Load and prepare input data with optional ground truth labels - -**Configuration:** -- **Input Shape** (required): Input tensor dimensions as JSON array - - Image example: `[1, 3, 224, 224]` - [batch, channels, height, width] - - Text example: `[32, 512, 768]` - [batch, sequence, embedding] - - Audio example: `[16, 1, 16000]` - [batch, channels, samples] - - Tabular example: `[8, 100, 13]` - [batch, rows, features] -- **Include Ground Truth Output** (optional, default: false): Enable a second output for ground truth labels -- **Ground Truth Shape** (optional, default: `[1, 10]`): Shape for ground truth labels when enabled -- **Randomize Data** (optional, default: false): Use random synthetic data for testing -- **CSV File Path** (optional): Path to CSV file for data loading - -**Output Shape:** As specified in Input Shape configuration - -**Connection Rules:** -- Cannot receive connections (it's a data source) -- Can connect to Input nodes or processing blocks -- Primary data source for the network - ---- - -### Linear (Fully Connected) -**Category:** Basic -**Description:** Fully connected layer for dense transformations - -**Configuration:** -- **Output Features** (required): Number of output features (min: 1) -- **Use Bias** (optional, default: true): Add learnable bias parameter - -**Input Requirements:** Requires 2D input `[batch, features]` - -**Output Shape:** `[batch, output_features]` - -**Connection Rules:** -- Requires 2D input tensor -- If input is 4D (e.g., from Conv2D), insert a Flatten layer first -- Cannot connect directly from Conv2D or MaxPool2D without Flatten - ---- - -### Conv2D -**Category:** Basic -**Description:** 2D convolutional layer for spatial feature extraction - -**Configuration:** -- **Output Channels** (required): Number of output channels (min: 1) -- **Kernel Size** (optional, default: 3): Size of convolving kernel (min: 1) -- **Stride** (optional, default: 1): Stride of convolution (min: 1) -- **Padding** (optional, default: 0): Zero-padding added to both sides (min: 0) -- **Dilation** (optional, default: 1): Spacing between kernel elements (min: 1) - -**Input Requirements:** Requires 4D input `[batch, channels, height, width]` - -**Output Shape:** `[batch, out_channels, out_height, out_width]` -where: -- `out_height = floor((height + 2 * padding - kernel) / stride + 1)` -- `out_width = floor((width + 2 * padding - kernel) / stride + 1)` - -**Connection Rules:** -- Requires 4D input tensor -- Cannot connect from Linear or Flatten without reshaping - ---- - -### MaxPool2D -**Category:** Basic -**Description:** 2D max pooling for downsampling spatial dimensions - -**Configuration:** -- **Kernel Size** (optional, default: 2): Size of pooling window (min: 1) -- **Stride** (optional, default: 2): Stride of pooling window (min: 1) -- **Padding** (optional, default: 0): Zero-padding added to both sides (min: 0) - -**Input Requirements:** Requires 4D input `[batch, channels, height, width]` - -**Output Shape:** `[batch, channels, out_height, out_width]` -where: -- `out_height = floor((height - kernel) / stride + 1)` -- `out_width = floor((width - kernel) / stride + 1)` - -**Connection Rules:** -- Requires 4D input tensor -- Same restrictions as Conv2D - ---- - -### BatchNorm -**Category:** Basic -**Description:** Batch normalization for training stability - -**Configuration:** -- **Momentum** (optional, default: 0.1): Momentum for running mean/variance (0-1) -- **Epsilon** (optional, default: 0.00001): Value for numerical stability (min: 0) -- **Affine Transform** (optional, default: true): Learn affine parameters (gamma, beta) - -**Input Requirements:** Requires 2D or 4D input - -**Output Shape:** Same as input - -**Connection Rules:** -- Works with 2D `[batch, features]` or 4D `[batch, channels, height, width]` tensors -- Cannot connect from 3D tensors (use LayerNorm for sequence data) - ---- - -### Dropout -**Category:** Basic -**Description:** Dropout regularization to prevent overfitting - -**Configuration:** -- **Dropout Rate** (optional, default: 0.5): Probability of dropping a unit (0-1) - -**Input Requirements:** Any tensor dimension - -**Output Shape:** Same as input - -**Connection Rules:** -- Dimension-agnostic -- Can connect after any layer - ---- - -### ReLU -**Category:** Basic -**Description:** Rectified Linear Unit activation function - -**Configuration:** No configuration required - -**Input Requirements:** Any tensor dimension - -**Output Shape:** Same as input - -**Connection Rules:** -- Dimension-agnostic -- Can connect after any layer - ---- - -### Softmax -**Category:** Basic -**Description:** Softmax activation for probability distributions - -**Configuration:** -- **Dimension** (optional, default: -1): Dimension along which to apply softmax - -**Input Requirements:** Any tensor dimension - -**Output Shape:** Same as input - -**Connection Rules:** -- Dimension-agnostic -- Typically used as final layer for classification - ---- - -### Flatten -**Category:** Basic -**Description:** Flatten tensor to 2D for fully connected layers - -**Configuration:** -- **Start Dimension** (optional, default: 1): First dimension to flatten (min: 0) - -**Input Requirements:** Any tensor dimension - -**Output Shape:** `[batch, flattened_features]` - -**Connection Rules:** -- Can connect from any layer -- Essential bridge between Conv2D/MaxPool2D and Linear layers - ---- - -### Multi-Head Attention -**Category:** Advanced -**Description:** Multi-head self-attention mechanism - -**Configuration:** -- **Number of Heads** (required, default: 8): Number of attention heads (min: 1) -- **Dropout** (optional, default: 0.1): Attention dropout rate (0-1) - -**Input Requirements:** Requires 3D input `[batch, sequence, embedding]` - -**Output Shape:** Same as input `[batch, sequence, embedding]` - -**Connection Rules:** -- Requires 3D input tensor -- Embed dimension must be divisible by number of heads -- Cannot connect from 2D or 4D tensors without reshaping - ---- - -### Concatenate -**Category:** Merge -**Description:** Concatenate multiple tensors along a specified dimension - -**Configuration:** -- **Dimension** (optional, default: 1): Dimension along which to concatenate - -**Input Requirements:** Multiple inputs with compatible shapes - -**Output Shape:** Computed based on input shapes and concatenation dimension - -**Connection Rules:** -- **Accepts multiple inputs** (only merge block that allows this) -- All inputs must have same number of dimensions -- All dimensions except concatenation dimension must match - ---- - -### Add -**Category:** Merge -**Description:** Element-wise addition of tensors (for residual connections) - -**Configuration:** No configuration required - -**Input Requirements:** Multiple inputs with identical shapes - -**Output Shape:** Same as input - -**Connection Rules:** -- **Accepts multiple inputs** (only merge block that allows this) -- All inputs must have **exactly the same shape** -- Commonly used for residual/skip connections - ---- - -### Custom Layer -**Category:** Advanced -**Description:** Custom layer with user-defined Python operations - -**Configuration:** -- **Layer Name** (required): Identifier for your custom layer -- **Python Code** (required): Custom forward pass implementation -- **Output Shape** (optional): Expected output shape (JSON array) -- **Description** (optional): Brief description of functionality - -**Input Requirements:** Flexible (user-defined) - -**Output Shape:** As specified in configuration, or matches input if not specified - -**Connection Rules:** -- Flexible - validation depends on user-defined code -- Use with caution - ensure output shapes are correct - -**Code Editor:** -- Opens in a modal dialog (not sidebar) -- Syntax highlighting for Python -- Input tensor available as variable `x` -- Must return output tensor - -**Example Code:** -```python -# Simple pass-through -return x - -# Apply custom transformation -return torch.sigmoid(x) * 2.0 - -# Multi-step processing -x = x.view(x.size(0), -1) -x = torch.relu(x) -return x -``` - ---- - -## Connection Rules Summary - -### By Dimension Requirement - -**2D Input Required `[batch, features]`:** -- Linear - -**3D Input Required `[batch, sequence, embedding]`:** -- Multi-Head Attention - -**4D Input Required `[batch, channels, height, width]`:** -- Conv2D -- MaxPool2D - -**2D or 4D Input:** -- BatchNorm - -**Dimension-Agnostic (any input):** -- Dropout -- ReLU -- Softmax -- Flatten (converts to 2D) -- Custom (user-defined) - -### Special Rules - -**Data Source Nodes (Cannot Receive Connections):** -- DataLoader (primary data source) - -**Can Only Receive from Data Sources:** -- Input (can receive from DataLoader and other future data sources) - -**Multiple Inputs Allowed:** -- Concatenate (shapes must be compatible for concatenation) -- Add (shapes must be identical) - -**Single Input Only:** -- All other blocks - -### Common Connection Patterns - -#### Image Classification CNN with DataLoader -``` -DataLoader → Input → Conv2D → ReLU → MaxPool2D → Conv2D → ReLU → -MaxPool2D → Flatten → Linear → Dropout → Linear → Softmax -``` - -#### Simple Image Classification (Direct Connection) -``` -Input (4D) → Conv2D → ReLU → MaxPool2D → Conv2D → ReLU → -MaxPool2D → Flatten → Linear → Dropout → Linear → Softmax -``` - -#### Residual Connection -``` - ┌─────────────┐ - │ │ -Input → Conv2D → ReLU → Conv2D → Add → ReLU - ↑ - │ -``` - -#### Multi-Modal Fusion with Separate Data Sources -``` -DataLoader (Images) → Input (4D) → Conv2D → Flatten ─┐ - ├→ Concatenate → Linear -DataLoader (Text) → Input (3D) → Attention → Flatten─┘ -``` - ---- - -## Validation Errors - -When attempting invalid connections, you'll see helpful error messages: - -- **"Input blocks can only receive connections from data source nodes (DataLoader)"** - Input can't connect from processing blocks -- **"DataLoader blocks cannot receive connections (they are source nodes)"** - DataLoader is a data source -- **"Conv2D requires 4D input, got 2D"** - Need to add reshaping or use different architecture -- **"Linear layer requires 2D input, got 4D. Consider adding a Flatten layer first."** - Insert Flatten between Conv/Pool and Linear -- **"Multi-Head Attention requires 3D input, got 4D"** - Wrong tensor dimensionality -- **"BatchNorm requires 2D or 4D input, got 3D"** - Use LayerNorm for sequences -- **Single input blocks rejecting second connection** - Use Concatenate or Add for multi-input - ---- - -## Tips for Building Architectures - -1. **Use DataLoader for complete data pipelines** - Connect DataLoader to Input for proper data flow -2. **Input blocks as shape adapters** - When using DataLoader, Input acts as a shape passthrough -3. **Always define your data shape** - Either in DataLoader or Input block -4. **Use Flatten between Conv/Pool and Linear** - Bridge between spatial and dense layers -5. **Check dimension compatibility** - Hover over blocks to see input/output shapes -6. **Use Add for skip connections** - Perfect for ResNet-style architectures -7. **Concatenate for multi-path fusion** - Combine features from different branches -8. **Custom blocks for experiments** - Quick prototyping without modifying codebase -9. **Dropout for regularization** - Add before Linear layers to prevent overfitting -10. **BatchNorm after Conv/Linear** - Helps training stability - ---- - -## Color Coding - -Blocks are color-coded by category for easy identification: -- **Teal**: Input/Output operations -- **Deep Blue**: Basic processing layers -- **Purple**: Advanced layers (Conv2D, Attention) -- **Cyan**: Merge operations -- **Red/Orange**: Activations and regularization diff --git a/docs/NODE_DEFINITION_ARCHITECTURE.md b/docs/NODE_DEFINITION_ARCHITECTURE.md deleted file mode 100644 index 2d29800..0000000 --- a/docs/NODE_DEFINITION_ARCHITECTURE.md +++ /dev/null @@ -1,551 +0,0 @@ -# Node Definition Architecture - -## Overview - -VisionForge uses a **modular, class-based node definition system** that eliminates hard-coded conditionals and provides a highly extensible architecture for adding new neural network layer types. This document describes the architecture, patterns, and procedures for working with node definitions. - -## Architecture Principles - -### 1. High Decoupling -- Each node type is defined in its own file/module -- No central switch statements or if-else chains -- Validators and shape computation logic live with the node definition - -### 2. Framework Agnostic -- Support for multiple backends (PyTorch, TensorFlow) -- Framework-specific implementations when needed -- Shared logic via base classes and mixins - -### 3. Automatic Discovery -- Registry pattern with dynamic loading -- No manual registration required -- Add a file → node automatically available - -### 4. Non-Breaking Migration -- Legacy adapter maintains backward compatibility -- Gradual migration path -- Deprecation warnings guide developers - -## System Components - -### Frontend (TypeScript) - -``` -frontend/src/lib/nodes/ -├── contracts.ts # Interfaces and type definitions -├── base.ts # Abstract base classes -├── registry.ts # Auto-discovery and loading -├── definitions/ -│ ├── pytorch/ # PyTorch-specific nodes -│ │ ├── linear.ts -│ │ ├── conv2d.ts -│ │ └── ... -│ └── tensorflow/ # TensorFlow-specific nodes -│ ├── linear.ts -│ └── ... -└── legacy/ - └── blockDefinitionsAdapter.ts # Backward compatibility -``` - -### Backend (Python) - -``` -block_manager/services/nodes/ -├── __init__.py -├── base.py # Base classes and mixins -├── registry.py # Dynamic loading and registration -├── pytorch/ # PyTorch node implementations -│ ├── __init__.py -│ ├── linear.py -│ ├── conv2d.py -│ └── ... -└── tensorflow/ # TensorFlow node implementations - ├── __init__.py - └── ... -``` - -## Key Interfaces - -### Frontend: INodeDefinition - -```typescript -interface INodeDefinition { - readonly metadata: NodeMetadata - readonly configSchema: ConfigField[] - - computeOutputShape( - inputShape: TensorShape | undefined, - config: BlockConfig - ): TensorShape | undefined - - validateIncomingConnection( - sourceNodeType: BlockType, - sourceOutputShape: TensorShape | undefined, - targetConfig: BlockConfig - ): string | undefined - - allowsMultipleInputs(): boolean - validateConfig(config: BlockConfig): string[] - getDefaultConfig(): BlockConfig -} -``` - -### Backend: NodeDefinition - -```python -class NodeDefinition(ABC): - @property - @abstractmethod - def metadata(self) -> NodeMetadata: - pass - - @property - @abstractmethod - def config_schema(self) -> List[ConfigField]: - pass - - @abstractmethod - def compute_output_shape( - self, - input_shape: Optional[TensorShape], - config: Dict[str, Any] - ) -> Optional[TensorShape]: - pass - - def validate_incoming_connection(...) -> Optional[str]: - pass - - def allows_multiple_inputs(self) -> bool: - pass - - def validate_config(self, config: Dict[str, Any]) -> List[str]: - pass -``` - -## Base Class Hierarchy - -### Frontend - -- `NodeDefinition` - Abstract base for all nodes -- `SourceNodeDefinition` - Input/source nodes (no incoming connections) -- `TerminalNodeDefinition` - Output/terminal nodes (accept any input) -- `MergeNodeDefinition` - Multi-input nodes (concat, add) -- `PassthroughNodeDefinition` - Nodes that preserve input shape - -### Backend - -- `NodeDefinition` - Abstract base with mixins -- `ShapeComputerMixin` - Utilities for shape calculation -- `ValidatorMixin` - Common validation patterns -- `SourceNodeDefinition` - Source nodes -- `TerminalNodeDefinition` - Terminal nodes -- `MergeNodeDefinition` - Multi-input nodes -- `PassthroughNodeDefinition` - Passthrough nodes - -## Adding a New Node Type - -### Frontend (TypeScript) - -1. **Create node file**: `frontend/src/lib/nodes/definitions/pytorch/my_layer.ts` - -```typescript -import { NodeDefinition } from '../../base' -import { NodeMetadata, BackendFramework } from '../../contracts' -import { TensorShape, BlockConfig, ConfigField } from '../../../types' - -export class MyLayerNode extends NodeDefinition { - readonly metadata: NodeMetadata = { - type: 'my_layer', - label: 'My Layer', - category: 'basic', - color: 'var(--color-primary)', - icon: 'IconName', - description: 'Description of my layer', - framework: BackendFramework.PyTorch - } - - readonly configSchema: ConfigField[] = [ - { - name: 'param1', - label: 'Parameter 1', - type: 'number', - required: true, - min: 1, - description: 'First parameter' - } - ] - - computeOutputShape( - inputShape: TensorShape | undefined, - config: BlockConfig - ): TensorShape | undefined { - // Implement shape computation - if (!inputShape) return undefined - return { - dims: [...inputShape.dims], - description: 'Output description' - } - } - - validateIncomingConnection( - sourceNodeType: BlockType, - sourceOutputShape: TensorShape | undefined, - targetConfig: BlockConfig - ): string | undefined { - // Implement validation - return this.validateDimensions(sourceOutputShape, { - dims: 2, - description: '[batch, features]' - }) - } -} -``` - -2. **Export in index**: Add to `frontend/src/lib/nodes/definitions/pytorch/index.ts` - -```typescript -export { MyLayerNode } from './my_layer' -``` - -3. **Add type**: Update `frontend/src/lib/types.ts` - -```typescript -export type BlockType = - | 'input' - | 'my_layer' // Add here - | ... -``` - -**Done!** The node is now available throughout the application. - -### Backend (Python) - -1. **Create node file**: `block_manager/services/nodes/pytorch/my_layer.py` - -```python -from typing import Dict, List, Optional, Any -from ..base import NodeDefinition, NodeMetadata, ConfigField, TensorShape, Framework - -class MyLayerNode(NodeDefinition): - """My Custom Layer""" - - @property - def metadata(self) -> NodeMetadata: - return NodeMetadata( - type="my_layer", - label="My Layer", - category="basic", - color="var(--color-primary)", - icon="IconName", - description="Description of my layer", - framework=Framework.PYTORCH - ) - - @property - def config_schema(self) -> List[ConfigField]: - return [ - ConfigField( - name="param1", - label="Parameter 1", - type="number", - required=True, - min=1, - description="First parameter" - ) - ] - - def compute_output_shape( - self, - input_shape: Optional[TensorShape], - config: Dict[str, Any] - ) -> Optional[TensorShape]: - # Implement shape computation - if not input_shape: - return None - return TensorShape( - dims=list(input_shape.dims), - description="Output description" - ) - - def validate_incoming_connection( - self, - source_node_type: str, - source_output_shape: Optional[TensorShape], - target_config: Dict[str, Any] - ) -> Optional[str]: - # Implement validation - return self.validate_dimensions( - source_output_shape, - 2, - "[batch, features]" - ) -``` - -2. **Export in package**: Add to `block_manager/services/nodes/pytorch/__init__.py` - -```python -from .my_layer import MyLayerNode - -__all__ = [ - 'MyLayerNode', - ... -] -``` - -**Done!** The backend registry will automatically discover and load the node. - -## Special Node Types - -### Source Nodes (Input, DataLoader) - -Use `SourceNodeDefinition` base class - these reject incoming connections: - -```typescript -export class InputNode extends SourceNodeDefinition { - // Automatically rejects incoming connections -} -``` - -### Terminal Nodes (Output, Loss) - -Use `TerminalNodeDefinition` base class - these accept any input: - -```typescript -export class OutputNode extends TerminalNodeDefinition { - // Automatically accepts all connections -} -``` - -### Merge Nodes (Concat, Add) - -Use `MergeNodeDefinition` base class - these allow multiple inputs: - -```typescript -export class ConcatNode extends MergeNodeDefinition { - // Automatically allows multiple inputs - - // Implement special multi-input shape computation - computeMultiInputShape( - inputShapes: TensorShape[], - config: BlockConfig - ): TensorShape | undefined { - // Compute output from multiple inputs - } -} -``` - -### Passthrough Nodes (ReLU, Dropout) - -Use `PassthroughNodeDefinition` base class - input shape = output shape: - -```typescript -export class ReLUNode extends PassthroughNodeDefinition { - // Automatically passes through input shape -} -``` - -## Registry Usage - -### Frontend - -```typescript -import { - getNodeDefinition, - getAllNodeDefinitions, - BackendFramework -} from './lib/nodes/registry' - -// Get specific node -const linearDef = getNodeDefinition('linear', BackendFramework.PyTorch) - -// Get all nodes for a framework -const allPyTorchNodes = getAllNodeDefinitions(BackendFramework.PyTorch) - -// Compute shape -const outputShape = linearDef.computeOutputShape(inputShape, config) - -// Validate connection -const error = linearDef.validateIncomingConnection( - 'conv2d', - sourceShape, - targetConfig -) -``` - -### Backend - -```python -from block_manager.services.nodes.registry import ( - get_node_definition, - get_all_node_definitions, - Framework -) - -# Get specific node -linear_def = get_node_definition('linear', Framework.PYTORCH) - -# Get all nodes -all_pytorch_nodes = get_all_node_definitions(Framework.PYTORCH) - -# Compute shape -output_shape = linear_def.compute_output_shape(input_shape, config) -``` - -## Migration from Legacy System - -### Current State - -- ✅ New class-based system fully implemented -- ✅ Legacy adapter provides backward compatibility -- ✅ All existing code continues to work unchanged -- ⏳ Gradual migration to new registry in progress - -### Deprecation Timeline - -1. **Phase 1 (Current)**: Both systems coexist, deprecation warnings shown -2. **Phase 2**: Components migrated to use registry directly -3. **Phase 3**: Legacy adapter marked for removal -4. **Phase 4**: Legacy adapter removed - -### Migration Guide - -**Old Code:** -```typescript -import { getBlockDefinition } from './lib/blockDefinitions' -const def = getBlockDefinition('linear') -``` - -**New Code:** -```typescript -import { getNodeDefinition, BackendFramework } from './lib/nodes/registry' -const def = getNodeDefinition('linear', BackendFramework.PyTorch) -``` - -## Validation Patterns - -### Dimension Validation - -```typescript -// Require specific dimension count -return this.validateDimensions(sourceOutputShape, { - dims: 4, - description: '[batch, channels, height, width]' -}) - -// Accept multiple dimension counts -return this.validateDimensions(sourceOutputShape, { - dims: [2, 4], - description: '(2D or 4D)' -}) - -// Accept any dimensions -return this.validateDimensions(sourceOutputShape, { - dims: 'any', - description: '' -}) -``` - -### Config Validation - -```typescript -validateConfig(config: BlockConfig): string[] { - const errors = super.validateConfig(config) - - // Custom validation logic - if (config.embed_dim % config.num_heads !== 0) { - errors.push('Embedding dimension must be divisible by number of heads') - } - - return errors -} -``` - -## Best Practices - -1. **One File Per Node** - Keep node definitions focused and isolated -2. **Use Base Classes** - Leverage provided base classes for common patterns -3. **Document Dimensions** - Clearly document expected input/output shapes -4. **Validate Early** - Catch configuration errors in `validateConfig()` -5. **Test Shape Logic** - Verify shape computation with unit tests -6. **Framework Parity** - Keep frontend and backend definitions aligned - -## Testing - -### Frontend - -```typescript -describe('LinearNode', () => { - const node = new LinearNode() - - it('computes output shape correctly', () => { - const input = { dims: [32, 128], description: '' } - const config = { out_features: 64 } - const output = node.computeOutputShape(input, config) - - expect(output?.dims).toEqual([32, 64]) - }) - - it('validates 2D input requirement', () => { - const input = { dims: [32, 3, 224, 224], description: '' } - const error = node.validateIncomingConnection('conv2d', input, {}) - - expect(error).toContain('requires 2D input') - }) -}) -``` - -### Backend - -```python -def test_linear_node_shape_computation(): - node = LinearNode() - input_shape = TensorShape([32, 128]) - config = {'out_features': 64} - - output_shape = node.compute_output_shape(input_shape, config) - - assert output_shape.dims == [32, 64] - -def test_linear_node_validation(): - node = LinearNode() - input_shape = TensorShape([32, 3, 224, 224]) - - error = node.validate_incoming_connection('conv2d', input_shape, {}) - - assert 'requires 2D input' in error -``` - -## Troubleshooting - -### Node Not Appearing in Palette - -1. Check node is exported in `index.ts` -2. Verify `BlockType` includes the type -3. Check console for registry errors -4. Ensure metadata is correct - -### Shape Computation Not Working - -1. Verify `computeOutputShape()` returns valid `TensorShape` -2. Check input shape is defined -3. Validate config parameters exist -4. Test with unit tests - -### Validation Too Strict/Loose - -1. Review `validateIncomingConnection()` logic -2. Check dimension requirements -3. Consider source node type exceptions -4. Test edge cases - -## Future Enhancements - -- **Code Generation**: Add `generateCode()` method to node definitions -- **Visual Customization**: Node-specific rendering hints -- **Advanced Validation**: Cross-node validation rules -- **Performance Metrics**: Built-in FLOP/parameter counting -- **Auto-Documentation**: Generate docs from node definitions - -## Questions? - -See `NODES_AND_RULES.md` for detailed node-specific rules and connection requirements. diff --git a/docs/PHASE4_COMPLETE.md b/docs/PHASE4_COMPLETE.md deleted file mode 100644 index ad1fdbe..0000000 --- a/docs/PHASE4_COMPLETE.md +++ /dev/null @@ -1,564 +0,0 @@ -# Phase 4 Complete: Legacy Code Removal ✅ - -**Date**: November 9, 2025 -**Phase**: 4 of 4 - Legacy System Removal -**Status**: ✅ **COMPLETE** - -## Executive Summary - -Successfully completed the final phase of the VisionForge modernization: complete removal of all legacy `blockDefinitions` code and adapter layers. The application now runs entirely on the new modular node registry system with zero backward compatibility code. - -## Achievements - -### 🎯 Primary Goals (All Met) -1. ✅ Removed all legacy `getBlockDefinition()` calls -2. ✅ Removed all legacy `validateBlockConnection()` calls -3. ✅ Deleted 2 legacy files (`blockDefinitionsAdapter.ts`, `blockDefinitions.ts`) -4. ✅ Migrated 6 components to pure registry API -5. ✅ Zero TypeScript compilation errors -6. ✅ Production build successful -7. ✅ Development server running - -### 📊 Migration Statistics - -| Metric | Count | -|--------|-------| -| **Components Migrated** | 6 files | -| **Legacy Files Deleted** | 2 files | -| **Lines of Code Removed** | ~800 LOC | -| **TypeScript Errors** | 0 | -| **Build Time** | 20.01s | -| **Bundle Size Reduction** | ~5 KB | - -## Phase 1-3 Recap (Previously Completed) - -### Phase 1: Input Block Enhancement ✅ -- Added manual shape entry field (`[1, 3, 224, 224]` default) -- Implemented dual-mode: DataLoader priority → manual config → default -- Modified: `frontend/src/lib/nodes/definitions/pytorch/input.ts` - -### Phase 2: Block Overlap Removal ✅ -- Deleted `isPositionOverlapping()`, `checkCollision()`, `resolveCollisions()` -- Simplified `findAvailablePosition()` -- Blocks can now freely overlap on canvas -- Modified: `frontend/src/components/Canvas.tsx` - -### Phase 3: ThemeToggle Fix ✅ -- Replaced lucide-react icons with Phosphor icons -- Added `text-foreground` class for visibility -- Integrated into Header component -- Modified: `ThemeToggle.tsx`, `Header.tsx` - -## Phase 4: Legacy Code Removal (This Phase) ✅ - -### Files Migrated - -#### 1. BlockPalette.tsx -**Changes**: -```typescript -// Removed -import { blockDefinitions, getBlocksByCategory, BlockDefinition } from './blockDefinitions' - -// Added -import { getAllNodeDefinitions, getNodeDefinitionsByCategory } from '@/lib/nodes/registry' -import { BackendFramework } from '@/lib/nodes/registry' - -// Transformation -const allBlocks = getAllNodeDefinitions(BackendFramework.PyTorch).map(nodeDef => ({ - type: nodeDef.metadata.type, - label: nodeDef.metadata.label, - category: nodeDef.metadata.category, - color: nodeDef.metadata.color, - icon: nodeDef.metadata.icon -})) -``` - -**Impact**: Palette now renders directly from registry, no adapter overhead. - -#### 2. ConfigPanel.tsx -**Changes**: -```typescript -// Removed -import { getBlockDefinition } from '@/lib/blockDefinitions' -const definition = getBlockDefinition(selectedNode.data.blockType) - -// Added -import { getNodeDefinition, BackendFramework } from '@/lib/nodes/registry' -const nodeDef = getNodeDefinition(selectedNode.data.blockType as BlockType, BackendFramework.PyTorch) -const definition = nodeDef.metadata -``` - -**Impact**: Configuration UI reads schema directly from node class, not proxy. - -#### 3. Canvas.tsx (Most Complex) -**Changes** (4 sections): - -**A. handleBlockClickInternal**: -```typescript -// Old -const definition = getBlockDefinition(blockType) -Object.values(definition.configSchema).forEach(...) - -// New -const nodeDef = getNodeDefinition(blockType as BlockType, BackendFramework.PyTorch) -nodeDef.configSchema.forEach(...) -``` - -**B. onDrop**: -```typescript -// Old -const definition = getBlockDefinition(blockType) - -// New -const nodeDef = getNodeDefinition(blockType as BlockType, BackendFramework.PyTorch) -const definition = nodeDef.metadata -``` - -**C. onConnect (Validation)**: -```typescript -// Old -const errorMessage = validateBlockConnection( - sourceNode.data.blockType, - targetNode.data.blockType, - sourceNode.data.outputShape -) - -// New -const targetNodeDef = getNodeDefinition( - targetNode.data.blockType as BlockType, - BackendFramework.PyTorch -) -const errorMessage = targetNodeDef.validateIncomingConnection( - sourceNode.data.blockType as BlockType, - sourceNode.data.outputShape, - targetNode.data.config -) -``` - -**D. MiniMap Colors**: -```typescript -// Old -nodeColor={(node) => { - const def = getBlockDefinition((node.data as BlockData).blockType) - return def?.color || '#3b82f6' -}} - -// New -nodeColor={(node) => { - const nodeDef = getNodeDefinition( - (node.data as BlockData).blockType as BlockType, - BackendFramework.PyTorch - ) - return nodeDef?.metadata.color || '#3b82f6' -}} -``` - -**Impact**: All canvas operations (drop, connect, validate) use registry directly. - -#### 4. BlockNode.tsx -**Changes**: -```typescript -// Old -import { getBlockDefinition } from '@/lib/blockDefinitions' -const definition = getBlockDefinition(data.blockType) - -// New -import { getNodeDefinition, BackendFramework } from '@/lib/nodes/registry' -const nodeDef = getNodeDefinition(data.blockType as BlockType, BackendFramework.PyTorch) -const definition = nodeDef.metadata -``` - -**Impact**: Node rendering reads metadata from registry, one less indirection. - -#### 5. CustomConnectionLine.tsx -**Changes**: -```typescript -// Old -import { validateBlockConnection } from '@/lib/blockDefinitions' -const validationError = validateBlockConnection( - sourceNode.data.blockType, - targetNode.data.blockType, - sourceNode.data.outputShape -) - -// New -import { getNodeDefinition, BackendFramework } from '@/lib/nodes/registry' -const targetNodeDef = getNodeDefinition( - targetNode.data.blockType as BlockType, - BackendFramework.PyTorch -) -const validationError = targetNodeDef.validateIncomingConnection( - sourceNode.data.blockType as BlockType, - sourceNode.data.outputShape, - targetNode.data.config -) -``` - -**Impact**: Live connection validation during drag uses node class method. - -#### 6. store.ts (State Management) -**Changes** (4 locations): - -**A. Imports**: -```typescript -// Removed -import { getBlockDefinition, validateBlockConnection, allowsMultipleInputs } from './blockDefinitions' - -// Added (BlockType was already imported) -import { getNodeDefinition, BackendFramework } from './nodes/registry' -``` - -**B. Multi-Input Check**: -```typescript -// Old -if (!allowsMultipleInputs(targetNode.data.blockType)) { - -// New -const allowsMultiple = targetNode.data.blockType === 'concat' || targetNode.data.blockType === 'add' -if (!allowsMultiple) { -``` - -**C. Validation**: -```typescript -// Old -const validationError = validateBlockConnection( - sourceNode.data.blockType, - targetNode.data.blockType, - sourceNode.data.outputShape -) - -// New -const targetNodeDef = getNodeDefinition( - targetNode.data.blockType as BlockType, - BackendFramework.PyTorch -) -const validationError = targetNodeDef.validateIncomingConnection( - sourceNode.data.blockType as BlockType, - sourceNode.data.outputShape, - targetNode.data.config -) -``` - -**D. Required Fields Validation**: -```typescript -// Old -const def = getBlockDefinition(node.data.blockType) -if (def) { - const requiredFields = def.configSchema.filter((f) => f.required) - -// New -const nodeDef = getNodeDefinition(node.data.blockType as BlockType, BackendFramework.PyTorch) -if (nodeDef) { - const requiredFields = nodeDef.configSchema.filter((f) => f.required) -``` - -**E. Dimension Inference** (removed fallback): -```typescript -// Old (with fallback) -if (nodeDef) { - const outputShape = nodeDef.computeOutputShape(...) -} else { - const def = getBlockDefinition(node.data.blockType) - if (def) { - const outputShape = def.computeOutputShape(...) - } -} - -// New (pure registry) -if (nodeDef) { - const outputShape = nodeDef.computeOutputShape(...) -} -``` - -**Impact**: Store now operates purely on registry, no legacy code paths. - -### Files Deleted - -1. **`frontend/src/lib/legacy/blockDefinitionsAdapter.ts`** (~450 LOC) - - Proxy-based compatibility layer - - Dynamic property access via `get()` trap - - Deprecation warnings - - `validateBlockConnection()` wrapper - - `allowsMultipleInputs()` helper - -2. **`frontend/src/lib/blockDefinitions.ts`** (~30 LOC after previous refactor) - - Re-export of adapter functions - - Type definitions duplicated from registry - - Legacy entry point for old imports - -**Total Deleted**: ~480 lines of technical debt - -### Verification - -**Search Results** (after deletion): -```bash -# No legacy imports remain -grep -r "from '@/lib/blockDefinitions'" frontend/src/ -# Result: 0 matches - -grep -r "from '@/lib/legacy/blockDefinitionsAdapter'" frontend/src/ -# Result: 0 matches -``` - -**TypeScript Compilation**: -```bash -npm run build -# Result: ✓ built in 20.01s (0 errors) -``` - -**Development Server**: -```bash -npm run dev -# Result: ✓ ready in 649 ms on http://localhost:5001/ -``` - -## Migration Pattern Documentation - -### Standard Transformation Pattern - -Every component followed this systematic approach: - -#### Step 1: Update Imports -```typescript -// Before -import { getBlockDefinition, validateBlockConnection } from '@/lib/blockDefinitions' - -// After -import { getNodeDefinition, BackendFramework } from '@/lib/nodes/registry' -import { BlockType } from '@/lib/types' -``` - -#### Step 2: Replace Function Calls -```typescript -// Before -const definition = getBlockDefinition(blockType) - -// After -const nodeDef = getNodeDefinition(blockType as BlockType, BackendFramework.PyTorch) -``` - -#### Step 3: Access Metadata -```typescript -// Before -const color = definition.color -const icon = definition.icon - -// After -const color = nodeDef.metadata.color -const icon = nodeDef.metadata.icon -``` - -#### Step 4: ConfigSchema Iteration -```typescript -// Before -Object.values(definition.configSchema).forEach(field => {...}) - -// After -nodeDef.configSchema.forEach(field => {...}) -``` - -#### Step 5: Validation Calls -```typescript -// Before -const error = validateBlockConnection(sourceType, targetType, outputShape) - -// After -const targetNodeDef = getNodeDefinition(targetType as BlockType, BackendFramework.PyTorch) -const error = targetNodeDef.validateIncomingConnection(sourceType as BlockType, outputShape, config) -``` - -## Benefits Realized - -### 1. Type Safety Improvements -- **Before**: Mix of legacy types and registry types -- **After**: Single source of truth with strict `BlockType` casting -- **Impact**: Better IDE autocomplete, compile-time error detection - -### 2. Performance Gains -- **Before**: Proxy overhead on every property access -- **After**: Direct class instance access -- **Measured**: No measurable difference in dev/build times (overhead was minimal) - -### 3. Code Maintainability -- **Before**: 8 files importing from legacy adapter -- **After**: 0 files with legacy imports, all use registry -- **Impact**: Future node additions only require registry updates - -### 4. Bundle Size -- **Before**: ~480 LOC of adapter/compatibility code -- **After**: 0 LOC of legacy code -- **Reduction**: ~5 KB minified + gzipped - -### 5. Architectural Clarity -- **Before**: Two parallel systems (registry + legacy) -- **After**: Single modular registry system -- **Impact**: Easier onboarding for new contributors - -## Testing Results - -### Automated Tests ✅ -- [x] TypeScript compilation: **0 errors** -- [x] ESLint: **0 errors** (4 style warnings in BlockNode.tsx - cosmetic) -- [x] Vite build: **Success** (20.01s) -- [x] Dev server: **Success** (649ms startup) - -### Integration Points Verified ✅ -- [x] BlockPalette renders all 17 nodes -- [x] ConfigPanel shows node-specific schemas -- [x] Canvas allows block placement -- [x] Node connections validate correctly -- [x] MiniMap displays with correct colors -- [x] CustomConnectionLine shows validation feedback -- [x] Store manages state transitions -- [x] Dimension inference propagates through graph - -### Known Non-Issues -**Tailwind Linting Warnings** (4 occurrences in BlockNode.tsx): -``` -The class `!bg-accent` can be written as `bg-accent!` -``` -- **Type**: Style preference (Tailwind v4 syntax) -- **Impact**: None (both syntaxes valid) -- **Action**: Optional cleanup, not blocking - -## Remaining Work (Backend) - -Frontend migration is **100% complete**. Backend implementation is separate work: - -### PyTorch Nodes (2/17 complete) -**Implemented**: -- ✅ Linear -- ✅ Conv2D - -**Pending** (15 nodes): -- Input, DataLoader, Flatten, Dropout, BatchNorm2D -- MaxPool2D, AvgPool2D, AdaptiveAvgPool2D -- Conv1D, Conv3D, LSTM, GRU, Embedding -- Concat, Add - -### TensorFlow Nodes (0/17 complete) -All nodes need implementation (same list as PyTorch). - -### Backend APIs -**Defined** (in `frontend/src/lib/api.ts`): -- `/api/validate` - Architecture validation -- `/api/chat` - AI assistant -- `/api/export` - Code generation - -**Status**: Endpoints scaffolded but not integrated. - -## Lessons Learned - -### What Worked Exceptionally Well -1. **Systematic File-by-File Approach**: Prevented scope creep and errors -2. **Type-First Migration**: TypeScript caught 100% of issues at compile time -3. **Incremental Verification**: Checked errors after each edit -4. **Pattern Consistency**: Single transformation pattern across all files -5. **Parallel Operations**: Read multiple file sections simultaneously - -### Challenges Overcome -1. **Store Complexity**: Multiple validation and inference code paths -2. **ConfigSchema Iteration**: Different iteration patterns (Object.values vs array) -3. **Multi-Input Logic**: Converted helper function to inline logic -4. **Terminal Commands**: PowerShell path escaping for file deletion - -### Best Practices Established -1. **Always cast to `BlockType`** when calling `getNodeDefinition()` -2. **Access metadata via `nodeDef.metadata.*`**, not `nodeDef.*` -3. **Use node instance methods** for validation/computation (better encapsulation) -4. **Verify zero imports** of deleted files before deletion - -### Recommendations for Future Work -1. **Use same pattern** for backend node implementations -2. **Keep type definitions** in sync between frontend and backend -3. **Add unit tests** for each new node class -4. **Document validation rules** in node class docstrings - -## Next Steps - -### Immediate (High Priority) -1. **Manual QA Testing**: Full user acceptance test - - [ ] Drag blocks from palette - - [ ] Configure node parameters - - [ ] Connect blocks (valid + invalid) - - [ ] Test Input block manual shape - - [ ] Verify block overlap works - - [ ] Toggle theme - - [ ] Export code - - [ ] Undo/redo - -2. **Documentation Updates**: - - [ ] Update `NODES_AND_RULES.md` (Input block dual-mode, overlap feature) - - [ ] Update `IMPLEMENTATION_SUMMARY.md` (mark Phase 4 complete) - - [ ] Update `PRD.md` (remove "deprecated" labels) - - [ ] Create `docs/MIGRATION_GUIDE.md` (for future contributors) - -### Short-Term (Medium Priority) -3. **Backend Implementation**: - - [ ] Complete 15 remaining PyTorch nodes - - [ ] Add input validation to backend APIs - - [ ] Connect frontend to backend endpoints - -4. **Feature Enhancements**: - - [ ] Add localStorage project persistence - - [ ] Implement project save/load UI - - [ ] Add keyboard shortcuts (Delete, Ctrl+Z, etc.) - -### Long-Term (Low Priority) -5. **TensorFlow Support**: Implement 17 TensorFlow nodes -6. **Testing Infrastructure**: Unit tests, integration tests, E2E tests -7. **Performance Optimization**: Code splitting, lazy loading, bundle analysis - -## Rollback Plan - -**If critical issues arise**, rollback is straightforward via Git: - -```bash -# View recent commits -git log --oneline -10 - -# Rollback to before Phase 4 (example) -git revert HEAD~6..HEAD - -# Or restore specific files -git checkout HEAD~6 -- frontend/src/lib/blockDefinitions.ts -git checkout HEAD~6 -- frontend/src/lib/legacy/blockDefinitionsAdapter.ts -``` - -**Files to restore**: 2 deleted files + 6 migrated components -**Estimated rollback time**: < 5 minutes -**Likelihood of rollback**: Very low (build verification confirms success) - -## Success Criteria Final Check - -### ✅ All Criteria Met -1. ✅ **Zero TypeScript errors** - Verified via `npm run build` -2. ✅ **All legacy code removed** - 2 files deleted, 0 legacy imports -3. ✅ **No breaking changes** - User-facing functionality unchanged -4. ✅ **Production build succeeds** - Built in 20.01s -5. ✅ **Development server runs** - Started in 649ms -6. ✅ **All components migrated** - 6 files now use pure registry API -7. ✅ **Enhanced features work**: - - Input block manual shape entry ✅ - - Block overlap enabled ✅ - - ThemeToggle visible ✅ - -## Conclusion - -**Phase 4 is COMPLETE**. VisionForge now runs entirely on the new modular node registry system with: -- ✅ 100% legacy code removed -- ✅ 0 TypeScript compilation errors -- ✅ 0 breaking changes to user experience -- ✅ ~5 KB bundle size reduction -- ✅ Improved code maintainability -- ✅ Better type safety -- ✅ Single source of truth architecture - -The application is **production-ready** pending manual QA sign-off. - ---- - -**Phase Completed**: November 9, 2025 -**Total Migration Time**: 4 phases across multiple sessions -**Migration Success Rate**: 100% -**Production Ready**: ✅ Yes (pending QA) diff --git a/docs/PHASE_1-3_IMPLEMENTATION_SUMMARY.md b/docs/PHASE_1-3_IMPLEMENTATION_SUMMARY.md deleted file mode 100644 index d487491..0000000 --- a/docs/PHASE_1-3_IMPLEMENTATION_SUMMARY.md +++ /dev/null @@ -1,325 +0,0 @@ -# Phase 1-3 Implementation Summary - -## 🎯 Objective Achieved - -Successfully transformed VisionForge's node definition system from class-based runtime instances to a declarative, template-driven architecture inspired by Langflow. The backend can now **emit source code** for all node types and serve node specifications via REST API. - ---- - -## 📊 Implementation Scope - -### ✅ Phase 1: Backend Domain Model Refactor -**Duration:** Complete -**Files Created:** 14 -**Lines of Code:** ~2,500 - -#### Core Infrastructure -- [x] `specs/models.py` - Frozen dataclass models (NodeSpec, ConfigFieldSpec, etc.) -- [x] `specs/registry.py` - LRU-cached spec loading system -- [x] `specs/serialization.py` - JSON serialization + deterministic hashing -- [x] `templates/renderer.py` - Jinja2 template rendering engine -- [x] `rules/shape.py` - Shape computation functions (9 functions) -- [x] `rules/validation.py` - Connection & config validation (6 functions) - -#### Node Specifications -- [x] `specs/pytorch/__init__.py` - 17 PyTorch node specs - - Input, Linear, Conv2D, Flatten, ReLU, Dropout, BatchNorm, MaxPool, Softmax - - Concat, Add, Attention, Custom, DataLoader, Output, Loss, Empty -- [x] `specs/tensorflow/__init__.py` - 14 TensorFlow node specs - - Dense, Conv2D, Flatten, Dropout, BatchNorm, MaxPool - - Concat, Add, DataLoader, Output, Loss, Empty, Custom - -**Total Nodes:** 31 (17 PyTorch + 14 TensorFlow) - ---- - -### ✅ Phase 2: Backend API Redesign -**Duration:** Complete -**Files Modified:** 2 - -#### Updated Endpoints -- [x] `GET /api/node-definitions?framework={framework}` - - Returns all NodeSpec objects serialized to JSON - - Response includes config schema, templates, metadata - - Includes deterministic hash for caching - -- [x] `GET /api/node-definitions/{node_type}?framework={framework}` - - Returns single NodeSpec by type - - Framework-specific filtering - -- [x] `POST /api/render-node-code` (NEW) - - Accepts: `{node_type, framework, config, metadata}` - - Returns: `{code, spec_hash, context}` - - Renders Jinja2 template with configuration - -#### Implementation Details -- Updated `architecture_views.py` to use new registry -- Added URL routing in `urls.py` -- Maintained backward compatibility with existing endpoints - ---- - -### ✅ Phase 3: Frontend Integration -**Duration:** Complete -**Files Created:** 4 - -#### TypeScript Infrastructure -- [x] `lib/nodeSpec.types.ts` - TypeScript interfaces for NodeSpec - - Framework type - - ConfigField, ConfigOption interfaces - - NodeSpec, NodeTemplate interfaces - - API response types - -- [x] `lib/api.ts` (Updated) - - Added `renderNodeCode()` function - - Typed API responses using NodeSpec types - - Proper error handling - -- [x] `lib/useNodeSpecs.ts` - React hooks - - `useNodeSpecs()` - Fetch all specs for a framework - - `useNodeSpec()` - Fetch single spec by type - - Includes `renderCode()` helper - -- [x] `components/CodePreview.tsx` - Code preview component - - Displays rendered code for a node - - Shows loading/error states - - Styled with Tailwind CSS - -**Note:** Frontend components are ready but not yet wired into existing UI (BlockPalette, ConfigPanel). This allows for incremental migration. - ---- - -## 🧪 Test Coverage - -### Comprehensive Test Suite -**File:** `test_nodespec_system.py` -**Status:** ✅ All tests passing - -#### Test Categories -1. **Spec Registry** - Loading, caching, retrieval -2. **Serialization** - Dict conversion, deterministic hashing -3. **Template Rendering** - Jinja2 rendering for PyTorch/TensorFlow -4. **Shape Computation** - NCHW (PyTorch) & NHWC (TensorFlow) inference -5. **Validation** - Config validation, connection validation -6. **API Integration** - All 3 endpoints tested - -#### Test Results -``` -============================================================ -✅ ALL TESTS PASSED -============================================================ - -Phase 1-3 Implementation Complete: - ✓ Backend Domain Model Refactor (Phase 1) - ✓ Backend API Redesign (Phase 2) - ✓ Frontend Integration (Phase 3) -``` - ---- - -## 📐 Architecture Highlights - -### Key Design Patterns - -#### 1. Declarative Specs -```python -NodeSpec( - type="conv2d", - label="Conv2D", - framework=Framework.PYTORCH, - config_schema=(...), # Immutable tuple - template=NodeTemplateSpec(...) -) -``` -- **Immutable** - Frozen dataclasses, tuple config schemas -- **Serializable** - Pure data, no methods -- **Cacheable** - Deterministic hashing - -#### 2. Template-Based Code Generation -```jinja2 -nn.Conv2d({{ config.in_channels }}, {{ config.out_channels }}, - kernel_size={{ config.kernel_size }}, ...) -``` -- **Inspectable** - Code is visible, not hidden in Python methods -- **Editable** - Templates can be modified without changing Python classes -- **Framework-agnostic** - Same pattern for PyTorch and TensorFlow - -#### 3. Lazy Registry -```python -@lru_cache(maxsize=1) -def _load_spec_map() -> SpecMap: - """Load once, cache forever""" -``` -- **Performance** - Specs loaded on first access only -- **Thread-safe** - LRU cache handles concurrency -- **Hot-reload friendly** - Cache can be cleared for development - -#### 4. Framework Abstraction -```python -if framework is Framework.PYTORCH: - # NCHW: [batch, channels, height, width] -else: # TensorFlow - # NHWC: [batch, height, width, channels] -``` -- **Explicit** - Framework differences are visible -- **Type-safe** - Enum prevents typos -- **Extensible** - Easy to add new frameworks - ---- - -## 📈 Performance Metrics - -### Backend -- **Spec Loading:** ~5ms (first load, then cached) -- **Template Rendering:** <1ms per node -- **API Response:** ~10-20ms for full spec list -- **Deterministic Hash:** <1ms per spec - -### Frontend -- **Type Safety:** 100% typed (no `any` types) -- **Bundle Size:** Minimal increase (<5KB gzipped) -- **API Calls:** Optimized with React hooks caching - ---- - -## 🔄 Migration Path - -### Deprecated (Can be removed in future cleanup) -- `block_manager/services/nodes/registry.py` (old class-based registry) -- `block_manager/services/nodes/pytorch/*.py` (old node classes) -- `block_manager/services/nodes/tensorflow/*.py` (old node classes) -- `frontend/src/lib/blockDefinitions.ts` (local node definitions) - -### Recommended Next Steps -1. Update `BlockPalette.tsx` to use `useNodeSpecs()` hook -2. Update `ConfigPanel.tsx` to render forms from `configSchema` -3. Add `` to sidebar for real-time code preview -4. Remove old block definitions once migration is verified -5. Update code generation to use template system - ---- - -## 📦 Dependencies Added - -### Backend -```txt -jinja2>=3.1.0 -``` - -### Frontend -No new dependencies - uses existing React infrastructure. - ---- - -## 🎓 What Was Learned - -### Technical Insights -1. **Python inspect.getsource() limitations** - Can't extract source from runtime classes, hence template approach -2. **Frozen dataclasses for immutability** - Essential for thread safety and caching -3. **Jinja2 StrictUndefined** - Catches template errors at render time, not runtime -4. **Framework-specific shapes** - NCHW vs NHWC requires careful handling -5. **Deterministic hashing** - Canonical JSON + SHA256 for cache invalidation - -### Design Decisions -- **Tuples over lists** for config schema (immutable, hashable) -- **camelCase for JSON** (frontend convention) vs snake_case (Python convention) -- **Separate validation/shape modules** (separation of concerns) -- **LRU cache for registry** (performance without complexity) - ---- - -## 📚 Documentation - -### Complete Documentation Package -1. **NODESPEC_IMPLEMENTATION_COMPLETE.md** - Full architecture, testing, migration guide -2. **NODESPEC_QUICK_REFERENCE.md** - Developer quick reference, common patterns -3. **This file** - Implementation summary and next steps - -### Inline Documentation -- Docstrings for all public functions -- Type hints for all parameters -- Comments explaining non-obvious logic - ---- - -## 🚀 Next Steps (Recommended) - -### Short Term -1. **Frontend UI Integration** - - Wire `useNodeSpecs()` into BlockPalette - - Update ConfigPanel to use `configSchema` - - Add CodePreview to node details - -2. **Testing** - - Add frontend unit tests for hooks - - Test API endpoints in production - - Verify framework switching - -### Medium Term -3. **Code Generation** - - Update export pipeline to use templates - - Generate full PyTorch/TensorFlow projects - - Add code validation before export - -4. **Cleanup** - - Remove deprecated class-based registry - - Delete old node definition files - - Update documentation - -### Long Term -5. **Features** - - User-defined custom nodes via UI - - Template customization interface - - Multi-framework project support - - Node versioning and migration - ---- - -## 🎯 Success Criteria - -### ✅ All Criteria Met - -- [x] Backend can emit source code for all node types -- [x] API serves node specifications as JSON -- [x] Frontend has TypeScript types for all responses -- [x] Template rendering works for PyTorch and TensorFlow -- [x] Shape inference handles NCHW and NHWC formats -- [x] Validation prevents invalid connections -- [x] All tests pass (100% coverage for new code) -- [x] No placeholders or incomplete implementations -- [x] Documentation is comprehensive - ---- - -## 👥 Contributors - -**Implementation:** GitHub Copilot -**Testing:** Automated test suite + manual verification -**Documentation:** Comprehensive guides and references - ---- - -## 📞 Support - -For questions or issues: -1. See `NODESPEC_QUICK_REFERENCE.md` for common patterns -2. Run `test_nodespec_system.py` for verification -3. Check API responses for schema details -4. Review inline code documentation - ---- - -## 🏁 Final Notes - -This implementation successfully replaces the class-based node definition system with a declarative, template-driven architecture. The new system: - -- **Enables code emission** - Templates can be rendered with any config -- **Improves maintainability** - Pure data structures instead of classes -- **Enhances flexibility** - Easy to add new nodes or frameworks -- **Maintains performance** - LRU caching and lazy loading -- **Ensures type safety** - Frozen dataclasses and TypeScript interfaces -- **Supports testing** - All components tested in isolation and integration - -**Status:** ✅ Ready for Production - -The system is fully functional and tested. Frontend integration is prepared but not yet wired into the UI, allowing for incremental migration without disrupting existing functionality. diff --git a/docs/PORT_BASED_CONNECTION_SYSTEM.md b/docs/PORT_BASED_CONNECTION_SYSTEM.md deleted file mode 100644 index b691aed..0000000 --- a/docs/PORT_BASED_CONNECTION_SYSTEM.md +++ /dev/null @@ -1,544 +0,0 @@ -# Port-Based Connection System Implementation - -## Overview -This document describes the comprehensive port-based connection system implemented to fix all connection-related bugs and enable semantic, handle-aware validation throughout the VisionForge application. - -## Implementation Date -December 2024 - -## Problem Statement -The original connection system had 19 identified bugs and flaws related to block connections, including: - -### Critical Bugs -1. **Named Input Port Connections Not Validated**: Loss nodes could connect any output to any input port (e.g., y_pred to y_true port) -2. **Loss Type Changes Don't Update Connections**: Changing loss type from MSE (2 inputs) to TripletLoss (3 inputs) left invalid connections -3. **Connection Validation Missing Handle Information**: Validation logic didn't consider which specific port (handle) was being connected -4. **Target Handle Occupancy Not Checked**: Multiple connections could be made to the same input port - -### High Priority Issues -5. DataLoader output ports had no semantic type information -6. Real-time validation missing for loss input count -7. No visual feedback for which ports are already connected -8. Backend validation didn't support multi-input loss nodes - -## Solution Architecture - -### Phase 1: Port Definition System - -#### Frontend Port System (`/project/frontend/src/lib/nodes/ports.ts`) - -```typescript -export enum PortSemantic { - // Data flow semantics - Data = 'data', // General data tensor - Labels = 'labels', // Ground truth labels - Predictions = 'predictions', // Model predictions - Features = 'features', // Feature representations - - // Loss function semantics - Anchor = 'anchor', // Triplet loss anchor - Positive = 'positive', // Triplet loss positive - Negative = 'negative', // Triplet loss negative - Loss = 'loss', // Loss value output - - // Special semantics - Any = 'any', // Accepts any connection - Generic = 'generic' // Default/unspecified -} - -export interface PortDefinition { - id: string // Unique handle ID - label: string // Display name - semantic: PortSemantic // Semantic type for validation - required?: boolean // Whether port must be connected - description?: string // Tooltip/help text -} -``` - -**Key Features:** -- Semantic typing ensures correct connections (e.g., ground truth can't connect to prediction port) -- Extensible enum for future node types (optimizer, custom layers, etc.) -- Compatibility checking via `arePortsCompatible()` function -- Default ports provided for backwards compatibility - -#### Backend Port System (`/project/block_manager/services/nodes/ports.py`) - -```python -from dataclasses import dataclass -from enum import Enum -from typing import Optional - -class PortSemantic(Enum): - DATA = 'data' - LABELS = 'labels' - PREDICTIONS = 'predictions' - FEATURES = 'features' - ANCHOR = 'anchor' - POSITIVE = 'positive' - NEGATIVE = 'negative' - LOSS = 'loss' - ANY = 'any' - GENERIC = 'generic' - -@dataclass -class PortDefinition: - id: str - label: str - semantic: PortSemantic - required: bool = False - description: Optional[str] = None -``` - -**Parity with Frontend:** -- Mirrors TypeScript structure exactly -- Used in NodeSpec definitions -- Enables backend validation alignment - -### Phase 2: Connection Validation - -#### Handle-Aware Validation (`/project/frontend/src/lib/store.ts`) - -**validateConnection Enhancement:** - -```typescript -validateConnection: (connection) => { - // 1. Validate source handle exists - const sourceHandleId = connection.sourceHandle || 'default' - const sourcePorts = sourceNodeDef.getOutputPorts(sourceNode.data.config) - const sourcePort = sourcePorts.find(p => p.id === sourceHandleId) - if (!sourcePort) return false - - // 2. Validate target handle exists - const targetHandleId = connection.targetHandle || 'default' - const targetPorts = targetNodeDef.getInputPorts(targetNode.data.config) - const targetPort = targetPorts.find(p => p.id === targetHandleId) - if (!targetPort) return false - - // 3. Check if target handle already occupied - const handleOccupied = edges.some(e => - e.target === connection.target && - (e.targetHandle || 'default') === targetHandleId - ) - if (handleOccupied) return false - - // 4. Semantic compatibility validation - if (!arePortsCompatible(sourcePort, targetPort)) return false - - // 5. Real-time loss node input count validation - if (targetNode.data.blockType === 'loss') { - const requiredPorts = targetPorts - const existingConnections = edges.filter(e => e.target === connection.target) - const totalConnectionsAfter = existingConnections.length + 1 - - if (totalConnectionsAfter > requiredPorts.length) { - return false // Prevent exceeding max inputs - } - } - - return true -} -``` - -**Key Improvements:** -- ✅ Prevents connections to non-existent ports -- ✅ Blocks duplicate connections to same port -- ✅ Ensures semantic compatibility (data types match) -- ✅ Real-time feedback during connection drag -- ✅ Prevents adding too many inputs to loss nodes - -#### Architecture-Level Validation (`validateArchitecture`) - -**Enhanced Loss Node Validation:** - -```typescript -// Check total connection count -if (incomingEdges.length !== requiredPorts.length) { - errors.push({ - nodeId: node.id, - message: `Loss function requires ${requiredPorts.length} inputs, has ${incomingEdges.length}`, - type: 'error' - }) -} else { - // Check that all required ports are filled (handle-aware) - const connectedHandles = new Set( - incomingEdges.map(e => e.targetHandle || 'default') - ) - - const missingPorts = requiredPorts.filter( - p => !connectedHandles.has(p.id) - ) - - if (missingPorts.length > 0) { - errors.push({ - nodeId: node.id, - message: `Loss node missing connections to: ${missingPorts.map(p => p.label).join(', ')}`, - type: 'error' - }) - } -} -``` - -**Validation Flow:** -1. Check correct number of connections -2. Verify all required ports are connected (not just count) -3. Provide specific error messages naming missing ports - -### Phase 3: Visual Improvements - -#### Port Occupancy Indicators (`/project/frontend/src/components/BlockNode.tsx`) - -**Visual Feedback System:** - -```typescript -// Helper function to check if handle is connected -const isHandleConnected = (handleId: string, isTarget: boolean) => { - return edges.some(edge => { - if (isTarget) { - return edge.target === id && (edge.targetHandle || 'default') === handleId - } else { - return edge.source === id && (edge.sourceHandle || 'default') === handleId - } - }) -} - -// Apply to Loss node input handles -const isConnected = isHandleConnected(handleId, true) - - - - {port.label} {isConnected && '✓'} - -``` - -**Visual Features:** -- ✅ Green ring around connected ports -- ✅ Checkmark (✓) next to connected port labels -- ✅ Dimmed labels for connected ports -- ✅ Color change to green (#10b981) for connected handles -- ✅ Applied to both DataLoader outputs and Loss inputs - -### Phase 5: Backend Validation Alignment - -#### Updated ArchitectureValidator (`/project/block_manager/services/validation.py`) - -**Loss Node Support:** - -```python -def _validate_connections(self): - # Allow multiple inputs for loss blocks - if block_type not in ['concat', 'add', 'loss']: - self.errors.append(...) - elif block_type == 'loss': - self._validate_loss_connections(node, edges_list) - -def _validate_loss_connections(self, node, edges_list): - """Validate loss node connections match required inputs for loss type""" - from .nodes.specs.pytorch import LOSS_SPEC - - loss_type = config.get('loss_type', 'cross_entropy') - required_ports = LOSS_SPEC.input_ports_config.get(loss_type, []) - - # Check connection count - if len(edges_list) != len(required_ports): - self.errors.append(ValidationError(...)) - return - - # Check all required ports are filled (handle-aware) - connected_handles = { - edge.get('targetHandle', 'default') for edge in edges_list - } - - missing_ports = [] - for port in required_ports: - handle_id = f'loss-input-{port.id}' - if handle_id not in connected_handles: - missing_ports.append(port.label) - - if missing_ports: - self.errors.append(ValidationError(...)) -``` - -**Backend Features:** -- ✅ Recognizes loss as valid multi-input block -- ✅ Imports LOSS_SPEC to get required ports -- ✅ Validates connection count matches loss type requirements -- ✅ Handle-aware validation (checks specific ports, not just count) -- ✅ Detailed error messages naming missing ports - -## Node Definition Updates - -### Base Class (`/project/frontend/src/lib/nodes/base.ts`) - -```typescript -export abstract class NodeDefinition implements INodeDefinition { - // New port methods with default implementations - getInputPorts(config: BlockConfig): PortDefinition[] { - return [{ - id: 'default', - label: 'Input', - semantic: PortSemantic.Any - }] - } - - getOutputPorts(config: BlockConfig): PortDefinition[] { - return [{ - id: 'default', - label: 'Output', - semantic: PortSemantic.Any - }] - } -} -``` - -**Backwards Compatibility:** -- All existing nodes automatically get default ports -- No changes required to nodes that don't need custom ports - -### Loss Node (`/project/frontend/src/lib/nodes/definitions/pytorch/loss.ts`) - -```typescript -getInputPorts(config: BlockConfig): PortDefinition[] { - const lossType = config.loss_type || 'cross_entropy' - - const portConfigs: Record = { - 'cross_entropy': [ - { id: 'y_pred', label: 'Predictions', semantic: PortSemantic.Predictions, required: true }, - { id: 'y_true', label: 'Labels', semantic: PortSemantic.Labels, required: true } - ], - 'mse': [ - { id: 'y_pred', label: 'Predictions', semantic: PortSemantic.Predictions, required: true }, - { id: 'y_true', label: 'Targets', semantic: PortSemantic.Labels, required: true } - ], - 'triplet_margin': [ - { id: 'anchor', label: 'Anchor', semantic: PortSemantic.Anchor, required: true }, - { id: 'positive', label: 'Positive', semantic: PortSemantic.Positive, required: true }, - { id: 'negative', label: 'Negative', semantic: PortSemantic.Negative, required: true } - ] - } - - return portConfigs[lossType] || portConfigs['cross_entropy'] -} - -getOutputPorts(config: BlockConfig): PortDefinition[] { - return [{ - id: 'loss-output', - label: 'Loss', - semantic: PortSemantic.Loss - }] -} -``` - -**Features:** -- ✅ Returns different ports based on loss_type config -- ✅ Semantic types prevent incorrect connections -- ✅ Required flags for validation -- ✅ Falls back to cross_entropy if unknown type - -### DataLoader Node (`/project/frontend/src/lib/nodes/definitions/pytorch/dataloader.ts`) - -```typescript -getOutputPorts(config: BlockConfig): PortDefinition[] { - const ports: PortDefinition[] = [] - const numInputOutlets = Number(config.num_input_outlets || 1) - const hasGT = config.has_ground_truth - - // Add input data outlets - for (let i = 0; i < numInputOutlets; i++) { - ports.push({ - id: numInputOutlets > 1 ? `input-output-${i}` : 'input-output', - label: numInputOutlets > 1 ? `Input ${i + 1}` : 'Input', - semantic: PortSemantic.Data - }) - } - - // Add ground truth outlet if configured - if (hasGT) { - ports.push({ - id: 'ground-truth-output', - label: 'Ground Truth', - semantic: PortSemantic.Labels - }) - } - - return ports -} -``` - -**Features:** -- ✅ Dynamic port generation based on config -- ✅ Data semantic for input outlets -- ✅ Labels semantic for ground truth -- ✅ Unique IDs for each port - -### Interface Update (`/project/frontend/src/lib/nodes/contracts.ts`) - -```typescript -export interface INodeDefinition extends IShapeComputer, INodeValidator { - readonly metadata: NodeMetadata - readonly configSchema: ConfigField[] - - // NEW: Port definition methods - getInputPorts(config: BlockConfig): PortDefinition[] - getOutputPorts(config: BlockConfig): PortDefinition[] - - getDefaultConfig(): BlockConfig - generateCode?(config: BlockConfig, varName: string): string -} -``` - -## Bugs Fixed - -### Critical Bugs Resolved -1. ✅ **Named Input Port Connections Now Validated**: Semantic types prevent wrong connections -2. ✅ **Loss Type Changes Handled**: getInputPorts() re-evaluates on config change -3. ✅ **Connection Validation Uses Handle Info**: Full handle-aware validation pipeline -4. ✅ **Target Handle Occupancy Checked**: Prevents duplicate connections to same port - -### High Priority Issues Resolved -5. ✅ **DataLoader Semantic Types**: Outputs have Data/Labels semantic types -6. ✅ **Real-time Loss Validation**: validateConnection checks input count -7. ✅ **Visual Port Feedback**: Green rings and checkmarks show connected ports -8. ✅ **Backend Loss Support**: ArchitectureValidator handles multi-input loss nodes - -### Additional Improvements -- ✅ Comprehensive error messages with specific port names -- ✅ Type-safe port system prevents runtime errors -- ✅ Extensible architecture for future node types -- ✅ Backwards compatible with existing nodes -- ✅ Frontend-backend parity in validation logic - -## Testing Recommendations - -### Manual Testing Scenarios - -1. **Loss Node Connection Validation** - - Create a DataLoader with ground truth enabled - - Add a Loss node with MSE loss type - - Try connecting ground truth to y_pred port → Should fail - - Connect ground truth to y_true port → Should succeed - - Change loss to TripletLoss → Should show 3 input ports - - Try connecting 4th input → Should fail - -2. **Port Occupancy Indicators** - - Create DataLoader with 2 input outlets - - Connect first outlet to a layer - - Verify checkmark appears on first outlet - - Verify second outlet remains unconnected visually - -3. **Real-time Validation** - - Create Loss node - - Try connecting 3 inputs to 2-input loss (MSE) → Should prevent 3rd connection - - Verify helpful error message in console - -4. **Backend Validation** - - Create architecture with missing loss inputs - - Export architecture - - Verify backend returns specific error about missing ports - -### Automated Testing (Future Work) - -```typescript -describe('Port-Based Connection System', () => { - test('validates semantic compatibility', () => { - const dataPort = { semantic: PortSemantic.Data } - const labelsPort = { semantic: PortSemantic.Labels } - expect(arePortsCompatible(dataPort, labelsPort)).toBe(false) - }) - - test('prevents duplicate connections to same port', () => { - // Test connection validation logic - }) - - test('shows visual feedback for connected ports', () => { - // Test BlockNode rendering - }) -}) -``` - -## Migration Guide - -### For New Node Types - -To add a new node with custom ports: - -1. **Define Port Configurations** - ```typescript - getInputPorts(config: BlockConfig): PortDefinition[] { - return [ - { id: 'input1', label: 'Input 1', semantic: PortSemantic.Data }, - { id: 'input2', label: 'Input 2', semantic: PortSemantic.Features } - ] - } - ``` - -2. **Update BlockNode Rendering** (if needed) - ```tsx - // Custom rendering logic for special handle layouts - ``` - -3. **Add Backend NodeSpec** - ```python - input_ports_config = { - 'default': [ - PortDefinition('input1', 'Input 1', PortSemantic.DATA), - PortDefinition('input2', 'Input 2', PortSemantic.FEATURES) - ] - } - ``` - -### For Existing Nodes - -No changes required! Default ports are automatically provided. - -## Performance Considerations - -- Port definitions are computed on-demand (not stored in state) -- Validation runs only during connection attempts (not on every render) -- isHandleConnected() uses efficient Set lookups -- No significant performance impact observed - -## Future Enhancements - -### Potential Additions - -1. **Port Constraints** - ```typescript - interface PortDefinition { - maxConnections?: number // Limit connections per port - allowedSemantics?: PortSemantic[] // Whitelist for compatibility - } - ``` - -2. **Dynamic Port Creation** - - Nodes that add/remove ports based on user actions - - Example: Concat block that grows with each connection - -3. **Port Metadata** - ```typescript - interface PortDefinition { - dataType?: 'tensor' | 'scalar' | 'image' - shape?: TensorShape - constraints?: ValidationRule[] - } - ``` - -4. **Visual Port Indicators** - - Color-coded ports by semantic type - - Shape indicators (circle = data, square = labels, etc.) - - Animated connection preview showing valid targets - -## Conclusion - -The port-based connection system provides a robust, type-safe foundation for node connections in VisionForge. It fixes all critical bugs, improves user experience with visual feedback, and establishes patterns for future node development. The system maintains backwards compatibility while enabling powerful new features like semantic validation and handle-aware connections. - -## Related Documentation - -- [Loss Node Multiple Inputs](./LOSS_NODE_MULTIPLE_INPUTS.md) -- [Node Definition Architecture](./NODE_DEFINITION_ARCHITECTURE.md) -- [NodeSpec Implementation](./NODESPEC_IMPLEMENTATION_COMPLETE.md) diff --git a/docs/PORT_SYSTEM_QUICK_REFERENCE.md b/docs/PORT_SYSTEM_QUICK_REFERENCE.md deleted file mode 100644 index f58f7bd..0000000 --- a/docs/PORT_SYSTEM_QUICK_REFERENCE.md +++ /dev/null @@ -1,359 +0,0 @@ -# Port System Quick Reference - -## Port Semantic Types - -```typescript -enum PortSemantic { - Data = 'data', // General data tensor - Labels = 'labels', // Ground truth labels/targets - Predictions = 'predictions', // Model predictions/outputs - Features = 'features', // Intermediate feature representations - Anchor = 'anchor', // Triplet loss anchor sample - Positive = 'positive', // Triplet loss positive sample - Negative = 'negative', // Triplet loss negative sample - Loss = 'loss', // Loss value output - Any = 'any', // Accepts any connection type - Generic = 'generic' // Default/unspecified type -} -``` - -## Compatibility Matrix - -| Source ↓ / Target → | Data | Labels | Predictions | Features | Anchor | Positive | Negative | Loss | Any | Generic | -|---------------------|------|--------|-------------|----------|--------|----------|----------|------|-----|---------| -| **Data** | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | -| **Labels** | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | -| **Predictions** | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | -| **Features** | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | -| **Anchor** | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | -| **Positive** | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | -| **Negative** | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ✅ | -| **Loss** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | -| **Any** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Generic** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | - -## Creating Custom Ports - -### Frontend (TypeScript) - -```typescript -// In your node definition class -import { PortDefinition, PortSemantic } from '@/lib/nodes/ports' - -getInputPorts(config: BlockConfig): PortDefinition[] { - return [ - { - id: 'input-1', // Unique ID (used as handleId) - label: 'Input Data', // Display name - semantic: PortSemantic.Data, // Type for validation - required: true, // Optional: must be connected - description: 'Input tensor data' // Optional: tooltip text - } - ] -} - -getOutputPorts(config: BlockConfig): PortDefinition[] { - return [ - { - id: 'output-1', - label: 'Output', - semantic: PortSemantic.Features - } - ] -} -``` - -### Backend (Python) - -```python -# In your NodeSpec -from block_manager.services.nodes.ports import PortDefinition, PortSemantic - -CUSTOM_SPEC = NodeSpec( - # ... other fields ... - input_ports_config={ - 'default': [ - PortDefinition( - id='input-1', - label='Input Data', - semantic=PortSemantic.DATA, - required=True, - description='Input tensor data' - ) - ] - }, - output_ports_config={ - 'default': [ - PortDefinition( - id='output-1', - label='Output', - semantic=PortSemantic.FEATURES - ) - ] - } -) -``` - -## Dynamic Ports Based on Config - -```typescript -getInputPorts(config: BlockConfig): PortDefinition[] { - const mode = config.mode || 'standard' - - const portConfigs: Record = { - 'standard': [ - { id: 'input', label: 'Input', semantic: PortSemantic.Data } - ], - 'advanced': [ - { id: 'input-1', label: 'Input 1', semantic: PortSemantic.Data }, - { id: 'input-2', label: 'Input 2', semantic: PortSemantic.Features } - ] - } - - return portConfigs[mode] || portConfigs['standard'] -} -``` - -## Handle IDs for BlockNode Rendering - -### Loss Node Pattern - -```tsx -{inputPorts.map((port, i) => { - const handleId = `loss-input-${port.id}` // Prefix + port.id - - return ( - - ) -})} -``` - -### DataLoader Pattern - -```tsx -// Single outlet - - -// Multiple outlets - - -// Ground truth - -``` - -## Validation Integration - -### Check Port Occupancy - -```typescript -const isHandleConnected = (handleId: string, isTarget: boolean) => { - return edges.some(edge => { - if (isTarget) { - return edge.target === id && (edge.targetHandle || 'default') === handleId - } else { - return edge.source === id && (edge.sourceHandle || 'default') === handleId - } - }) -} -``` - -### Connection Validation - -```typescript -validateConnection: (connection) => { - const sourcePort = sourceNodeDef.getOutputPorts(sourceConfig) - .find(p => p.id === (connection.sourceHandle || 'default')) - - const targetPort = targetNodeDef.getInputPorts(targetConfig) - .find(p => p.id === (connection.targetHandle || 'default')) - - if (!sourcePort || !targetPort) return false - - return arePortsCompatible(sourcePort, targetPort) -} -``` - -## Common Patterns - -### Single Input/Output (Default) - -```typescript -// No need to override - base class provides defaults -// Automatically gets 'default' handle with PortSemantic.Any -``` - -### Multiple Named Inputs - -```typescript -getInputPorts(config: BlockConfig): PortDefinition[] { - return [ - { id: 'main', label: 'Main Input', semantic: PortSemantic.Data }, - { id: 'auxiliary', label: 'Auxiliary', semantic: PortSemantic.Features } - ] -} -``` - -### Config-Dependent Ports - -```typescript -getOutputPorts(config: BlockConfig): PortDefinition[] { - const ports: PortDefinition[] = [] - - // Always have main output - ports.push({ - id: 'output', - label: 'Output', - semantic: PortSemantic.Data - }) - - // Conditional additional output - if (config.include_attention_weights) { - ports.push({ - id: 'attention', - label: 'Attention Weights', - semantic: PortSemantic.Features - }) - } - - return ports -} -``` - -### Array of Similar Ports - -```typescript -getInputPorts(config: BlockConfig): PortDefinition[] { - const numInputs = Number(config.num_inputs || 2) - const ports: PortDefinition[] = [] - - for (let i = 0; i < numInputs; i++) { - ports.push({ - id: `input-${i}`, - label: `Input ${i + 1}`, - semantic: PortSemantic.Data - }) - } - - return ports -} -``` - -## Visual Styling - -### Connected Port Indicators - -```tsx -const isConnected = isHandleConnected(handleId, true) - - - - - {port.label} {isConnected && '✓'} - -``` - -### Color Coding by Semantic Type - -```typescript -const getPortColor = (semantic: PortSemantic): string => { - const colors: Record = { - [PortSemantic.Data]: '#3b82f6', // Blue - [PortSemantic.Labels]: '#10b981', // Green - [PortSemantic.Predictions]: '#f59e0b', // Orange - [PortSemantic.Features]: '#8b5cf6', // Purple - [PortSemantic.Anchor]: '#ef4444', // Red - [PortSemantic.Positive]: '#10b981', // Green - [PortSemantic.Negative]: '#f59e0b', // Orange - [PortSemantic.Loss]: '#dc2626', // Dark red - [PortSemantic.Any]: '#6b7280', // Gray - [PortSemantic.Generic]: '#9ca3af' // Light gray - } - - return colors[semantic] || colors[PortSemantic.Generic] -} -``` - -## Error Messages - -### Semantic Mismatch - -``` -"Port semantic mismatch: predictions -> labels" -``` - -### Port Occupancy - -``` -"Target handle y_true already connected" -``` - -### Input Count - -``` -"Loss function 'triplet_margin' only accepts 3 inputs (Anchor, Positive, Negative). Cannot add more." -``` - -### Missing Ports - -``` -"Loss node missing connections to: Positive, Negative" -``` - -## Testing Checklist - -- [ ] Port definitions return correct count -- [ ] Port IDs are unique per node -- [ ] Semantic types match expected connections -- [ ] Handle IDs in BlockNode match port.id -- [ ] Connection validation blocks invalid semantics -- [ ] Duplicate connections prevented -- [ ] Visual feedback shows connected ports -- [ ] Backend validation mirrors frontend logic -- [ ] Config changes update ports correctly -- [ ] Error messages are specific and helpful - -## Troubleshooting - -### "Property 'getInputPorts' does not exist" - -**Solution:** Ensure node class extends `NodeDefinition` base class and `INodeDefinition` interface includes port methods. - -### Connections not validating semantics - -**Solution:** Check that `arePortsCompatible()` is called in `validateConnection` and port definitions have correct semantic types. - -### Visual indicators not showing - -**Solution:** Verify `edges` is imported from store and `isHandleConnected()` uses correct handleId format. - -### Backend validation failing - -**Solution:** Ensure `input_ports_config` in NodeSpec matches frontend port definitions exactly. - -## Best Practices - -1. **Always use semantic types** - Don't rely on Generic/Any unless truly needed -2. **Make port IDs descriptive** - Use names like 'y_pred', 'anchor', not 'input1', 'input2' -3. **Provide descriptions** - Help users understand what each port expects -4. **Test with real connections** - Verify semantic validation works as intended -5. **Keep frontend-backend in sync** - Port definitions should match between TS and Python -6. **Document custom ports** - Add comments explaining special port configurations -7. **Use required flag** - Mark critical ports that must be connected -8. **Handle config changes** - Ensure ports update when relevant config changes - -## Resources - -- Full Implementation: [PORT_BASED_CONNECTION_SYSTEM.md](./PORT_BASED_CONNECTION_SYSTEM.md) -- Loss Node Example: [LOSS_NODE_MULTIPLE_INPUTS.md](./LOSS_NODE_MULTIPLE_INPUTS.md) -- Node Architecture: [NODE_DEFINITION_ARCHITECTURE.md](./NODE_DEFINITION_ARCHITECTURE.md) diff --git a/docs/PRD.md b/docs/PRD.md deleted file mode 100644 index 55b258a..0000000 --- a/docs/PRD.md +++ /dev/null @@ -1,147 +0,0 @@ -# Visual AI Model Builder - PRD - -A browser-based visual interface for designing neural network architectures through intuitive drag-and-drop interactions, with automatic dimension inference, real-time validation, and PyTorch code generation. - -**Experience Qualities:** -1. **Intuitive** - Building complex AI architectures should feel as natural as sketching on paper, with immediate visual feedback -2. **Empowering** - Users discover capabilities through exploration, with helpful guidance preventing mistakes before they happen -3. **Professional** - The interface conveys precision and technical sophistication appropriate for AI engineering work - -**Complexity Level**: Light Application (multiple features with basic state) -This is a specialized design tool with persistent state management, real-time validation, and code generation - beyond a simple showcase but not requiring user accounts or backend infrastructure. - -## Essential Features - -### Canvas-Based Architecture Design -**Functionality**: Drag-and-drop neural network blocks onto an infinite canvas and connect them with visual edges -**Purpose**: Enable intuitive visual thinking about model architecture without code syntax barriers -**Trigger**: User drags a block from the palette onto the canvas -**Progression**: Select block from palette → Drag to canvas → Block appears with handles → Click handle → Drag to another block's handle → Connection validates → Dimensions auto-infer -**Success criteria**: Users can create complex multi-layer architectures in under 2 minutes; invalid connections prevented with clear explanations - -### Real-Time Dimension Inference -**Functionality**: Automatically compute tensor shapes as blocks connect, propagating dimensions through the entire graph; supports multi-modal inputs defined as generalized tensor shapes -**Purpose**: Eliminate manual dimension calculation errors and provide instant feedback on architecture validity; enable flexibility for any data modality (text, image, audio, video, tabular) -**Trigger**: Connection created between two blocks or block parameter changed -**Progression**: Connection made → System traces back to input → Computes output shape based on layer parameters → Updates all downstream blocks → Visual display refreshes -**Success criteria**: All shape calculations complete within 100ms; users never see dimension mismatch errors at export time; supports arbitrary tensor dimensions for any modality - -### Intelligent Block Configuration -**Functionality**: Context-aware parameter panels that show only relevant settings with smart defaults, inline validation, and quick presets for common modalities -**Purpose**: Guide users to correct configurations while allowing expert customization; provide quick-start templates for different data types -**Trigger**: User selects a block on canvas -**Progression**: Click block → Side panel opens → Display block-specific parameters → User modifies value or selects preset (Image/Text/Audio/Tabular) → Real-time validation → Dimensions update → Visual feedback -**Success criteria**: Required parameters clearly indicated; impossible values prevented; helpful tooltips on hover; one-click presets for common use cases - -### Multi-Framework Code Export -**Functionality**: Generate complete, runnable PyTorch or TensorFlow model code with training boilerplate -**Purpose**: Bridge visual design to production-ready code without manual translation -**Trigger**: User clicks Export button after building valid architecture -**Progression**: Click export → Select framework → System validates architecture → Generates model class + training script + config → Display code preview → Copy to clipboard or download -**Success criteria**: Generated code runs without modification; includes helpful comments; follows framework best practices - -### Project Persistence -**Functionality**: Save and load architecture designs with all configurations preserved -**Purpose**: Enable iterative design across sessions without losing work -**Trigger**: User clicks Save or loads a previous project -**Progression**: Click save → Architecture serialized to browser storage → Confirmation shown → Later: Click load → Select project → Canvas restores with all blocks, connections, and parameters -**Success criteria**: Projects persist indefinitely in browser; reload is pixel-perfect; supports 10+ saved projects - -## Edge Case Handling - -**Circular Dependencies** - Prevent connection creation that would create cycles; show tooltip "Neural networks must be acyclic" -**Orphaned Blocks** - Highlight unconnected blocks in orange; warning message "3 blocks not connected to main graph" -**Missing Input Block** - Prevent export with error "Architecture must start with an Input block" -**Dimension Mismatches** - Block connection attempt shows immediate tooltip "Conv2D requires 4D input [B,C,H,W], got 2D [B,F]" -**Browser Storage Limits** - Show warning at 80% quota; offer export to JSON file option -**Invalid Parameters** - Input field turns red; inline message "Must be positive integer" -**Multiple Frameworks** - Same visual architecture exports to different code syntax; internally track framework choice per project - -## Design Direction - -The design should feel like a precision engineering tool - clean, technical, and focused - with the polished sophistication of professional CAD software. A minimal interface ensures the architecture diagram remains the hero, while purposeful micro-interactions provide guidance without distraction. The aesthetic should communicate reliability and technical depth. - -## Color Selection - -**Triadic color scheme** (three equally spaced colors) creating visual hierarchy between input/processing/output operations, with each block category receiving a distinct hue family while maintaining harmony. - -- **Primary Color**: Deep Blue (oklch(0.45 0.15 250)) - Represents core processing layers (Linear, Conv, etc.); conveys technical precision and computational intelligence -- **Secondary Colors**: - - Teal (oklch(0.55 0.12 180)) for input/output operations - suggests data flow and connectivity - - Purple (oklch(0.50 0.13 290)) for advanced blocks (Attention, Transformer) - indicates sophisticated capabilities -- **Accent Color**: Vibrant Cyan (oklch(0.70 0.15 200)) for interactive elements, connection lines, and active states - creates visual energy and guides attention to actions -- **Foreground/Background Pairings**: - - Background (Soft Gray oklch(0.98 0 0)): Dark text (oklch(0.20 0 0)) - Ratio 13.1:1 ✓ - - Card (White oklch(1 0 0)): Dark text (oklch(0.20 0 0)) - Ratio 14.8:1 ✓ - - Primary (Deep Blue oklch(0.45 0.15 250)): White text (oklch(1 0 0)) - Ratio 7.2:1 ✓ - - Secondary (Light Gray oklch(0.96 0 0)): Dark text (oklch(0.20 0 0)) - Ratio 12.8:1 ✓ - - Accent (Vibrant Cyan oklch(0.70 0.15 200)): White text (oklch(1 0 0)) - Ratio 4.9:1 ✓ - - Muted (Soft Gray oklch(0.95 0 0)): Muted text (oklch(0.50 0 0)) - Ratio 6.8:1 ✓ - -## Font Selection - -Typography should balance technical legibility with modern sophistication - clear monospace numerals for dimensions, crisp sans-serif for labels, and consistent hierarchy throughout the interface. **Inter** for all UI text (exceptional clarity at small sizes, professional feel) and **JetBrains Mono** for code display and dimension annotations (designed for programming contexts). - -- **Typographic Hierarchy**: - - H1 (Project Title): Inter SemiBold/24px/tight tracking - strong presence without overwhelming - - H2 (Panel Headers): Inter Medium/16px/normal tracking - clear section delineation - - Body (Block Labels): Inter Regular/14px/normal tracking - optimal legibility on canvas - - Small (Dimensions): JetBrains Mono Regular/12px/wide tracking - technical precision - - Code (Export Preview): JetBrains Mono Regular/13px/normal tracking - familiar to developers - -## Animations - -Animations should reinforce the sense of a living, responsive system - blocks settle into place with subtle physics, connections draw with purpose, and validation feedback appears instantly but gracefully. Movement is restrained and functional, never decorative. - -- **Purposeful Meaning**: Blocks "snap" into grid alignment with gentle spring physics, communicating the underlying structure; connection lines draw from source to target (not fade in) to show direction of data flow; validation errors pulse once to catch attention without distraction -- **Hierarchy of Movement**: Connection creation (300ms ease-out) receives most animation emphasis as it's the primary creative action; block selection is instant (50ms) for responsive feel; panel transitions are quick (200ms) to feel snappy; validation feedback appears immediately (100ms) then settles - -## Component Selection - -- **Components**: - - Canvas: Custom div-based infinite canvas with pan/zoom controls, SVG overlay for connections - - Block Palette: ScrollArea with categorized Accordion sections, each BlockItem as draggable Card - - Config Panel: Sheet (right-anchored) with dynamic Form components based on selected block - - Blocks: Custom Card components with gradient borders indicating category, Badge for block type - - Connections: Custom SVG paths with Arrow markers, color-coded by validation state - - Export Modal: Dialog with Tabs for PyTorch/TensorFlow, syntax-highlighted code in ScrollArea - - Inputs: Input for text/numbers, Select for dropdowns, Switch for booleans, Slider for ranges - - Validation: Alert components for errors, Toast (sonner) for save/load confirmations - - Project Selector: DropdownMenu in header with saved projects list - -- **Customizations**: - - Block cards need custom drag handles and connection ports (small circles on edges) - - Connection SVG paths need animated drawing effect on creation - - Config panel forms need dynamic rendering based on block type schema - - Canvas needs custom zoom controls (+ - reset buttons) in bottom-right - - Block palette items show icon + name in horizontal layout with drag cursor - -- **States**: - - Blocks: default (neutral), selected (cyan border + shadow), invalid (red border), dragging (50% opacity) - - Connections: valid (cyan solid line), invalid-prevented (red dashed), hover (thicker line) - - Inputs: default, focused (cyan ring), error (red border + text), disabled (gray) - - Buttons: default, hover (slight scale + brightness), active (pressed scale), loading (spinner) - -- **Icon Selection**: - - Phosphor icons throughout for consistency and clarity - - FlowArrow for connections, Plus for add block, Download for export - - Lightning for processing blocks, Brain for AI blocks, GitBranch for splits/merges - - Eye for visibility toggle, Trash for delete, Copy for duplicate - - Warning for validation errors, CheckCircle for success states - -- **Spacing**: - - Canvas grid: 20px for subtle alignment guidance - - Panel padding: p-6 for breathing room around content - - Block internal padding: p-4 for compact but readable - - Form field spacing: space-y-4 for clear field separation - - Button spacing: px-6 py-2 for comfortable touch targets - - Section gaps: gap-8 between major sections, gap-4 within sections - -- **Mobile**: - - Stack layout: Canvas goes full-screen with floating action buttons - - Block palette: Bottom sheet drawer (swipe up to open) - - Config panel: Full-screen modal overlay when block selected - - Touch targets: Minimum 44px for all interactive elements - - Gestures: Pinch-to-zoom on canvas, long-press to select, double-tap for properties - - Simplified: Hide dimension annotations until block selected (reduce visual clutter) - - Progressive: Desktop shows all three panels simultaneously; mobile shows one at a time diff --git a/docs/QUICKSTART.md b/docs/QUICKSTART.md deleted file mode 100644 index f5edf1c..0000000 --- a/docs/QUICKSTART.md +++ /dev/null @@ -1,113 +0,0 @@ -# VisionForge Chatbot - Quick Start Guide - -Get the AI chatbot up and running in 5 minutes! - -## Prerequisites - -- Python 3.8+ installed -- Node.js 16+ installed -- Google account (for Gemini API key) - -## Step 1: Get Your API Key (2 minutes) - -1. Visit https://aistudio.google.com/app/apikey -2. Sign in with your Google account -3. Click **"Create API Key"** -4. Copy the key to your clipboard - -## Step 2: Configure Backend (1 minute) - -1. Navigate to the project backend directory: - ```bash - cd project - ``` - -2. Create a `.env` file: - ```bash - # On Windows - copy .env.example .env - - # On macOS/Linux - cp .env.example .env - ``` - -3. Edit `.env` and paste your API key: - ```env - GEMINI_API_KEY=paste-your-key-here - ``` - -## Step 3: Install Dependencies (1 minute) - -```bash -# Install Python dependencies -pip install -r requirements.txt -``` - -## Step 4: Start the Servers (1 minute) - -**Terminal 1 - Backend:** -```bash -cd project -python manage.py runserver -``` - -**Terminal 2 - Frontend:** -```bash -cd project/frontend -npm run dev -``` - -## Step 5: Use the Chatbot! - -1. Open your browser to `http://localhost:5173` -2. Click the **chat bubble icon** in the bottom-right corner -3. Start chatting! - -### Quick Examples - -**Q&A Mode (default):** -``` -You: "How do I add a convolutional layer?" -AI: "You can add a Conv2D layer from the block palette..." -``` - -**Modification Mode (toggle ON):** -``` -You: "Add a Conv2D layer with 64 filters" -AI: [Suggests modification with "Apply Change" button] -``` - -## That's It! - -You now have a fully functional AI-powered chatbot that can: -- Answer questions about your neural network -- Suggest improvements -- Modify your workflow with one click - -For detailed documentation, see [CHATBOT_SETUP.md](./CHATBOT_SETUP.md) - -## Troubleshooting - -### "API key is not configured" -- Check that `.env` file exists in the `project` folder -- Verify `GEMINI_API_KEY=your-key` is set correctly -- Restart the backend server - -### "Connection error" -- Ensure both servers are running -- Backend: `http://localhost:8000` -- Frontend: `http://localhost:5173` - -### Chat not responding -- Check browser console (F12) for errors -- Verify your API key is valid -- Check internet connection - -## Next Steps - -1. Toggle **Modification Mode** to let AI modify your workflow -2. Ask the AI to suggest improvements to your architecture -3. Apply modifications with one click -4. Iterate and refine your neural network design - -Happy building with VisionForge! diff --git a/docs/QUICK_START_AI_PROVIDERS.md b/docs/QUICK_START_AI_PROVIDERS.md deleted file mode 100644 index 68eb221..0000000 --- a/docs/QUICK_START_AI_PROVIDERS.md +++ /dev/null @@ -1,150 +0,0 @@ -# Quick Start: AI Provider Configuration - -## 🚀 Quick Setup - -### Option 1: Use Gemini (Free, Fast) - -1. **Get your key:** https://aistudio.google.com/app/apikey - -2. **Edit `.env`:** - ```env - AI_PROVIDER=gemini - GEMINI_API_KEY=AIzaSy_your_key_here - ``` - -3. **Restart server:** - ```bash - python manage.py runserver - ``` - -**Done!** Your chatbot is now powered by Gemini. - ---- - -### Option 2: Use Claude (High Quality) - -1. **Get your key:** https://console.anthropic.com/ - -2. **Install package:** - ```bash - pip install anthropic - ``` - -3. **Edit `.env`:** - ```env - AI_PROVIDER=claude - ANTHROPIC_API_KEY=sk-ant-your_key_here - ``` - -4. **Restart server:** - ```bash - python manage.py runserver - ``` - -**Done!** Your chatbot is now powered by Claude. - ---- - -## 🔄 Switch Providers Anytime - -Just change `AI_PROVIDER` in `.env` and restart: - -```env -AI_PROVIDER=claude # or 'gemini' -``` - ---- - -## ⚡ Which Should I Choose? - -### Choose Gemini if you want: -- ✅ Free tier (15 requests/minute) -- ✅ Fastest response times -- ✅ Easy setup - -### Choose Claude if you want: -- ✅ Best reasoning and code understanding -- ✅ Most detailed explanations -- ✅ Highest quality suggestions -- ⚠️ Paid API (no free tier) - ---- - -## 📝 Environment Variables Reference - -### Required Variables - -```env -# Choose provider: 'gemini' or 'claude' -AI_PROVIDER=gemini - -# Gemini setup -GEMINI_API_KEY=your_gemini_key - -# Claude setup -ANTHROPIC_API_KEY=your_anthropic_key -``` - -### Complete `.env` Template - -```env -# Django Settings -SECRET_KEY=your-secret-key-here -DEBUG=True - -# AI Provider Configuration -# Choose which AI provider to use: 'gemini' or 'claude' -AI_PROVIDER=gemini - -# Gemini AI Configuration -# Get your API key from: https://aistudio.google.com/app/apikey -GEMINI_API_KEY=your-gemini-api-key-here - -# Claude AI Configuration -# Get your API key from: https://console.anthropic.com/ -ANTHROPIC_API_KEY=your-anthropic-api-key-here - -# Database (optional, defaults to SQLite) -# DATABASE_URL=postgresql://user:password@localhost/dbname -``` - ---- - -## 🆘 Troubleshooting - -### Error: "AI service not properly configured" - -**Fix:** -1. Check `AI_PROVIDER` is set to `gemini` or `claude` -2. Check the corresponding API key is set -3. Restart Django server: `python manage.py runserver` - -### Error: "Invalid AI_PROVIDER" - -**Fix:** Set `AI_PROVIDER` to exactly `gemini` or `claude` (lowercase) - -### Chatbot not responding - -**Fix:** -1. Open browser console (F12) -2. Check backend terminal for errors -3. Verify API key is correct -4. Check you haven't exceeded rate limits - ---- - -## 💡 Pro Tips - -1. **Keep both API keys configured** - makes switching instant -2. **Start with Gemini** - it's free and great for testing -3. **Upgrade to Claude** - when you need highest quality suggestions -4. **Never commit `.env`** - it's already in `.gitignore` - ---- - -## 📚 More Information - -- Full setup guide: `docs/CHATBOT_SETUP.md` -- Implementation details: `docs/AI_PROVIDER_IMPLEMENTATION.md` -- Gemini docs: https://ai.google.dev/docs -- Claude docs: https://docs.anthropic.com/ diff --git a/docs/SECURITY.md b/docs/SECURITY.md deleted file mode 100644 index 67a9cbf..0000000 --- a/docs/SECURITY.md +++ /dev/null @@ -1,31 +0,0 @@ -Thanks for helping make GitHub safe for everyone. - -# Security - -GitHub takes the security of our software products and services seriously, including all of the open source code repositories managed through our GitHub organizations, such as [GitHub](https://github.com/GitHub). - -Even though [open source repositories are outside of the scope of our bug bounty program](https://bounty.github.com/index.html#scope) and therefore not eligible for bounty rewards, we will ensure that your finding gets passed along to the appropriate maintainers for remediation. - -## Reporting Security Issues - -If you believe you have found a security vulnerability in any GitHub-owned repository, please report it to us through coordinated disclosure. - -**Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.** - -Instead, please send an email to opensource-security[@]github.com. - -Please include as much of the information listed below as you can to help us better understand and resolve the issue: - - * The type of issue (e.g., buffer overflow, SQL injection, or cross-site scripting) - * Full paths of source file(s) related to the manifestation of the issue - * The location of the affected source code (tag/branch/commit or direct URL) - * Any special configuration required to reproduce the issue - * Step-by-step instructions to reproduce the issue - * Proof-of-concept or exploit code (if possible) - * Impact of the issue, including how an attacker might exploit the issue - -This information will help us triage your report more quickly. - -## Policy - -See [GitHub's Safe Harbor Policy](https://docs.github.com/en/site-policy/security-policies/github-bug-bounty-program-legal-safe-harbor#1-safe-harbor-terms) diff --git a/docs/TENSORFLOW_IMPLEMENTATION_COMPLETE.md b/docs/TENSORFLOW_IMPLEMENTATION_COMPLETE.md deleted file mode 100644 index def574a..0000000 --- a/docs/TENSORFLOW_IMPLEMENTATION_COMPLETE.md +++ /dev/null @@ -1,331 +0,0 @@ -# TensorFlow Backend Implementation - Complete - -## Overview - -The TensorFlow backend has been successfully implemented for VisionForge, providing full feature parity with the PyTorch backend. Users can now select TensorFlow as their framework when creating projects and generate production-ready `tf.keras` code. - -## Implementation Summary - -### ✅ Completed Components - -#### 1. **TensorFlow Node Definitions (17 nodes)** -Location: `/project/block_manager/services/nodes/tensorflow/` - -All nodes use `tf.keras.layers` APIs with **NHWC (channels_last)** data format: - -**Input/Data Nodes:** -- `input.py` - Input layer with shape specification (NHWC format) -- `dataloader.py` - Data loading using `tf.keras.utils.PyDataset` - -**Convolutional Layers:** -- `conv2d.py` - 2D convolution using `tf.keras.layers.Conv2D` -- `conv1d.py` - 1D convolution using `tf.keras.layers.Conv1D` -- `conv3d.py` - 3D convolution using `tf.keras.layers.Conv3D` - -**Dense/Fully Connected:** -- `linear.py` - Dense layer using `tf.keras.layers.Dense` - -**Normalization & Regularization:** -- `batchnorm2d.py` - Batch normalization using `tf.keras.layers.BatchNormalization` -- `dropout.py` - Dropout using `tf.keras.layers.Dropout` - -**Pooling Layers:** -- `maxpool2d.py` - Max pooling using `tf.keras.layers.MaxPooling2D` -- `avgpool2d.py` - Average pooling using `tf.keras.layers.AveragePooling2D` -- `adaptiveavgpool2d.py` - Global average pooling using `tf.keras.layers.GlobalAveragePooling2D` - -**Utility Layers:** -- `flatten.py` - Flatten using `tf.keras.layers.Flatten` - -**Recurrent Layers:** -- `lstm.py` - LSTM using `tf.keras.layers.LSTM` -- `gru.py` - GRU using `tf.keras.layers.GRU` - -**Embedding:** -- `embedding.py` - Embedding using `tf.keras.layers.Embedding` - -**Merge Operations:** -- `concat.py` - Concatenation using `tf.keras.layers.Concatenate` -- `add.py` - Element-wise addition using `tf.keras.layers.Add` - -#### 2. **Code Generation System** -Location: `/project/block_manager/services/tensorflow_codegen.py` - -**Features:** -- Generates `tf.keras.Model` subclass with proper inheritance -- Implements `call()` method (not `forward()`) for TensorFlow -- Handles `training` parameter for layers like Dropout and BatchNormalization -- Generates complete training script with: - - Model compilation - - Callbacks (ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard) - - Dataset creation using `tf.data.Dataset` - - Training loop using `model.fit()` -- Generates `tf.keras.utils.PyDataset` class for custom data loading -- Generates configuration file with hyperparameters - -**Generated Files:** -1. `model.py` - Model class definition -2. `train.py` - Complete training script -3. `dataset.py` - PyDataset implementation -4. `config.py` - Configuration parameters - -#### 3. **Shape Inference** -- All nodes compute output shapes following **NHWC format**: `[batch, height, width, channels]` -- Proper handling of TensorFlow padding modes: `'valid'` and `'same'` -- Accurate shape calculations for all layer types - -#### 4. **Validation System** -- Framework-agnostic validation in `validation.py` -- Detailed error messages passed to frontend -- Shape mismatch detection with clear explanations -- Connection validation ensuring architectural integrity - -#### 5. **API Integration** -Location: `/project/block_manager/views/export_views.py` - -- Export endpoint updated to route TensorFlow requests to code generator -- Returns multiple generated files (model, train, dataset, config) -- Comprehensive error handling with frontend-friendly messages -- Proper HTTP status codes for different error types - -## Key Technical Details - -### Channel Ordering: NHWC (Channels Last) - -TensorFlow uses **NHWC** format by default: -```python -# Input shape: [batch, height, width, channels] -# Example: [32, 224, 224, 3] for batch of 32 RGB images at 224x224 -``` - -**Comparison with PyTorch (NCHW):** -```python -# PyTorch: [32, 3, 224, 224] -# TensorFlow: [32, 224, 224, 3] -``` - -### Parameter Mapping - -| PyTorch | TensorFlow | Notes | -|------------------|------------------|--------------------------------| -| `out_channels` | `filters` | Conv layers | -| `out_features` | `units` | Dense layers | -| `kernel_size` | `kernel_size` | Same | -| `stride` | `strides` | Note: plural in TensorFlow | -| `padding` (int) | `padding` (str) | 'valid' or 'same' | -| `bias` | `use_bias` | Dense/Conv layers | - -### Padding Modes - -TensorFlow uses string-based padding: -- **`'valid'`**: No padding (equivalent to PyTorch padding=0) -- **`'same'`**: Padding to preserve dimensions (when stride=1) - -### Training Parameter - -Layers that behave differently during training vs inference (Dropout, BatchNormalization) receive the `training` parameter: - -```python -def call(self, inputs, training=None): - x = self.conv1(inputs) - x = self.bn1(x, training=training) # BatchNorm needs training flag - x = self.dropout1(x, training=training) # Dropout needs training flag - return x -``` - -## Generated Code Structure - -### Model Class -```python -class GeneratedModel(keras.Model): - def __init__(self): - super(GeneratedModel, self).__init__() - # Layers initialized as instance attributes - self.layer_0 = layers.Conv2D(...) - self.layer_1 = layers.MaxPooling2D(...) - - def call(self, inputs, training=None): - # Forward pass - x = self.layer_0(inputs) - x = self.layer_1(x) - return x -``` - -### Training Script -```python -# Complete with: -# - Model compilation -# - Callbacks (checkpoint, early stopping, LR scheduler) -# - Training loop using model.fit() -# - Model saving -``` - -### Dataset Class -```python -class CustomDataset(keras.utils.PyDataset): - def __len__(self): - return num_batches - - def __getitem__(self, idx): - # Return batch in NHWC format - return batch_x, batch_y -``` - -## Testing Results - -### ✅ All Tests Passed - -1. **Node Registry**: 17 TensorFlow nodes successfully loaded -2. **Shape Validation**: - - Valid 4D NHWC input to Conv2D ✓ - - Invalid 2D input to Conv2D properly rejected ✓ - - Error messages clear and actionable ✓ -3. **Shape Inference**: - - Conv2D with 'same' padding: `[32,224,224,3]` → `[32,112,112,64]` (stride=2) ✓ - - Conv2D with 'valid' padding: `[32,224,224,3]` → `[32,222,222,64]` ✓ -4. **Code Generation**: - - Valid Python syntax ✓ - - Proper TensorFlow/Keras imports ✓ - - Correct class inheritance ✓ - - Training parameter handling ✓ -5. **End-to-End**: Complete CNN architecture generated successfully ✓ - -## Usage Example - -### Frontend (Project Creation) -```typescript -const project = { - name: "My CNN", - framework: "tensorflow", // Select TensorFlow - // ... -} -``` - -### API Request (Export) -```bash -POST /api/export -{ - "nodes": [...], - "edges": [...], - "format": "tensorflow", - "projectName": "MyCNN" -} -``` - -### Response -```json -{ - "code": "...", // model.py content - "additionalFiles": { - "train.py": "...", - "dataset.py": "...", - "config.py": "..." - } -} -``` - -## Error Handling - -All errors are passed to the frontend with detailed messages: - -**Shape Mismatch Example:** -```json -{ - "error": "Requires 4D input [batch, height, width, channels], got 2D", - "nodeId": "node123", - "type": "error", - "suggestion": "Add a Reshape or ensure previous layer outputs correct dimensions" -} -``` - -**Missing Configuration Example:** -```json -{ - "error": "Conv2D layer requires filters parameter", - "nodeId": "conv_layer_1", - "type": "error", - "suggestion": "Configure the number of output filters in the configuration panel" -} -``` - -## Architecture Highlights - -### 1. **Framework Abstraction** -- Base classes in `base.py` support both frameworks -- `Framework.TENSORFLOW` enum value -- Node registry automatically discovers TensorFlow nodes - -### 2. **Shape Computation** -- Each node implements `compute_output_shape()` -- NHWC format consistently used -- Proper handling of padding modes - -### 3. **Validation** -- Each node implements `validate_incoming_connection()` -- Clear error messages with suggestions -- Shape compatibility checking - -### 4. **Code Generation** -- Topological sorting ensures correct layer order -- Proper variable naming and tracking -- Support for multiple inputs (concat, add) - -## Comparison: PyTorch vs TensorFlow - -| Aspect | PyTorch | TensorFlow | -|---------------------|----------------------------|----------------------------| -| **Data Format** | NCHW (channels_first) | NHWC (channels_last) | -| **Base Class** | `nn.Module` | `keras.Model` | -| **Forward Method** | `forward(x)` | `call(inputs, training)` | -| **Conv Parameter** | `out_channels` | `filters` | -| **Dense Parameter** | `out_features` | `units` | -| **Padding** | Integer | String ('valid', 'same') | -| **DataLoader** | `torch.utils.data.DataLoader` | `tf.keras.utils.PyDataset` | - -## Next Steps / Future Enhancements - -1. **Activation Functions**: Add standalone activation nodes (ReLU, Sigmoid, etc.) -2. **Advanced Layers**: Add attention mechanisms, transformers -3. **Optimization**: Support for mixed precision training -4. **Deployment**: Add export to TensorFlow Lite, TensorFlow.js -5. **Custom Layers**: Support for user-defined custom layers -6. **Model Conversion**: Tools to convert between PyTorch and TensorFlow - -## Files Modified/Created - -### Created: -- `tensorflow/input.py` -- `tensorflow/dataloader.py` -- `tensorflow/conv2d.py` -- `tensorflow/conv1d.py` -- `tensorflow/conv3d.py` -- `tensorflow/linear.py` -- `tensorflow/batchnorm2d.py` -- `tensorflow/dropout.py` -- `tensorflow/maxpool2d.py` -- `tensorflow/avgpool2d.py` -- `tensorflow/adaptiveavgpool2d.py` -- `tensorflow/flatten.py` -- `tensorflow/lstm.py` -- `tensorflow/gru.py` -- `tensorflow/embedding.py` -- `tensorflow/concat.py` -- `tensorflow/add.py` -- `services/tensorflow_codegen.py` - -### Modified: -- `tensorflow/__init__.py` - Export all nodes -- `views/export_views.py` - Connect to TensorFlow code generator -- (validation.py and inference.py already framework-agnostic) - -## Conclusion - -The TensorFlow backend is **production-ready** and provides: -- ✅ Complete node library (17 nodes) -- ✅ Accurate shape inference (NHWC format) -- ✅ Comprehensive validation with detailed error messages -- ✅ Production-quality code generation -- ✅ Complete training pipeline generation -- ✅ Frontend integration via API - -Users can now seamlessly switch between PyTorch and TensorFlow frameworks, with VisionForge handling all the framework-specific details automatically. diff --git a/docs/TENSORFLOW_IMPLEMENTATION_SUMMARY.md b/docs/TENSORFLOW_IMPLEMENTATION_SUMMARY.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs/TENSORFLOW_QUICK_REFERENCE.md b/docs/TENSORFLOW_QUICK_REFERENCE.md deleted file mode 100644 index 79c70c7..0000000 --- a/docs/TENSORFLOW_QUICK_REFERENCE.md +++ /dev/null @@ -1,339 +0,0 @@ -# TensorFlow Backend - Quick Reference Guide - -## Overview - -VisionForge now supports **TensorFlow/Keras** as a backend framework alongside PyTorch. This guide helps you understand the key differences and how to use the TensorFlow backend effectively. - -## Quick Start - -### 1. Select TensorFlow Framework -When creating a new project, select **TensorFlow** as your framework: -```typescript -{ - "name": "My CNN Project", - "framework": "tensorflow" // or "pytorch" -} -``` - -### 2. Build Your Architecture -- Drag and drop nodes from the palette -- Configure layer parameters in the config panel -- Connect nodes to create your architecture - -### 3. Export Code -Click "Export Code" to generate production-ready TensorFlow code. - -## Key Differences: PyTorch vs TensorFlow - -### Data Format - -**TensorFlow uses NHWC (channels_last):** -```python -# TensorFlow: [batch, height, width, channels] -input_shape = [32, 224, 224, 3] # 32 RGB images at 224x224 - -# PyTorch: [batch, channels, height, width] -input_shape = [32, 3, 224, 224] # Same data, different ordering -``` - -### Parameter Names - -| Layer Type | PyTorch Parameter | TensorFlow Parameter | -|------------|-------------------|----------------------| -| Conv2D | `out_channels=64` | `filters=64` | -| Dense | `out_features=128`| `units=128` | -| All Conv | `stride=2` | `strides=2` | - -### Padding - -**TensorFlow uses string-based padding:** -```python -# PyTorch: padding=1 -# TensorFlow: padding='same' or 'valid' -``` - -- **`'valid'`**: No padding (default) -- **`'same'`**: Padding to preserve spatial dimensions (when stride=1) - -## Available Nodes (17 Total) - -### Input & Data -- **Input** - Define input tensor shape (NHWC format) -- **DataLoader** - Data loading using `tf.keras.utils.PyDataset` - -### Convolutional Layers -- **Conv2D** - 2D convolution (`tf.keras.layers.Conv2D`) -- **Conv1D** - 1D convolution for sequences -- **Conv3D** - 3D convolution for video/volumetric data - -### Dense Layers -- **Dense** - Fully connected layer (`tf.keras.layers.Dense`) - -### Normalization & Regularization -- **BatchNorm2D** - Batch normalization (`tf.keras.layers.BatchNormalization`) -- **Dropout** - Dropout regularization - -### Pooling -- **MaxPool2D** - Max pooling -- **AvgPool2D** - Average pooling -- **GlobalAvgPool2D** - Global average pooling (adaptive) - -### Utility -- **Flatten** - Flatten multi-dimensional input to 2D - -### Recurrent -- **LSTM** - Long Short-Term Memory -- **GRU** - Gated Recurrent Unit - -### Embedding -- **Embedding** - Embedding layer for categorical data - -### Merge Operations -- **Concat** - Concatenate tensors along an axis -- **Add** - Element-wise addition - -## Configuration Examples - -### Conv2D Layer -```json -{ - "filters": 64, // Number of output channels - "kernel_size": 3, // 3x3 kernel - "strides": 1, // Stride of 1 - "padding": "same", // Preserve dimensions - "activation": "relu" // Built-in activation (optional) -} -``` - -### Dense Layer -```json -{ - "units": 128, // Number of neurons - "activation": "relu", // Activation function - "use_bias": true // Include bias term -} -``` - -### LSTM Layer -```json -{ - "units": 128, // Hidden state size - "return_sequences": false, // Return only last output - "dropout": 0.2, // Input dropout - "recurrent_dropout": 0.2 // Recurrent dropout -} -``` - -## Common Shapes (NHWC Format) - -### Image Data -```python -# Single RGB image (224x224) -[1, 224, 224, 3] - -# Batch of 32 grayscale images (28x28) -[32, 28, 28, 1] - -# Batch of 64 RGB images (256x256) -[64, 256, 256, 3] -``` - -### Sequence Data -```python -# Batch of 32 sequences, 100 time steps, 50 features -[32, 100, 50] -``` - -### After Flatten -```python -# From: [32, 7, 7, 64] (feature maps) -# To: [32, 3136] (7*7*64 = 3136) -``` - -## Generated Code Structure - -### Model File (`model.py`) -```python -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers - -class YourModel(keras.Model): - def __init__(self): - super(YourModel, self).__init__() - # Layers initialized here - self.layer_0 = layers.Conv2D(32, 3, padding='same') - self.layer_1 = layers.MaxPooling2D(2) - # ... - - def call(self, inputs, training=None): - # Forward pass - x = self.layer_0(inputs) - x = self.layer_1(x) - # ... - return x - -def create_model(): - return YourModel() -``` - -### Training Script (`train.py`) -Complete training pipeline with: -- Model compilation (`optimizer`, `loss`, `metrics`) -- Callbacks (checkpoints, early stopping, LR scheduling) -- Training loop using `model.fit()` -- Model saving - -### Dataset Class (`dataset.py`) -```python -class CustomDataset(keras.utils.PyDataset): - def __len__(self): - return num_batches - - def __getitem__(self, idx): - # Return batch in NHWC format - batch_x = ... # Shape: [batch_size, height, width, channels] - batch_y = ... - return batch_x, batch_y -``` - -## Error Messages - -VisionForge provides detailed error messages to help you fix issues: - -### Shape Mismatch -``` -Error: "Requires 4D input [batch, height, width, channels], got 2D" -Suggestion: "Add a Reshape layer or ensure previous layer outputs correct dimensions" -``` - -### Missing Configuration -``` -Error: "Conv2D layer requires filters parameter" -Suggestion: "Configure the number of output filters in the configuration panel" -``` - -### Invalid Connection -``` -Error: "DataLoader is a source node and cannot accept incoming connections" -``` - -## Best Practices - -### 1. Input Shape -Always specify input shape in **NHWC format**: -```json -{ - "shape": "[32, 224, 224, 3]" // Correct: NHWC - // NOT: "[32, 3, 224, 224]" // Wrong: This is NCHW (PyTorch format) -} -``` - -### 2. Padding Choice -- Use **`'same'`** when you want to preserve spatial dimensions -- Use **`'valid'`** when you want dimensions to shrink (no padding) - -### 3. Activation Functions -You can either: -- Set activation in layer config: `"activation": "relu"` -- Add a separate activation node (for more control) - -### 4. Training-Dependent Layers -These layers automatically handle the `training` parameter: -- Dropout -- BatchNormalization - -### 5. Data Format Consistency -Ensure your data is in **NHWC format** throughout: -- Input data: `[batch, height, width, channels]` -- After Conv2D: `[batch, new_height, new_width, filters]` -- After Flatten: `[batch, features]` - -## Example Architectures - -### Simple CNN for MNIST -``` -Input [32, 28, 28, 1] - ↓ -Conv2D (filters=32, kernel=3, padding='same') - ↓ -MaxPool2D (pool_size=2) - ↓ -Conv2D (filters=64, kernel=3, padding='same') - ↓ -MaxPool2D (pool_size=2) - ↓ -Flatten - ↓ -Dense (units=128, activation='relu') - ↓ -Dropout (rate=0.5) - ↓ -Dense (units=10, activation='softmax') -``` - -### ResNet-style Skip Connection -``` -Input - ├─→ Conv2D → BatchNorm → ReLU → Conv2D → BatchNorm ─┐ - │ ↓ - └─────────────────────────────────────────────────→ Add - ↓ - ReLU -``` - -## Troubleshooting - -### Issue: "Shape mismatch error" -**Solution**: Check that: -1. Input is in NHWC format -2. Previous layer output matches current layer input requirements -3. Flatten is used before Dense layers when needed - -### Issue: "Generated code doesn't run" -**Solution**: -1. Check TensorFlow is installed: `pip install tensorflow` -2. Verify input data is in NHWC format -3. Check that all required parameters are configured - -### Issue: "Padding doesn't work as expected" -**Solution**: -- Use `padding='same'` with `strides=1` to preserve dimensions -- Use `padding='valid'` for no padding (dimensions will shrink) - -## Migration from PyTorch - -If you have a PyTorch architecture and want to convert to TensorFlow: - -1. **Transpose data**: NCHW → NHWC - ```python - # PyTorch: [batch, channels, height, width] - # TensorFlow: [batch, height, width, channels] - tf_data = torch_data.permute(0, 2, 3, 1) # If using PyTorch - # Or: tf_data = np.transpose(data, (0, 2, 3, 1)) # If using NumPy - ``` - -2. **Update parameter names**: - - `out_channels` → `filters` - - `out_features` → `units` - -3. **Convert padding**: - - `padding=0` → `padding='valid'` - - `padding=k//2` (for odd k) → `padding='same'` - -## Resources - -- **TensorFlow Documentation**: https://tensorflow.org/api_docs/python/tf/keras -- **Keras Layers Guide**: https://keras.io/api/layers/ -- **PyDataset Documentation**: https://keras.io/api/utils/python_utils/#pydataset-class - -## Support - -For issues or questions: -1. Check error messages in the UI -2. Review this guide -3. Consult the full documentation in `TENSORFLOW_IMPLEMENTATION_COMPLETE.md` - ---- - -**Happy Model Building with TensorFlow! 🚀** diff --git a/docs/TENSORFLOW_SUMMARY.md b/docs/TENSORFLOW_SUMMARY.md deleted file mode 100644 index ff2fb70..0000000 --- a/docs/TENSORFLOW_SUMMARY.md +++ /dev/null @@ -1,232 +0,0 @@ -# TensorFlow Backend Implementation - Summary - -## ✅ Implementation Complete - -The TensorFlow backend for VisionForge has been successfully implemented and tested. Users can now select TensorFlow when creating projects and generate production-ready `tf.keras` code. - -## What Was Implemented - -### 1. **17 TensorFlow Node Definitions** -All nodes use `tf.keras.layers` APIs with NHWC (channels_last) data format: - -**Created Files:** -- `tensorflow/input.py` - Input layer -- `tensorflow/dataloader.py` - DataLoader using PyDataset -- `tensorflow/conv2d.py` - 2D Convolution -- `tensorflow/conv1d.py` - 1D Convolution -- `tensorflow/conv3d.py` - 3D Convolution -- `tensorflow/linear.py` - Dense layer -- `tensorflow/batchnorm2d.py` - Batch Normalization -- `tensorflow/dropout.py` - Dropout -- `tensorflow/maxpool2d.py` - Max Pooling -- `tensorflow/avgpool2d.py` - Average Pooling -- `tensorflow/adaptiveavgpool2d.py` - Global Average Pooling -- `tensorflow/flatten.py` - Flatten -- `tensorflow/lstm.py` - LSTM -- `tensorflow/gru.py` - GRU -- `tensorflow/embedding.py` - Embedding -- `tensorflow/concat.py` - Concatenate -- `tensorflow/add.py` - Element-wise Addition - -### 2. **Code Generation System** -**File:** `services/tensorflow_codegen.py` - -Generates 4 complete files: -1. **model.py** - `tf.keras.Model` subclass -2. **train.py** - Complete training pipeline -3. **dataset.py** - `tf.keras.utils.PyDataset` implementation -4. **config.py** - Hyperparameter configuration - -### 3. **Integration & Validation** -**Updated Files:** -- `tensorflow/__init__.py` - Export all nodes -- `views/export_views.py` - Route TensorFlow requests to code generator -- (Validation system already framework-agnostic) - -### 4. **Documentation** -- `TENSORFLOW_IMPLEMENTATION_COMPLETE.md` - Full technical documentation -- `TENSORFLOW_QUICK_REFERENCE.md` - User guide and quick reference - -## Key Features - -### ✅ Framework Parity -- All 17 PyTorch nodes have TensorFlow equivalents -- Same frontend experience, different backend - -### ✅ NHWC Data Format -- Proper channels_last format: `[batch, height, width, channels]` -- Accurate shape inference for all layers -- Proper padding calculations ('valid', 'same') - -### ✅ Error Handling -- Shape mismatch errors with detailed messages -- Missing configuration detection -- Invalid connection validation -- All errors passed to frontend for user correction - -### ✅ Production-Ready Code -- Valid `tf.keras.Model` classes -- Proper `training` parameter handling -- Complete training scripts with callbacks -- PyDataset implementation for custom data loading - -## Testing Results - -### ✅ All Tests Passed - -**Test Coverage:** -1. ✅ Node Registry - 17 nodes loaded -2. ✅ Shape Validation - Error messages clear and actionable -3. ✅ Shape Inference - NHWC calculations correct -4. ✅ Code Generation - Valid Python/TensorFlow syntax -5. ✅ Architecture Validation - Comprehensive error checking -6. ✅ All Node Types - Each node verified individually - -**Example Test Output:** -``` -✓ PyTorch nodes: 17 -✓ TensorFlow nodes: 17 -✓ Valid NHWC input accepted -✓ Invalid 2D input rejected with message -✓ Shape inference: [32,224,224,3] → [32,112,112,64] -✓ Generated code: valid Python syntax -``` - -## User Workflow - -### 1. Create Project -```typescript -{ - "name": "My Model", - "framework": "tensorflow" // Select TensorFlow -} -``` - -### 2. Build Architecture -- Drag nodes from palette -- Configure parameters -- Connect nodes -- VisionForge validates in real-time - -### 3. Export Code -```bash -POST /api/export -{ - "nodes": [...], - "edges": [...], - "format": "tensorflow" -} -``` - -### 4. Get Generated Files -```python -# model.py - TensorFlow model class -# train.py - Training script -# dataset.py - Data loading -# config.py - Hyperparameters -``` - -## Error Message Examples - -Users receive clear, actionable error messages: - -**Shape Mismatch:** -``` -"Requires 4D input [batch, height, width, channels], got 2D" -``` - -**Missing Config:** -``` -"Conv2D layer requires filters parameter" -Suggestion: "Configure the number of output filters in the config panel" -``` - -**Invalid Connection:** -``` -"DataLoader is a source node and cannot accept incoming connections" -``` - -## Technical Highlights - -### Parameter Mapping -``` -PyTorch → TensorFlow -out_channels → filters -out_features → units -stride → strides -padding (int) → padding ('valid'/'same') -``` - -### Data Format -``` -PyTorch: [batch, channels, height, width] (NCHW) -TensorFlow: [batch, height, width, channels] (NHWC) -``` - -### Model Structure -```python -# PyTorch -class Model(nn.Module): - def forward(self, x): - return x - -# TensorFlow -class Model(keras.Model): - def call(self, inputs, training=None): - return x -``` - -## Files Summary - -### Created (19 files): -- 17 TensorFlow node files -- 1 code generator -- 1 implementation guide -- Plus this summary - -### Modified (2 files): -- `tensorflow/__init__.py` -- `views/export_views.py` - -### Total Lines of Code: ~2,500+ - -## Benefits to Users - -1. **Framework Choice**: Select PyTorch or TensorFlow based on preference -2. **Consistency**: Same visual interface, different backend -3. **Accuracy**: NHWC format handled automatically -4. **Validation**: Real-time error checking with helpful messages -5. **Production-Ready**: Generated code follows TensorFlow best practices -6. **Complete Pipeline**: Model + training + data loading code - -## Next Steps (Optional Future Enhancements) - -While the current implementation is complete and production-ready, potential future enhancements could include: - -1. **Additional Nodes**: Attention layers, transformers, custom activations -2. **Model Conversion**: Tools to convert between PyTorch ↔ TensorFlow -3. **Optimization**: Mixed precision training, distributed training -4. **Deployment**: Export to TF Lite, TF.js -5. **Visualization**: Layer activation visualization, model graphs - -## Conclusion - -The TensorFlow backend is **fully implemented, tested, and production-ready**. Users can: - -✅ Select TensorFlow framework when creating projects -✅ Build architectures using 17 TensorFlow nodes -✅ Get real-time validation with detailed error messages -✅ Export production-ready `tf.keras` code -✅ Receive complete training pipeline - -All shape mismatches and configuration errors are caught and reported to the frontend with actionable suggestions, allowing users to correct issues easily. - ---- - -**Implementation Status: COMPLETE ✅** - -**Date:** November 9, 2025 -**Nodes Implemented:** 17 -**Code Quality:** Production-ready -**Testing:** All tests passed -**Documentation:** Complete diff --git a/docs/TENSORFLOW_SUPPORT.md b/docs/TENSORFLOW_SUPPORT.md deleted file mode 100644 index e69de29..0000000 diff --git a/project/block_manager/migrations/0002_block_is_expanded_block_repetition_metadata_and_more.py b/project/block_manager/migrations/0002_block_is_expanded_block_repetition_metadata_and_more.py new file mode 100644 index 0000000..96c97d6 --- /dev/null +++ b/project/block_manager/migrations/0002_block_is_expanded_block_repetition_metadata_and_more.py @@ -0,0 +1,48 @@ +# Generated by Django 5.2.8 on 2025-12-05 08:14 + +import django.db.models.deletion +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('block_manager', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='block', + name='is_expanded', + field=models.BooleanField(default=False), + ), + migrations.AddField( + model_name='block', + name='repetition_metadata', + field=models.JSONField(blank=True, null=True), + ), + migrations.CreateModel( + name='GroupBlockDefinition', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('name', models.CharField(max_length=255)), + ('description', models.TextField(blank=True, default='')), + ('category', models.CharField(default='utility', max_length=50)), + ('color', models.CharField(default='#9333ea', max_length=50)), + ('internal_structure', models.JSONField(blank=True, default=dict)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('project', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='group_definitions', to='block_manager.project')), + ], + options={ + 'ordering': ['-updated_at'], + 'unique_together': {('project', 'name')}, + }, + ), + migrations.AddField( + model_name='block', + name='group_definition', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='instances', to='block_manager.groupblockdefinition'), + ), + ] diff --git a/project/block_manager/migrations/0003_add_instance_config_overrides.py b/project/block_manager/migrations/0003_add_instance_config_overrides.py new file mode 100644 index 0000000..c804234 --- /dev/null +++ b/project/block_manager/migrations/0003_add_instance_config_overrides.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.8 on 2025-12-06 23:05 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('block_manager', '0002_block_is_expanded_block_repetition_metadata_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='block', + name='instance_config_overrides', + field=models.JSONField(blank=True, default=dict, null=True), + ), + ] diff --git a/project/block_manager/models.py b/project/block_manager/models.py index 0a6768f..dbb3c36 100644 --- a/project/block_manager/models.py +++ b/project/block_manager/models.py @@ -40,6 +40,33 @@ def __str__(self): return f"Architecture for {self.project.name}" +class GroupBlockDefinition(models.Model): + """Project-specific block template for group blocks""" + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + project = models.ForeignKey( + Project, + on_delete=models.CASCADE, + related_name='group_definitions' + ) + name = models.CharField(max_length=255) + description = models.TextField(blank=True, default='') + category = models.CharField(max_length=50, default='utility') + color = models.CharField(max_length=50, default='#9333ea') + + # Serialized structure: {nodes, edges, portMappings} + internal_structure = models.JSONField(default=dict, blank=True) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ['-updated_at'] + unique_together = ['project', 'name'] + + def __str__(self): + return f"{self.name} ({self.project.name})" + + class Block(models.Model): """Represents a single block/layer in the architecture""" id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) @@ -56,6 +83,19 @@ class Block(models.Model): config = models.JSONField(default=dict, blank=True) input_shape = models.JSONField(null=True, blank=True) output_shape = models.JSONField(null=True, blank=True) + + # Group block fields + group_definition = models.ForeignKey( + GroupBlockDefinition, + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name='instances' + ) + is_expanded = models.BooleanField(default=False) + repetition_metadata = models.JSONField(null=True, blank=True) + instance_config_overrides = models.JSONField(null=True, blank=True, default=dict) + created_at = models.DateTimeField(auto_now_add=True) class Meta: diff --git a/project/block_manager/serializers.py b/project/block_manager/serializers.py index 49a7d3c..a76b2c5 100644 --- a/project/block_manager/serializers.py +++ b/project/block_manager/serializers.py @@ -1,14 +1,176 @@ from rest_framework import serializers -from .models import Project, ModelArchitecture, Block, Connection +from .models import Project, ModelArchitecture, Block, Connection, GroupBlockDefinition + + +class GroupBlockDefinitionSerializer(serializers.ModelSerializer): + """ + Serializer for GroupBlockDefinition model + + Handles validation of internal structure to ensure data integrity + and prevent malformed group definitions from corrupting the database. + """ + internalNodes = serializers.ListField( + child=serializers.DictField(), + write_only=True, + required=False, + default=list + ) + internalEdges = serializers.ListField( + child=serializers.DictField(), + write_only=True, + required=False, + default=list + ) + portMappings = serializers.ListField( + child=serializers.DictField(), + write_only=True, + required=False, + default=list + ) + + class Meta: + model = GroupBlockDefinition + fields = [ + 'id', 'name', 'description', 'category', 'color', + 'internalNodes', 'internalEdges', 'portMappings', + 'created_at', 'updated_at' + ] + read_only_fields = ['id', 'created_at', 'updated_at'] + + def validate_internalNodes(self, value): + """Validate internal nodes structure""" + if not isinstance(value, list): + raise serializers.ValidationError("internalNodes must be a list") + + for node in value: + if not isinstance(node, dict): + raise serializers.ValidationError("Each node must be a dictionary") + if 'id' not in node: + raise serializers.ValidationError("Each node must have an 'id' field") + + return value + + def validate_internalEdges(self, value): + """Validate internal edges structure""" + if not isinstance(value, list): + raise serializers.ValidationError("internalEdges must be a list") + + for edge in value: + if not isinstance(edge, dict): + raise serializers.ValidationError("Each edge must be a dictionary") + required_fields = ['id', 'source', 'target'] + for field in required_fields: + if field not in edge: + raise serializers.ValidationError( + f"Each edge must have a '{field}' field" + ) + + return value + + def validate_portMappings(self, value): + """Validate port mappings structure""" + if not isinstance(value, list): + raise serializers.ValidationError("portMappings must be a list") + + for mapping in value: + if not isinstance(mapping, dict): + raise serializers.ValidationError("Each port mapping must be a dictionary") + required_fields = ['type', 'externalPortLabel', 'internalNodeId'] + for field in required_fields: + if field not in mapping: + raise serializers.ValidationError( + f"Each port mapping must have a '{field}' field" + ) + if mapping['type'] not in ['input', 'output']: + raise serializers.ValidationError( + "Port mapping type must be 'input' or 'output'" + ) + + return value + + def validate(self, data): + """Cross-field validation for internal structure consistency""" + internal_nodes = data.get('internalNodes', []) + internal_edges = data.get('internalEdges', []) + port_mappings = data.get('portMappings', []) + + # Build node ID set for validation + node_ids = {node['id'] for node in internal_nodes} + + # Validate edge references + for edge in internal_edges: + if edge['source'] not in node_ids: + raise serializers.ValidationError( + f"Edge references non-existent source node: {edge['source']}" + ) + if edge['target'] not in node_ids: + raise serializers.ValidationError( + f"Edge references non-existent target node: {edge['target']}" + ) + + # Validate port mapping references + for mapping in port_mappings: + if mapping['internalNodeId'] not in node_ids: + raise serializers.ValidationError( + f"Port mapping references non-existent node: {mapping['internalNodeId']}" + ) + + return data + + def create(self, validated_data): + """Create group definition with validated internal structure""" + internal_nodes = validated_data.pop('internalNodes', []) + internal_edges = validated_data.pop('internalEdges', []) + port_mappings = validated_data.pop('portMappings', []) + + validated_data['internal_structure'] = { + 'nodes': internal_nodes, + 'edges': internal_edges, + 'portMappings': port_mappings + } + + return super().create(validated_data) + + def update(self, instance, validated_data): + """Update group definition with validated internal structure""" + internal_nodes = validated_data.pop('internalNodes', None) + internal_edges = validated_data.pop('internalEdges', None) + port_mappings = validated_data.pop('portMappings', None) + + # Update internal structure if any component is provided + if any([internal_nodes is not None, internal_edges is not None, port_mappings is not None]): + current_structure = instance.internal_structure or {} + validated_data['internal_structure'] = { + 'nodes': internal_nodes if internal_nodes is not None else current_structure.get('nodes', []), + 'edges': internal_edges if internal_edges is not None else current_structure.get('edges', []), + 'portMappings': port_mappings if port_mappings is not None else current_structure.get('portMappings', []) + } + + return super().update(instance, validated_data) + + def to_representation(self, instance): + """Convert instance to dictionary for read operations""" + representation = super().to_representation(instance) + # Add internal structure fields for read operations + representation['internalNodes'] = instance.internal_structure.get('nodes', []) + representation['internalEdges'] = instance.internal_structure.get('edges', []) + representation['portMappings'] = instance.internal_structure.get('portMappings', []) + return representation class BlockSerializer(serializers.ModelSerializer): """Serializer for Block model""" + group_definition = GroupBlockDefinitionSerializer(read_only=True) + instance_config_overrides = serializers.JSONField(required=False, allow_null=True) + class Meta: model = Block fields = [ 'id', 'node_id', 'block_type', 'position_x', 'position_y', - 'config', 'input_shape', 'output_shape', 'created_at' + 'config', 'input_shape', 'output_shape', + 'group_definition', 'is_expanded', 'repetition_metadata', + 'instance_config_overrides', + 'created_at' ] read_only_fields = ['id', 'created_at'] @@ -69,6 +231,7 @@ class SaveArchitectureSerializer(serializers.Serializer): """Serializer for saving architecture from frontend""" nodes = serializers.ListField(child=serializers.DictField()) edges = serializers.ListField(child=serializers.DictField()) + groupDefinitions = serializers.ListField(child=serializers.DictField(), required=False, default=list) class ValidationResponseSerializer(serializers.Serializer): diff --git a/project/block_manager/services/pytorch_codegen.py b/project/block_manager/services/pytorch_codegen.py index 4e3dd02..cb1b7ef 100644 --- a/project/block_manager/services/pytorch_codegen.py +++ b/project/block_manager/services/pytorch_codegen.py @@ -5,13 +5,1439 @@ from typing import List, Dict, Any, Optional, Tuple from collections import deque +import logging +import json + +# Configure logging +logger = logging.getLogger(__name__) + + +# ============================================ +# Custom Exception Classes +# ============================================ + +class GroupDefinitionNotFoundError(Exception): + """Raised when a group block references a non-existent definition.""" + + def __init__(self, node_id: str, definition_id: str): + self.node_id = node_id + self.definition_id = definition_id + super().__init__( + f"Group block {node_id} references undefined definition {definition_id}" + ) + + +class ShapeMismatchError(Exception): + """Raised when internal layers have incompatible shapes.""" + + def __init__(self, block_name: str, layer_name: str, expected: Dict, actual: Dict): + self.block_name = block_name + self.layer_name = layer_name + self.expected = expected + self.actual = actual + super().__init__( + f"Shape mismatch in block '{block_name}' at layer '{layer_name}': " + f"expected {expected}, got {actual}" + ) + + +class CyclicDependencyError(Exception): + """Raised when internal structure contains cycles.""" + + def __init__(self, block_name: str, cycle_nodes: List[str]): + self.block_name = block_name + self.cycle_nodes = cycle_nodes + super().__init__( + f"Cyclic dependency detected in block '{block_name}': {' -> '.join(cycle_nodes)}" + ) + + +class UnsupportedNodeTypeError(Exception): + """Raised when encountering an unsupported node type during code generation.""" + + def __init__(self, node_id: str, node_type: str, framework: str): + self.node_id = node_id + self.node_type = node_type + self.framework = framework + super().__init__( + f"Unsupported node type '{node_type}' for {framework} in node {node_id}. " + f"Please use a supported layer type or implement this layer manually." + ) + + +class ShapeInferenceError(Exception): + """Raised when shape inference fails for a node.""" + + def __init__(self, node_id: str, node_type: str, reason: str, suggestion: str = None): + self.node_id = node_id + self.node_type = node_type + self.reason = reason + self.suggestion = suggestion + msg = f"Shape inference failed for node {node_id} ({node_type}): {reason}" + if suggestion: + msg += f"\nSuggestion: {suggestion}" + super().__init__(msg) + + +class MissingShapeDataError(Exception): + """Raised when required shape data is missing from upstream nodes.""" + + def __init__(self, node_id: str, upstream_node_id: str, missing_keys: List[str]): + self.node_id = node_id + self.upstream_node_id = upstream_node_id + self.missing_keys = missing_keys + super().__init__( + f"Node {node_id} requires shape data from upstream node {upstream_node_id}, " + f"but the following keys are missing: {', '.join(missing_keys)}. " + f"Check that the upstream node produces valid output shapes." + ) + + +# ============================================ +# Shape Data Validation Utility +# ============================================ + +def safe_get_shape_data( + shape_map: Dict[str, Dict[str, Any]], + node_id: str, + upstream_node_id: str, + required_keys: List[str], + default_values: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + """ + Safely retrieve shape data from upstream node with validation. + + This function ensures that shape data access is safe by: + 1. Checking that the upstream node exists in the shape map + 2. Validating that shape data is not None and is a dictionary + 3. Verifying all required keys are present + 4. Providing clear error messages when data is missing + + Args: + shape_map: Map of node IDs to shape information dictionaries + node_id: Current node ID (for error messages and tracing) + upstream_node_id: ID of upstream node to retrieve shape from + required_keys: List of required shape keys (e.g., ['out_channels', 'out_height']) + default_values: Optional default values to use if data is missing + + Returns: + Dictionary containing the requested shape data + + Raises: + MissingShapeDataError: If required data is missing and no defaults provided + ShapeInferenceError: If upstream shape data is invalid (None or not a dict) + + Example: + >>> shape_data = safe_get_shape_data( + ... shape_map, + ... 'conv2', + ... 'conv1', + ... ['out_channels', 'out_height', 'out_width'], + ... default_values={'out_channels': 64, 'out_height': 32, 'out_width': 32} + ... ) + >>> print(shape_data['out_channels']) + 64 + """ + result = {} + + # Check if upstream node exists in shape map + if upstream_node_id not in shape_map: + if default_values: + return default_values.copy() + raise MissingShapeDataError( + node_id=node_id, + upstream_node_id=upstream_node_id, + missing_keys=required_keys + ) + + upstream_shape = shape_map[upstream_node_id] + + # Validate upstream shape is not None and is a dict + if upstream_shape is None or not isinstance(upstream_shape, dict): + if default_values: + return default_values.copy() + raise ShapeInferenceError( + node_id=node_id, + node_type="unknown", + reason=f"Upstream node {upstream_node_id} has invalid shape data (None or not a dict)", + suggestion="Check that the upstream node is properly configured and connected" + ) + + # Extract required keys with validation + missing_keys = [] + for key in required_keys: + if key in upstream_shape: + result[key] = upstream_shape[key] + elif default_values and key in default_values: + result[key] = default_values[key] + else: + missing_keys.append(key) + + if missing_keys: + if default_values: + for key in missing_keys: + if key in default_values: + result[key] = default_values[key] + return result + raise MissingShapeDataError( + node_id=node_id, + upstream_node_id=upstream_node_id, + missing_keys=missing_keys + ) + + return result + + +class GroupBlockShapeComputer: + """ + Computes output shapes for group blocks by traversing internal structure. + + This class handles shape inference for group blocks by: + 1. Retrieving the internal structure of a group block + 2. Topologically sorting internal nodes + 3. Propagating shapes through the internal graph + 4. Mapping internal output nodes to external output ports + + Performance optimizations: + - Shape caching to avoid redundant computations + - Lazy topological sorting (cached per definition) + - Cache invalidation on definition changes + """ + + def __init__(self, group_definitions: Dict[str, Dict[str, Any]], cache_size: int = 1000, profiler: Optional['ShapeInferenceProfiler'] = None): + """ + Initialize with group definitions. + + Args: + group_definitions: Map of definition ID to definition dict + cache_size: Maximum number of cached shape computations (default: 1000) + profiler: Optional profiler for performance analysis + """ + self.group_definitions = group_definitions + self.shape_cache = {} # Cache computed shapes: {(group_def_id, input_shape_tuple): output_shape} + self.cache_size = cache_size + self.cache_hits = 0 + self.cache_misses = 0 + + # Cache for topological sorts per definition + self.topo_sort_cache = {} # {group_def_id: sorted_nodes} + + # Track definition versions for cache invalidation + self.definition_versions = {} # {group_def_id: version_hash} + self._initialize_definition_versions() + + # Optional profiler + self.profiler = profiler + + def compute_output_shape( + self, + group_def_id: str, + input_shape: Dict[str, Any] + ) -> Tuple[Optional[Dict[str, Any]], List[Exception]]: + """ + Compute output shape for a group block given input shape. + + Args: + group_def_id: ID of the group definition + input_shape: Input shape dict with keys like 'out_channels', 'out_height', etc. + + Returns: + Tuple of (output shape dict with computed dimensions or None, list of errors) + """ + import time + start_time = time.time() if self.profiler and self.profiler.enabled else None + + errors = [] + + # Check cache first + cache_key = (group_def_id, self._shape_to_tuple(input_shape)) + if cache_key in self.shape_cache: + self.cache_hits += 1 + logger.debug(f"Cache hit for group {group_def_id} (hit rate: {self.cache_hits}/{self.cache_hits + self.cache_misses})") + if start_time: + self.profiler.record_timing('compute_output_shape_cached', time.time() - start_time) + return self.shape_cache[cache_key], [] + + self.cache_misses += 1 + + # Get group definition + if group_def_id not in self.group_definitions: + error = GroupDefinitionNotFoundError('unknown', group_def_id) + logger.error(str(error)) + errors.append(error) + return None, errors + + group_def = self.group_definitions[group_def_id] + group_name = group_def.get('name', group_def_id) + internal_structure = group_def.get('internal_structure', {}) + + if not internal_structure: + logger.warning(f"Group {group_name} has no internal structure") + return input_shape.copy(), [] + + internal_nodes = internal_structure.get('nodes', []) + internal_edges = internal_structure.get('edges', []) + port_mappings = internal_structure.get('portMappings', []) + + # Handle edge case: no internal nodes + if not internal_nodes: + logger.warning(f"Group {group_name} has no internal nodes") + return input_shape.copy(), [] + + # Compute internal shapes + try: + internal_shape_map, shape_errors = self.compute_internal_shapes( + internal_nodes, + internal_edges, + port_mappings, + input_shape, + group_name + ) + + # Collect any errors from internal shape computation + errors.extend(shape_errors) + + if not internal_shape_map: + logger.error(f"Failed to compute internal shapes for group {group_name}") + return None, errors + + # Find output port mappings + output_ports = [pm for pm in port_mappings if pm.get('type') == 'output'] + + if not output_ports: + logger.warning(f"Group {group_name} has no output ports") + return input_shape.copy(), errors + + # Handle multiple output ports - return dict with shapes for each port + if len(output_ports) > 1: + logger.debug(f"Group {group_name} has {len(output_ports)} output ports") + output_shapes = {} + for idx, port in enumerate(output_ports): + internal_node_id = port.get('internalNodeId') + port_label = port.get('externalPortLabel', f'output_{idx}') + if internal_node_id in internal_shape_map: + output_shapes[port_label] = internal_shape_map[internal_node_id] + else: + error_msg = f"Output port '{port_label}' maps to unknown node {internal_node_id}" + logger.error(error_msg) + errors.append(Exception(error_msg)) + + # For now, return the first output port's shape for backward compatibility + # In the future, we should return all output shapes + if output_shapes: + first_output_shape = list(output_shapes.values())[0] + # Cache the result only if no errors + if not errors: + self.shape_cache[cache_key] = first_output_shape + self._evict_cache_if_needed() + return first_output_shape, errors + else: + return None, errors + + # For single output, return the shape of the mapped internal node + internal_node_id = output_ports[0].get('internalNodeId') + if internal_node_id in internal_shape_map: + output_shape = internal_shape_map[internal_node_id] + # Cache the result only if no errors + if not errors: + self.shape_cache[cache_key] = output_shape + self._evict_cache_if_needed() + return output_shape, errors + else: + error_msg = f"Output port maps to unknown node {internal_node_id}" + logger.error(error_msg) + errors.append(Exception(error_msg)) + return None, errors + + except CyclicDependencyError as e: + logger.error(f"Cyclic dependency in group {group_name}: {e}") + errors.append(e) + if start_time: + self.profiler.record_timing('compute_output_shape_error', time.time() - start_time) + return None, errors + except Exception as e: + logger.error(f"Error computing output shape for group {group_name}: {e}") + errors.append(e) + if start_time: + self.profiler.record_timing('compute_output_shape_error', time.time() - start_time) + return None, errors + finally: + # Record timing for successful computation + if start_time and not errors: + self.profiler.record_timing('compute_output_shape_success', time.time() - start_time) + + def compute_internal_shapes( + self, + internal_nodes: List[Dict], + internal_edges: List[Dict], + port_mappings: List[Dict], + external_input_shape: Dict[str, Any], + group_name: str = "unknown" + ) -> Tuple[Dict[str, Dict[str, Any]], List[Exception]]: + """ + Compute shapes for all internal nodes. + + Args: + internal_nodes: List of nodes inside the block + internal_edges: List of edges inside the block + port_mappings: Port mapping configuration + external_input_shape: Shape coming into the block + group_name: Name of the group block for error reporting + + Returns: + Tuple of (map of node_id to shape info, list of errors) + """ + import time + start_time = time.time() if self.profiler and self.profiler.enabled else None + + errors = [] + + # Edge case: no internal edges - validate that all nodes are input/output nodes + if not internal_edges: + logger.warning(f"Group {group_name} has no internal edges") + # Check if we have only input/output nodes + non_io_nodes = [n for n in internal_nodes if get_node_type(n) not in ('input', 'dataloader', 'output')] + if non_io_nodes: + error_msg = f"Group {group_name} has {len(non_io_nodes)} non-input/output nodes but no edges connecting them" + logger.error(error_msg) + errors.append(Exception(error_msg)) + # Still try to process nodes with default shapes + + # Topologically sort internal nodes (with caching) + try: + sorted_nodes = self._get_cached_topological_sort( + group_name, internal_nodes, internal_edges + ) + except Exception as e: + # Check if this is a cyclic dependency + if "cycle" in str(e).lower(): + cycle_error = CyclicDependencyError(group_name, []) + logger.error(str(cycle_error)) + errors.append(cycle_error) + else: + logger.error(f"Failed to topologically sort internal nodes: {e}") + errors.append(e) + return {}, errors + + # Build edge map for finding inputs + edge_map = {} + for edge in internal_edges: + target = edge.get('target') + source = edge.get('source') + if target not in edge_map: + edge_map[target] = [] + edge_map[target].append(source) + + # Initialize shape map + internal_shape_map = {} + + # Handle multiple input ports - map each to external input shape + input_ports = [pm for pm in port_mappings if pm.get('type') == 'input'] + + if len(input_ports) > 1: + logger.debug(f"Group {group_name} has {len(input_ports)} input ports") + # For multiple inputs, we need to handle them separately + # For now, we'll use the same external_input_shape for all inputs + # In the future, we should support different shapes for different inputs + for idx, input_port in enumerate(input_ports): + internal_node_id = input_port.get('internalNodeId') + port_label = input_port.get('externalPortLabel', f'input_{idx}') + if internal_node_id: + # Use the same shape for all inputs for now + internal_shape_map[internal_node_id] = external_input_shape.copy() + logger.debug(f"Mapped input port '{port_label}' to node {internal_node_id}") + else: + # Single input port + for input_port in input_ports: + internal_node_id = input_port.get('internalNodeId') + if internal_node_id: + internal_shape_map[internal_node_id] = external_input_shape.copy() + + # Detect disconnected subgraphs - nodes with no path from input ports + # Build a set of reachable nodes from input ports + reachable_nodes = set() + if input_ports: + # BFS from input nodes + from collections import deque + queue = deque() + for input_port in input_ports: + internal_node_id = input_port.get('internalNodeId') + if internal_node_id: + queue.append(internal_node_id) + reachable_nodes.add(internal_node_id) + + # Build forward edge map (source -> targets) + forward_edge_map = {} + for edge in internal_edges: + source = edge.get('source') + target = edge.get('target') + if source not in forward_edge_map: + forward_edge_map[source] = [] + forward_edge_map[source].append(target) + + # BFS to find all reachable nodes + while queue: + current = queue.popleft() + for neighbor in forward_edge_map.get(current, []): + if neighbor not in reachable_nodes: + reachable_nodes.add(neighbor) + queue.append(neighbor) + + # Check for disconnected nodes + all_node_ids = {node['id'] for node in internal_nodes} + disconnected_nodes = all_node_ids - reachable_nodes + # Filter out input/output nodes from disconnected check + disconnected_non_io = [nid for nid in disconnected_nodes + if get_node_type(next((n for n in internal_nodes if n['id'] == nid), {})) + not in ('input', 'dataloader', 'output')] + + if disconnected_non_io: + logger.warning(f"Group {group_name} has {len(disconnected_non_io)} disconnected nodes: {disconnected_non_io[:3]}") + # This is a warning, not an error - we'll still process what we can + + # Process each internal node in topological order + for node in sorted_nodes: + node_id = node['id'] + node_type = get_node_type(node) + config = node.get('data', {}).get('config', {}) + node_label = node.get('data', {}).get('label', node_type) + + # Skip if already computed (input nodes) + if node_id in internal_shape_map: + continue + + # Get incoming edges + incoming = edge_map.get(node_id, []) + + # Initialize shape info for this node + shape_info = {} + + # Handle different node types + if node_type == 'input': + # Input nodes should already be in the map + if node_id not in internal_shape_map: + internal_shape_map[node_id] = external_input_shape.copy() + continue + + # Handle nodes with multiple inputs (concat, add, etc.) + if node_type in ('concat', 'add') and len(incoming) > 1: + logger.debug(f"Processing {node_type} node {node_label} with {len(incoming)} inputs") + + # Validate that all inputs have compatible shapes + input_shapes = [] + for src_id in incoming: + if src_id in internal_shape_map: + input_shapes.append(internal_shape_map[src_id]) + else: + logger.warning(f"Input {src_id} for {node_type} node {node_label} has no computed shape") + + if not input_shapes: + # No valid inputs, use default + shape_info = {'out_channels': 64, 'out_height': 7, 'out_width': 7} + elif node_type == 'concat': + # For concat, sum the channels + total_channels = sum(s.get('out_channels', 0) for s in input_shapes) + # Use spatial dimensions from first input + shape_info['out_channels'] = total_channels + if 'out_height' in input_shapes[0]: + shape_info['out_height'] = input_shapes[0]['out_height'] + if 'out_width' in input_shapes[0]: + shape_info['out_width'] = input_shapes[0]['out_width'] + elif node_type == 'add': + # For add, channels must match - use first input's shape + shape_info = input_shapes[0].copy() + # Validate that all inputs have same channels + for idx, s in enumerate(input_shapes[1:], 1): + if s.get('out_channels') != shape_info.get('out_channels'): + error = ShapeMismatchError( + group_name, + node_label, + {'out_channels': shape_info.get('out_channels')}, + {'out_channels': s.get('out_channels')} + ) + logger.error(str(error)) + errors.append(error) + + internal_shape_map[node_id] = shape_info + continue + + elif node_type == 'conv2d': + # Get input channels from previous layer + if incoming and incoming[0] in internal_shape_map: + prev_shape = internal_shape_map[incoming[0]] + if 'out_channels' not in prev_shape: + # Shape mismatch: expected channels but got features + error = ShapeMismatchError( + group_name, + node_label, + {'out_channels': 'required'}, + prev_shape + ) + logger.error(str(error)) + errors.append(error) + shape_info['in_channels'] = 3 # Use default + else: + shape_info['in_channels'] = prev_shape.get('out_channels', 3) + else: + shape_info['in_channels'] = 3 + + # Output channels from config + shape_info['out_channels'] = config.get('out_channels', 64) + + # Calculate output spatial dimensions + if incoming and incoming[0] in internal_shape_map: + prev_shape = internal_shape_map[incoming[0]] + kernel_size = config.get('kernel_size', 3) + stride = config.get('stride', 1) + padding = config.get('padding', 0) + + if 'out_height' in prev_shape and 'out_width' in prev_shape: + shape_info['out_height'] = (prev_shape['out_height'] + 2*padding - kernel_size) // stride + 1 + shape_info['out_width'] = (prev_shape['out_width'] + 2*padding - kernel_size) // stride + 1 + + elif node_type == 'maxpool': + # Preserve channels, reduce spatial dimensions + if incoming and incoming[0] in internal_shape_map: + prev_shape = internal_shape_map[incoming[0]] + if 'out_channels' not in prev_shape: + # Shape mismatch: expected channels + error = ShapeMismatchError( + group_name, + node_label, + {'out_channels': 'required'}, + prev_shape + ) + logger.error(str(error)) + errors.append(error) + shape_info['in_channels'] = 64 # Use default + shape_info['out_channels'] = 64 + else: + shape_info['in_channels'] = prev_shape.get('out_channels', 64) + shape_info['out_channels'] = shape_info['in_channels'] + + kernel_size = config.get('kernel_size', 2) + stride = config.get('stride', 2) + padding = config.get('padding', 0) + + if 'out_height' in prev_shape and 'out_width' in prev_shape: + shape_info['out_height'] = (prev_shape['out_height'] + 2*padding - kernel_size) // stride + 1 + shape_info['out_width'] = (prev_shape['out_width'] + 2*padding - kernel_size) // stride + 1 + + elif node_type == 'flatten': + # Convert spatial dimensions to features + if incoming and incoming[0] in internal_shape_map: + prev_shape = internal_shape_map[incoming[0]] + channels = prev_shape.get('out_channels', 64) + height = prev_shape.get('out_height', 7) + width = prev_shape.get('out_width', 7) + shape_info['out_features'] = channels * height * width + + elif node_type == 'linear': + # Get input features from previous layer + if incoming and incoming[0] in internal_shape_map: + prev_shape = internal_shape_map[incoming[0]] + # Accept both 'out_features' (PyTorch) and 'out_units' (TensorFlow) + if 'out_features' not in prev_shape and 'out_units' not in prev_shape: + # Shape mismatch: expected features but got channels + error = ShapeMismatchError( + group_name, + node_label, + {'out_features': 'required'}, + prev_shape + ) + logger.error(str(error)) + errors.append(error) + shape_info['in_features'] = 512 # Use default + else: + # Use out_features if available, otherwise out_units + shape_info['in_features'] = prev_shape.get('out_features', prev_shape.get('out_units', 512)) + else: + shape_info['in_features'] = 512 + + # Output features from config + shape_info['out_features'] = config.get('out_features', 128) + + elif node_type == 'batchnorm' or node_type == 'batchnorm2d': + # Preserve dimensions, just need num_features + if incoming and incoming[0] in internal_shape_map: + prev_shape = internal_shape_map[incoming[0]] + if 'out_channels' not in prev_shape: + # Shape mismatch: expected channels but got features + error = ShapeMismatchError( + group_name, + node_label, + {'out_channels': 'required'}, + prev_shape + ) + logger.error(str(error)) + errors.append(error) + shape_info['num_features'] = 64 # Use default + shape_info['out_channels'] = 64 + else: + shape_info['num_features'] = prev_shape.get('out_channels', 64) + shape_info['out_channels'] = shape_info['num_features'] + if 'out_height' in prev_shape: + shape_info['out_height'] = prev_shape['out_height'] + if 'out_width' in prev_shape: + shape_info['out_width'] = prev_shape['out_width'] + + elif node_type == 'group': + # Handle nested group blocks recursively + nested_group_def_id = node.get('data', {}).get('groupDefinitionId') + + if not nested_group_def_id: + logger.warning(f"Nested group block {node_label} has no definition ID") + # Use input shape if available + if incoming and incoming[0] in internal_shape_map: + shape_info = internal_shape_map[incoming[0]].copy() + else: + shape_info = {'out_channels': 64, 'out_height': 7, 'out_width': 7} + elif not incoming: + logger.warning(f"Nested group block {node_label} has no incoming edges") + # Use default shape + shape_info = {'out_channels': 64, 'out_height': 7, 'out_width': 7} + elif incoming[0] not in internal_shape_map: + logger.warning(f"Nested group block {node_label} has incoming edge from node with no computed shape") + # Use default shape + shape_info = {'out_channels': 64, 'out_height': 7, 'out_width': 7} + else: + # Recursively compute nested group block shape + nested_input_shape = internal_shape_map[incoming[0]] + logger.debug(f"Recursively computing shape for nested group {node_label} (def: {nested_group_def_id})") + nested_output_shape, nested_errors = self.compute_output_shape(nested_group_def_id, nested_input_shape) + + # Collect errors from nested computation + errors.extend(nested_errors) + + if nested_output_shape: + shape_info = nested_output_shape + logger.debug(f"Nested group {node_label} output shape: {nested_output_shape}") + else: + # Fallback: copy input shape + shape_info = nested_input_shape.copy() + logger.warning(f"Failed to compute shape for nested group {node_label}, using input shape") + + else: + # For other layers, try to preserve shape from input + if incoming and incoming[0] in internal_shape_map: + prev_shape = internal_shape_map[incoming[0]] + shape_info.update(prev_shape) + + internal_shape_map[node_id] = shape_info + + # Record timing + if start_time: + self.profiler.record_timing('compute_internal_shapes', time.time() - start_time) + + return internal_shape_map, errors + + def _get_cached_topological_sort( + self, + group_name: str, + internal_nodes: List[Dict], + internal_edges: List[Dict] + ) -> List[Dict]: + """ + Get topologically sorted nodes with caching. + + Args: + group_name: Name of the group (for cache key) + internal_nodes: List of internal nodes + internal_edges: List of internal edges + + Returns: + List of topologically sorted nodes + """ + # Use group_name as cache key (assumes nodes/edges don't change for same group) + if group_name in self.topo_sort_cache: + logger.debug(f"Using cached topological sort for {group_name}") + return self.topo_sort_cache[group_name] + + # Compute topological sort + sorted_nodes = topological_sort(internal_nodes, internal_edges) + + # Cache the result + self.topo_sort_cache[group_name] = sorted_nodes + logger.debug(f"Cached topological sort for {group_name} ({len(sorted_nodes)} nodes)") + + return sorted_nodes + + def _shape_to_tuple(self, shape: Dict[str, Any]) -> tuple: + """ + Convert shape dict to tuple for use as cache key. + + Args: + shape: Shape dictionary + + Returns: + Tuple representation of shape + """ + # Create a sorted tuple of key-value pairs + return tuple(sorted(shape.items())) + + def _initialize_definition_versions(self): + """Initialize version hashes for all definitions.""" + for def_id, definition in self.group_definitions.items(): + self.definition_versions[def_id] = self._compute_definition_hash(definition) + + def _compute_definition_hash(self, definition: Dict[str, Any]) -> int: + """ + Compute a hash of the definition for cache invalidation. + + Args: + definition: Group block definition + + Returns: + Hash value representing the definition structure + """ + import json + # Hash the internal structure to detect changes + internal_structure = definition.get('internal_structure', {}) + # Convert to JSON string for consistent hashing + structure_str = json.dumps(internal_structure, sort_keys=True) + return hash(structure_str) + + def invalidate_cache_for_definition(self, group_def_id: str): + """ + Invalidate all cached data for a specific definition. + + Args: + group_def_id: ID of the group definition that changed + """ + # Remove shape cache entries for this definition + keys_to_remove = [key for key in self.shape_cache.keys() if key[0] == group_def_id] + for key in keys_to_remove: + del self.shape_cache[key] + + # Remove topological sort cache + if group_def_id in self.topo_sort_cache: + del self.topo_sort_cache[group_def_id] + + # Update version hash + if group_def_id in self.group_definitions: + self.definition_versions[group_def_id] = self._compute_definition_hash( + self.group_definitions[group_def_id] + ) + + logger.debug(f"Cache invalidated for definition {group_def_id}") + + def update_definition(self, group_def_id: str, new_definition: Dict[str, Any]): + """ + Update a group definition and invalidate related caches. + + Args: + group_def_id: ID of the group definition + new_definition: New definition data + """ + # Check if definition actually changed + old_hash = self.definition_versions.get(group_def_id) + new_hash = self._compute_definition_hash(new_definition) + + if old_hash != new_hash: + # Definition changed, invalidate caches + self.group_definitions[group_def_id] = new_definition + self.invalidate_cache_for_definition(group_def_id) + logger.info(f"Definition {group_def_id} updated and cache invalidated") + else: + # No structural change, just update the definition + self.group_definitions[group_def_id] = new_definition + logger.debug(f"Definition {group_def_id} updated (no structural change)") + + def _evict_cache_if_needed(self): + """Evict oldest cache entries if cache size limit is exceeded.""" + if len(self.shape_cache) > self.cache_size: + # Simple LRU: remove oldest 10% of entries + num_to_remove = max(1, len(self.shape_cache) // 10) + keys_to_remove = list(self.shape_cache.keys())[:num_to_remove] + for key in keys_to_remove: + del self.shape_cache[key] + logger.debug(f"Evicted {num_to_remove} cache entries (cache size: {len(self.shape_cache)})") + + def get_cache_stats(self) -> Dict[str, Any]: + """ + Get cache performance statistics. + + Returns: + Dictionary with cache statistics + """ + total_requests = self.cache_hits + self.cache_misses + hit_rate = (self.cache_hits / total_requests * 100) if total_requests > 0 else 0 + + return { + 'cache_size': len(self.shape_cache), + 'cache_limit': self.cache_size, + 'cache_hits': self.cache_hits, + 'cache_misses': self.cache_misses, + 'hit_rate': hit_rate, + 'topo_sort_cache_size': len(self.topo_sort_cache) + } + + def clear_cache(self): + """Clear all caches and reset statistics.""" + self.shape_cache.clear() + self.topo_sort_cache.clear() + self.cache_hits = 0 + self.cache_misses = 0 + logger.debug("All caches cleared") + + +class ShapeInferenceProfiler: + """ + Profiler for shape inference performance analysis. + + Tracks timing and statistics for shape inference operations + to identify performance bottlenecks in large architectures. + """ + + def __init__(self): + """Initialize the profiler.""" + self.timings = {} # {operation_name: [durations]} + self.enabled = False + + def enable(self): + """Enable profiling.""" + self.enabled = True + logger.info("Shape inference profiling enabled") + + def disable(self): + """Disable profiling.""" + self.enabled = False + logger.info("Shape inference profiling disabled") + + def record_timing(self, operation: str, duration: float): + """ + Record timing for an operation. + + Args: + operation: Name of the operation + duration: Duration in seconds + """ + if not self.enabled: + return + + if operation not in self.timings: + self.timings[operation] = [] + self.timings[operation].append(duration) + + def get_stats(self) -> Dict[str, Dict[str, float]]: + """ + Get profiling statistics. + + Returns: + Dictionary with statistics for each operation + """ + stats = {} + for operation, durations in self.timings.items(): + if durations: + stats[operation] = { + 'count': len(durations), + 'total': sum(durations), + 'mean': sum(durations) / len(durations), + 'min': min(durations), + 'max': max(durations) + } + return stats + + def print_report(self): + """Print a formatted profiling report.""" + if not self.timings: + print("No profiling data collected") + return + + print("\n" + "=" * 80) + print("Shape Inference Performance Report") + print("=" * 80) + + stats = self.get_stats() + for operation, data in sorted(stats.items(), key=lambda x: x[1]['total'], reverse=True): + print(f"\n{operation}:") + print(f" Count: {data['count']}") + print(f" Total: {data['total']:.4f}s") + print(f" Mean: {data['mean']:.4f}s") + print(f" Min: {data['min']:.4f}s") + print(f" Max: {data['max']:.4f}s") + + print("\n" + "=" * 80) + + def reset(self): + """Reset all profiling data.""" + self.timings.clear() + + +class PyTorchBlockGenerator: + """ + Generator for PyTorch nn.Module code for group blocks. + + Converts GroupBlockDefinition into reusable nn.Module subclasses + with proper initialization and forward pass logic. + """ + + def __init__( + self, + group_definitions: List[Dict[str, Any]], + shape_computer: Optional[GroupBlockShapeComputer] = None + ): + """ + Initialize the block generator. + + Args: + group_definitions: List of GroupBlockDefinition dictionaries + shape_computer: Optional shape computer for internal shape inference + """ + self.group_definitions = {defn['id']: defn for defn in group_definitions} + self.generated_classes = {} # Cache generated class code + self.shape_computer = shape_computer or GroupBlockShapeComputer(self.group_definitions) + + def generate_all_block_classes(self) -> str: + """ + Generate all block class definitions. + + Returns: + String containing all block class definitions + """ + if not self.group_definitions: + return "" + + code_parts = [] + code_parts.append("# ============================================") + code_parts.append("# Custom Block Definitions") + code_parts.append("# ============================================\n") + + for defn_id, definition in self.group_definitions.items(): + block_class = self.generate_block_class(definition) + code_parts.append(block_class) + code_parts.append("\n") + + return "\n".join(code_parts) + + def generate_block_class( + self, + definition: Dict[str, Any], + example_input_shape: Optional[Dict[str, Any]] = None + ) -> str: + """ + Generate nn.Module subclass for a single block definition. + + Args: + definition: GroupBlockDefinition dictionary + example_input_shape: Optional example input shape for computing internal shapes + + Returns: + String containing the complete block class definition + """ + block_name = definition['name'] + class_name = self._to_class_name(block_name) + description = definition.get('description', '') + + # Get internal structure + internal_structure = definition.get('internal_structure', {}) + internal_nodes = internal_structure.get('nodes', []) + internal_edges = internal_structure.get('edges', []) + port_mappings = internal_structure.get('portMappings', []) + + # Sort internal nodes topologically + sorted_nodes = topological_sort(internal_nodes, internal_edges) + + # Compute internal shapes if example provided + internal_shape_map = {} + if example_input_shape: + internal_shape_map, _ = self.shape_computer.compute_internal_shapes( + internal_nodes, + internal_edges, + port_mappings, + example_input_shape, + block_name + ) + else: + # Fallback to old behavior without shape computer + internal_shape_map, _ = infer_shapes(sorted_nodes, internal_edges) + + # Generate __init__ method + init_method = self._generate_init_method(sorted_nodes, internal_shape_map, port_mappings) + + # Generate forward method + forward_method = self._generate_forward_method( + sorted_nodes, internal_edges, internal_shape_map, port_mappings + ) + + # Build class docstring + docstring = self._generate_block_docstring( + block_name, description, port_mappings, sorted_nodes + ) + + # Assemble the complete class + class_code = f'''class {class_name}(nn.Module): + """{docstring}""" + +{init_method} + +{forward_method}''' + + # Cache the generated class + self.generated_classes[definition['id']] = class_name + + return class_code + + def _generate_init_method( + self, + nodes: List[Dict[str, Any]], + shape_map: Dict[str, Dict[str, Any]], + port_mappings: List[Dict[str, Any]] + ) -> str: + """Generate __init__ method with layer instantiation.""" + lines = [] + + # Detect which shape parameters are needed by scanning nodes + needs_in_channels = False + needs_in_features = False + needs_num_features = False + + for node in nodes: + node_type = get_node_type(node) + if node_type in ('input', 'dataloader', 'output'): + continue + if node_type == 'conv2d': + needs_in_channels = True + elif node_type == 'linear': + needs_in_features = True + elif node_type in ('batchnorm', 'batchnorm2d'): + needs_num_features = True + + # Generate __init__ signature with detected parameters + params = [] + if needs_in_channels: + params.append("in_channels=None") + if needs_in_features: + params.append("in_features=None") + if needs_num_features: + params.append("num_features=None") + + if params: + lines.append(f" def __init__(self, {', '.join(params)}):") + else: + lines.append(" def __init__(self):") + + lines.append(' """Initialize all internal layers."""') + lines.append(f" super().__init__()") + lines.append("") + + # Track which nodes need to be instantiated and which is first of each type + layer_count = {} + first_layer_of_type = {} + + for idx, node in enumerate(nodes): + node_id = node['id'] + node_type = get_node_type(node) + config = node.get('data', {}).get('config', {}) + shape_info = shape_map.get(node_id, {}) + + # Skip input/output nodes + if node_type in ('input', 'dataloader', 'output'): + continue + + # Track if this is the first layer of its type + is_first = node_type not in first_layer_of_type + if is_first: + first_layer_of_type[node_type] = node_id + + # Generate layer instantiation + layer_name = self._get_internal_layer_name(node_type, node_id, layer_count) + layer_class_name = self._get_layer_class_name_for_node(node_type, config) + + # Generate instantiation with proper arguments + instantiation = self._generate_layer_instantiation_line( + layer_name, layer_class_name, node_type, shape_info, config, is_first + ) + + if instantiation: + lines.append(f" {instantiation}") + + return "\n".join(lines) + + def _generate_forward_method( + self, + nodes: List[Dict[str, Any]], + edges: List[Dict[str, Any]], + shape_map: Dict[str, Dict[str, Any]], + port_mappings: List[Dict[str, Any]] + ) -> str: + """Generate forward method with internal connection logic.""" + lines = [] + + # Determine input parameters from port mappings + input_ports = [pm for pm in port_mappings if pm['type'] == 'input'] + output_ports = [pm for pm in port_mappings if pm['type'] == 'output'] + + # Generate method signature + if len(input_ports) == 1: + lines.append(" def forward(self, x: torch.Tensor) -> torch.Tensor:") + else: + param_names = [f"input_{i}" for i in range(len(input_ports))] + params = ", ".join([f"{name}: torch.Tensor" for name in param_names]) + lines.append(f" def forward(self, {params}) -> torch.Tensor:") + + lines.append(' """') + lines.append(' Forward pass through the block.') + lines.append('') + lines.append(' Args:') + if len(input_ports) == 1: + lines.append(' x: Input tensor') + else: + for i, port in enumerate(input_ports): + label = port.get('externalPortLabel', f'input_{i}') + lines.append(f' input_{i}: {label}') + lines.append('') + lines.append(' Returns:') + if len(output_ports) == 1: + lines.append(' Output tensor') + else: + lines.append(' Tuple of output tensors') + lines.append(' """') + + # Build edge map for finding inputs + edge_map = {} + for edge in edges: + target = edge.get('target') + source = edge.get('source') + if target not in edge_map: + edge_map[target] = [] + edge_map[target].append(source) + + # Map internal node IDs to variable names + var_map = {} + layer_count = {} + + # Map input ports to initial variables + for i, port in enumerate(input_ports): + internal_node_id = port['internalNodeId'] + if len(input_ports) == 1: + var_map[internal_node_id] = 'x' + else: + var_map[internal_node_id] = f'input_{i}' + + # Generate forward pass for each internal node + for node in nodes: + node_id = node['id'] + node_type = get_node_type(node) + config = node.get('data', {}).get('config', {}) + + # Skip input/output nodes + if node_type in ('input', 'dataloader', 'output'): + # Input nodes are already mapped + if node_id not in var_map: + var_map[node_id] = 'x' + continue + + # Get layer name + layer_name = self._get_internal_layer_name(node_type, node_id, layer_count) + + # Get input variable(s) + incoming = edge_map.get(node_id, []) + if not incoming: + # No incoming edges, might be an input node we missed + input_var = 'x' + elif len(incoming) == 1: + input_var = var_map.get(incoming[0], 'x') + else: + # Multiple inputs (for concat, add, etc.) + input_vars = [var_map.get(src, 'x') for src in incoming] + input_var = f"[{', '.join(input_vars)}]" + + # Generate output variable name (sanitize node_id to avoid hyphens) + output_var = f"x_{node_id[:8].replace('-', '_')}" + var_map[node_id] = output_var + + # Generate forward line + if node_type in ('concat', 'add'): + lines.append(f" {output_var} = self.{layer_name}({input_var})") + else: + lines.append(f" {output_var} = self.{layer_name}({input_var})") + + # Map output ports to return values + if len(output_ports) == 1: + output_node_id = output_ports[0]['internalNodeId'] + output_var = var_map.get(output_node_id, 'x') + lines.append(f" return {output_var}") + else: + output_vars = [] + for port in output_ports: + output_node_id = port['internalNodeId'] + output_vars.append(var_map.get(output_node_id, 'x')) + lines.append(f" return ({', '.join(output_vars)})") + + return "\n".join(lines) + + def _generate_block_docstring( + self, + block_name: str, + description: str, + port_mappings: List[Dict[str, Any]], + nodes: List[Dict[str, Any]] + ) -> str: + """Generate comprehensive docstring for block class.""" + lines = [] + lines.append(f"Custom Block: {block_name}") + lines.append("") + + if description: + lines.append(description) + lines.append("") + + lines.append("This block encapsulates a reusable subgraph of layers.") + lines.append("") + + # Document ports + input_ports = [pm for pm in port_mappings if pm['type'] == 'input'] + output_ports = [pm for pm in port_mappings if pm['type'] == 'output'] + + if input_ports: + lines.append("Input Ports:") + for port in input_ports: + label = port.get('externalPortLabel', 'input') + lines.append(f" - {label}") + + if output_ports: + lines.append("") + lines.append("Output Ports:") + for port in output_ports: + label = port.get('externalPortLabel', 'output') + lines.append(f" - {label}") + + lines.append("") + lines.append(f"Internal Layers: {len([n for n in nodes if get_node_type(n) not in ('input', 'dataloader', 'output')])}") + + return "\n ".join(lines) + + def _generate_layer_instantiation_line( + self, + layer_name: str, + layer_class_name: str, + node_type: str, + shape_info: Dict[str, Any], + config: Dict[str, Any], + is_first: bool = False + ) -> str: + """Generate layer instantiation line with proper arguments using shape_map values.""" + # Determine if layer needs shape arguments + if node_type == 'conv2d': + if is_first: + # Use parameter for first conv2d layer + return f"self.{layer_name} = {layer_class_name}(in_channels=in_channels)" + else: + # Use computed value from shape map (no hardcoded defaults) + in_channels = shape_info.get('in_channels') + if in_channels is not None: + return f"self.{layer_name} = {layer_class_name}(in_channels={in_channels})" + else: + # If shape inference failed, use parameter + logger.warning(f"No in_channels in shape_map for {layer_name}, using parameter") + return f"self.{layer_name} = {layer_class_name}(in_channels=in_channels)" + elif node_type == 'linear': + if is_first: + # Use parameter for first linear layer + return f"self.{layer_name} = {layer_class_name}(in_features=in_features)" + else: + # Use computed value from shape map (no hardcoded defaults) + in_features = shape_info.get('in_features') + if in_features is not None: + return f"self.{layer_name} = {layer_class_name}(in_features={in_features})" + else: + # If shape inference failed, use parameter + logger.warning(f"No in_features in shape_map for {layer_name}, using parameter") + return f"self.{layer_name} = {layer_class_name}(in_features=in_features)" + elif node_type in ('batchnorm', 'batchnorm2d'): + if is_first: + # Use parameter for first batchnorm layer + return f"self.{layer_name} = {layer_class_name}(num_features=num_features)" + else: + # Use computed value from shape map (no hardcoded defaults) + num_features = shape_info.get('num_features') + if num_features is not None: + return f"self.{layer_name} = {layer_class_name}(num_features={num_features})" + else: + # If shape inference failed, use parameter + logger.warning(f"No num_features in shape_map for {layer_name}, using parameter") + return f"self.{layer_name} = {layer_class_name}(num_features=num_features)" + else: + return f"self.{layer_name} = {layer_class_name}()" + + def _get_internal_layer_name( + self, + node_type: str, + node_id: str, + layer_count: Dict[str, int] + ) -> str: + """Generate unique layer variable name for internal node.""" + # Use node_id suffix for uniqueness (sanitize to avoid hyphens) + suffix = node_id[:8].replace('-', '_') + base_name = node_type.replace('_', '') + + # Track count for this type + if node_type not in layer_count: + layer_count[node_type] = 0 + layer_count[node_type] += 1 + + return f"{base_name}_{suffix}" + + def _get_layer_class_name_for_node( + self, + node_type: str, + config: Dict[str, Any] + ) -> str: + """Get the layer class name that will be used in the main model.""" + # These should match the class names generated by generate_layer_class + type_name = node_type.replace('_', '').title() + + if node_type == 'conv2d': + channels = config.get('out_channels', 64) + kernel = config.get('kernel_size', 3) + return f"{type_name}Layer_{channels}ch_{kernel}x{kernel}" + elif node_type == 'linear': + features = config.get('out_features', 128) + return f"{type_name}Layer_{features}units" + elif node_type == 'maxpool': + kernel = config.get('kernel_size', 2) + return f"{type_name}Layer_{kernel}x{kernel}" + elif node_type == 'custom': + name = config.get('name', 'CustomLayer') + safe_name = name.replace(' ', '_').replace('-', '_') + return f"CustomLayer_{safe_name}" + else: + # For other types, we'll need to generate a generic name + # This will be handled by the main code generation + return f"{type_name}Layer" + + def _to_class_name(self, name: str) -> str: + """Convert block name to valid Python class name.""" + import re + # Remove special characters and convert to PascalCase + name = re.sub(r'[^a-zA-Z0-9]', ' ', name) + name = ''.join(word.capitalize() for word in name.split()) + if not name: + return 'CustomBlock' + if name[0].isdigit(): + name = 'Block' + name + return name + 'Block' + + def get_block_class_name(self, definition_id: str) -> Optional[str]: + """ + Get the generated class name for a block definition. + + Args: + definition_id: ID of the GroupBlockDefinition + + Returns: + Class name if generated, None otherwise + """ + return self.generated_classes.get(definition_id) def generate_pytorch_code( nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]], - project_name: str = "GeneratedModel" -) -> Dict[str, str]: + project_name: str = "GeneratedModel", + group_definitions: Optional[List[Dict[str, Any]]] = None +) -> Tuple[Dict[str, str], List[Exception]]: """ Generate complete PyTorch code including model, training, and data loading. Each layer gets its own reusable class, all combined in a main model class. @@ -20,18 +1446,37 @@ def generate_pytorch_code( nodes: List of node dictionaries from architecture edges: List of edge dictionaries defining connections project_name: Name for the generated model class + group_definitions: Optional list of GroupBlockDefinition dictionaries Returns: - Dictionary with keys: 'model', 'train', 'dataset', 'config' + Tuple of (dictionary with keys: 'model', 'train', 'dataset', 'config', list of errors) """ # Topologically sort nodes sorted_nodes = topological_sort(nodes, edges) - # Infer shapes through the graph - shape_map = infer_shapes(sorted_nodes, edges) + # Initialize block generator if we have group definitions + block_generator = None + group_def_dict = None + shape_computer = None + if group_definitions: + # Convert list to dict for shape inference + group_def_dict = {defn['id']: defn for defn in group_definitions} + # Create shape computer for reuse + shape_computer = GroupBlockShapeComputer(group_def_dict) + # Create block generator with shape computer + block_generator = PyTorchBlockGenerator(group_definitions, shape_computer) + + # Infer shapes through the graph (now with group definitions) + shape_map, shape_errors = infer_shapes(sorted_nodes, edges, group_def_dict) + + # Validate computed shapes for critical issues + validation_errors = validate_shape_map(sorted_nodes, shape_map) + if validation_errors: + logger.warning(f"Shape validation found {len(validation_errors)} potential issues") + shape_errors.extend(validation_errors) # Generate different components - model_code = generate_model_file(sorted_nodes, edges, project_name, shape_map) + model_code = generate_model_file(sorted_nodes, edges, project_name, shape_map, block_generator, shape_errors) train_code = generate_training_script(project_name) dataset_code = generate_dataset_class(nodes) config_code = generate_config_file(nodes) @@ -41,7 +1486,7 @@ def generate_pytorch_code( 'train': train_code, 'dataset': dataset_code, 'config': config_code - } + }, shape_errors def generate_single_layer_class( @@ -190,14 +1635,70 @@ def topological_sort(nodes: List[Dict], edges: List[Dict]) -> List[Dict]: return [node_map[node_id] for node_id in sorted_ids if node_id in node_map] -def infer_shapes(nodes: List[Dict], edges: List[Dict]) -> Dict[str, Dict[str, Any]]: +def extract_output_shape_from_metadata(node: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Extract output shape from node's frontend-provided metadata. + + The frontend computes output shapes accurately during the visual design phase + and stores them in node.data.outputShape. This function extracts those + pre-computed shapes, which are considered authoritative. + + Args: + node: Node dictionary with potential data.outputShape metadata + + Returns: + Dictionary with shape keys (out_channels, out_features, etc.) or None if + metadata is incomplete/missing + """ + output_shape = node.get('data', {}).get('outputShape', {}) + if not output_shape or not isinstance(output_shape, dict): + return None + + dims = output_shape.get('dims', []) + if not dims: + return None + + shape_info = {} + + # PyTorch uses NCHW format: [batch, channels, height, width] + if len(dims) == 4: + shape_info['out_channels'] = dims[1] + shape_info['out_height'] = dims[2] + shape_info['out_width'] = dims[3] + elif len(dims) == 2: # [batch, features] - for Linear/Flatten output + shape_info['out_features'] = dims[1] + else: + # Unusual shape format - log for debugging but don't fail + logger.debug(f"Unusual output shape dims: {dims}") + return None + + return shape_info + + +def infer_shapes( + nodes: List[Dict], + edges: List[Dict], + group_definitions: Optional[Dict[str, Any]] = None +) -> Tuple[Dict[str, Dict[str, Any]], List[Exception]]: """ Infer input/output shapes for each layer in the graph. + Enhanced to handle group blocks properly. + + Args: + nodes: List of node dictionaries + edges: List of edge dictionaries + group_definitions: Optional map of group definition IDs to definitions Returns: - Dictionary mapping node_id to shape info: {'in_channels', 'out_channels', 'in_features', 'out_features', etc.} + Tuple of (dictionary mapping node_id to shape info, list of errors encountered) """ shape_map = {} + errors = [] + + # Initialize shape computer for group blocks + shape_computer = None + if group_definitions: + shape_computer = GroupBlockShapeComputer(group_definitions) # Build edge map for finding inputs edge_map = {} @@ -208,8 +1709,12 @@ def infer_shapes(nodes: List[Dict], edges: List[Dict]) -> Dict[str, Dict[str, An edge_map[target] = [] edge_map[target].append(source) - # Process nodes in order - for node in nodes: + # Topologically sort nodes to ensure we process layers in dependency order + # This is CRITICAL: we must compute upstream layer shapes before downstream layers + sorted_nodes = topological_sort(nodes, edges) + + # Process nodes in topological order + for node in sorted_nodes: node_id = node['id'] node_type = get_node_type(node) config = node.get('data', {}).get('config', {}) @@ -217,116 +1722,474 @@ def infer_shapes(nodes: List[Dict], edges: List[Dict]) -> Dict[str, Dict[str, An # Get incoming edges incoming = edge_map.get(node_id, []) - # Initialize shape info for this node - shape_info = {} + # ========== PHASE 1: Extract output metadata (if available) ========== + # Frontend provides accurate output shapes in metadata + metadata_shape = extract_output_shape_from_metadata(node) + shape_info = metadata_shape if metadata_shape else {} + + # ========== PHASE 2: Compute input dimensions from upstream nodes ========== + # Input dimensions ALWAYS come from upstream, regardless of metadata + # This is critical for layers like Conv2d, Linear, BatchNorm if node_type == 'input': - # Parse input shape - shape_str = config.get('shape', '[1, 3, 224, 224]') - try: - # Try to parse shape - import json - shape = json.loads(shape_str) - if len(shape) >= 4: - shape_info['out_channels'] = shape[1] # NCHW format - shape_info['out_height'] = shape[2] - shape_info['out_width'] = shape[3] - elif len(shape) >= 2: - shape_info['out_features'] = shape[1] - except: - shape_info['out_channels'] = 3 - shape_info['out_height'] = 224 - shape_info['out_width'] = 224 + # Input nodes have no upstream - parse from config if metadata doesn't exist + if not metadata_shape: + shape_str = config.get('shape', '[1, 3, 224, 224]') + try: + # Try to parse shape + shape = json.loads(shape_str) + if len(shape) >= 4: + shape_info['out_channels'] = shape[1] # NCHW format + shape_info['out_height'] = shape[2] + shape_info['out_width'] = shape[3] + elif len(shape) >= 2: + shape_info['out_features'] = shape[1] + except (json.JSONDecodeError, ValueError, KeyError, IndexError, TypeError) as e: + logger.warning( + f"Failed to parse input shape for node {node_id}: {e}. " + f"Using default shape [1, 3, 224, 224] (NCHW)" + ) + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason=f"Failed to parse shape configuration: {str(e)}", + suggestion="Check that the input shape is a valid JSON array like [1, 3, 224, 224]" + )) + shape_info['out_channels'] = 3 + shape_info['out_height'] = 224 + shape_info['out_width'] = 224 elif node_type == 'conv2d': - # Get input channels from previous layer + # Get input channels from upstream layer (ALWAYS required) if incoming and incoming[0] in shape_map: - shape_info['in_channels'] = shape_map[incoming[0]].get('out_channels', 3) + try: + upstream_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_channels'], + default_values={'out_channels': 3} + ) + shape_info['in_channels'] = upstream_shape['out_channels'] + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for node {node_id}: {e}. Using default.") + errors.append(e) + shape_info['in_channels'] = 3 else: shape_info['in_channels'] = 3 - # Output channels from config - shape_info['out_channels'] = config.get('out_channels', 64) - - # Calculate output spatial dimensions - if incoming and incoming[0] in shape_map: - prev_shape = shape_map[incoming[0]] - kernel_size = config.get('kernel_size', 3) - stride = config.get('stride', 1) - padding = config.get('padding', 0) - - if 'out_height' in prev_shape and 'out_width' in prev_shape: - shape_info['out_height'] = (prev_shape['out_height'] + 2*padding - kernel_size) // stride + 1 - shape_info['out_width'] = (prev_shape['out_width'] + 2*padding - kernel_size) // stride + 1 + # Output channels: use metadata if available, otherwise config + if 'out_channels' not in shape_info: + shape_info['out_channels'] = config.get('out_channels', 64) + + # Spatial dimensions: use metadata if available, otherwise calculate + if 'out_height' not in shape_info or 'out_width' not in shape_info: + if incoming and incoming[0] in shape_map: + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_height', 'out_width'], + default_values=None + ) + kernel_size = config.get('kernel_size', 3) + stride = config.get('stride', 1) + padding = config.get('padding', 0) + + shape_info['out_height'] = (prev_shape['out_height'] + 2*padding - kernel_size) // stride + 1 + shape_info['out_width'] = (prev_shape['out_width'] + 2*padding - kernel_size) // stride + 1 + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Could not compute spatial dimensions for conv2d {node_id}: {e}") + errors.append(e) elif node_type == 'maxpool': - # Preserve channels, reduce spatial dimensions + # MaxPool preserves channels from upstream if incoming and incoming[0] in shape_map: - prev_shape = shape_map[incoming[0]] - shape_info['in_channels'] = prev_shape.get('out_channels', 64) - shape_info['out_channels'] = shape_info['in_channels'] - - kernel_size = config.get('kernel_size', 2) - stride = config.get('stride', 2) - padding = config.get('padding', 0) - - if 'out_height' in prev_shape and 'out_width' in prev_shape: - shape_info['out_height'] = (prev_shape['out_height'] + 2*padding - kernel_size) // stride + 1 - shape_info['out_width'] = (prev_shape['out_width'] + 2*padding - kernel_size) // stride + 1 + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_channels'], + default_values={'out_channels': 64} + ) + shape_info['out_channels'] = prev_shape['out_channels'] + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for maxpool {node_id}: {e}") + errors.append(e) + shape_info['out_channels'] = 64 + else: + shape_info['out_channels'] = 64 + + # Spatial dimensions: use metadata if available, otherwise calculate + if 'out_height' not in shape_info or 'out_width' not in shape_info: + if incoming and incoming[0] in shape_map: + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_height', 'out_width'], + default_values={'out_height': 7, 'out_width': 7} + ) + kernel_size = config.get('kernel_size', 2) + stride = config.get('stride', 2) + padding = config.get('padding', 0) + + shape_info['out_height'] = (prev_shape['out_height'] + 2*padding - kernel_size) // stride + 1 + shape_info['out_width'] = (prev_shape['out_width'] + 2*padding - kernel_size) // stride + 1 + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Could not compute spatial dimensions for maxpool {node_id}: {e}") + errors.append(e) elif node_type == 'flatten': - # Convert spatial dimensions to features - if incoming and incoming[0] in shape_map: - prev_shape = shape_map[incoming[0]] - channels = prev_shape.get('out_channels', 64) - height = prev_shape.get('out_height', 7) - width = prev_shape.get('out_width', 7) - shape_info['out_features'] = channels * height * width + # Flatten converts spatial dimensions to features + # Use metadata if available, otherwise calculate from upstream + if 'out_features' not in shape_info: + if incoming and incoming[0] in shape_map: + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_channels', 'out_height', 'out_width'], + default_values={'out_channels': 64, 'out_height': 7, 'out_width': 7} + ) + channels = prev_shape['out_channels'] + height = prev_shape['out_height'] + width = prev_shape['out_width'] + shape_info['out_features'] = channels * height * width + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for flatten {node_id}: {e}") + errors.append(e) + shape_info['out_features'] = 3136 # 64 * 7 * 7 + else: + shape_info['out_features'] = 3136 # Default elif node_type == 'linear': - # Get input features from previous layer + # Get input features from upstream layer (ALWAYS required) if incoming and incoming[0] in shape_map: - shape_info['in_features'] = shape_map[incoming[0]].get('out_features', 512) + try: + upstream_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_features'], + default_values={'out_features': 512} + ) + shape_info['in_features'] = upstream_shape['out_features'] + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for linear {node_id}: {e}") + errors.append(e) + shape_info['in_features'] = 512 else: shape_info['in_features'] = 512 - # Output features from config - shape_info['out_features'] = config.get('out_features', 128) + # Output features: use metadata if available, otherwise config + if 'out_features' not in shape_info: + shape_info['out_features'] = config.get('out_features', 128) elif node_type == 'batchnorm': - # Preserve dimensions, just need num_features + # BatchNorm preserves all dimensions from upstream if incoming and incoming[0] in shape_map: - prev_shape = shape_map[incoming[0]] - shape_info['num_features'] = prev_shape.get('out_channels', 64) - shape_info['out_channels'] = shape_info['num_features'] - if 'out_height' in prev_shape: - shape_info['out_height'] = prev_shape['out_height'] - if 'out_width' in prev_shape: - shape_info['out_width'] = prev_shape['out_width'] + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_channels'], + default_values={'out_channels': 64} + ) + shape_info['num_features'] = prev_shape['out_channels'] + shape_info['out_channels'] = shape_info['num_features'] + # Copy spatial dimensions if they exist and not in metadata + if 'out_height' not in shape_info and 'out_height' in prev_shape: + shape_info['out_height'] = prev_shape['out_height'] + if 'out_width' not in shape_info and 'out_width' in prev_shape: + shape_info['out_width'] = prev_shape['out_width'] + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for batchnorm {node_id}: {e}") + errors.append(e) + shape_info['num_features'] = 64 + shape_info['out_channels'] = 64 + else: + shape_info['num_features'] = 64 + shape_info['out_channels'] = 64 + + elif node_type == 'group': + # Group blocks: Use metadata if available, otherwise compute from internal structure + if not metadata_shape: + # No metadata - compute output shape using GroupBlockShapeComputer + if shape_computer: + group_def_id = node.get('data', {}).get('groupDefinitionId') + + if group_def_id and incoming and incoming[0] in shape_map: + # Get input shape from upstream node + input_shape = shape_map[incoming[0]] + + # Compute output shape using internal structure + output_shape, shape_errors = shape_computer.compute_output_shape( + group_def_id, + input_shape + ) + + # Collect errors from shape computation + errors.extend(shape_errors) + + if output_shape: + shape_info = output_shape + logger.debug(f"Computed shape for group block {node_id}: {output_shape}") + else: + # Fallback: copy input shape + shape_info = input_shape.copy() + logger.warning(f"Failed to compute shape for group block {node_id}, using input shape") + elif group_def_id and not (incoming and incoming[0] in shape_map): + # Group definition exists but no valid input + shape_info = {'out_channels': 3, 'out_height': 224, 'out_width': 224} + logger.warning(f"Group block {node_id} has no valid input, using default shape") + elif not group_def_id and incoming and incoming[0] in shape_map: + # No definition found, copy input shape + shape_info = shape_map[incoming[0]].copy() + logger.warning(f"No group definition ID found for node {node_id}, using input shape") + else: + # No definition and no input, use default + shape_info = {'out_channels': 3, 'out_height': 224, 'out_width': 224} + logger.warning(f"Group block {node_id} has no definition ID and no input, using default shape") + else: + # No shape computer available, fallback to old behavior + if incoming and incoming[0] in shape_map: + prev_shape = shape_map[incoming[0]] + shape_info.update(prev_shape) + else: + shape_info['out_channels'] = 3 + shape_info['out_height'] = 224 + shape_info['out_width'] = 224 + logger.warning(f"No shape computer available for group block {node_id}, using fallback behavior") else: - # For other layers, try to preserve shape from input - if incoming and incoming[0] in shape_map: + # For other layers: Use metadata if available, otherwise preserve upstream shape + if not metadata_shape and incoming and incoming[0] in shape_map: prev_shape = shape_map[incoming[0]] shape_info.update(prev_shape) - + shape_map[node_id] = shape_info - return shape_map + return shape_map, errors + + +def validate_shape_map( + nodes: List[Dict], + shape_map: Dict[str, Dict[str, Any]] +) -> List[Exception]: + """ + Validate computed shape map for common critical issues. + + This catches problems that would cause runtime errors in generated code: + - Missing shape information + - Invalid dimensions (zero or negative) + - Type-specific requirements not met + + Args: + nodes: List of all nodes + shape_map: Computed shape mapping + + Returns: + List of validation errors (as exceptions for consistency with shape_errors) + """ + errors = [] + + for node in nodes: + node_id = node['id'] + node_type = get_node_type(node) + + # Skip non-layer nodes + if node_type in ('input', 'output', 'dataloader', 'group'): + continue + + shape_info = shape_map.get(node_id) + + # Critical: Shape info must exist + if not shape_info: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason="No shape information computed for node", + suggestion="Check that node has valid upstream connections and metadata" + )) + continue + + # Type-specific validation + if node_type == 'linear': + # Linear MUST have in_features + if 'in_features' not in shape_info: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason="Missing required in_features for Linear layer", + suggestion="Check upstream Flatten or Linear layer output shape" + )) + # in_features must be positive + elif shape_info.get('in_features', 0) <= 0: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason=f"Invalid in_features={shape_info.get('in_features')} (must be > 0)", + suggestion="Check upstream layer produces valid output shape" + )) + + elif node_type == 'conv2d': + # Conv2d MUST have in_channels + if 'in_channels' not in shape_info: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason="Missing required in_channels for Conv2d layer", + suggestion="Check upstream Conv2d or Input layer provides channels" + )) + + elif node_type == 'flatten': + # Flatten MUST produce out_features + if 'out_features' not in shape_info: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason="Flatten layer must produce out_features", + suggestion="Check upstream layer has spatial dimensions (NCHW format)" + )) + elif shape_info.get('out_features', 0) <= 0: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason=f"Invalid out_features={shape_info.get('out_features')} (must be > 0)", + suggestion="Check upstream layer output dimensions are valid" + )) + + return errors + + +def collect_all_nodes_with_internals( + main_nodes: List[Dict], + block_generator: Optional[PyTorchBlockGenerator] = None +) -> List[Tuple[Dict, int, str]]: + """ + Collect all nodes including internal nodes from group blocks. + Returns list of tuples: (node, index, source_context) + source_context is either 'main' or 'group_{group_def_id}' + + This ensures we generate layer classes for ALL nodes, not just main model nodes. + """ + all_nodes = [] + node_index = 0 + + # Add main model nodes + for node in main_nodes: + all_nodes.append((node, node_index, 'main')) + node_index += 1 + + # Add internal nodes from group definitions + if block_generator: + for group_def_id, group_def in block_generator.group_definitions.items(): + internal_structure = group_def.get('internal_structure', {}) + internal_nodes = internal_structure.get('nodes', []) + + for internal_node in internal_nodes: + node_type = get_node_type(internal_node) + # Skip input/output nodes + if node_type not in ('input', 'dataloader', 'output'): + all_nodes.append((internal_node, node_index, f'group_{group_def_id}')) + node_index += 1 + + return all_nodes + + +def get_layer_signature(node: Dict, config: Dict[str, Any], node_type: str) -> str: + """ + Generate a unique signature for a layer based on its type and config. + Used for deduplication - layers with same signature can share the same class. + """ + if node_type == 'conv2d': + return f"conv2d_{config.get('out_channels', 64)}_{config.get('kernel_size', 3)}_{config.get('stride', 1)}_{config.get('padding', 0)}_{config.get('dilation', 1)}" + elif node_type == 'linear': + return f"linear_{config.get('out_features', 128)}_{config.get('bias', True)}" + elif node_type == 'maxpool': + return f"maxpool_{config.get('kernel_size', 2)}_{config.get('stride', 2)}_{config.get('padding', 0)}" + elif node_type == 'dropout': + return f"dropout_{config.get('p', 0.5)}" + elif node_type == 'batchnorm': + return f"batchnorm_{config.get('eps', 1e-5)}_{config.get('momentum', 0.1)}_{config.get('affine', True)}" + elif node_type == 'softmax': + return f"softmax_{config.get('dim', 1)}" + elif node_type == 'attention': + return f"attention_{config.get('embed_dim', 512)}_{config.get('num_heads', 8)}_{config.get('dropout', 0.0)}" + elif node_type == 'custom': + return f"custom_{config.get('name', 'CustomLayer')}" + else: + # For layers without config (relu, flatten, etc.) + return node_type def generate_model_file( nodes: List[Dict], edges: List[Dict], project_name: str, - shape_map: Dict[str, Dict[str, Any]] + shape_map: Dict[str, Dict[str, Any]], + block_generator: Optional[PyTorchBlockGenerator] = None, + shape_errors: Optional[List[Exception]] = None ) -> str: - """Generate complete model.py file with layer classes and main model class""" + """ + Generate complete model.py file with layer classes and main model class. + + Args: + nodes: List of node dictionaries + edges: List of edge dictionaries + project_name: Name for the generated model class + shape_map: Dictionary mapping node_id to shape info + block_generator: Optional block generator for group blocks + shape_errors: Optional list of errors encountered during shape inference + + Returns: + String containing the complete model.py file + """ + if shape_errors is None: + shape_errors = [] class_name = to_class_name(project_name) - # Generate individual layer classes + # Generate block class definitions FIRST (if any) - this populates the cache + block_classes_code = "" + if block_generator: + block_classes_code = block_generator.generate_all_block_classes() + + # COLLECT ALL NODES (main + internal from groups) and generate layer classes + all_nodes_to_generate = collect_all_nodes_with_internals(nodes, block_generator) + + # DEDUPLICATE by signature and generate layer classes + seen_signatures = set() layer_classes = [] + + for node, idx, source_context in all_nodes_to_generate: + node_type = get_node_type(node) + config = node.get('data', {}).get('config', {}) + node_id = node['id'] + + # Get shape info (use shape_map for main nodes, extract for internal) + if source_context == 'main': + shape_info = shape_map.get(node_id, {}) + else: + shape_info = extract_shape_info_from_node(node) + + # Generate signature for deduplication + signature = get_layer_signature(node, config, node_type) + + # Only generate if we haven't seen this signature before + if signature not in seen_signatures: + seen_signatures.add(signature) + layer_class_code = generate_layer_class(node, idx, config, node_type, shape_info) + if layer_class_code: + layer_classes.append(layer_class_code) + + # Now generate layer instantiations and forward pass for MAIN MODEL ONLY layer_instantiations = [] forward_pass_lines = [] @@ -352,15 +2215,97 @@ def generate_model_file( var_map[node_id] = 'x' if not var_map else 'x' continue - # Generate layer class - layer_class_code = generate_layer_class(node, idx, config, node_type, shape_info) - if layer_class_code: - layer_classes.append(layer_class_code) + # Handle group blocks differently + if node_type == 'group': + # Get the group definition ID + group_def_id = node.get('data', {}).get('groupDefinitionId') + + if block_generator and group_def_id: + # Use the block class name from the generator + block_class_name = block_generator.get_block_class_name(group_def_id) + + if block_class_name: + layer_name = f"block_{node_id.replace('-', '_')}" + + # Get upstream node's output shape from shape_map + incoming = edge_map.get(node_id, []) + params = [] + + if incoming and incoming[0] in shape_map: + # Get upstream node's output shape + upstream_shape = shape_map[incoming[0]] + + # Extract in_channels or in_features from upstream shape + # Pass in_channels if the upstream outputs channels (convolutional layers) + if 'out_channels' in upstream_shape: + in_channels = upstream_shape['out_channels'] + params.append(f"in_channels={in_channels}") + logger.debug(f"Block {node_id}: passing in_channels={in_channels} from upstream node {incoming[0]}") + + # Pass in_features if the upstream outputs features (linear layers) + elif 'out_features' in upstream_shape: + in_features = upstream_shape['out_features'] + params.append(f"in_features={in_features}") + logger.debug(f"Block {node_id}: passing in_features={in_features} from upstream node {incoming[0]}") + + # Pass num_features if the upstream outputs num_features (batch norm) + elif 'num_features' in upstream_shape: + num_features = upstream_shape['num_features'] + params.append(f"num_features={num_features}") + logger.debug(f"Block {node_id}: passing num_features={num_features} from upstream node {incoming[0]}") + else: + # Upstream shape exists but doesn't have expected keys + logger.warning(f"Block {node_id}: upstream shape {upstream_shape} doesn't contain expected keys") + else: + # Handle case where no upstream exists (use input node shape) + # Look for input nodes in the graph + input_nodes = [n for n in nodes if get_node_type(n) == 'input'] + if input_nodes and input_nodes[0]['id'] in shape_map: + input_shape = shape_map[input_nodes[0]['id']] + + # Use input node's output shape + if 'out_channels' in input_shape: + in_channels = input_shape['out_channels'] + params.append(f"in_channels={in_channels}") + logger.debug(f"Block {node_id}: no upstream, using input shape in_channels={in_channels}") + elif 'out_features' in input_shape: + in_features = input_shape['out_features'] + params.append(f"in_features={in_features}") + logger.debug(f"Block {node_id}: no upstream, using input shape in_features={in_features}") + else: + logger.warning(f"Block {node_id}: input shape {input_shape} doesn't contain expected keys") + else: + # No upstream and no input node, use defaults + logger.warning(f"Block {node_id}: no upstream connection and no input node found") + + # Generate instantiation with computed parameters + # Each instance gets independent shape computation based on its position in the graph + if params: + layer_instantiations.append(f"self.{layer_name} = {block_class_name}({', '.join(params)}) # Instance at position {idx}") + else: + layer_instantiations.append(f"self.{layer_name} = {block_class_name}() # Instance at position {idx}") + + # Generate forward pass line + input_var = get_input_variable(incoming, var_map) + output_var = 'x' + forward_pass_lines.append(f"{output_var} = self.{layer_name}({input_var})") + var_map[node_id] = output_var + else: + # Block class not found, skip + logger.warning(f"Block class not found for group definition {group_def_id}") + var_map[node_id] = 'x' + else: + # No block generator or definition ID, skip + logger.warning(f"No block generator or definition ID for node {node_id}") + var_map[node_id] = 'x' + continue + + # For regular nodes, we already generated the layer class above (no need to generate again) # Generate layer instantiation for __init__ layer_name = get_layer_variable_name(node_type, idx, config) layer_class_name = get_layer_class_name(node_type, idx, config) - layer_init = generate_layer_instantiation(layer_class_name, layer_name, shape_info) + layer_init = generate_layer_instantiation(layer_class_name, layer_name, shape_info, node_type) if layer_init: layer_instantiations.append(layer_init) @@ -390,6 +2335,10 @@ def generate_model_file( ''' + # Add block class definitions (already generated at the start) + if block_classes_code: + code += block_classes_code + '\n\n' + # Add all layer class definitions for layer_class in layer_classes: code += layer_class + '\n\n' @@ -472,6 +2421,12 @@ def generate_layer_class( ) -> Optional[str]: """Generate a complete layer class definition with documentation""" + # Special node types that don't generate individual layer classes: + # - input/output/dataloader: Architectural markers for graph structure + # - group: Reusable components generated separately by BlockGenerator + if node_type in ('input', 'output', 'dataloader', 'group'): + return None + class_name = get_layer_class_name(node_type, idx, config) if node_type == 'conv2d': @@ -906,26 +2861,44 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Replace this with your custom logic return x''' - return None + # If we reach here, the node type is not supported + raise UnsupportedNodeTypeError( + node_id=node.get('id', 'unknown'), + node_type=node_type, + framework='PyTorch' + ) def generate_layer_instantiation( class_name: str, layer_name: str, - shape_info: Dict[str, Any] + shape_info: Dict[str, Any], + node_type: str = None ) -> str: - """Generate layer instantiation line for __init__ method""" - # Determine if layer needs arguments - if 'in_channels' in shape_info: + """ + Generate layer instantiation line for __init__ method. + + Only certain layer types need shape parameters: + - Conv2d: needs in_channels + - Linear: needs in_features + - BatchNorm: needs num_features + + Other layers (Dropout, ReLU, Flatten, etc.) are instantiated with no parameters + or only their specific configuration parameters (handled in layer class __init__). + """ + # Only add shape parameters for layers that actually need them + if node_type == 'conv2d' and 'in_channels' in shape_info: in_ch = shape_info['in_channels'] return f"self.{layer_name} = {class_name}(in_channels={in_ch}) # Input: {in_ch} channels" - elif 'in_features' in shape_info: + elif node_type == 'linear' and 'in_features' in shape_info: in_feat = shape_info['in_features'] return f"self.{layer_name} = {class_name}(in_features={in_feat}) # Input: {in_feat} features" - elif 'num_features' in shape_info: + elif node_type in ('batchnorm', 'batchnorm2d') and 'num_features' in shape_info: num_feat = shape_info['num_features'] return f"self.{layer_name} = {class_name}(num_features={num_feat}) # {num_feat} features" else: + # For all other layers (Dropout, ReLU, Flatten, MaxPool, etc.): + # Instantiate with no parameters - their config is baked into the class definition return f"self.{layer_name} = {class_name}()" diff --git a/project/block_manager/services/tensorflow_codegen.py b/project/block_manager/services/tensorflow_codegen.py index 795db28..70e2bca 100644 --- a/project/block_manager/services/tensorflow_codegen.py +++ b/project/block_manager/services/tensorflow_codegen.py @@ -5,13 +5,469 @@ from typing import List, Dict, Any, Optional, Tuple from collections import deque +import logging + +# Import shared utilities and exceptions from PyTorch codegen (framework-agnostic) +from .pytorch_codegen import ( + GroupBlockShapeComputer, + GroupDefinitionNotFoundError, + ShapeMismatchError, + CyclicDependencyError, + UnsupportedNodeTypeError, + ShapeInferenceError, + MissingShapeDataError, + safe_get_shape_data +) + +# Configure logging +logger = logging.getLogger(__name__) + + +class TensorFlowBlockGenerator: + """ + Generator for TensorFlow/Keras tf.keras.Model code for group blocks. + + Converts GroupBlockDefinition into reusable tf.keras.Model subclasses + with proper initialization and call method logic. + """ + + def __init__( + self, + group_definitions: List[Dict[str, Any]], + shape_computer: Optional[GroupBlockShapeComputer] = None + ): + """ + Initialize the block generator. + + Args: + group_definitions: List of GroupBlockDefinition dictionaries + shape_computer: Optional shape computer for internal shape inference + """ + self.group_definitions = {defn['id']: defn for defn in group_definitions} + self.generated_classes = {} # Cache generated class code + self.shape_computer = shape_computer or GroupBlockShapeComputer(self.group_definitions) + + def generate_all_block_classes(self) -> str: + """ + Generate all block class definitions. + + Returns: + String containing all block class definitions + """ + if not self.group_definitions: + return "" + + code_parts = [] + code_parts.append("# ============================================") + code_parts.append("# Custom Block Definitions") + code_parts.append("# ============================================\n") + + for defn_id, definition in self.group_definitions.items(): + block_class = self.generate_block_class(definition) + code_parts.append(block_class) + code_parts.append("\n") + + return "\n".join(code_parts) + + def generate_block_class( + self, + definition: Dict[str, Any], + example_input_shape: Optional[Dict[str, Any]] = None + ) -> str: + """ + Generate tf.keras.Model subclass for a single block definition. + + Args: + definition: GroupBlockDefinition dictionary + example_input_shape: Optional example input shape for computing internal shapes + + Returns: + String containing the complete block class definition + """ + block_name = definition['name'] + class_name = self._to_class_name(block_name) + description = definition.get('description', '') + + # Get internal structure + internal_structure = definition.get('internal_structure', {}) + internal_nodes = internal_structure.get('nodes', []) + internal_edges = internal_structure.get('edges', []) + port_mappings = internal_structure.get('portMappings', []) + + # Sort internal nodes topologically + sorted_nodes = topological_sort(internal_nodes, internal_edges) + + # Compute internal shapes if example provided + internal_shape_map = {} + if example_input_shape: + internal_shape_map, _ = self.shape_computer.compute_internal_shapes( + internal_nodes, + internal_edges, + port_mappings, + example_input_shape, + block_name + ) + else: + # Fallback to old behavior without shape computer + internal_shape_map, _ = infer_shapes(sorted_nodes, internal_edges) + + # Generate __init__ method + init_method = self._generate_init_method(sorted_nodes, internal_shape_map, port_mappings) + + # Generate call method + call_method = self._generate_call_method( + sorted_nodes, internal_edges, internal_shape_map, port_mappings + ) + + # Build class docstring + docstring = self._generate_block_docstring( + block_name, description, port_mappings, sorted_nodes + ) + + # Assemble the complete class + class_code = f'''class {class_name}(keras.Model): + """{docstring}""" + +{init_method} + +{call_method}''' + + # Cache the generated class + self.generated_classes[definition['id']] = class_name + + return class_code + + def _generate_init_method( + self, + nodes: List[Dict[str, Any]], + shape_map: Dict[str, Dict[str, Any]], + port_mappings: List[Dict[str, Any]] + ) -> str: + """Generate __init__ method with layer instantiation.""" + lines = [] + + # Detect which shape parameters are needed by scanning nodes + needs_in_channels = False + needs_in_features = False + needs_num_features = False + + for node in nodes: + node_type = get_node_type(node) + if node_type in ('input', 'dataloader', 'output'): + continue + if node_type == 'conv2d': + needs_in_channels = True + elif node_type == 'linear': + needs_in_features = True + elif node_type in ('batchnorm', 'batchnorm2d'): + needs_num_features = True + + # Generate __init__ signature with detected parameters + params = [] + if needs_in_channels: + params.append("in_channels=None") + if needs_in_features: + params.append("in_features=None") + if needs_num_features: + params.append("num_features=None") + + if params: + lines.append(f" def __init__(self, {', '.join(params)}):") + else: + lines.append(" def __init__(self):") + + lines.append(' """Initialize all internal layers."""') + lines.append(f" super().__init__()") + lines.append("") + + # Track which nodes need to be instantiated and which is first of each type + layer_count = {} + first_layer_of_type = {} + + for idx, node in enumerate(nodes): + node_id = node['id'] + node_type = get_node_type(node) + config = node.get('data', {}).get('config', {}) + shape_info = shape_map.get(node_id, {}) + + # Skip input/output nodes + if node_type in ('input', 'dataloader', 'output'): + continue + + # Track if this is the first layer of its type + is_first = node_type not in first_layer_of_type + if is_first: + first_layer_of_type[node_type] = node_id + + # Generate layer instantiation + layer_name = self._get_internal_layer_name(node_type, node_id, layer_count) + layer_class_name = self._get_layer_class_name_for_node(node_type, config) + + # Generate instantiation with proper arguments + instantiation = self._generate_layer_instantiation_line( + layer_name, layer_class_name, node_type, shape_info, config, is_first + ) + + if instantiation: + lines.append(f" {instantiation}") + + return "\n".join(lines) + + def _generate_call_method( + self, + nodes: List[Dict[str, Any]], + edges: List[Dict[str, Any]], + shape_map: Dict[str, Dict[str, Any]], + port_mappings: List[Dict[str, Any]] + ) -> str: + """Generate call method with internal connection logic.""" + lines = [] + + # Determine input parameters from port mappings + input_ports = [pm for pm in port_mappings if pm['type'] == 'input'] + output_ports = [pm for pm in port_mappings if pm['type'] == 'output'] + + # Generate method signature + if len(input_ports) == 1: + lines.append(" def call(self, inputs: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:") + else: + param_names = [f"input_{i}" for i in range(len(input_ports))] + params = ", ".join([f"{name}: tf.Tensor" for name in param_names]) + lines.append(f" def call(self, {params}, training: Optional[bool] = None) -> tf.Tensor:") + + lines.append(' """') + lines.append(' Forward pass through the block.') + lines.append('') + lines.append(' Args:') + if len(input_ports) == 1: + lines.append(' inputs: Input tensor in NHWC format') + else: + for i, port in enumerate(input_ports): + label = port.get('externalPortLabel', f'input_{i}') + lines.append(f' input_{i}: {label}') + lines.append(' training: Whether in training mode') + lines.append('') + lines.append(' Returns:') + if len(output_ports) == 1: + lines.append(' Output tensor') + else: + lines.append(' Tuple of output tensors') + lines.append(' """') + + # Build edge map for finding inputs + edge_map = {} + for edge in edges: + target = edge.get('target') + source = edge.get('source') + if target not in edge_map: + edge_map[target] = [] + edge_map[target].append(source) + + # Map internal node IDs to variable names + var_map = {} + layer_count = {} + + # Map input ports to initial variables + for i, port in enumerate(input_ports): + internal_node_id = port['internalNodeId'] + if len(input_ports) == 1: + var_map[internal_node_id] = 'inputs' + else: + var_map[internal_node_id] = f'input_{i}' + + # Generate forward pass for each internal node + for node in nodes: + node_id = node['id'] + node_type = get_node_type(node) + config = node.get('data', {}).get('config', {}) + + # Skip input/output nodes + if node_type in ('input', 'dataloader', 'output'): + # Input nodes are already mapped + if node_id not in var_map: + var_map[node_id] = 'inputs' + continue + + # Get layer name + layer_name = self._get_internal_layer_name(node_type, node_id, layer_count) + + # Get input variable(s) + incoming = edge_map.get(node_id, []) + if not incoming: + # No incoming edges, might be an input node we missed + input_var = 'inputs' + elif len(incoming) == 1: + input_var = var_map.get(incoming[0], 'inputs') + else: + # Multiple inputs (for concat, add, etc.) + input_vars = [var_map.get(src, 'inputs') for src in incoming] + input_var = f"[{', '.join(input_vars)}]" + + # Generate output variable name (sanitize node_id to avoid hyphens) + output_var = f"x_{node_id[:8].replace('-', '_')}" + var_map[node_id] = output_var + + # Generate forward line with training parameter for layers that need it + if node_type in ('dropout', 'batchnorm', 'batchnorm2d'): + lines.append(f" {output_var} = self.{layer_name}({input_var}, training=training)") + else: + lines.append(f" {output_var} = self.{layer_name}({input_var})") + + # Map output ports to return values + if len(output_ports) == 1: + output_node_id = output_ports[0]['internalNodeId'] + output_var = var_map.get(output_node_id, 'inputs') + lines.append(f" return {output_var}") + else: + output_vars = [] + for port in output_ports: + output_node_id = port['internalNodeId'] + output_vars.append(var_map.get(output_node_id, 'inputs')) + lines.append(f" return ({', '.join(output_vars)})") + + return "\n".join(lines) + + def _generate_block_docstring( + self, + block_name: str, + description: str, + port_mappings: List[Dict[str, Any]], + nodes: List[Dict[str, Any]] + ) -> str: + """Generate comprehensive docstring for block class.""" + lines = [] + lines.append(f"Custom Block: {block_name}") + lines.append("") + + if description: + lines.append(description) + lines.append("") + + lines.append("This block encapsulates a reusable subgraph of layers.") + lines.append("") + lines.append("Note: TensorFlow uses NHWC format (batch, height, width, channels)") + lines.append("") + + # Document ports + input_ports = [pm for pm in port_mappings if pm['type'] == 'input'] + output_ports = [pm for pm in port_mappings if pm['type'] == 'output'] + + if input_ports: + lines.append("Input Ports:") + for port in input_ports: + label = port.get('externalPortLabel', 'input') + lines.append(f" - {label}") + + if output_ports: + lines.append("") + lines.append("Output Ports:") + for port in output_ports: + label = port.get('externalPortLabel', 'output') + lines.append(f" - {label}") + + lines.append("") + lines.append(f"Internal Layers: {len([n for n in nodes if get_node_type(n) not in ('input', 'dataloader', 'output')])}") + + return "\n ".join(lines) + + def _generate_layer_instantiation_line( + self, + layer_name: str, + layer_class_name: str, + node_type: str, + shape_info: Dict[str, Any], + config: Dict[str, Any], + is_first: bool = False + ) -> str: + """ + Generate layer instantiation line for TensorFlow/Keras layers. + + TensorFlow/Keras layer classes have all configuration baked into their + class definitions, so __init__ methods take no parameters. This differs + from PyTorch where layers need input dimensions in the constructor. + """ + # TensorFlow layers don't need input shape parameters in constructor + # All configuration is already baked into the layer class definition + # Just instantiate with no arguments + return f"self.{layer_name} = {layer_class_name}()" + + def _get_internal_layer_name( + self, + node_type: str, + node_id: str, + layer_count: Dict[str, int] + ) -> str: + """Generate unique layer variable name for internal node.""" + # Use node_id suffix for uniqueness (sanitize to avoid hyphens) + suffix = node_id[:8].replace('-', '_') + base_name = node_type.replace('_', '') + + # Track count for this type + if node_type not in layer_count: + layer_count[node_type] = 0 + layer_count[node_type] += 1 + + return f"{base_name}_{suffix}" + + def _get_layer_class_name_for_node( + self, + node_type: str, + config: Dict[str, Any] + ) -> str: + """Get the layer class name that will be used in the main model.""" + # These should match the class names generated by generate_layer_class + type_name = node_type.replace('_', '').replace('2d', '2D').replace('3d', '3D').title() + + if node_type == 'conv2d': + filters = config.get('filters', 64) + kernel = config.get('kernel_size', 3) + return f"{type_name}Layer_{filters}filters_{kernel}x{kernel}" + elif node_type == 'linear': + units = config.get('units', 128) + return f"DenseLayer_{units}units" + elif node_type in ('maxpool2d', 'maxpool'): + pool_size = config.get('pool_size', 2) + return f"MaxPool2DLayer_{pool_size}x{pool_size}" + elif node_type == 'custom': + name = config.get('name', 'CustomLayer') + safe_name = name.replace(' ', '_').replace('-', '_') + return f"CustomLayer_{safe_name}" + else: + # For other types, we'll need to generate a generic name + # This will be handled by the main code generation + return f"{type_name}Layer" + + def _to_class_name(self, name: str) -> str: + """Convert block name to valid Python class name.""" + import re + # Remove special characters and convert to PascalCase + name = re.sub(r'[^a-zA-Z0-9]', ' ', name) + name = ''.join(word.capitalize() for word in name.split()) + if not name: + return 'CustomBlock' + if name[0].isdigit(): + name = 'Block' + name + return name + 'Block' + + def get_block_class_name(self, definition_id: str) -> Optional[str]: + """ + Get the generated class name for a block definition. + + Args: + definition_id: ID of the GroupBlockDefinition + + Returns: + Class name if generated, None otherwise + """ + return self.generated_classes.get(definition_id) def generate_tensorflow_code( nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]], - project_name: str = "GeneratedModel" -) -> Dict[str, str]: + project_name: str = "GeneratedModel", + group_definitions: Optional[List[Dict[str, Any]]] = None +) -> Tuple[Dict[str, str], List[Exception]]: """ Generate complete TensorFlow/Keras code including model, training, and data loading. Each layer gets its own reusable class, all combined in a main model class. @@ -20,28 +476,48 @@ def generate_tensorflow_code( nodes: List of node dictionaries from architecture edges: List of edge dictionaries defining connections project_name: Name for the generated model class + group_definitions: Optional list of GroupBlockDefinition dictionaries Returns: - Dictionary with keys: 'model', 'train', 'dataset', 'config' + Tuple of (dictionary with keys: 'model', 'train', 'dataset', 'config', list of errors) """ # Topologically sort nodes sorted_nodes = topological_sort(nodes, edges) - # Infer shapes through the graph - shape_map = infer_shapes(sorted_nodes, edges) + # Convert group_definitions list to dict for shape inference + group_defs_dict = None + if group_definitions: + group_defs_dict = {defn['id']: defn for defn in group_definitions} + + # Infer shapes through the graph with group block support + shape_map, shape_errors = infer_shapes(sorted_nodes, edges, group_defs_dict) + + # Validate computed shapes for critical issues + validation_errors = validate_shape_map(sorted_nodes, shape_map) + if validation_errors: + logger.warning(f"Shape validation found {len(validation_errors)} potential issues") + shape_errors.extend(validation_errors) + + # Initialize block generator if we have group definitions + block_generator = None + if group_definitions: + # Create shape computer for block generator + shape_computer = GroupBlockShapeComputer(group_defs_dict) if group_defs_dict else None + block_generator = TensorFlowBlockGenerator(group_definitions, shape_computer) # Generate different components - model_code = generate_model_file(sorted_nodes, edges, project_name, shape_map) + model_code = generate_model_file(sorted_nodes, edges, project_name, shape_map, block_generator) train_code = generate_training_script(project_name) dataset_code = generate_dataset_class(nodes) config_code = generate_config_file(nodes) + # Return generated code with any shape inference errors return { 'model': model_code, 'train': train_code, 'dataset': dataset_code, 'config': config_code - } + }, shape_errors def generate_single_layer_class( @@ -192,15 +668,72 @@ def topological_sort(nodes: List[Dict], edges: List[Dict]) -> List[Dict]: return [node_map[node_id] for node_id in sorted_ids if node_id in node_map] -def infer_shapes(nodes: List[Dict], edges: List[Dict]) -> Dict[str, Dict[str, Any]]: +def extract_output_shape_from_metadata(node: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Extract output shape from node's frontend-provided metadata (TensorFlow/NHWC version). + + The frontend computes output shapes accurately during the visual design phase + and stores them in node.data.outputShape. This function extracts those + pre-computed shapes, which are considered authoritative. + + Args: + node: Node dictionary with potential data.outputShape metadata + + Returns: + Dictionary with shape keys (out_channels, out_features, etc.) or None if + metadata is incomplete/missing + """ + output_shape = node.get('data', {}).get('outputShape', {}) + if not output_shape or not isinstance(output_shape, dict): + return None + + dims = output_shape.get('dims', []) + if not dims: + return None + + shape_info = {} + + # TensorFlow uses NHWC format: [batch, height, width, channels] + # Note: This is different from PyTorch's NCHW format! + if len(dims) == 4: + shape_info['out_height'] = dims[1] + shape_info['out_width'] = dims[2] + shape_info['out_channels'] = dims[3] + elif len(dims) == 2: # [batch, features] - for Dense/Flatten output + shape_info['out_features'] = dims[1] + else: + # Unusual shape format - log for debugging but don't fail + logger.debug(f"Unusual output shape dims: {dims}") + return None + + return shape_info + + +def infer_shapes( + nodes: List[Dict], + edges: List[Dict], + group_definitions: Optional[Dict[str, Any]] = None +) -> Tuple[Dict[str, Dict[str, Any]], List[Exception]]: """ Infer input/output shapes for each layer in the graph. TensorFlow uses NHWC format (batch, height, width, channels). + Enhanced to handle group blocks properly. + + Args: + nodes: List of node dictionaries + edges: List of edge dictionaries + group_definitions: Optional map of group definition IDs to definitions Returns: - Dictionary mapping node_id to shape info: {'in_channels', 'out_channels', 'in_units', 'out_units', etc.} + Tuple of (dictionary mapping node_id to shape info, list of errors) """ shape_map = {} + errors = [] + + # Initialize shape computer for group blocks + shape_computer = None + if group_definitions: + shape_computer = GroupBlockShapeComputer(group_definitions) # Build edge map for finding inputs edge_map = {} @@ -220,119 +753,456 @@ def infer_shapes(nodes: List[Dict], edges: List[Dict]) -> Dict[str, Dict[str, An # Get incoming edges incoming = edge_map.get(node_id, []) - # Initialize shape info for this node - shape_info = {} + # ========== PHASE 1: Extract output metadata (if available) ========== + # Frontend provides accurate output shapes in metadata + metadata_shape = extract_output_shape_from_metadata(node) + shape_info = metadata_shape if metadata_shape else {} + + # ========== PHASE 2: Compute input dimensions from upstream nodes ========== + # Input dimensions ALWAYS come from upstream, regardless of metadata + # This is critical for layers like Conv2D, Dense, BatchNorm if node_type == 'input': - # Parse input shape (NHWC format) - shape_str = config.get('shape', '[1, 224, 224, 3]') - try: - import json - shape = json.loads(shape_str) - if len(shape) >= 4: - shape_info['out_height'] = shape[1] # NHWC format - shape_info['out_width'] = shape[2] - shape_info['out_channels'] = shape[3] - elif len(shape) >= 2: - shape_info['out_units'] = shape[1] - except: - shape_info['out_height'] = 224 - shape_info['out_width'] = 224 - shape_info['out_channels'] = 3 + # Input nodes have no upstream - parse from config if metadata doesn't exist + if not metadata_shape: + shape_str = config.get('shape', '[1, 224, 224, 3]') + try: + import json + shape = json.loads(shape_str) + if len(shape) >= 4: + shape_info['out_height'] = shape[1] # NHWC format + shape_info['out_width'] = shape[2] + shape_info['out_channels'] = shape[3] + elif len(shape) >= 2: + shape_info['out_units'] = shape[1] + except (json.JSONDecodeError, ValueError, KeyError, IndexError, TypeError) as e: + logger.warning( + f"Failed to parse input shape for node {node_id}: {e}. " + f"Using default shape [1, 224, 224, 3] (NHWC)" + ) + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason=f"Failed to parse shape configuration: {str(e)}", + suggestion="Check that the input shape is a valid JSON array like [1, 224, 224, 3]" + )) + shape_info['out_height'] = 224 + shape_info['out_width'] = 224 + shape_info['out_channels'] = 3 elif node_type == 'conv2d': - # Get input channels from previous layer + # Get input channels from upstream layer (ALWAYS required) if incoming and incoming[0] in shape_map: - shape_info['in_channels'] = shape_map[incoming[0]].get('out_channels', 3) + try: + upstream_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_channels'], + default_values={'out_channels': 3} + ) + shape_info['in_channels'] = upstream_shape['out_channels'] + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for node {node_id}: {e}. Using default.") + errors.append(e) + shape_info['in_channels'] = 3 else: shape_info['in_channels'] = 3 - # Output channels (filters) from config - shape_info['out_channels'] = config.get('filters', 64) - - # Calculate output spatial dimensions - if incoming and incoming[0] in shape_map: - prev_shape = shape_map[incoming[0]] - kernel_size = config.get('kernel_size', 3) - strides = config.get('strides', 1) - padding = config.get('padding', 'valid') - - if 'out_height' in prev_shape and 'out_width' in prev_shape: - if padding == 'same': - # Same padding preserves dimensions (with stride) - shape_info['out_height'] = (prev_shape['out_height'] + strides - 1) // strides - shape_info['out_width'] = (prev_shape['out_width'] + strides - 1) // strides - else: # valid padding - shape_info['out_height'] = (prev_shape['out_height'] - kernel_size) // strides + 1 - shape_info['out_width'] = (prev_shape['out_width'] - kernel_size) // strides + 1 + # Output channels: use metadata if available, otherwise config + if 'out_channels' not in shape_info: + shape_info['out_channels'] = config.get('filters', 64) + + # Spatial dimensions: use metadata if available, otherwise calculate + if 'out_height' not in shape_info or 'out_width' not in shape_info: + if incoming and incoming[0] in shape_map: + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_height', 'out_width'], + default_values=None + ) + kernel_size = config.get('kernel_size', 3) + strides = config.get('strides', 1) + padding = config.get('padding', 'valid') + + if padding == 'same': + # Same padding preserves dimensions (with stride) + shape_info['out_height'] = (prev_shape['out_height'] + strides - 1) // strides + shape_info['out_width'] = (prev_shape['out_width'] + strides - 1) // strides + else: # valid padding + shape_info['out_height'] = (prev_shape['out_height'] - kernel_size) // strides + 1 + shape_info['out_width'] = (prev_shape['out_width'] - kernel_size) // strides + 1 + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Could not compute spatial dimensions for conv2d {node_id}: {e}") + errors.append(e) elif node_type in ('maxpool2d', 'maxpool'): - # Preserve channels, reduce spatial dimensions + # MaxPool preserves channels from upstream if incoming and incoming[0] in shape_map: - prev_shape = shape_map[incoming[0]] - shape_info['in_channels'] = prev_shape.get('out_channels', 64) - shape_info['out_channels'] = shape_info['in_channels'] - - pool_size = config.get('pool_size', 2) - strides = config.get('strides', 2) - padding = config.get('padding', 'valid') - - if 'out_height' in prev_shape and 'out_width' in prev_shape: - if padding == 'same': - shape_info['out_height'] = (prev_shape['out_height'] + strides - 1) // strides - shape_info['out_width'] = (prev_shape['out_width'] + strides - 1) // strides - else: # valid padding - shape_info['out_height'] = (prev_shape['out_height'] - pool_size) // strides + 1 - shape_info['out_width'] = (prev_shape['out_width'] - pool_size) // strides + 1 + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_channels'], + default_values={'out_channels': 64} + ) + shape_info['out_channels'] = prev_shape['out_channels'] + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for maxpool {node_id}: {e}") + errors.append(e) + shape_info['out_channels'] = 64 + else: + shape_info['out_channels'] = 64 + + # Spatial dimensions: use metadata if available, otherwise calculate + if 'out_height' not in shape_info or 'out_width' not in shape_info: + if incoming and incoming[0] in shape_map: + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_height', 'out_width'], + default_values={'out_height': 7, 'out_width': 7} + ) + pool_size = config.get('pool_size', 2) + strides = config.get('strides', 2) + padding = config.get('padding', 'valid') + + if padding == 'same': + shape_info['out_height'] = (prev_shape['out_height'] + strides - 1) // strides + shape_info['out_width'] = (prev_shape['out_width'] + strides - 1) // strides + else: # valid padding + shape_info['out_height'] = (prev_shape['out_height'] - pool_size) // strides + 1 + shape_info['out_width'] = (prev_shape['out_width'] - pool_size) // strides + 1 + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Could not compute spatial dimensions for maxpool {node_id}: {e}") + errors.append(e) elif node_type == 'flatten': - # Convert spatial dimensions to units - if incoming and incoming[0] in shape_map: - prev_shape = shape_map[incoming[0]] - channels = prev_shape.get('out_channels', 64) - height = prev_shape.get('out_height', 7) - width = prev_shape.get('out_width', 7) - shape_info['out_units'] = channels * height * width + # Flatten converts spatial dimensions to units + # Use metadata if available, otherwise calculate from upstream + if 'out_units' not in shape_info: + if incoming and incoming[0] in shape_map: + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_channels', 'out_height', 'out_width'], + default_values={'out_channels': 64, 'out_height': 7, 'out_width': 7} + ) + channels = prev_shape['out_channels'] + height = prev_shape['out_height'] + width = prev_shape['out_width'] + shape_info['out_units'] = channels * height * width + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for flatten {node_id}: {e}") + errors.append(e) + shape_info['out_units'] = 3136 # 64 * 7 * 7 + else: + shape_info['out_units'] = 3136 # Default elif node_type == 'linear': - # Get input units from previous layer + # Get input units from upstream layer (ALWAYS required) if incoming and incoming[0] in shape_map: - shape_info['in_units'] = shape_map[incoming[0]].get('out_units', 512) + try: + upstream_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=['out_units'], + default_values={'out_units': 512} + ) + shape_info['in_units'] = upstream_shape['out_units'] + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for linear {node_id}: {e}") + errors.append(e) + shape_info['in_units'] = 512 else: shape_info['in_units'] = 512 - # Output units from config - shape_info['out_units'] = config.get('units', 128) + # Output units: use metadata if available, otherwise config + if 'out_units' not in shape_info: + shape_info['out_units'] = config.get('units', 128) elif node_type in ('batchnorm', 'batchnorm2d'): - # Preserve dimensions - if incoming and incoming[0] in shape_map: - prev_shape = shape_map[incoming[0]] - shape_info.update(prev_shape) + # BatchNorm preserves all dimensions from upstream + # Only copy upstream if metadata doesn't provide them + if not metadata_shape and incoming and incoming[0] in shape_map: + try: + prev_shape = safe_get_shape_data( + shape_map=shape_map, + node_id=node_id, + upstream_node_id=incoming[0], + required_keys=[], # Accept whatever keys exist + default_values={} + ) + shape_info.update(prev_shape) + except (MissingShapeDataError, ShapeInferenceError) as e: + logger.warning(f"Shape inference warning for batchnorm {node_id}: {e}") + errors.append(e) + + elif node_type == 'group': + # Group blocks: Use metadata if available, otherwise compute from internal structure + if not metadata_shape: + # No metadata - compute output shape using shape computer + if shape_computer: + group_def_id = node.get('data', {}).get('groupDefinitionId') + + if group_def_id and incoming and incoming[0] in shape_map: + # Get input shape from upstream node + input_shape = shape_map[incoming[0]] + + # Compute output shape using internal structure + logger.debug(f"Computing shape for group block {node_id} (def: {group_def_id})") + output_shape, shape_errors = shape_computer.compute_output_shape( + group_def_id, + input_shape + ) + + # Collect any errors from shape computation + errors.extend(shape_errors) + + if output_shape: + shape_info = output_shape + logger.debug(f"Group block {node_id} output shape: {output_shape}") + else: + # Fallback: copy input shape + shape_info = input_shape.copy() + logger.warning(f"Failed to compute shape for group block {node_id}, using input shape") + elif incoming and incoming[0] in shape_map: + # No definition found, copy input shape + shape_info = shape_map[incoming[0]].copy() + logger.warning(f"Group block {node_id} has no definition ID, using input shape") + else: + # No input, use default + shape_info = {'out_channels': 3, 'out_height': 224, 'out_width': 224} + logger.warning(f"Group block {node_id} has no incoming edges, using default shape") + else: + # No shape computer available, fall back to old behavior + if incoming and incoming[0] in shape_map: + prev_shape = shape_map[incoming[0]] + # Copy input shape as default + shape_info.update(prev_shape) + else: + # Default starting shape + shape_info['out_channels'] = 3 + shape_info['out_height'] = 224 + shape_info['out_width'] = 224 else: - # For other layers, try to preserve shape from input - if incoming and incoming[0] in shape_map: + # For other layers: Use metadata if available, otherwise preserve upstream shape + if not metadata_shape and incoming and incoming[0] in shape_map: prev_shape = shape_map[incoming[0]] shape_info.update(prev_shape) - + shape_map[node_id] = shape_info - return shape_map + return shape_map, errors + + +def validate_shape_map( + nodes: List[Dict], + shape_map: Dict[str, Dict[str, Any]] +) -> List[Exception]: + """ + Validate computed shape map for common critical issues (TensorFlow version). + + This catches problems that would cause runtime errors in generated code: + - Missing shape information + - Invalid dimensions (zero or negative) + - Type-specific requirements not met + + Args: + nodes: List of all nodes + shape_map: Computed shape mapping + + Returns: + List of validation errors (as exceptions for consistency with shape_errors) + """ + errors = [] + + for node in nodes: + node_id = node['id'] + node_type = get_node_type(node) + + # Skip non-layer nodes + if node_type in ('input', 'output', 'dataloader', 'group'): + continue + + shape_info = shape_map.get(node_id) + + # Critical: Shape info must exist + if not shape_info: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason="No shape information computed for node", + suggestion="Check that node has valid upstream connections and metadata" + )) + continue + + # Type-specific validation + if node_type == 'linear' or node_type == 'dense': + # Linear/Dense MUST have in_features or in_units + if 'in_features' not in shape_info and 'in_units' not in shape_info: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason="Missing required in_features/in_units for Linear/Dense layer", + suggestion="Check upstream Flatten or Linear layer output shape" + )) + # in_features/in_units must be positive + in_val = shape_info.get('in_features') or shape_info.get('in_units', 0) + if in_val <= 0: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason=f"Invalid in_features/in_units={in_val} (must be > 0)", + suggestion="Check upstream layer produces valid output shape" + )) + + elif node_type == 'conv2d': + # Conv2d MUST have in_channels + if 'in_channels' not in shape_info: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason="Missing required in_channels for Conv2d layer", + suggestion="Check upstream Conv2d or Input layer provides channels" + )) + + elif node_type == 'flatten': + # Flatten MUST produce out_features + if 'out_features' not in shape_info: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason="Flatten layer must produce out_features", + suggestion="Check upstream layer has spatial dimensions (NHWC format)" + )) + elif shape_info.get('out_features', 0) <= 0: + errors.append(ShapeInferenceError( + node_id=node_id, + node_type=node_type, + reason=f"Invalid out_features={shape_info.get('out_features')} (must be > 0)", + suggestion="Check upstream layer output dimensions are valid" + )) + + return errors + + +def collect_all_nodes_with_internals( + main_nodes: List[Dict], + block_generator: Optional[TensorFlowBlockGenerator] = None +) -> List[Tuple[Dict, int, str]]: + """ + Collect all nodes including internal nodes from group blocks. + Returns list of tuples: (node, index, source_context) + source_context is either 'main' or 'group_{group_def_id}' + + This ensures we generate layer classes for ALL nodes, not just main model nodes. + """ + all_nodes = [] + node_index = 0 + + # Add main model nodes + for node in main_nodes: + all_nodes.append((node, node_index, 'main')) + node_index += 1 + + # Add internal nodes from group definitions + if block_generator: + for group_def_id, group_def in block_generator.group_definitions.items(): + internal_structure = group_def.get('internal_structure', {}) + internal_nodes = internal_structure.get('nodes', []) + + for internal_node in internal_nodes: + node_type = get_node_type(internal_node) + # Skip input/output nodes + if node_type not in ('input', 'dataloader', 'output'): + all_nodes.append((internal_node, node_index, f'group_{group_def_id}')) + node_index += 1 + + return all_nodes + + +def get_layer_signature(node: Dict, config: Dict[str, Any], node_type: str) -> str: + """ + Generate a unique signature for a layer based on its type and config. + Used for deduplication - layers with same signature can share the same class. + """ + if node_type == 'conv2d': + return f"conv2d_{config.get('out_channels', 64)}_{config.get('kernel_size', 3)}_{config.get('stride', 1)}_{config.get('padding', 0)}_{config.get('dilation', 1)}" + elif node_type == 'linear': + return f"linear_{config.get('out_features', 128)}_{config.get('bias', True)}" + elif node_type == 'maxpool': + return f"maxpool_{config.get('kernel_size', 2)}_{config.get('stride', 2)}_{config.get('padding', 0)}" + elif node_type == 'dropout': + return f"dropout_{config.get('p', 0.5)}" + elif node_type == 'batchnorm': + return f"batchnorm_{config.get('eps', 1e-5)}_{config.get('momentum', 0.1)}_{config.get('affine', True)}" + elif node_type == 'softmax': + return f"softmax_{config.get('dim', 1)}" + elif node_type == 'attention': + return f"attention_{config.get('embed_dim', 512)}_{config.get('num_heads', 8)}_{config.get('dropout', 0.0)}" + elif node_type == 'custom': + return f"custom_{config.get('name', 'CustomLayer')}" + else: + # For layers without config (relu, flatten, etc.) + return node_type def generate_model_file( nodes: List[Dict], edges: List[Dict], project_name: str, - shape_map: Dict[str, Dict[str, Any]] + shape_map: Dict[str, Dict[str, Any]], + block_generator: Optional[TensorFlowBlockGenerator] = None ) -> str: """Generate complete model.py file with layer classes and main model class""" class_name = to_class_name(project_name) - # Generate individual layer classes + # Generate block class definitions FIRST (if any) - this populates the cache + block_classes_code = "" + if block_generator: + block_classes_code = block_generator.generate_all_block_classes() + + # COLLECT ALL NODES (main + internal from groups) and generate layer classes + all_nodes_to_generate = collect_all_nodes_with_internals(nodes, block_generator) + + # DEDUPLICATE by signature and generate layer classes + seen_signatures = set() layer_classes = [] + + for node, idx, source_context in all_nodes_to_generate: + node_type = get_node_type(node) + config = node.get('data', {}).get('config', {}) + node_id = node['id'] + + # Get shape info (use shape_map for main nodes, extract for internal) + if source_context == 'main': + shape_info = shape_map.get(node_id, {}) + else: + shape_info = extract_shape_info_from_node(node) + + # Generate signature for deduplication + signature = get_layer_signature(node, config, node_type) + + # Only generate if we haven't seen this signature before + if signature not in seen_signatures: + seen_signatures.add(signature) + layer_class_code = generate_layer_class(node, idx, config, node_type, shape_info) + if layer_class_code: + layer_classes.append(layer_class_code) + + # Now generate layer instantiations and forward pass for MAIN MODEL ONLY layer_instantiations = [] forward_pass_lines = [] @@ -358,10 +1228,102 @@ def generate_model_file( var_map[node_id] = 'x' if not var_map else 'x' continue - # Generate layer class - layer_class_code = generate_layer_class(node, idx, config, node_type, shape_info) - if layer_class_code: - layer_classes.append(layer_class_code) + # Handle group blocks differently + if node_type == 'group': + # Get the group definition ID + group_def_id = node.get('data', {}).get('groupDefinitionId') + + if block_generator and group_def_id: + # Use the block class name from the generator + block_class_name = block_generator.get_block_class_name(group_def_id) + + if block_class_name: + layer_name = f"block_{node_id.replace('-', '_')}" + + # Get upstream node's output shape from shape_map + incoming = edge_map.get(node_id, []) + params = [] + + if incoming and incoming[0] in shape_map: + # Get upstream node's output shape + upstream_shape = shape_map[incoming[0]] + + # Extract in_channels or in_features from upstream shape + # TensorFlow uses same parameter names as PyTorch for consistency + # Pass in_channels if the upstream outputs channels (convolutional layers) + if 'out_channels' in upstream_shape: + in_channels = upstream_shape['out_channels'] + params.append(f"in_channels={in_channels}") + logger.debug(f"TF Block {node_id}: passing in_channels={in_channels} from upstream node {incoming[0]}") + + # Pass in_features if the upstream outputs features (linear layers) + # TensorFlow uses 'out_units' instead of 'out_features' + elif 'out_units' in upstream_shape: + in_units = upstream_shape['out_units'] + params.append(f"in_features={in_units}") + logger.debug(f"TF Block {node_id}: passing in_features={in_units} from upstream node {incoming[0]}") + elif 'out_features' in upstream_shape: + in_features = upstream_shape['out_features'] + params.append(f"in_features={in_features}") + logger.debug(f"TF Block {node_id}: passing in_features={in_features} from upstream node {incoming[0]}") + + # Pass num_features if the upstream outputs num_features (batch norm) + elif 'num_features' in upstream_shape: + num_features = upstream_shape['num_features'] + params.append(f"num_features={num_features}") + logger.debug(f"TF Block {node_id}: passing num_features={num_features} from upstream node {incoming[0]}") + else: + # Upstream shape exists but doesn't have expected keys + logger.warning(f"TF Block {node_id}: upstream shape {upstream_shape} doesn't contain expected keys") + else: + # Handle case where no upstream exists (use input node shape) + # Look for input nodes in the graph + input_nodes = [n for n in nodes if get_node_type(n) == 'input'] + if input_nodes and input_nodes[0]['id'] in shape_map: + input_shape = shape_map[input_nodes[0]['id']] + + # Use input node's output shape + if 'out_channels' in input_shape: + in_channels = input_shape['out_channels'] + params.append(f"in_channels={in_channels}") + logger.debug(f"TF Block {node_id}: no upstream, using input shape in_channels={in_channels}") + elif 'out_units' in input_shape: + in_units = input_shape['out_units'] + params.append(f"in_features={in_units}") + logger.debug(f"TF Block {node_id}: no upstream, using input shape in_features={in_units}") + elif 'out_features' in input_shape: + in_features = input_shape['out_features'] + params.append(f"in_features={in_features}") + logger.debug(f"TF Block {node_id}: no upstream, using input shape in_features={in_features}") + else: + logger.warning(f"TF Block {node_id}: input shape {input_shape} doesn't contain expected keys") + else: + # No upstream and no input node, use defaults + logger.warning(f"TF Block {node_id}: no upstream connection and no input node found") + + # Generate instantiation with computed parameters + # Each instance gets independent shape computation based on its position in the graph + if params: + layer_instantiations.append(f"self.{layer_name} = {block_class_name}({', '.join(params)}) # Instance at position {idx}") + else: + layer_instantiations.append(f"self.{layer_name} = {block_class_name}() # Instance at position {idx}") + + # Generate forward pass line + input_var = get_input_variable(incoming, var_map) + output_var = 'x' + forward_pass_lines.append(f"{output_var} = self.{layer_name}({input_var}, training=training)") + var_map[node_id] = output_var + else: + # Block class not found, skip + logger.warning(f"TF Block class not found for group definition {group_def_id}") + var_map[node_id] = 'x' + else: + # No block generator or definition ID, skip + logger.warning(f"TF No block generator or definition ID for node {node_id}") + var_map[node_id] = 'x' + continue + + # For regular nodes, we already generated the layer class above (no need to generate again) # Generate layer instantiation for __init__ layer_name = get_layer_variable_name(node_type, idx, config) @@ -398,6 +1360,10 @@ def generate_model_file( ''' + # Add block class definitions (already generated at the start) + if block_classes_code: + code += block_classes_code + '\n\n' + # Add all layer class definitions for layer_class in layer_classes: code += layer_class + '\n\n' @@ -488,6 +1454,12 @@ def generate_layer_class( ) -> Optional[str]: """Generate a complete layer class definition with documentation""" + # Special node types that don't generate individual layer classes: + # - input/output/dataloader: Architectural markers for graph structure + # - group: Reusable components generated separately by BlockGenerator + if node_type in ('input', 'output', 'dataloader', 'group'): + return None + class_name = get_layer_class_name(node_type, idx, config) if node_type == 'conv2d': @@ -886,7 +1858,12 @@ def call(self, inputs: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor: x = inputs return x''' - return None + # If we reach here, the node type is not supported + raise UnsupportedNodeTypeError( + node_id=node.get('id', 'unknown'), + node_type=node_type, + framework='TensorFlow' + ) def generate_layer_instantiation( diff --git a/project/block_manager/services/validation.py b/project/block_manager/services/validation.py index a598c17..da5079b 100644 --- a/project/block_manager/services/validation.py +++ b/project/block_manager/services/validation.py @@ -2,6 +2,7 @@ Validation service for model architectures Handles shape checking, connection validation, and architecture integrity """ +import ast from typing import List, Dict, Any, Tuple, Optional @@ -53,7 +54,8 @@ def validate(self) -> Tuple[bool, List[Dict], List[Dict]]: self._validate_orphaned_blocks() self._validate_block_configurations() self._check_circular_dependencies() - + self._validate_shape_compatibility() + is_valid = len(self.errors) == 0 return ( @@ -172,7 +174,7 @@ def _validate_orphaned_blocks(self): if not self.edges: # If there are no edges but there are nodes, all non-input nodes are orphaned for node in self.nodes: - if self._get_block_type(node) != 'input': + if self._get_block_type(node) != 'input' and self._get_block_type(node) != 'output' and self._get_block_type(node) != 'loss': self.warnings.append(ValidationError( message='Block is not connected to the graph', node_id=node['id'], @@ -233,7 +235,19 @@ def _validate_block_configurations(self): )) elif block_type == 'input': - input_shape = eval(config.get('shape')) + # Use ast.literal_eval for safe evaluation of shape configuration + shape_value = config.get('shape') + try: + input_shape = ast.literal_eval(shape_value) if shape_value else None + except (ValueError, SyntaxError): + self.errors.append(ValidationError( + message='Input block has invalid shape format', + node_id=node_id, + error_type='error', + suggestion='Shape must be a valid Python literal (list or tuple)' + )) + input_shape = None + if not input_shape: self.errors.append(ValidationError( message='Input block requires input shape configuration', @@ -283,6 +297,57 @@ def has_cycle(node_id: str) -> bool: )) break + def _validate_shape_compatibility(self): + """ + Perform basic shape compatibility checks before code generation. + This catches common shape mismatches early in the validation phase. + """ + # Build edge map + edge_map = {} + for edge in self.edges: + target = edge.get('target') + source = edge.get('source') + if target not in edge_map: + edge_map[target] = [] + edge_map[target].append(source) + + # Check each node's connections + for node in self.nodes: + node_id = node['id'] + node_type = self._get_block_type(node) + config = node.get('data', {}).get('config', {}) + + # Skip nodes that don't have shape requirements + if node_type in ('input', 'output', 'dataloader'): + continue + + incoming = edge_map.get(node_id, []) + + # Check that nodes with required inputs have connections + if node_type in ('conv2d', 'linear', 'maxpool2d', 'maxpool', 'batchnorm', 'batchnorm2d', 'flatten'): + if not incoming: + self.errors.append(ValidationError( + message=f'{node_type} layer requires an input connection', + node_id=node_id, + error_type='error', + suggestion=f'Connect an upstream layer to this {node_type} layer' + )) + + # Validate flatten placement + if node_type == 'flatten': + if incoming and len(incoming) == 1: + upstream_node = self.node_map.get(incoming[0]) + if upstream_node: + upstream_type = self._get_block_type(upstream_node) + # Warn if flatten comes after linear (unusual) + if upstream_type == 'linear': + self.warnings.append(ValidationError( + message='Flatten layer after Linear layer may be unnecessary', + node_id=node_id, + error_type='warning', + suggestion='Flatten is typically used before Linear layers, not after' + )) + def validate_architecture(nodes: List[Dict], edges: List[Dict]) -> Dict[str, Any]: """ diff --git a/project/block_manager/urls.py b/project/block_manager/urls.py index bc4e0bb..1207bc6 100644 --- a/project/block_manager/urls.py +++ b/project/block_manager/urls.py @@ -3,7 +3,7 @@ from block_manager.views.project_views import ProjectViewSet from block_manager.views.architecture_views import ( - save_architecture, + save_architecture, load_architecture, get_node_definitions, get_node_definition, @@ -12,6 +12,7 @@ from block_manager.views.validation_views import validate_model from block_manager.views.export_views import export_model from block_manager.views.chat_views import chat_message, get_suggestions, get_environment_info +from block_manager.views.group_views import group_definition_list, group_definition_detail # Create router for viewsets router = DefaultRouter() @@ -24,7 +25,11 @@ # Architecture endpoints path('projects//save-architecture', save_architecture, name='save-architecture'), path('projects//load-architecture', load_architecture, name='load-architecture'), - + + # Group definition endpoints + path('projects//groups', group_definition_list, name='group-definition-list'), + path('projects//groups/', group_definition_detail, name='group-definition-detail'), + # Node definition endpoints path('node-definitions', get_node_definitions, name='node-definitions'), path('node-definitions/', get_node_definition, name='node-definition'), diff --git a/project/block_manager/views/architecture_views.py b/project/block_manager/views/architecture_views.py index 79dd00a..5ba90e1 100644 --- a/project/block_manager/views/architecture_views.py +++ b/project/block_manager/views/architecture_views.py @@ -2,19 +2,25 @@ from rest_framework.decorators import api_view from rest_framework.response import Response from django.shortcuts import get_object_or_404 +from django.db import transaction -from block_manager.models import Project, ModelArchitecture, Block, Connection +from block_manager.models import Project, ModelArchitecture, Block, Connection, GroupBlockDefinition from block_manager.serializers import ( SaveArchitectureSerializer, ModelArchitectureSerializer, + GroupBlockDefinitionSerializer, ) @api_view(['POST']) +@transaction.atomic def save_architecture(request, project_id): """ Save architecture for a project Accepts nodes and edges from frontend canvas + + Uses atomic transaction to ensure data integrity - all database + operations succeed or rollback together on failure. """ project = get_object_or_404(Project, pk=project_id) serializer = SaveArchitectureSerializer(data=request.data) @@ -27,21 +33,46 @@ def save_architecture(request, project_id): nodes = serializer.validated_data['nodes'] edges = serializer.validated_data['edges'] - + group_definitions = serializer.validated_data.get('groupDefinitions', []) + # Get or create architecture architecture, created = ModelArchitecture.objects.get_or_create(project=project) - - # Clear existing blocks and connections + + # Clear existing blocks, connections, and group definitions architecture.blocks.all().delete() architecture.connections.all().delete() - + project.group_definitions.all().delete() + + # Save group definitions first - use serializer for validation + group_def_id_map = {} + for group_def in group_definitions: + # Validate and create group definition using serializer + group_serializer = GroupBlockDefinitionSerializer(data=group_def) + if not group_serializer.is_valid(): + return Response( + { + 'success': False, + 'error': 'Invalid group definition', + 'details': group_serializer.errors + }, + status=status.HTTP_400_BAD_REQUEST + ) + + # Save with project context + gbd = group_serializer.save(project=project) + group_def_id_map[str(gbd.id)] = gbd + # Create blocks from nodes node_id_to_block = {} for node in nodes: node_id = node.get('id') node_data = node.get('data', {}) position = node.get('position', {'x': 0, 'y': 0}) - + + # Get group definition if this is a group block + group_def_id = node_data.get('groupDefinitionId') + group_definition = group_def_id_map.get(group_def_id) if group_def_id else None + block = Block.objects.create( architecture=architecture, node_id=node_id, @@ -51,6 +82,9 @@ def save_architecture(request, project_id): config=node_data.get('config', {}), input_shape=node_data.get('inputShape'), output_shape=node_data.get('outputShape'), + group_definition=group_definition, + is_expanded=node_data.get('isExpanded', False), + repetition_metadata=node_data.get('repetitionMetadata') ) node_id_to_block[node_id] = block @@ -77,6 +111,7 @@ def save_architecture(request, project_id): architecture.canvas_state = { 'nodes': nodes, 'edges': edges, + 'groupDefinitions': group_definitions, } architecture.save() @@ -118,21 +153,30 @@ def load_architecture(request, project_id): # Reconstruct from database nodes = [] for block in architecture.blocks.all(): + node_data = { + 'blockType': block.block_type, + 'config': block.config, + 'inputShape': block.input_shape, + 'outputShape': block.output_shape, + } + + # Add group block specific data + if block.block_type == 'group' and block.group_definition: + node_data['groupDefinitionId'] = str(block.group_definition.id) + node_data['isExpanded'] = block.is_expanded + if block.repetition_metadata: + node_data['repetitionMetadata'] = block.repetition_metadata + nodes.append({ 'id': block.node_id, - 'type': block.block_type, + 'type': 'group' if block.block_type == 'group' else 'custom', 'position': { 'x': block.position_x, 'y': block.position_y, }, - 'data': { - 'blockType': block.block_type, - 'config': block.config, - 'inputShape': block.input_shape, - 'outputShape': block.output_shape, - } + 'data': node_data }) - + edges = [] for conn in architecture.connections.all(): edges.append({ @@ -142,10 +186,17 @@ def load_architecture(request, project_id): 'sourceHandle': conn.source_handle, 'targetHandle': conn.target_handle, }) + + # Load group definitions + group_definitions = [] + for group_def in project.group_definitions.all(): + serializer = GroupBlockDefinitionSerializer(group_def) + group_definitions.append(serializer.data) return Response({ 'nodes': nodes, 'edges': edges, + 'groupDefinitions': group_definitions, }) diff --git a/project/block_manager/views/export_views.py b/project/block_manager/views/export_views.py index a3874a5..0e7e15a 100644 --- a/project/block_manager/views/export_views.py +++ b/project/block_manager/views/export_views.py @@ -26,6 +26,7 @@ def export_model(request): edges = request.data.get('edges', []) export_format = request.data.get('format', 'pytorch') project_name = request.data.get('projectName', 'GeneratedModel') + group_definitions = request.data.get('groupDefinitions', []) if not nodes: return Response( @@ -35,15 +36,64 @@ def export_model(request): try: # Generate code based on framework + shape_errors = [] if export_format == 'pytorch': - generated = generate_pytorch_code(nodes, edges, project_name) + generated, shape_errors = generate_pytorch_code(nodes, edges, project_name, group_definitions) elif export_format == 'tensorflow': - generated = generate_tensorflow_code(nodes, edges, project_name) + generated, shape_errors = generate_tensorflow_code(nodes, edges, project_name, group_definitions) else: return Response( {'error': f'Unsupported export format: {export_format}'}, status=status.HTTP_400_BAD_REQUEST ) + + # Check if there are shape errors that should prevent export + if shape_errors: + # Format shape errors with comprehensive details for frontend + formatted_errors = [] + for error in shape_errors: + error_dict = { + 'type': 'error', # Mark as error type for frontend + 'message': str(error) + } + + # Extract additional context from specific error types + # These attributes come from our custom exception classes + if hasattr(error, 'node_id'): + error_dict['nodeId'] = error.node_id + if hasattr(error, 'node_type'): + error_dict['nodeType'] = error.node_type + if hasattr(error, 'block_name'): + error_dict['blockName'] = error.block_name + if hasattr(error, 'layer_name'): + error_dict['layerName'] = error.layer_name + if hasattr(error, 'expected'): + error_dict['expected'] = error.expected + if hasattr(error, 'actual'): + error_dict['actual'] = error.actual + if hasattr(error, 'suggestion'): + error_dict['suggestion'] = error.suggestion + if hasattr(error, 'reason'): + error_dict['reason'] = error.reason + if hasattr(error, 'upstream_node_id'): + error_dict['upstreamNodeId'] = error.upstream_node_id + if hasattr(error, 'missing_keys'): + error_dict['missingKeys'] = error.missing_keys + if hasattr(error, 'framework'): + error_dict['framework'] = error.framework + + formatted_errors.append(error_dict) + + # Return validation-style error response that frontend expects + return Response( + { + 'error': 'Code generation errors detected', + 'validationErrors': formatted_errors, + 'details': 'Please fix the errors in your architecture before exporting.', + 'errorCount': len(formatted_errors) + }, + status=status.HTTP_400_BAD_REQUEST + ) # Create a zip file with all generated files zip_buffer = io.BytesIO() diff --git a/project/block_manager/views/group_views.py b/project/block_manager/views/group_views.py new file mode 100644 index 0000000..4a67852 --- /dev/null +++ b/project/block_manager/views/group_views.py @@ -0,0 +1,71 @@ +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from django.shortcuts import get_object_or_404 + +from block_manager.models import Project, GroupBlockDefinition +from block_manager.serializers import GroupBlockDefinitionSerializer + + +@api_view(['GET', 'POST']) +def group_definition_list(request, project_id): + """ + List all group definitions for a project or create a new one. + """ + project = get_object_or_404(Project, id=project_id) + + if request.method == 'GET': + definitions = GroupBlockDefinition.objects.filter(project=project) + serializer = GroupBlockDefinitionSerializer(definitions, many=True) + return Response(serializer.data) + + elif request.method == 'POST': + serializer = GroupBlockDefinitionSerializer(data=request.data) + if serializer.is_valid(): + serializer.save(project=project) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +@api_view(['GET', 'PUT', 'DELETE']) +def group_definition_detail(request, project_id, definition_id): + """ + Retrieve, update, or delete a group definition. + """ + project = get_object_or_404(Project, id=project_id) + definition = get_object_or_404(GroupBlockDefinition, id=definition_id, project=project) + + if request.method == 'GET': + serializer = GroupBlockDefinitionSerializer(definition) + return Response(serializer.data) + + elif request.method == 'PUT': + serializer = GroupBlockDefinitionSerializer(definition, data=request.data, partial=True) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + elif request.method == 'DELETE': + # Check if there are instances using this definition + instances_count = definition.instances.count() + + # Get cascade option from query params (default: False) + cascade = request.query_params.get('cascade', 'false').lower() == 'true' + + if instances_count > 0 and not cascade: + return Response( + { + 'error': 'Cannot delete definition with active instances', + 'instances_count': instances_count, + 'message': f'{instances_count} block instance(s) are using this definition. Use cascade=true to delete all instances.' + }, + status=status.HTTP_400_BAD_REQUEST + ) + + # Delete the definition (CASCADE will handle instances if cascade=true) + definition.delete() + return Response( + {'message': 'Group definition deleted successfully'}, + status=status.HTTP_204_NO_CONTENT + ) diff --git a/project/frontend/EXPORT_FORMAT.md b/project/frontend/EXPORT_FORMAT.md deleted file mode 100644 index 34b682b..0000000 --- a/project/frontend/EXPORT_FORMAT.md +++ /dev/null @@ -1,322 +0,0 @@ -# JSON Export Format Documentation - -## Overview - -VisionForge exports neural network architectures in a clean, human-readable JSON format. This format is designed to be: -- **Secure**: No code execution, only configuration data -- **Readable**: Well-structured and easy to understand -- **Portable**: Can be shared, versioned, and imported across projects - -## Format Structure - -```json -{ - "version": "1.0.0", - "projectName": "My Neural Network", - "projectDescription": "A custom architecture for image classification", - "framework": "pytorch", - "architecture": { - "nodes": [...], - "connections": [...] - }, - "metadata": { - "exportedAt": 1704567890123, - "nodeCount": 8, - "edgeCount": 7 - } -} -``` - -## Field Descriptions - -### Top-Level Fields - -- **version** (string): Format version for compatibility checking -- **projectName** (string): Name of the project -- **projectDescription** (string): Optional description -- **framework** (string): Target framework (`"pytorch"` or `"tensorflow"`) -- **architecture** (object): Contains nodes and connections -- **metadata** (object): Export metadata - -### Architecture Object - -#### Nodes Array - -Each node represents a neural network layer/block: - -```json -{ - "id": "node-1704567890123", - "type": "conv2d", - "label": "Conv2D Layer", - "category": "basic", - "config": { - "in_channels": 3, - "out_channels": 64, - "kernel_size": 3, - "stride": 1, - "padding": 1 - }, - "inputShape": { - "dims": ["batch", 3, 224, 224] - }, - "outputShape": { - "dims": ["batch", 64, 224, 224] - } -} -``` - -Fields: -- **id**: Unique identifier for the node -- **type**: Block type (input, linear, conv2d, etc.) -- **label**: Display name -- **category**: Block category (input, basic, advanced, merge) -- **config**: Layer-specific configuration parameters -- **inputShape**: Input tensor shape (optional) -- **outputShape**: Output tensor shape (optional) - -#### Connections Array - -Each connection represents data flow between nodes: - -```json -{ - "from": "node-1704567890123", - "to": "node-1704567891456" -} -``` - -Fields: -- **from**: Source node ID -- **to**: Target node ID - -### Metadata Object - -```json -{ - "exportedAt": 1704567890123, - "nodeCount": 8, - "edgeCount": 7 -} -``` - -- **exportedAt**: Unix timestamp (milliseconds) -- **nodeCount**: Number of nodes in architecture -- **edgeCount**: Number of connections - -## Security Features - -### What's Excluded - -The export format intentionally excludes: -- ❌ UI positioning data (x, y coordinates) -- ❌ Internal React Flow state -- ❌ Generated PyTorch code -- ❌ Training data or weights -- ❌ Executable code of any kind - -### What's Included - -Only safe, declarative configuration: -- ✅ Layer types and names -- ✅ Configuration parameters (numeric values, booleans, strings) -- ✅ Tensor shape information -- ✅ Connection topology - -### Why This is Secure - -1. **No Code Execution**: The format contains only data, no executable code -2. **Validated on Import**: All imported data is validated against known block types -3. **Type-Safe**: TypeScript ensures proper data types -4. **Read-Only**: Import creates new instances, doesn't modify existing code - -## Example: Complete Export - -```json -{ - "version": "1.0.0", - "projectName": "Simple CNN Classifier", - "projectDescription": "Basic image classification architecture", - "framework": "pytorch", - "architecture": { - "nodes": [ - { - "id": "input-1", - "type": "input", - "label": "Input Layer", - "category": "input", - "config": { - "shape": "[\"batch\", 3, 224, 224]" - }, - "outputShape": { - "dims": ["batch", 3, 224, 224] - } - }, - { - "id": "conv-1", - "type": "conv2d", - "label": "Conv2D", - "category": "basic", - "config": { - "in_channels": 3, - "out_channels": 64, - "kernel_size": 3, - "stride": 1, - "padding": 1 - }, - "inputShape": { - "dims": ["batch", 3, 224, 224] - }, - "outputShape": { - "dims": ["batch", 64, 224, 224] - } - }, - { - "id": "relu-1", - "type": "relu", - "label": "ReLU", - "category": "basic", - "config": {}, - "inputShape": { - "dims": ["batch", 64, 224, 224] - }, - "outputShape": { - "dims": ["batch", 64, 224, 224] - } - }, - { - "id": "pool-1", - "type": "maxpool", - "label": "MaxPool2D", - "category": "basic", - "config": { - "kernel_size": 2, - "stride": 2 - }, - "inputShape": { - "dims": ["batch", 64, 224, 224] - }, - "outputShape": { - "dims": ["batch", 64, 112, 112] - } - }, - { - "id": "flatten-1", - "type": "flatten", - "label": "Flatten", - "category": "basic", - "config": {}, - "inputShape": { - "dims": ["batch", 64, 112, 112] - }, - "outputShape": { - "dims": ["batch", 802816] - } - }, - { - "id": "linear-1", - "type": "linear", - "label": "Linear", - "category": "basic", - "config": { - "in_features": 802816, - "out_features": 1000 - }, - "inputShape": { - "dims": ["batch", 802816] - }, - "outputShape": { - "dims": ["batch", 1000] - } - } - ], - "connections": [ - { "from": "input-1", "to": "conv-1" }, - { "from": "conv-1", "to": "relu-1" }, - { "from": "relu-1", "to": "pool-1" }, - { "from": "pool-1", "to": "flatten-1" }, - { "from": "flatten-1", "to": "linear-1" } - ] - }, - "metadata": { - "exportedAt": 1704567890123, - "nodeCount": 6, - "edgeCount": 5 - } -} -``` - -## Import Behavior - -When importing a JSON file: - -1. **Validation**: File is checked for proper structure and version -2. **Node Creation**: Nodes are recreated with configuration -3. **Layout**: Nodes are arranged in a grid pattern -4. **Connections**: Edges are recreated between nodes -5. **Validation**: Architecture is validated for errors -6. **Project**: A new project is created (or current is updated) - -## Version Compatibility - -- Current version: `1.0.0` -- Future versions will maintain backward compatibility -- Unsupported versions will show a clear error message - -## Best Practices - -### Exporting -- Use descriptive project names -- Add meaningful descriptions -- Validate before exporting to ensure completeness - -### Importing -- Keep backup copies of important architectures -- Review imported architectures before modifying -- Validate after import to check for any issues - -### Sharing -- JSON files can be safely shared via email, GitHub, etc. -- No sensitive data is included -- Files can be version-controlled with Git - -## Troubleshooting - -### Import Errors - -**"Invalid export file format"** -- File is not valid JSON -- Required fields are missing -- Check file hasn't been corrupted - -**"Unsupported export version"** -- File was created with a newer version -- Update VisionForge to latest version -- Contact support if issue persists - -**"Unknown block type"** -- File contains block types not available in current version -- Update VisionForge or remove unsupported blocks - -## Migration Guide - -If you need to modify the JSON manually: - -1. Make a backup copy first -2. Ensure JSON syntax is valid -3. Keep all required fields -4. Match block types to available blocks -5. Validate after re-importing - -## Future Enhancements - -Planned additions to the format: -- Model metadata (author, tags, license) -- Custom block definitions -- Training configuration -- Dataset specifications -- Performance metrics - ---- - -**Note**: This format is designed for architecture definitions only. For full model deployment (including weights), use PyTorch's native save/load mechanisms. diff --git a/project/frontend/index.html b/project/frontend/index.html index a4be6e7..cbd499e 100644 --- a/project/frontend/index.html +++ b/project/frontend/index.html @@ -1,4 +1,3 @@ - diff --git a/project/frontend/package-lock.json b/project/frontend/package-lock.json index 2cedeb3..28e02d8 100644 --- a/project/frontend/package-lock.json +++ b/project/frontend/package-lock.json @@ -79,18 +79,23 @@ "devDependencies": { "@eslint/js": "^9.21.0", "@tailwindcss/postcss": "^4.1.8", + "@testing-library/jest-dom": "^6.9.1", + "@testing-library/react": "^16.3.0", "@types/react": "^19.0.10", "@types/react-dom": "^19.0.4", "@vitejs/plugin-react": "^4.3.4", "@vitejs/plugin-react-swc": "^3.10.1", + "@vitest/ui": "^4.0.15", "eslint": "^9.28.0", "eslint-plugin-react-hooks": "^5.2.0", "eslint-plugin-react-refresh": "^0.4.19", "globals": "^16.0.0", + "jsdom": "^27.2.0", "tailwindcss": "^4.1.11", "typescript": "~5.7.2", "typescript-eslint": "^8.38.0", - "vite": "^6.3.5" + "vite": "^6.3.5", + "vitest": "^4.0.15" }, "workspaces": { "packages": [ @@ -98,6 +103,20 @@ ] } }, + "node_modules/@acemir/cssom": { + "version": "0.9.26", + "resolved": "https://registry.npmjs.org/@acemir/cssom/-/cssom-0.9.26.tgz", + "integrity": "sha512-UMFbL3EnWH/eTvl21dz9s7Td4wYDMtxz/56zD8sL9IZGYyi48RxmdgPMiyT7R6Vn3rjMTwYZ42bqKa7ex74GEQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@adobe/css-tools": { + "version": "4.4.4", + "resolved": "https://registry.npmjs.org/@adobe/css-tools/-/css-tools-4.4.4.tgz", + "integrity": "sha512-Elp+iwUx5rN5+Y8xLt5/GRoG20WGoDCQ/1Fb+1LiGtvwbDavuSk0jhD/eZdckHAuzcDzccnkv+rEjyWfRx18gg==", + "dev": true, + "license": "MIT" + }, "node_modules/@alloc/quick-lru": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", @@ -111,6 +130,61 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/@asamuzakjp/css-color": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/@asamuzakjp/css-color/-/css-color-4.1.0.tgz", + "integrity": "sha512-9xiBAtLn4aNsa4mDnpovJvBn72tNEIACyvlqaNJ+ADemR+yeMJWnBudOi2qGDviJa7SwcDOU/TRh5dnET7qk0w==", + "dev": true, + "license": "MIT", + "dependencies": { + "@csstools/css-calc": "^2.1.4", + "@csstools/css-color-parser": "^3.1.0", + "@csstools/css-parser-algorithms": "^3.0.5", + "@csstools/css-tokenizer": "^3.0.4", + "lru-cache": "^11.2.2" + } + }, + "node_modules/@asamuzakjp/css-color/node_modules/lru-cache": { + "version": "11.2.4", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.4.tgz", + "integrity": "sha512-B5Y16Jr9LB9dHVkh6ZevG+vAbOsNOYCX+sXvFWFu7B3Iz5mijW3zdbMyhsh8ANd2mSWBYdJgnqi+mL7/LrOPYg==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@asamuzakjp/dom-selector": { + "version": "6.7.5", + "resolved": "https://registry.npmjs.org/@asamuzakjp/dom-selector/-/dom-selector-6.7.5.tgz", + "integrity": "sha512-Eks6dY8zau4m4wNRQjRVaKQRTalNcPcBvU1ZQ35w5kKRk1gUeNCkVLsRiATurjASTp3TKM4H10wsI50nx3NZdw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@asamuzakjp/nwsapi": "^2.3.9", + "bidi-js": "^1.0.3", + "css-tree": "^3.1.0", + "is-potential-custom-element-name": "^1.0.1", + "lru-cache": "^11.2.2" + } + }, + "node_modules/@asamuzakjp/dom-selector/node_modules/lru-cache": { + "version": "11.2.4", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.4.tgz", + "integrity": "sha512-B5Y16Jr9LB9dHVkh6ZevG+vAbOsNOYCX+sXvFWFu7B3Iz5mijW3zdbMyhsh8ANd2mSWBYdJgnqi+mL7/LrOPYg==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@asamuzakjp/nwsapi": { + "version": "2.3.9", + "resolved": "https://registry.npmjs.org/@asamuzakjp/nwsapi/-/nwsapi-2.3.9.tgz", + "integrity": "sha512-n8GuYSrI9bF7FFZ/SjhwevlHc8xaVlb/7HmHelnc/PZXBD2ZR49NnN9sMMuDdEGPeeRQ5d0hqlSlEpgCX3Wl0Q==", + "dev": true, + "license": "MIT" + }, "node_modules/@babel/code-frame": { "version": "7.27.1", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", @@ -510,6 +584,143 @@ "w3c-keyname": "^2.2.4" } }, + "node_modules/@csstools/color-helpers": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/@csstools/color-helpers/-/color-helpers-5.1.0.tgz", + "integrity": "sha512-S11EXWJyy0Mz5SYvRmY8nJYTFFd1LCNV+7cXyAgQtOOuzb4EsgfqDufL+9esx72/eLhsRdGZwaldu/h+E4t4BA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT-0", + "engines": { + "node": ">=18" + } + }, + "node_modules/@csstools/css-calc": { + "version": "2.1.4", + "resolved": "https://registry.npmjs.org/@csstools/css-calc/-/css-calc-2.1.4.tgz", + "integrity": "sha512-3N8oaj+0juUw/1H3YwmDDJXCgTB1gKU6Hc/bB502u9zR0q2vd786XJH9QfrKIEgFlZmhZiq6epXl4rHqhzsIgQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@csstools/css-parser-algorithms": "^3.0.5", + "@csstools/css-tokenizer": "^3.0.4" + } + }, + "node_modules/@csstools/css-color-parser": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@csstools/css-color-parser/-/css-color-parser-3.1.0.tgz", + "integrity": "sha512-nbtKwh3a6xNVIp/VRuXV64yTKnb1IjTAEEh3irzS+HkKjAOYLTGNb9pmVNntZ8iVBHcWDA2Dof0QtPgFI1BaTA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "dependencies": { + "@csstools/color-helpers": "^5.1.0", + "@csstools/css-calc": "^2.1.4" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@csstools/css-parser-algorithms": "^3.0.5", + "@csstools/css-tokenizer": "^3.0.4" + } + }, + "node_modules/@csstools/css-parser-algorithms": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/@csstools/css-parser-algorithms/-/css-parser-algorithms-3.0.5.tgz", + "integrity": "sha512-DaDeUkXZKjdGhgYaHNJTV9pV7Y9B3b644jCLs9Upc3VeNGg6LWARAT6O+Q+/COo+2gg/bM5rhpMAtf70WqfBdQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "peer": true, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@csstools/css-tokenizer": "^3.0.4" + } + }, + "node_modules/@csstools/css-syntax-patches-for-csstree": { + "version": "1.0.20", + "resolved": "https://registry.npmjs.org/@csstools/css-syntax-patches-for-csstree/-/css-syntax-patches-for-csstree-1.0.20.tgz", + "integrity": "sha512-8BHsjXfSciZxjmHQOuVdW2b8WLUPts9a+mfL13/PzEviufUEW2xnvQuOlKs9dRBHgRqJ53SF/DUoK9+MZk72oQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT-0", + "engines": { + "node": ">=18" + } + }, + "node_modules/@csstools/css-tokenizer": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@csstools/css-tokenizer/-/css-tokenizer-3.0.4.tgz", + "integrity": "sha512-Vd/9EVDiu6PPJt9yAh6roZP6El1xHrdvIVGjyBsHR0RYwNHgL7FJPyIIW4fANJNG6FtyZfvlRPpFI4ZM/lubvw==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "peer": true, + "engines": { + "node": ">=18" + } + }, "node_modules/@date-fns/tz": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/@date-fns/tz/-/tz-1.4.1.tgz", @@ -1679,6 +1890,13 @@ "react-dom": ">= 16.8" } }, + "node_modules/@polka/url": { + "version": "1.0.0-next.29", + "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.29.tgz", + "integrity": "sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww==", + "dev": true, + "license": "MIT" + }, "node_modules/@radix-ui/colors": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/@radix-ui/colors/-/colors-3.0.0.tgz", @@ -3587,6 +3805,13 @@ "win32" ] }, + "node_modules/@standard-schema/spec": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.0.0.tgz", + "integrity": "sha512-m2bOd0f2RT9k8QJx1JN85cZYyH1RqFBdlwtkSlf4tBDYLCiiZnv1fIIwacK6cqwXavOydf0NPToMQgpKq+dVlA==", + "dev": true, + "license": "MIT" + }, "node_modules/@standard-schema/utils": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/@standard-schema/utils/-/utils-0.3.0.tgz", @@ -4125,6 +4350,89 @@ "react": "^18 || ^19" } }, + "node_modules/@testing-library/dom": { + "version": "10.4.1", + "resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-10.4.1.tgz", + "integrity": "sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "@babel/code-frame": "^7.10.4", + "@babel/runtime": "^7.12.5", + "@types/aria-query": "^5.0.1", + "aria-query": "5.3.0", + "dom-accessibility-api": "^0.5.9", + "lz-string": "^1.5.0", + "picocolors": "1.1.1", + "pretty-format": "^27.0.2" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@testing-library/jest-dom": { + "version": "6.9.1", + "resolved": "https://registry.npmjs.org/@testing-library/jest-dom/-/jest-dom-6.9.1.tgz", + "integrity": "sha512-zIcONa+hVtVSSep9UT3jZ5rizo2BsxgyDYU7WFD5eICBE7no3881HGeb/QkGfsJs6JTkY1aQhT7rIPC7e+0nnA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@adobe/css-tools": "^4.4.0", + "aria-query": "^5.0.0", + "css.escape": "^1.5.1", + "dom-accessibility-api": "^0.6.3", + "picocolors": "^1.1.1", + "redent": "^3.0.0" + }, + "engines": { + "node": ">=14", + "npm": ">=6", + "yarn": ">=1" + } + }, + "node_modules/@testing-library/jest-dom/node_modules/dom-accessibility-api": { + "version": "0.6.3", + "resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.6.3.tgz", + "integrity": "sha512-7ZgogeTnjuHbo+ct10G9Ffp0mif17idi0IyWNVA/wcwcm7NPOD/WEHVP3n7n3MhXqxoIYm8d6MuZohYWIZ4T3w==", + "dev": true, + "license": "MIT" + }, + "node_modules/@testing-library/react": { + "version": "16.3.0", + "resolved": "https://registry.npmjs.org/@testing-library/react/-/react-16.3.0.tgz", + "integrity": "sha512-kFSyxiEDwv1WLl2fgsq6pPBbw5aWKrsY2/noi1Id0TK0UParSF62oFQFGHXIyaG4pp2tEub/Zlel+fjjZILDsw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.12.5" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@testing-library/dom": "^10.0.0", + "@types/react": "^18.0.0 || ^19.0.0", + "@types/react-dom": "^18.0.0 || ^19.0.0", + "react": "^18.0.0 || ^19.0.0", + "react-dom": "^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@types/aria-query": { + "version": "5.0.4", + "resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz", + "integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/aws-lambda": { "version": "8.10.157", "resolved": "https://registry.npmjs.org/@types/aws-lambda/-/aws-lambda-8.10.157.tgz", @@ -4176,6 +4484,17 @@ "@babel/types": "^7.28.2" } }, + "node_modules/@types/chai": { + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/@types/chai/-/chai-5.2.3.tgz", + "integrity": "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/deep-eql": "*", + "assertion-error": "^2.0.1" + } + }, "node_modules/@types/d3-array": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz", @@ -4282,6 +4601,13 @@ "@types/ms": "*" } }, + "node_modules/@types/deep-eql": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz", + "integrity": "sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/estree": { "version": "1.0.8", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", @@ -4752,6 +5078,140 @@ "vite": "^4 || ^5 || ^6 || ^7" } }, + "node_modules/@vitest/expect": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-4.0.15.tgz", + "integrity": "sha512-Gfyva9/GxPAWXIWjyGDli9O+waHDC0Q0jaLdFP1qPAUUfo1FEXPXUfUkp3eZA0sSq340vPycSyOlYUeM15Ft1w==", + "dev": true, + "license": "MIT", + "dependencies": { + "@standard-schema/spec": "^1.0.0", + "@types/chai": "^5.2.2", + "@vitest/spy": "4.0.15", + "@vitest/utils": "4.0.15", + "chai": "^6.2.1", + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/mocker": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-4.0.15.tgz", + "integrity": "sha512-CZ28GLfOEIFkvCFngN8Sfx5h+Se0zN+h4B7yOsPVCcgtiO7t5jt9xQh2E1UkFep+eb9fjyMfuC5gBypwb07fvQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/spy": "4.0.15", + "estree-walker": "^3.0.3", + "magic-string": "^0.30.21" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "msw": "^2.4.9", + "vite": "^6.0.0 || ^7.0.0-0" + }, + "peerDependenciesMeta": { + "msw": { + "optional": true + }, + "vite": { + "optional": true + } + } + }, + "node_modules/@vitest/pretty-format": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-4.0.15.tgz", + "integrity": "sha512-SWdqR8vEv83WtZcrfLNqlqeQXlQLh2iilO1Wk1gv4eiHKjEzvgHb2OVc3mIPyhZE6F+CtfYjNlDJwP5MN6Km7A==", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/runner": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-4.0.15.tgz", + "integrity": "sha512-+A+yMY8dGixUhHmNdPUxOh0la6uVzun86vAbuMT3hIDxMrAOmn5ILBHm8ajrqHE0t8R9T1dGnde1A5DTnmi3qw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/utils": "4.0.15", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/snapshot": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-4.0.15.tgz", + "integrity": "sha512-A7Ob8EdFZJIBjLjeO0DZF4lqR6U7Ydi5/5LIZ0xcI+23lYlsYJAfGn8PrIWTYdZQRNnSRlzhg0zyGu37mVdy5g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.0.15", + "magic-string": "^0.30.21", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/spy": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-4.0.15.tgz", + "integrity": "sha512-+EIjOJmnY6mIfdXtE/bnozKEvTC4Uczg19yeZ2vtCz5Yyb0QQ31QWVQ8hswJ3Ysx/K2EqaNsVanjr//2+P3FHw==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/ui": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@vitest/ui/-/ui-4.0.15.tgz", + "integrity": "sha512-sxSyJMaKp45zI0u+lHrPuZM1ZJQ8FaVD35k+UxVrha1yyvQ+TZuUYllUixwvQXlB7ixoDc7skf3lQPopZIvaQw==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "@vitest/utils": "4.0.15", + "fflate": "^0.8.2", + "flatted": "^3.3.3", + "pathe": "^2.0.3", + "sirv": "^3.0.2", + "tinyglobby": "^0.2.15", + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "vitest": "4.0.15" + } + }, + "node_modules/@vitest/utils": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-4.0.15.tgz", + "integrity": "sha512-HXjPW2w5dxhTD0dLwtYHDnelK3j8sR8cWIaLxr22evTyY6q8pRCjZSmhRWVjBaOVXChQd6AwMzi9pucorXCPZA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.0.15", + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, "node_modules/@xyflow/react": { "version": "12.9.2", "resolved": "https://registry.npmjs.org/@xyflow/react/-/react-12.9.2.tgz", @@ -4808,6 +5268,16 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, + "node_modules/agent-base": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-7.1.4.tgz", + "integrity": "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 14" + } + }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -4825,8 +5295,18 @@ "url": "https://github.com/sponsors/epoberezkin" } }, - "node_modules/ansi-styles": { - "version": "4.3.0", + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", "dev": true, @@ -4860,6 +5340,26 @@ "node": ">=10" } }, + "node_modules/aria-query": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/aria-query/-/aria-query-5.3.0.tgz", + "integrity": "sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "dequal": "^2.0.3" + } + }, + "node_modules/assertion-error": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz", + "integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + } + }, "node_modules/bail": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/bail/-/bail-2.0.2.tgz", @@ -4893,6 +5393,16 @@ "integrity": "sha512-Nik3Sc0ncrMK4UUdXQmAnRtzmNQTAAXmXIopizwZ1W1t8QmfJj+zL4OA2I7XPTPW5z5TDqv4hRo/JzouDJnX3A==", "license": "Apache-2.0" }, + "node_modules/bidi-js": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/bidi-js/-/bidi-js-1.0.3.tgz", + "integrity": "sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==", + "dev": true, + "license": "MIT", + "dependencies": { + "require-from-string": "^2.0.2" + } + }, "node_modules/bottleneck": { "version": "2.19.5", "resolved": "https://registry.npmjs.org/bottleneck/-/bottleneck-2.19.5.tgz", @@ -4999,6 +5509,16 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/chai": { + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/chai/-/chai-6.2.1.tgz", + "integrity": "sha512-p4Z49OGG5W/WBCPSS/dH3jQ73kD6tiMmUM+bckNK6Jr5JHMG3k9bg/BvKR8lKmtVBKmOiuVaV2ws8s9oSbwysg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -5197,6 +5717,42 @@ "node": ">= 8" } }, + "node_modules/css-tree": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-3.1.0.tgz", + "integrity": "sha512-0eW44TGN5SQXU1mWSkKwFstI/22X2bG1nYzZTYMAWjylYURhse752YgbE4Cx46AC+bAvI+/dYTPRk1LqSUnu6w==", + "dev": true, + "license": "MIT", + "dependencies": { + "mdn-data": "2.12.2", + "source-map-js": "^1.0.1" + }, + "engines": { + "node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0" + } + }, + "node_modules/css.escape": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/css.escape/-/css.escape-1.5.1.tgz", + "integrity": "sha512-YUifsXXuknHlUsmlgyY0PKzgPOr7/FjCePfHNt0jxm83wHZi44VDMQ7/fGNkjY3/jV1MC+1CmZbaHzugyeRtpg==", + "dev": true, + "license": "MIT" + }, + "node_modules/cssstyle": { + "version": "5.3.3", + "resolved": "https://registry.npmjs.org/cssstyle/-/cssstyle-5.3.3.tgz", + "integrity": "sha512-OytmFH+13/QXONJcC75QNdMtKpceNk3u8ThBjyyYjkEcy/ekBwR1mMAuNvi3gdBPW3N5TlCzQ0WZw8H0lN/bDw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@asamuzakjp/css-color": "^4.0.3", + "@csstools/css-syntax-patches-for-csstree": "^1.0.14", + "css-tree": "^3.1.0" + }, + "engines": { + "node": ">=20" + } + }, "node_modules/csstype": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", @@ -5605,6 +6161,20 @@ "node": ">=12" } }, + "node_modules/data-urls": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/data-urls/-/data-urls-6.0.0.tgz", + "integrity": "sha512-BnBS08aLUM+DKamupXs3w2tJJoqU+AkaE/+6vQxi/G/DPmIZFJJp9Dkb1kM03AZx8ADehDUZgsNxju3mPXZYIA==", + "dev": true, + "license": "MIT", + "dependencies": { + "whatwg-mimetype": "^4.0.0", + "whatwg-url": "^15.0.0" + }, + "engines": { + "node": ">=20" + } + }, "node_modules/date-fns": { "version": "3.6.0", "resolved": "https://registry.npmjs.org/date-fns/-/date-fns-3.6.0.tgz", @@ -5638,6 +6208,13 @@ } } }, + "node_modules/decimal.js": { + "version": "10.6.0", + "resolved": "https://registry.npmjs.org/decimal.js/-/decimal.js-10.6.0.tgz", + "integrity": "sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==", + "dev": true, + "license": "MIT" + }, "node_modules/decimal.js-light": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/decimal.js-light/-/decimal.js-light-2.5.1.tgz", @@ -5710,6 +6287,13 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/dom-accessibility-api": { + "version": "0.5.16", + "resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz", + "integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==", + "dev": true, + "license": "MIT" + }, "node_modules/dom-helpers": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/dom-helpers/-/dom-helpers-5.2.1.tgz", @@ -5769,6 +6353,26 @@ "node": ">=10.13.0" } }, + "node_modules/entities": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/entities/-/entities-6.0.1.tgz", + "integrity": "sha512-aN97NXWF6AWBTahfVOIrB/NShkzi5H7F9r1s9mD3cDj4Ko5f2qhhVoYMibXF7GlLveb/D2ioWay8lxI97Ven3g==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.12" + }, + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, + "node_modules/es-module-lexer": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.7.0.tgz", + "integrity": "sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==", + "dev": true, + "license": "MIT" + }, "node_modules/esbuild": { "version": "0.25.12", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.12.tgz", @@ -6011,6 +6615,16 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/estree-walker": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", + "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -6027,6 +6641,16 @@ "integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==", "license": "MIT" }, + "node_modules/expect-type": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.2.2.tgz", + "integrity": "sha512-JhFGDVJ7tmDJItKhYgJCGLOWjuK9vPxiXoUFLwLDc99NlmklilbiQJwoctZtt13+xMw91MCk/REan6MWHqDjyA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/extend": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", @@ -6119,6 +6743,13 @@ "reusify": "^1.0.4" } }, + "node_modules/fflate": { + "version": "0.8.2", + "resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz", + "integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==", + "dev": true, + "license": "MIT" + }, "node_modules/file-entry-cache": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz", @@ -6341,6 +6972,19 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/html-encoding-sniffer": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-4.0.0.tgz", + "integrity": "sha512-Y22oTqIU4uuPgEemfz7NDJz6OeKf12Lsu+QC+s3BVpda64lTiMYCyGwg5ki4vFxkMwQdeZDl2adZoqUgdFuTgQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "whatwg-encoding": "^3.1.1" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/html-url-attributes": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.1.tgz", @@ -6351,6 +6995,34 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/http-proxy-agent": { + "version": "7.0.2", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-7.0.2.tgz", + "integrity": "sha512-T1gkAiYYDWYx3V5Bmyu7HcfcvL7mUrTWiM6yOfa3PIphViJ/gFPbvidQ+veqSOHci/PxBcDabeUNCzpOODJZig==", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.0", + "debug": "^4.3.4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/https-proxy-agent": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.6.tgz", + "integrity": "sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "4" + }, + "engines": { + "node": ">= 14" + } + }, "node_modules/iconv-lite": { "version": "0.6.3", "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz", @@ -6400,6 +7072,16 @@ "node": ">=0.8.19" } }, + "node_modules/indent-string": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz", + "integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, "node_modules/inline-style-parser": { "version": "0.2.6", "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.6.tgz", @@ -6514,6 +7196,13 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/is-potential-custom-element-name": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-potential-custom-element-name/-/is-potential-custom-element-name-1.0.1.tgz", + "integrity": "sha512-bCYeRA2rVibKZd+s2625gGnGF/t7DSqDs4dP7CrLA1m7jKWz6pps0LpYLJN8Q64HtmPKJ1hrN3nzPNKFEKOUiQ==", + "dev": true, + "license": "MIT" + }, "node_modules/isexe": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", @@ -6549,6 +7238,47 @@ "js-yaml": "bin/js-yaml.js" } }, + "node_modules/jsdom": { + "version": "27.2.0", + "resolved": "https://registry.npmjs.org/jsdom/-/jsdom-27.2.0.tgz", + "integrity": "sha512-454TI39PeRDW1LgpyLPyURtB4Zx1tklSr6+OFOipsxGUH1WMTvk6C65JQdrj455+DP2uJ1+veBEHTGFKWVLFoA==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "@acemir/cssom": "^0.9.23", + "@asamuzakjp/dom-selector": "^6.7.4", + "cssstyle": "^5.3.3", + "data-urls": "^6.0.0", + "decimal.js": "^10.6.0", + "html-encoding-sniffer": "^4.0.0", + "http-proxy-agent": "^7.0.2", + "https-proxy-agent": "^7.0.6", + "is-potential-custom-element-name": "^1.0.1", + "parse5": "^8.0.0", + "saxes": "^6.0.0", + "symbol-tree": "^3.2.4", + "tough-cookie": "^6.0.0", + "w3c-xmlserializer": "^5.0.0", + "webidl-conversions": "^8.0.0", + "whatwg-encoding": "^3.1.1", + "whatwg-mimetype": "^4.0.0", + "whatwg-url": "^15.1.0", + "ws": "^8.18.3", + "xml-name-validator": "^5.0.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + }, + "peerDependencies": { + "canvas": "^3.0.0" + }, + "peerDependenciesMeta": { + "canvas": { + "optional": true + } + } + }, "node_modules/jsesc": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-3.1.0.tgz", @@ -6939,6 +7669,16 @@ "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, + "node_modules/lz-string": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz", + "integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==", + "dev": true, + "license": "MIT", + "bin": { + "lz-string": "bin/bin.js" + } + }, "node_modules/magic-string": { "version": "0.30.21", "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.21.tgz", @@ -7113,6 +7853,13 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/mdn-data": { + "version": "2.12.2", + "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.12.2.tgz", + "integrity": "sha512-IEn+pegP1aManZuckezWCO+XZQDplx1366JoVhTpMpBB1sPey/SbveZQUosKiKiGYjg1wH4pMlNgXbCiYgihQA==", + "dev": true, + "license": "CC0-1.0" + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -7579,6 +8326,16 @@ "node": ">=8.6" } }, + "node_modules/min-indent": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz", + "integrity": "sha512-I9jwMn07Sy/IwOj3zVkVik2JTvgpaykDZEigL6Rx6N9LbMywwUSMtxET+7lVoDLLd3O3IXwJwvuuns8UB/HeAg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, "node_modules/minimatch": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", @@ -7607,6 +8364,16 @@ "integrity": "sha512-eAWoPgr4eFEOFfg2WjIsMoqJTW6Z8MTUCgn/GZ3VRpClWBdnbjryiA3ZSNLyxCTmCQx4RmYX6jX1iWHbenUPNQ==", "license": "MIT" }, + "node_modules/mrmime": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-2.0.1.tgz", + "integrity": "sha512-Y3wQdFg2Va6etvQ5I82yUhGdsKrcYox6p7FfL1LbK2J4V01F9TGlepTIhnK24t7koZibmg82KGglhA1XK5IsLQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + } + }, "node_modules/ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", @@ -7664,6 +8431,17 @@ "node": ">=0.10.0" } }, + "node_modules/obug": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/obug/-/obug-2.1.1.tgz", + "integrity": "sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/sxzz", + "https://opencollective.com/debug" + ], + "license": "MIT" + }, "node_modules/octokit": { "version": "4.1.4", "resolved": "https://registry.npmjs.org/octokit/-/octokit-4.1.4.tgz", @@ -7774,6 +8552,19 @@ "integrity": "sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==", "license": "MIT" }, + "node_modules/parse5": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/parse5/-/parse5-8.0.0.tgz", + "integrity": "sha512-9m4m5GSgXjL4AjumKzq1Fgfp3Z8rsvjRNbnkVwfu2ImRqE5D0LnY2QfDen18FSY9C573YU5XxSapdHZTZ2WolA==", + "dev": true, + "license": "MIT", + "dependencies": { + "entities": "^6.0.0" + }, + "funding": { + "url": "https://github.com/inikulin/parse5?sponsor=1" + } + }, "node_modules/path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", @@ -7794,6 +8585,13 @@ "node": ">=8" } }, + "node_modules/pathe": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", + "integrity": "sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==", + "dev": true, + "license": "MIT" + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", @@ -7851,6 +8649,41 @@ "node": ">= 0.8.0" } }, + "node_modules/pretty-format": { + "version": "27.5.1", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-27.5.1.tgz", + "integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1", + "ansi-styles": "^5.0.0", + "react-is": "^17.0.1" + }, + "engines": { + "node": "^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0" + } + }, + "node_modules/pretty-format/node_modules/ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/pretty-format/node_modules/react-is": { + "version": "17.0.2", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", + "integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==", + "dev": true, + "license": "MIT" + }, "node_modules/prop-types": { "version": "15.8.1", "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", @@ -8215,6 +9048,20 @@ "decimal.js-light": "^2.4.1" } }, + "node_modules/redent": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/redent/-/redent-3.0.0.tgz", + "integrity": "sha512-6tDA8g98We0zd0GvVeMT9arEOnTw9qM03L9cJXaCjrip1OO764RDBLBfrB4cwzNGDj5OA5ioymC9GkizgWJDUg==", + "dev": true, + "license": "MIT", + "dependencies": { + "indent-string": "^4.0.0", + "strip-indent": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/remark-parse": { "version": "11.0.0", "resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz", @@ -8248,6 +9095,16 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/resolve-from": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", @@ -8352,6 +9209,19 @@ "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", "license": "MIT" }, + "node_modules/saxes": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/saxes/-/saxes-6.0.0.tgz", + "integrity": "sha512-xAg7SOnEhrm5zI3puOOKyy1OMcMlIJZYNJY7xLBwSze0UjhPLnWfj2GF2EpT0jmzaJKIWKHLsaSSajf35bcYnA==", + "dev": true, + "license": "ISC", + "dependencies": { + "xmlchars": "^2.2.0" + }, + "engines": { + "node": ">=v12.22.7" + } + }, "node_modules/scheduler": { "version": "0.27.0", "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.27.0.tgz", @@ -8397,6 +9267,28 @@ "node": ">=8" } }, + "node_modules/siginfo": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz", + "integrity": "sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==", + "dev": true, + "license": "ISC" + }, + "node_modules/sirv": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/sirv/-/sirv-3.0.2.tgz", + "integrity": "sha512-2wcC/oGxHis/BoHkkPwldgiPSYcpZK3JU28WoMVv55yHJgcZ8rlXvuG9iZggz+sU1d4bRgIGASwyWqjxu3FM0g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@polka/url": "^1.0.0-next.24", + "mrmime": "^2.0.0", + "totalist": "^3.0.0" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/sonner": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/sonner/-/sonner-2.0.7.tgz", @@ -8426,6 +9318,20 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/stackback": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz", + "integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==", + "dev": true, + "license": "MIT" + }, + "node_modules/std-env": { + "version": "3.10.0", + "resolved": "https://registry.npmjs.org/std-env/-/std-env-3.10.0.tgz", + "integrity": "sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==", + "dev": true, + "license": "MIT" + }, "node_modules/stringify-entities": { "version": "4.0.4", "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz", @@ -8440,6 +9346,19 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/strip-indent": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/strip-indent/-/strip-indent-3.0.0.tgz", + "integrity": "sha512-laJTa3Jb+VQpaC6DseHhF7dXVqHTfJPCRDaEbid/drOhgitgYku/letMUqOXFoWV0zIIUbjpdH2t+tYj4bQMRQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "min-indent": "^1.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/strip-json-comments": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", @@ -8490,6 +9409,13 @@ "node": ">=8" } }, + "node_modules/symbol-tree": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/symbol-tree/-/symbol-tree-3.2.4.tgz", + "integrity": "sha512-9QNk5KwDF+Bvz+PyObkmSYjI5ksVUYtjW7AU22r2NKcfLJcXp96hkDWU3+XndOsUb+AQ9QhfzfCT2O+CNWT5Tw==", + "dev": true, + "license": "MIT" + }, "node_modules/tailwind-merge": { "version": "3.3.1", "resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-3.3.1.tgz", @@ -8532,6 +9458,23 @@ "integrity": "sha512-+FbBPE1o9QAYvviau/qC5SE3caw21q3xkvWKBtja5vgqOWIHHJ3ioaq1VPfn/Szqctz2bU/oYeKd9/z5BL+PVg==", "license": "MIT" }, + "node_modules/tinybench": { + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.9.0.tgz", + "integrity": "sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==", + "dev": true, + "license": "MIT" + }, + "node_modules/tinyexec": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/tinyexec/-/tinyexec-1.0.2.tgz", + "integrity": "sha512-W/KYk+NFhkmsYpuHq5JykngiOCnxeVL8v8dFnqxSD8qEEdRfXk1SDM6JzNqcERbcGYj9tMrDQBYV9cjgnunFIg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -8578,6 +9521,36 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/tinyrainbow": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-3.0.3.tgz", + "integrity": "sha512-PSkbLUoxOFRzJYjjxHJt9xro7D+iilgMX/C9lawzVuYiIdcihh9DXmVibBe8lmcFrRi/VzlPjBxbN7rH24q8/Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/tldts": { + "version": "7.0.19", + "resolved": "https://registry.npmjs.org/tldts/-/tldts-7.0.19.tgz", + "integrity": "sha512-8PWx8tvC4jDB39BQw1m4x8y5MH1BcQ5xHeL2n7UVFulMPH/3Q0uiamahFJ3lXA0zO2SUyRXuVVbWSDmstlt9YA==", + "dev": true, + "license": "MIT", + "dependencies": { + "tldts-core": "^7.0.19" + }, + "bin": { + "tldts": "bin/cli.js" + } + }, + "node_modules/tldts-core": { + "version": "7.0.19", + "resolved": "https://registry.npmjs.org/tldts-core/-/tldts-core-7.0.19.tgz", + "integrity": "sha512-lJX2dEWx0SGH4O6p+7FPwYmJ/bu1JbcGJ8RLaG9b7liIgZ85itUVEPbMtWRVrde/0fnDPEPHW10ZsKW3kVsE9A==", + "dev": true, + "license": "MIT" + }, "node_modules/to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", @@ -8600,6 +9573,42 @@ "node": ">=12" } }, + "node_modules/totalist": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/totalist/-/totalist-3.0.1.tgz", + "integrity": "sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/tough-cookie": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-6.0.0.tgz", + "integrity": "sha512-kXuRi1mtaKMrsLUxz3sQYvVl37B0Ns6MzfrtV5DvJceE9bPyspOqk9xxv7XbZWcfLWbFmm997vl83qUWVJA64w==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "tldts": "^7.0.5" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/tr46": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-6.0.0.tgz", + "integrity": "sha512-bLVMLPtstlZ4iMQHpFHTR7GAGj2jxi8Dg0s2h2MafAE4uSWF98FC/3MomU51iQAMf8/qDUbKWf5GxuvvVcXEhw==", + "dev": true, + "license": "MIT", + "dependencies": { + "punycode": "^2.3.1" + }, + "engines": { + "node": ">=20" + } + }, "node_modules/trim-lines": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz", @@ -9073,12 +10082,164 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/vitest": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-4.0.15.tgz", + "integrity": "sha512-n1RxDp8UJm6N0IbJLQo+yzLZ2sQCDyl1o0LeugbPWf8+8Fttp29GghsQBjYJVmWq3gBFfe9Hs1spR44vovn2wA==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "@vitest/expect": "4.0.15", + "@vitest/mocker": "4.0.15", + "@vitest/pretty-format": "4.0.15", + "@vitest/runner": "4.0.15", + "@vitest/snapshot": "4.0.15", + "@vitest/spy": "4.0.15", + "@vitest/utils": "4.0.15", + "es-module-lexer": "^1.7.0", + "expect-type": "^1.2.2", + "magic-string": "^0.30.21", + "obug": "^2.1.1", + "pathe": "^2.0.3", + "picomatch": "^4.0.3", + "std-env": "^3.10.0", + "tinybench": "^2.9.0", + "tinyexec": "^1.0.2", + "tinyglobby": "^0.2.15", + "tinyrainbow": "^3.0.3", + "vite": "^6.0.0 || ^7.0.0", + "why-is-node-running": "^2.3.0" + }, + "bin": { + "vitest": "vitest.mjs" + }, + "engines": { + "node": "^20.0.0 || ^22.0.0 || >=24.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@edge-runtime/vm": "*", + "@opentelemetry/api": "^1.9.0", + "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", + "@vitest/browser-playwright": "4.0.15", + "@vitest/browser-preview": "4.0.15", + "@vitest/browser-webdriverio": "4.0.15", + "@vitest/ui": "4.0.15", + "happy-dom": "*", + "jsdom": "*" + }, + "peerDependenciesMeta": { + "@edge-runtime/vm": { + "optional": true + }, + "@opentelemetry/api": { + "optional": true + }, + "@types/node": { + "optional": true + }, + "@vitest/browser-playwright": { + "optional": true + }, + "@vitest/browser-preview": { + "optional": true + }, + "@vitest/browser-webdriverio": { + "optional": true + }, + "@vitest/ui": { + "optional": true + }, + "happy-dom": { + "optional": true + }, + "jsdom": { + "optional": true + } + } + }, + "node_modules/vitest/node_modules/picomatch": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, "node_modules/w3c-keyname": { "version": "2.2.8", "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==", "license": "MIT" }, + "node_modules/w3c-xmlserializer": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-5.0.0.tgz", + "integrity": "sha512-o8qghlI8NZHU1lLPrpi2+Uq7abh4GGPpYANlalzWxyWteJOCsr/P+oPBA49TOLu5FTZO4d3F9MnWJfiMo4BkmA==", + "dev": true, + "license": "MIT", + "dependencies": { + "xml-name-validator": "^5.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/webidl-conversions": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-8.0.0.tgz", + "integrity": "sha512-n4W4YFyz5JzOfQeA8oN7dUYpR+MBP3PIUsn2jLjWXwK5ASUzt0Jc/A5sAUZoCYFJRGF0FBKJ+1JjN43rNdsQzA==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=20" + } + }, + "node_modules/whatwg-encoding": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/whatwg-encoding/-/whatwg-encoding-3.1.1.tgz", + "integrity": "sha512-6qN4hJdMwfYBtE3YBTTHhoeuUrDBPZmbQaxWAqSALV/MeEnR5z1xd8UKud2RAkFoPkmB+hli1TZSnyi84xz1vQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "iconv-lite": "0.6.3" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/whatwg-mimetype": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-4.0.0.tgz", + "integrity": "sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/whatwg-url": { + "version": "15.1.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-15.1.0.tgz", + "integrity": "sha512-2ytDk0kiEj/yu90JOAp44PVPUkO9+jVhyf+SybKlRHSDlvOOZhdPIrr7xTH64l4WixO2cP+wQIcgujkGBPPz6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "tr46": "^6.0.0", + "webidl-conversions": "^8.0.0" + }, + "engines": { + "node": ">=20" + } + }, "node_modules/which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -9095,6 +10256,23 @@ "node": ">= 8" } }, + "node_modules/why-is-node-running": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz", + "integrity": "sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==", + "dev": true, + "license": "MIT", + "dependencies": { + "siginfo": "^2.0.0", + "stackback": "0.0.2" + }, + "bin": { + "why-is-node-running": "cli.js" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/word-wrap": { "version": "1.2.5", "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", @@ -9105,6 +10283,45 @@ "node": ">=0.10.0" } }, + "node_modules/ws": { + "version": "8.18.3", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz", + "integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xml-name-validator": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/xml-name-validator/-/xml-name-validator-5.0.0.tgz", + "integrity": "sha512-EvGK8EJ3DhaHfbRlETOWAS5pO9MZITeauHKJyb8wyajUfQUenkIg2MvLDTZ4T/TgIcm3HU0TFBgWWboAZ30UHg==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18" + } + }, + "node_modules/xmlchars": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz", + "integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==", + "dev": true, + "license": "MIT" + }, "node_modules/yallist": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", diff --git a/project/frontend/package.json b/project/frontend/package.json index 3134beb..170a4c7 100644 --- a/project/frontend/package.json +++ b/project/frontend/package.json @@ -9,7 +9,10 @@ "build": "tsc -b --noCheck && vite build", "lint": "eslint .", "optimize": "vite optimize", - "preview": "vite preview" + "preview": "vite preview", + "test": "vitest --run", + "test:watch": "vitest", + "test:ui": "vitest --ui" }, "dependencies": { "@codemirror/lang-python": "^6.2.1", @@ -83,18 +86,23 @@ "devDependencies": { "@eslint/js": "^9.21.0", "@tailwindcss/postcss": "^4.1.8", + "@testing-library/jest-dom": "^6.9.1", + "@testing-library/react": "^16.3.0", "@types/react": "^19.0.10", "@types/react-dom": "^19.0.4", "@vitejs/plugin-react": "^4.3.4", "@vitejs/plugin-react-swc": "^3.10.1", + "@vitest/ui": "^4.0.15", "eslint": "^9.28.0", "eslint-plugin-react-hooks": "^5.2.0", "eslint-plugin-react-refresh": "^0.4.19", "globals": "^16.0.0", + "jsdom": "^27.2.0", "tailwindcss": "^4.1.11", "typescript": "~5.7.2", "typescript-eslint": "^8.38.0", - "vite": "^6.3.5" + "vite": "^6.3.5", + "vitest": "^4.0.15" }, "workspaces": { "packages": [ diff --git a/project/frontend/src/App.tsx b/project/frontend/src/App.tsx index 593f9d9..6260953 100644 --- a/project/frontend/src/App.tsx +++ b/project/frontend/src/App.tsx @@ -14,7 +14,7 @@ import { LandingPage } from './landing' function ProjectCanvas() { const { projectId } = useParams<{ projectId: string }>() const navigate = useNavigate() - const { setNodes, setEdges, loadProject, currentProject, reset } = useModelBuilderStore() + const { setNodes, setEdges, loadProject, loadGroupDefinitions, currentProject, reset } = useModelBuilderStore() const [isLoading, setIsLoading] = useState(false) const [draggedType, setDraggedType] = useState(null) const { selectedNodeId } = useModelBuilderStore() @@ -28,9 +28,14 @@ function ProjectCanvas() { .then(async (backendProject) => { // Load architecture if it exists try { - const { nodes, edges } = await loadArchitecture(projectId) + const { nodes, edges, groupDefinitions } = await loadArchitecture(projectId) const project = convertToFrontendProject(backendProject, nodes, edges) loadProject(project) + + // Load group definitions if they exist + if (groupDefinitions && groupDefinitions.length > 0) { + loadGroupDefinitions(groupDefinitions) + } } catch (error) { // No architecture yet, just load project metadata const project = convertToFrontendProject(backendProject) @@ -48,7 +53,7 @@ function ProjectCanvas() { setIsLoading(false) }) } - }, [projectId, currentProject, setNodes, setEdges, loadProject, navigate]) + }, [projectId, currentProject, setNodes, setEdges, loadProject, loadGroupDefinitions, navigate]) const handleDragStart = (type: string) => { setDraggedType(type) diff --git a/project/frontend/src/components/BlockDefinitionContextMenu.tsx b/project/frontend/src/components/BlockDefinitionContextMenu.tsx new file mode 100644 index 0000000..4bb76d5 --- /dev/null +++ b/project/frontend/src/components/BlockDefinitionContextMenu.tsx @@ -0,0 +1,97 @@ +import { useEffect, useRef } from 'react' +import { Card } from './ui/card' +import * as Icons from '@phosphor-icons/react' + +interface BlockDefinitionContextMenuProps { + x: number + y: number + definitionId: string + definitionName: string + instanceCount: number + onClose: () => void + onRename: (definitionId: string) => void + onDuplicate: (definitionId: string) => void + onDelete: (definitionId: string) => void +} + +export function BlockDefinitionContextMenu({ + x, + y, + definitionId, + definitionName, + instanceCount, + onClose, + onRename, + onDuplicate, + onDelete +}: BlockDefinitionContextMenuProps) { + const menuRef = useRef(null) + + useEffect(() => { + const handleClickOutside = (e: Event) => { + const target = e.target as Node + if (menuRef.current && !menuRef.current.contains(target)) { + onClose() + } + } + + const timeoutId = setTimeout(() => { + document.addEventListener('pointerdown', handleClickOutside, true) + }, 100) + + return () => { + clearTimeout(timeoutId) + document.removeEventListener('pointerdown', handleClickOutside, true) + } + }, [onClose]) + + return ( + +
+ {definitionName} +
+ + + + + +
+ + + + ) +} diff --git a/project/frontend/src/components/BlockNode.tsx b/project/frontend/src/components/BlockNode.tsx index f932777..e483351 100644 --- a/project/frontend/src/components/BlockNode.tsx +++ b/project/frontend/src/components/BlockNode.tsx @@ -27,6 +27,7 @@ const BlockNode = memo(({ data, selected, id }: BlockNodeProps) => { const nodeDef = getNodeDefinition(data.blockType as BlockType, BackendFramework.PyTorch) const validationErrors = useModelBuilderStore((state) => state.validationErrors) const edges = useModelBuilderStore((state) => state.edges) + const hasConfigOverrides = useModelBuilderStore((state) => state.hasConfigOverrides) if (!nodeDef) return null @@ -37,6 +38,14 @@ const BlockNode = memo(({ data, selected, id }: BlockNodeProps) => { const nodeErrors = validationErrors.filter((error) => error.nodeId === id && error.type === 'error') const hasErrors = nodeErrors.length > 0 + // Check if this is an expanded internal node with overrides + const isExpandedInternal = (data as any)._isExpandedInternal === true + const parentGroupNodeId = (data as any)._expandedFrom as string | undefined + const internalNodeId = id.split('-expanded-')[0] + const hasOverrides = isExpandedInternal && parentGroupNodeId + ? hasConfigOverrides(parentGroupNodeId, internalNodeId) + : false + // Helper to check if a handle is already connected const isHandleConnected = (handleId: string, isTarget: boolean) => { return edges.some(edge => { @@ -71,6 +80,27 @@ const BlockNode = memo(({ data, selected, id }: BlockNodeProps) => {
)} + {/* Override Badge - for expanded internal nodes with customizations */} + {hasOverrides && !hasErrors && ( +
+ + + +
+ +
+
+ +
+
Customized
+
This node has custom configuration
+
+
+
+
+
+ )} + {/* Action Buttons - Only shown when selected */} {selected && (
@@ -233,7 +263,7 @@ const BlockNode = memo(({ data, selected, id }: BlockNodeProps) => { return shapes })()} - {!data.outputShape && data.blockType !== 'input' && data.blockType !== 'dataloader' && data.blockType !== 'empty' && ( + {!data.outputShape && data.blockType !== 'input' && data.blockType !== 'dataloader' && data.blockType !== 'empty' && data.blockType !== 'output' && data.blockType !== 'loss' && (
Configure params
diff --git a/project/frontend/src/components/BlockPalette.tsx b/project/frontend/src/components/BlockPalette.tsx index 1d681c7..d2aeede 100644 --- a/project/frontend/src/components/BlockPalette.tsx +++ b/project/frontend/src/components/BlockPalette.tsx @@ -4,6 +4,11 @@ import { Accordion, AccordionContent, AccordionItem, AccordionTrigger } from '@/ import { Card } from '@/components/ui/card' import { Input } from '@/components/ui/input' import { getAllNodeDefinitions, getNodeDefinitionsByCategory, BackendFramework } from '@/lib/nodes/registry' +import { useModelBuilderStore } from '@/lib/store' +import { BlockDefinitionContextMenu } from './BlockDefinitionContextMenu' +import RenameBlockDialog from './RenameBlockDialog' +import DeleteBlockDialog from './DeleteBlockDialog' +import { toast } from 'sonner' import * as Icons from '@phosphor-icons/react' import Fuse from 'fuse.js' @@ -15,6 +20,27 @@ interface BlockPaletteProps { export default function BlockPalette({ onDragStart, onBlockClick, isCollapsed }: BlockPaletteProps) { const [searchQuery, setSearchQuery] = useState('') + const [contextMenu, setContextMenu] = useState<{ + x: number + y: number + definitionId: string + definitionName: string + } | null>(null) + const [renameDialog, setRenameDialog] = useState<{ + definitionId: string + currentName: string + } | null>(null) + const [deleteDialog, setDeleteDialog] = useState<{ + definitionId: string + blockName: string + instanceCount: number + } | null>(null) + + const groupDefinitions = useModelBuilderStore((state) => state.groupDefinitions) + const nodes = useModelBuilderStore((state) => state.nodes) + const renameGroupDefinition = useModelBuilderStore((state) => state.renameGroupDefinition) + const deleteGroupDefinition = useModelBuilderStore((state) => state.deleteGroupDefinition) + const duplicateGroupDefinition = useModelBuilderStore((state) => state.duplicateGroupDefinition) const categories = [ { key: 'input', label: 'Input & Data', icon: Icons.DownloadSimple }, @@ -23,12 +49,13 @@ export default function BlockPalette({ onDragStart, onBlockClick, isCollapsed }: { key: 'advanced', label: 'Advanced Layers', icon: Icons.CubeFocus }, { key: 'merge', label: 'Operations', icon: Icons.Unite }, { key: 'output', label: 'Output & Loss', icon: Icons.UploadSimple }, - { key: 'utility', label: 'Utility', icon: Icons.Wrench } + { key: 'utility', label: 'Utility', icon: Icons.Wrench }, + { key: 'custom', label: 'Custom Blocks', icon: Icons.Package } ] // Prepare all blocks for fuzzy search - maintain category order const allBlocks = useMemo(() => { - const categoryOrder = ['input', 'basic', 'activation', 'advanced', 'merge', 'output', 'utility'] + const categoryOrder = ['input', 'basic', 'activation', 'advanced', 'merge', 'output', 'utility', 'custom'] const nodes = getAllNodeDefinitions(BackendFramework.PyTorch) // Group by category @@ -61,14 +88,26 @@ export default function BlockPalette({ onDragStart, onBlockClick, isCollapsed }: category: node.metadata.category, color: node.metadata.color, icon: node.metadata.icon, - description: node.metadata.description + description: node.metadata.description, + isGroup: false + })) + + // Add custom group blocks + const groupBlocks = Array.from(groupDefinitions.values()).map(def => ({ + type: `group:${def.id}`, + label: def.name, + category: 'custom', + color: def.color, + icon: 'SquaresFour', + description: def.description || `Custom block with ${def.internalNodes.length} nodes`, + isGroup: true, + groupDefinitionId: def.id })) - // Debug: log all icons - console.log('Block icons loaded:', blocks.map(b => `${b.label}: ${b.icon}`)) + blocks.push(...groupBlocks) return blocks - }, []) + }, [groupDefinitions]) // Setup fuzzy search const fuse = useMemo(() => { @@ -94,6 +133,60 @@ export default function BlockPalette({ onDragStart, onBlockClick, isCollapsed }: onDragStart(type) } + const handleContextMenu = (e: React.MouseEvent, block: any) => { + if (!block.isGroup) return + + e.preventDefault() + e.stopPropagation() + + setContextMenu({ + x: e.clientX, + y: e.clientY, + definitionId: block.groupDefinitionId, + definitionName: block.label + }) + } + + const handleRename = (definitionId: string) => { + const definition = groupDefinitions.get(definitionId) + if (!definition) return + + setRenameDialog({ + definitionId, + currentName: definition.name + }) + } + + const handleDuplicate = (definitionId: string) => { + const definition = groupDefinitions.get(definitionId) + const newId = duplicateGroupDefinition(definitionId) + if (newId && definition) { + toast.success('Block duplicated', { + description: `Created copy of "${definition.name}"` + }) + } + } + + const handleDelete = (definitionId: string) => { + const definition = groupDefinitions.get(definitionId) + if (!definition) return + + // Count instances on canvas + const instanceCount = nodes.filter(node => { + if (node.data.blockType === 'group') { + const groupData = node.data as any + return groupData.groupDefinitionId === definitionId + } + return false + }).length + + setDeleteDialog({ + definitionId, + blockName: definition.name, + instanceCount + }) + } + const renderBlockCard = (block: { type: string label: string @@ -101,6 +194,8 @@ export default function BlockPalette({ onDragStart, onBlockClick, isCollapsed }: color: string icon: string description: string + isGroup?: boolean + groupDefinitionId?: string }) => { const IconComponent = (Icons as any)[block.icon] @@ -121,13 +216,16 @@ export default function BlockPalette({ onDragStart, onBlockClick, isCollapsed }: handleDragStart(block.type) }} onClick={() => onBlockClick(block.type)} + onContextMenu={(e) => handleContextMenu(e, block)} >
@@ -147,126 +245,263 @@ export default function BlockPalette({ onDragStart, onBlockClick, isCollapsed }: if (isCollapsed) { return ( -
- {/* Scrollable Block Icons */} - -
- {allBlocks.map((block) => { - const IconComponent = (Icons as any)[block.icon] - - // Debug: log if icon is missing - if (!IconComponent && block.icon) { - console.warn(`Icon "${block.icon}" not found for block "${block.label}" (${block.type})`) - } + <> +
+ {/* Scrollable Block Icons */} + +
+ {allBlocks.map((block: any) => { + const IconComponent = (Icons as any)[block.icon] + + // Debug: log if icon is missing + if (!IconComponent && block.icon) { + console.warn(`Icon "${block.icon}" not found for block "${block.label}" (${block.type})`) + } - const FinalIcon = IconComponent || Icons.Cube - - return ( - - ) - })} -
-
-
+ {/* Tooltip on hover */} +
+ {block.label} +
+ + ) + })} +
+
+
+ + {/* Context Menu */} + {contextMenu && ( + { + if (node.data.blockType === 'group') { + const groupData = node.data as any + return groupData.groupDefinitionId === contextMenu.definitionId + } + return false + }).length} + onClose={() => setContextMenu(null)} + onRename={handleRename} + onDuplicate={handleDuplicate} + onDelete={handleDelete} + /> + )} + + {/* Rename Dialog */} + {renameDialog && ( + setRenameDialog(null)} + onSave={(newName) => { + renameGroupDefinition(renameDialog.definitionId, newName) + toast.success('Block renamed', { + description: `Renamed "${renameDialog.currentName}" to "${newName}"` + }) + setRenameDialog(null) + }} + currentName={renameDialog.currentName} + existingNames={Array.from(groupDefinitions.values()).map(def => def.name)} + /> + )} + + {/* Delete Dialog */} + {deleteDialog && ( + setDeleteDialog(null)} + onConfirm={(cascade) => { + deleteGroupDefinition(deleteDialog.definitionId, cascade) + if (cascade && deleteDialog.instanceCount > 0) { + toast.success('Block deleted', { + description: `Deleted "${deleteDialog.blockName}" and ${deleteDialog.instanceCount} instance(s) from canvas` + }) + } else if (deleteDialog.instanceCount > 0) { + toast.warning('Definition deleted', { + description: `"${deleteDialog.blockName}" deleted but ${deleteDialog.instanceCount} instance(s) remain on canvas with errors` + }) + } else { + toast.success('Block deleted', { + description: `Deleted "${deleteDialog.blockName}"` + }) + } + setDeleteDialog(null) + }} + blockName={deleteDialog.blockName} + instanceCount={deleteDialog.instanceCount} + /> + )} + ) } return ( -
-
-
- - setSearchQuery(e.target.value)} - className="pl-9 h-9" - /> - {searchQuery && ( - - )} + <> +
+
+
+ + setSearchQuery(e.target.value)} + className="pl-9 h-9" + /> + {searchQuery && ( + + )} +
+ + +
+ {filteredBlocks !== null ? ( + // Search results view +
+ {filteredBlocks.length > 0 ? ( + filteredBlocks.map((block) => renderBlockCard(block)) + ) : ( +
+ +

No blocks found

+

Try a different search term

+
+ )} +
+ ) : ( + // Categorized view + + {categories.map((category) => { + const blocks = allBlocks.filter(b => b.category === category.key) + const CategoryIcon = category.icon + + return ( + + +
+ + {category.label} +
+
+ +
+ {blocks.map((block) => renderBlockCard(block))} +
+
+
+ ) + })} +
+ )} +
+
- -
- {filteredBlocks !== null ? ( - // Search results view -
- {filteredBlocks.length > 0 ? ( - filteredBlocks.map((block) => renderBlockCard(block)) - ) : ( -
- -

No blocks found

-

Try a different search term

-
- )} -
- ) : ( - // Categorized view - - {categories.map((category) => { - const blocks = allBlocks.filter(b => b.category === category.key) - const CategoryIcon = category.icon + {/* Context Menu */} + {contextMenu && ( + { + if (node.data.blockType === 'group') { + const groupData = node.data as any + return groupData.groupDefinitionId === contextMenu.definitionId + } + return false + }).length} + onClose={() => setContextMenu(null)} + onRename={handleRename} + onDuplicate={handleDuplicate} + onDelete={handleDelete} + /> + )} - return ( - - -
- - {category.label} -
-
- -
- {blocks.map((block) => renderBlockCard(block))} -
-
-
- ) - })} -
- )} -
-
-
+ {/* Rename Dialog */} + {renameDialog && ( + setRenameDialog(null)} + onSave={(newName) => { + renameGroupDefinition(renameDialog.definitionId, newName) + toast.success('Block renamed', { + description: `Renamed "${renameDialog.currentName}" to "${newName}"` + }) + setRenameDialog(null) + }} + currentName={renameDialog.currentName} + existingNames={Array.from(groupDefinitions.values()).map(def => def.name)} + /> + )} + + {/* Delete Dialog */} + {deleteDialog && ( + setDeleteDialog(null)} + onConfirm={(cascade) => { + deleteGroupDefinition(deleteDialog.definitionId, cascade) + if (cascade && deleteDialog.instanceCount > 0) { + toast.success('Block deleted', { + description: `Deleted "${deleteDialog.blockName}" and ${deleteDialog.instanceCount} instance(s) from canvas` + }) + } else if (deleteDialog.instanceCount > 0) { + toast.warning('Definition deleted', { + description: `"${deleteDialog.blockName}" deleted but ${deleteDialog.instanceCount} instance(s) remain on canvas with errors` + }) + } else { + toast.success('Block deleted', { + description: `Deleted "${deleteDialog.blockName}"` + }) + } + setDeleteDialog(null) + }} + blockName={deleteDialog.blockName} + instanceCount={deleteDialog.instanceCount} + /> + )} + ) } diff --git a/project/frontend/src/components/Canvas.tsx b/project/frontend/src/components/Canvas.tsx index ca8f560..4415de7 100644 --- a/project/frontend/src/components/Canvas.tsx +++ b/project/frontend/src/components/Canvas.tsx @@ -15,17 +15,23 @@ import { import '@xyflow/react/dist/style.css' import { useModelBuilderStore } from '@/lib/store' import { getNodeDefinition, BackendFramework } from '@/lib/nodes/registry' -import { BlockData, BlockType } from '@/lib/types' +import { BlockData, BlockType, GroupBlockData } from '@/lib/types' import BlockNode from './BlockNode' +import GroupBlockNode from './GroupBlockNode' +import ExpandedGroupContainer from './ExpandedGroupContainer' import CustomConnectionLine from './CustomConnectionLine' import { HistoryToolbar } from './HistoryToolbar' import { ContextMenu } from './ContextMenu' import ViewCodeModal from './ViewCodeModal' +import GroupCreationDialog from './GroupCreationDialog' +import ValidationErrorsPanel from './ValidationErrorsPanel' import { renderNodeCode } from '@/lib/api' import { toast } from 'sonner' const nodeTypes = { - custom: BlockNode + custom: BlockNode, + group: GroupBlockNode, + expandedGroupContainer: ExpandedGroupContainer } interface CanvasProps { @@ -50,8 +56,11 @@ function FlowCanvas({ onRegisterAddNode }: { onRegisterAddNode: (handler: (block duplicateNode, recentlyUsedNodes, validateConnection, + validateArchitecture, undo, - redo + redo, + groupDefinitions, + ungroupBlock } = useModelBuilderStore() const { screenToFlowPosition, getViewport } = useReactFlow() @@ -68,7 +77,15 @@ function FlowCanvas({ onRegisterAddNode }: { onRegisterAddNode: (handler: (block const [isLoadingCode, setIsLoadingCode] = useState(false) const currentProject = useModelBuilderStore((state) => state.currentProject) - // Keyboard shortcuts for undo/redo/delete + // GroupCreationDialog state + const [isGroupDialogOpen, setIsGroupDialogOpen] = useState(false) + const [selectedNodesForGrouping, setSelectedNodesForGrouping] = useState([]) + const createGroupBlock = useModelBuilderStore((state) => state.createGroupBlock) + + // Validation is now triggered manually via the Validate button in Header + // Removed automatic validation on nodes/edges change + + // Keyboard shortcuts for undo/redo/delete/group/expand useEffect(() => { const handleKeyDown = (e: KeyboardEvent) => { // Check for Ctrl (Windows/Linux) or Cmd (Mac) @@ -80,6 +97,19 @@ function FlowCanvas({ onRegisterAddNode }: { onRegisterAddNode: (handler: (block } else if (isMod && (e.key === 'y' || (e.key === 'z' && e.shiftKey))) { e.preventDefault() redo() + } else if (isMod && e.key === 'g' && !e.shiftKey) { + // Ctrl+G: Create group from selection + e.preventDefault() + const target = e.target as HTMLElement + if (target.tagName !== 'INPUT' && target.tagName !== 'TEXTAREA' && !target.isContentEditable) { + const selectedNodes = nodes.filter(n => n.selected) + if (selectedNodes.length >= 2) { + setSelectedNodesForGrouping(selectedNodes.map(n => n.id)) + setIsGroupDialogOpen(true) + } else { + toast.error('Select at least 2 nodes to create a group') + } + } } else if ((e.key === 'Delete' || e.key === 'Backspace')) { // Only delete if not typing in an input field const target = e.target as HTMLElement @@ -97,7 +127,7 @@ function FlowCanvas({ onRegisterAddNode }: { onRegisterAddNode: (handler: (block window.addEventListener('keydown', handleKeyDown) return () => window.removeEventListener('keydown', handleKeyDown) - }, [undo, redo, removeNode, removeEdge, selectedNodeId, selectedEdgeId, setSelectedEdgeId]) + }, [undo, redo, removeNode, removeEdge, selectedNodeId, selectedEdgeId, setSelectedEdgeId, nodes]) // Find a suitable position for a new node const findAvailablePosition = useCallback(() => { @@ -128,11 +158,50 @@ function FlowCanvas({ onRegisterAddNode }: { onRegisterAddNode: (handler: (block // Handle block click from palette useEffect(() => { const handleBlockClickInternal = (blockType: string) => { + const position = findAvailablePosition() + + // Check if it's a group block + if (blockType.startsWith('group:')) { + const groupId = blockType.substring(6) // Remove 'group:' prefix + const groupDef = useModelBuilderStore.getState().groupDefinitions.get(groupId) + + if (!groupDef) { + toast.error('Group definition not found') + return + } + + // Create a new instance of the group block + const groupNodeId = `group-block-${Date.now()}` + const groupNode = { + id: groupNodeId, + type: 'group', + position, + data: { + blockType: 'group', + label: groupDef.name, + config: {}, + category: groupDef.category, + groupDefinitionId: groupId, + isExpanded: false + } + } + + addNode(groupNode as any) + + setTimeout(() => { + useModelBuilderStore.getState().inferDimensions() + }, 0) + + toast.success(`Added ${groupDef.name}`, { + description: 'Group block instance added to canvas' + }) + return + } + + // Regular node const nodeDef = getNodeDefinition(blockType as BlockType, BackendFramework.PyTorch) if (!nodeDef) return - const position = findAvailablePosition() - const newNode = { id: `${blockType}-${Date.now()}`, type: 'custom', @@ -178,37 +247,75 @@ function FlowCanvas({ onRegisterAddNode }: { onRegisterAddNode: (handler: (block const type = (window as any).draggedBlockTypeGlobal if (!type) return - const nodeDef = getNodeDefinition(type as BlockType, BackendFramework.PyTorch) - if (!nodeDef) return - const position = screenToFlowPosition({ x: event.clientX, y: event.clientY }) - const newNode = { - id: `${type}-${Date.now()}`, - type: 'custom', - position, - data: { - blockType: nodeDef.metadata.type, - label: nodeDef.metadata.label, - config: {}, - category: nodeDef.metadata.category - } as BlockData - } + // Check if it's a group block + if (type.startsWith('group:')) { + const groupId = type.substring(6) // Remove 'group:' prefix + const groupDef = useModelBuilderStore.getState().groupDefinitions.get(groupId) + + if (!groupDef) { + toast.error('Group definition not found') + return + } - nodeDef.configSchema.forEach((field) => { - if (field.default !== undefined) { - newNode.data.config[field.name] = field.default + // Create a new instance of the group block + const groupNodeId = `group-block-${Date.now()}` + const groupNode = { + id: groupNodeId, + type: 'group', + position, + data: { + blockType: 'group', + label: groupDef.name, + config: {}, + category: groupDef.category, + groupDefinitionId: groupId, + isExpanded: false + } } - }) - addNode(newNode) + addNode(groupNode as any) - setTimeout(() => { - useModelBuilderStore.getState().inferDimensions() - }, 0) + setTimeout(() => { + useModelBuilderStore.getState().inferDimensions() + }, 0) + + toast.success(`Added ${groupDef.name}`, { + description: 'Group block instance added to canvas' + }) + } else { + // Regular node + const nodeDef = getNodeDefinition(type as BlockType, BackendFramework.PyTorch) + if (!nodeDef) return + + const newNode = { + id: `${type}-${Date.now()}`, + type: 'custom', + position, + data: { + blockType: nodeDef.metadata.type, + label: nodeDef.metadata.label, + config: {}, + category: nodeDef.metadata.category + } as BlockData + } + + nodeDef.configSchema.forEach((field) => { + if (field.default !== undefined) { + newNode.data.config[field.name] = field.default + } + }) + + addNode(newNode) + + setTimeout(() => { + useModelBuilderStore.getState().inferDimensions() + }, 0) + } ;(window as any).draggedBlockTypeGlobal = null }, @@ -643,12 +750,14 @@ function FlowCanvas({ onRegisterAddNode }: { onRegisterAddNode: (handler: (block y={contextMenu.y} type={contextMenu.type} nodeId={contextMenu.nodeId} + isGroupBlock={contextMenu.nodeId ? nodes.find(n => n.id === contextMenu.nodeId)?.data.blockType === 'group' : false} recentlyUsedNodes={recentlyUsedNodes} onClose={() => setContextMenu(null)} onAddNode={handleAddNodeFromContextMenu} onDeleteNode={removeNode} onDuplicateNode={duplicateNode} onReplicateNode={handleReplicateNode} + onUngroupNode={ungroupBlock} /> )} + { + setIsGroupDialogOpen(false) + setSelectedNodesForGrouping([]) + }} + onSave={(config) => { + // Pass the full config including portMappings to createGroupBlock + const result = createGroupBlock(selectedNodesForGrouping, config) + + // Only show toast if creation succeeded (result is a valid node ID) + if (result) { + toast.success(`Created block: ${config.name}`) + setIsGroupDialogOpen(false) + setSelectedNodesForGrouping([]) + } else { + toast.error('Failed to create block', { + description: 'Check console for validation errors' + }) + } + }} + selectedNodeIds={selectedNodesForGrouping} + /> +
) } diff --git a/project/frontend/src/components/ConfigPanel.tsx b/project/frontend/src/components/ConfigPanel.tsx index 7c961a0..de1a4f5 100644 --- a/project/frontend/src/components/ConfigPanel.tsx +++ b/project/frontend/src/components/ConfigPanel.tsx @@ -11,14 +11,28 @@ import { Card } from '@/components/ui/card' import { X, Code, UploadSimple } from '@phosphor-icons/react' import { toast } from 'sonner' import CustomLayerModal from './CustomLayerModal' +import InternalNodeConfigPanel from './InternalNodeConfigPanel' export default function ConfigPanel() { - const { nodes, selectedNodeId, updateNode, setSelectedNodeId, removeNode } = useModelBuilderStore() + const { nodes, selectedNodeId, updateNode, setSelectedNodeId, removeNode, repeatGroupBlock, groupDefinitions } = useModelBuilderStore() const [isCustomModalOpen, setIsCustomModalOpen] = useState(false) + const [repeatCount, setRepeatCount] = useState(2) + const [repeatSpacingX, setRepeatSpacingX] = useState(300) + const [repeatSpacingY, setRepeatSpacingY] = useState(0) const fileInputRefs = useRef<{ [key: string]: HTMLInputElement | null }>({}) const selectedNode = nodes.find((n) => n.id === selectedNodeId) + // Check if selected node is an expanded internal node + const isExpandedInternal = selectedNode?.data?._isExpandedInternal === true + const parentGroupNodeId = selectedNode?.data?._expandedFrom as string | undefined + const groupDefinitionId = selectedNode?.data?._groupDefinitionId as string | undefined + + // Extract the original internal node ID from the expanded node ID + // Format: originalId-expanded-timestamp + // We need to split and take everything before the last "-expanded-" occurrence + const internalNodeId = selectedNode?.id ? selectedNode.id.substring(0, selectedNode.id.lastIndexOf('-expanded-')) : undefined + const handleCustomLayerSave = (config: { name: string code: string @@ -42,6 +56,28 @@ export default function ConfigPanel() { } }, [selectedNode?.id, selectedNode?.data.blockType]) + // Handle expanded internal node configuration + if (isExpandedInternal && parentGroupNodeId && groupDefinitionId && internalNodeId) { + // Debug logging + console.log('ConfigPanel - Routing to InternalNodeConfigPanel:', { + selectedNodeId: selectedNode.id, + parentGroupNodeId, + groupDefinitionId, + internalNodeId, + isExpandedInternal + }) + + return ( + setSelectedNodeId(null)} + /> + ) + } + // For custom blocks, don't show the sidebar at all - only the modal if (selectedNode?.data.blockType === 'custom') { return ( @@ -62,6 +98,13 @@ export default function ConfigPanel() { ) } + // Define handleDelete early so it can be used in all sections + const handleDelete = () => { + if (selectedNode) { + removeNode(selectedNode.id) + } + } + if (!selectedNode) { return (
@@ -72,6 +115,133 @@ export default function ConfigPanel() { ) } + // Handle group blocks separately + if (selectedNode.data.blockType === 'group') { + const groupData = selectedNode.data as any + const groupDef = groupDefinitions.get(groupData.groupDefinitionId) + + const handleRepeat = () => { + if (repeatCount < 1 || repeatCount > 10) { + toast.error('Repeat count must be between 1 and 10') + return + } + const newNodeIds = repeatGroupBlock(selectedNode.id, repeatCount, repeatSpacingX, repeatSpacingY) + toast.success(`Created ${repeatCount} copies`, { + description: `${repeatCount} instances added to canvas` + }) + } + + return ( +
+
+
+

{groupDef?.name || 'Group Block'}

+

Group block configuration

+
+ +
+ +
+
+ {groupDef && ( + <> + +
Group Information
+
+
Category: {groupDef.category}
+
Internal Nodes: {groupDef.internalNodes.length}
+
Inputs: {groupDef.portMappings.filter(p => p.type === 'input').length}
+
Outputs: {groupDef.portMappings.filter(p => p.type === 'output').length}
+
+ {groupDef.description && ( +
{groupDef.description}
+ )} +
+ +
+
Repeat Block
+
+
+ + setRepeatCount(parseInt(e.target.value) || 1)} + placeholder="Enter count" + /> +

Create 1-10 copies

+
+
+ + setRepeatSpacingX(parseInt(e.target.value) || 0)} + placeholder="Enter spacing" + /> +
+
+ + setRepeatSpacingY(parseInt(e.target.value) || 0)} + placeholder="Enter spacing" + /> +
+ +
+
+ + )} + + {selectedNode.data.inputShape && ( + +
Input Shape
+
+ [{selectedNode.data.inputShape.dims.join(', ')}] +
+
+ )} + + {selectedNode.data.outputShape && ( + +
Output Shape
+
+ [{selectedNode.data.outputShape.dims.join(', ')}] +
+
+ )} +
+
+ +
+ +
+
+ ) + } + const nodeDef = getNodeDefinition(selectedNode.data.blockType, BackendFramework.PyTorch) if (!nodeDef) return null @@ -99,10 +269,6 @@ export default function ConfigPanel() { } } - const handleDelete = () => { - removeNode(selectedNode.id) - } - const handleFileUpload = async (fieldName: string, file: File) => { try { // Read file as base64 for storage diff --git a/project/frontend/src/components/ContextMenu.tsx b/project/frontend/src/components/ContextMenu.tsx index 4edd7e8..f93d87a 100644 --- a/project/frontend/src/components/ContextMenu.tsx +++ b/project/frontend/src/components/ContextMenu.tsx @@ -9,12 +9,14 @@ interface ContextMenuProps { y: number type: 'canvas' | 'node' nodeId?: string + isGroupBlock?: boolean recentlyUsedNodes?: BlockType[] onClose: () => void onAddNode?: (nodeType: BlockType, x: number, y: number) => void onDeleteNode?: (nodeId: string) => void onDuplicateNode?: (nodeId: string) => void onReplicateNode?: (nodeId: string) => void + onUngroupNode?: (nodeId: string) => void } export function ContextMenu({ @@ -22,12 +24,14 @@ export function ContextMenu({ y, type, nodeId, + isGroupBlock = false, recentlyUsedNodes = [], onClose, onAddNode, onDeleteNode, onDuplicateNode, - onReplicateNode + onReplicateNode, + onUngroupNode }: ContextMenuProps) { const menuRef = useRef(null) @@ -118,6 +122,21 @@ export function ContextMenu({ Replicate as Custom + {isGroupBlock && ( + <> +
+ + + )}
+ + + + + ) +} diff --git a/project/frontend/src/components/ExpandedGroupContainer.tsx b/project/frontend/src/components/ExpandedGroupContainer.tsx new file mode 100644 index 0000000..7d6d8c3 --- /dev/null +++ b/project/frontend/src/components/ExpandedGroupContainer.tsx @@ -0,0 +1,76 @@ +import { memo } from 'react' +import { NodeProps } from '@xyflow/react' +import { useModelBuilderStore } from '@/lib/store' +import * as Icons from '@phosphor-icons/react' +import { Button } from '@/components/ui/button' +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from '@/components/ui/tooltip' + +interface ExpandedGroupContainerData { + _expandedFrom: string + _groupDefinitionId: string + groupName: string + groupColor: string +} + +const ExpandedGroupContainer = memo(({ data, id }: NodeProps) => { + const toggleGroupExpansion = useModelBuilderStore((state) => state.toggleGroupExpansion) + + const handleCollapse = (e: React.MouseEvent) => { + e.stopPropagation() + // Use the _expandedFrom ID to collapse the group + toggleGroupExpansion(data._expandedFrom) + } + + return ( +
+ {/* Collapse button in top right */} +
+ + + + + + Collapse {data.groupName} + + +
+ + {/* Group name label in top left */} +
+ {data.groupName} +
+
+ ) +}) + +ExpandedGroupContainer.displayName = 'ExpandedGroupContainer' + +export default ExpandedGroupContainer diff --git a/project/frontend/src/components/GroupBlockNode.tsx b/project/frontend/src/components/GroupBlockNode.tsx new file mode 100644 index 0000000..2899d4d --- /dev/null +++ b/project/frontend/src/components/GroupBlockNode.tsx @@ -0,0 +1,327 @@ +import { memo } from 'react' +import { Handle, Position, NodeProps } from '@xyflow/react' +import { GroupBlockData, PortMapping } from '@/lib/types' +import { useModelBuilderStore } from '@/lib/store' +import * as Icons from '@phosphor-icons/react' +import { Card } from '@/components/ui/card' +import { Badge } from '@/components/ui/badge' +import { Button } from '@/components/ui/button' +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from '@/components/ui/tooltip' + +interface GroupBlockNodeProps { + data: GroupBlockData + selected?: boolean + id: string +} + +const GroupBlockNode = memo(({ data, selected, id }: GroupBlockNodeProps) => { + const validationErrors = useModelBuilderStore((state) => state.validationErrors) + const edges = useModelBuilderStore((state) => state.edges) + const groupDefinitions = useModelBuilderStore((state) => state.groupDefinitions) + const toggleGroupExpansion = useModelBuilderStore((state) => state.toggleGroupExpansion) + + const groupDef = groupDefinitions.get(data.groupDefinitionId) + if (!groupDef) return null + + const nodeErrors = validationErrors.filter((error) => error.nodeId === id && error.type === 'error') + const hasErrors = nodeErrors.length > 0 + + // Check if this group instance has any configuration overrides + const hasCustomizations = data.instanceConfigOverrides && Object.keys(data.instanceConfigOverrides).length > 0 + const customizationCount = hasCustomizations ? Object.keys(data.instanceConfigOverrides!).length : 0 + + const isHandleConnected = (handleId: string, isTarget: boolean) => { + return edges.some(edge => { + if (isTarget) { + return edge.target === id && (edge.targetHandle || 'default') === handleId + } else { + return edge.source === id && (edge.sourceHandle || 'default') === handleId + } + }) + } + + const inputPorts = groupDef.portMappings.filter(p => p.type === 'input') + const outputPorts = groupDef.portMappings.filter(p => p.type === 'output') + + const getPortColor = (semantic: string) => { + const colors: Record = { + 'data': '#3b82f6', + 'labels': '#10b981', + 'loss': '#ef4444', + 'predictions': '#8b5cf6', + 'anchor': '#ec4899', + 'positive': '#f59e0b', + 'negative': '#f43f5e', + 'input1': '#06b6d4', + 'input2': '#8b5cf6', + 'weights': '#6366f1' + } + return colors[semantic] || '#3b82f6' + } + + return ( + + {/* Error Badge */} + {hasErrors && ( +
+
+ +
+
+ )} + + {/* Repetition Badge */} + {data.repetitionMetadata && ( +
+ + {data.repetitionMetadata.index + 1}/{data.repetitionMetadata.totalCount} + +
+ )} + + {/* Customization Badge */} + {hasCustomizations && !hasErrors && ( +
+ + + + + + {customizationCount} + + + +
+
Customized Instance
+
{customizationCount} internal node{customizationCount > 1 ? 's' : ''} customized
+
+
+
+
+
+ )} + + {/* Action Buttons */} + {selected && ( +
+ + + + + + {data.isExpanded ? 'Collapse' : 'Expand'} (Space) + + +
+ )} + + {/* Render input handles */} + {inputPorts.map((port, index) => { + const spacing = 100 / (inputPorts.length + 1) + const topPercent = spacing * (index + 1) + const color = getPortColor(port.semantic) + const isConnected = isHandleConnected(port.externalPortId, true) + + return ( +
+ + + + + + {port.externalPortLabel} {isConnected && '✓'} + + + +
+
Internal Mapping:
+
Node: {port.internalNodeId.split('-')[0]}
+
Port: {port.internalPortId}
+
+
+
+
+ {selected && ( +
+ )} +
+ ) + })} + +
+
+
+ +
+
+
+ {groupDef.name} +
+
+ + {groupDef.category} + + + {groupDef.internalNodes.length} nodes + +
+
+
+ + {groupDef.description && ( +
+ {groupDef.description} +
+ )} + +
+ + {inputPorts.length} in + + + {outputPorts.length} out +
+
+ + {/* Render output handles */} + {outputPorts.map((port, index) => { + const spacing = 100 / (outputPorts.length + 1) + const topPercent = spacing * (index + 1) + const color = getPortColor(port.semantic) + const isConnected = isHandleConnected(port.externalPortId, false) + + return ( +
+ + + + + {port.externalPortLabel} {isConnected && '✓'} + + + +
+
Internal Mapping:
+
Node: {port.internalNodeId.split('-')[0]}
+
Port: {port.internalPortId}
+
+
+
+
+ + {selected && ( +
+ )} +
+ ) + })} + + ) +}) + +GroupBlockNode.displayName = 'GroupBlockNode' + +export default GroupBlockNode diff --git a/project/frontend/src/components/GroupCreationDialog.test.tsx b/project/frontend/src/components/GroupCreationDialog.test.tsx new file mode 100644 index 0000000..0a1720d --- /dev/null +++ b/project/frontend/src/components/GroupCreationDialog.test.tsx @@ -0,0 +1,370 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import GroupCreationDialog from './GroupCreationDialog' +import { useModelBuilderStore } from '@/lib/store' +import { Node, Edge } from '@xyflow/react' +import { BlockData } from '@/lib/types' + +// Mock the store +vi.mock('@/lib/store', () => ({ + useModelBuilderStore: vi.fn() +})) + +// Mock the node registry +vi.mock('@/lib/nodes/registry', () => ({ + getNodeDefinition: vi.fn((blockType) => { + if (blockType === 'conv2d') { + return { + getInputPorts: () => [{ id: 'input', label: 'Input', semantic: 'data' }], + getOutputPorts: () => [{ id: 'output', label: 'Output', semantic: 'data' }] + } + } + if (blockType === 'linear') { + return { + getInputPorts: () => [{ id: 'input', label: 'Input', semantic: 'data' }], + getOutputPorts: () => [{ id: 'output', label: 'Output', semantic: 'data' }] + } + } + return null + }), + BackendFramework: { + PyTorch: 'pytorch' + } +})) + +// Mock blockValidation +vi.mock('@/lib/blockValidation', () => ({ + validateConnectivity: vi.fn(() => ({ isValid: true, errors: [] })), + detectCycles: vi.fn(() => ({ isValid: true, errors: [] })), + validateBlockName: vi.fn((name: string) => { + if (!name) return { isValid: false, errors: ['Block name is required'] } + if (name.length > 50) return { isValid: false, errors: ['Block name must be 50 characters or less'] } + return { isValid: true, errors: [] } + }) +})) + +describe('GroupCreationDialog - Port Configuration', () => { + const mockNodes: Node[] = [ + { + id: 'node1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + blockType: 'conv2d', + label: 'Conv2D', + config: {}, + category: 'basic' + } + }, + { + id: 'node2', + type: 'custom', + position: { x: 100, y: 0 }, + data: { + blockType: 'linear', + label: 'Linear', + config: {}, + category: 'basic' + } + } + ] + + const mockEdges: Edge[] = [ + { + id: 'edge1', + source: 'node1', + target: 'node2', + sourceHandle: 'output', + targetHandle: 'input' + } + ] + + const mockOnSave = vi.fn() + const mockOnClose = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + ;(useModelBuilderStore as any).mockImplementation((selector: any) => { + const state = { + nodes: mockNodes, + edges: mockEdges, + currentProject: { framework: 'pytorch' }, + groupDefinitions: new Map() + } + return selector ? selector(state) : state + }) + }) + + it('should display comprehensive port selection UI in step 2', async () => { + render( + + ) + + // Fill in name and proceed to step 2 + const nameInput = screen.getByLabelText(/Block Name/i) + fireEvent.change(nameInput, { target: { value: 'TestBlock' } }) + + const nextButton = screen.getByText(/Next: Select Ports/i) + fireEvent.click(nextButton) + + // Verify step 2 is displayed + await waitFor(() => { + expect(screen.getByText(/Input Ports/i)).toBeInTheDocument() + expect(screen.getByText(/Output Ports/i)).toBeInTheDocument() + }) + }) + + it('should display all available input and output ports from internal layers', async () => { + render( + + ) + + // Navigate to step 2 + const nameInput = screen.getByLabelText(/Block Name/i) + fireEvent.change(nameInput, { target: { value: 'TestBlock' } }) + + const nextButton = screen.getByText(/Next: Select Ports/i) + fireEvent.click(nextButton) + + // Verify ports are displayed + await waitFor(() => { + // Should show Conv2D and Linear nodes + expect(screen.getByText('Conv2D')).toBeInTheDocument() + expect(screen.getByText('Linear')).toBeInTheDocument() + }) + }) + + it('should allow port selection and deselection', async () => { + render( + + ) + + // Navigate to step 2 + const nameInput = screen.getByLabelText(/Block Name/i) + fireEvent.change(nameInput, { target: { value: 'TestBlock' } }) + + const nextButton = screen.getByText(/Next: Select Ports/i) + fireEvent.click(nextButton) + + await waitFor(() => { + expect(screen.getByText(/Input Ports/i)).toBeInTheDocument() + }) + + // Find checkboxes + const checkboxes = screen.getAllByRole('checkbox') + expect(checkboxes.length).toBeGreaterThan(0) + + // Toggle a checkbox + const firstCheckbox = checkboxes[0] + const initialChecked = firstCheckbox.getAttribute('data-state') === 'checked' + + fireEvent.click(firstCheckbox) + + // Verify state changed + await waitFor(() => { + const newState = firstCheckbox.getAttribute('data-state') + expect(newState).not.toBe(initialChecked ? 'checked' : 'unchecked') + }) + }) + + it('should provide custom label editing for selected ports', async () => { + render( + + ) + + // Navigate to step 2 + const nameInput = screen.getByLabelText(/Block Name/i) + fireEvent.change(nameInput, { target: { value: 'TestBlock' } }) + + const nextButton = screen.getByText(/Next: Select Ports/i) + fireEvent.click(nextButton) + + await waitFor(() => { + expect(screen.getByText(/Input Ports/i)).toBeInTheDocument() + }) + + // Find a checkbox and select it + const checkboxes = screen.getAllByRole('checkbox') + const firstCheckbox = checkboxes[0] + + // If not checked, check it + if (firstCheckbox.getAttribute('data-state') !== 'checked') { + fireEvent.click(firstCheckbox) + } + + // Look for label input field + await waitFor(() => { + const labelInputs = screen.getAllByPlaceholderText(/External port label/i) + expect(labelInputs.length).toBeGreaterThan(0) + }) + }) + + it('should validate that at least one port is exposed before allowing creation', async () => { + // Mock edges with no external connections + ;(useModelBuilderStore as any).mockImplementation((selector: any) => { + const state = { + nodes: mockNodes, + edges: mockEdges, + currentProject: { framework: 'pytorch' }, + groupDefinitions: new Map() + } + return selector ? selector(state) : state + }) + + render( + + ) + + // Navigate to step 2 + const nameInput = screen.getByLabelText(/Block Name/i) + fireEvent.change(nameInput, { target: { value: 'TestBlock' } }) + + const nextButton = screen.getByText(/Next: Select Ports/i) + fireEvent.click(nextButton) + + await waitFor(() => { + expect(screen.getByText(/Input Ports/i)).toBeInTheDocument() + }) + + // Deselect all ports + const checkboxes = screen.getAllByRole('checkbox') + for (const checkbox of checkboxes) { + if (checkbox.getAttribute('data-state') === 'checked') { + fireEvent.click(checkbox) + } + } + + // Try to create block + const createButton = screen.getByText(/Create Block/i) + fireEvent.click(createButton) + + // Should show validation error + await waitFor(() => { + expect(screen.getByText(/At least one port must be exposed/i)).toBeInTheDocument() + }) + + // onSave should not be called + expect(mockOnSave).not.toHaveBeenCalled() + }) + + it('should mark ports with external connections as "External"', async () => { + // Add external edge + const edgesWithExternal: Edge[] = [ + ...mockEdges, + { + id: 'external1', + source: 'external-node', + target: 'node1', + sourceHandle: 'output', + targetHandle: 'input' + } + ] + + ;(useModelBuilderStore as any).mockImplementation((selector: any) => { + const state = { + nodes: [...mockNodes, { + id: 'external-node', + type: 'custom', + position: { x: -100, y: 0 }, + data: { blockType: 'input', label: 'Input', config: {}, category: 'basic' } + }], + edges: edgesWithExternal, + currentProject: { framework: 'pytorch' }, + groupDefinitions: new Map() + } + return selector ? selector(state) : state + }) + + render( + + ) + + // Navigate to step 2 + const nameInput = screen.getByLabelText(/Block Name/i) + fireEvent.change(nameInput, { target: { value: 'TestBlock' } }) + + const nextButton = screen.getByText(/Next: Select Ports/i) + fireEvent.click(nextButton) + + // Verify "External" badge is shown + await waitFor(() => { + expect(screen.getByText('External')).toBeInTheDocument() + }) + }) + + it('should call onSave with correct port mappings configuration', async () => { + render( + + ) + + // Fill in name + const nameInput = screen.getByLabelText(/Block Name/i) + fireEvent.change(nameInput, { target: { value: 'TestBlock' } }) + + // Navigate to step 2 + const nextButton = screen.getByText(/Next: Select Ports/i) + fireEvent.click(nextButton) + + await waitFor(() => { + expect(screen.getByText(/Input Ports/i)).toBeInTheDocument() + }) + + // Ensure at least one port is selected + const checkboxes = screen.getAllByRole('checkbox') + if (checkboxes[0].getAttribute('data-state') !== 'checked') { + fireEvent.click(checkboxes[0]) + } + + // Create block + const createButton = screen.getByText(/Create Block/i) + fireEvent.click(createButton) + + // Verify onSave was called with correct structure + await waitFor(() => { + expect(mockOnSave).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'TestBlock', + description: '', + category: expect.any(String), + color: expect.any(String), + portMappings: expect.any(Array) + }) + ) + }) + }) +}) diff --git a/project/frontend/src/components/GroupCreationDialog.tsx b/project/frontend/src/components/GroupCreationDialog.tsx new file mode 100644 index 0000000..9434bd7 --- /dev/null +++ b/project/frontend/src/components/GroupCreationDialog.tsx @@ -0,0 +1,593 @@ +import { useState, useEffect } from 'react' +import { Dialog, DialogContent, DialogHeader, DialogTitle, DialogDescription, DialogFooter } from '@/components/ui/dialog' +import { Button } from '@/components/ui/button' +import { Input } from '@/components/ui/input' +import { Label } from '@/components/ui/label' +import { Textarea } from '@/components/ui/textarea' +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select' +import { Checkbox } from '@/components/ui/checkbox' +import { ScrollArea } from '@/components/ui/scroll-area' +import { Badge } from '@/components/ui/badge' +import { Alert, AlertDescription } from '@/components/ui/alert' +import { BlockCategory, PortMapping } from '@/lib/types' +import { useModelBuilderStore } from '@/lib/store' +import { getNodeDefinition, BackendFramework } from '@/lib/nodes/registry' +import { validateConnectivity, detectCycles, validateBlockName } from '@/lib/blockValidation' +import { toast } from 'sonner' +import * as Icons from '@phosphor-icons/react' + +interface GroupCreationDialogProps { + isOpen: boolean + onClose: () => void + onSave: (config: { + name: string + description: string + category: BlockCategory + color: string + portMappings: PortMapping[] + }) => void + selectedNodeIds: string[] +} + +const COLOR_OPTIONS = [ + { value: '#9333ea', label: 'Purple', color: '#9333ea' }, + { value: '#ec4899', label: 'Pink', color: '#ec4899' }, + { value: '#f59e0b', label: 'Orange', color: '#f59e0b' }, + { value: '#10b981', label: 'Green', color: '#10b981' }, + { value: '#3b82f6', label: 'Blue', color: '#3b82f6' }, + { value: '#ef4444', label: 'Red', color: '#ef4444' }, + { value: '#8b5cf6', label: 'Violet', color: '#8b5cf6' }, + { value: '#06b6d4', label: 'Cyan', color: '#06b6d4' }, +] + +interface PortInfo { + nodeId: string + nodeName: string + portId: string + portLabel: string + type: 'input' | 'output' + semantic: string + isExternal: boolean +} + +export default function GroupCreationDialog({ + isOpen, + onClose, + onSave, + selectedNodeIds +}: GroupCreationDialogProps) { + const [step, setStep] = useState(1) + const [name, setName] = useState('') + const [description, setDescription] = useState('') + const [category, setCategory] = useState('utility') + const [color, setColor] = useState('#9333ea') + const [nameError, setNameError] = useState('') + const [validationErrors, setValidationErrors] = useState([]) + const [selectedPorts, setSelectedPorts] = useState>(new Set()) + const [portLabels, setPortLabels] = useState>(new Map()) + + const nodes = useModelBuilderStore((state) => state.nodes) + const edges = useModelBuilderStore((state) => state.edges) + const currentProject = useModelBuilderStore((state) => state.currentProject) + const groupDefinitions = useModelBuilderStore((state) => state.groupDefinitions) + + // Discover available ports from selected nodes + const availablePorts: PortInfo[] = [] + const selectedNodes = nodes.filter(n => selectedNodeIds.includes(n.id)) + + selectedNodes.forEach(node => { + const nodeDef = getNodeDefinition(node.data.blockType, currentProject?.framework as any || BackendFramework.PyTorch) + if (!nodeDef) return + + const inputPorts = nodeDef.getInputPorts ? nodeDef.getInputPorts(node.data.config) : [] + const outputPorts = nodeDef.getOutputPorts ? nodeDef.getOutputPorts(node.data.config) : [] + + // Check which ports have external connections + inputPorts.forEach(port => { + const hasExternalConnection = edges.some(e => + e.target === node.id && + (e.targetHandle || 'default') === port.id && + !selectedNodeIds.includes(e.source) + ) + availablePorts.push({ + nodeId: node.id, + nodeName: node.data.label || node.data.blockType, + portId: port.id, + portLabel: port.label, + type: 'input', + semantic: port.semantic, + isExternal: hasExternalConnection + }) + }) + + outputPorts.forEach(port => { + const hasExternalConnection = edges.some(e => + e.source === node.id && + (e.sourceHandle || 'default') === port.id && + !selectedNodeIds.includes(e.target) + ) + availablePorts.push({ + nodeId: node.id, + nodeName: node.data.label || node.data.blockType, + portId: port.id, + portLabel: port.label, + type: 'output', + semantic: port.semantic, + isExternal: hasExternalConnection + }) + }) + }) + + useEffect(() => { + if (isOpen) { + setStep(1) + setName('') + setDescription('') + setCategory('utility') + setColor('#9333ea') + setNameError('') + setValidationErrors([]) + setSelectedPorts(new Set()) + setPortLabels(new Map()) + + // Validate selection on open + const errors: string[] = [] + + // Check connectivity + const connectivityResult = validateConnectivity(selectedNodeIds, edges) + if (!connectivityResult.isValid) { + errors.push(...connectivityResult.errors) + } + + // Check for cycles + const cycleResult = detectCycles(selectedNodeIds, edges) + if (!cycleResult.isValid) { + errors.push(...cycleResult.errors) + } + + setValidationErrors(errors) + + // Auto-select external ports + const autoSelected = new Set() + const autoLabels = new Map() + availablePorts.forEach(port => { + if (port.isExternal) { + const portKey = `${port.nodeId}-${port.portId}-${port.type}` + autoSelected.add(portKey) + autoLabels.set(portKey, `${port.type === 'input' ? 'Input' : 'Output'} ${autoLabels.size + 1}`) + } + }) + setSelectedPorts(autoSelected) + setPortLabels(autoLabels) + } + }, [isOpen, selectedNodeIds, edges]) + + const validateName = (value: string) => { + // Get existing block names + const existingNames = Array.from(groupDefinitions.values()).map(def => def.name) + + const result = validateBlockName(value, existingNames) + + if (!result.isValid && result.errors.length > 0) { + setNameError(result.errors[0]) + return false + } + + setNameError('') + return true + } + + const togglePort = (portKey: string) => { + const newSelected = new Set(selectedPorts) + if (newSelected.has(portKey)) { + newSelected.delete(portKey) + const newLabels = new Map(portLabels) + newLabels.delete(portKey) + setPortLabels(newLabels) + } else { + newSelected.add(portKey) + // Auto-generate label if not exists + if (!portLabels.has(portKey)) { + const port = availablePorts.find(p => `${p.nodeId}-${p.portId}-${p.type}` === portKey) + if (port) { + const newLabels = new Map(portLabels) + const count = Array.from(selectedPorts).filter(k => k.endsWith(port.type)).length + 1 + newLabels.set(portKey, `${port.type === 'input' ? 'Input' : 'Output'} ${count}`) + setPortLabels(newLabels) + } + } + } + setSelectedPorts(newSelected) + } + + const updatePortLabel = (portKey: string, label: string) => { + const newLabels = new Map(portLabels) + newLabels.set(portKey, label) + setPortLabels(newLabels) + } + + const handleNext = () => { + // Check for structural validation errors first + if (validationErrors.length > 0) { + // Show toast with first error for better UX + toast.error('Cannot proceed', { + description: validationErrors[0] + }) + return + } + + if (!validateName(name)) { + toast.error('Invalid block name', { + description: nameError + }) + return + } + setStep(2) + } + + const handleBack = () => { + setStep(1) + } + + const handleSave = () => { + // Check for structural validation errors + if (validationErrors.length > 0) { + toast.error('Cannot create block', { + description: validationErrors[0] + }) + return + } + + if (!validateName(name)) { + toast.error('Invalid block name', { + description: nameError + }) + return + } + + // Validate port selection + if (selectedPorts.size === 0) { + setValidationErrors(['At least one port must be exposed']) + toast.error('No ports selected', { + description: 'At least one port must be exposed on the block' + }) + return + } + + // Build port mappings from selections + const portMappings: PortMapping[] = [] + let inputIndex = 0 + let outputIndex = 0 + + selectedPorts.forEach(portKey => { + const port = availablePorts.find(p => `${p.nodeId}-${p.portId}-${p.type}` === portKey) + if (!port) return + + const externalPortId = port.type === 'input' + ? `group-input-${inputIndex++}` + : `group-output-${outputIndex++}` + + portMappings.push({ + internalNodeId: port.nodeId, + internalPortId: port.portId, + externalPortId, + externalPortLabel: portLabels.get(portKey) || port.portLabel, + type: port.type, + semantic: port.semantic as any + }) + }) + + onSave({ + name: name.trim(), + description: description.trim(), + category, + color, + portMappings + }) + onClose() + } + + const inputPorts = availablePorts.filter(p => p.type === 'input') + const outputPorts = availablePorts.filter(p => p.type === 'output') + + return ( + !open && onClose()}> + + + + + Create Block from Selection + + Step {step} of 2 + + + + Group {selectedNodeIds.length} nodes into a reusable block + + + + {step === 1 && ( +
+ {/* Validation Errors */} + {validationErrors.length > 0 && ( + + + +
    + {validationErrors.map((error, index) => ( +
  • {error}
  • + ))} +
+
+
+ )} + + {/* Name Input */} +
+ + { + setName(e.target.value) + validateName(e.target.value) + }} + className={nameError ? 'border-red-500' : ''} + /> + {nameError && ( +

{nameError}

+ )} +
+ + {/* Description Input */} +
+ +