Python Pandas: đừng dùng IF, hãy dùng where!#

Giới thiệu:#

Đôi khi trong phân tích, ta cần tạo ra các field thứ cấp từ các field ban đầu, trong đó có sử dụng điều kiện.

  • Thông thường để thực hiện điều này, theo bản năng ta dùng IF/ELIF để tạo function và loop qua từng giá trị trong field.

  • Cách làm trên tuy vẫn hoạt động và concept quen thuộc, dễ hiểu tuy nhiên không mang lại hiệu suất tính toán tốt, do đó trong trường hợp dữ liệu lớn sẽ làm kéo dài thời gian tính toán.

  • Để tối ưu tính toán, ta cần nắm rõ và vận dụng tốt các build-in function trong numpy/pandas, những hàm này đã được tối ưu cho các phép toán vectorized operation.

  • Cụ thể đối với trường hợp này, ta sẽ sử dụng where thay vì IF.

Cú pháp:#

Syntax :numpy.where(condition[, x, y])

Hàm này tương tự như hàm if trong Excel: Nếu + điều kiện , giá trị trả về nếu True , giá trị trả về nếu False

Bài toán:#

Giả sử ta có dữ liệu nhịp tim như bên dưới, biết rằng nhịp tim của người bình thường sẽ rơi vào khoảng 60-100 nhịp/phút. Ta cần phân loại nhịp tim thành 3 nhóm:

  • normal: [60-100]

  • unnormal: other case

import pandas as pd 
import numpy as np
sample = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/refs/heads/master/exercise.csv',index_col=0)[['pulse']]
sample.head()
pulse
0 85
1 85
2 88
3 90
4 92

Cách 1: sử dụng IF#

%%time
sample['type'] = ['normal' if  ((row['pulse'] >=60) & ( row['pulse'] <=100 )) else  'unnormal' for _, row in sample.iterrows()  ]
CPU times: total: 0 ns
Wall time: 3 ms
sample
pulse type
0 85 normal
1 85 normal
2 88 normal
3 90 normal
4 92 normal
... ... ...
85 135 unnormal
86 130 unnormal
87 99 normal
88 111 unnormal
89 150 unnormal

90 rows × 2 columns

Cách 2: Sử dụng pd.where#

%%time
sample['type2'] =  np.where((sample.pulse>=60)&(sample.pulse<=100),'normal','unnormal')
CPU times: total: 0 ns
Wall time: 0 ns
sample
pulse type type2
0 85 normal normal
1 85 normal normal
2 88 normal normal
3 90 normal normal
4 92 normal normal
... ... ... ...
85 135 unnormal unnormal
86 130 unnormal unnormal
87 99 normal normal
88 111 unnormal unnormal
89 150 unnormal unnormal

90 rows × 3 columns

Kết luận:#

  • So sách thời gian thực hiện bằng 2 cách sẽ thấy rõ sự khác biệt.

  • np.where đặc biệt hiệu quả khi kết hợp với các build-in function khác có sẵn trên numpy và pandas.