Few-Shot Lung Disease Classification via Prototypical Networks

Deep Learning, Medical Imaging, Computer Vision

Project Overview

This project implements a PyTorch-based Prototypical Network with a Simple CNN encoder for few-shot multi-class classification of chest X-ray images. The model is designed to distinguish between five critical lung conditions: Normal, Bacterial Pneumonia, Viral Pneumonia, COVID-19, and Tuberculosis.

The approach is particularly suitable for medical imaging scenarios where limited labeled data is available per class. By leveraging few-shot learning techniques, the system can effectively learn from minimal examples while maintaining high classification accuracy, making it ideal for real-world medical applications.

PyTorch Python CNN Prototypical Networks Few-Shot Learning Medical Imaging Computer Vision
Project Type
Academic Research
Timeline
2024
Role
ML Researcher, Deep Learning Engineer
Accuracy
81.6%
Environment
Windows 10/11, Ubuntu 20.04+
Project Link

Key Features & Methodology

The Few-Shot Lung Disease Classification system delivers several innovative features:

  • Episodic Training: Implements episodic sampling with N_WAY classes, N_SUPPORT support images, and N_QUERY query images per class for effective few-shot learning.
  • Prototypical Networks: Computes class prototypes from support embeddings and classifies queries based on nearest prototype using Euclidean distance.
  • Simple CNN Architecture: Features 3 Conv2D layers with ReLU activation, MaxPool, AdaptiveAvgPool2d, Flatten, and Linear layers for efficient feature extraction.
  • Multi-Class Classification: Accurately distinguishes between five lung conditions: Normal, Bacterial Pneumonia, Viral Pneumonia, COVID-19, and Tuberculosis.
  • Medical Image Processing: Specialized for chest X-ray analysis with optimized preprocessing and augmentation techniques.
Overall Classification Accuracy: 81.6%

Dataset Structure

The project uses a comprehensive chest X-ray dataset organized into training and validation sets:

Lung Disease Dataset/ train/ Normal/ Bacterial Pneumonia/ Viral Pneumonia/ COVID/ Tuberculosis/ val/ Normal/ ...

Dataset Size: 6,000 training images and 2,000 validation images (customizable based on requirements)

Technical Implementation

The system architecture combines state-of-the-art few-shot learning techniques with medical imaging best practices:

  • Model Architecture: Simple CNN Encoder with 3 convolutional layers, ReLU activation, and max pooling
  • Training Configuration: Adam optimizer with learning rate 1e-3, CrossEntropyLoss, 20 epochs default
  • Hardware Requirements: Intel Core i7-1185G7 or better, 32 GB RAM, CUDA support for GPU acceleration
  • Dependencies: PyTorch 1.9+, torchvision, numpy, pandas, matplotlib, seaborn, scikit-learn

Usage Instructions

Follow these steps to set up and run the project:

  1. Clone the Repository:
git clone https://github.com/shiga2006/Multi-class-lung-disease-classification
  1. Install Dependencies:
pip install -r requirements.txt
  1. Prepare Dataset: Organize chest X-ray images according to the folder structure shown above
  2. Run Training: Edit DATA_DIR in the main script if needed, then execute:
python main.py

Results & Evaluation

The model demonstrates strong performance in few-shot medical image classification:

  • Best model checkpoint saved as: best_protonet_xray_long[5class].pth
  • Comprehensive evaluation includes training/validation loss curves and confusion matrices
  • Automatic saving of results and visualizations in the working directory
  • Academic and research-focused implementation with detailed documentation

Project Gallery