diff options
| -rw-r--r-- | internal/tag/tag.go | 16 | ||||
| -rw-r--r-- | lay/strain/solve.go | 247 | ||||
| -rw-r--r-- | lay/strain/solve_test.go | 85 |
3 files changed, 239 insertions, 109 deletions
diff --git a/internal/tag/tag.go b/internal/tag/tag.go index 197d2d6..b4b627d 100644 --- a/internal/tag/tag.go +++ b/internal/tag/tag.go @@ -1,12 +1,22 @@ package tag +import "context" + type Tagged[V, T any] struct { Val V Tag T } -func Tag[V, T any](out chan<- Tagged[V, T], in <-chan V, f func(V) T) { - for val := range in { - out <- Tagged[V, T]{val, f(val)} +func Tag[V, T any](ctx context.Context, out chan<- Tagged[V, T], in <-chan V, f func(V) T) { + for { + select { + case val, ok := <-in: + if !ok { + return + } + out <- Tagged[V, T]{val, f(val)} + case <-ctx.Done(): + return + } } } diff --git a/lay/strain/solve.go b/lay/strain/solve.go index cfb1f41..41cabc5 100644 --- a/lay/strain/solve.go +++ b/lay/strain/solve.go @@ -1,8 +1,10 @@ package strain import ( + "context" "fmt" "image" + "math" "sync" "github.com/lithdew/casso" @@ -13,8 +15,9 @@ import ( ) const ( - fieldConstraintPriority = casso.Medium - layoutConstraintPriority = casso.Medium + solverPriority = casso.Medium + layoutPriority = casso.Medium + fieldPriority = casso.Medium ) // Solver uses the Cassowary algorithm to partition a rectangle among @@ -27,18 +30,13 @@ const ( // and size of fields within the container, while the field // constraints control the width and height of each field. type Solver struct { - solver *casso.Solver - style *style.Style - fieldConstrs chan tag.Tagged[Constraint, fieldIndex] // incoming constraints from fields - // External symbols container SymRect // position and size of container fields []SymRect // position and size of each field - // Constraint ID symbols - fieldSizeConstrs []sizeConstraintSymbols - - mu sync.Mutex + fieldConstrs chan tag.Tagged[Constraint, fieldIndex] + layoutConstrs chan constrainRequest + solveReqs chan solveRequest } type fieldIndex int @@ -56,6 +54,24 @@ type sizeConstraintSymbols struct { // Nil means no constraint. } +type constrainRequest struct { + casso.Priority + casso.Op + constant float64 + terms []casso.Term + res chan<- error +} + +type solveRequest struct { + container image.Rectangle + res chan<- solveResponse +} + +type solveResponse struct { + fields []image.Rectangle + error +} + // NewSolver creates a Solver that can be used to resolve constraints // received from the given channels: one channel per field in the // layout. These are generally the receiving side of some Envs' @@ -66,100 +82,114 @@ func NewSolver(styl *style.Style, constraints []<-chan Constraint) (*Solver, err fieldConstrs := make(chan tag.Tagged[Constraint, fieldIndex]) var wg sync.WaitGroup wg.Add(nfields) - for i, cs := range constraints { + ctx, cancel := context.WithCancel(context.Background()) + for i := range fields { fields[i] = NewSymRect() go func() { // Tag incoming field constraints by field index and multiplex them into fieldConstrs - tag.Tag(fieldConstrs, cs, func(c Constraint) fieldIndex { + tag.Tag(ctx, fieldConstrs, constraints[i], func(c Constraint) fieldIndex { return fieldIndex(i) }) wg.Done() }() } - go func() { wg.Wait() // once all fields close their constraints channels, - close(fieldConstrs) + cancel() // destroy the solver }() solver := casso.NewSolver() container := NewSymRect() if err := editRect(solver, container, casso.Strong); err != nil { + cancel() return nil, fmt.Errorf("error marking container symbol as editable: %w", err) } s := &Solver{ - solver, - styl, - fieldConstrs, container, fields, - make([]sizeConstraintSymbols, nfields), - sync.Mutex{}, + fieldConstrs, + make(chan constrainRequest), + make(chan solveRequest), + } + go s.run(ctx, solver, styl) + if err := s.addDefaultConstraints(); err != nil { + cancel() + return nil, fmt.Errorf("error adding default constraint: %w", err) } + return s, nil +} - go func() { - for tc := range fieldConstrs { - constr, fieldIdx := tc.Val, tc.Tag - s.mu.Lock() - err := s.addFieldSizeConstraint(constr, fieldIdx) - s.mu.Unlock() +func (s *Solver) run(ctx context.Context, solver *casso.Solver, styl *style.Style) { + defer close(s.fieldConstrs) + defer close(s.layoutConstrs) + defer close(s.solveReqs) + + fieldSizeConstrs := make([]sizeConstraintSymbols, len(s.fields)) + for { + select { + case tc, ok := <-s.fieldConstrs: + if !ok { + return + } + constr, i := tc.Val, tc.Tag + err := addFieldSizeConstraint(solver, styl, constr, s.fields[i], &fieldSizeConstrs[i]) if err != nil { log.Err.Printf("error adding layout constraint %#v from field %d: %v\n", - constr, fieldIdx, err) + constr, i, err) } - } - }() - if err := s.addDefaultConstraints(); err != nil { - return nil, fmt.Errorf("error adding default constraint: %w", err) - } + case req := <-s.layoutConstrs: + _, err := addConstraint(solver, req.Priority, req.Op, req.constant, req.terms...) + req.res <- err - return s, nil + case req := <-s.solveReqs: + fields, err := s.solve(solver, req.container) + req.res <- solveResponse{fields, err} + + case <-ctx.Done(): + return + } + } } func (s *Solver) addDefaultConstraints() error { for _, field := range s.fields { - if err := s.AddConstraintPt(casso.GTE, field.Origin, s.container.Origin); err != nil { + if err := s.addConstraintPt(solverPriority, casso.GTE, field.Origin, s.container.Origin); err != nil { return err } - if err := s.AddConstraintPt(casso.LTE, field.Size, s.container.Size); err != nil { + if err := s.addConstraintPt(solverPriority, casso.LTE, field.Size, s.container.Size); err != nil { return err } } return nil } -// addSizeConstraint adds or modifies a constraint on the size of a -// field, removing mutually exclusive constraints. -// -// Solver.mu must be held. -func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error { - fieldSize := s.fields[i].Size - fieldConstrs := &s.fieldSizeConstrs[i] - +// addFieldSizeConstraint adds or modifies a constraint on the size of +// a field, removing mutually exclusive constraints. +func addFieldSizeConstraint(s *casso.Solver, styl *style.Style, constr Constraint, field SymRect, fieldConstrs *sizeConstraintSymbols) error { // Clear mutually exclusive constraints and replace with new one switch constr.Dimension { case Width: - width := float64(s.style.Pixels(constr.Value).Round()) + width := float64(styl.Pixels(constr.Value).Round()) switch constr.Op { case casso.EQ: - s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthGte, fieldConstrs.widthLte) - c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0)) + removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthGte, fieldConstrs.widthLte) + c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0)) if err != nil { return err } fieldConstrs.widthEq = &c case casso.GTE: - s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthGte) - c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0)) + removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthGte) + c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0)) if err != nil { return err } fieldConstrs.widthGte = &c case casso.LTE: - s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthLte) - c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0)) + removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthLte) + c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0)) if err != nil { return err } @@ -168,25 +198,25 @@ func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error { panic(fmt.Sprintf("unreachable: impossible %T: %v", constr.Op, constr.Op)) } case Height: - height := float64(s.style.Pixels(constr.Value).Round()) + height := float64(styl.Pixels(constr.Value).Round()) switch constr.Op { case casso.EQ: - s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightGte, fieldConstrs.heightLte) - c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0)) + removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightGte, fieldConstrs.heightLte) + c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0)) if err != nil { return err } fieldConstrs.heightEq = &c case casso.GTE: - s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightGte) - c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0)) + removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightGte) + c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0)) if err != nil { return err } fieldConstrs.heightGte = &c case casso.LTE: - s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightLte) - c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0)) + removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightLte) + c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0)) if err != nil { return err } @@ -201,11 +231,10 @@ func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error { return nil } -// Solver.mu must be held. -func (s *Solver) removeConstraints(constrs ...*casso.Symbol) error { +func removeConstraints(s *casso.Solver, constrs ...*casso.Symbol) error { for _, constr := range constrs { if constr != nil { - if err := s.solver.RemoveConstraint(*constr); err != nil { + if err := s.RemoveConstraint(*constr); err != nil { return err } } @@ -213,11 +242,30 @@ func (s *Solver) removeConstraints(constrs ...*casso.Symbol) error { return nil } -// Solver.mu must be held. -func (s *Solver) addFieldConstraint(op casso.Op, constant float64, terms ...casso.Term) (casso.Symbol, error) { - return s.solver.AddConstraintWithPriority( - fieldConstraintPriority, - casso.NewConstraint(op, constant, terms...)) +func addConstraint(s *casso.Solver, priority casso.Priority, op casso.Op, constant float64, terms ...casso.Term) (casso.Symbol, error) { + return s.AddConstraintWithPriority(priority, casso.NewConstraint(op, constant, terms...)) +} + +func (s *Solver) solve(solver *casso.Solver, container image.Rectangle) (fields []image.Rectangle, err error) { + if err := suggestRect(solver, s.container, container); err != nil { + return nil, err + } + + fields = make([]image.Rectangle, len(s.fields)) + for i, field := range s.fields { + min := image.Pt( + s.val(solver, field.Origin.X), + s.val(solver, field.Origin.Y)) + max := min.Add(image.Pt( + s.val(solver, field.Size.X), + s.val(solver, field.Size.Y))) + fields[i] = image.Rectangle{min, max} + } + return fields, nil +} + +func (s *Solver) val(solver *casso.Solver, sym casso.Symbol) int { + return int(math.Round(solver.Val(sym))) } // Container returns the Cassowary symbols representing the layout @@ -228,24 +276,48 @@ func (s *Solver) Container() SymRect { return s.container } // position and size. func (s *Solver) Field(i int) SymRect { return s.fields[i] } -// AddConstraint imposes a constraint between two symbols. The symbols -// may be aspects of the Container() or Field()s. +// AddConstraint imposes a constraint on a set of terms: +// +// constant + ∑terms op 0 +// +// The terms can be obtained by calling the T(coeff) method of the +// symbols returned by Container() or Field(n). Think of terms with +// positive coefficients as being on the LHS and terms with negative +// coefficients as being on the RHS. // // Once added, a constraint cannot be removed or modified. -func (s *Solver) AddConstraint(op casso.Op, lhs, rhs casso.Symbol) error { - s.mu.Lock() - defer s.mu.Unlock() - _, err := s.solver.AddConstraintWithPriority(layoutConstraintPriority, casso.NewConstraint(op, 0, lhs.T(1.0), rhs.T(-1.0))) - return err +func (s *Solver) AddConstraint(op casso.Op, constant float64, terms ...casso.Term) error { + return s.addConstraint(layoutPriority, op, constant, terms...) } // AddConstraintPt imposes a constraint between two point symbols: // (lhs.X op rhs.X) and (lhs.Y op rhs.Y). func (s *Solver) AddConstraintPt(op casso.Op, lhs, rhs SymPt) error { - if err := s.AddConstraint(op, lhs.X, rhs.X); err != nil { + return s.addConstraintPt(layoutPriority, op, lhs, rhs) +} + +func (s *Solver) addConstraint(priority casso.Priority, op casso.Op, constant float64, terms ...casso.Term) error { + resc := make(chan error) + defer close(resc) + s.layoutConstrs <- constrainRequest{priority, op, constant, terms, resc} + return <-resc +} + +func (s *Solver) addConstraintPt(priority casso.Priority, op casso.Op, lhs, rhs SymPt) error { + if err := s.addConstraint(priority, op, 0, lhs.X.T(1.0), rhs.X.T(-1.0)); err != nil { return err } - if err := s.AddConstraint(op, lhs.Y, rhs.Y); err != nil { + if err := s.addConstraint(priority, op, 0, lhs.Y.T(1.0), rhs.Y.T(-1.0)); err != nil { + return err + } + return nil +} + +func (s *Solver) addSizeConstraint(priority casso.Priority, op casso.Op, lhs SymPt, rhs image.Point) error { + if err := s.addConstraint(priority, op, -float64(rhs.X), lhs.X.T(1.0)); err != nil { + return err + } + if err := s.addConstraint(priority, op, -float64(rhs.Y), lhs.Y.T(1.0)); err != nil { return err } return nil @@ -257,27 +329,10 @@ func (s *Solver) AddConstraintPt(op casso.Op, lhs, rhs SymPt) error { // It returns a slice of Rectangles, one per field. The slice has the // same length as the slice of Constraint channels that were passed to // NewSolver(). -func (s *Solver) Solve(container image.Rectangle) ([]image.Rectangle, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if err := suggestRect(s.solver, s.container, container); err != nil { - return nil, err - } - - fields := make([]image.Rectangle, len(s.fields)) - for i := range fields { - field := s.fields[i] - min := image.Pt( - s.intVal(field.Origin.X), - s.intVal(field.Origin.Y)) - max := min.Add(image.Pt( - s.intVal(field.Size.X), - s.intVal(field.Size.Y))) - fields[i] = image.Rectangle{min, max} - } - return fields, nil +func (s *Solver) Solve(container image.Rectangle) (fields []image.Rectangle, err error) { + resc := make(chan solveResponse) + defer close(resc) + s.solveReqs <- solveRequest{container, resc} + res := <-resc + return res.fields, res.error } - -// Solver.mu must be held. -func (s *Solver) intVal(sym casso.Symbol) int { return int(s.solver.Val(sym)) } diff --git a/lay/strain/solve_test.go b/lay/strain/solve_test.go index 6281124..4c0cc83 100644 --- a/lay/strain/solve_test.go +++ b/lay/strain/solve_test.go @@ -80,8 +80,8 @@ func TestSingleField(t *testing.T) { // Setup constraints := make(chan strain.Constraint) st := newSolverTest(t, []<-chan strain.Constraint{constraints}) - defer st.Close() defer close(constraints) + defer st.Close() // Add layout constraints container := st.Solver.Container() @@ -107,10 +107,10 @@ func TestFieldMinSize(t *testing.T) { // Setup constraints := make(chan strain.Constraint) st := newSolverTest(t, []<-chan strain.Constraint{constraints}) - defer st.Close() defer close(constraints) + defer st.Close() - // Add widget constraints + // Add field constraints minWidth := unit.Value{32, unit.Ch} minHeight := unit.Value{1.5, unit.Em} constraints <- strain.Constraint{strain.Width, casso.GTE, minWidth} @@ -140,26 +140,91 @@ func TestFieldMinSizeLargerThanContainer(t *testing.T) { // Setup constraints := make(chan strain.Constraint) st := newSolverTest(t, []<-chan strain.Constraint{constraints}) - defer st.Close() defer close(constraints) + defer st.Close() - // Add widget constraints + // Add field constraints constraints <- strain.Constraint{strain.Width, casso.GTE, unit.Value{200, unit.Px}} constraints <- strain.Constraint{strain.Height, casso.GTE, unit.Value{300, unit.Px}} - synctest.Wait() // Solve container := image.Rect(12, 34, 100, 200) + synctest.Wait() st.solve(container, validateEq([]image.Rectangle{container})) }) } -// Solver with only layout constaints, no field constraints. -func TestLayConstrs(t *testing.T) { +// Fields arranged as rows. +func TestRows(t *testing.T) { t.Parallel() + synctest.Test(t, func(t *testing.T) { + // Setup + nrows := 8 + constraintss := make([]chan strain.Constraint, nrows) + for i := range constraintss { + constraintss[i] = make(chan strain.Constraint) + } + st := newSolverTest(t, castRx(constraintss)) + defer func() { + for _, c := range constraintss { + close(c) + } + st.Close() + }() + + // Add layout constraints + require.NoError(t, st.Solver.AddConstraintPt(casso.EQ, st.Solver.Field(0).Origin, st.Solver.Container().Origin)) // start from top left corner + fieldHeights := make([]casso.Term, nrows) + for i := 0; i < nrows; i++ { + fieldHeights[i] = st.Field(i).Size.Y.T(1) + container := st.Solver.Container() + f := st.Solver.Field(i) + require.NoError(t, st.Solver.AddConstraint(casso.EQ, 0, f.Size.X.T(1), container.Size.X.T(-1))) // span full width + } + terms := append(fieldHeights, st.Solver.Container().Size.Y.T(-1)) // ∑field[i].height <= container.height + require.NoError(t, st.Solver.AddConstraint(casso.LTE, 0, terms...)) + for i := 1; i < nrows; i++ { + f0, f1 := st.Solver.Field(i-1), st.Solver.Field(i) + require.NoError(t, st.Solver.AddConstraint(casso.EQ, 0, f1.Origin.Y.T(1), f0.Origin.Y.T(-1), f0.Size.Y.T(-1))) // in order + require.NoError(t, st.Solver.AddConstraintPt(casso.EQ, f1.Size, f0.Size)) // same size + } - st := newSolverTest(t, nil) - defer st.Close() + // Add field constraints + var rowHeight int + for i, c := range constraintss { + rowHeight = 2 * i + c <- strain.Constraint{strain.Width, casso.GTE, unit.Value{float64(16 + i), unit.Ch}} + c <- strain.Constraint{strain.Height, casso.GTE, unit.Value{float64(rowHeight), unit.Px}} + } + + // Solve + container := image.Rect(123, 234, 567, 678) + synctest.Wait() + fields, err := st.Solver.Solve(container) + if err != nil { + t.Fatal(err) + } + require.EqualValues(t, nrows, len(fields), "wrong number of fields") + for _, field := range fields { + t.Logf("%v (%d %d)\n", field, field.Dx(), field.Dy()) + } + for i, field := range fields { + require.Equal(t, container.Min.X, field.Min.X, "not left-aligned with container") + require.Equal(t, container.Min.Y+i*rowHeight, field.Min.Y, "wrong y position") + require.Equal(t, container.Dx(), field.Dx(), "wrong width") + require.Equal(t, rowHeight, field.Dy(), "wrong height") + } + }) +} +func TestSolver(t *testing.T) { t.Fail() // TODO: more tests } + +func castRx[T any](cs []chan T) []<-chan T { + rcs := make([]<-chan T, len(cs)) + for i, c := range cs { + rcs[i] = c + } + return rcs +} |