Hosted with nbsanity. See source notebook on GitHub.

Transformer embeddings for clinical NLP

10 05 24


Objective:

  • Develop a set of embeddings that is suitable for clinical record linkage, for example identifying which pathology test result names are the same, including the use of abbreviations

References:

Clinical abbreviation datasets - https://www.nature.com/articles/s41597-021-00929-4 - https://github.com/lisavirginia/clinical-abbreviations

Fine-tuning transformers (HuggingFace) - https://huggingface.co/blog/how-to-train-sentence-transformers

import pandas as pd
import numpy as np
import polars as pl

from os import listdir

import matplotlib.pyplot as plt
import seaborn as sns

from sentence_transformers import SentenceTransformer, models
from datasets import load_dataset
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
from sentence_transformers import losses

sns.set()
%matplotlib inline

Data sources: Load and clean

Using the clinical abbreviations datasets mentioned in the references

source_folder = '/Users/alexlee/Desktop/Data/clinical/clinical_abbreviations/'
filenames = listdir(source_folder)
df1 = (
    pd
    .read_csv(f'{source_folder}/{filenames[0]}', sep='=', header=None, names=['abbreviation', 'sense'])
    .assign(abbreviation=lambda df_: df_.abbreviation.str.strip(), 
            sense=lambda df_: df_.sense.str.strip())
)
df2 = (
    pd
    .read_csv(f'{source_folder}/{filenames[1]}', sep='\t', header=None, names=['abbreviation', 'sense', 'similarity'])
    .assign(abbreviation=lambda df_: df_.abbreviation.str.strip(), 
            sense=lambda df_: df_.sense.str.strip())
)
df3 = (
    pd
    .read_csv(f'{source_folder}/{filenames[2]}', sep=',', names=['abbreviation', 'sense'])
    .assign(abbreviation=lambda df_: df_.abbreviation.str.strip(), 
            sense=lambda df_: df_.sense.str.strip())
)
df4 = (
    pd
    .read_csv(f'{source_folder}/vanderbilt_clinic_notes.txt', sep='\t')
    .assign(abbreviation=lambda df_: df_.abbreviation.str.strip(), 
            sense=lambda df_: df_.sense.str.strip())
) 
df5 = (
    pd
    .read_csv(f'{source_folder}/vanderbilt_discharge_sums.txt', sep='\t')
    .assign(abbreviation=lambda df_: df_.abbreviation.str.strip(), 
            sense=lambda df_: df_.sense.str.strip())
) 
df_all = pd.concat([df1, df2, df3, df4, df4])
df_all = (
    df_all
    .loc[:, ['abbreviation', 'sense']]
    .drop_duplicates()
    .sort_values(by=['abbreviation'])
    .query('abbreviation.isnull() == False')
    .query('sense.isnull() == False')
    .reset_index()
    .iloc[:, 1:]
)

# training data
train_data = df_all.values

Load model

## Step 1: use an existing language model
word_embedding_model = models.Transformer('distilroberta-base')
/Users/alexlee/Desktop/Coding/transformers/.venv/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
## Step 2: use a pool function over the token embeddings
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

## Join steps 1 and 2 using the modules argument
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

Create dataloader object

train_examples = []

n_examples = len(train_data)

# convert each of the examples 
for i in range(n_examples):
  example = train_data[i]
  train_examples.append(InputExample(texts=[example[0], example[1]]))
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

Loss function

Use MultipleNegativesRankingLoss since our training data consists of pairs of similar strings

train_loss = losses.MultipleNegativesRankingLoss(model=model)

Fine-tuning

model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=10) 
model.save('clinical_embeddings_100524')

Evaluation

# initially inspecting some of the matches
texts = ['sob', 
         'shortness of breath', 
         'hbg', 
         'plt', 
         'bilirubin', 
         'haemoglobin', 
         'platelets', 
         'alp', 
         'alkaline phosphatase', 
         'hb', 'hb.', 'plt.', 'plat', 's.o.b', 'sob on arrival']
texts_emb = model.encode(texts)
s1 = []
s2 = []
scores = []

for n in range(len(texts)):
    for m in range(len(texts)):
        s1.append(texts[n])
        s2.append(texts[m])
        score = texts_emb[n].dot(texts_emb[m].T)
        scores.append(score)
df = pd.DataFrame(data={'text1': s1, 
                        'text2': s2, 
                        'similarity': scores})
query = 'sob on arrival'
df.query(f'text1 == "{query}"').sort_values(by='similarity', ascending=False)
text1 text2 similarity
224 sob on arrival sob on arrival 266.899292
210 sob on arrival sob 183.777115
211 sob on arrival shortness of breath 130.903046
223 sob on arrival s.o.b 107.093262
215 sob on arrival haemoglobin 45.915550
214 sob on arrival bilirubin 27.788834
218 sob on arrival alkaline phosphatase 15.709982
219 sob on arrival hb 14.771133
220 sob on arrival hb. 8.055744
217 sob on arrival alp -5.120555
212 sob on arrival hbg -10.904447
222 sob on arrival plat -18.356281
216 sob on arrival platelets -39.142982
221 sob on arrival plt. -47.600090
213 sob on arrival plt -53.838760