Building Synthetic Medical Records using GANs¶
In [11]:
import pandas as pd
pd.set_option('display.max_rows', 100); pd.set_option('display.max_columns', 100)
In [5]:
# dataset taken from Synthea, a synthetic patient generator. The dataset contains 1000 patients and their medical records,
# including demographics, conditions, medications, and procedures.
patients = pd.read_csv("synthea_sample_data_csv_apr2020\csv\patients.csv")
In [ ]:
patients.head()
Out[ ]:
| Id | BIRTHDATE | DEATHDATE | SSN | DRIVERS | PASSPORT | PREFIX | FIRST | LAST | SUFFIX | MAIDEN | MARITAL | RACE | ETHNICITY | GENDER | BIRTHPLACE | ADDRESS | CITY | STATE | COUNTY | ZIP | LAT | LON | HEALTHCARE_EXPENSES | HEALTHCARE_COVERAGE | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1d604da9-9a81-4ba9-80c2-de3375d59b40 | 1989-05-25 | NaN | 999-76-6866 | S99984236 | X19277260X | Mr. | José Eduardo181 | Gómez206 | NaN | NaN | M | white | hispanic | M | Marigot Saint Andrew Parish DM | 427 Balistreri Way Unit 19 | Chicopee | Massachusetts | Hampden County | 1013.0 | 42.228354 | -72.562951 | 271227.08 | 1334.88 |
| 1 | 034e9e3b-2def-4559-bb2a-7850888ae060 | 1983-11-14 | NaN | 999-73-5361 | S99962402 | X88275464X | Mr. | Milo271 | Feil794 | NaN | NaN | M | white | nonhispanic | M | Danvers Massachusetts US | 422 Farrell Path Unit 69 | Somerville | Massachusetts | Middlesex County | 2143.0 | 42.360697 | -71.126531 | 793946.01 | 3204.49 |
| 2 | 10339b10-3cd1-4ac3-ac13-ec26728cb592 | 1992-06-02 | NaN | 999-27-3385 | S99972682 | X73754411X | Mr. | Jayson808 | Fadel536 | NaN | NaN | M | white | nonhispanic | M | Springfield Massachusetts US | 1056 Harris Lane Suite 70 | Chicopee | Massachusetts | Hampden County | 1020.0 | 42.181642 | -72.608842 | 574111.90 | 2606.40 |
| 3 | 8d4c4326-e9de-4f45-9a4c-f8c36bff89ae | 1978-05-27 | NaN | 999-85-4926 | S99974448 | X40915583X | Mrs. | Mariana775 | Rutherford999 | NaN | Williamson769 | M | white | nonhispanic | F | Yarmouth Massachusetts US | 999 Kuhn Forge | Lowell | Massachusetts | Middlesex County | 1851.0 | 42.636143 | -71.343255 | 935630.30 | 8756.19 |
| 4 | f5dcd418-09fe-4a2f-baa0-3da800bd8c3a | 1996-10-18 | NaN | 999-60-7372 | S99915787 | X86772962X | Mr. | Gregorio366 | Auer97 | NaN | NaN | NaN | white | nonhispanic | M | Patras Achaea GR | 1050 Lindgren Extension Apt 38 | Boston | Massachusetts | Suffolk County | 2135.0 | 42.352434 | -71.028610 | 598763.07 | 3772.20 |
In [17]:
len(patients)
Out[17]:
1171
In [6]:
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
import numpy as np
# Remove ID/High-Cardinality columns (Crucial for GANs to work on generic data)
# We drop columns where almost every row is unique (like IDs, Names, SSNs)
limit = 0.95 * len(patients)
patients = patients.dropna(axis=1, how='all') # Drop empty cols
for col in patients.select_dtypes(include=['object']):
if patients[col].nunique() > limit:
print(f"Dropping high-cardinality column: {col}")
patients = patients.drop(columns=[col])
num_cols = patients.select_dtypes(include=['int64', 'float64']).columns
cat_cols = patients.select_dtypes(include=['object']).columns
# Impute NaNs to avoid the BCELoss error
patients[num_cols] = patients[num_cols].fillna(patients[num_cols].mean())
for col in cat_cols:
patients[col] = patients[col].fillna(patients[col].mode()[0])
encoder = OneHotEncoder(sparse_output=False)
cat_encoded = encoder.fit_transform(patients[cat_cols])
scaler = MinMaxScaler(feature_range=(-1, 1))
num_scaled = scaler.fit_transform(patients[num_cols])
# combine processed data
data_processed = np.hstack((num_scaled, cat_encoded))
Dropping high-cardinality column: Id Dropping high-cardinality column: SSN Dropping high-cardinality column: ADDRESS
In [19]:
data_processed
Out[19]:
array([[-0.98709677, 0.14299744, -0.5239573 , ..., 0. ,
0. , 0. ],
[ 0.22795699, 0.31291511, 0.31085808, ..., 0. ,
0. , 0. ],
[-0.97956989, 0.08302387, -0.55062853, ..., 0. ,
0. , 0. ],
...,
[ nan, 0.12118423, 0.2602432 , ..., 0. ,
0. , 0. ],
[ 0.14086022, 0.06441201, 0.24755032, ..., 0. ,
0. , 0. ],
[ 0.17096774, 0.15019466, 0.24187614, ..., 0. ,
0. , 0. ]], shape=(1171, 8831))
In [6]:
pip install torch
Collecting torch Downloading torch-2.10.0-cp313-cp313-win_amd64.whl.metadata (31 kB) Requirement already satisfied: filelock in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (3.20.0) Requirement already satisfied: typing-extensions>=4.10.0 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (4.15.0) Requirement already satisfied: sympy>=1.13.3 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (1.14.0) Requirement already satisfied: networkx>=2.5.1 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (3.5) Requirement already satisfied: jinja2 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (3.1.6) Requirement already satisfied: fsspec>=0.8.5 in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (2025.10.0) Requirement already satisfied: setuptools in c:\users\vyasv\anaconda3\lib\site-packages (from torch) (80.9.0) Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\users\vyasv\anaconda3\lib\site-packages (from sympy>=1.13.3->torch) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in c:\users\vyasv\anaconda3\lib\site-packages (from jinja2->torch) (3.0.2) Downloading torch-2.10.0-cp313-cp313-win_amd64.whl (113.8 MB) ---------------------------------------- 0.0/113.8 MB ? eta -:--:-- - -------------------------------------- 3.1/113.8 MB 19.9 MB/s eta 0:00:06 -- ------------------------------------- 7.9/113.8 MB 20.1 MB/s eta 0:00:06 ---- ----------------------------------- 13.6/113.8 MB 22.3 MB/s eta 0:00:05 ------ --------------------------------- 18.6/113.8 MB 22.5 MB/s eta 0:00:05 ------- -------------------------------- 21.2/113.8 MB 20.3 MB/s eta 0:00:05 -------- ------------------------------- 24.6/113.8 MB 19.7 MB/s eta 0:00:05 ---------- ----------------------------- 29.4/113.8 MB 20.1 MB/s eta 0:00:05 ----------- ---------------------------- 34.1/113.8 MB 20.3 MB/s eta 0:00:04 ------------- -------------------------- 38.8/113.8 MB 20.5 MB/s eta 0:00:04 --------------- ------------------------ 43.5/113.8 MB 20.6 MB/s eta 0:00:04 ----------------- ---------------------- 48.5/113.8 MB 20.9 MB/s eta 0:00:04 ------------------- -------------------- 54.3/113.8 MB 21.4 MB/s eta 0:00:03 -------------------- ------------------- 59.5/113.8 MB 21.7 MB/s eta 0:00:03 ----------------------- ---------------- 65.5/113.8 MB 22.2 MB/s eta 0:00:03 ------------------------ --------------- 70.8/113.8 MB 22.2 MB/s eta 0:00:02 --------------------------- ------------ 76.8/113.8 MB 22.7 MB/s eta 0:00:02 ----------------------------- ---------- 82.8/113.8 MB 23.1 MB/s eta 0:00:02 ------------------------------- -------- 89.1/113.8 MB 23.5 MB/s eta 0:00:02 --------------------------------- ------ 96.2/113.8 MB 24.0 MB/s eta 0:00:01 ----------------------------------- --- 103.0/113.8 MB 24.4 MB/s eta 0:00:01 ------------------------------------- - 110.1/113.8 MB 24.8 MB/s eta 0:00:01 -------------------------------------- 113.5/113.8 MB 24.9 MB/s eta 0:00:01 -------------------------------------- 113.5/113.8 MB 24.9 MB/s eta 0:00:01 ---------------------------------------- 113.8/113.8 MB 23.3 MB/s 0:00:04 Installing collected packages: torch Successfully installed torch-2.10.0 Note: you may need to restart the kernel to use updated packages.
In [7]:
## Here’s how to build the architecture of GANs:\
import torch
import torch.nn as nn
data_dim = data_processed.shape[1] # total features
latent_dim = 64 # size of random noise input
# generator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, data_dim),
nn.Tanh() # output in range [-1, 1]
)
def forward(self, z):
return self.model(z)
# discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(data_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid() # probability of real/fake
)
def forward(self, x):
return self.model(x)
In [8]:
from torch.utils.data import DataLoader, TensorDataset
# convert data to PyTorch tensors
real_data = torch.tensor(data_processed, dtype=torch.float32)
dataset = TensorDataset(real_data)
loader = DataLoader(dataset, batch_size=16, shuffle=True)
# initialize models
generator = Generator()
discriminator = Discriminator()
# optimizers
lr = 0.0002
optim_G = torch.optim.Adam(generator.parameters(), lr=lr)
optim_D = torch.optim.Adam(discriminator.parameters(), lr=lr)
# loss
criterion = nn.BCELoss()
epochs = 2000
for epoch in range(epochs):
for real_batch, in loader:
batch_size = real_batch.size(0)
# labels for real and fake data
real_labels = torch.ones((batch_size, 1))
fake_labels = torch.zeros((batch_size, 1))
# train discriminator
z = torch.randn(batch_size, latent_dim)
fake_data = generator(z)
real_loss = criterion(discriminator(real_batch), real_labels)
fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
d_loss = (real_loss + fake_loss) / 2
optim_D.zero_grad()
d_loss.backward()
optim_D.step()
# train generator
z = torch.randn(batch_size, latent_dim)
fake_data = generator(z)
g_loss = criterion(discriminator(fake_data), real_labels) # want fake to be real
optim_G.zero_grad()
g_loss.backward()
optim_G.step()
if epoch % 200 == 0:
print(f"Epoch [{epoch}/{epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}")
Epoch [0/2000] D_loss: 0.8390 G_loss: 0.5230 Epoch [200/2000] D_loss: 0.5556 G_loss: 1.5419 Epoch [400/2000] D_loss: 0.2877 G_loss: 2.8194 Epoch [600/2000] D_loss: 0.0972 G_loss: 5.7064 Epoch [800/2000] D_loss: 0.0681 G_loss: 4.5970 Epoch [1000/2000] D_loss: 0.0828 G_loss: 5.7511 Epoch [1200/2000] D_loss: 0.0042 G_loss: 5.8584 Epoch [1400/2000] D_loss: 0.0000 G_loss: 80.8665 Epoch [1600/2000] D_loss: 0.0000 G_loss: 100.0000 Epoch [1800/2000] D_loss: 0.0000 G_loss: 100.0000
In [9]:
# generate new synthetic data
z = torch.randn(1000, latent_dim) # 1000 synthetic samples
synthetic_data_scaled = generator(z).detach().numpy()
# inverse transform
num_synthetic = scaler.inverse_transform(synthetic_data_scaled[:, :len(num_cols)])
cat_synthetic = encoder.inverse_transform(synthetic_data_scaled[:, len(num_cols):])
# combine into dataframe
synthetic_df = pd.DataFrame(num_synthetic, columns=num_cols)
synthetic_df[cat_cols] = cat_synthetic
print(synthetic_df)
ZIP LAT LON HEALTHCARE_EXPENSES \
0 2161.355957 42.356140 -71.173187 539843.12500
1 2186.058350 42.299240 -71.162605 880593.93750
2 1817.902222 42.277500 -71.543427 364262.21875
3 1853.842651 42.238213 -71.468315 511571.37500
4 2140.446777 42.312260 -71.257774 700188.75000
.. ... ... ... ...
995 1830.314331 42.248318 -71.447144 554655.75000
996 2145.418945 42.367058 -71.091789 730281.62500
997 2163.928955 42.347404 -71.090889 822593.87500
998 2201.406982 42.337494 -71.128876 772574.50000
999 2034.734375 42.301086 -71.240829 634362.06250
HEALTHCARE_COVERAGE BIRTHDATE DEATHDATE DRIVERS PASSPORT \
0 2038.231567 1991-10-25 1926-03-05 S99934122 X10056752X
1 25678.156250 1974-05-30 1926-03-05 S99996704 X622036X
2 23.090076 2017-12-23 1926-03-05 S99927008 X10056752X
3 123.995094 2017-12-23 1926-03-05 S99927475 X55831878X
4 5956.271973 1983-09-02 1926-03-05 S99934122 X10056752X
.. ... ... ... ... ...
995 87.852898 2017-12-23 1926-03-05 S99955008 X10056752X
996 10727.262695 1923-05-15 1926-03-05 S99934122 X10056752X
997 21109.224609 1990-07-18 1926-03-05 S99934122 X10056752X
998 14073.028320 1979-03-26 1926-03-05 S99934122 X10056752X
999 1211.357910 2019-03-05 1926-03-05 S99934122 X10056752X
PREFIX FIRST LAST SUFFIX MAIDEN MARITAL RACE \
0 Mr. Patrick786 Wehner319 JD Keeling57 M white
1 Mr. Felicia295 Abbott774 JD Keeling57 M white
2 Mr. Eura647 Grant908 JD Keeling57 M white
3 Mr. Ángela136 Grant908 JD Keeling57 M white
4 Mr. Oda116 Grant908 JD Keeling57 M white
.. ... ... ... ... ... ... ...
995 Mr. Monroe732 Grant908 JD Keeling57 M white
996 Mr. Monroe732 Grant908 JD Keeling57 M white
997 Mr. Iraida50 Zulauf375 JD Keeling57 M white
998 Mr. Patrick786 Buckridge80 JD Keeling57 M white
999 Mr. Sixta311 Bergstrom287 JD Keeling57 M white
ETHNICITY GENDER BIRTHPLACE CITY \
0 nonhispanic F Hamilton Massachusetts US Boston
1 nonhispanic F Boston Massachusetts US Boston
2 nonhispanic F Somerville Massachusetts US Southampton
3 nonhispanic F Boston Massachusetts US Melrose
4 nonhispanic F Boston Massachusetts US Boston
.. ... ... ... ...
995 nonhispanic F Somerville Massachusetts US Melrose
996 nonhispanic F Boston Massachusetts US Boston
997 nonhispanic F Boston Massachusetts US Boston
998 nonhispanic F Boston Massachusetts US Boston
999 nonhispanic F Boston Massachusetts US Fall River
STATE COUNTY
0 Massachusetts Middlesex County
1 Massachusetts Middlesex County
2 Massachusetts Essex County
3 Massachusetts Suffolk County
4 Massachusetts Middlesex County
.. ... ...
995 Massachusetts Suffolk County
996 Massachusetts Middlesex County
997 Massachusetts Middlesex County
998 Massachusetts Middlesex County
999 Massachusetts Suffolk County
[1000 rows x 22 columns]
This is the generated dataset! Quite cool.¶
In [10]:
synthetic_df
Out[10]:
| ZIP | LAT | LON | HEALTHCARE_EXPENSES | HEALTHCARE_COVERAGE | BIRTHDATE | DEATHDATE | DRIVERS | PASSPORT | PREFIX | FIRST | LAST | SUFFIX | MAIDEN | MARITAL | RACE | ETHNICITY | GENDER | BIRTHPLACE | CITY | STATE | COUNTY | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2161.355957 | 42.356140 | -71.173187 | 539843.12500 | 2038.231567 | 1991-10-25 | 1926-03-05 | S99934122 | X10056752X | Mr. | Patrick786 | Wehner319 | JD | Keeling57 | M | white | nonhispanic | F | Hamilton Massachusetts US | Boston | Massachusetts | Middlesex County |
| 1 | 2186.058350 | 42.299240 | -71.162605 | 880593.93750 | 25678.156250 | 1974-05-30 | 1926-03-05 | S99996704 | X622036X | Mr. | Felicia295 | Abbott774 | JD | Keeling57 | M | white | nonhispanic | F | Boston Massachusetts US | Boston | Massachusetts | Middlesex County |
| 2 | 1817.902222 | 42.277500 | -71.543427 | 364262.21875 | 23.090076 | 2017-12-23 | 1926-03-05 | S99927008 | X10056752X | Mr. | Eura647 | Grant908 | JD | Keeling57 | M | white | nonhispanic | F | Somerville Massachusetts US | Southampton | Massachusetts | Essex County |
| 3 | 1853.842651 | 42.238213 | -71.468315 | 511571.37500 | 123.995094 | 2017-12-23 | 1926-03-05 | S99927475 | X55831878X | Mr. | Ángela136 | Grant908 | JD | Keeling57 | M | white | nonhispanic | F | Boston Massachusetts US | Melrose | Massachusetts | Suffolk County |
| 4 | 2140.446777 | 42.312260 | -71.257774 | 700188.75000 | 5956.271973 | 1983-09-02 | 1926-03-05 | S99934122 | X10056752X | Mr. | Oda116 | Grant908 | JD | Keeling57 | M | white | nonhispanic | F | Boston Massachusetts US | Boston | Massachusetts | Middlesex County |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 995 | 1830.314331 | 42.248318 | -71.447144 | 554655.75000 | 87.852898 | 2017-12-23 | 1926-03-05 | S99955008 | X10056752X | Mr. | Monroe732 | Grant908 | JD | Keeling57 | M | white | nonhispanic | F | Somerville Massachusetts US | Melrose | Massachusetts | Suffolk County |
| 996 | 2145.418945 | 42.367058 | -71.091789 | 730281.62500 | 10727.262695 | 1923-05-15 | 1926-03-05 | S99934122 | X10056752X | Mr. | Monroe732 | Grant908 | JD | Keeling57 | M | white | nonhispanic | F | Boston Massachusetts US | Boston | Massachusetts | Middlesex County |
| 997 | 2163.928955 | 42.347404 | -71.090889 | 822593.87500 | 21109.224609 | 1990-07-18 | 1926-03-05 | S99934122 | X10056752X | Mr. | Iraida50 | Zulauf375 | JD | Keeling57 | M | white | nonhispanic | F | Boston Massachusetts US | Boston | Massachusetts | Middlesex County |
| 998 | 2201.406982 | 42.337494 | -71.128876 | 772574.50000 | 14073.028320 | 1979-03-26 | 1926-03-05 | S99934122 | X10056752X | Mr. | Patrick786 | Buckridge80 | JD | Keeling57 | M | white | nonhispanic | F | Boston Massachusetts US | Boston | Massachusetts | Middlesex County |
| 999 | 2034.734375 | 42.301086 | -71.240829 | 634362.06250 | 1211.357910 | 2019-03-05 | 1926-03-05 | S99934122 | X10056752X | Mr. | Sixta311 | Bergstrom287 | JD | Keeling57 | M | white | nonhispanic | F | Boston Massachusetts US | Fall River | Massachusetts | Suffolk County |
1000 rows × 22 columns
In [12]:
synthetic_df['PASSPORT'].value_counts()
Out[12]:
PASSPORT X10056752X 623 X42955785X 75 X86615364X 34 X58347333X 32 X622036X 23 X35334533X 23 X80908865X 17 X66208297X 15 X42581289X 13 X20177258X 13 X70628984X 12 X62794246X 11 X22725008X 10 X57951475X 10 X28593812X 10 X41095772X 10 X55831878X 9 X69522263X 8 X37764385X 7 X38846992X 7 X41431358X 6 X85804833X 6 X87625600X 4 X65838399X 4 X14093941X 3 X27959569X 3 X78670564X 2 X66829529X 2 X35840246X 1 X40355437X 1 X3274633X 1 X6198901X 1 X27167656X 1 X82842351X 1 X36011830X 1 X70211797X 1 Name: count, dtype: int64
In [ ]:
# saving this notebook as an html file
In [2]:
import os
# Save notebook as HTML using nbconvert
notebook_filename = os.path.basename(__file__) if '__file__' in globals() else 'gan_medical_records.ipynb'
!jupyter nbconvert --to html --output "synthetic_medical_records.html" "{notebook_filename}"
[NbConvertApp] Converting notebook gan_medical_records.ipynb to html [NbConvertApp] Writing 335095 bytes to synthetic_medical_records.html
In [ ]: