Building Synthetic Medical Records using GANs¶

In [11]:
import pandas as pd
pd.set_option('display.max_rows', 100); pd.set_option('display.max_columns', 100)
In [5]:
# dataset taken from Synthea, a synthetic patient generator. The dataset contains 1000 patients and their medical records, 
# including demographics, conditions, medications, and procedures.

patients = pd.read_csv("synthea_sample_data_csv_apr2020\csv\patients.csv")
In [ ]:
patients.head()
Out[ ]:
Id BIRTHDATE DEATHDATE SSN DRIVERS PASSPORT PREFIX FIRST LAST SUFFIX MAIDEN MARITAL RACE ETHNICITY GENDER BIRTHPLACE ADDRESS CITY STATE COUNTY ZIP LAT LON HEALTHCARE_EXPENSES HEALTHCARE_COVERAGE
0 1d604da9-9a81-4ba9-80c2-de3375d59b40 1989-05-25 NaN 999-76-6866 S99984236 X19277260X Mr. José Eduardo181 Gómez206 NaN NaN M white hispanic M Marigot Saint Andrew Parish DM 427 Balistreri Way Unit 19 Chicopee Massachusetts Hampden County 1013.0 42.228354 -72.562951 271227.08 1334.88
1 034e9e3b-2def-4559-bb2a-7850888ae060 1983-11-14 NaN 999-73-5361 S99962402 X88275464X Mr. Milo271 Feil794 NaN NaN M white nonhispanic M Danvers Massachusetts US 422 Farrell Path Unit 69 Somerville Massachusetts Middlesex County 2143.0 42.360697 -71.126531 793946.01 3204.49
2 10339b10-3cd1-4ac3-ac13-ec26728cb592 1992-06-02 NaN 999-27-3385 S99972682 X73754411X Mr. Jayson808 Fadel536 NaN NaN M white nonhispanic M Springfield Massachusetts US 1056 Harris Lane Suite 70 Chicopee Massachusetts Hampden County 1020.0 42.181642 -72.608842 574111.90 2606.40
3 8d4c4326-e9de-4f45-9a4c-f8c36bff89ae 1978-05-27 NaN 999-85-4926 S99974448 X40915583X Mrs. Mariana775 Rutherford999 NaN Williamson769 M white nonhispanic F Yarmouth Massachusetts US 999 Kuhn Forge Lowell Massachusetts Middlesex County 1851.0 42.636143 -71.343255 935630.30 8756.19
4 f5dcd418-09fe-4a2f-baa0-3da800bd8c3a 1996-10-18 NaN 999-60-7372 S99915787 X86772962X Mr. Gregorio366 Auer97 NaN NaN NaN white nonhispanic M Patras Achaea GR 1050 Lindgren Extension Apt 38 Boston Massachusetts Suffolk County 2135.0 42.352434 -71.028610 598763.07 3772.20
In [17]:
len(patients)
Out[17]:
1171
In [6]:
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
import numpy as np

# Remove ID/High-Cardinality columns (Crucial for GANs to work on generic data)
# We drop columns where almost every row is unique (like IDs, Names, SSNs)
limit = 0.95 * len(patients)
patients = patients.dropna(axis=1, how='all') # Drop empty cols
for col in patients.select_dtypes(include=['object']):
    if patients[col].nunique() > limit:
        print(f"Dropping high-cardinality column: {col}")
        patients = patients.drop(columns=[col])

num_cols = patients.select_dtypes(include=['int64', 'float64']).columns
cat_cols = patients.select_dtypes(include=['object']).columns

# Impute NaNs to avoid the BCELoss error
patients[num_cols] = patients[num_cols].fillna(patients[num_cols].mean())
for col in cat_cols:
    patients[col] = patients[col].fillna(patients[col].mode()[0]) 

encoder = OneHotEncoder(sparse_output=False)
cat_encoded = encoder.fit_transform(patients[cat_cols])

