This notebook is available on GitHub.
Context¶
The problem encountered was more or less a customer journey. A customer may first do A, and then B, then B again, then C three times, then do B again twice. A series like ABBCCCBB. The goal is to remove the duplicates found in the neighbouring events. If there is another element between two same elements, the two elements are not duplicates. In our example above, we want the final result as ABCB. We try to achieve this with drop_duplicates
method of Pandas
data frame. In this blog, I would like to share the frustration and lesson I learned from solving this problem.
%matplotlib inline
import pandas as pd
import numpy as np
Problem¶
Unfortunately, there isn't a built-in function in Pandas
that can do that. Let's first see what we get with the built-in function.
dat = pd.DataFrame({'event': list('ABBCCCBB')})
dat
dat.drop_duplicates()
Solution in a Simple Case¶
The method clearly counts there are only three unique elements. However, it drops the trailing B's we actually want to keep. One vital observation leading to the solution is if a cell differs from the cell above, then it should be kept. So let's shift the data frame downwards so we can see a cell and the cell above in the row.
dat.loc[:, 'event_shifted'] = dat.event.shift()
dat.loc[:, 'is_different'] = dat.event != dat.event_shifted
dat
Therefore, if I keep only the rows that is_different is true, problem would be solved.
dat.loc[dat.is_different, ['event']]
Retrospectively, this totally makes sense as the range of our defination for "duplicates" is limited to the row above rather than the whole column as is assumed in drop_duplicates
.
A Little More Complex Case¶
While this simple solution works magically, what about the situation of multiple columns. For example, instead of one we now have to customers. Running the same solution occasionaly will err. Like the example below.
dat = pd.DataFrame({'event': list('ABBCCCBBBBCCCBB'), 'customer_id': [1]*8 + [2]*7})
dat
dat.loc[:, 'event_shifted'] = dat.event.shift()
dat.loc[:, 'is_different'] = dat.event != dat.event_shifted
dat.loc[dat.is_different, ['customer_id', 'event']]
The first event of customer 2 was removed as is the same as the last event of customer 1. Therefore customer id should also be compared.
dat = pd.DataFrame({'event': list('ABBCCCBBBBCCCBB'), 'customer_id': [1]*8 + [2]*7})
shifted = dat.shift()
is_different = (dat.customer_id != shifted.customer_id) | (dat.event != shifted.event)
dat.loc[is_different]
Now we got the correct final data set.
What Learned¶
We find a simple solution to drop duplicates only across neighbouring rows. We fully implement Pandas
built-in methods or functions. No iteration through the rows, which means fast speed.
No comments:
Post a Comment