Project Overview
This project implements a PyTorch-based Prototypical Network with a Simple CNN encoder for few-shot multi-class classification of chest X-ray images. The model is designed to distinguish between five critical lung conditions: Normal, Bacterial Pneumonia, Viral Pneumonia, COVID-19, and Tuberculosis.
The approach is particularly suitable for medical imaging scenarios where limited labeled data is available per class. By leveraging few-shot learning techniques, the system can effectively learn from minimal examples while maintaining high classification accuracy, making it ideal for real-world medical applications.
PyTorch
Python
CNN
Prototypical Networks
Few-Shot Learning
Medical Imaging
Computer Vision
Key Features & Methodology
The Few-Shot Lung Disease Classification system delivers several innovative features:
- Episodic Training: Implements episodic sampling with N_WAY classes, N_SUPPORT support images, and N_QUERY query images per class for effective few-shot learning.
- Prototypical Networks: Computes class prototypes from support embeddings and classifies queries based on nearest prototype using Euclidean distance.
- Simple CNN Architecture: Features 3 Conv2D layers with ReLU activation, MaxPool, AdaptiveAvgPool2d, Flatten, and Linear layers for efficient feature extraction.
- Multi-Class Classification: Accurately distinguishes between five lung conditions: Normal, Bacterial Pneumonia, Viral Pneumonia, COVID-19, and Tuberculosis.
- Medical Image Processing: Specialized for chest X-ray analysis with optimized preprocessing and augmentation techniques.
Overall Classification Accuracy: 81.6%
Dataset Structure
The project uses a comprehensive chest X-ray dataset organized into training and validation sets:
Lung Disease Dataset/
train/
Normal/
Bacterial Pneumonia/
Viral Pneumonia/
COVID/
Tuberculosis/
val/
Normal/
...
Dataset Size: 6,000 training images and 2,000 validation images (customizable based on requirements)
Technical Implementation
The system architecture combines state-of-the-art few-shot learning techniques with medical imaging best practices:
- Model Architecture: Simple CNN Encoder with 3 convolutional layers, ReLU activation, and max pooling
- Training Configuration: Adam optimizer with learning rate 1e-3, CrossEntropyLoss, 20 epochs default
- Hardware Requirements: Intel Core i7-1185G7 or better, 32 GB RAM, CUDA support for GPU acceleration
- Dependencies: PyTorch 1.9+, torchvision, numpy, pandas, matplotlib, seaborn, scikit-learn
Usage Instructions
Follow these steps to set up and run the project:
- Clone the Repository:
git clone https://github.com/shiga2006/Multi-class-lung-disease-classification
- Install Dependencies:
pip install -r requirements.txt
- Prepare Dataset: Organize chest X-ray images according to the folder structure shown above
- Run Training: Edit DATA_DIR in the main script if needed, then execute:
python main.py
Results & Evaluation
The model demonstrates strong performance in few-shot medical image classification:
- Best model checkpoint saved as:
best_protonet_xray_long[5class].pth
- Comprehensive evaluation includes training/validation loss curves and confusion matrices
- Automatic saving of results and visualizations in the working directory
- Academic and research-focused implementation with detailed documentation
Project Gallery