scaler = MinMaxScaler(feature_range=(-1, 1))
num_scaled = scaler.fit_transform(patients[num_cols])
# combine processed data
data_processed = np.hstack((num_scaled, cat_encoded))
Dropping high-cardinality column: Id
Dropping high-cardinality column: SSN
Dropping high-cardinality column: ADDRESS
In [19]:
data_processed
Out[19]:
array([[-0.98709677,  0.14299744, -0.5239573 , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.22795699,  0.31291511,  0.31085808, ...,  0.        ,
         0.        ,  0.        ],
       [-0.97956989,  0.08302387, -0.55062853, ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [        nan,  0.12118423,  0.2602432 , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.14086022,  0.06441201,  0.24755032, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.17096774,  0.15019466,  0.24187614, ...,  0.        ,
         0.        ,  0.        ]], shape=(1171, 8831))
In [6]:
pip install torch
Collecting torch
  Downloading torch-2.10.0-cp313-cp313-win_amd64.whl.metadata (31 kB)
Requirement already satisfied: filelock in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (3.20.0)
Requirement already satisfied: typing-extensions>=4.10.0 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (4.15.0)
Requirement already satisfied: sympy>=1.13.3 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (3.5)
Requirement already satisfied: jinja2 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (2025.10.0)
Requirement already satisfied: setuptools in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (80.9.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\users\vyasv\anaconda3\lib\site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in c:\users\vyasv\anaconda3\lib\site-packages (from jinja2->torch) (3.0.2)
Downloading torch-2.10.0-cp313-cp313-win_amd64.whl (113.8 MB)
   ---------------------------------------- 0.0/113.8 MB ? eta -:--:--
   - -------------------------------------- 3.1/113.8 MB 19.9 MB/s eta 0:00:06
   -- ------------------------------------- 7.9/113.8 MB 20.1 MB/s eta 0:00:06
   ---- ----------------------------------- 13.6/113.8 MB 22.3 MB/s eta 0:00:05
   ------ --------------------------------- 18.6/113.8 MB 22.5 MB/s eta 0:00:05
   ------- -------------------------------- 21.2/113.8 MB 20.3 MB/s eta 0:00:05
   -------- ------------------------------- 24.6/113.8 MB 19.7 MB/s eta 0:00:05
   ---------- ----------------------------- 29.4/113.8 MB 20.1 MB/s eta 0:00:05
   ----------- ---------------------------- 34.1/113.8 MB 20.3 MB/s eta 0:00:04
   ------------- -------------------------- 38.8/113.8 MB 20.5 MB/s eta 0:00:04
   --------------- ------------------------ 43.5/113.8 MB 20.6 MB/s eta 0:00:04
   ----------------- ---------------------- 48.5/113.8 MB 20.9 MB/s eta 0:00:04
   ------------------- -------------------- 54.3/113.8 MB 21.4 MB/s eta 0:00:03
   -------------------- ------------------- 59.5/113.8 MB 21.7 MB/s eta 0:00:03
   ----------------------- ---------------- 65.5/113.8 MB 22.2 MB/s eta 0:00:03
   ------------------------ --------------- 70.8/113.8 MB 22.2 MB/s eta 0:00:02
   --------------------------- ------------ 76.8/113.8 MB 22.7 MB/s eta 0:00:02
   ----------------------------- ---------- 82.8/113.8 MB 23.1 MB/s eta 0:00:02
   ------------------------------- -------- 89.1/113.8 MB 23.5 MB/s eta 0:00:02
   --------------------------------- ------ 96.2/113.8 MB 24.0 MB/s eta 0:00:01
   ----------------------------------- --- 103.0/113.8 MB 24.4 MB/s eta 0:00:01
   ------------------------------------- - 110.1/113.8 MB 24.8 MB/s eta 0:00:01
   --------------------------------------  113.5/113.8 MB 24.9 MB/s eta 0:00:01
   --------------------------------------  113.5/113.8 MB 24.9 MB/s eta 0:00:01
   ---------------------------------------- 113.8/113.8 MB 23.3 MB/s  0:00:04
Installing collected packages: torch
Successfully installed torch-2.10.0
Note: you may need to restart the kernel to use updated packages.
In [7]:
## Here’s how to build the architecture of GANs:\
import torch
import torch.nn as nn

data_dim = data_processed.shape[1]  # total features
latent_dim = 64  # size of random noise input

# generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, data_dim),
            nn.Tanh()  # output in range [-1, 1]
        )
    def forward(self, z):
        return self.model(z)

# discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(data_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()  # probability of real/fake
        )
    def forward(self, x):
        return self.model(x)
