diff --git a/benches/plot.py b/benches/plot.py
index 58acd8c850aa3d64bb4d9045b0eb8698695e58ec..348d9f003b3c74588696de90c482fa463b38b6cb 100644
--- a/benches/plot.py
+++ b/benches/plot.py
@@ -55,8 +55,13 @@ def read_data(data_file: str) -> Data:
 
 def plot_linalg(data: Data, save: bool = False):
     # key: the start of the `$.id` field
-    def plot(data: Data, key: str, ax):
-        filtered_data = list(filter(lambda line: line["id"].startswith(key), data))
+    def plot(data: Data, key: str, curve: str, color: str, ax):
+        filtered_data = list(filter(
+            lambda line: line["id"].startswith(key) and line["id"].endswith(f" on {curve}"),
+            data
+        ))
+        if len(filtered_data) == 0:
+            return
 
         sizes = [
             int(line["id"].split(' ')[1].split('x')[0]) for line in filtered_data
@@ -66,22 +71,16 @@ def plot_linalg(data: Data, save: bool = False):
         up = ns_to_ms(extract(filtered_data, "mean", "upper_bound"))
         down = ns_to_ms(extract(filtered_data, "mean", "lower_bound"))
 
-        ax.plot(sizes, means, label="mean", color="blue")
-        ax.fill_between(sizes, down, up, color="blue", alpha=0.3)
-
-        medians = ns_to_ms(extract(filtered_data, "median", "estimate"))
-        up = ns_to_ms(extract(filtered_data, "median", "upper_bound"))
-        down = ns_to_ms(extract(filtered_data, "median", "lower_bound"))
-
-        ax.plot(sizes, medians, label="median", color="orange")
-        ax.fill_between(sizes, down, up, color="orange", alpha=0.3)
+        ax.plot(sizes, means, label=curve, color=color)
+        ax.fill_between(sizes, down, up, color=color, alpha=0.3)
 
     labels = ["transpose", "mul", "inverse"]
 
     fig, axs = plt.subplots(len(labels), 1, figsize=(16, 9))
 
     for label, ax in zip(labels, axs):
-        plot(data, key=label, ax=ax)
+        for (curve, color) in [("BLS12-381", "blue"), ("BN-254", "orange"), ("PALLAS", "green")]:
+            plot(data, key=label, curve=curve, color=color, ax=ax)
         ax.set_title(label)
         ax.set_yscale("log")
         ax.set_ylabel("time in ms")