Skip to content

Commit 6f2b0d3

Browse files
committed
Add error message on invalid column
Add an error message that is printed when the user supplies an invalid column to prevent unwanted calculations and user-visible exceptions. Also add tests to ensure that the error message is printed to stderr. Closes #1
1 parent 1d9b976 commit 6f2b0d3

File tree

4 files changed

+132
-3
lines changed

4 files changed

+132
-3
lines changed

blendplot/__main__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
A cli application for plotting 3D data in obj format for use in Blender.
33
"""
44
import cli.app
5+
import sys
56
import time
67

78
from . import obj_graph
@@ -70,8 +71,11 @@ def main(input_filename, output_filename, num_rows, columns, spacing, point_size
7071
end = time.time()
7172
output_file.close()
7273

73-
print("Wrote plot file to %s" % output_filename)
74-
print("Plotted %s points in %f seconds" % (points, end - start))
74+
if points is None:
75+
sys.exit(1)
76+
else:
77+
print("Wrote plot file to %s" % output_filename)
78+
print("Plotted %s points in %f seconds" % (points, end - start))
7579

7680
def run():
7781
blendplot.run()

blendplot/obj_graph.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
from sklearn import preprocessing
55
import pandas as pd
6+
import sys
67

78
def add_cube_verticies(cube_str, x, y, z, point_size):
89
"""
@@ -170,9 +171,20 @@ def plot_file(input_filename, output_file, num_rows, columns, spacing, point_siz
170171
Returns
171172
-------
172173
points : int
173-
the number of points that were plotted
174+
the number of points that were plotted, or None if the plot is unable
175+
to be made
174176
"""
175177
original_data = pd.read_csv(input_filename, nrows = num_rows)
178+
179+
missing = get_missing_columns(original_data, columns, category_column)
180+
if len(missing) > 0:
181+
missing_columns = ", ".join(missing)
182+
valid_columns = ", ".join(list(original_data.columns))
183+
error_msg = "Invalid column(s): %s\n" % missing_columns
184+
error_msg += "Valid columns are: %s" % valid_columns
185+
print(error_msg, file=sys.stderr)
186+
return None
187+
176188
data = pd.DataFrame(original_data, columns = columns).dropna()
177189
data = pd.DataFrame(preprocessing.scale(data), columns = data.columns)
178190

@@ -196,3 +208,32 @@ def plot_file(input_filename, output_file, num_rows, columns, spacing, point_siz
196208

197209
points = num_rows if num_rows is not None else len(data.index)
198210
return points
211+
212+
def get_missing_columns(data, columns, category_column):
213+
"""
214+
Returns all of the given columns that are not in the given dataframe.
215+
216+
Parameters
217+
----------
218+
columns : List[str]
219+
the columns to look for
220+
category_column : str
221+
the category column to look for, or None if there is no category column
222+
being used
223+
224+
Returns
225+
-------
226+
missing : List[str]
227+
a list of the missing columns
228+
"""
229+
data_columns = set(data.columns)
230+
231+
if not category_column is None:
232+
columns = columns + [category_column]
233+
234+
missing = []
235+
for col in columns:
236+
if not col in data_columns:
237+
missing.append(col)
238+
239+
return missing

test/test_obj_graph.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from blendplot.obj_graph import *
77

8+
import utilities
9+
810
class TestObjGraph(unittest.TestCase):
911
def test_add_cube_verticies(self):
1012
cube_str = ""
@@ -120,6 +122,52 @@ def test_plot_file(self):
120122

121123
self.assertEquals(actual_output, expected_output)
122124

125+
def test_plot_file_invalid_column(self):
126+
input_filename = "test/resources/data_01.csv"
127+
output_file = io.StringIO()
128+
num_rows = None
129+
invalid_column = "u"
130+
columns = [invalid_column, "b", "c"]
131+
spacing = 0.5
132+
point_size = 0.1
133+
category_column = None
134+
135+
func = lambda x: plot_file(input_filename, output_file, num_rows, columns, spacing, point_size, category_column)
136+
137+
(actual_return, actual_error) = utilities.capture_stderr(func)
138+
expected_error = "Invalid column(s): %s\nValid columns are: a, b, c, d, category\n" % invalid_column
139+
140+
self.assertEquals(actual_error, expected_error)
141+
self.assertEquals(actual_return, None)
142+
143+
actual_output = output_file.getvalue()
144+
expected_output = ""
145+
146+
self.assertEquals(actual_output, expected_output)
147+
148+
def test_plot_file_invalid_column_multiple(self):
149+
input_filename = "test/resources/data_01.csv"
150+
output_file = io.StringIO()
151+
num_rows = None
152+
invalid_columns = ["u", "A"]
153+
columns = ["c"] + invalid_columns
154+
spacing = 0.5
155+
point_size = 0.1
156+
category_column = None
157+
158+
func = lambda x: plot_file(input_filename, output_file, num_rows, columns, spacing, point_size, category_column)
159+
160+
(actual_return, actual_error) = utilities.capture_stderr(func)
161+
expected_error = "Invalid column(s): %s, %s\nValid columns are: a, b, c, d, category\n" % (invalid_columns[0], invalid_columns[1])
162+
163+
self.assertEquals(actual_error, expected_error)
164+
self.assertEquals(actual_return, None)
165+
166+
actual_output = output_file.getvalue()
167+
expected_output = ""
168+
169+
self.assertEquals(actual_output, expected_output)
170+
123171
def test_plot_file_category(self):
124172
input_filename = "test/resources/data_01.csv"
125173
output_file = io.StringIO()
@@ -142,6 +190,28 @@ def test_plot_file_category(self):
142190

143191
self.assertEquals(actual_output, expected_output)
144192

193+
def test_plot_file_category_invalid(self):
194+
input_filename = "test/resources/data_01.csv"
195+
output_file = io.StringIO()
196+
num_rows = 4
197+
columns = ["a", "b", "c"]
198+
spacing = 0.5
199+
point_size = 0.1
200+
invalid_category_column = "cats"
201+
202+
func = lambda x: plot_file(input_filename, output_file, num_rows, columns, spacing, point_size, invalid_category_column)
203+
204+
(actual_return, actual_error) = utilities.capture_stderr(func)
205+
expected_error = "Invalid column(s): %s\nValid columns are: a, b, c, d, category\n" % invalid_category_column
206+
207+
self.assertEquals(actual_error, expected_error)
208+
self.assertEquals(actual_return, None)
209+
210+
actual_output = output_file.getvalue()
211+
expected_output = ""
212+
213+
self.assertEquals(actual_output, expected_output)
214+
145215
@given(
146216
st.text(),
147217
st.floats(allow_nan=False, allow_infinity=False),

test/utilities.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import io
2+
import sys
3+
4+
def capture_stderr(func):
5+
err, sys.stderr = sys.stderr, io.StringIO()
6+
value = None
7+
try:
8+
ret = func(None)
9+
sys.stderr.seek(0)
10+
value = (ret, sys.stderr.read())
11+
finally:
12+
sys.stderr = err
13+
14+
return value

0 commit comments

Comments
 (0)