Skip to content

Commit 772beaa

Browse files
committed
improve plot layout
1 parent 11e808f commit 772beaa

File tree

3 files changed

+198
-15
lines changed

3 files changed

+198
-15
lines changed

git_commits_graph/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77

88
# bmh | fivethirtyeight | seaborn-whitegrid | seaborn-darkgrid | seaborn-ticks
99
DEFAULT_STYLE = "fivethirtyeight"
10+
DEFAULT_BACKEND = "matplotlib"
11+
DEFAULT_OUTPUT_FILE = "out.html"
12+
MAX_NUM_BARS = 200

git_commits_graph/main.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
import git
33
import matplotlib.pyplot as plt
44
import pandas as pd
5+
from git_commits_graph.config import DEFAULT_BACKEND
6+
from git_commits_graph.config import DEFAULT_OUTPUT_FILE
57
from git_commits_graph.config import DEFAULT_STYLE
68
from git_commits_graph.plotters import plot_changes
9+
from git_commits_graph.plotters import plot_changes_px
710
from git_commits_graph.plotters import plot_total_lines
11+
from git_commits_graph.plotters import plot_total_lines_px
812

913

1014
@click.command()
@@ -35,6 +39,14 @@
3539
is_flag=True,
3640
help="list available plot styles and exit.",
3741
)
42+
@click.option(
43+
"-e",
44+
"--engine",
45+
is_flag=False,
46+
default=DEFAULT_BACKEND,
47+
help="plotting engine to use (matplitlib | plotly)",
48+
)
49+
@click.option("-o", "--output-file", help="output file name (for plotly backend)")
3850
def main(
3951
git_dir,
4052
branch,
@@ -44,8 +56,10 @@ def main(
4456
style,
4557
list_available_plot_styles,
4658
aggregate_by,
59+
engine="matplotlib",
60+
output_file=DEFAULT_OUTPUT_FILE,
4761
):
48-
62+
"""Plot git commits timeline main function."""
4963
if list_available_plot_styles:
5064
print(plt.style.available)
5165
exit()
@@ -58,6 +72,8 @@ def main(
5872
log_scale=log_scale,
5973
style=style,
6074
aggregate_by=aggregate_by,
75+
backend=engine,
76+
output_file=output_file,
6177
)
6278

6379

@@ -69,35 +85,51 @@ def git_graph(
6985
log_scale=False,
7086
style=DEFAULT_STYLE,
7187
aggregate_by=None,
88+
backend=DEFAULT_BACKEND,
89+
output_file=None,
7290
):
73-
# TODO: KS: 2022-06-06: Fetch commits from the github repo without cloning it.
91+
"""Plot git commits timeline."""
92+
# TODO: KS: 2022-06-06: Fetch commits from the GitHub repo without cloning it.
7493
# see: https://stackoverflow.com/a/64561416/3247880
7594

7695
git_dir, repo = get_git_repo(
7796
changes=changes, git_dir=git_dir, total_lines=total_lines
7897
)
79-
commits = fetch_commits(branch, repo)
98+
commits = fetch_commits(branch, repo) # this might take a long time
8099

81100
plt.style.use(style)
101+
102+
if backend == "plotly":
103+
func_total = plot_total_lines_px
104+
func_changes = plot_changes_px
105+
elif backend == "matplotlib":
106+
func_total = plot_total_lines
107+
func_changes = plot_changes
108+
else:
109+
raise ValueError(f"Unknown backend: {backend}")
110+
82111
if total_lines:
83-
plot_total_lines(
112+
func_total(
84113
commits=commits,
85114
git_dir=git_dir,
86115
log_scale=log_scale,
87116
aggregate_by=aggregate_by,
117+
output_file=output_file,
88118
)
89119
plt.show()
90120
if changes:
91-
plot_changes(
121+
func_changes(
92122
commits=commits,
93123
git_dir=git_dir,
94124
log_scale=log_scale,
95125
aggregate_by=aggregate_by,
126+
output_file=output_file,
96127
)
97128
plt.show()
98129

99130

100131
def get_git_repo(changes, git_dir, total_lines):
132+
"""Get git repository object."""
101133
git_dir = git_dir.strip('"').strip("'")
102134
try:
103135
repo = git.repo.Repo(git_dir)
@@ -120,6 +152,7 @@ def get_git_repo(changes, git_dir, total_lines):
120152

121153

122154
def fetch_commits(branch, repo):
155+
"""Fetch commits from the git repository."""
123156
commits = []
124157
try:
125158
for i in reversed(list(repo.iter_commits(rev=branch))):
@@ -146,6 +179,7 @@ def fetch_commits(branch, repo):
146179

147180

148181
def prepare_commits_dataframe(commits):
182+
"""Make DataFrame from commits list."""
149183
commits = pd.DataFrame(commits, columns=["date", "added", "removed"])
150184
commits["delta"] = commits["added"] - commits["removed"]
151185
commits.date = pd.to_datetime(commits.date)

git_commits_graph/plotters.py

Lines changed: 156 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,56 @@
22
import os
33
from typing import Optional
44

5+
import matplotlib.dates as mdates
6+
import matplotlib.ticker as ticker
57
import pandas as pd
8+
import plotly.graph_objs as go
69
from git_commits_graph.config import FIGSIZE
10+
from git_commits_graph.config import MAX_NUM_BARS
711
from git_commits_graph.config import XTICKS_FMT
812
from matplotlib import pyplot as plt
913

1014

11-
def plot_changes(commits, git_dir: str, log_scale: bool, aggregate_by: Optional[str]):
15+
# Define a function to format the y-axis labels
16+
def y_fmt(x, pos):
17+
if x >= 1000000:
18+
return f"{x * 1e-6:.0f}M"
19+
elif x >= 1000:
20+
return f"{x * 1e-3:.0f}k"
21+
else:
22+
return f"{x:.0f}"
23+
24+
25+
# Create a FuncFormatter object from the y_fmt function
26+
y_formatter = ticker.FuncFormatter(y_fmt)
27+
28+
29+
def data_aggregation(aggregate_by, plot_data_add, plot_data_rem):
30+
if aggregate_by:
31+
plot_data_add = run_aggregation(plot_data_add, col="added", period=aggregate_by)
32+
plot_data_rem = run_aggregation(
33+
plot_data_rem, col="removed", period=aggregate_by
34+
)
35+
else:
36+
for agg in ["D", "W", "M", "Y"]:
37+
if len(plot_data_add) > MAX_NUM_BARS:
38+
plot_data_add = run_aggregation(plot_data_add, col="added", period=agg)
39+
plot_data_rem = run_aggregation(
40+
plot_data_rem, col="removed", period=agg
41+
)
42+
if len(plot_data_add) <= MAX_NUM_BARS:
43+
break
44+
return plot_data_add, plot_data_rem
45+
46+
47+
def plot_changes_px(
48+
commits,
49+
git_dir: str,
50+
log_scale: bool,
51+
aggregate_by: Optional[str],
52+
output_file="out.html",
53+
):
54+
"""Plot added/removed lines timeline."""
1255
plot_data_add = commits.added
1356
plot_data_rem = commits.removed
1457

@@ -18,27 +61,89 @@ def plot_changes(commits, git_dir: str, log_scale: bool, aggregate_by: Optional[
1861
plot_data_rem = (plot_data_rem + 1).apply(math.log10)
1962
ylabel = "log number of lines added/removed"
2063

21-
if aggregate_by:
22-
plot_data_add = run_aggregation(plot_data_add, col="added", period=aggregate_by)
23-
plot_data_rem = run_aggregation(
24-
plot_data_rem, col="removed", period=aggregate_by
64+
plot_data_add, plot_data_rem = data_aggregation(
65+
aggregate_by=aggregate_by,
66+
plot_data_add=plot_data_add,
67+
plot_data_rem=plot_data_rem,
68+
)
69+
70+
plot_data_add = pd.DataFrame(plot_data_add)
71+
plot_data_rem = pd.DataFrame(plot_data_rem)
72+
fig = go.Figure()
73+
74+
fig.add_trace(
75+
go.Bar(
76+
x=plot_data_add.index,
77+
y=plot_data_add.added,
78+
name="added",
79+
marker_color="green",
80+
)
81+
)
82+
fig.add_trace(
83+
go.Bar(
84+
x=plot_data_rem.index,
85+
y=-plot_data_rem.removed,
86+
name="removed",
87+
marker_color="red",
2588
)
89+
)
90+
91+
fig.update_layout(
92+
title=f"Added/Removed Lines in repo {os.path.basename(git_dir)}",
93+
xaxis_title="Date",
94+
yaxis_title=ylabel,
95+
)
96+
if log_scale:
97+
fig.update_yaxes(type="log")
98+
# fig.show()
99+
fig.write_html(output_file)
100+
print(f"Saved to {output_file}")
101+
102+
103+
def plot_changes(
104+
commits, git_dir: str, log_scale: bool, aggregate_by: Optional[str], output_file
105+
):
106+
"""Plot added/removed lines timeline."""
107+
plot_data_add = commits.added
108+
plot_data_rem = commits.removed
109+
110+
ylabel = "number of lines added/removed"
111+
if log_scale:
112+
plot_data_add = (plot_data_add + 1).apply(math.log10)
113+
plot_data_rem = (plot_data_rem + 1).apply(math.log10)
114+
ylabel = "log number of lines added/removed"
115+
116+
plot_data_add, plot_data_rem = data_aggregation(
117+
aggregate_by=aggregate_by,
118+
plot_data_add=plot_data_add,
119+
plot_data_rem=plot_data_rem,
120+
)
26121

27122
plot_data_add = pd.DataFrame(plot_data_add)
123+
plot_data_rem = pd.DataFrame(-plot_data_rem)
28124
fig, ax = plt.subplots(1, 1, figsize=FIGSIZE)
29125

126+
locator = mdates.AutoDateLocator(minticks=3, maxticks=15)
127+
x_formatter = mdates.ConciseDateFormatter(locator)
128+
30129
ax = plot_data_add.plot(kind="bar", ax=ax, color="green", label="added")
31-
ax = (-plot_data_rem).plot(kind="bar", ax=ax, color="red", label="removed")
130+
ax = plot_data_rem.plot(kind="bar", ax=ax, color="red", label="removed")
131+
ax.xaxis.set_major_locator(locator)
132+
ax.xaxis.set_major_formatter(x_formatter)
133+
ax.yaxis.set_major_formatter(y_formatter)
32134

33-
# plot xticks (dates)
135+
plt.xticks(rotation=45)
34136
format_xticklabels(ax, plot_data_add)
35137
plt.ylabel(ylabel)
138+
36139
ax.set_title(f"Added/Removed Lines in repo {os.path.basename(git_dir)}")
37140
fig.tight_layout()
38141

39142

40-
def plot_total_lines(commits, git_dir, log_scale, aggregate_by):
41-
fig, ax = plt.subplots(1, 1, figsize=[8, 6])
143+
def plot_total_lines_px(
144+
commits, git_dir, log_scale, aggregate_by, output_file="out.html"
145+
):
146+
"""Plot total lines timeline."""
42147

43148
_delta = commits.delta
44149
if aggregate_by:
@@ -55,24 +160,65 @@ def plot_total_lines(commits, git_dir, log_scale, aggregate_by):
55160
plot_data = (plot_data + 1).apply(math.log10)
56161
ylabel = "log number of lines"
57162

58-
ax = plot_data.plot()
163+
fig = go.Figure()
164+
fig.add_trace(go.Scatter(x=plot_data.index, y=plot_data.values, mode="lines"))
165+
fig.update_layout(
166+
title=f"Number of Lines Progress in repo {os.path.basename(git_dir)}",
167+
xaxis_title="Date",
168+
yaxis_title=ylabel,
169+
)
170+
fig.update_yaxes(range=[0, 1.1 * plot_data.max()])
171+
# fig.show()
172+
fig.write_html(output_file)
173+
print(f"Saved to {output_file}")
174+
175+
176+
def plot_total_lines(commits, git_dir, log_scale, aggregate_by, output_file):
177+
"""Plot total lines timeline."""
178+
fig, ax = plt.subplots(1, 1, figsize=[8, 6])
179+
180+
_delta = commits.delta
181+
_delta.index = pd.to_datetime(_delta.index, utc=True)
182+
if aggregate_by:
183+
_delta = _delta.reset_index()
184+
_delta = _delta.groupby(pd.Grouper(key="date", axis=0, freq=aggregate_by))[
185+
"delta"
186+
].sum()
187+
188+
plot_data = _delta.cumsum()
189+
190+
ylabel = "number of lines"
191+
if log_scale:
192+
plot_data = (plot_data + 1).apply(math.log10)
193+
ylabel = "log number of lines"
194+
195+
ax = plot_data.plot(ax=ax, kind="line")
196+
locator = mdates.AutoDateLocator(minticks=10, maxticks=15)
197+
formatter = mdates.ConciseDateFormatter(locator)
198+
ax.xaxis.set_major_locator(locator)
199+
# ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(ax.xaxis.get_major_locator()))
200+
ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(formatter))
201+
ax.yaxis.set_major_formatter(y_formatter)
59202
plt.ylabel(ylabel)
60203
ax.set_title(f"Number of Lines Progress in repo {os.path.basename(git_dir)}")
61204
ax.set_ylim([0, 1.1 * plot_data.max()])
62205
# format_xticklabels(ax, plot_data)
63206
ax.xaxis_date()
64207
# Optional. Just rotates x-ticklabels in this case.
65208
fig.autofmt_xdate()
209+
fig.tight_layout()
66210

67211

68212
def format_xticklabels(ax, plot_data_add, fmt=XTICKS_FMT):
213+
"""Format xticks labels."""
69214
plot_data_add = plot_data_add.reset_index()
70215
plot_data_add.date = pd.to_datetime(plot_data_add.date, utc=True)
71216
plot_data_add["xticks"] = plot_data_add.date.dt.strftime(fmt)
72217
ax.set_xticklabels(plot_data_add.xticks, rotation=90)
73218

74219

75220
def run_aggregation(plot_data_add, col, period):
221+
"""Aggregate data by period."""
76222
plot_data_add = plot_data_add.reset_index()
77223
plot_data_add.date = pd.to_datetime(plot_data_add.date, utc=True)
78224
plot_data_add = plot_data_add.groupby(pd.Grouper(key="date", axis=0, freq=period))[

0 commit comments

Comments
 (0)