Skip to content

Commit c1be1bc

Browse files
committed
Create ondemand-estimate-tool.py
Added tool to estimate ondemand cost from provisioned capacity history
1 parent f3e6cf3 commit c1be1bc

File tree

1 file changed

+302
-0
lines changed

1 file changed

+302
-0
lines changed

bin/ondemand-estimate-tool.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
import re
2+
import boto3
3+
from datetime import datetime, timedelta, timezone
4+
import statistics
5+
import json
6+
import argparse
7+
from decimal import Decimal, getcontext
8+
9+
10+
#
11+
# --- HELPERS ---
12+
#
13+
14+
"""
15+
Retrieve data points for a single CloudWatch metric (e.g. 'WriteThrottleEvent')
16+
within the specified time range, at the specified period (defaults to 1 hour).
17+
18+
Returns a list of the chosen statistic (e.g., average) values for each period.
19+
"""
20+
def get_operation_metric_data_points(
21+
cloudwatch_client,
22+
metric_name,
23+
keyspace_name,
24+
table_name,
25+
operation,
26+
start_time,
27+
end_time,
28+
period=3600,
29+
statistic='Average'
30+
):
31+
# Get the metric data points
32+
response = cloudwatch_client.get_metric_statistics(
33+
Namespace='AWS/Cassandra',
34+
MetricName=metric_name,
35+
Dimensions=[{'Name': 'TableName', 'Value': table_name}, {'Name': 'Keyspace', 'Value': keyspace_name}, {'Name': 'Operation', 'Value': operation}],
36+
StartTime=start_time,
37+
EndTime=end_time,
38+
Period=period,
39+
Statistics=[statistic]
40+
)
41+
42+
# Sort datapoints by timestamp just in case
43+
data_points = sorted(response.get('Datapoints', []), key=lambda d: d['Timestamp'])
44+
45+
# Extract the values from the data points
46+
values = [dp[statistic] for dp in data_points]
47+
return values
48+
49+
"""
50+
Retrieve data points for a single CloudWatch metric (e.g. 'ProvisionedReadCapacityUnits')
51+
within the specified time range, at the specified period (defaults to 1 hour).
52+
53+
Returns a list of the chosen statistic (e.g., average) values for each period.
54+
"""
55+
def get_table_metric_data_points(
56+
cloudwatch_client,
57+
metric_name,
58+
keyspace_name,
59+
table_name,
60+
start_time,
61+
end_time,
62+
period=3600,
63+
statistic='Average'
64+
):
65+
# Get the metric data points
66+
response = cloudwatch_client.get_metric_statistics(
67+
Namespace='AWS/Cassandra',
68+
MetricName=metric_name,
69+
Dimensions=[{'Name': 'TableName', 'Value': table_name}, {'Name': 'Keyspace', 'Value': keyspace_name}],
70+
StartTime=start_time,
71+
EndTime=end_time,
72+
Period=period,
73+
Statistics=[statistic]
74+
)
75+
76+
# Sort datapoints by timestamp just in case
77+
data_points = sorted(response.get('Datapoints', []), key=lambda d: d['Timestamp'])
78+
79+
# Extract the values from the data points
80+
values = [Decimal(dp[statistic]) for dp in data_points]
81+
return values
82+
83+
"""
84+
Total the number of throttles for all operations.
85+
"""
86+
def sum_all_throttles(insert_throttles, update_throttles, delete_throttles, select_throttles):
87+
88+
total = 0;
89+
90+
total += sum(insert_throttles) + sum(update_throttles) + sum(delete_throttles) + sum(select_throttles)
91+
92+
# Cost = (RCU * read_rate + WCU * write_rate) * hours
93+
return (total)
94+
95+
"""
96+
Estimate cost in US$ for capacity over 'hours' hours.
97+
rcu, wcu are capacity units.
98+
"""
99+
def estimate_cost(vals, price):
100+
101+
return sum(val * price for val in vals)
102+
103+
def get_keyspaces_throughput_mode(client, keyspace_name, table_name):
104+
"""
105+
Retrieve the throughputMode ('PAY_PER_REQUEST' or 'PROVISIONED')
106+
for the specified Amazon Keyspaces table.
107+
"""
108+
109+
response = client.get_table(
110+
keyspaceName=keyspace_name,
111+
tableName=table_name
112+
)
113+
114+
# Navigate into the response to find the throughputMode
115+
# The structure is: response['table']['capacitySpecification']['throughputMode']
116+
# table_info = response['capacitySpecification']
117+
capacity_spec = response.get('capacitySpecification', {})
118+
119+
throughput_mode = capacity_spec.get('throughputMode', 'UNKNOWN')
120+
121+
return throughput_mode
122+
123+
def get_keyspaces_pricing(pricing_client, region_name):
124+
125+
response = pricing_client.get_products(
126+
ServiceCode='AmazonMCS',
127+
Filters=[
128+
{
129+
'Type': 'TERM_MATCH',
130+
'Field': 'regionCode',
131+
'Value': region_name
132+
}
133+
],
134+
MaxResults=100
135+
)
136+
137+
# The response includes 'PriceList', a list of JSON or stringified-JSON documents
138+
# describing the terms. You can parse it as needed.
139+
results = {}
140+
for price_item_json in response['PriceList']:
141+
price_item = json.loads(price_item_json)
142+
usageType = price_item.get('product', {}).get('attributes', {}).get('usagetype', '')
143+
pattern = r'^[A-Za-z0-9]{2}-|^[A-Za-z0-9]{3}-|^[A-Za-z0-9]{4}-'
144+
# Replace that pattern (if it exists at the beginning) with nothing.
145+
usageType = re.sub(pattern, '', usageType)
146+
147+
for one_value in iter(price_item.get('terms', {}).get('OnDemand', {}).values()):
148+
for one_dimension in iter(one_value.get('priceDimensions', {}).values()):
149+
price = one_dimension.get('pricePerUnit', {}).get('USD', '0.0')
150+
results.update({usageType: Decimal(price)})
151+
152+
153+
return results
154+
155+
"""
156+
By default, analyzes the last 'days' (7) days of metrics for
157+
all tables in the specified region.
158+
"""
159+
def main():
160+
# Set decimal precision to 10
161+
getcontext().prec = 10
162+
163+
# Parse command-line arguments
164+
parser = argparse.ArgumentParser(
165+
description='Generate a report from nodetool tablestats and nodetool info and row size sampler outputs.'
166+
)
167+
168+
169+
parser.add_argument('--number-of-days', help='Number of days in the past to look back', type=int, default=7)
170+
parser.add_argument('--region-name', help='The AWS region where you have Amazon Keyspaces usage',type=str, default='us-east-1')
171+
parser.add_argument('--single-keyspace', type=str, default=None,
172+
help='Calculate a single keyspace. Leave out all other keyspaces')
173+
174+
# Parse arguments
175+
args = parser.parse_args()
176+
177+
region_name = args.region_name
178+
days = args.number_of_days
179+
single_keyspace = args.single_keyspace
180+
181+
print(f"Estimating Amazon Keyspaces OnDemand costs using CloudWatch metrics for region {region_name} and {days} days")
182+
183+
# Check if the region is China
184+
if region_name in ['cn-north-1', 'cn-northwest-1']:
185+
print("Amazon Keyspaces pricing is not available in China regions through the pricing api")
186+
exit(1)
187+
188+
189+
service_client = boto3.client('keyspaces', region_name=region_name)
190+
cloudwatch = boto3.client('cloudwatch', region_name=region_name)
191+
pricing_client = boto3.client('pricing', region_name=('ap-south-1' if region_name == 'ap-south-1' else 'us-east-1'))
192+
193+
price_dictionary = get_keyspaces_pricing(pricing_client, region_name)
194+
195+
# Determine default time range if none provided
196+
end_time = datetime.now(timezone.utc)
197+
start_time = end_time - timedelta(days=days)
198+
199+
# List all tables
200+
all_tables = []
201+
202+
ks_paginator = service_client.get_paginator('list_keyspaces')
203+
tbl_paginator = service_client.get_paginator('list_tables')
204+
205+
# Iterate over all keyspaces and tables.
206+
# If a single keyspace is specified, only that keyspace is analyzed.
207+
# filter system tables
208+
# capture Provisioned tables
209+
for page in ks_paginator.paginate():
210+
for one_keyspace in page["keyspaces"]:
211+
one_keyspace_name = one_keyspace['keyspaceName']
212+
if single_keyspace == None or one_keyspace_name == single_keyspace:
213+
if one_keyspace_name not in ['system', 'system_auth', 'system_distributed', 'system_schema', 'system_traces', 'system_schema_mcs', 'system_multiregion_info' ]:
214+
for table_page in tbl_paginator.paginate(keyspaceName=one_keyspace_name):
215+
for one_table in table_page["tables"]:
216+
one_table_name = one_table["tableName"]
217+
throughput_mode = get_keyspaces_throughput_mode(client=service_client, keyspace_name=one_keyspace_name, table_name=one_table_name)
218+
if(throughput_mode == 'PROVISIONED'):
219+
all_tables.append({'keyspaceName': one_keyspace_name, 'tableName': one_table["tableName"], 'throughputMode': throughput_mode})
220+
221+
222+
if not all_tables:
223+
print("No provisioned capacity mode tables found in this account/region/keyspace")
224+
return
225+
226+
period = 3600 # 1-hour granularity
227+
# total_hours = (end_time - start_time).total_seconds() / 3600.0
228+
229+
print(f"Analyzing tables in region {region_name} from {start_time} to {end_time}")
230+
231+
print(
232+
f"{'Keyspace':20s} "
233+
f"{'Table':30s} "
234+
f"{'current mode':17s} "
235+
f"{'provisioned reads':>17s} "
236+
f"{'provisioned writes':>17s} "
237+
f"{'ondemand reads':>17s} "
238+
f"{'ondemand writes':>17s} "
239+
f"{'provision estimate':>17s} "
240+
f"{'ondemand estimate':>17s} "
241+
f"{'total throttles':>17s} "
242+
f"{'ondemand savings':>17s} "
243+
)
244+
# For each table, gather data
245+
for one_table in all_tables:
246+
table_name = one_table["tableName"]
247+
keyspace_name = one_table["keyspaceName"]
248+
throughput_mode = one_table["throughputMode"]
249+
250+
# Fetch Provisioned & Consumed metrics
251+
prov_read_vals = get_table_metric_data_points(cloudwatch, 'ProvisionedReadCapacityUnits', keyspace_name, table_name, start_time, end_time, period)
252+
prov_write_vals = get_table_metric_data_points(cloudwatch, 'ProvisionedWriteCapacityUnits', keyspace_name, table_name, start_time, end_time, period)
253+
cons_read_vals = get_table_metric_data_points(cloudwatch, 'ConsumedReadCapacityUnits', keyspace_name, table_name, start_time, end_time, period, 'Sum')
254+
cons_write_vals = get_table_metric_data_points(cloudwatch, 'ConsumedWriteCapacityUnits', keyspace_name, table_name, start_time, end_time, period, 'Sum')
255+
256+
# Fetch Throttle metrics
257+
total_insert_throttles = get_operation_metric_data_points(cloudwatch, 'WriteThrottleEvents', keyspace_name, table_name, 'INSERT', start_time, end_time, period, 'Sum')
258+
total_update_throttles = get_operation_metric_data_points(cloudwatch, 'WriteThrottleEvents', keyspace_name, table_name, 'UPDATE', start_time, end_time, period, 'Sum')
259+
total_delete_throttles = get_operation_metric_data_points(cloudwatch, 'WriteThrottleEvents', keyspace_name, table_name, 'DELETE', start_time, end_time, period, 'Sum')
260+
total_select_throttles = get_operation_metric_data_points(cloudwatch, 'ReadThrottleEvents', keyspace_name, table_name, 'SELECT', start_time, end_time, period, 'Sum')
261+
262+
# Calculate total throttles
263+
total_throttles = sum_all_throttles(total_insert_throttles, total_update_throttles, total_delete_throttles, total_select_throttles)
264+
265+
# Estimate provisioned costs
266+
provision_read_cost = estimate_cost(prov_read_vals, price_dictionary.get('ReadCapacityUnit-Hrs'))
267+
provision_write_cost = estimate_cost(prov_write_vals, price_dictionary.get('WriteCapacityUnit-Hrs'))
268+
269+
# Estimate on-demand costs
270+
ondemand_read_cost = estimate_cost(cons_read_vals, price_dictionary.get('ReadRequestUnits'))
271+
ondemand_write_cost = estimate_cost(cons_write_vals, price_dictionary.get('WriteRequestUnits'))
272+
273+
# Calculate total costs
274+
provision_total_cost = provision_read_cost + provision_write_cost
275+
ondemand_total_cost = ondemand_read_cost + ondemand_write_cost
276+
277+
# Difference = On-Demand total minus Provisioned total
278+
# Negative => on-demand is cheaper
279+
ondemand_difference = (( provision_total_cost - ondemand_total_cost ) / provision_total_cost) * 100 if provision_total_cost > 0.0 else 0.0
280+
281+
# Print one CSV row
282+
print(
283+
f"{keyspace_name:<20s} "
284+
f"{table_name:<30s} "
285+
f"{throughput_mode:<17s} "
286+
f"{provision_read_cost:>17.2f} $"
287+
f"{provision_write_cost:>17.2f} $"
288+
f"{ondemand_read_cost:>17.2f} $"
289+
f"{ondemand_write_cost:>17.2f} $"
290+
f"{provision_total_cost:>17.2f} $"
291+
f"{ondemand_total_cost:>17.2f} $"
292+
f"{total_throttles:>17.0f} "
293+
f"{ondemand_difference:>17.2f} %"
294+
)
295+
296+
#
297+
# --- RUN EXAMPLE ---
298+
#
299+
if __name__ == '__main__':
300+
# Default: last 7 days, US East-1
301+
main()
302+

0 commit comments

Comments
 (0)