Skip to content

Commit 309800d

Browse files
fixed constraints interface
1 parent 1d10086 commit 309800d

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

batchglm/data.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,16 @@ def constraint_system_from_star(
217217
sample_description=sample_description,
218218
formula=formula,
219219
as_categorical=as_categorical,
220-
constraints=constraints
220+
constraints=constraints,
221+
return_type="dataframe"
221222
)
222223
elif isinstance(constraints, tuple) or isinstance(constraints, list):
223224
cmat = constraint_matrix_from_string(
224225
dmat=dmat,
225226
constraints=constraints
226227
)
227228
elif isinstance(constraints, np.ndarray):
228-
cmat = parse_constraints
229+
cmat = constraints
229230
elif constraints is None:
230231
cmat = None
231232
else:
@@ -238,7 +239,8 @@ def constraint_matrix_from_dict(
238239
sample_description: pd.DataFrame,
239240
formula: str,
240241
as_categorical: Union[bool, list] = True,
241-
constraints: dict = {}
242+
constraints: dict = {},
243+
return_type: str = "dataframe"
242244
) -> Tuple:
243245
"""
244246
Create a design matrix from some sample description and a constraint matrix
@@ -303,9 +305,14 @@ def constraint_matrix_from_dict(
303305
# Build constraint matrix.
304306
constraints_ar = constraint_matrix_from_string(
305307
dmat=dmat,
308+
coef_names=coef_names,
306309
constraints=constraints_ls
307310
)
308311

312+
# Format return type
313+
if return_type == "dataframe":
314+
dmat = pd.DataFrame(dmat, columns=coef_names)
315+
309316
return dmat, constraints_ar
310317

311318

@@ -362,6 +369,7 @@ def string_constraints_from_dict(
362369

363370
def constraint_matrix_from_string(
364371
dmat: np.ndarray,
372+
coef_names: list,
365373
constraints: Union[Tuple[str, str], List[str]]
366374
):
367375
r"""
@@ -375,10 +383,10 @@ def constraint_matrix_from_string(
375383
"""
376384
assert len(constraints) > 0, "supply constraints"
377385

378-
n_par_all = dmat.values.shape[1]
386+
n_par_all = dmat.shape[1]
379387
n_par_free = n_par_all - len(constraints)
380388

381-
di = patsy.DesignInfo(dmat.coords["design_params"].values)
389+
di = patsy.DesignInfo(coef_names)
382390
constraint_ls = [di.linear_constraint(x).coefs[0] for x in constraints]
383391
idx_constr = np.asarray([np.where(x == 1)[0][0] for x in constraint_ls])
384392
idx_depending = [np.where(x == 1)[0][1:] for x in constraint_ls]

0 commit comments

Comments
 (0)