In [8]:
from torch.utils.data import DataLoader, TensorDataset

# convert data to PyTorch tensors
real_data = torch.tensor(data_processed, dtype=torch.float32)
dataset = TensorDataset(real_data)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

# initialize models
generator = Generator()
discriminator = Discriminator()

# optimizers
lr = 0.0002
optim_G = torch.optim.Adam(generator.parameters(), lr=lr)
optim_D = torch.optim.Adam(discriminator.parameters(), lr=lr)

# loss
criterion = nn.BCELoss()

epochs = 2000
for epoch in range(epochs):
    for real_batch, in loader:
        batch_size = real_batch.size(0)

        # labels for real and fake data
        real_labels = torch.ones((batch_size, 1))
        fake_labels = torch.zeros((batch_size, 1))

        # train discriminator
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z)

        real_loss = criterion(discriminator(real_batch), real_labels)
        fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2

        optim_D.zero_grad()
        d_loss.backward()
        optim_D.step()

        # train generator
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z)
        g_loss = criterion(discriminator(fake_data), real_labels)  # want fake to be real

        optim_G.zero_grad()
        g_loss.backward()
        optim_G.step()

    if epoch % 200 == 0:
        print(f"Epoch [{epoch}/{epochs}]  D_loss: {d_loss.item():.4f}  G_loss: {g_loss.item():.4f}")
Epoch [0/2000]  D_loss: 0.8390  G_loss: 0.5230
Epoch [200/2000]  D_loss: 0.5556  G_loss: 1.5419
Epoch [400/2000]  D_loss: 0.2877  G_loss: 2.8194
Epoch [600/2000]  D_loss: 0.0972  G_loss: 5.7064
Epoch [800/2000]  D_loss: 0.0681  G_loss: 4.5970
Epoch [1000/2000]  D_loss: 0.0828  G_loss: 5.7511
Epoch [1200/2000]  D_loss: 0.0042  G_loss: 5.8584
Epoch [1400/2000]  D_loss: 0.0000  G_loss: 80.8665
Epoch [1600/2000]  D_loss: 0.0000  G_loss: 100.0000
Epoch [1800/2000]  D_loss: 0.0000  G_loss: 100.0000
In [9]:
# generate new synthetic data
z = torch.randn(1000, latent_dim)  # 1000 synthetic samples
synthetic_data_scaled = generator(z).detach().numpy()

# inverse transform
num_synthetic = scaler.inverse_transform(synthetic_data_scaled[:, :len(num_cols)])
cat_synthetic = encoder.inverse_transform(synthetic_data_scaled[:, len(num_cols):])

# combine into dataframe
synthetic_df = pd.DataFrame(num_synthetic, columns=num_cols)
synthetic_df[cat_cols] = cat_synthetic

print(synthetic_df)
             ZIP        LAT        LON  HEALTHCARE_EXPENSES  \
0    2161.355957  42.356140 -71.173187         539843.12500   
1    2186.058350  42.299240 -71.162605         880593.93750   
2    1817.902222  42.277500 -71.543427         364262.21875   
3    1853.842651  42.238213 -71.468315         511571.37500   
4    2140.446777  42.312260 -71.257774         700188.75000   
..           ...        ...        ...                  ...   
995  1830.314331  42.248318 -71.447144         554655.75000   
996  2145.418945  42.367058 -71.091789         730281.62500   
997  2163.928955  42.347404 -71.090889         822593.87500   
998  2201.406982  42.337494 -71.128876         772574.50000   
999  2034.734375  42.301086 -71.240829         634362.06250   

     HEALTHCARE_COVERAGE   BIRTHDATE   DEATHDATE    DRIVERS    PASSPORT  \
