55
66from blendplot .obj_graph import *
77
8+ import utilities
9+
810class TestObjGraph (unittest .TestCase ):
911 def test_add_cube_verticies (self ):
1012 cube_str = ""
@@ -120,6 +122,52 @@ def test_plot_file(self):
120122
121123 self .assertEquals (actual_output , expected_output )
122124
125+ def test_plot_file_invalid_column (self ):
126+ input_filename = "test/resources/data_01.csv"
127+ output_file = io .StringIO ()
128+ num_rows = None
129+ invalid_column = "u"
130+ columns = [invalid_column , "b" , "c" ]
131+ spacing = 0.5
132+ point_size = 0.1
133+ category_column = None
134+
135+ func = lambda x : plot_file (input_filename , output_file , num_rows , columns , spacing , point_size , category_column )
136+
137+ (actual_return , actual_error ) = utilities .capture_stderr (func )
138+ expected_error = "Invalid column(s): %s\n Valid columns are: a, b, c, d, category\n " % invalid_column
139+
140+ self .assertEquals (actual_error , expected_error )
141+ self .assertEquals (actual_return , None )
142+
143+ actual_output = output_file .getvalue ()
144+ expected_output = ""
145+
146+ self .assertEquals (actual_output , expected_output )
147+
148+ def test_plot_file_invalid_column_multiple (self ):
149+ input_filename = "test/resources/data_01.csv"
150+ output_file = io .StringIO ()
151+ num_rows = None
152+ invalid_columns = ["u" , "A" ]
153+ columns = ["c" ] + invalid_columns
154+ spacing = 0.5
155+ point_size = 0.1
156+ category_column = None
157+
158+ func = lambda x : plot_file (input_filename , output_file , num_rows , columns , spacing , point_size , category_column )
159+
160+ (actual_return , actual_error ) = utilities .capture_stderr (func )
161+ expected_error = "Invalid column(s): %s, %s\n Valid columns are: a, b, c, d, category\n " % (invalid_columns [0 ], invalid_columns [1 ])
162+
163+ self .assertEquals (actual_error , expected_error )
164+ self .assertEquals (actual_return , None )
165+
166+ actual_output = output_file .getvalue ()
167+ expected_output = ""
168+
169+ self .assertEquals (actual_output , expected_output )
170+
123171 def test_plot_file_category (self ):
124172 input_filename = "test/resources/data_01.csv"
125173 output_file = io .StringIO ()
@@ -142,6 +190,28 @@ def test_plot_file_category(self):
142190
143191 self .assertEquals (actual_output , expected_output )
144192
193+ def test_plot_file_category_invalid (self ):
194+ input_filename = "test/resources/data_01.csv"
195+ output_file = io .StringIO ()
196+ num_rows = 4
197+ columns = ["a" , "b" , "c" ]
198+ spacing = 0.5
199+ point_size = 0.1
200+ invalid_category_column = "cats"
201+
202+ func = lambda x : plot_file (input_filename , output_file , num_rows , columns , spacing , point_size , invalid_category_column )
203+
204+ (actual_return , actual_error ) = utilities .capture_stderr (func )
205+ expected_error = "Invalid column(s): %s\n Valid columns are: a, b, c, d, category\n " % invalid_category_column
206+
207+ self .assertEquals (actual_error , expected_error )
208+ self .assertEquals (actual_return , None )
209+
210+ actual_output = output_file .getvalue ()
211+ expected_output = ""
212+
213+ self .assertEquals (actual_output , expected_output )
214+
145215@given (
146216 st .text (),
147217 st .floats (allow_nan = False , allow_infinity = False ),
0 commit comments