diff options
Diffstat (limited to 'lay/strain/solve.go')
| -rw-r--r-- | lay/strain/solve.go | 247 |
1 files changed, 151 insertions, 96 deletions
diff --git a/lay/strain/solve.go b/lay/strain/solve.go index cfb1f41..41cabc5 100644 --- a/lay/strain/solve.go +++ b/lay/strain/solve.go @@ -1,8 +1,10 @@ package strain import ( + "context" "fmt" "image" + "math" "sync" "github.com/lithdew/casso" @@ -13,8 +15,9 @@ import ( ) const ( - fieldConstraintPriority = casso.Medium - layoutConstraintPriority = casso.Medium + solverPriority = casso.Medium + layoutPriority = casso.Medium + fieldPriority = casso.Medium ) // Solver uses the Cassowary algorithm to partition a rectangle among @@ -27,18 +30,13 @@ const ( // and size of fields within the container, while the field // constraints control the width and height of each field. type Solver struct { - solver *casso.Solver - style *style.Style - fieldConstrs chan tag.Tagged[Constraint, fieldIndex] // incoming constraints from fields - // External symbols container SymRect // position and size of container fields []SymRect // position and size of each field - // Constraint ID symbols - fieldSizeConstrs []sizeConstraintSymbols - - mu sync.Mutex + fieldConstrs chan tag.Tagged[Constraint, fieldIndex] + layoutConstrs chan constrainRequest + solveReqs chan solveRequest } type fieldIndex int @@ -56,6 +54,24 @@ type sizeConstraintSymbols struct { // Nil means no constraint. } +type constrainRequest struct { + casso.Priority + casso.Op + constant float64 + terms []casso.Term + res chan<- error +} + +type solveRequest struct { + container image.Rectangle + res chan<- solveResponse +} + +type solveResponse struct { + fields []image.Rectangle + error +} + // NewSolver creates a Solver that can be used to resolve constraints // received from the given channels: one channel per field in the // layout. These are generally the receiving side of some Envs' @@ -66,100 +82,114 @@ func NewSolver(styl *style.Style, constraints []<-chan Constraint) (*Solver, err fieldConstrs := make(chan tag.Tagged[Constraint, fieldIndex]) var wg sync.WaitGroup wg.Add(nfields) - for i, cs := range constraints { + ctx, cancel := context.WithCancel(context.Background()) + for i := range fields { fields[i] = NewSymRect() go func() { // Tag incoming field constraints by field index and multiplex them into fieldConstrs - tag.Tag(fieldConstrs, cs, func(c Constraint) fieldIndex { + tag.Tag(ctx, fieldConstrs, constraints[i], func(c Constraint) fieldIndex { return fieldIndex(i) }) wg.Done() }() } - go func() { wg.Wait() // once all fields close their constraints channels, - close(fieldConstrs) + cancel() // destroy the solver }() solver := casso.NewSolver() container := NewSymRect() if err := editRect(solver, container, casso.Strong); err != nil { + cancel() return nil, fmt.Errorf("error marking container symbol as editable: %w", err) } s := &Solver{ - solver, - styl, - fieldConstrs, container, fields, - make([]sizeConstraintSymbols, nfields), - sync.Mutex{}, + fieldConstrs, + make(chan constrainRequest), + make(chan solveRequest), + } + go s.run(ctx, solver, styl) + if err := s.addDefaultConstraints(); err != nil { + cancel() + return nil, fmt.Errorf("error adding default constraint: %w", err) } + return s, nil +} - go func() { - for tc := range fieldConstrs { - constr, fieldIdx := tc.Val, tc.Tag - s.mu.Lock() - err := s.addFieldSizeConstraint(constr, fieldIdx) - s.mu.Unlock() +func (s *Solver) run(ctx context.Context, solver *casso.Solver, styl *style.Style) { + defer close(s.fieldConstrs) + defer close(s.layoutConstrs) + defer close(s.solveReqs) + + fieldSizeConstrs := make([]sizeConstraintSymbols, len(s.fields)) + for { + select { + case tc, ok := <-s.fieldConstrs: + if !ok { + return + } + constr, i := tc.Val, tc.Tag + err := addFieldSizeConstraint(solver, styl, constr, s.fields[i], &fieldSizeConstrs[i]) if err != nil { log.Err.Printf("error adding layout constraint %#v from field %d: %v\n", - constr, fieldIdx, err) + constr, i, err) } - } - }() - if err := s.addDefaultConstraints(); err != nil { - return nil, fmt.Errorf("error adding default constraint: %w", err) - } + case req := <-s.layoutConstrs: + _, err := addConstraint(solver, req.Priority, req.Op, req.constant, req.terms...) + req.res <- err - return s, nil + case req := <-s.solveReqs: + fields, err := s.solve(solver, req.container) + req.res <- solveResponse{fields, err} + + case <-ctx.Done(): + return + } + } } func (s *Solver) addDefaultConstraints() error { for _, field := range s.fields { - if err := s.AddConstraintPt(casso.GTE, field.Origin, s.container.Origin); err != nil { + if err := s.addConstraintPt(solverPriority, casso.GTE, field.Origin, s.container.Origin); err != nil { return err } - if err := s.AddConstraintPt(casso.LTE, field.Size, s.container.Size); err != nil { + if err := s.addConstraintPt(solverPriority, casso.LTE, field.Size, s.container.Size); err != nil { return err } } return nil } -// addSizeConstraint adds or modifies a constraint on the size of a -// field, removing mutually exclusive constraints. -// -// Solver.mu must be held. -func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error { - fieldSize := s.fields[i].Size - fieldConstrs := &s.fieldSizeConstrs[i] - +// addFieldSizeConstraint adds or modifies a constraint on the size of +// a field, removing mutually exclusive constraints. +func addFieldSizeConstraint(s *casso.Solver, styl *style.Style, constr Constraint, field SymRect, fieldConstrs *sizeConstraintSymbols) error { // Clear mutually exclusive constraints and replace with new one switch constr.Dimension { case Width: - width := float64(s.style.Pixels(constr.Value).Round()) + width := float64(styl.Pixels(constr.Value).Round()) switch constr.Op { case casso.EQ: - s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthGte, fieldConstrs.widthLte) - c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0)) + removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthGte, fieldConstrs.widthLte) + c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0)) if err != nil { return err } fieldConstrs.widthEq = &c case casso.GTE: - s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthGte) - c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0)) + removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthGte) + c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0)) if err != nil { return err } fieldConstrs.widthGte = &c case casso.LTE: - s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthLte) - c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0)) + removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthLte) + c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0)) if err != nil { return err } @@ -168,25 +198,25 @@ func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error { panic(fmt.Sprintf("unreachable: impossible %T: %v", constr.Op, constr.Op)) } case Height: - height := float64(s.style.Pixels(constr.Value).Round()) + height := float64(styl.Pixels(constr.Value).Round()) switch constr.Op { case casso.EQ: - s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightGte, fieldConstrs.heightLte) - c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0)) + removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightGte, fieldConstrs.heightLte) + c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0)) if err != nil { return err } fieldConstrs.heightEq = &c case casso.GTE: - s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightGte) - c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0)) + removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightGte) + c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0)) if err != nil { return err } fieldConstrs.heightGte = &c case casso.LTE: - s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightLte) - c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0)) + removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightLte) + c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0)) if err != nil { return err } @@ -201,11 +231,10 @@ func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error { return nil } -// Solver.mu must be held. -func (s *Solver) removeConstraints(constrs ...*casso.Symbol) error { +func removeConstraints(s *casso.Solver, constrs ...*casso.Symbol) error { for _, constr := range constrs { if constr != nil { - if err := s.solver.RemoveConstraint(*constr); err != nil { + if err := s.RemoveConstraint(*constr); err != nil { return err } } @@ -213,11 +242,30 @@ func (s *Solver) removeConstraints(constrs ...*casso.Symbol) error { return nil } -// Solver.mu must be held. -func (s *Solver) addFieldConstraint(op casso.Op, constant float64, terms ...casso.Term) (casso.Symbol, error) { - return s.solver.AddConstraintWithPriority( - fieldConstraintPriority, - casso.NewConstraint(op, constant, terms...)) +func addConstraint(s *casso.Solver, priority casso.Priority, op casso.Op, constant float64, terms ...casso.Term) (casso.Symbol, error) { + return s.AddConstraintWithPriority(priority, casso.NewConstraint(op, constant, terms...)) +} + +func (s *Solver) solve(solver *casso.Solver, container image.Rectangle) (fields []image.Rectangle, err error) { + if err := suggestRect(solver, s.container, container); err != nil { + return nil, err + } + + fields = make([]image.Rectangle, len(s.fields)) + for i, field := range s.fields { + min := image.Pt( + s.val(solver, field.Origin.X), + s.val(solver, field.Origin.Y)) + max := min.Add(image.Pt( + s.val(solver, field.Size.X), + s.val(solver, field.Size.Y))) + fields[i] = image.Rectangle{min, max} + } + return fields, nil +} + +func (s *Solver) val(solver *casso.Solver, sym casso.Symbol) int { + return int(math.Round(solver.Val(sym))) } // Container returns the Cassowary symbols representing the layout @@ -228,24 +276,48 @@ func (s *Solver) Container() SymRect { return s.container } // position and size. func (s *Solver) Field(i int) SymRect { return s.fields[i] } -// AddConstraint imposes a constraint between two symbols. The symbols -// may be aspects of the Container() or Field()s. +// AddConstraint imposes a constraint on a set of terms: +// +// constant + ∑terms op 0 +// +// The terms can be obtained by calling the T(coeff) method of the +// symbols returned by Container() or Field(n). Think of terms with +// positive coefficients as being on the LHS and terms with negative +// coefficients as being on the RHS. // // Once added, a constraint cannot be removed or modified. -func (s *Solver) AddConstraint(op casso.Op, lhs, rhs casso.Symbol) error { - s.mu.Lock() - defer s.mu.Unlock() - _, err := s.solver.AddConstraintWithPriority(layoutConstraintPriority, casso.NewConstraint(op, 0, lhs.T(1.0), rhs.T(-1.0))) - return err +func (s *Solver) AddConstraint(op casso.Op, constant float64, terms ...casso.Term) error { + return s.addConstraint(layoutPriority, op, constant, terms...) } // AddConstraintPt imposes a constraint between two point symbols: // (lhs.X op rhs.X) and (lhs.Y op rhs.Y). func (s *Solver) AddConstraintPt(op casso.Op, lhs, rhs SymPt) error { - if err := s.AddConstraint(op, lhs.X, rhs.X); err != nil { + return s.addConstraintPt(layoutPriority, op, lhs, rhs) +} + +func (s *Solver) addConstraint(priority casso.Priority, op casso.Op, constant float64, terms ...casso.Term) error { + resc := make(chan error) + defer close(resc) + s.layoutConstrs <- constrainRequest{priority, op, constant, terms, resc} + return <-resc +} + +func (s *Solver) addConstraintPt(priority casso.Priority, op casso.Op, lhs, rhs SymPt) error { + if err := s.addConstraint(priority, op, 0, lhs.X.T(1.0), rhs.X.T(-1.0)); err != nil { return err } - if err := s.AddConstraint(op, lhs.Y, rhs.Y); err != nil { + if err := s.addConstraint(priority, op, 0, lhs.Y.T(1.0), rhs.Y.T(-1.0)); err != nil { + return err + } + return nil +} + +func (s *Solver) addSizeConstraint(priority casso.Priority, op casso.Op, lhs SymPt, rhs image.Point) error { + if err := s.addConstraint(priority, op, -float64(rhs.X), lhs.X.T(1.0)); err != nil { + return err + } + if err := s.addConstraint(priority, op, -float64(rhs.Y), lhs.Y.T(1.0)); err != nil { return err } return nil @@ -257,27 +329,10 @@ func (s *Solver) AddConstraintPt(op casso.Op, lhs, rhs SymPt) error { // It returns a slice of Rectangles, one per field. The slice has the // same length as the slice of Constraint channels that were passed to // NewSolver(). -func (s *Solver) Solve(container image.Rectangle) ([]image.Rectangle, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if err := suggestRect(s.solver, s.container, container); err != nil { - return nil, err - } - - fields := make([]image.Rectangle, len(s.fields)) - for i := range fields { - field := s.fields[i] - min := image.Pt( - s.intVal(field.Origin.X), - s.intVal(field.Origin.Y)) - max := min.Add(image.Pt( - s.intVal(field.Size.X), - s.intVal(field.Size.Y))) - fields[i] = image.Rectangle{min, max} - } - return fields, nil +func (s *Solver) Solve(container image.Rectangle) (fields []image.Rectangle, err error) { + resc := make(chan solveResponse) + defer close(resc) + s.solveReqs <- solveRequest{container, resc} + res := <-resc + return res.fields, res.error } - -// Solver.mu must be held. -func (s *Solver) intVal(sym casso.Symbol) int { return int(s.solver.Val(sym)) } |