0            2038.231567  1991-10-25  1926-03-05  S99934122  X10056752X   
1           25678.156250  1974-05-30  1926-03-05  S99996704    X622036X   
2              23.090076  2017-12-23  1926-03-05  S99927008  X10056752X   
3             123.995094  2017-12-23  1926-03-05  S99927475  X55831878X   
4            5956.271973  1983-09-02  1926-03-05  S99934122  X10056752X   
..                   ...         ...         ...        ...         ...   
995            87.852898  2017-12-23  1926-03-05  S99955008  X10056752X   
996         10727.262695  1923-05-15  1926-03-05  S99934122  X10056752X   
997         21109.224609  1990-07-18  1926-03-05  S99934122  X10056752X   
998         14073.028320  1979-03-26  1926-03-05  S99934122  X10056752X   
999          1211.357910  2019-03-05  1926-03-05  S99934122  X10056752X   

    PREFIX       FIRST          LAST SUFFIX     MAIDEN MARITAL   RACE  \
0      Mr.  Patrick786     Wehner319     JD  Keeling57       M  white   
1      Mr.  Felicia295     Abbott774     JD  Keeling57       M  white   
2      Mr.     Eura647      Grant908     JD  Keeling57       M  white   
3      Mr.   Ángela136      Grant908     JD  Keeling57       M  white   
4      Mr.      Oda116      Grant908     JD  Keeling57       M  white   
..     ...         ...           ...    ...        ...     ...    ...   
995    Mr.   Monroe732      Grant908     JD  Keeling57       M  white   
996    Mr.   Monroe732      Grant908     JD  Keeling57       M  white   
997    Mr.    Iraida50     Zulauf375     JD  Keeling57       M  white   
998    Mr.  Patrick786   Buckridge80     JD  Keeling57       M  white   
999    Mr.    Sixta311  Bergstrom287     JD  Keeling57       M  white   

       ETHNICITY GENDER                     BIRTHPLACE         CITY  \
0    nonhispanic      F    Hamilton  Massachusetts  US       Boston   
1    nonhispanic      F      Boston  Massachusetts  US       Boston   
2    nonhispanic      F  Somerville  Massachusetts  US  Southampton   
3    nonhispanic      F      Boston  Massachusetts  US      Melrose   
4    nonhispanic      F      Boston  Massachusetts  US       Boston   
..           ...    ...                            ...          ...   
995  nonhispanic      F  Somerville  Massachusetts  US      Melrose   
996  nonhispanic      F      Boston  Massachusetts  US       Boston   
997  nonhispanic      F      Boston  Massachusetts  US       Boston   
998  nonhispanic      F      Boston  Massachusetts  US       Boston   
999  nonhispanic      F      Boston  Massachusetts  US   Fall River   

             STATE            COUNTY  
0    Massachusetts  Middlesex County  
1    Massachusetts  Middlesex County  
2    Massachusetts      Essex County  
3    Massachusetts    Suffolk County  
4    Massachusetts  Middlesex County  
..             ...               ...  
995  Massachusetts    Suffolk County  
996  Massachusetts  Middlesex County  
997  Massachusetts  Middlesex County  
998  Massachusetts  Middlesex County  
999  Massachusetts    Suffolk County  

[1000 rows x 22 columns]

This is the generated dataset! Quite cool.¶

In [10]:
synthetic_df
Out[10]:
ZIP LAT LON HEALTHCARE_EXPENSES HEALTHCARE_COVERAGE BIRTHDATE DEATHDATE DRIVERS PASSPORT PREFIX FIRST LAST SUFFIX MAIDEN MARITAL RACE ETHNICITY GENDER BIRTHPLACE CITY STATE COUNTY
0 2161.355957 42.356140 -71.173187 539843.12500 2038.231567 1991-10-25 1926-03-05 S99934122 X10056752X Mr. Patrick786 Wehner319 JD Keeling57 M white nonhispanic F Hamilton Massachusetts US Boston Massachusetts Middlesex County
1 2186.058350 42.299240 -71.162605 880593.93750 25678.156250 1974-05-30 1926-03-05 S99996704 X622036X Mr. Felicia295 Abbott774 JD Keeling57 M white nonhispanic F Boston Massachusetts US Boston Massachusetts Middlesex County
2 1817.902222 42.277500 -71.543427 364262.21875 23.090076 2017-12-23 1926-03-05 S99927008 X10056752X Mr. Eura647 Grant908 JD Keeling57 M white nonhispanic F Somerville Massachusetts US Southampton Massachusetts Essex County
3 1853.842651 42.238213 -71.468315 511571.37500 123.995094 2017-12-23 1926-03-05 S99927475 X55831878X Mr. Ángela136 Grant908 JD Keeling57 M white nonhispanic F Boston Massachusetts US Melrose Massachusetts Suffolk County
4 2140.446777 42.312260 -71.257774 700188.75000 5956.271973 1983-09-02 1926-03-05 S99934122 X10056752X Mr. Oda116 Grant908 JD Keeling57 M white nonhispanic F Boston Massachusetts US Boston Massachusetts Middlesex County
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 1830.314331 42.248318 -71.447144 554655.75000 87.852898 2017-12-23 1926-03-05 S99955008 X10056752X Mr. Monroe732 Grant908 JD Keeling57 M white nonhispanic F Somerville Massachusetts US Melrose Massachusetts Suffolk County
996 2145.418945 42.367058 -71.091789 730281.62500 10727.262695 1923-05-15 1926-03-05 S99934122 X10056752X Mr. Monroe732 Grant908 JD Keeling57 M white nonhispanic F Boston Massachusetts US Boston Massachusetts Middlesex County
997 2163.928955 42.347404 -71.090889 822593.87500 21109.224609 1990-07-18 1926-03-05 S99934122 X10056752X Mr. Iraida50 Zulauf375 JD Keeling57 M white nonhispanic F Boston Massachusetts US Boston Massachusetts Middlesex County
998 2201.406982 42.337494 -71.128876 772574.50000 14073.028320 1979-03-26 1926-03-05 S99934122 X10056752X Mr. Patrick786 Buckridge80 JD Keeling57 M white nonhispanic F Boston Massachusetts US Boston Massachusetts Middlesex County
999 2034.734375 42.301086 -71.240829 634362.06250 1211.357910 2019-03-05 1926-03-05 S99934122 X10056752X Mr. Sixta311 Bergstrom287 JD Keeling57 M white nonhispanic F Boston Massachusetts US Fall River Massachusetts Suffolk County

1000 rows × 22 columns

In [12]:
synthetic_df['PASSPORT'].value_counts()
Out[12]:
PASSPORT
X10056752X    623
X42955785X     75
X86615364X     34
X58347333X     32
X622036X       23
X35334533X     23
X80908865X     17
X66208297X     15
X42581289X     13
X20177258X     13
X70628984X     12
X62794246X     11
X22725008X     10
X57951475X     10
X28593812X     10
X41095772X     10
X55831878X      9
X69522263X      8
X37764385X      7
X38846992X      7
X41431358X      6
X85804833X      6
X87625600X      4
X65838399X      4
X14093941X      3
X27959569X      3
X78670564X      2
X66829529X      2
X35840246X      1
X40355437X      1
X3274633X       1
X6198901X       1
X27167656X      1
X82842351X      1
X36011830X      1
X70211797X      1
Name: count, dtype: int64
In [ ]:
# saving this notebook as an html file
In [2]:
import os

# Save notebook as HTML using nbconvert
notebook_filename = os.path.basename(__file__) if '__file__' in globals() else 'gan_medical_records.ipynb'
!jupyter nbconvert --to html --output "synthetic_medical_records.html" "{notebook_filename}"
[NbConvertApp] Converting notebook gan_medical_records.ipynb to html
[NbConvertApp] Writing 335095 bytes to synthetic_medical_records.html
In [ ]: